Skip to content

Commit beb2a47

Browse files
committed
feat: initial subagent aggregation + tests
1 parent 79f2b70 commit beb2a47

File tree

15 files changed

+1150
-73
lines changed

15 files changed

+1150
-73
lines changed

backend/src/__tests__/cost-aggregation-integration.test.ts

Lines changed: 521 additions & 0 deletions
Large diffs are not rendered by default.

backend/src/__tests__/cost-aggregation.test.ts

Lines changed: 414 additions & 0 deletions
Large diffs are not rendered by default.

backend/src/__tests__/sandbox-generator.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ describe('QuickJS Sandbox Generator', () => {
2727
agentContext: {},
2828
subagents: [],
2929
stepsRemaining: 10,
30+
creditsUsed: 0,
3031
}
3132

3233
// Base template structure - will be customized per test

backend/src/llm-apis/message-cost-tracker.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ export const saveMessage = async (value: {
543543
usesUserApiKey?: boolean
544544
chargeUser?: boolean
545545
costOverrideDollars?: number
546-
}) =>
546+
}): Promise<number> =>
547547
withLoggerContext(
548548
{
549549
messageId: value.messageId,
@@ -592,7 +592,7 @@ export const saveMessage = async (value: {
592592
},
593593
`Credits used by test user (${creditsUsed})`,
594594
)
595-
return
595+
return creditsUsed
596596
}
597597

598598
if (VERBOSE) {
@@ -625,7 +625,7 @@ export const saveMessage = async (value: {
625625
{ messageId: value.messageId, userId: value.userId },
626626
'Skipping further processing (no user ID or failed to save message).',
627627
)
628-
return null
628+
return 0
629629
}
630630

631631
const consumptionResult = await updateUserCycleUsage(
@@ -656,6 +656,6 @@ export const saveMessage = async (value: {
656656
)
657657
}
658658

659-
return savedMessageResult
659+
return creditsUsed
660660
},
661661
)

backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ export const promptAiSdkStream = async function* (
7373
thinkingBudget?: number
7474
userInputId: string
7575
maxRetries?: number
76+
onCostCalculated?: (credits: number) => Promise<void>
7677
} & Omit<Parameters<typeof streamText>[0], 'model'>,
7778
) {
7879
if (
@@ -192,7 +193,7 @@ export const promptAiSdkStream = async function* (
192193
}
193194
}
194195

195-
saveMessage({
196+
const creditsUsedPromise = saveMessage({
196197
messageId,
197198
userId: options.userId,
198199
clientSessionId: options.clientSessionId,
@@ -210,6 +211,12 @@ export const promptAiSdkStream = async function* (
210211
chargeUser: options.chargeUser ?? true,
211212
costOverrideDollars,
212213
})
214+
215+
// Call the cost callback if provided
216+
if (options.onCostCalculated) {
217+
const creditsUsed = await creditsUsedPromise
218+
await options.onCostCalculated(creditsUsed)
219+
}
213220
}
214221

215222
// TODO: figure out a nice way to unify stream & non-stream versions maybe?
@@ -222,6 +229,7 @@ export const promptAiSdk = async function (
222229
model: Model
223230
userId: string | undefined
224231
chargeUser?: boolean
232+
onCostCalculated?: (credits: number) => Promise<void>
225233
} & Omit<Parameters<typeof generateText>[0], 'model'>,
226234
): Promise<string> {
227235
if (
@@ -250,12 +258,11 @@ export const promptAiSdk = async function (
250258
model: aiSDKModel,
251259
messages: convertCbToModelMessages(options),
252260
})
253-
254261
const content = response.text
255262
const inputTokens = response.usage.inputTokens || 0
256263
const outputTokens = response.usage.inputTokens || 0
257264

258-
saveMessage({
265+
const creditsUsedPromise = saveMessage({
259266
messageId: generateCompactId(),
260267
userId: options.userId,
261268
clientSessionId: options.clientSessionId,
@@ -271,6 +278,12 @@ export const promptAiSdk = async function (
271278
chargeUser: options.chargeUser ?? true,
272279
})
273280

281+
// Call the cost callback if provided
282+
if (options.onCostCalculated) {
283+
const creditsUsed = await creditsUsedPromise
284+
await options.onCostCalculated(creditsUsed)
285+
}
286+
274287
return content
275288
}
276289

@@ -287,6 +300,7 @@ export const promptAiSdkStructured = async function <T>(options: {
287300
temperature?: number
288301
timeout?: number
289302
chargeUser?: boolean
303+
onCostCalculated?: (credits: number) => Promise<void>
290304
}): Promise<T> {
291305
if (
292306
!checkLiveUserInput(
@@ -318,12 +332,11 @@ export const promptAiSdkStructured = async function <T>(options: {
318332
const response = await (options.timeout === undefined
319333
? responsePromise
320334
: withTimeout(responsePromise, options.timeout))
321-
322335
const content = response.object
323336
const inputTokens = response.usage.inputTokens || 0
324337
const outputTokens = response.usage.inputTokens || 0
325338

326-
saveMessage({
339+
const creditsUsedPromise = saveMessage({
327340
messageId: generateCompactId(),
328341
userId: options.userId,
329342
clientSessionId: options.clientSessionId,
@@ -339,6 +352,12 @@ export const promptAiSdkStructured = async function <T>(options: {
339352
chargeUser: options.chargeUser ?? true,
340353
})
341354

355+
// Call the cost callback if provided
356+
if (options.onCostCalculated) {
357+
const creditsUsed = await creditsUsedPromise
358+
await options.onCostCalculated(creditsUsed)
359+
}
360+
342361
return content
343362
}
344363

backend/src/prompt-agent-stream.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,18 @@ export const getAgentStreamFromTemplate = (params: {
1212
fingerprintId: string
1313
userInputId: string
1414
userId: string | undefined
15+
onCostCalculated?: (credits: number) => Promise<void>
1516

1617
template: AgentTemplate
1718
}) => {
18-
const { clientSessionId, fingerprintId, userInputId, userId, template } =
19-
params
19+
const {
20+
clientSessionId,
21+
fingerprintId,
22+
userInputId,
23+
userId,
24+
onCostCalculated,
25+
template,
26+
} = params
2027

2128
if (!template) {
2229
throw new Error('Agent template is null/undefined')
@@ -34,6 +41,7 @@ export const getAgentStreamFromTemplate = (params: {
3441
userInputId,
3542
userId,
3643
maxOutputTokens: 32_000,
44+
onCostCalculated,
3745
}
3846

3947
// Add Gemini-specific options if needed

backend/src/run-agent-step.ts

Lines changed: 98 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { renderToolResults } from '@codebuff/common/tools/utils'
1010
import { buildArray } from '@codebuff/common/util/array'
1111
import { generateCompactId } from '@codebuff/common/util/string'
1212

13+
1314
import { asyncAgentManager } from './async-agent-manager'
1415
import { getFileReadingUpdates } from './get-file-reading-updates'
1516
import { checkLiveUserInput } from './live-user-inputs'
@@ -253,6 +254,31 @@ export const runAgentStep = async (
253254
userInputId,
254255
userId,
255256
template: agentTemplate,
257+
onCostCalculated: async (credits: number) => {
258+
try {
259+
agentState.creditsUsed += credits
260+
logger.debug(
261+
{
262+
agentId: agentState.agentId,
263+
credits,
264+
totalCredits: agentState.creditsUsed,
265+
},
266+
'Added LLM cost to agent state',
267+
)
268+
269+
// Transactional cost attribution: ensure costs are actually deducted
270+
// This is already handled by the saveMessage function which calls updateUserCycleUsage
271+
// If that fails, the promise rejection will bubble up and halt agent execution
272+
} catch (error) {
273+
logger.error(
274+
{ agentId: agentState.agentId, credits, error },
275+
'Failed to add cost to agent state',
276+
)
277+
throw new Error(
278+
`Cost tracking failed for agent ${agentState.agentId}: ${error}`,
279+
)
280+
}
281+
},
256282
})
257283

258284
const iterationNum = agentState.messageHistory.length
@@ -519,68 +545,91 @@ export const loopAgentSteps = async (
519545
let currentPrompt = prompt
520546
let currentParams = params
521547

522-
while (checkLiveUserInput(userId, userInputId, clientSessionId)) {
523-
// 1. Run programmatic step first if it exists
524-
if (agentTemplate.handleSteps) {
525-
const { agentState: programmaticAgentState, endTurn } =
526-
await runProgrammaticStep(currentAgentState, {
548+
try {
549+
while (checkLiveUserInput(userId, userInputId, clientSessionId)) {
550+
// 1. Run programmatic step first if it exists
551+
if (agentTemplate.handleSteps) {
552+
const { agentState: programmaticAgentState, endTurn } =
553+
await runProgrammaticStep(currentAgentState, {
554+
userId,
555+
userInputId,
556+
clientSessionId,
557+
fingerprintId,
558+
onResponseChunk,
559+
agentType,
560+
fileContext,
561+
ws,
562+
template: agentTemplate,
563+
localAgentTemplates,
564+
prompt: currentPrompt,
565+
params: currentParams,
566+
stepsComplete: shouldEndTurn,
567+
})
568+
currentAgentState = programmaticAgentState
569+
570+
if (endTurn) {
571+
shouldEndTurn = true
572+
}
573+
}
574+
575+
if (ASYNC_AGENTS_ENABLED) {
576+
const hasMessages =
577+
asyncAgentManager.getMessages(agentState.agentId).length > 0
578+
if (hasMessages) {
579+
shouldEndTurn = false
580+
}
581+
}
582+
583+
// End turn if programmatic step ended turn, or if the previous runAgentStep ended turn
584+
if (shouldEndTurn) {
585+
return {
586+
agentState: currentAgentState,
587+
}
588+
}
589+
590+
const { agentState: newAgentState, shouldEndTurn: llmShouldEndTurn } =
591+
await runAgentStep(ws, {
527592
userId,
528593
userInputId,
529594
clientSessionId,
530595
fingerprintId,
531596
onResponseChunk,
597+
localAgentTemplates,
532598
agentType,
533599
fileContext,
534-
ws,
535-
template: agentTemplate,
536-
localAgentTemplates,
600+
agentState: currentAgentState,
537601
prompt: currentPrompt,
538602
params: currentParams,
539-
stepsComplete: shouldEndTurn,
540603
})
541-
currentAgentState = programmaticAgentState
542-
543-
if (endTurn) {
544-
shouldEndTurn = true
545-
}
546-
}
547604

548-
if (ASYNC_AGENTS_ENABLED) {
549-
const hasMessages =
550-
asyncAgentManager.getMessages(agentState.agentId).length > 0
551-
if (hasMessages) {
552-
shouldEndTurn = false
553-
}
554-
}
605+
currentAgentState = newAgentState
606+
shouldEndTurn = llmShouldEndTurn
555607

556-
// End turn if programmatic step ended turn, or if the previous runAgentStep ended turn
557-
if (shouldEndTurn) {
558-
return {
559-
agentState: currentAgentState,
560-
}
608+
currentPrompt = undefined
609+
currentParams = undefined
561610
}
562611

563-
const { agentState: newAgentState, shouldEndTurn: llmShouldEndTurn } =
564-
await runAgentStep(ws, {
565-
userId,
566-
userInputId,
567-
clientSessionId,
568-
fingerprintId,
569-
onResponseChunk,
570-
localAgentTemplates,
571-
agentType,
572-
fileContext,
573-
agentState: currentAgentState,
574-
prompt: currentPrompt,
575-
params: currentParams,
576-
})
577-
578-
currentAgentState = newAgentState
579-
shouldEndTurn = llmShouldEndTurn
580-
581-
currentPrompt = undefined
582-
currentParams = undefined
612+
return { agentState: currentAgentState }
613+
} catch (error) {
614+
// Log the error but still return the state with partial costs
615+
logger.error(
616+
{
617+
error,
618+
agentId: currentAgentState.agentId,
619+
creditsUsed: currentAgentState.creditsUsed,
620+
},
621+
'Agent execution failed but returning state with partial costs',
622+
)
623+
throw error
624+
} finally {
625+
// Ensure costs are always captured, even on failure
626+
logger.debug(
627+
{
628+
agentId: currentAgentState.agentId,
629+
creditsUsed: currentAgentState.creditsUsed,
630+
status: 'completed_or_failed',
631+
},
632+
'Agent execution completed with cost tracking',
633+
)
583634
}
584-
585-
return { agentState: currentAgentState }
586635
}

backend/src/tools/handlers/tool/spawn-agent-inline.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ export const handleSpawnAgentInline = ((params: {
7777
subagents: [],
7878
messageHistory: getLatestState().messages, // Share the same message array
7979
stepsRemaining: MAX_AGENT_STEPS_DEFAULT,
80+
creditsUsed: 0,
8081
output: undefined,
8182
parentId: agentState.agentId,
8283
}

backend/src/tools/handlers/tool/spawn-agent-utils.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ export function createAgentState(
255255
subagents: [],
256256
messageHistory,
257257
stepsRemaining: MAX_AGENT_STEPS_DEFAULT,
258+
creditsUsed: 0,
258259
output: undefined,
259260
parentId: parentAgentState.agentId,
260261
}

0 commit comments

Comments
 (0)