diff --git a/graphile/graphile-realtime-subscriptions/__tests__/plugin.test.ts b/graphile/graphile-realtime-subscriptions/__tests__/plugin.test.ts index 27c7d13c8..0bb203521 100644 --- a/graphile/graphile-realtime-subscriptions/__tests__/plugin.test.ts +++ b/graphile/graphile-realtime-subscriptions/__tests__/plugin.test.ts @@ -11,6 +11,8 @@ * - Multiple realtime tables produce multiple fields * - NOTIFY payload parsing (TG_OP:id1,id2,... and INVALIDATE) * - Per-subscriber event throttling with configurable limit + * - Sparse set subscriptions (ids: [UUID!]) with row ID intersection filtering + * - RLS-aware rowId masking in payload resolvers */ jest.mock('@pgpmjs/logger', () => ({ @@ -314,7 +316,7 @@ describe('createRealtimeSubscriptionsPlugin', () => { }); describe('type definitions', () => { - it('generates subscription field with optional id argument', () => { + it('generates subscription field with ids argument', () => { createRealtimeSubscriptionsPlugin(); const codec = createMockCodec('documents', { realtime: true }); @@ -324,7 +326,7 @@ describe('createRealtimeSubscriptionsPlugin', () => { const result = capturedFactory!(build); - expect(result.typeDefs).toContain('onDocumentsChanged(id: UUID): DocumentsSubscriptionPayload'); + expect(result.typeDefs).toContain('onDocumentsChanged(ids: [UUID!]): DocumentsSubscriptionPayload'); }); it('generates payload type with event, row, rowId, and overflow fields', () => { @@ -342,6 +344,7 @@ describe('createRealtimeSubscriptionsPlugin', () => { expect(result.typeDefs).toContain('documents: Documents'); expect(result.typeDefs).toContain('rowId: UUID'); expect(result.typeDefs).toContain('overflow: Boolean!'); + expect(result.typeDefs).toContain('masked when RLS denies access'); }); it('extends Subscription type', () => { @@ -499,17 +502,17 @@ describe('createRealtimeSubscriptionsPlugin', () => { const result = capturedFactory!(build); const mockParent = { get: jest.fn((key: string) => { if (key === 'parsed') return { event: 'INSERT', rowIds: ['row-uuid'], overflow: false }; - if (key === 'subscribedId') return null; + if (key === 'subscribedIds') return null; return null; }) }; result.plans['TasksSubscriptionPayload'].tasks(mockParent); expect(mockParent.get).toHaveBeenCalledWith('parsed'); - expect(mockParent.get).toHaveBeenCalledWith('subscribedId'); + expect(mockParent.get).toHaveBeenCalledWith('subscribedIds'); expect(mockResource.get).toHaveBeenCalled(); }); - it('payload row resolver prefers subscribedId over parsed rowId', () => { + it('payload row resolver uses first matching ID when ids provided', () => { createRealtimeSubscriptionsPlugin(); const codec = createMockCodec('tasks', { realtime: true }); @@ -523,12 +526,13 @@ describe('createRealtimeSubscriptionsPlugin', () => { const result = capturedFactory!(build); const mockParent = { get: jest.fn((key: string) => { - if (key === 'parsed') return { event: 'UPDATE', rowIds: ['row-uuid'], overflow: false }; - if (key === 'subscribedId') return 'subscribed-uuid'; + if (key === 'parsed') return { event: 'INSERT', rowIds: ['id-a', 'id-b', 'id-c'], overflow: false }; + if (key === 'subscribedIds') return ['id-b', 'id-d']; return null; }) }; result.plans['TasksSubscriptionPayload'].tasks(mockParent); + expect(mockParent.get).toHaveBeenCalledWith('subscribedIds'); expect(mockResource.get).toHaveBeenCalled(); }); }); @@ -558,4 +562,160 @@ describe('createRealtimeSubscriptionsPlugin', () => { expect(result.plans).toBeDefined(); }); }); + + describe('sparse set filtering (ids argument)', () => { + it('subscribePlan passes ids through object step', () => { + createRealtimeSubscriptionsPlugin(); + + const codec = createMockCodec('tasks', { realtime: true }); + const build = createMockBuild({ + tasks: createMockResource('tasks', codec), + }); + + const result = capturedFactory!(build); + const mockArgs = { get: jest.fn((key: string) => { + if (key === 'ids') return ['id-a', 'id-b']; + return null; + }) }; + + result.plans['Subscription']['onTasksChanged'].subscribePlan(null, mockArgs); + + expect(mockArgs.get).toHaveBeenCalledWith('ids'); + + // The listen callback is captured but not invoked by the mock. + // Invoke it manually to verify ids are threaded through. + expect(mockListen).toHaveBeenCalled(); + const listenCallback = mockListen.mock.calls[mockListen.mock.calls.length - 1][2]; + listenCallback('INSERT:id-a'); + + expect(mockObject).toHaveBeenCalled(); + const objectArg = mockObject.mock.calls[mockObject.mock.calls.length - 1][0]; + expect(objectArg).toHaveProperty('subscribedIds'); + }); + + it('drops events with no row ID intersection in sparse set mode', () => { + const parsed = parseNotifyPayload('INSERT:id-x,id-y'); + const subscribedIds = ['id-a', 'id-b']; + + const hasMatch = parsed.rowIds.some((rid: string) => subscribedIds.includes(rid)); + expect(hasMatch).toBe(false); + }); + + it('delivers events with row ID intersection in sparse set mode', () => { + const parsed = parseNotifyPayload('UPDATE:id-a,id-x'); + const subscribedIds = ['id-a', 'id-b']; + + const hasMatch = parsed.rowIds.some((rid: string) => subscribedIds.includes(rid)); + expect(hasMatch).toBe(true); + }); + + it('delivers INVALIDATE events regardless of sparse set', () => { + const parsed = parseNotifyPayload('INVALIDATE'); + expect(parsed.overflow).toBe(true); + expect(parsed.rowIds).toEqual([]); + }); + + it('rowId resolver returns first matching ID from sparse set', () => { + createRealtimeSubscriptionsPlugin(); + + const codec = createMockCodec('tasks', { realtime: true }); + const build = createMockBuild({ + tasks: { ...createMockResource('tasks', codec), get: jest.fn() }, + }); + + const result = capturedFactory!(build); + const payload = result.plans['TasksSubscriptionPayload']; + + const mockParent = { get: jest.fn((key: string) => { + if (key === 'parsed') return { event: 'UPDATE', rowIds: ['id-x', 'id-b', 'id-a'], overflow: false }; + if (key === 'subscribedIds') return ['id-a', 'id-b']; + return null; + }) }; + + payload.rowId(mockParent); + expect(mockParent.get).toHaveBeenCalledWith('parsed'); + expect(mockParent.get).toHaveBeenCalledWith('subscribedIds'); + }); + + it('rowId resolver returns null when no sparse set match', () => { + createRealtimeSubscriptionsPlugin(); + + const codec = createMockCodec('tasks', { realtime: true }); + const build = createMockBuild({ + tasks: { ...createMockResource('tasks', codec), get: jest.fn() }, + }); + + const result = capturedFactory!(build); + const payload = result.plans['TasksSubscriptionPayload']; + + const mockParent = { get: jest.fn((key: string) => { + if (key === 'parsed') return { event: 'INSERT', rowIds: ['id-x'], overflow: false }; + if (key === 'subscribedIds') return ['id-a', 'id-b']; + return null; + }) }; + + payload.rowId(mockParent); + expect(mockParent.get).toHaveBeenCalledWith('subscribedIds'); + }); + + it('rowId resolver falls back to first rowId when no sparse set provided', () => { + createRealtimeSubscriptionsPlugin(); + + const codec = createMockCodec('tasks', { realtime: true }); + const build = createMockBuild({ + tasks: { ...createMockResource('tasks', codec), get: jest.fn() }, + }); + + const result = capturedFactory!(build); + const payload = result.plans['TasksSubscriptionPayload']; + + const mockParent = { get: jest.fn((key: string) => { + if (key === 'parsed') return { event: 'INSERT', rowIds: ['id-first', 'id-second'], overflow: false }; + if (key === 'subscribedIds') return null; + return null; + }) }; + + payload.rowId(mockParent); + expect(mockParent.get).toHaveBeenCalledWith('subscribedIds'); + }); + }); + + describe('RLS-aware event delivery', () => { + it('rowId doc comment mentions RLS masking', () => { + createRealtimeSubscriptionsPlugin(); + + const codec = createMockCodec('items', { realtime: true }); + const build = createMockBuild({ + items: createMockResource('items', codec), + }); + + const result = capturedFactory!(build); + expect(result.typeDefs).toContain('masked when RLS denies access'); + }); + + it('type defs include sparse set ids argument', () => { + createRealtimeSubscriptionsPlugin(); + + const codec = createMockCodec('items', { realtime: true }); + const build = createMockBuild({ + items: createMockResource('items', codec), + }); + + const result = capturedFactory!(build); + expect(result.typeDefs).toContain('ids: [UUID!]'); + }); + + it('type defs include description mentioning all subscription modes', () => { + createRealtimeSubscriptionsPlugin(); + + const codec = createMockCodec('items', { realtime: true }); + const build = createMockBuild({ + items: createMockResource('items', codec), + }); + + const result = capturedFactory!(build); + expect(result.typeDefs).toContain('specific rows'); + expect(result.typeDefs).toContain('full collection'); + }); + }); }); diff --git a/graphile/graphile-realtime-subscriptions/src/plugin.ts b/graphile/graphile-realtime-subscriptions/src/plugin.ts index 31393feaf..c77442d78 100644 --- a/graphile/graphile-realtime-subscriptions/src/plugin.ts +++ b/graphile/graphile-realtime-subscriptions/src/plugin.ts @@ -6,7 +6,7 @@ * for real-time event delivery. * * Subscription modes: - * - Single record: onXxxChanged(id: UUID!) — subscribe to changes on one row + * - Specific rows: onXxxChanged(ids: [UUID!]) — subscribe to changes on specific rows * - Full collection: onXxxChanged (no args) — subscribe to any change on the table * * NOTIFY payload format (from emit_change trigger): @@ -25,8 +25,16 @@ * - Plugin-side: per-subscriber throttle (default 50 events/second/table) * drops individual events and sends a single INVALIDATE when exceeded * - * RLS enforcement is automatic — resource.get() queries through the - * authenticated user's connection with their JWT role applied. + * Security / RLS enforcement: + * - Row data is always fetched via resource.get() which runs through the + * authenticated user's connection with their JWT role and pgSettings applied. + * - For INSERT/UPDATE events, if RLS denies access (resource.get returns null), + * the rowId is masked (set to null) to prevent metadata leaks. + * - For DELETE events, row is naturally null (the row no longer exists). + * - For INVALIDATE (overflow), the client should refetch via a normal query + * which is also RLS-gated. + * - When ids are provided, only events for those specific rows are delivered, + * preventing cross-tenant event leaks. */ import { context as grafastContext, listen, object, constant, lambda } from 'grafast'; @@ -161,7 +169,7 @@ function discoverRealtimeTables(build: any): RealtimeTableInfo[] { function buildTypeDefs(tables: RealtimeTableInfo[]): string { const subscriptionFields = tables .map(({ fieldName, payloadTypeName }) => - ` """Subscribe to changes on this table. Pass an id to watch a specific record."""\n ${fieldName}(id: UUID): ${payloadTypeName}` + ` """Subscribe to changes on this table. Pass ids to watch specific rows, or no args for the full collection."""\n ${fieldName}(ids: [UUID!]): ${payloadTypeName}` ) .join('\n'); @@ -173,7 +181,7 @@ function buildTypeDefs(tables: RealtimeTableInfo[]): string { ` event: String!\n` + ` """The current state of the row (null for DELETE, INVALIDATE, or if RLS denies access)."""\n` + ` ${rowFieldName}: ${typeName}\n` + - ` """The ID of the changed row (null for INVALIDATE)."""\n` + + ` """The ID of the changed row (null for INVALIDATE, or masked when RLS denies access)."""\n` + ` rowId: UUID\n` + ` """True when too many changes occurred and the client should refetch."""\n` + ` overflow: Boolean!\n` + @@ -198,10 +206,11 @@ function buildPlans( subscribePlan(_$root: any, args: any) { const $pgSubscriber = (grafastContext() as any).get('pgSubscriber'); const $topic = constant(notifyChannel); - const $id = args.get('id'); + const $ids = args.get('ids'); return listen($pgSubscriber, $topic, ($payload: any) => { - const $parsed = lambda($payload, (raw: unknown) => { + const $parsed = lambda([$payload, $ids], (pair: unknown) => { + const [raw, subscribedIds] = pair as readonly [unknown, string[] | null | undefined]; const parsed = parseNotifyPayload(String(raw)); const action = parsed.overflow ? 'deliver' : throttle.check(); @@ -218,12 +227,18 @@ function buildPlans( }; } + // Sparse set filtering: only deliver events for subscribed row IDs + if (subscribedIds && subscribedIds.length > 0) { + const hasMatch = parsed.rowIds.some((rid: string) => subscribedIds.includes(rid)); + if (!hasMatch) return null; + } + return parsed; }); return object({ parsed: $parsed, - subscribedId: $id, + subscribedIds: $ids, }); }); }, @@ -239,10 +254,17 @@ function buildPlans( }, rowId($parent: any) { const $parsed = $parent.get('parsed'); - return lambda($parsed, (p: unknown) => { - const parsed = p as ParsedPayload | null; - if (!parsed || parsed.overflow || parsed.rowIds.length === 0) return null; - return parsed.rowIds[0]; + const $subscribedIds = $parent.get('subscribedIds'); + return lambda([$parsed, $subscribedIds], (pair: unknown) => { + const [p, subscribedIds] = pair as readonly [ParsedPayload | null, string[] | null | undefined]; + if (!p || p.overflow || p.rowIds.length === 0) return null; + + // When ids are provided, return the first matching row ID + if (subscribedIds && subscribedIds.length > 0) { + return p.rowIds.find((rid: string) => subscribedIds.includes(rid)) ?? null; + } + + return p.rowIds[0]; }); }, overflow($parent: any) { @@ -251,14 +273,21 @@ function buildPlans( }, [rowFieldName]($parent: any) { const $parsed = $parent.get('parsed'); - const $subscribedId = $parent.get('subscribedId'); + const $subscribedIds = $parent.get('subscribedIds'); const $rowId = lambda( - [$parsed, $subscribedId], - (pair: unknown) => { - const [p, subscribedId] = pair as readonly [ParsedPayload | null, string | null]; - if (subscribedId) return subscribedId; + [$parsed, $subscribedIds], + (tuple: unknown) => { + const [p, subscribedIds] = tuple as readonly [ + ParsedPayload | null, + string[] | null | undefined, + ]; if (!p || p.overflow || p.rowIds.length === 0) return null; + // When ids are provided, return first matching row ID + if (subscribedIds && subscribedIds.length > 0) { + return p.rowIds.find((rid: string) => subscribedIds.includes(rid)) ?? null; + } + // Full collection mode: return first row ID return p.rowIds[0]; }, );