Skip to content

Commit b5dc560

Browse files
authored
fix(kb): fixed kb race condition resulting in no chunks found (#487)
1 parent 2f78c5e commit b5dc560

File tree

2 files changed

+252
-22
lines changed

2 files changed

+252
-22
lines changed
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import { beforeEach, expect, test, vi } from 'vitest'
2+
3+
vi.mock('drizzle-orm', () => ({
4+
and: (...args: any[]) => args,
5+
eq: (...args: any[]) => args,
6+
isNull: () => true,
7+
sql: (strings: TemplateStringsArray, ...expr: any[]) => ({ strings, expr }),
8+
}))
9+
10+
const dbOps: {
11+
order: string[]
12+
insertRecords: any[][]
13+
updatePayloads: any[]
14+
} = {
15+
order: [],
16+
insertRecords: [],
17+
updatePayloads: [],
18+
}
19+
20+
var mockDocumentTable: any = {}
21+
var mockKbTable: any = {}
22+
var mockEmbeddingTable: any = {}
23+
24+
let kbRows: any[] = []
25+
let docRows: any[] = []
26+
let chunkRows: any[] = []
27+
28+
function resetDatasets() {
29+
kbRows = []
30+
docRows = []
31+
chunkRows = []
32+
}
33+
34+
vi.mock('@/db', () => {
35+
const selectBuilder = {
36+
from(table: any) {
37+
return {
38+
where() {
39+
return {
40+
limit(n: number) {
41+
const tableSymbols = Object.getOwnPropertySymbols(table || {})
42+
const baseNameSymbol = tableSymbols.find((s) => s.toString().includes('BaseName'))
43+
const tableName = baseNameSymbol ? table[baseNameSymbol] : ''
44+
45+
if (tableName === 'knowledge_base') {
46+
return Promise.resolve(kbRows.slice(0, n))
47+
}
48+
if (tableName === 'document') {
49+
return Promise.resolve(docRows.slice(0, n))
50+
}
51+
if (tableName === 'embedding') {
52+
return Promise.resolve(chunkRows.slice(0, n))
53+
}
54+
55+
return Promise.resolve([])
56+
},
57+
}
58+
},
59+
}
60+
},
61+
}
62+
63+
return {
64+
db: {
65+
select: vi.fn(() => selectBuilder),
66+
update: () => ({
67+
set: () => ({
68+
where: () => Promise.resolve(),
69+
}),
70+
}),
71+
transaction: vi.fn(async (fn: any) => {
72+
await fn({
73+
insert: (table: any) => ({
74+
values: (records: any) => {
75+
dbOps.order.push('insert')
76+
dbOps.insertRecords.push(records)
77+
return Promise.resolve()
78+
},
79+
}),
80+
update: () => ({
81+
set: (payload: any) => ({
82+
where: () => {
83+
dbOps.updatePayloads.push(payload)
84+
const label = dbOps.updatePayloads.length === 1 ? 'updateDoc' : 'updateKb'
85+
dbOps.order.push(label)
86+
return Promise.resolve()
87+
},
88+
}),
89+
}),
90+
})
91+
}),
92+
},
93+
document: mockDocumentTable,
94+
knowledgeBase: mockKbTable,
95+
embedding: mockEmbeddingTable,
96+
}
97+
})
98+
99+
vi.mock('@/lib/env', () => ({ env: { OPENAI_API_KEY: 'test-key' } }))
100+
101+
vi.mock('@/lib/documents/utils', () => ({
102+
retryWithExponentialBackoff: (fn: any) => fn(),
103+
}))
104+
105+
vi.mock('@/lib/documents/document-processor', () => ({
106+
processDocuments: vi.fn().mockResolvedValue([
107+
{
108+
chunks: [
109+
{ text: 'alpha', startIndex: 0, endIndex: 4 },
110+
{ text: 'beta', startIndex: 5, endIndex: 8 },
111+
],
112+
metadata: {
113+
filename: 'dummy',
114+
fileSize: 10,
115+
mimeType: 'text/plain',
116+
characterCount: 9,
117+
tokenCount: 3,
118+
chunkCount: 2,
119+
processingMethod: 'file-parser',
120+
},
121+
},
122+
]),
123+
}))
124+
125+
vi.stubGlobal(
126+
'fetch',
127+
vi.fn().mockResolvedValue({
128+
ok: true,
129+
json: async () => ({
130+
data: [
131+
{ embedding: [0.1, 0.2], index: 0 },
132+
{ embedding: [0.3, 0.4], index: 1 },
133+
],
134+
}),
135+
})
136+
)
137+
138+
import { processDocumentAsync } from './utils'
139+
140+
beforeEach(() => {
141+
dbOps.order.length = 0
142+
dbOps.insertRecords.length = 0
143+
dbOps.updatePayloads.length = 0
144+
})
145+
146+
test('processDocumentAsync inserts embeddings before updating document counters', async () => {
147+
await processDocumentAsync(
148+
'kb1',
149+
'doc1',
150+
{
151+
filename: 'file.txt',
152+
fileUrl: 'https://example.com/file.txt',
153+
fileSize: 10,
154+
mimeType: 'text/plain',
155+
},
156+
{}
157+
)
158+
159+
expect(dbOps.order).toEqual(['insert', 'updateDoc', 'updateKb'])
160+
161+
expect(dbOps.updatePayloads[0]).toMatchObject({
162+
processingStatus: 'completed',
163+
chunkCount: 2,
164+
})
165+
166+
expect(dbOps.insertRecords[0].length).toBe(2)
167+
})
168+
169+
import {
170+
checkChunkAccess,
171+
checkDocumentAccess,
172+
checkKnowledgeBaseAccess,
173+
generateEmbeddings,
174+
} from './utils'
175+
176+
beforeEach(() => {
177+
dbOps.order.length = 0
178+
dbOps.insertRecords.length = 0
179+
dbOps.updatePayloads.length = 0
180+
resetDatasets()
181+
})
182+
183+
test('checkKnowledgeBaseAccess returns success for owner', async () => {
184+
kbRows.push({ id: 'kb1', userId: 'user1' })
185+
const res = await checkKnowledgeBaseAccess('kb1', 'user1')
186+
expect(res.hasAccess).toBe(true)
187+
})
188+
189+
test('checkKnowledgeBaseAccess returns notFound when kb missing', async () => {
190+
const res = await checkKnowledgeBaseAccess('missing', 'user1')
191+
expect(res.hasAccess).toBe(false)
192+
expect('notFound' in res && res.notFound).toBe(true)
193+
})
194+
195+
test('checkDocumentAccess unauthorized when user mismatch', async () => {
196+
kbRows.push({ id: 'kb1', userId: 'owner' })
197+
const res = await checkDocumentAccess('kb1', 'doc1', 'intruder')
198+
expect(res.hasAccess).toBe(false)
199+
if ('reason' in res) {
200+
expect(res.reason).toBe('Unauthorized knowledge base access')
201+
}
202+
})
203+
204+
test('checkChunkAccess fails when document not completed', async () => {
205+
kbRows.push({ id: 'kb1', userId: 'user1' })
206+
docRows.push({ id: 'doc1', knowledgeBaseId: 'kb1', processingStatus: 'processing' })
207+
const res = await checkChunkAccess('kb1', 'doc1', 'chunk1', 'user1')
208+
expect(res.hasAccess).toBe(false)
209+
if ('reason' in res) {
210+
expect(res.reason).toContain('Document is not ready')
211+
}
212+
})
213+
214+
test('checkChunkAccess success path', async () => {
215+
kbRows.push({ id: 'kb1', userId: 'user1' })
216+
docRows.push({ id: 'doc1', knowledgeBaseId: 'kb1', processingStatus: 'completed' })
217+
chunkRows.push({ id: 'chunk1', documentId: 'doc1' })
218+
const res = await checkChunkAccess('kb1', 'doc1', 'chunk1', 'user1')
219+
expect(res.hasAccess).toBe(true)
220+
if ('chunk' in res) {
221+
expect(res.chunk.id).toBe('chunk1')
222+
}
223+
})
224+
225+
test('generateEmbeddings returns same length as input', async () => {
226+
const result = await generateEmbeddings(['a', 'b'])
227+
expect(result.length).toBe(2)
228+
})

apps/sim/app/api/knowledge/utils.ts

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -414,18 +414,6 @@ export async function processDocumentAsync(
414414

415415
logger.info(`[${documentId}] Embeddings generated, updating document record`)
416416

417-
await db
418-
.update(document)
419-
.set({
420-
chunkCount: processed.metadata.chunkCount,
421-
tokenCount: processed.metadata.tokenCount,
422-
characterCount: processed.metadata.characterCount,
423-
processingStatus: 'completed',
424-
processingCompletedAt: now,
425-
processingError: null,
426-
})
427-
.where(eq(document.id, documentId))
428-
429417
const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({
430418
id: crypto.randomUUID(),
431419
knowledgeBaseId,
@@ -449,17 +437,31 @@ export async function processDocumentAsync(
449437
updatedAt: now,
450438
}))
451439

452-
if (embeddingRecords.length > 0) {
453-
await db.insert(embedding).values(embeddingRecords)
454-
}
440+
await db.transaction(async (tx) => {
441+
if (embeddingRecords.length > 0) {
442+
await tx.insert(embedding).values(embeddingRecords)
443+
}
455444

456-
await db
457-
.update(knowledgeBase)
458-
.set({
459-
tokenCount: sql`${knowledgeBase.tokenCount} + ${processed.metadata.tokenCount}`,
460-
updatedAt: now,
461-
})
462-
.where(eq(knowledgeBase.id, knowledgeBaseId))
445+
await tx
446+
.update(document)
447+
.set({
448+
chunkCount: processed.metadata.chunkCount,
449+
tokenCount: processed.metadata.tokenCount,
450+
characterCount: processed.metadata.characterCount,
451+
processingStatus: 'completed',
452+
processingCompletedAt: now,
453+
processingError: null,
454+
})
455+
.where(eq(document.id, documentId))
456+
457+
await tx
458+
.update(knowledgeBase)
459+
.set({
460+
tokenCount: sql`${knowledgeBase.tokenCount} + ${processed.metadata.tokenCount}`,
461+
updatedAt: now,
462+
})
463+
.where(eq(knowledgeBase.id, knowledgeBaseId))
464+
})
463465

464466
const processingTime = Date.now() - startTime
465467
logger.info(

0 commit comments

Comments
 (0)