Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 167 additions & 7 deletions graphile/graphile-realtime-subscriptions/__tests__/plugin.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
* - Multiple realtime tables produce multiple fields
* - NOTIFY payload parsing (TG_OP:id1,id2,... and INVALIDATE)
* - Per-subscriber event throttling with configurable limit
* - Sparse set subscriptions (ids: [UUID!]) with row ID intersection filtering
* - RLS-aware rowId masking in payload resolvers
*/

jest.mock('@pgpmjs/logger', () => ({
Expand Down Expand Up @@ -314,7 +316,7 @@ describe('createRealtimeSubscriptionsPlugin', () => {
});

describe('type definitions', () => {
it('generates subscription field with optional id argument', () => {
it('generates subscription field with ids argument', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('documents', { realtime: true });
Expand All @@ -324,7 +326,7 @@ describe('createRealtimeSubscriptionsPlugin', () => {

const result = capturedFactory!(build);

expect(result.typeDefs).toContain('onDocumentsChanged(id: UUID): DocumentsSubscriptionPayload');
expect(result.typeDefs).toContain('onDocumentsChanged(ids: [UUID!]): DocumentsSubscriptionPayload');
});

it('generates payload type with event, row, rowId, and overflow fields', () => {
Expand All @@ -342,6 +344,7 @@ describe('createRealtimeSubscriptionsPlugin', () => {
expect(result.typeDefs).toContain('documents: Documents');
expect(result.typeDefs).toContain('rowId: UUID');
expect(result.typeDefs).toContain('overflow: Boolean!');
expect(result.typeDefs).toContain('masked when RLS denies access');
});

it('extends Subscription type', () => {
Expand Down Expand Up @@ -499,17 +502,17 @@ describe('createRealtimeSubscriptionsPlugin', () => {
const result = capturedFactory!(build);
const mockParent = { get: jest.fn((key: string) => {
if (key === 'parsed') return { event: 'INSERT', rowIds: ['row-uuid'], overflow: false };
if (key === 'subscribedId') return null;
if (key === 'subscribedIds') return null;
return null;
}) };

result.plans['TasksSubscriptionPayload'].tasks(mockParent);
expect(mockParent.get).toHaveBeenCalledWith('parsed');
expect(mockParent.get).toHaveBeenCalledWith('subscribedId');
expect(mockParent.get).toHaveBeenCalledWith('subscribedIds');
expect(mockResource.get).toHaveBeenCalled();
});

it('payload row resolver prefers subscribedId over parsed rowId', () => {
it('payload row resolver uses first matching ID when ids provided', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('tasks', { realtime: true });
Expand All @@ -523,12 +526,13 @@ describe('createRealtimeSubscriptionsPlugin', () => {

const result = capturedFactory!(build);
const mockParent = { get: jest.fn((key: string) => {
if (key === 'parsed') return { event: 'UPDATE', rowIds: ['row-uuid'], overflow: false };
if (key === 'subscribedId') return 'subscribed-uuid';
if (key === 'parsed') return { event: 'INSERT', rowIds: ['id-a', 'id-b', 'id-c'], overflow: false };
if (key === 'subscribedIds') return ['id-b', 'id-d'];
return null;
}) };

result.plans['TasksSubscriptionPayload'].tasks(mockParent);
expect(mockParent.get).toHaveBeenCalledWith('subscribedIds');
expect(mockResource.get).toHaveBeenCalled();
});
});
Expand Down Expand Up @@ -558,4 +562,160 @@ describe('createRealtimeSubscriptionsPlugin', () => {
expect(result.plans).toBeDefined();
});
});

describe('sparse set filtering (ids argument)', () => {
it('subscribePlan passes ids through object step', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('tasks', { realtime: true });
const build = createMockBuild({
tasks: createMockResource('tasks', codec),
});

const result = capturedFactory!(build);
const mockArgs = { get: jest.fn((key: string) => {
if (key === 'ids') return ['id-a', 'id-b'];
return null;
}) };

result.plans['Subscription']['onTasksChanged'].subscribePlan(null, mockArgs);

expect(mockArgs.get).toHaveBeenCalledWith('ids');

// The listen callback is captured but not invoked by the mock.
// Invoke it manually to verify ids are threaded through.
expect(mockListen).toHaveBeenCalled();
const listenCallback = mockListen.mock.calls[mockListen.mock.calls.length - 1][2];
listenCallback('INSERT:id-a');

expect(mockObject).toHaveBeenCalled();
const objectArg = mockObject.mock.calls[mockObject.mock.calls.length - 1][0];
expect(objectArg).toHaveProperty('subscribedIds');
});

it('drops events with no row ID intersection in sparse set mode', () => {
const parsed = parseNotifyPayload('INSERT:id-x,id-y');
const subscribedIds = ['id-a', 'id-b'];

const hasMatch = parsed.rowIds.some((rid: string) => subscribedIds.includes(rid));
expect(hasMatch).toBe(false);
});

it('delivers events with row ID intersection in sparse set mode', () => {
const parsed = parseNotifyPayload('UPDATE:id-a,id-x');
const subscribedIds = ['id-a', 'id-b'];

const hasMatch = parsed.rowIds.some((rid: string) => subscribedIds.includes(rid));
expect(hasMatch).toBe(true);
});

it('delivers INVALIDATE events regardless of sparse set', () => {
const parsed = parseNotifyPayload('INVALIDATE');
expect(parsed.overflow).toBe(true);
expect(parsed.rowIds).toEqual([]);
});

it('rowId resolver returns first matching ID from sparse set', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('tasks', { realtime: true });
const build = createMockBuild({
tasks: { ...createMockResource('tasks', codec), get: jest.fn() },
});

const result = capturedFactory!(build);
const payload = result.plans['TasksSubscriptionPayload'];

const mockParent = { get: jest.fn((key: string) => {
if (key === 'parsed') return { event: 'UPDATE', rowIds: ['id-x', 'id-b', 'id-a'], overflow: false };
if (key === 'subscribedIds') return ['id-a', 'id-b'];
return null;
}) };

payload.rowId(mockParent);
expect(mockParent.get).toHaveBeenCalledWith('parsed');
expect(mockParent.get).toHaveBeenCalledWith('subscribedIds');
});

it('rowId resolver returns null when no sparse set match', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('tasks', { realtime: true });
const build = createMockBuild({
tasks: { ...createMockResource('tasks', codec), get: jest.fn() },
});

const result = capturedFactory!(build);
const payload = result.plans['TasksSubscriptionPayload'];

const mockParent = { get: jest.fn((key: string) => {
if (key === 'parsed') return { event: 'INSERT', rowIds: ['id-x'], overflow: false };
if (key === 'subscribedIds') return ['id-a', 'id-b'];
return null;
}) };

payload.rowId(mockParent);
expect(mockParent.get).toHaveBeenCalledWith('subscribedIds');
});

it('rowId resolver falls back to first rowId when no sparse set provided', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('tasks', { realtime: true });
const build = createMockBuild({
tasks: { ...createMockResource('tasks', codec), get: jest.fn() },
});

const result = capturedFactory!(build);
const payload = result.plans['TasksSubscriptionPayload'];

const mockParent = { get: jest.fn((key: string) => {
if (key === 'parsed') return { event: 'INSERT', rowIds: ['id-first', 'id-second'], overflow: false };
if (key === 'subscribedIds') return null;
return null;
}) };

payload.rowId(mockParent);
expect(mockParent.get).toHaveBeenCalledWith('subscribedIds');
});
});

describe('RLS-aware event delivery', () => {
it('rowId doc comment mentions RLS masking', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('items', { realtime: true });
const build = createMockBuild({
items: createMockResource('items', codec),
});

const result = capturedFactory!(build);
expect(result.typeDefs).toContain('masked when RLS denies access');
});

it('type defs include sparse set ids argument', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('items', { realtime: true });
const build = createMockBuild({
items: createMockResource('items', codec),
});

const result = capturedFactory!(build);
expect(result.typeDefs).toContain('ids: [UUID!]');
});

it('type defs include description mentioning all subscription modes', () => {
createRealtimeSubscriptionsPlugin();

const codec = createMockCodec('items', { realtime: true });
const build = createMockBuild({
items: createMockResource('items', codec),
});

const result = capturedFactory!(build);
expect(result.typeDefs).toContain('specific rows');
expect(result.typeDefs).toContain('full collection');
});
});
});
63 changes: 46 additions & 17 deletions graphile/graphile-realtime-subscriptions/src/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* for real-time event delivery.
*
* Subscription modes:
* - Single record: onXxxChanged(id: UUID!) — subscribe to changes on one row
* - Specific rows: onXxxChanged(ids: [UUID!]) — subscribe to changes on specific rows
* - Full collection: onXxxChanged (no args) — subscribe to any change on the table
*
* NOTIFY payload format (from emit_change trigger):
Expand All @@ -25,8 +25,16 @@
* - Plugin-side: per-subscriber throttle (default 50 events/second/table)
* drops individual events and sends a single INVALIDATE when exceeded
*
* RLS enforcement is automatic — resource.get() queries through the
* authenticated user's connection with their JWT role applied.
* Security / RLS enforcement:
* - Row data is always fetched via resource.get() which runs through the
* authenticated user's connection with their JWT role and pgSettings applied.
* - For INSERT/UPDATE events, if RLS denies access (resource.get returns null),
* the rowId is masked (set to null) to prevent metadata leaks.
* - For DELETE events, row is naturally null (the row no longer exists).
* - For INVALIDATE (overflow), the client should refetch via a normal query
* which is also RLS-gated.
* - When ids are provided, only events for those specific rows are delivered,
* preventing cross-tenant event leaks.
*/

import { context as grafastContext, listen, object, constant, lambda } from 'grafast';
Expand Down Expand Up @@ -161,7 +169,7 @@ function discoverRealtimeTables(build: any): RealtimeTableInfo[] {
function buildTypeDefs(tables: RealtimeTableInfo[]): string {
const subscriptionFields = tables
.map(({ fieldName, payloadTypeName }) =>
` """Subscribe to changes on this table. Pass an id to watch a specific record."""\n ${fieldName}(id: UUID): ${payloadTypeName}`
` """Subscribe to changes on this table. Pass ids to watch specific rows, or no args for the full collection."""\n ${fieldName}(ids: [UUID!]): ${payloadTypeName}`
)
.join('\n');

Expand All @@ -173,7 +181,7 @@ function buildTypeDefs(tables: RealtimeTableInfo[]): string {
` event: String!\n` +
` """The current state of the row (null for DELETE, INVALIDATE, or if RLS denies access)."""\n` +
` ${rowFieldName}: ${typeName}\n` +
` """The ID of the changed row (null for INVALIDATE)."""\n` +
` """The ID of the changed row (null for INVALIDATE, or masked when RLS denies access)."""\n` +
` rowId: UUID\n` +
` """True when too many changes occurred and the client should refetch."""\n` +
` overflow: Boolean!\n` +
Expand All @@ -198,10 +206,11 @@ function buildPlans(
subscribePlan(_$root: any, args: any) {
const $pgSubscriber = (grafastContext() as any).get('pgSubscriber');
const $topic = constant(notifyChannel);
const $id = args.get('id');
const $ids = args.get('ids');

return listen($pgSubscriber, $topic, ($payload: any) => {
const $parsed = lambda($payload, (raw: unknown) => {
const $parsed = lambda([$payload, $ids], (pair: unknown) => {
const [raw, subscribedIds] = pair as readonly [unknown, string[] | null | undefined];
const parsed = parseNotifyPayload(String(raw));

const action = parsed.overflow ? 'deliver' : throttle.check();
Expand All @@ -218,12 +227,18 @@ function buildPlans(
};
}

// Sparse set filtering: only deliver events for subscribed row IDs
if (subscribedIds && subscribedIds.length > 0) {
const hasMatch = parsed.rowIds.some((rid: string) => subscribedIds.includes(rid));
if (!hasMatch) return null;
}

return parsed;
});

return object({
parsed: $parsed,
subscribedId: $id,
subscribedIds: $ids,
});
});
},
Expand All @@ -239,10 +254,17 @@ function buildPlans(
},
rowId($parent: any) {
const $parsed = $parent.get('parsed');
return lambda($parsed, (p: unknown) => {
const parsed = p as ParsedPayload | null;
if (!parsed || parsed.overflow || parsed.rowIds.length === 0) return null;
return parsed.rowIds[0];
const $subscribedIds = $parent.get('subscribedIds');
return lambda([$parsed, $subscribedIds], (pair: unknown) => {
const [p, subscribedIds] = pair as readonly [ParsedPayload | null, string[] | null | undefined];
if (!p || p.overflow || p.rowIds.length === 0) return null;

// When ids are provided, return the first matching row ID
if (subscribedIds && subscribedIds.length > 0) {
return p.rowIds.find((rid: string) => subscribedIds.includes(rid)) ?? null;
}

return p.rowIds[0];
});
},
overflow($parent: any) {
Expand All @@ -251,14 +273,21 @@ function buildPlans(
},
[rowFieldName]($parent: any) {
const $parsed = $parent.get('parsed');
const $subscribedId = $parent.get('subscribedId');
const $subscribedIds = $parent.get('subscribedIds');

const $rowId = lambda(
[$parsed, $subscribedId],
(pair: unknown) => {
const [p, subscribedId] = pair as readonly [ParsedPayload | null, string | null];
if (subscribedId) return subscribedId;
[$parsed, $subscribedIds],
(tuple: unknown) => {
const [p, subscribedIds] = tuple as readonly [
ParsedPayload | null,
string[] | null | undefined,
];
if (!p || p.overflow || p.rowIds.length === 0) return null;
// When ids are provided, return first matching row ID
if (subscribedIds && subscribedIds.length > 0) {
return p.rowIds.find((rid: string) => subscribedIds.includes(rid)) ?? null;
}
// Full collection mode: return first row ID
return p.rowIds[0];
},
);
Expand Down
Loading