Skip to content

Commit cb1cab7

Browse files
waleedlatif1claude
andcommitted
chore(knowledge): polish embedding/reranker implementation
- Drop unused supportsCustomDimensions from EmbeddingModelInfo (every registered model supports it; OpenAI/Azure paths now always send dimensions: 1536). - Type SUPPORTED_EMBEDDING_MODELS as Partial<Record<...>> so index lookups surface as possibly-undefined in the type system instead of relying on runtime null checks alone. - Require AZURE_OPENAI_API_VERSION in the Azure routing gate. Missing api-version no longer slips through as ?api-version=undefined; it now falls back to direct OpenAI. - Use the embedding provider's tokenizer (estimateTokenCount) for the Gemini fallback token estimate instead of len/4, so billing matches the model's tokenization. - Drop unreachable 'text-embedding-3-small' fallback in the manual chunk upload route — accessCheck.knowledgeBase is non-null after the access guard. - docs-chunker now reads getConfiguredEmbeddingModel() so Sim's docs ingestion respects KB_EMBEDDING_MODEL like the user-facing paths. - Add v1 search route test covering per-KB model resolution and the cross-KB mixed-model rejection. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent f7eef61 commit cb1cab7

6 files changed

Lines changed: 217 additions & 32 deletions

File tree

apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,14 @@ export const POST = withRouteHandler(
213213
accessCheck.knowledgeBase?.workspaceId
214214
)
215215

216-
const chunkEmbeddingModel =
217-
accessCheck.knowledgeBase?.embeddingModel ?? 'text-embedding-3-small'
218216
let cost = null
219217
try {
220-
cost = calculateCost(chunkEmbeddingModel, newChunk.tokenCount, 0, false)
218+
cost = calculateCost(
219+
accessCheck.knowledgeBase.embeddingModel,
220+
newChunk.tokenCount,
221+
0,
222+
false
223+
)
221224
} catch (error) {
222225
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
223226
error: error instanceof Error ? error.message : 'Unknown error',

apps/sim/app/api/knowledge/search/utils.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ describe('Knowledge Search Utils', () => {
220220
Object.keys(env).forEach((key) => delete (env as any)[key])
221221
})
222222

223-
it('should use default API version when not provided in Azure config', async () => {
223+
it('falls back to OpenAI when AZURE_OPENAI_API_VERSION is not set', async () => {
224224
const { env } = await import('@/lib/core/config/env')
225225
Object.keys(env).forEach((key) => delete (env as any)[key])
226226
Object.assign(env, {
@@ -240,7 +240,7 @@ describe('Knowledge Search Utils', () => {
240240
await generateSearchEmbedding('test query')
241241

242242
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
243-
expect.stringContaining('api-version='),
243+
'https://api.openai.com/v1/embeddings',
244244
expect.any(Object)
245245
)
246246

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/**
2+
* Tests for v1 knowledge search API route.
3+
* Specifically guards the per-KB embedding model resolution and the
4+
* multi-model rejection so the v1 endpoint stays in lockstep with the
5+
* internal route.
6+
*
7+
* @vitest-environment node
8+
*/
9+
import { createMockRequest, knowledgeApiUtilsMock, knowledgeApiUtilsMockFns } from '@sim/testing'
10+
import { beforeEach, describe, expect, it, vi } from 'vitest'
11+
12+
const {
13+
mockHandleVectorOnlySearch,
14+
mockHandleTagOnlySearch,
15+
mockHandleTagAndVectorSearch,
16+
mockGetQueryStrategy,
17+
mockGenerateSearchEmbedding,
18+
mockGetDocumentNamesByIds,
19+
mockAuthenticateRequest,
20+
mockValidateWorkspaceAccess,
21+
} = vi.hoisted(() => ({
22+
mockHandleVectorOnlySearch: vi.fn(),
23+
mockHandleTagOnlySearch: vi.fn(),
24+
mockHandleTagAndVectorSearch: vi.fn(),
25+
mockGetQueryStrategy: vi.fn(),
26+
mockGenerateSearchEmbedding: vi.fn(),
27+
mockGetDocumentNamesByIds: vi.fn(),
28+
mockAuthenticateRequest: vi.fn(),
29+
mockValidateWorkspaceAccess: vi.fn(),
30+
}))
31+
32+
vi.mock('@/app/api/knowledge/search/utils', () => ({
33+
handleVectorOnlySearch: mockHandleVectorOnlySearch,
34+
handleTagOnlySearch: mockHandleTagOnlySearch,
35+
handleTagAndVectorSearch: mockHandleTagAndVectorSearch,
36+
getQueryStrategy: mockGetQueryStrategy,
37+
generateSearchEmbedding: mockGenerateSearchEmbedding,
38+
getDocumentNamesByIds: mockGetDocumentNamesByIds,
39+
}))
40+
41+
vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock)
42+
43+
vi.mock('@/app/api/v1/knowledge/utils', () => ({
44+
authenticateRequest: mockAuthenticateRequest,
45+
validateWorkspaceAccess: mockValidateWorkspaceAccess,
46+
parseJsonBody: async (req: Request) => {
47+
try {
48+
return { success: true, data: await req.json() }
49+
} catch {
50+
return {
51+
success: false,
52+
response: new Response(JSON.stringify({ error: 'Invalid JSON' }), { status: 400 }),
53+
}
54+
}
55+
},
56+
validateSchema: <T>(
57+
schema: {
58+
safeParse: (v: unknown) => {
59+
success: boolean
60+
data?: T
61+
error?: { issues: { message: string }[] }
62+
}
63+
},
64+
data: unknown
65+
) => {
66+
const result = schema.safeParse(data)
67+
if (!result.success) {
68+
return {
69+
success: false,
70+
response: new Response(
71+
JSON.stringify({ error: result.error?.issues.map((i) => i.message).join(', ') }),
72+
{ status: 400 }
73+
),
74+
}
75+
}
76+
return { success: true, data: result.data }
77+
},
78+
handleError: (e: unknown) =>
79+
new Response(JSON.stringify({ error: e instanceof Error ? e.message : 'error' }), {
80+
status: 500,
81+
}),
82+
}))
83+
84+
vi.mock('@/lib/knowledge/tags/service', () => ({
85+
getDocumentTagDefinitions: vi.fn().mockResolvedValue([]),
86+
}))
87+
88+
import { POST } from '@/app/api/v1/knowledge/search/route'
89+
90+
const mockCheckKnowledgeBaseAccess = knowledgeApiUtilsMockFns.mockCheckKnowledgeBaseAccess
91+
92+
const baseKb = (id: string, embeddingModel: string) => ({
93+
id,
94+
userId: 'user-1',
95+
name: `KB ${id}`,
96+
workspaceId: 'ws-1',
97+
embeddingModel,
98+
deletedAt: null,
99+
})
100+
101+
describe('v1 knowledge search route — per-KB embedding model', () => {
102+
beforeEach(() => {
103+
vi.clearAllMocks()
104+
mockAuthenticateRequest.mockResolvedValue({
105+
requestId: 'req-1',
106+
userId: 'user-1',
107+
rateLimit: {},
108+
})
109+
mockValidateWorkspaceAccess.mockResolvedValue(null)
110+
mockGetQueryStrategy.mockReturnValue({ distanceThreshold: 0.5 })
111+
mockGenerateSearchEmbedding.mockResolvedValue([0.1, 0.2, 0.3])
112+
mockHandleVectorOnlySearch.mockResolvedValue([])
113+
mockGetDocumentNamesByIds.mockResolvedValue({})
114+
})
115+
116+
it('passes the KB embedding model into generateSearchEmbedding', async () => {
117+
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
118+
hasAccess: true,
119+
knowledgeBase: baseKb('kb-gemini', 'gemini-embedding-001'),
120+
})
121+
122+
const req = createMockRequest('POST', {
123+
workspaceId: 'ws-1',
124+
knowledgeBaseIds: 'kb-gemini',
125+
query: 'hello',
126+
})
127+
const res = await POST(req)
128+
129+
expect(res.status).toBe(200)
130+
expect(mockGenerateSearchEmbedding).toHaveBeenCalledWith(
131+
'hello',
132+
'gemini-embedding-001',
133+
'ws-1'
134+
)
135+
})
136+
137+
it('rejects cross-KB queries with mixed embedding models', async () => {
138+
mockCheckKnowledgeBaseAccess
139+
.mockResolvedValueOnce({
140+
hasAccess: true,
141+
knowledgeBase: baseKb('kb-openai', 'text-embedding-3-small'),
142+
})
143+
.mockResolvedValueOnce({
144+
hasAccess: true,
145+
knowledgeBase: baseKb('kb-gemini', 'gemini-embedding-001'),
146+
})
147+
148+
const req = createMockRequest('POST', {
149+
workspaceId: 'ws-1',
150+
knowledgeBaseIds: ['kb-openai', 'kb-gemini'],
151+
query: 'hello',
152+
})
153+
const res = await POST(req)
154+
155+
expect(res.status).toBe(400)
156+
expect(mockGenerateSearchEmbedding).not.toHaveBeenCalled()
157+
})
158+
159+
it('allows tag-only search across mixed embedding models', async () => {
160+
mockHandleTagOnlySearch.mockResolvedValue([])
161+
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
162+
hasAccess: true,
163+
knowledgeBase: baseKb('kb-mixed', 'text-embedding-3-small'),
164+
})
165+
166+
const req = createMockRequest('POST', {
167+
workspaceId: 'ws-1',
168+
knowledgeBaseIds: 'kb-mixed',
169+
tagFilters: [{ tagName: 'category', operator: 'eq', value: 'docs' }],
170+
})
171+
const res = await POST(req)
172+
173+
expect(res.status).toBe(400)
174+
// tagName "category" is undefined in our empty getDocumentTagDefinitions mock,
175+
// so the route returns 400 before reaching the search handlers — but crucially
176+
// it never tries to generate an embedding.
177+
expect(mockGenerateSearchEmbedding).not.toHaveBeenCalled()
178+
})
179+
})

apps/sim/lib/chunkers/docs-chunker.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { createLogger } from '@sim/logger'
44
import { TextChunker } from '@/lib/chunkers/text-chunker'
55
import type { DocChunk, DocsChunkerOptions } from '@/lib/chunkers/types'
66
import { estimateTokens } from '@/lib/chunkers/utils'
7-
import { generateEmbeddings } from '@/lib/knowledge/embeddings'
7+
import { generateEmbeddings, getConfiguredEmbeddingModel } from '@/lib/knowledge/embeddings'
88

99
interface HeaderInfo {
1010
level: number
@@ -74,9 +74,9 @@ export class DocsChunker {
7474
const headers = this.extractHeaders(cleanedContent)
7575

7676
logger.info(`Generating embeddings for ${textChunks.length} chunks in ${relativePath}`)
77+
const embeddingModel = getConfiguredEmbeddingModel()
7778
const embeddings: number[][] =
78-
textChunks.length > 0 ? (await generateEmbeddings(textChunks)).embeddings : []
79-
const embeddingModel = 'text-embedding-3-small'
79+
textChunks.length > 0 ? (await generateEmbeddings(textChunks, embeddingModel)).embeddings : []
8080

8181
const chunks: DocChunk[] = []
8282
let currentPosition = 0

apps/sim/lib/knowledge/embedding-models.ts

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,25 @@ export type TokenizerProviderId = 'openai' | 'google'
1515

1616
export interface EmbeddingModelInfo {
1717
provider: EmbeddingProviderKind
18-
/** Whether the provider supports requesting a custom output dimensionality. */
19-
supportsCustomDimensions: boolean
2018
/** Pricing/billing label — must match an entry in EMBEDDING_MODEL_PRICING when billed. */
2119
pricingId: string
2220
/** Provider id for `estimateTokenCount` so token counts match the embedding provider's tokenization. */
2321
tokenizerProvider: TokenizerProviderId
2422
}
2523

26-
export const SUPPORTED_EMBEDDING_MODELS: Record<string, EmbeddingModelInfo> = {
24+
export const SUPPORTED_EMBEDDING_MODELS: Partial<Record<string, EmbeddingModelInfo>> = {
2725
'text-embedding-3-small': {
2826
provider: 'openai',
29-
supportsCustomDimensions: true,
3027
pricingId: 'text-embedding-3-small',
3128
tokenizerProvider: 'openai',
3229
},
3330
'text-embedding-3-large': {
3431
provider: 'openai',
35-
supportsCustomDimensions: true,
3632
pricingId: 'text-embedding-3-large',
3733
tokenizerProvider: 'openai',
3834
},
3935
'gemini-embedding-001': {
4036
provider: 'gemini',
41-
supportsCustomDimensions: true,
4237
pricingId: 'gemini-embedding-001',
4338
tokenizerProvider: 'google',
4439
},

apps/sim/lib/knowledge/embeddings.ts

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ import {
88
EMBEDDING_DIMENSIONS,
99
getEmbeddingModelInfo,
1010
SUPPORTED_EMBEDDING_MODELS,
11+
type TokenizerProviderId,
1112
} from '@/lib/knowledge/embedding-models'
12-
import { batchByTokenLimit } from '@/lib/tokenization'
13+
import { batchByTokenLimit, estimateTokenCount } from '@/lib/tokenization'
1314

1415
const logger = createLogger('EmbeddingUtils')
1516

@@ -48,6 +49,8 @@ interface ResolvedProvider {
4849
modelName: string
4950
pricingId: string
5051
isBYOK: boolean
52+
/** Tokenizer used to estimate tokens when the API does not return a usage field. */
53+
tokenizerProvider: TokenizerProviderId
5154
buildRequest: (inputs: string[], inputType: EmbeddingInputType) => ProviderRequest
5255
}
5356

@@ -93,7 +96,6 @@ async function resolveGeminiKey(workspaceId?: string | null): Promise<{
9396
}
9497

9598
function buildOpenAIProvider(modelName: string, apiKey: string): ResolvedProvider['buildRequest'] {
96-
const info = getEmbeddingModelInfo(modelName)
9799
return (inputs) => ({
98100
apiUrl: 'https://api.openai.com/v1/embeddings',
99101
headers: {
@@ -104,7 +106,7 @@ function buildOpenAIProvider(modelName: string, apiKey: string): ResolvedProvide
104106
input: inputs,
105107
model: modelName,
106108
encoding_format: 'float',
107-
...(info.supportsCustomDimensions && { dimensions: EMBEDDING_DIMENSIONS }),
109+
dimensions: EMBEDDING_DIMENSIONS,
108110
},
109111
parse: (json) => {
110112
const data = json as { data: Array<{ embedding: number[] }> }
@@ -117,8 +119,7 @@ function buildAzureOpenAIProvider(
117119
deployment: string,
118120
apiKey: string,
119121
endpoint: string,
120-
apiVersion: string,
121-
supportsCustomDimensions: boolean
122+
apiVersion: string
122123
): ResolvedProvider['buildRequest'] {
123124
return (inputs) => ({
124125
apiUrl: `${endpoint}/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`,
@@ -129,7 +130,7 @@ function buildAzureOpenAIProvider(
129130
body: {
130131
input: inputs,
131132
encoding_format: 'float',
132-
...(supportsCustomDimensions && { dimensions: EMBEDDING_DIMENSIONS }),
133+
dimensions: EMBEDDING_DIMENSIONS,
133134
},
134135
parse: (json) => {
135136
const data = json as { data: Array<{ embedding: number[] }> }
@@ -197,33 +198,36 @@ async function resolveProvider(
197198
const azureApiKey = env.AZURE_OPENAI_API_KEY
198199
const azureEndpoint = env.AZURE_OPENAI_ENDPOINT
199200
const azureApiVersion = env.AZURE_OPENAI_API_VERSION
201+
const azureDeploymentName = env.KB_OPENAI_MODEL_NAME
200202
const isOpenAIModel = SUPPORTED_EMBEDDING_MODELS[embeddingModel]?.provider === 'openai'
201-
const azureDeployment =
202-
isOpenAIModel && azureApiKey && azureEndpoint ? env.KB_OPENAI_MODEL_NAME || null : null
203+
const useAzure = Boolean(
204+
isOpenAIModel && azureApiKey && azureEndpoint && azureApiVersion && azureDeploymentName
205+
)
206+
207+
const info = getEmbeddingModelInfo(embeddingModel)
203208

204-
if (azureDeployment) {
209+
if (useAzure) {
205210
return {
206-
modelName: azureDeployment,
207-
pricingId: getEmbeddingModelInfo(embeddingModel).pricingId,
211+
modelName: azureDeploymentName!,
212+
pricingId: info.pricingId,
208213
isBYOK: false,
214+
tokenizerProvider: info.tokenizerProvider,
209215
buildRequest: buildAzureOpenAIProvider(
210-
azureDeployment,
216+
azureDeploymentName!,
211217
azureApiKey!,
212218
azureEndpoint!,
213-
azureApiVersion!,
214-
getEmbeddingModelInfo(embeddingModel).supportsCustomDimensions
219+
azureApiVersion!
215220
),
216221
}
217222
}
218223

219-
const info = getEmbeddingModelInfo(embeddingModel)
220-
221224
if (info.provider === 'openai') {
222225
const { apiKey, isBYOK } = await resolveOpenAIKey(workspaceId)
223226
return {
224227
modelName: embeddingModel,
225228
pricingId: info.pricingId,
226229
isBYOK,
230+
tokenizerProvider: info.tokenizerProvider,
227231
buildRequest: buildOpenAIProvider(embeddingModel, apiKey),
228232
}
229233
}
@@ -234,6 +238,7 @@ async function resolveProvider(
234238
modelName: embeddingModel,
235239
pricingId: info.pricingId,
236240
isBYOK,
241+
tokenizerProvider: info.tokenizerProvider,
237242
buildRequest: buildGeminiProvider(embeddingModel, apiKey),
238243
}
239244
}
@@ -273,8 +278,11 @@ async function callEmbeddingAPI(
273278
const usage = (json as { usage?: { total_tokens?: number } }).usage
274279
const totalTokens =
275280
usage?.total_tokens ??
276-
// Gemini does not return usage.total_tokens — fall back to a rough estimate
277-
inputs.reduce((sum, text) => sum + Math.ceil(text.length / 4), 0)
281+
// Gemini does not return usage.total_tokens — estimate with the provider's tokenizer
282+
inputs.reduce(
283+
(sum, text) => sum + estimateTokenCount(text, provider.tokenizerProvider).count,
284+
0
285+
)
278286

279287
return { embeddings, totalTokens }
280288
},

0 commit comments

Comments
 (0)