diff --git a/graphql/server/src/middleware/__tests__/upload.test.ts b/graphql/server/src/middleware/__tests__/upload.test.ts index 4dc2e3f76d..33abc00d07 100644 --- a/graphql/server/src/middleware/__tests__/upload.test.ts +++ b/graphql/server/src/middleware/__tests__/upload.test.ts @@ -184,6 +184,9 @@ describe('createUploadAuthenticateMiddleware', () => { const res = makeRes(); const next = makeNext(); + // typed rls_settings query returns no rows (table may not exist yet) + rootPool.query.mockResolvedValueOnce({ rows: [] }); + // legacy api_modules fallback rootPool.query.mockResolvedValueOnce({ rows: [ { @@ -282,6 +285,9 @@ describe('createUploadAuthenticateMiddleware', () => { const res = makeRes(); const next = makeNext(); + // typed rls_settings query returns no rows (table may not exist yet) + rootPool.query.mockResolvedValueOnce({ rows: [] }); + // legacy api_modules fallback rootPool.query.mockResolvedValueOnce({ rows: [ { @@ -330,7 +336,13 @@ describe('createUploadAuthenticateMiddleware', () => { const res = makeRes(); const next = makeNext(); + // typed rls_settings query returns no rows + rootPool.query.mockResolvedValueOnce({ rows: [] }); + // legacy api_modules by database_id returns no rows + rootPool.query.mockResolvedValueOnce({ rows: [] }); + // typed rls_settings by dbname returns no rows rootPool.query.mockResolvedValueOnce({ rows: [] }); + // legacy api_modules by dbname returns no rows rootPool.query.mockResolvedValueOnce({ rows: [] }); await middleware(req, res, next); diff --git a/graphql/server/src/middleware/api.ts b/graphql/server/src/middleware/api.ts index 430d9e53a8..51a621a330 100644 --- a/graphql/server/src/middleware/api.ts +++ b/graphql/server/src/middleware/api.ts @@ -86,6 +86,29 @@ const RLS_MODULE_SQL = ` LIMIT 1 `; +const RLS_SETTINGS_SQL = ` + SELECT + auth_schema.schema_name AS authenticate_schema, + role_schema.schema_name AS role_schema, + auth_fn.name AS authenticate, + auth_strict_fn.name AS authenticate_strict, + role_fn.name AS current_role, + role_id_fn.name AS current_role_id, + ua_fn.name AS current_user_agent, + ip_fn.name AS current_ip_address + FROM services_public.rls_settings rs + LEFT JOIN metaschema_public.schema auth_schema ON rs.authenticate_schema_id = auth_schema.id + LEFT JOIN metaschema_public.schema role_schema ON rs.role_schema_id = role_schema.id + LEFT JOIN metaschema_public.function auth_fn ON rs.authenticate_function_id = auth_fn.id + LEFT JOIN metaschema_public.function auth_strict_fn ON rs.authenticate_strict_function_id = auth_strict_fn.id + LEFT JOIN metaschema_public.function role_fn ON rs.current_role_function_id = role_fn.id + LEFT JOIN metaschema_public.function role_id_fn ON rs.current_role_id_function_id = role_id_fn.id + LEFT JOIN metaschema_public.function ua_fn ON rs.current_user_agent_function_id = ua_fn.id + LEFT JOIN metaschema_public.function ip_fn ON rs.current_ip_address_function_id = ip_fn.id + WHERE rs.database_id = $1 + LIMIT 1 +`; + /** * Discover auth settings table location via public metaschema tables. * Joins sessions_module with metaschema_public.schema to resolve @@ -249,6 +272,24 @@ const toRlsModule = (row: RlsModuleRow | null): RlsModule | undefined => { }; }; +const toRlsModuleFromSettings = (row: RlsModuleData | null): RlsModule | undefined => { + if (!row) return undefined; + return { + authenticate: row.authenticate, + authenticateStrict: row.authenticate_strict, + privateSchema: { + schemaName: row.authenticate_schema, + }, + publicSchema: { + schemaName: row.role_schema, + }, + currentRole: row.current_role, + currentRoleId: row.current_role_id, + currentIpAddress: row.current_ip_address, + currentUserAgent: row.current_user_agent, + }; +}; + const toAuthSettings = (row: AuthSettingsRow | null): AuthSettings | undefined => { if (!row) return undefined; return { @@ -263,14 +304,14 @@ const toAuthSettings = (row: AuthSettingsRow | null): AuthSettings | undefined = }; }; -const toApiStructure = (row: ApiRow, opts: ApiOptions, rlsModuleRow?: RlsModuleRow | null, authSettingsRow?: AuthSettingsRow | null): ApiStructure => ({ +const toApiStructure = (row: ApiRow, opts: ApiOptions, rlsModule?: RlsModule, authSettingsRow?: AuthSettingsRow | null): ApiStructure => ({ apiId: row.api_id, dbname: row.dbname || opts.pg?.database || '', anonRole: row.anon_role || 'anon', roleName: row.role_name || 'authenticated', schema: row.schemas || [], apiModules: [], - rlsModule: toRlsModule(rlsModuleRow ?? null), + rlsModule, domains: [], databaseId: row.database_id, isPublic: row.is_public, @@ -329,9 +370,24 @@ const queryApiList = async (pool: Pool, isPublic: boolean): Promise => { +const queryRlsSettings = async (pool: Pool, databaseId: string): Promise => { + try { + const result = await pool.query(RLS_SETTINGS_SQL, [databaseId]); + return toRlsModuleFromSettings(result.rows[0] ?? null); + } catch { + return undefined; + } +}; + +const queryRlsModuleLegacy = async (pool: Pool, apiId: string): Promise => { const result = await pool.query(RLS_MODULE_SQL, [apiId]); - return result.rows[0] ?? null; + return toRlsModule(result.rows[0] ?? null); +}; + +const queryRlsModule = async (pool: Pool, databaseId: string, apiId: string): Promise => { + const fromSettings = await queryRlsSettings(pool, databaseId); + if (fromSettings) return fromSettings; + return queryRlsModuleLegacy(pool, apiId); }; /** @@ -423,7 +479,7 @@ const resolveApiNameHeader = async (ctx: ResolveContext): Promise { }; }; +const toRlsModuleFromSettings = (row: RlsModuleData | null): RlsModule | undefined => { + if (!row) return undefined; + return { + authenticate: row.authenticate, + authenticateStrict: row.authenticate_strict, + privateSchema: { schemaName: row.authenticate_schema }, + publicSchema: { schemaName: row.role_schema }, + currentRole: row.current_role, + currentRoleId: row.current_role_id, + currentIpAddress: row.current_ip_address, + currentUserAgent: row.current_user_agent, + }; +}; + const getBearerToken = (authorization?: string): string | null => { if (!authorization) return null; const [authType, authToken] = authorization.split(' '); @@ -120,7 +181,27 @@ const getBearerToken = (authorization?: string): string | null => { return authToken; }; +const queryRlsSettingsByDatabaseId = async (pool: Pool, databaseId: string): Promise => { + try { + const result = await pool.query(RLS_SETTINGS_BY_DATABASE_ID_SQL, [databaseId]); + return toRlsModuleFromSettings(result.rows[0] ?? null); + } catch { + return undefined; + } +}; + +const queryRlsSettingsByDbname = async (pool: Pool, dbname: string): Promise => { + try { + const result = await pool.query(RLS_SETTINGS_BY_DBNAME_SQL, [dbname]); + return toRlsModuleFromSettings(result.rows[0] ?? null); + } catch { + return undefined; + } +}; + const queryRlsModuleByDatabaseId = async (pool: Pool, databaseId: string): Promise => { + const fromSettings = await queryRlsSettingsByDatabaseId(pool, databaseId); + if (fromSettings) return fromSettings; const result = await pool.query(RLS_MODULE_BY_DATABASE_ID_SQL, [databaseId]); return toRlsModule(result.rows[0] ?? null); }; @@ -131,6 +212,8 @@ const queryRlsModuleByApiId = async (pool: Pool, apiId: string): Promise => { + const fromSettings = await queryRlsSettingsByDbname(pool, dbname); + if (fromSettings) return fromSettings; const result = await pool.query(RLS_MODULE_BY_DBNAME_SQL, [dbname]); return toRlsModule(result.rows[0] ?? null); };