From 657cb7c40a6f72bd31be13ba66f2e4284d4cc71e Mon Sep 17 00:00:00 2001 From: Konstantin Burkalev Date: Tue, 9 Dec 2025 15:41:01 +0200 Subject: [PATCH 1/3] chore(ci): Commented out dremio tests because of flaky integration tests (#10230) --- .github/workflows/push.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index c64a6de399c92..df56a8b28aecf 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -410,8 +410,8 @@ jobs: bigquery snowflake firebolt - dremio - # Athena (just to check for secrets availability) +# Athena (just to check for secrets availability) +# dremio # Commented out because of flaky integration tests DRIVERS_TESTS_ATHENA_CUBEJS_AWS_KEY: ${{ secrets.DRIVERS_TESTS_ATHENA_CUBEJS_AWS_KEY }} strategy: @@ -423,7 +423,8 @@ jobs: db: [ 'athena', 'bigquery', 'snowflake', 'trino', 'clickhouse', 'druid', 'elasticsearch', 'mssql', 'mysql', 'postgres', 'prestodb', - 'mysql-aurora-serverless', 'crate', 'mongobi', 'firebolt', 'dremio' + 'mysql-aurora-serverless', 'crate', 'mongobi', 'firebolt', +# 'dremio' # Commented out because of flaky integration tests # 'vertica' ] use_tesseract_sql_planner: [ false ] From 4c3029c4b1ccce4b251515da7a700aebe8772268 Mon Sep 17 00:00:00 2001 From: Dmitry Patsura Date: Tue, 9 Dec 2025 16:03:25 +0100 Subject: [PATCH 2/3] refactor(api-gateway): Use kebab-case for file names (#10232) --- ...bejsHandlerError.ts => cubejs-handler-error.ts} | 0 .../src/{dateParser.js => date-parser.js} | 2 +- packages/cubejs-api-gateway/src/gateway.ts | 14 +++++++------- ...{prepareAnnotation.ts => prepare-annotation.ts} | 2 +- .../helpers/{toConfigMap.ts => to-config-map.ts} | 0 ...mMetaExtended.ts => transform-meta-extended.ts} | 0 packages/cubejs-api-gateway/src/index.ts | 6 +++--- packages/cubejs-api-gateway/src/interfaces.ts | 2 +- packages/cubejs-api-gateway/src/query.js | 4 ++-- .../src/{requestParser.ts => request-parser.ts} | 0 packages/cubejs-api-gateway/src/types/responses.ts | 2 +- .../src/{UserError.ts => user-error.ts} | 2 +- .../local-subscription-store.ts} | 0 .../subscription-server.ts} | 8 ++++---- .../{dateParser.test.js => date-parser.test.js} | 2 +- ...notation.test.ts => prepare-annotation.test.ts} | 4 ++-- ...ded.test.ts => transform-meta-extended.test.ts} | 2 +- 17 files changed, 25 insertions(+), 25 deletions(-) rename packages/cubejs-api-gateway/src/{CubejsHandlerError.ts => cubejs-handler-error.ts} (100%) rename packages/cubejs-api-gateway/src/{dateParser.js => date-parser.js} (99%) rename packages/cubejs-api-gateway/src/helpers/{prepareAnnotation.ts => prepare-annotation.ts} (98%) rename packages/cubejs-api-gateway/src/helpers/{toConfigMap.ts => to-config-map.ts} (100%) rename packages/cubejs-api-gateway/src/helpers/{transformMetaExtended.ts => transform-meta-extended.ts} (100%) rename packages/cubejs-api-gateway/src/{requestParser.ts => request-parser.ts} (100%) rename packages/cubejs-api-gateway/src/{UserError.ts => user-error.ts} (69%) rename packages/cubejs-api-gateway/src/{LocalSubscriptionStore.ts => ws/local-subscription-store.ts} (100%) rename packages/cubejs-api-gateway/src/{SubscriptionServer.ts => ws/subscription-server.ts} (95%) rename packages/cubejs-api-gateway/test/{dateParser.test.js => date-parser.test.js} (99%) rename packages/cubejs-api-gateway/test/helpers/{prepareAnnotation.test.ts => prepare-annotation.test.ts} (98%) rename packages/cubejs-api-gateway/test/helpers/{transformMetaExtended.test.ts => transform-meta-extended.test.ts} (99%) diff --git a/packages/cubejs-api-gateway/src/CubejsHandlerError.ts b/packages/cubejs-api-gateway/src/cubejs-handler-error.ts similarity index 100% rename from packages/cubejs-api-gateway/src/CubejsHandlerError.ts rename to packages/cubejs-api-gateway/src/cubejs-handler-error.ts diff --git a/packages/cubejs-api-gateway/src/dateParser.js b/packages/cubejs-api-gateway/src/date-parser.js similarity index 99% rename from packages/cubejs-api-gateway/src/dateParser.js rename to packages/cubejs-api-gateway/src/date-parser.js index 136f5d86878a8..61ffb55acc44d 100644 --- a/packages/cubejs-api-gateway/src/dateParser.js +++ b/packages/cubejs-api-gateway/src/date-parser.js @@ -1,7 +1,7 @@ import moment from 'moment-timezone'; import { parse } from 'chrono-node'; -import { UserError } from './UserError'; +import { UserError } from './user-error'; const momentFromResult = (result, timezone) => { const dateMoment = moment().tz(timezone); diff --git a/packages/cubejs-api-gateway/src/gateway.ts b/packages/cubejs-api-gateway/src/gateway.ts index b5ef3ba78a8d8..4031a81582ad7 100644 --- a/packages/cubejs-api-gateway/src/gateway.ts +++ b/packages/cubejs-api-gateway/src/gateway.ts @@ -77,11 +77,11 @@ import { ContextRejectionMiddlewareFn, ContextAcceptorFn, } from './interfaces'; -import { getRequestIdFromRequest, requestParser } from './requestParser'; -import { UserError } from './UserError'; -import { CubejsHandlerError } from './CubejsHandlerError'; -import { SubscriptionServer, WebSocketSendMessageFn } from './SubscriptionServer'; -import { LocalSubscriptionStore } from './LocalSubscriptionStore'; +import { getRequestIdFromRequest, requestParser } from './request-parser'; +import { UserError } from './user-error'; +import { CubejsHandlerError } from './cubejs-handler-error'; +import { SubscriptionServer, WebSocketSendMessageFn } from './ws/subscription-server'; +import { LocalSubscriptionStore } from './ws/local-subscription-store'; import { getPivotQuery, getQueryGranularity, @@ -97,7 +97,7 @@ import { cachedHandler } from './cached-handler'; import { createJWKsFetcher } from './jwk'; import { SQLServer, SQLServerConstructorOptions } from './sql-server'; import { getJsonQueryFromGraphQLQuery, makeSchema } from './graphql'; -import { ConfigItem, prepareAnnotation } from './helpers/prepareAnnotation'; +import { ConfigItem, prepareAnnotation } from './helpers/prepare-annotation'; import { transformCube, transformMeasure, @@ -105,7 +105,7 @@ import { transformSegment, transformJoins, transformPreAggregations, -} from './helpers/transformMetaExtended'; +} from './helpers/transform-meta-extended'; type HandleErrorOptions = { e: any, diff --git a/packages/cubejs-api-gateway/src/helpers/prepareAnnotation.ts b/packages/cubejs-api-gateway/src/helpers/prepare-annotation.ts similarity index 98% rename from packages/cubejs-api-gateway/src/helpers/prepareAnnotation.ts rename to packages/cubejs-api-gateway/src/helpers/prepare-annotation.ts index db0385e7248d4..e3c5f33d7f77d 100644 --- a/packages/cubejs-api-gateway/src/helpers/prepareAnnotation.ts +++ b/packages/cubejs-api-gateway/src/helpers/prepare-annotation.ts @@ -7,7 +7,7 @@ import R from 'ramda'; import { isPredefinedGranularity } from '@cubejs-backend/shared'; -import { MetaConfig, MetaConfigMap, toConfigMap } from './toConfigMap'; +import { MetaConfig, MetaConfigMap, toConfigMap } from './to-config-map'; import { MemberType } from '../types/strings'; import { MemberType as MemberTypeEnum } from '../types/enums'; import { MemberExpression } from '../types/query'; diff --git a/packages/cubejs-api-gateway/src/helpers/toConfigMap.ts b/packages/cubejs-api-gateway/src/helpers/to-config-map.ts similarity index 100% rename from packages/cubejs-api-gateway/src/helpers/toConfigMap.ts rename to packages/cubejs-api-gateway/src/helpers/to-config-map.ts diff --git a/packages/cubejs-api-gateway/src/helpers/transformMetaExtended.ts b/packages/cubejs-api-gateway/src/helpers/transform-meta-extended.ts similarity index 100% rename from packages/cubejs-api-gateway/src/helpers/transformMetaExtended.ts rename to packages/cubejs-api-gateway/src/helpers/transform-meta-extended.ts diff --git a/packages/cubejs-api-gateway/src/index.ts b/packages/cubejs-api-gateway/src/index.ts index a990e663bf6e8..679fc7d02cb55 100644 --- a/packages/cubejs-api-gateway/src/index.ts +++ b/packages/cubejs-api-gateway/src/index.ts @@ -1,7 +1,7 @@ export * from './gateway'; export * from './sql-server'; export * from './interfaces'; -export * from './CubejsHandlerError'; -export * from './UserError'; -export { getRequestIdFromRequest } from './requestParser'; +export * from './cubejs-handler-error'; +export * from './user-error'; +export { getRequestIdFromRequest } from './request-parser'; export { TransformDataRequest } from './types/responses'; diff --git a/packages/cubejs-api-gateway/src/interfaces.ts b/packages/cubejs-api-gateway/src/interfaces.ts index 0d1b5418fdbdd..5094494417ef2 100644 --- a/packages/cubejs-api-gateway/src/interfaces.ts +++ b/packages/cubejs-api-gateway/src/interfaces.ts @@ -53,7 +53,7 @@ import { import { ConfigItem, GranularityMeta -} from './helpers/prepareAnnotation'; +} from './helpers/prepare-annotation'; export { AliasToMemberMap, diff --git a/packages/cubejs-api-gateway/src/query.js b/packages/cubejs-api-gateway/src/query.js index 55c04def32857..a2b91358b2b35 100644 --- a/packages/cubejs-api-gateway/src/query.js +++ b/packages/cubejs-api-gateway/src/query.js @@ -3,8 +3,8 @@ import moment from 'moment'; import Joi from 'joi'; import { getEnv } from '@cubejs-backend/shared'; -import { UserError } from './UserError'; -import { dateParser } from './dateParser'; +import { UserError } from './user-error'; +import { dateParser } from './date-parser'; import { QueryType } from './types/enums'; const getQueryGranularity = (queries) => R.pipe( diff --git a/packages/cubejs-api-gateway/src/requestParser.ts b/packages/cubejs-api-gateway/src/request-parser.ts similarity index 100% rename from packages/cubejs-api-gateway/src/requestParser.ts rename to packages/cubejs-api-gateway/src/request-parser.ts diff --git a/packages/cubejs-api-gateway/src/types/responses.ts b/packages/cubejs-api-gateway/src/types/responses.ts index d5e1d8ea1e0aa..4f53d564baddd 100644 --- a/packages/cubejs-api-gateway/src/types/responses.ts +++ b/packages/cubejs-api-gateway/src/types/responses.ts @@ -1,4 +1,4 @@ -import type { ConfigItem } from '../helpers/prepareAnnotation'; +import type { ConfigItem } from '../helpers/prepare-annotation'; import type { NormalizedQuery } from './query'; import type { QueryType, ResultType } from './strings'; diff --git a/packages/cubejs-api-gateway/src/UserError.ts b/packages/cubejs-api-gateway/src/user-error.ts similarity index 69% rename from packages/cubejs-api-gateway/src/UserError.ts rename to packages/cubejs-api-gateway/src/user-error.ts index e4f907908d4d1..13e2da47fcced 100644 --- a/packages/cubejs-api-gateway/src/UserError.ts +++ b/packages/cubejs-api-gateway/src/user-error.ts @@ -1,4 +1,4 @@ -import { CubejsHandlerError } from './CubejsHandlerError'; +import { CubejsHandlerError } from './cubejs-handler-error'; export class UserError extends CubejsHandlerError { public constructor(message: string) { diff --git a/packages/cubejs-api-gateway/src/LocalSubscriptionStore.ts b/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts similarity index 100% rename from packages/cubejs-api-gateway/src/LocalSubscriptionStore.ts rename to packages/cubejs-api-gateway/src/ws/local-subscription-store.ts diff --git a/packages/cubejs-api-gateway/src/SubscriptionServer.ts b/packages/cubejs-api-gateway/src/ws/subscription-server.ts similarity index 95% rename from packages/cubejs-api-gateway/src/SubscriptionServer.ts rename to packages/cubejs-api-gateway/src/ws/subscription-server.ts index 15cbe057c0b52..26aa0eb2130d4 100644 --- a/packages/cubejs-api-gateway/src/SubscriptionServer.ts +++ b/packages/cubejs-api-gateway/src/ws/subscription-server.ts @@ -1,9 +1,9 @@ import { v4 as uuidv4 } from 'uuid'; -import { UserError } from './UserError'; -import type { ApiGateway } from './gateway'; -import type { LocalSubscriptionStore } from './LocalSubscriptionStore'; -import { ExtendedRequestContext, ContextAcceptorFn } from './interfaces'; +import { UserError } from '../user-error'; +import type { ApiGateway } from '../gateway'; +import type { LocalSubscriptionStore } from './local-subscription-store'; +import { ExtendedRequestContext, ContextAcceptorFn } from '../interfaces'; const methodParams: Record = { load: ['query', 'queryType'], diff --git a/packages/cubejs-api-gateway/test/dateParser.test.js b/packages/cubejs-api-gateway/test/date-parser.test.js similarity index 99% rename from packages/cubejs-api-gateway/test/dateParser.test.js rename to packages/cubejs-api-gateway/test/date-parser.test.js index 46acce9fa8d82..3ce1710e8ede4 100644 --- a/packages/cubejs-api-gateway/test/dateParser.test.js +++ b/packages/cubejs-api-gateway/test/date-parser.test.js @@ -1,6 +1,6 @@ /* globals describe,test,expect,jest */ -import { dateParser } from '../src/dateParser'; +import { dateParser } from '../src/date-parser'; describe('dateParser', () => { test('custom daily ranges returns day aligned dateRange', () => { diff --git a/packages/cubejs-api-gateway/test/helpers/prepareAnnotation.test.ts b/packages/cubejs-api-gateway/test/helpers/prepare-annotation.test.ts similarity index 98% rename from packages/cubejs-api-gateway/test/helpers/prepareAnnotation.test.ts rename to packages/cubejs-api-gateway/test/helpers/prepare-annotation.test.ts index 2977fdce0c74c..2fe52c50fb52d 100644 --- a/packages/cubejs-api-gateway/test/helpers/prepareAnnotation.test.ts +++ b/packages/cubejs-api-gateway/test/helpers/prepare-annotation.test.ts @@ -9,11 +9,11 @@ import { MemberType } from '../../src/types/enums'; import prepareAnnotationDef - from '../../src/helpers/prepareAnnotation'; + from '../../src/helpers/prepare-annotation'; import { annotation, prepareAnnotation, -} from '../../src/helpers/prepareAnnotation'; +} from '../../src/helpers/prepare-annotation'; describe('prepareAnnotation helpers', () => { test('export looks as expected', () => { diff --git a/packages/cubejs-api-gateway/test/helpers/transformMetaExtended.test.ts b/packages/cubejs-api-gateway/test/helpers/transform-meta-extended.test.ts similarity index 99% rename from packages/cubejs-api-gateway/test/helpers/transformMetaExtended.test.ts rename to packages/cubejs-api-gateway/test/helpers/transform-meta-extended.test.ts index 7b593b503e016..51f8a1ba5d4af 100644 --- a/packages/cubejs-api-gateway/test/helpers/transformMetaExtended.test.ts +++ b/packages/cubejs-api-gateway/test/helpers/transform-meta-extended.test.ts @@ -19,7 +19,7 @@ import { transformSegment, transformJoins, transformPreAggregations, -} from '../../src/helpers/transformMetaExtended'; +} from '../../src/helpers/transform-meta-extended'; const MOCK_USERS_CUBE = { measures: { From 62715203c70b71e371bf2ccad83609ea4b9ce2d1 Mon Sep 17 00:00:00 2001 From: Dmitry Patsura Date: Tue, 9 Dec 2025 18:02:17 +0100 Subject: [PATCH 3/3] fix: Improve WS request sanitization (#10231) --- packages/cubejs-api-gateway/package.json | 3 +- packages/cubejs-api-gateway/src/index.ts | 3 + packages/cubejs-api-gateway/src/ws/index.ts | 3 + .../src/ws/local-subscription-store.ts | 81 +++--- .../src/ws/message-schema.ts | 67 +++++ .../src/ws/subscription-server.ts | 147 +++++++--- .../test/ws/subscription-server.test.ts | 261 ++++++++++++++++++ .../cubejs-server-core/src/core/server.ts | 4 +- .../cubejs-server/src/websocket-server.ts | 16 +- yarn.lock | 5 + 10 files changed, 516 insertions(+), 74 deletions(-) create mode 100644 packages/cubejs-api-gateway/src/ws/index.ts create mode 100644 packages/cubejs-api-gateway/src/ws/message-schema.ts create mode 100644 packages/cubejs-api-gateway/test/ws/subscription-server.test.ts diff --git a/packages/cubejs-api-gateway/package.json b/packages/cubejs-api-gateway/package.json index 24bdb08e8971f..07d6e87cc0c3e 100644 --- a/packages/cubejs-api-gateway/package.json +++ b/packages/cubejs-api-gateway/package.json @@ -49,7 +49,8 @@ "nexus": "^1.1.0", "node-fetch": "^2.6.1", "ramda": "^0.27.0", - "uuid": "^8.3.2" + "uuid": "^8.3.2", + "zod": "^4.1.13" }, "devDependencies": { "@cubejs-backend/linter": "1.5.12", diff --git a/packages/cubejs-api-gateway/src/index.ts b/packages/cubejs-api-gateway/src/index.ts index 679fc7d02cb55..e635daf8e04e9 100644 --- a/packages/cubejs-api-gateway/src/index.ts +++ b/packages/cubejs-api-gateway/src/index.ts @@ -3,5 +3,8 @@ export * from './sql-server'; export * from './interfaces'; export * from './cubejs-handler-error'; export * from './user-error'; + export { getRequestIdFromRequest } from './request-parser'; export { TransformDataRequest } from './types/responses'; + +export type { SubscriptionServer } from './ws'; diff --git a/packages/cubejs-api-gateway/src/ws/index.ts b/packages/cubejs-api-gateway/src/ws/index.ts new file mode 100644 index 0000000000000..cdea01d621cb5 --- /dev/null +++ b/packages/cubejs-api-gateway/src/ws/index.ts @@ -0,0 +1,3 @@ +export * from './local-subscription-store'; +export * from './message-schema'; +export * from './subscription-server'; diff --git a/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts b/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts index a94225d9368ec..20d71cb6fcf5d 100644 --- a/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts +++ b/packages/cubejs-api-gateway/src/ws/local-subscription-store.ts @@ -2,8 +2,19 @@ interface LocalSubscriptionStoreOptions { heartBeatInterval?: number; } +export type LocalSubscriptionStoreSubscription = { + message: any, + state: any, + timestamp: Date, +}; + +export type LocalSubscriptionStoreConnection = { + subscriptions: Map, + authContext?: any, +}; + export class LocalSubscriptionStore { - protected connections = {}; + protected readonly connections: Map = new Map(); protected readonly hearBeatInterval: number; @@ -12,60 +23,68 @@ export class LocalSubscriptionStore { } public async getSubscription(connectionId: string, subscriptionId: string) { - const connection = this.getConnection(connectionId); - return connection.subscriptions[subscriptionId]; + const connection = this.getConnectionOrCreate(connectionId); + return connection.subscriptions.get(subscriptionId); } public async subscribe(connectionId: string, subscriptionId: string, subscription) { - const connection = this.getConnection(connectionId); - connection.subscriptions[subscriptionId] = { + const connection = this.getConnectionOrCreate(connectionId); + connection.subscriptions.set(subscriptionId, { ...subscription, timestamp: new Date() - }; + }); } public async unsubscribe(connectionId: string, subscriptionId: string) { - const connection = this.getConnection(connectionId); - delete connection.subscriptions[subscriptionId]; + const connection = this.getConnectionOrCreate(connectionId); + connection.subscriptions.delete(subscriptionId); } - public async getAllSubscriptions() { - return Object.keys(this.connections).map(connectionId => { - Object.keys(this.connections[connectionId].subscriptions).filter( - subscriptionId => new Date().getTime() - - this.connections[connectionId].subscriptions[subscriptionId].timestamp.getTime() > - this.hearBeatInterval * 4 * 1000 - ).forEach(subscriptionId => { delete this.connections[connectionId].subscriptions[subscriptionId]; }); - - return Object.keys(this.connections[connectionId].subscriptions) - .map(subscriptionId => ({ - connectionId, - ...this.connections[connectionId].subscriptions[subscriptionId] - })); - }).reduce((a, b) => a.concat(b), []); + public getAllSubscriptions() { + const now = Date.now(); + const staleThreshold = this.hearBeatInterval * 4 * 1000; + const result: Array<{ connectionId: string } & LocalSubscriptionStoreSubscription> = []; + + for (const [connectionId, connection] of this.connections) { + for (const [subscriptionId, subscription] of connection.subscriptions) { + if (now - subscription.timestamp.getTime() > staleThreshold) { + connection.subscriptions.delete(subscriptionId); + } + } + + for (const [, subscription] of connection.subscriptions) { + result.push({ connectionId, ...subscription }); + } + } + + return result; } - public async cleanupSubscriptions(connectionId: string) { - delete this.connections[connectionId]; + public async disconnect(connectionId: string) { + this.connections.delete(connectionId); } public async getAuthContext(connectionId: string) { - return this.getConnection(connectionId).authContext; + return this.getConnectionOrCreate(connectionId).authContext; } public async setAuthContext(connectionId: string, authContext) { - this.getConnection(connectionId).authContext = authContext; + this.getConnectionOrCreate(connectionId).authContext = authContext; } - protected getConnection(connectionId: string) { - if (!this.connections[connectionId]) { - this.connections[connectionId] = { subscriptions: {} }; + protected getConnectionOrCreate(connectionId: string): LocalSubscriptionStoreConnection { + const connect = this.connections.get(connectionId); + if (connect) { + return connect; } - return this.connections[connectionId]; + const connection = { subscriptions: new Map() }; + this.connections.set(connectionId, connection); + + return connection; } public clear() { - this.connections = {}; + this.connections.clear(); } } diff --git a/packages/cubejs-api-gateway/src/ws/message-schema.ts b/packages/cubejs-api-gateway/src/ws/message-schema.ts new file mode 100644 index 0000000000000..ca981faee5fea --- /dev/null +++ b/packages/cubejs-api-gateway/src/ws/message-schema.ts @@ -0,0 +1,67 @@ +import { z } from 'zod'; + +const messageId = z.union([z.string().max(16), z.number()]); +const requestId = z.string().max(64).optional(); + +export const authMessageSchema = z.object({ + authorization: z.string(), +}).strict(); + +export const unsubscribeMessageSchema = z.object({ + unsubscribe: z.string().max(16), +}).strict(); + +const queryParams = z.object({ + query: z.unknown(), + queryType: z.string().optional(), +}).strict(); + +const queryOnlyParams = z.object({ + query: z.unknown(), +}).strict(); + +// Method-based messages using discriminatedUnion +export const methodMessageSchema = z.discriminatedUnion('method', [ + z.object({ + method: z.literal('load'), + messageId, + requestId, + params: queryParams, + }).strict(), + z.object({ + method: z.literal('sql'), + messageId, + requestId, + params: queryOnlyParams, + }).strict(), + z.object({ + method: z.literal('dry-run'), + messageId, + requestId, + params: queryOnlyParams, + }).strict(), + z.object({ + method: z.literal('meta'), + messageId, + requestId, + params: z.object({}).strict().optional(), + }).strict(), + z.object({ + method: z.literal('subscribe'), + messageId, + requestId, + params: queryParams, + }).strict(), + z.object({ + method: z.literal('unsubscribe'), + messageId, + requestId, + params: z.object({}).strict().optional(), + }).strict(), +]); + +// Export types +export type AuthMessage = z.infer; +export type UnsubscribeMessage = z.infer; +export type MethodMessage = z.infer; +export type WsMessage = AuthMessage | UnsubscribeMessage | MethodMessage; diff --git a/packages/cubejs-api-gateway/src/ws/subscription-server.ts b/packages/cubejs-api-gateway/src/ws/subscription-server.ts index 26aa0eb2130d4..2c556d13a0cc7 100644 --- a/packages/cubejs-api-gateway/src/ws/subscription-server.ts +++ b/packages/cubejs-api-gateway/src/ws/subscription-server.ts @@ -1,18 +1,27 @@ import { v4 as uuidv4 } from 'uuid'; +import type { ZodError } from 'zod'; import { UserError } from '../user-error'; +import { ExtendedRequestContext, ContextAcceptorFn } from '../interfaces'; +import { CubejsHandlerError } from '../cubejs-handler-error'; +import { + authMessageSchema, + unsubscribeMessageSchema, + methodMessageSchema, + WsMessage, +} from './message-schema'; + import type { ApiGateway } from '../gateway'; import type { LocalSubscriptionStore } from './local-subscription-store'; -import { ExtendedRequestContext, ContextAcceptorFn } from '../interfaces'; -const methodParams: Record = { +const methodParams: Record = Object.freeze({ load: ['query', 'queryType'], sql: ['query'], 'dry-run': ['query'], meta: [], subscribe: ['query', 'queryType'], unsubscribe: [], -}; +}); const calcMessageLength = (message: unknown) => Buffer.byteLength( typeof message === 'string' ? message : JSON.stringify(message) @@ -29,42 +38,97 @@ export class SubscriptionServer { ) { } - public resultFn(connectionId: string, messageId: string, requestId: string | undefined) { + protected resultFn(connectionId: string, messageId: string | number | undefined, requestId: string | undefined, logNetworkUsage: boolean = true) { return async (message, { status } = { status: 200 }) => { - this.apiGateway.log({ - type: 'Outgoing network usage', - service: 'api-ws', - bytes: calcMessageLength(message), - }, { requestId }); + if (logNetworkUsage) { + this.apiGateway.log({ type: 'Outgoing network usage', service: 'api-ws', bytes: calcMessageLength(message), }, { requestId }); + } + return this.sendMessage(connectionId, { messageId, message, status }); }; } - public async processMessage(connectionId: string, message, isSubscription) { + protected deserializeMessage(message: any): any { + try { + return JSON.parse(message); + } catch (e: any) { + throw new CubejsHandlerError(400, 'Invalid JSON payload', e.message); + } + } + + protected mapZodError(error: ZodError): string { + return error.issues + .map(e => (e.path.length ? `${e.path.join('.')}: ${e.message}` : e.message)) + .join(', '); + } + + protected validateMessage(message: object): WsMessage { + if ('authorization' in message) { + const result = authMessageSchema.safeParse(message); + if (!result.success) { + throw new CubejsHandlerError(400, 'Invalid authorization message format', this.mapZodError(result.error)); + } + + return result.data; + } + + if ('unsubscribe' in message) { + const result = unsubscribeMessageSchema.safeParse(message); + if (!result.success) { + throw new CubejsHandlerError(400, 'Invalid unsubscribe message format', this.mapZodError(result.error)); + } + + return result.data; + } + + const result = methodMessageSchema.safeParse(message); + if (!result.success) { + throw new CubejsHandlerError(400, 'Invalid message format', this.mapZodError(result.error)); + } + + return result.data; + } + + public async processMessage(connectionId: string, body: string) { + let message: any | undefined; + + try { + message = this.deserializeMessage(body); + message = this.validateMessage(message); + + await this.handleMessage(connectionId, message, false); + } catch (e) { + this.apiGateway.handleError({ + e, + query: message?.query, + res: this.resultFn(connectionId, message?.messageId, undefined, false), + }); + } + } + + protected async handleMessage(connectionId: string, message: WsMessage, isSubscription: boolean) { let authContext: any = {}; let context: Partial = {}; const bytes = calcMessageLength(message); try { - if (typeof message === 'string') { - message = JSON.parse(message); - } - - if (message.authorization) { - authContext = { isSubscription: true, protocol: 'ws' }; + if ('authorization' in message) { + authContext = { isSubscription, protocol: 'ws' }; await this.apiGateway.checkAuthFn(authContext, message.authorization); + const acceptanceResult = await this.contextAcceptor(authContext); if (!acceptanceResult.accepted) { this.sendMessage(connectionId, acceptanceResult.rejectMessage); return; } + await this.subscriptionStore.setAuthContext(connectionId, authContext); this.sendMessage(connectionId, { handshake: true }); return; } - if (message.unsubscribe) { + if ('unsubscribe' in message) { await this.subscriptionStore.unsubscribe(connectionId, message.unsubscribe); return; } @@ -74,7 +138,6 @@ export class SubscriptionServer { } authContext = await this.subscriptionStore.getAuthContext(connectionId); - if (!authContext) { await this.sendMessage( connectionId, @@ -88,16 +151,23 @@ export class SubscriptionServer { } if (!message.method) { - throw new UserError('method is required'); + throw new UserError('Method is required'); } - if (!methodParams[message.method]) { + if (!methodParams.hasOwnProperty(message.method)) { throw new UserError(`Unsupported method: ${message.method}`); } - const baseRequestId = message.requestId || `${connectionId}-${message.messageId}`; + const subscriptionId = String(message.messageId); + const baseRequestId = message.requestId || `${connectionId}-${subscriptionId}`; const requestId = `${baseRequestId}-span-${uuidv4()}`; - context = await this.apiGateway.contextByReq(message, authContext.securityContext, requestId); + + context = await this.apiGateway.contextByReq( + // TODO: We need to standardize type for WS request type + message as any, + authContext.securityContext, + requestId + ); this.apiGateway.log({ type: 'Incoming network usage', @@ -105,13 +175,17 @@ export class SubscriptionServer { bytes, }, context); - const allowedParams = methodParams[message.method]; - const params = allowedParams.map(k => ({ [k]: (message.params || {})[k] })) - .reduce((a, b) => ({ ...a, ...b }), {}); + const collectedParams: Record = Object.create(null); - const method = message.method.replace(/[^a-z]+(.)/g, (m, chr) => chr.toUpperCase()); + if (message.params) { + for (const k of methodParams[message.method]) { + collectedParams[k] = message.params[k]; + } + } + + const method = message.method.replace(/[^a-z]+(.)/g, (_m, chr) => chr.toUpperCase()); await this.apiGateway[method]({ - ...params, + ...collectedParams, connectionId, context, signedWithPlaygroundAuthSecret: authContext.signedWithPlaygroundAuthSecret, @@ -119,36 +193,39 @@ export class SubscriptionServer { apiType: 'ws', res: this.resultFn(connectionId, message.messageId, requestId), subscriptionState: async () => { - const subscription = await this.subscriptionStore.getSubscription(connectionId, message.messageId); + const subscription = await this.subscriptionStore.getSubscription(connectionId, subscriptionId); return subscription && subscription.state; }, - subscribe: async (state) => this.subscriptionStore.subscribe(connectionId, message.messageId, { + subscribe: async (state) => this.subscriptionStore.subscribe(connectionId, subscriptionId, { message, state }), - unsubscribe: async () => this.subscriptionStore.unsubscribe(connectionId, message.messageId) + unsubscribe: async () => this.subscriptionStore.unsubscribe(connectionId, subscriptionId) }); await this.sendMessage(connectionId, { messageProcessedId: message.messageId }); } catch (e) { + const messageId = 'messageId' in message ? message.messageId : undefined; + const query = 'params' in message ? message.params?.query : undefined; + this.apiGateway.handleError({ e, - query: message.query, - res: this.resultFn(connectionId, message.messageId, context.requestId), + query, + res: this.resultFn(connectionId, messageId, context.requestId), context }); } } public async processSubscriptions() { - const allSubscriptions = await this.subscriptionStore.getAllSubscriptions(); + const allSubscriptions = this.subscriptionStore.getAllSubscriptions(); await Promise.all(allSubscriptions.map(async subscription => { - await this.processMessage(subscription.connectionId, subscription.message, true); + await this.handleMessage(subscription.connectionId, subscription.message, true); })); } public async disconnect(connectionId: string) { - await this.subscriptionStore.cleanupSubscriptions(connectionId); + await this.subscriptionStore.disconnect(connectionId); } public clear() { diff --git a/packages/cubejs-api-gateway/test/ws/subscription-server.test.ts b/packages/cubejs-api-gateway/test/ws/subscription-server.test.ts new file mode 100644 index 0000000000000..3f3709c9c4654 --- /dev/null +++ b/packages/cubejs-api-gateway/test/ws/subscription-server.test.ts @@ -0,0 +1,261 @@ +import { SubscriptionServer } from '../../src/ws/subscription-server'; + +const createMocks = () => { + const sentMessages: any[] = []; + + const mockApiGateway: any = { + checkAuthFn: jest.fn().mockResolvedValue(undefined), + contextByReq: jest.fn().mockResolvedValue({ requestId: 'test-req' }), + log: jest.fn(), + handleError: jest.fn(), + load: jest.fn().mockResolvedValue(undefined), + sql: jest.fn().mockResolvedValue(undefined), + dryRun: jest.fn().mockResolvedValue(undefined), + meta: jest.fn().mockResolvedValue(undefined), + subscribe: jest.fn().mockResolvedValue(undefined), + }; + + const mockSubscriptionStore: any = { + setAuthContext: jest.fn().mockResolvedValue(undefined), + getAuthContext: jest.fn().mockResolvedValue({ securityContext: {} }), + subscribe: jest.fn().mockResolvedValue(undefined), + unsubscribe: jest.fn().mockResolvedValue(undefined), + getSubscription: jest.fn().mockResolvedValue(null), + }; + + const mockSendMessage = jest.fn().mockImplementation(async (_connId, msg) => { + sentMessages.push(msg); + }); + + const mockContextAcceptor = jest.fn().mockResolvedValue({ accepted: true }); + + return { + mockApiGateway, + mockSubscriptionStore, + mockSendMessage, + mockContextAcceptor, + sentMessages, + }; +}; + +describe('SubscriptionServer', () => { + describe('Message Validation', () => { + it('should accept valid auth message', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + await server.processMessage('conn-1', JSON.stringify({ authorization: 'token123' })); + + expect(mockApiGateway.checkAuthFn).toHaveBeenCalled(); + expect(sentMessages).toContainEqual({ handshake: true }); + }); + + it('should accept valid unsubscribe message', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + await server.processMessage('conn-1', JSON.stringify({ unsubscribe: 'msg-1' })); + + expect(mockSubscriptionStore.unsubscribe).toHaveBeenCalledWith('conn-1', 'msg-1'); + }); + + it('should accept valid load message', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'load', + messageId: '123', + params: { query: { measures: ['Orders.count'] } } + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.load).toHaveBeenCalled(); + expect(sentMessages).toContainEqual({ messageProcessedId: '123' }); + }); + + it('should accept messageId as number', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'load', + messageId: 123, + params: { query: { measures: ['Orders.count'] } } + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.load).toHaveBeenCalled(); + expect(sentMessages).toContainEqual({ messageProcessedId: 123 }); + }); + + it('should reject invalid JSON payload', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + await server.processMessage('conn-1', 'not valid json'); + + expect(mockApiGateway.handleError).toHaveBeenCalled(); + const errorCall = mockApiGateway.handleError.mock.calls[0][0]; + expect(errorCall.e.type).toBe('Invalid JSON payload'); + }); + + it('should reject message with unknown fields', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'load', + messageId: '123', + params: { query: { measures: ['Orders.count'] } }, + fieldIsNotAllowed: true, + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.load).not.toHaveBeenCalled(); + expect(mockApiGateway.handleError).toHaveBeenCalled(); + }); + + it('should reject messageId & requestId', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'load', + messageId: '12345678901234567', // 17 chars + requestId: 'a'.repeat(65), // 65 chars + params: { query: { measures: ['Orders.count'] } }, + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.load).not.toHaveBeenCalled(); + expect(mockApiGateway.handleError).toHaveBeenCalled(); + }); + }); + + describe('Auth Flow', () => { + it('should complete successful authorization handshake', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + await server.processMessage('conn-1', JSON.stringify({ authorization: 'valid-token' })); + + expect(mockApiGateway.checkAuthFn).toHaveBeenCalledWith( + expect.objectContaining({ protocol: 'ws' }), + 'valid-token' + ); + expect(mockSubscriptionStore.setAuthContext).toHaveBeenCalled(); + expect(sentMessages).toContainEqual({ handshake: true }); + }); + + it('should reject when contextAcceptor rejects', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks(); + mockContextAcceptor.mockResolvedValue({ accepted: false, rejectMessage: { error: 'Rejected' } }); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + await server.processMessage('conn-1', JSON.stringify({ authorization: 'token' })); + + expect(mockSubscriptionStore.setAuthContext).not.toHaveBeenCalled(); + expect(sentMessages).toContainEqual({ error: 'Rejected' }); + }); + + it('should return 403 for unauthorized method call', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor, sentMessages } = createMocks(); + mockSubscriptionStore.getAuthContext.mockResolvedValue(null); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'load', + messageId: '123', + params: { query: {} } + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.load).not.toHaveBeenCalled(); + expect(sentMessages).toContainEqual({ + messageId: '123', + message: { error: 'Not authorized' }, + status: 403 + }); + }); + }); + + describe('Method Dispatch', () => { + it('should call load method correctly', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'load', + messageId: '123', + params: { query: { measures: ['Orders.count'] }, queryType: 'multi' } + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.load).toHaveBeenCalledWith( + expect.objectContaining({ + query: { measures: ['Orders.count'] }, + queryType: 'multi', + connectionId: 'conn-1', + apiType: 'ws', + }) + ); + }); + + it('should call sql method correctly', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'sql', + messageId: '123', + params: { query: { measures: ['Orders.count'] } } + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.sql).toHaveBeenCalledWith( + expect.objectContaining({ + query: { measures: ['Orders.count'] }, + connectionId: 'conn-1', + }) + ); + }); + + it('should call meta method correctly', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'meta', + messageId: '123', + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.meta).toHaveBeenCalledWith( + expect.objectContaining({ + connectionId: 'conn-1', + apiType: 'ws', + }) + ); + }); + + it('should call subscribe method correctly', async () => { + const { mockApiGateway, mockSubscriptionStore, mockSendMessage, mockContextAcceptor } = createMocks(); + const server = new SubscriptionServer(mockApiGateway, mockSendMessage, mockSubscriptionStore, mockContextAcceptor); + + const message = { + method: 'subscribe', + messageId: '123', + params: { query: { measures: ['Orders.count'] } } + }; + await server.processMessage('conn-1', JSON.stringify(message)); + + expect(mockApiGateway.subscribe).toHaveBeenCalledWith( + expect.objectContaining({ + query: { measures: ['Orders.count'] }, + connectionId: 'conn-1', + }) + ); + }); + }); +}); diff --git a/packages/cubejs-server-core/src/core/server.ts b/packages/cubejs-server-core/src/core/server.ts index da38b6cbcbf1e..bf51142ab6785 100644 --- a/packages/cubejs-server-core/src/core/server.ts +++ b/packages/cubejs-server-core/src/core/server.ts @@ -26,6 +26,8 @@ import { import type { Application as ExpressApplication } from 'express'; import { BaseDriver, DriverFactoryByDataSource } from '@cubejs-backend/query-orchestrator'; +import type { SubscriptionServer } from '@cubejs-backend/api-gateway'; + import { RefreshScheduler, ScheduledRefreshOptions } from './RefreshScheduler'; import { OrchestratorApi, OrchestratorApiOptions } from './OrchestratorApi'; import { CompilerApi } from './CompilerApi'; @@ -449,7 +451,7 @@ export class CubejsServerCore { } } - public initSubscriptionServer(sendMessage) { + public initSubscriptionServer(sendMessage): SubscriptionServer { const apiGateway = this.apiGateway(); return apiGateway.initSubscriptionServer(sendMessage); } diff --git a/packages/cubejs-server/src/websocket-server.ts b/packages/cubejs-server/src/websocket-server.ts index 5c3cf364fca1f..3d42bf5b02128 100644 --- a/packages/cubejs-server/src/websocket-server.ts +++ b/packages/cubejs-server/src/websocket-server.ts @@ -2,9 +2,11 @@ import WebSocket from 'ws'; import crypto from 'crypto'; import util from 'util'; import { CancelableInterval, createCancelableInterval } from '@cubejs-backend/shared'; + import type { CubejsServerCore } from '@cubejs-backend/server-core'; import type http from 'http'; import type https from 'https'; +import type { SubscriptionServer } from '@cubejs-backend/api-gateway'; export interface WebSocketServerOptions { processSubscriptionsInterval?: number, @@ -16,7 +18,7 @@ export class WebSocketServer { protected wsServer: WebSocket.Server | null = null; - protected subscriptionServer: any = null; + protected subscriptionServer: SubscriptionServer | null = null; public constructor( protected readonly serverCore: CubejsServerCore, @@ -63,15 +65,15 @@ export class WebSocketServer { connectionIdToSocket[connectionId] = ws; ws.on('message', async (message) => { - await this.subscriptionServer.processMessage(connectionId, message, true); + await this.subscriptionServer!.processMessage(connectionId, message as string); }); ws.on('close', async () => { - await this.subscriptionServer.disconnect(connectionId); + await this.subscriptionServer!.disconnect(connectionId); }); ws.on('error', async () => { - await this.subscriptionServer.disconnect(connectionId); + await this.subscriptionServer!.disconnect(connectionId); }); }); @@ -79,7 +81,7 @@ export class WebSocketServer { this.subscriptionsTimer = createCancelableInterval( async () => { - await this.subscriptionServer.processSubscriptions(); + await this.subscriptionServer!.processSubscriptions(); }, { interval: processSubscriptionsInterval, @@ -100,6 +102,8 @@ export class WebSocketServer { await close(); } - this.subscriptionServer.clear(); + if (this.subscriptionServer) { + this.subscriptionServer.clear(); + } } } diff --git a/yarn.lock b/yarn.lock index 54cbc073f5d76..32196ded238ce 100644 --- a/yarn.lock +++ b/yarn.lock @@ -26969,3 +26969,8 @@ zip-stream@^6.0.1: archiver-utils "^5.0.0" compress-commons "^6.0.2" readable-stream "^4.0.0" + +zod@^4.1.13: + version "4.1.13" + resolved "https://registry.yarnpkg.com/zod/-/zod-4.1.13.tgz#93699a8afe937ba96badbb0ce8be6033c0a4b6b1" + integrity sha512-AvvthqfqrAhNH9dnfmrfKzX5upOdjUVJYFqNSlkmGf64gRaTzlPwz99IHYnVs28qYAybvAlBV+H7pn0saFY4Ig==