diff --git a/fixtures/anthropic/single_builtin_tool.txtar b/fixtures/anthropic/single_builtin_tool.txtar index 5df793b1..50ca93f1 100644 --- a/fixtures/anthropic/single_builtin_tool.txtar +++ b/fixtures/anthropic/single_builtin_tool.txtar @@ -4,6 +4,22 @@ Claude Code has builtin tools to (e.g.) explore the filesystem. { "model": "claude-sonnet-4-20250514", "max_tokens": 1024, + "tools": [ + { + "name": "Read", + "description": "Read the contents of a file at the given path.", + "input_schema": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The absolute path to the file to read" + } + }, + "required": ["file_path"] + } + } + ], "messages": [ { "role": "user", diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 7a755e06..b85bc7b4 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -103,10 +103,13 @@ func (i *interceptionBase) newErrorResponse(err error) map[string]any { } func (i *interceptionBase) injectTools() { - if i.req == nil || i.mcpProxy == nil { + if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() { return } + // Disable parallel tool calls when injectable tools are present to simplify the inner agentic loop. + i.req.ParallelToolCalls = openai.Bool(false) + // Inject tools. for _, tool := range i.mcpProxy.ListTools() { fn := openai.ChatCompletionToolUnionParam{ @@ -171,6 +174,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *err } } +func (i *interceptionBase) hasInjectableTools() bool { + return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0 +} + func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage { return openai.CompletionUsage{ CompletionTokens: ref.CompletionTokens + in.CompletionTokens, diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index ff3b78c6..550bb448 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -97,12 +97,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - // TODO: implement parallel tool calls. - // TODO: don't send if not supported by model (i.e. o4-mini). - if len(i.req.Tools) > 0 { // If no tools are specified but this setting is set, it'll cause a 400 Bad Request. - i.req.ParallelToolCalls = openai.Bool(false) - } - // Force responses to only have one choice. // It's unnecessary to generate multiple responses, and would complicate our stream processing logic if // multiple choices were returned. diff --git a/intercept/messages/base.go b/intercept/messages/base.go index c61a3a7c..37522380 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -101,19 +101,15 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) } func (i *interceptionBase) injectTools() { - if i.req == nil || i.mcpProxy == nil { + if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() { return } - tools := i.mcpProxy.ListTools() - if len(tools) == 0 { - // No injected tools: no need to influence parallel tool calling. - return - } + i.disableParallelToolCalls() // Inject tools. var injectedTools []anthropic.ToolUnionParam - for _, tool := range tools { + for _, tool := range i.mcpProxy.ListTools() { injectedTools = append(injectedTools, anthropic.ToolUnionParam{ OfTool: &anthropic.ToolParam{ InputSchema: anthropic.ToolInputSchemaParam{ @@ -137,7 +133,9 @@ func (i *interceptionBase) injectTools() { if err != nil { i.logger.Warn(context.Background(), "failed to set inject tools in request payload", slog.Error(err)) } +} +func (i *interceptionBase) disableParallelToolCalls() { // Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches. // https://github.com/coder/aibridge/issues/2 toolChoiceType := i.req.ToolChoice.GetType() @@ -163,6 +161,7 @@ func (i *interceptionBase) injectTools() { case string(constant.ValueOf[constant.None]()): // No-op; if tool_choice=none then tools are not used at all. } + var err error i.payload, err = sjson.SetBytes(i.payload, "tool_choice", i.req.ToolChoice) if err != nil { i.logger.Warn(context.Background(), "failed to set tool_choice in request payload", slog.Error(err)) @@ -315,6 +314,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err } } +func (i *interceptionBase) hasInjectableTools() bool { + return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0 +} + // accumulateUsage accumulates usage statistics from source into dest. // It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any]. // The function uses reflection to handle the differences between the types: diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 8b7c3ded..69db3878 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -326,6 +326,10 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon } } +func (i *responsesInterceptionBase) hasInjectableTools() bool { + return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0 +} + // responseCopier helper struct to send original response to the client type responseCopier struct { buff deltaBuffer diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 3e94a6cc..0c11a541 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -59,7 +59,6 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * } i.injectTools() - i.disableParallelToolCalls() var ( response *responses.Response diff --git a/intercept/responses/injected_tools.go b/intercept/responses/injected_tools.go index 8a478012..c3934fa3 100644 --- a/intercept/responses/injected_tools.go +++ b/intercept/responses/injected_tools.go @@ -16,17 +16,14 @@ import ( ) func (i *responsesInterceptionBase) injectTools() { - if i.req == nil || i.mcpProxy == nil { + if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() { return } - tools := i.mcpProxy.ListTools() - if len(tools) == 0 { - return - } + i.disableParallelToolCalls() // Inject tools. - for _, tool := range tools { + for _, tool := range i.mcpProxy.ListTools() { var params map[string]any if tool.Params != nil { @@ -67,12 +64,10 @@ func (i *responsesInterceptionBase) injectTools() { // TODO: implement parallel tool calls. func (i *responsesInterceptionBase) disableParallelToolCalls() { // Disable parallel tool calls to simplify inner agentic loop; best-effort. - if len(i.req.Tools) > 0 { - var err error - i.reqPayload, err = sjson.SetBytes(i.reqPayload, "parallel_tool_calls", false) - if err != nil { - i.logger.Warn(context.Background(), "failed to disable parallel_tool_calls", slog.Error(err)) - } + var err error + i.reqPayload, err = sjson.SetBytes(i.reqPayload, "parallel_tool_calls", false) + if err != nil { + i.logger.Warn(context.Background(), "failed to disable parallel_tool_calls", slog.Error(err)) } } diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 6925d86f..38f5771b 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -70,7 +70,6 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r } i.injectTools() - i.disableParallelToolCalls() events := eventstream.NewEventStream(ctx, i.logger.Named("sse-sender"), nil) go events.Start(w, r) diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 04f5ad14..01eb5815 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -21,6 +21,7 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" + "github.com/coder/aibridge/utils" "github.com/google/uuid" "github.com/openai/openai-go/v3" oaissestream "github.com/openai/openai-go/v3/packages/ssestream" @@ -1191,62 +1192,211 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { cases := []struct { name string + fixture []byte toolChoice any // nil, or map with "type" key. withInjectedTools bool - expectDisableParallel bool + expectDisableParallel *bool // nil = field should not be present, non-nil = expected value. expectToolChoiceTypeInRequest string }{ - // With injected tools - disable_parallel_tool_use should be set. + // With injected tools - disable_parallel_tool_use should be set to true. { name: "with injected tools: no tool_choice defined defaults to auto", + fixture: fixtures.AntSimple, toolChoice: nil, withInjectedTools: true, - expectDisableParallel: true, + expectDisableParallel: utils.PtrTo(true), expectToolChoiceTypeInRequest: toolChoiceAuto, }, { name: "with injected tools: tool_choice auto", + fixture: fixtures.AntSimple, toolChoice: map[string]any{"type": toolChoiceAuto}, withInjectedTools: true, - expectDisableParallel: true, + expectDisableParallel: utils.PtrTo(true), expectToolChoiceTypeInRequest: toolChoiceAuto, }, { name: "with injected tools: tool_choice any", + fixture: fixtures.AntSimple, toolChoice: map[string]any{"type": toolChoiceAny}, withInjectedTools: true, - expectDisableParallel: true, + expectDisableParallel: utils.PtrTo(true), expectToolChoiceTypeInRequest: toolChoiceAny, }, { name: "with injected tools: tool_choice tool", + fixture: fixtures.AntSimple, toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"}, withInjectedTools: true, - expectDisableParallel: true, + expectDisableParallel: utils.PtrTo(true), expectToolChoiceTypeInRequest: toolChoiceTool, }, { name: "with injected tools: tool_choice none", + fixture: fixtures.AntSimple, toolChoice: map[string]any{"type": toolChoiceNone}, withInjectedTools: true, - expectDisableParallel: false, + expectDisableParallel: nil, expectToolChoiceTypeInRequest: toolChoiceNone, }, - // Without injected tools - disable_parallel_tool_use should NOT be set. + // With injected tools and builtin tools - disable_parallel_tool_use should be set to true. { - name: "without injected tools: tool_choice auto", + name: "with injected and builtin tools: no tool_choice defined defaults to auto", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: nil, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected and builtin tools: tool_choice auto", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAuto}, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected and builtin tools: tool_choice any", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAny}, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAny, + }, + { + name: "with injected and builtin tools: tool_choice tool", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceTool, "name": "some_tool"}, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceTool, + }, + { + name: "with injected and builtin tools: tool_choice none", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceNone}, + withInjectedTools: true, + expectDisableParallel: nil, + expectToolChoiceTypeInRequest: toolChoiceNone, + }, + { + name: "with injected and builtin tools: request already disables parallel", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": true}, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected and builtin tools: request explicitly enables parallel", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": false}, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + // Without injected or builtin tools - disable_parallel_tool_use should NOT be set. + { + name: "without injected tools or builtin tools: tool_choice auto", + fixture: fixtures.AntSimple, toolChoice: map[string]any{"type": toolChoiceAuto}, withInjectedTools: false, - expectDisableParallel: false, + expectDisableParallel: nil, expectToolChoiceTypeInRequest: toolChoiceAuto, }, { - name: "without injected tools: tool_choice any", + name: "without injected tools or builtin tools: tool_choice any", + fixture: fixtures.AntSimple, toolChoice: map[string]any{"type": toolChoiceAny}, withInjectedTools: false, - expectDisableParallel: false, + expectDisableParallel: nil, expectToolChoiceTypeInRequest: toolChoiceAny, }, + // With builtin tools but without injected tools - disable_parallel_tool_use should NOT be set. + { + name: "with builtin tools only: tool_choice auto", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAuto}, + withInjectedTools: false, + expectDisableParallel: nil, + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with builtin tools only: tool_choice any", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAny}, + withInjectedTools: false, + expectDisableParallel: nil, + expectToolChoiceTypeInRequest: toolChoiceAny, + }, + { + name: "with builtin tools only: request explicitly disables parallel", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": true}, + withInjectedTools: false, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with builtin tools only: request explicitly enables parallel", + fixture: fixtures.AntSingleBuiltinTool, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": false}, + withInjectedTools: false, + expectDisableParallel: utils.PtrTo(false), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + // Without injected or builtin tools - disable_parallel_tool_use should be preserved if set. + { + name: "no tools: request explicitly disables parallel", + fixture: fixtures.AntSimple, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": true}, + withInjectedTools: false, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "no tools: request explicitly enables parallel", + fixture: fixtures.AntSimple, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": false}, + withInjectedTools: false, + expectDisableParallel: utils.PtrTo(false), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + // Request already has disable_parallel_tool_use set - with injected tools it should be set to true. + { + name: "with injected tools: request already disables parallel", + fixture: fixtures.AntSimple, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": true}, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "with injected tools: request explicitly enables parallel", + fixture: fixtures.AntSimple, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": false}, + withInjectedTools: true, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + // Request already has disable_parallel_tool_use set - without injected tools it should be preserved. + { + name: "without injected tools: request already disables parallel", + fixture: fixtures.AntSimple, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": true}, + withInjectedTools: false, + expectDisableParallel: utils.PtrTo(true), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, + { + name: "without injected tools: request explicitly enables parallel", + fixture: fixtures.AntSimple, + toolChoice: map[string]any{"type": toolChoiceAuto, "disable_parallel_tool_use": false}, + withInjectedTools: false, + expectDisableParallel: utils.PtrTo(false), + expectToolChoiceTypeInRequest: toolChoiceAuto, + }, } for _, tc := range cases { @@ -1264,7 +1414,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { mockMCP = newNoopMCPManager() } - fix := fixtures.Parse(t, fixtures.AntSimple) + fix := fixtures.Parse(t, tc.fixture) upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, @@ -1298,16 +1448,167 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use disableParallel, hasDisableParallel := toolChoice["disable_parallel_tool_use"].(bool) - if tc.expectDisableParallel { - require.True(t, hasDisableParallel, "expected disable_parallel_tool_use in tool_choice") - assert.True(t, disableParallel, "expected disable_parallel_tool_use to be true") - } else { - assert.False(t, hasDisableParallel, "expected disable_parallel_tool_use to not be set") + require.Equal(t, tc.expectDisableParallel != nil, hasDisableParallel, + "disable_parallel_tool_use presence mismatch") + if tc.expectDisableParallel != nil { + assert.Equal(t, *tc.expectDisableParallel, disableParallel) } }) } } +// TestChatCompletionsParallelToolCallsDisabled verifies that parallel_tool_calls +// is set to false only when injectable MCP tools are present and the request +// includes tools. +func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + fixture []byte + withInjectedTools bool + initialSetting *bool + expectedSetting *bool + }{ + // With injected tools and builtin tools: parallel_tool_calls should be forced false. + { + name: "with injected and builtin tools: parallel_tool_calls true", + fixture: fixtures.OaiChatSingleBuiltinTool, + withInjectedTools: true, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(false), + }, + { + name: "with injected and builtin tools: parallel_tool_calls false", + fixture: fixtures.OaiChatSingleBuiltinTool, + withInjectedTools: true, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), + }, + { + name: "with injected and builtin tools: parallel_tool_calls unset", + fixture: fixtures.OaiChatSingleBuiltinTool, + withInjectedTools: true, + initialSetting: nil, + expectedSetting: utils.PtrTo(false), + }, + // With injected tools but without builtin tools: parallel_tool_calls should be forced false. + { + name: "with injected tools only: parallel_tool_calls true", + fixture: fixtures.OaiChatSimple, + withInjectedTools: true, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(false), + }, + { + name: "with injected tools only: parallel_tool_calls false", + fixture: fixtures.OaiChatSimple, + withInjectedTools: true, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), + }, + { + name: "with injected tools only: parallel_tool_calls unset", + fixture: fixtures.OaiChatSimple, + withInjectedTools: true, + initialSetting: nil, + expectedSetting: utils.PtrTo(false), + }, + // With builtin tools but without injected tools: parallel_tool_calls should be preserved. + { + name: "with builtin tools only: parallel_tool_calls true", + fixture: fixtures.OaiChatSingleBuiltinTool, + withInjectedTools: false, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(true), + }, + { + name: "with builtin tools only: parallel_tool_calls false", + fixture: fixtures.OaiChatSingleBuiltinTool, + withInjectedTools: false, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), + }, + { + name: "with builtin tools only: parallel_tool_calls unset", + fixture: fixtures.OaiChatSingleBuiltinTool, + withInjectedTools: false, + initialSetting: nil, + expectedSetting: nil, + }, + // Without any tools: nothing is modified. + { + name: "no tools: parallel_tool_calls true", + fixture: fixtures.OaiChatSimple, + withInjectedTools: false, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(true), + }, + { + name: "no tools: parallel_tool_calls false", + fixture: fixtures.OaiChatSimple, + withInjectedTools: false, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), + }, + { + name: "no tools: parallel_tool_calls unset", + fixture: fixtures.OaiChatSimple, + withInjectedTools: false, + initialSetting: nil, + expectedSetting: nil, + }, + } + + for _, tc := range cases { + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, tc.fixture) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + var opts []bridgeOption + if tc.withInjectedTools { + opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer))) + } + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) + + var ( + reqBody = fix.Request() + err error + ) + if tc.initialSetting != nil { + reqBody, err = sjson.SetBytes(reqBody, "parallel_tool_calls", *tc.initialSetting) + require.NoError(t, err) + } + reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) + require.NoError(t, err) + + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + received := upstream.receivedRequests() + require.Len(t, received, 1) + + var upstreamReq map[string]any + require.NoError(t, json.Unmarshal(received[0].Body, &upstreamReq)) + + ptc, ok := upstreamReq["parallel_tool_calls"].(bool) + require.Equal(t, tc.expectedSetting != nil, ok, + "parallel_tool_calls presence mismatch") + if tc.expectedSetting != nil { + assert.Equal(t, *tc.expectedSetting, ptc) + } + }) + } + } +} + func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Parallel() diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 483ce903..eee1235f 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -3,6 +3,7 @@ package integrationtest import ( "context" "encoding/json" + "fmt" "io" "net" "net/http" @@ -18,8 +19,11 @@ import ( "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" + "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" ) type keyVal struct { @@ -438,99 +442,147 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { func TestResponsesParallelToolsOverwritten(t *testing.T) { t.Parallel() - tests := []struct { - name string - request string - streaming bool - expectParallelToolCalls bool - expectParallelToolCallsValue bool + cases := []struct { + name string + fixture [2][]byte // [blocking, streaming] fixture pair. + withInjectedTools bool + initialSetting *bool + expectedSetting *bool // nil = field should not be present, non-nil = expected value. }{ + // With injected tools and builtin tools: parallel_tool_calls should be forced false. { - name: "blocking_with_tools", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": false, "parallel_tool_calls": true, "tools": [{"type": "function", "name": "test", "parameters": {}}]}`, - streaming: false, - expectParallelToolCalls: true, - expectParallelToolCallsValue: false, + name: "with injected and builtin tools: parallel_tool_calls true", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSingleBuiltinTool, fixtures.OaiResponsesStreamingBuiltinTool}, + withInjectedTools: true, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(false), }, { - name: "streaming_with_tools", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": true, "parallel_tool_calls": true, "tools": [{"type": "function", "name": "test", "parameters": {}}]}`, - streaming: true, - expectParallelToolCalls: true, - expectParallelToolCallsValue: false, + name: "with injected and builtin tools: parallel_tool_calls false", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSingleBuiltinTool, fixtures.OaiResponsesStreamingBuiltinTool}, + withInjectedTools: true, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), }, { - name: "blocking_with_tools_no_parallel_param", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": false, "tools": [{"type": "function", "name": "test", "parameters": {}}]}`, - streaming: false, - expectParallelToolCalls: true, - expectParallelToolCallsValue: false, + name: "with injected and builtin tools: parallel_tool_calls unset", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSingleBuiltinTool, fixtures.OaiResponsesStreamingBuiltinTool}, + withInjectedTools: true, + initialSetting: nil, + expectedSetting: utils.PtrTo(false), }, + // With injected tools but without builtin tools: parallel_tool_calls should be forced false. { - name: "streaming_with_tools_no_parallel_param", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": true, "tools": [{"type": "function", "name": "test", "parameters": {}}]}`, - streaming: true, - expectParallelToolCalls: true, - expectParallelToolCallsValue: false, + name: "with injected tools only: parallel_tool_calls true", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSimple, fixtures.OaiResponsesStreamingSimple}, + withInjectedTools: true, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(false), }, { - name: "blocking_without_tools", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": false}`, - streaming: false, + name: "with injected tools only: parallel_tool_calls false", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSimple, fixtures.OaiResponsesStreamingSimple}, + withInjectedTools: true, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), }, { - name: "streaming_without_tools", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": true}`, - streaming: true, + name: "with injected tools only: parallel_tool_calls unset", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSimple, fixtures.OaiResponsesStreamingSimple}, + withInjectedTools: true, + initialSetting: nil, + expectedSetting: utils.PtrTo(false), + }, + // With builtin tools but without injected tools: parallel_tool_calls should be preserved. + { + name: "with builtin tools only: parallel_tool_calls true", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSingleBuiltinTool, fixtures.OaiResponsesStreamingBuiltinTool}, + withInjectedTools: false, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(true), + }, + { + name: "with builtin tools only: parallel_tool_calls false", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSingleBuiltinTool, fixtures.OaiResponsesStreamingBuiltinTool}, + withInjectedTools: false, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), + }, + { + name: "with builtin tools only: parallel_tool_calls unset", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSingleBuiltinTool, fixtures.OaiResponsesStreamingBuiltinTool}, + withInjectedTools: false, + initialSetting: nil, + expectedSetting: nil, }, + // Without any tools: nothing is modified. { - name: "blocking_without_tools_parallel_true", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": false, "parallel_tool_calls": true}`, - streaming: false, - expectParallelToolCalls: true, - expectParallelToolCallsValue: true, + name: "no tools: parallel_tool_calls true", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSimple, fixtures.OaiResponsesStreamingSimple}, + withInjectedTools: false, + initialSetting: utils.PtrTo(true), + expectedSetting: utils.PtrTo(true), }, { - name: "streaming_without_tools_parallel_true", - request: `{"input": "hello", "model": "gpt-4o-mini", "stream": true, "parallel_tool_calls": true}`, - streaming: true, - expectParallelToolCalls: true, - expectParallelToolCallsValue: true, + name: "no tools: parallel_tool_calls false", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSimple, fixtures.OaiResponsesStreamingSimple}, + withInjectedTools: false, + initialSetting: utils.PtrTo(false), + expectedSetting: utils.PtrTo(false), + }, + { + name: "no tools: parallel_tool_calls unset", + fixture: [2][]byte{fixtures.OaiResponsesBlockingSimple, fixtures.OaiResponsesStreamingSimple}, + withInjectedTools: false, + initialSetting: nil, + expectedSetting: nil, }, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - - fix := fixtures.OaiResponsesBlockingSimple - if tc.streaming { - fix = fixtures.OaiResponsesStreamingSimple - } - upstream := newMockUpstream(t, ctx, newFixtureResponse(fixtures.Parse(t, fix))) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + for _, tc := range cases { + for i, streaming := range []bool{false, true} { + t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, tc.fixture[i]) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + var opts []bridgeOption + if tc.withInjectedTools { + opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer))) + } + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) + + var ( + reqBody = fix.Request() + err error + ) + if tc.initialSetting != nil { + reqBody, err = sjson.SetBytes(reqBody, "parallel_tool_calls", *tc.initialSetting) + require.NoError(t, err) + } + + resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBody) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, []byte(tc.request)) - _, err := io.ReadAll(resp.Body) - require.NoError(t, err) + received := upstream.receivedRequests() + require.Len(t, received, 1) - received := upstream.receivedRequests() - require.Len(t, received, 1) + var upstreamReq map[string]any + require.NoError(t, json.Unmarshal(received[0].Body, &upstreamReq)) - var receivedRequest map[string]any - require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) - if tc.expectParallelToolCalls { - parallelToolCalls, ok := receivedRequest["parallel_tool_calls"].(bool) - require.True(t, ok, "parallel_tool_calls should be present in upstream request") - require.Equal(t, tc.expectParallelToolCallsValue, parallelToolCalls) - } else { - _, ok := receivedRequest["parallel_tool_calls"] - require.False(t, ok, "parallel_tool_calls should not be present when not set") - } - }) + ptc, ok := upstreamReq["parallel_tool_calls"].(bool) + require.Equal(t, tc.expectedSetting != nil, ok, + "parallel_tool_calls presence mismatch") + if tc.expectedSetting != nil { + assert.Equal(t, *tc.expectedSetting, ptc) + } + }) + } } }