diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 068474cd..98b3c551 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/http/httptest" + "slices" "strings" "sync" "testing" @@ -267,6 +268,150 @@ func TestAWSBedrockIntegration(t *testing.T) { }) } }) + + // Tests that Bedrock-incompatible fields are stripped and adaptive thinking + // is handled correctly per model. Different Bedrock model names trigger + // different behavior for beta flag filtering and field stripping. + t.Run("unsupported fields removed", func(t *testing.T) { + t.Parallel() + + // All fields in the fixture request that Bedrock may strip. Fields + // listed in a test case's expectKeptFields survive; all others must + // be absent from the forwarded body. + strippableFields := []string{ + "metadata", "service_tier", "container", "inference_geo", // always stripped + "output_config", "context_management", // stripped unless their beta flag survives + } + + cases := []struct { + name string + model string + smallFastModel string + expectThinkingType string + expectBudgetTokens int64 // 0 means budget_tokens should not be present + expectKeptFields []string // fields from strippableFields expected to survive + expectedBetaFlags []string // values expected in the anthropic_beta array in the forwarded body + }{ + // "beddel" matches no model prefix, so adaptive thinking is converted + // to enabled with budget, and all model-gated beta flags are stripped. + { + name: "beddel", + model: "beddel", + smallFastModel: "modrock", + expectThinkingType: "enabled", + expectBudgetTokens: 16000, // 32000 * 0.5 (medium effort) + expectedBetaFlags: []string{"interleaved-thinking-2025-05-14"}, + }, + // Opus 4.5 supports the effort beta, so output_config is kept. + { + name: "opus-4.5", + model: "anthropic.claude-opus-4-5-20250514-v1:0", + smallFastModel: "anthropic.claude-haiku-4-5-20241022-v1:0", + expectThinkingType: "enabled", + expectBudgetTokens: 16000, + expectKeptFields: []string{"output_config"}, + expectedBetaFlags: []string{"interleaved-thinking-2025-05-14", "effort-2025-11-24"}, + }, + // Sonnet 4.5 supports context-management beta, so context_management is kept. + { + name: "sonnet-4.5", + model: "anthropic.claude-sonnet-4-5-20241022-v2:0", + smallFastModel: "anthropic.claude-haiku-4-5-20241022-v1:0", + expectThinkingType: "enabled", + expectBudgetTokens: 16000, + expectKeptFields: []string{"context_management"}, + expectedBetaFlags: []string{"interleaved-thinking-2025-05-14", "context-management-2025-06-27"}, + }, + // Opus 4.6 supports adaptive thinking natively, so it is kept as-is. + // Neither effort nor context-management betas apply to this model. + { + name: "opus-4.6", + model: "anthropic.claude-opus-4-6-20260619-v1:0", + smallFastModel: "anthropic.claude-haiku-4-5-20241022-v1:0", + expectThinkingType: "adaptive", + expectedBetaFlags: []string{"interleaved-thinking-2025-05-14"}, + }, + } + + for _, tc := range cases { + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.AntSimpleBedrock) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + + bCfg := &config.AWSBedrock{ + Region: "us-west-2", + AccessKey: "test-access-key", + AccessKeySecret: "test-secret-key", + Model: tc.model, + SmallFastModel: tc.smallFastModel, + BaseURL: upstream.URL, + } + + recorderClient := &testutil.MockRecorder{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + b, err := aibridge.NewRequestBridge( + ctx, []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bCfg)}, + recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + require.NoError(t, err) + + mockBridgeSrv := httptest.NewUnstartedServer(b) + t.Cleanup(mockBridgeSrv.Close) + mockBridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, userID, nil) + } + mockBridgeSrv.Start() + + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + + // Send with Anthropic-Beta header containing flags that should be filtered. + req := createAnthropicMessagesReq(t, mockBridgeSrv.URL, reqBody) + req.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14,effort-2025-11-24,context-management-2025-06-27,prompt-caching-scope-2026-01-05") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + received := upstream.ReceivedRequests() + require.Len(t, received, 1) + body := received[0].Body + + // Verify strippable fields: kept only if listed in expectKeptFields. + for _, field := range strippableFields { + assert.Equal(t, slices.Contains(tc.expectKeptFields, field), gjson.GetBytes(body, field).Exists(), "field %s", field) + } + + // Verify thinking behavior. + assert.Equal(t, tc.expectThinkingType, gjson.GetBytes(body, "thinking.type").String(), "thinking type mismatch") + if tc.expectBudgetTokens > 0 { + assert.Equal(t, tc.expectBudgetTokens, gjson.GetBytes(body, "thinking.budget_tokens").Int(), "budget_tokens mismatch") + } else { + assert.False(t, gjson.GetBytes(body, "thinking.budget_tokens").Exists(), "budget_tokens should not be present") + } + + // The Bedrock SDK middleware moves Anthropic-Beta from the header + // into the body as "anthropic_beta". + betaArr := gjson.GetBytes(body, "anthropic_beta").Array() + var gotFlags []string + for _, v := range betaArr { + gotFlags = append(gotFlags, v.String()) + } + assert.Equal(t, tc.expectedBetaFlags, gotFlags, "beta flags mismatch") + + recorderClient.VerifyAllInterceptionsEnded(t) + }) + } + } + }) } func TestOpenAIChatCompletions(t *testing.T) { diff --git a/config/config.go b/config/config.go index f4107e42..b64323b7 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,9 @@ package config -import "time" +import ( + "net/http" + "time" +) const ( ProviderAnthropic = "anthropic" @@ -14,7 +17,7 @@ type Anthropic struct { APIDumpDir string CircuitBreaker *CircuitBreaker SendActorHeaders bool - ExtraHeaders map[string]string + ExtraHeaders http.Header } type AWSBedrock struct { @@ -33,7 +36,7 @@ type OpenAI struct { APIDumpDir string CircuitBreaker *CircuitBreaker SendActorHeaders bool - ExtraHeaders map[string]string + ExtraHeaders http.Header } // CircuitBreaker holds configuration for circuit breakers. diff --git a/fixtures/anthropic/simple_bedrock.txtar b/fixtures/anthropic/simple_bedrock.txtar new file mode 100644 index 00000000..45979381 --- /dev/null +++ b/fixtures/anthropic/simple_bedrock.txtar @@ -0,0 +1,51 @@ +Simple Bedrock request. Tests that fields unsupported by Bedrock are removed +and adaptive thinking is converted to enabled with a budget. Includes all +bedrockUnsupportedFields (metadata, service_tier, container, inference_geo) +and beta-gated fields (output_config, context_management). + +-- request -- +{ + "model": "claude-sonnet-4-6", + "max_tokens": 32000, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello." + } + ] + } + ], + "thinking": {"type": "adaptive"}, + "metadata": {"user_id": "session_abc123"}, + "service_tier": "auto", + "container": {"type": "ephemeral"}, + "inference_geo": {"allow": ["us"]}, + "output_config": {"effort": "medium"}, + "context_management": {"edits": [{"type": "clear_thinking_20251015", "keep": "all"}]}, + "stream": true +} + +-- streaming -- +event: message_start +data: {"type":"message_start","message":{"id":"msg_bdrk_01Test","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":4}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello! How can I help?"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":10}} + +event: message_stop +data: {"type":"message_stop"} + +-- non-streaming -- +{"id":"msg_bdrk_01Test","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[{"type":"text","text":"Hello! How can I help?"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":10}} diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go index 3c150471..6e04ba36 100644 --- a/fixtures/fixtures.go +++ b/fixtures/fixtures.go @@ -26,6 +26,9 @@ var ( //go:embed anthropic/non_stream_error.txtar AntNonStreamError []byte + + //go:embed anthropic/simple_bedrock.txtar + AntSimpleBedrock []byte ) var ( diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index ac7476bf..a09ccd3c 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -41,8 +41,10 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. - for key, value := range i.cfg.ExtraHeaders { - opts = append(opts, option.WithHeader(key, value)) + for key, values := range i.cfg.ExtraHeaders { + for _, v := range values { + opts = append(opts, option.WithHeaderAdd(key, v)) + } } // Add API dump middleware if configured diff --git a/intercept/messages/base.go b/intercept/messages/base.go index fd751865..a633a705 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -23,7 +23,6 @@ import ( "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" - "github.com/tidwall/sjson" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" @@ -32,10 +31,38 @@ import ( "cdr.dev/slog/v3" ) +// bedrockSupportedBetaFlags is the set of Anthropic-Beta flags that AWS Bedrock +// accepts. Flags not in this set cause a 400 "invalid beta flag" error. +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html +var bedrockSupportedBetaFlags = map[string]bool{ + // Supported on Claude 3.7 Sonnet. + "computer-use-2025-01-24": true, + // Supported on Claude 3.7 Sonnet and Claude 4+. + "token-efficient-tools-2025-02-19": true, + // Supported on Claude 4+ models. + "interleaved-thinking-2025-05-14": true, + // Supported on Claude 3.7 Sonnet. + "output-128k-2025-02-19": true, + // Supported on Claude 4+ models. Requires account team access. + "dev-full-thinking-2025-05-14": true, + // Supported on Claude Sonnet 4. + "context-1m-2025-08-07": true, + // Supported on Claude Sonnet 4.5 and Claude Haiku 4.5. + // Enables context_management body field for thinking block clearing. + "context-management-2025-06-27": true, + // Supported on Claude Opus 4.5. + // Enables output_config body field for effort control. + "effort-2025-11-24": true, + // Supported on Claude Opus 4.5. + "tool-search-tool-2025-10-19": true, + // Supported on Claude Opus 4.5. + "tool-examples-2025-10-29": true, +} + type interceptionBase struct { - id uuid.UUID - req *MessageNewParamsWrapper - payload []byte + id uuid.UUID + reqPayload MessagesRequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -58,7 +85,7 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, } func (i *interceptionBase) Model() string { - if i.req == nil { + if len(i.reqPayload) == 0 { return "coder-aibridge-unknown" } @@ -70,7 +97,7 @@ func (i *interceptionBase) Model() string { return model } - return string(i.req.Model) + return i.reqPayload.model() } func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { @@ -86,7 +113,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) } func (i *interceptionBase) injectTools() { - if i.req == nil || i.mcpProxy == nil { + if i.mcpProxy == nil { return } @@ -115,43 +142,21 @@ func (i *interceptionBase) injectTools() { // Prepend the injected tools in order to maintain any configured cache breakpoints. // The order of injected tools is expected to be stable, and therefore will not cause // any cache invalidation when prepended. - i.req.Tools = append(injectedTools, i.req.Tools...) - - var err error - i.payload, err = sjson.SetBytes(i.payload, "tools", i.req.Tools) + updated, err := i.reqPayload.injectTools(injectedTools) if err != nil { i.logger.Warn(context.Background(), "failed to set inject tools in request payload", slog.Error(err)) + return } + i.reqPayload = updated // Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches. // https://github.com/coder/aibridge/issues/2 - toolChoiceType := i.req.ToolChoice.GetType() - var toolChoiceTypeStr string - if toolChoiceType != nil { - toolChoiceTypeStr = *toolChoiceType - } - - switch toolChoiceTypeStr { - // If no tool_choice was defined, assume auto. - // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use. - case "", string(constant.ValueOf[constant.Auto]()): - // We only set OfAuto if no tool_choice was provided (the default). - // "auto" is the default when a zero value is provided, so we can safely disable parallel checks on it. - if i.req.ToolChoice.OfAuto == nil { - i.req.ToolChoice.OfAuto = &anthropic.ToolChoiceAutoParam{} - } - i.req.ToolChoice.OfAuto.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.Any]()): - i.req.ToolChoice.OfAny.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.Tool]()): - i.req.ToolChoice.OfTool.DisableParallelToolUse = anthropic.Bool(true) - case string(constant.ValueOf[constant.None]()): - // No-op; if tool_choice=none then tools are not used at all. - } - i.payload, err = sjson.SetBytes(i.payload, "tool_choice", i.req.ToolChoice) + updated, err = i.reqPayload.disableParallelToolCalls() if err != nil { i.logger.Warn(context.Background(), "failed to set tool_choice in request payload", slog.Error(err)) + return } + i.reqPayload = updated } // IsSmallFastModel checks if the model is a small/fast model (Haiku 3.5). @@ -159,19 +164,13 @@ func (i *interceptionBase) injectTools() { // See `ANTHROPIC_SMALL_FAST_MODEL`: https://docs.anthropic.com/en/docs/claude-code/settings#environment-variables // https://docs.claude.com/en/docs/claude-code/costs#background-token-usage func (i *interceptionBase) isSmallFastModel() bool { - return strings.Contains(string(i.req.Model), "haiku") + return strings.Contains(i.reqPayload.model(), "haiku") } func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { opts = append(opts, option.WithAPIKey(i.cfg.Key)) opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) - // Add extra headers if configured. - // Some providers require additional headers that are not added by the SDK. - for key, value := range i.cfg.ExtraHeaders { - opts = append(opts, option.WithHeader(key, value)) - } - // Add API dump middleware if configured if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) @@ -188,26 +187,23 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio i.augmentRequestForBedrock() } + // Add extra headers after augmentRequestForBedrock() so that any + // Bedrock-specific header filtering (e.g. Anthropic-Beta) is applied first. + for key, values := range i.cfg.ExtraHeaders { + for _, v := range values { + opts = append(opts, option.WithHeaderAdd(key, v)) + } + } + return anthropic.NewMessageService(opts...), nil } -// withBody returns a per-request option that sends the current i.payload as the -// request body. This is called for each API request so that the latest payload (including -// any messages appended during the agentic tool loop) is always sent. +// withBody returns a per-request option that sends the current raw request +// payload as the request body. This is called for each API request so that the +// latest payload (including any messages appended during the agentic tool loop) +// is always sent. func (i *interceptionBase) withBody() option.RequestOption { - return option.WithRequestBody("application/json", i.payload) -} - -// syncPayloadMessages updates the raw payload's "messages" field to match the given messages. -// This must be called before the next API request in the agentic loop so that -// withBody() picks up the updated messages. -func (i *interceptionBase) syncPayloadMessages(messages []anthropic.MessageParam) error { - var err error - i.payload, err = sjson.SetBytes(i.payload, "messages", messages) - if err != nil { - return fmt.Errorf("sync payload messages: %w", err) - } - return nil + return option.WithRequestBody("application/json", []byte(i.reqPayload)) } func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { @@ -258,18 +254,95 @@ func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibco } // augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support -// Anthropics' model names. +// Anthropics' model names. It also converts adaptive thinking to enabled with a budget for models that +// don't support adaptive thinking natively. func (i *interceptionBase) augmentRequestForBedrock() { if i.bedrockCfg == nil { return } - i.req.MessageNewParams.Model = anthropic.Model(i.Model()) - - var err error - i.payload, err = sjson.SetBytes(i.payload, "model", i.Model()) + model := i.Model() + updated, err := i.reqPayload.withModel(model) if err != nil { i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err)) + return + } + i.reqPayload = updated + + if !bedrockModelSupportsAdaptiveThinking(model) { + updated, err = i.reqPayload.convertAdaptiveThinkingForBedrock() + if err != nil { + i.logger.Warn(context.Background(), "failed to convert adaptive thinking for Bedrock", slog.Error(err)) + return + } + i.reqPayload = updated + } + + // Filter Anthropic-Beta header to only include Bedrock-supported flags + // that the current model supports. + filterBedrockBetaFlags(i.cfg.ExtraHeaders, model) + + // Strip body fields that Bedrock does not accept. + updated, err = i.reqPayload.removeUnsupportedBedrockFields(i.cfg.ExtraHeaders) + if err != nil { + i.logger.Warn(context.Background(), "failed to remove unsupported fields for Bedrock", slog.Error(err)) + return + } + i.reqPayload = updated +} + +// bedrockModelSupportsAdaptiveThinking returns true if the given Bedrock model ID +// supports the "adaptive" thinking type natively (i.e. Claude 4.6 models). +// See https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-adaptive-thinking.html +func bedrockModelSupportsAdaptiveThinking(model string) bool { + return strings.Contains(model, "anthropic.claude-opus-4-6") || + strings.Contains(model, "anthropic.claude-sonnet-4-6") +} + +// filterBedrockBetaFlags removes unsupported beta flags from the Anthropic-Beta +// header and also removes model-gated flags the current model doesn't support. +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html +func filterBedrockBetaFlags(headers http.Header, model string) { + // Collect all flags regardless of whether the client sent them as a single + // comma-separated value (eg. Claude Code sends them in that format) + // or as multiple separate header lines. + // https://httpwg.org/specs/rfc9110.html#rfc.section.5.3 + var flags []string + for _, v := range headers.Values("Anthropic-Beta") { + for _, flag := range strings.Split(v, ",") { + flags = append(flags, flag) + } + } + + if len(flags) == 0 { + return + } + + var keep []string + for _, flag := range flags { + trimmed := strings.TrimSpace(flag) + if !bedrockSupportedBetaFlags[trimmed] { + continue + } + + // effort is only supported in Opus 4.5 on Bedrock. + if trimmed == "effort-2025-11-24" && !strings.Contains(model, "anthropic.claude-opus-4-5") { + continue + } + + // context_management is only supported in Sonnet 4.5 and Haiku 4.5 models on Bedrock. + if trimmed == "context-management-2025-06-27" && + !strings.Contains(model, "anthropic.claude-sonnet-4-5") && + !strings.Contains(model, "anthropic.claude-haiku-4-5") { + continue + } + + keep = append(keep, trimmed) + } + + headers.Del("Anthropic-Beta") + for _, flag := range keep { + headers.Add("Anthropic-Beta", flag) } } diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index 5413a7d8..d1b14994 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -2,14 +2,17 @@ package messages import ( "context" + "net/http" "testing" + "cdr.dev/slog/v3" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/config" "github.com/coder/aibridge/mcp" mcpgo "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestAWSBedrockValidation(t *testing.T) { @@ -310,28 +313,19 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has existing tool with cache control, but no tools to inject. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`), mcpProxy: &mockServerProxier{tools: nil}, + logger: slog.Make(), } i.injectTools() // Cache control should remain untouched since no tools were injected. - require.Len(t, i.req.Tools, 1) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[0].OfTool.CacheControl.Type) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 1) + require.Equal(t, "existing_tool", toolItems[0].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[0].Get("cache_control.type").String()) }) t.Run("cache control breakpoint is preserved by prepending injected tools", func(t *testing.T) { @@ -339,36 +333,26 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has existing tool with cache control. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{ {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, }, }, + logger: slog.Make(), } i.injectTools() - require.Len(t, i.req.Tools, 2) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 2) // Injected tools are prepended. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) // Original tool's cache control should be preserved at the end. - require.Equal(t, "existing_tool", i.req.Tools[1].OfTool.Name) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type) + require.Equal(t, "existing_tool", toolItems[1].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String()) }) // The cache breakpoint SHOULD be on the final tool, but may not be; we must preserve that intention. @@ -377,43 +361,29 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has multiple tools with cache control breakpoints. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "tool_with_cache_1", - CacheControl: anthropic.CacheControlEphemeralParam{ - Type: constant.ValueOf[constant.Ephemeral](), - }, - }, - }, - { - OfTool: &anthropic.ToolParam{ - Name: "tool_with_cache_2", - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"tool_with_cache_1","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}},`+ + `{"name":"tool_with_cache_2","type":"custom","input_schema":{"type":"object","properties":{}}}]}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{ {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, }, }, + logger: slog.Make(), } i.injectTools() - require.Len(t, i.req.Tools, 3) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 3) // Injected tool is prepended without cache control. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) // Both original tools' cache controls should remain. - require.Equal(t, "tool_with_cache_1", i.req.Tools[1].OfTool.Name) - require.Equal(t, constant.ValueOf[constant.Ephemeral](), i.req.Tools[1].OfTool.CacheControl.Type) - require.Equal(t, "tool_with_cache_2", i.req.Tools[2].OfTool.Name) - require.Zero(t, i.req.Tools[2].OfTool.CacheControl) + require.Equal(t, "tool_with_cache_1", toolItems[1].Get("name").String()) + require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String()) + require.Equal(t, "tool_with_cache_2", toolItems[2].Get("name").String()) + require.Empty(t, toolItems[2].Get("cache_control.type").String()) }) t.Run("no cache control added when none originally set", func(t *testing.T) { @@ -421,33 +391,26 @@ func TestInjectTools_CacheBreakpoints(t *testing.T) { // Request has tools but none with cache control. i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Tools: []anthropic.ToolUnionParam{ - { - OfTool: &anthropic.ToolParam{ - Name: "existing_tool_no_cache", - }, - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tools":[`+ + `{"name":"existing_tool_no_cache","type":"custom","input_schema":{"type":"object","properties":{}}}]}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{ {ID: "injected_tool", Name: "injected", Description: "Injected tool"}, }, }, + logger: slog.Make(), } i.injectTools() - require.Len(t, i.req.Tools, 2) + toolItems := gjson.GetBytes(i.reqPayload, "tools").Array() + require.Len(t, toolItems, 2) // Injected tool is prepended without cache control. - require.Equal(t, "injected_tool", i.req.Tools[0].OfTool.Name) - require.Zero(t, i.req.Tools[0].OfTool.CacheControl) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Empty(t, toolItems[0].Get("cache_control.type").String()) // Original tool remains at the end without cache control. - require.Equal(t, "existing_tool_no_cache", i.req.Tools[1].OfTool.Name) - require.Zero(t, i.req.Tools[1].OfTool.CacheControl) + require.Equal(t, "existing_tool_no_cache", toolItems[1].Get("name").String()) + require.Empty(t, toolItems[1].Get("cache_control.type").String()) }) } @@ -458,152 +421,295 @@ func TestInjectTools_ParallelToolCalls(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - Type: constant.ValueOf[constant.Auto](), - }, - }, - }, - }, - mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject. + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`), + mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject. + logger: slog.Make(), } i.injectTools() // Tool choice should remain unchanged - DisableParallelToolUse should not be set. - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.False(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists()) }) - t.Run("disables parallel tool use for auto tool choice (default)", func(t *testing.T) { + t.Run("disables parallel tool use for empty tool choice (default)", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - // No tool choice set (default). - }, - }, + reqPayload: mustMessagesPayload(t, `{}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("disables parallel tool use for explicit auto tool choice", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAuto: &anthropic.ToolChoiceAutoParam{ - Type: constant.ValueOf[constant.Auto](), - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfAuto) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAuto.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("disables parallel tool use for any tool choice", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfAny: &anthropic.ToolChoiceAnyParam{ - Type: constant.ValueOf[constant.Any](), - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"any"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfAny) - require.True(t, i.req.ToolChoice.OfAny.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfAny.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Any]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("disables parallel tool use for tool choice type", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfTool: &anthropic.ToolChoiceToolParam{ - Type: constant.ValueOf[constant.Tool](), - Name: "specific_tool", - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"tool","name":"specific_tool"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() - require.NotNil(t, i.req.ToolChoice.OfTool) - require.True(t, i.req.ToolChoice.OfTool.DisableParallelToolUse.Valid()) - require.True(t, i.req.ToolChoice.OfTool.DisableParallelToolUse.Value) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.Tool]()), toolChoice.Get("type").String()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists()) + require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool()) }) t.Run("no-op for none tool choice type", func(t *testing.T) { t.Parallel() i := &interceptionBase{ - req: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - ToolChoice: anthropic.ToolChoiceUnionParam{ - OfNone: &anthropic.ToolChoiceNoneParam{ - Type: constant.ValueOf[constant.None](), - }, - }, - }, - }, + reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"none"}}`), mcpProxy: &mockServerProxier{ tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}}, }, + logger: slog.Make(), } i.injectTools() // Tools are still injected. - require.Len(t, i.req.Tools, 1) + require.Len(t, gjson.GetBytes(i.reqPayload, "tools").Array(), 1) // But no parallel tool use modification for "none" type. - require.Nil(t, i.req.ToolChoice.OfAuto) - require.Nil(t, i.req.ToolChoice.OfAny) - require.Nil(t, i.req.ToolChoice.OfTool) - require.NotNil(t, i.req.ToolChoice.OfNone) + toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice") + require.Equal(t, string(constant.ValueOf[constant.None]()), toolChoice.Get("type").String()) + require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists()) }) } +func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + + bedrockModel string + requestBody string + clientBetaFlags string + + expectThinkingType string + expectBudgetTokens int64 // 0 means budget_tokens should not be present + expectRemovedFields []string + expectKeptFields []string + expectBetaValues []string // expected separate Anthropic-Beta header values + }{ + { + name: "non_4_6_model_with_adaptive_thinking_gets_converted", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`, + expectThinkingType: "enabled", + expectBudgetTokens: 8000, // 10000 * 0.8 (default/high effort) + }, + { + name: "non_4_6_model_with_adaptive_thinking_and_small_max_tokens_disables_thinking", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"max_tokens":1000,"thinking":{"type":"adaptive"}}`, + expectThinkingType: "disabled", + }, + { + name: "opus_4_6_model_with_adaptive_thinking_is_not_converted", + bedrockModel: "anthropic.claude-opus-4-6-v1", + requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`, + expectThinkingType: "adaptive", + }, + { + name: "sonnet_4_6_model_with_adaptive_thinking_is_not_converted", + bedrockModel: "anthropic.claude-sonnet-4-6", + requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`, + expectThinkingType: "adaptive", + }, + { + name: "non_4_6_model_with_no_thinking_field_is_unchanged", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"max_tokens":10000}`, + }, + { + name: "non_4_6_model_with_enabled_thinking_is_unchanged", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000}}`, + expectThinkingType: "enabled", + expectBudgetTokens: 5000, + }, + { + name: "output_config_stripped_without_beta_flag_and_effort_used_for_budget", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`, + expectThinkingType: "enabled", + expectBudgetTokens: 2000, // 10000 * 0.2 (low effort) + expectRemovedFields: []string{"output_config"}, + }, + { + name: "output_config_kept_when_effort_beta_flag_present_on_opus_4_5", + bedrockModel: "anthropic.claude-opus-4-5-20250929-v1:0", + clientBetaFlags: "effort-2025-11-24,interleaved-thinking-2025-05-14", + requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"}}`, + expectKeptFields: []string{"output_config"}, + expectBetaValues: []string{"effort-2025-11-24", "interleaved-thinking-2025-05-14"}, + }, + { + name: "output_config_stripped_for_non_opus_4_5_even_with_effort_beta_flag", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + clientBetaFlags: "effort-2025-11-24,interleaved-thinking-2025-05-14", + requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"}}`, + expectRemovedFields: []string{"output_config"}, + expectBetaValues: []string{"interleaved-thinking-2025-05-14"}, + }, + { + name: "context_management_kept_when_beta_flag_present", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + clientBetaFlags: "context-management-2025-06-27", + requestBody: `{"max_tokens":10000,"context_management":{"type":"auto"}}`, + expectKeptFields: []string{"context_management"}, + expectBetaValues: []string{"context-management-2025-06-27"}, + }, + { + name: "context_management_stripped_without_beta_flag", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + requestBody: `{"max_tokens":10000,"context_management":{"type":"auto"}}`, + expectRemovedFields: []string{"context_management"}, + }, + { + name: "context_management_stripped_for_unsupported_model_even_with_beta_flag", + bedrockModel: "anthropic.claude-opus-4-6-v1", + clientBetaFlags: "context-management-2025-06-27", + requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"context_management":{"type":"auto"}}`, + expectThinkingType: "adaptive", + expectRemovedFields: []string{"context_management"}, + }, + { + name: "unsupported_beta_flags_are_filtered_out", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + clientBetaFlags: "claude-code-20250219,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05", + requestBody: `{"max_tokens":10000}`, + expectBetaValues: []string{"interleaved-thinking-2025-05-14"}, + }, + { + name: "all_unsupported_fields_stripped_and_beta_flags_filtered", + bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0", + clientBetaFlags: "claude-code-20250219,prompt-caching-scope-2026-01-05", + requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"},"metadata":{"user_id":"u123"},"service_tier":"auto","container":"ctr_abc","inference_geo":"us","context_management":{"type":"auto"}}`, + expectRemovedFields: []string{"output_config", "metadata", "service_tier", "container", "inference_geo", "context_management"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + extraHeaders := make(http.Header) + if tc.clientBetaFlags != "" { + extraHeaders.Set("Anthropic-Beta", tc.clientBetaFlags) + } + + i := &interceptionBase{ + reqPayload: mustMessagesPayload(t, tc.requestBody), + cfg: config.Anthropic{ExtraHeaders: extraHeaders}, + bedrockCfg: &config.AWSBedrock{ + Model: tc.bedrockModel, + SmallFastModel: "anthropic.claude-haiku-3-5", + }, + logger: slog.Make(), + } + + i.augmentRequestForBedrock() + + thinkingType := gjson.GetBytes(i.reqPayload, "thinking.type") + if tc.expectThinkingType == "" { + require.False(t, thinkingType.Exists()) + } else { + require.Equal(t, tc.expectThinkingType, thinkingType.String()) + } + + budgetTokens := gjson.GetBytes(i.reqPayload, "thinking.budget_tokens") + if tc.expectBudgetTokens == 0 { + require.False(t, budgetTokens.Exists(), "budget_tokens should not be set") + } else { + require.Equal(t, tc.expectBudgetTokens, budgetTokens.Int()) + } + + // Model should always be set to the bedrock model. + require.Equal(t, tc.bedrockModel, gjson.GetBytes(i.reqPayload, "model").String()) + + // Verify expected fields are removed. + for _, field := range tc.expectRemovedFields { + require.False(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be removed", field) + } + + // Verify expected fields are kept. + for _, field := range tc.expectKeptFields { + require.True(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be kept", field) + } + + got := extraHeaders.Values("Anthropic-Beta") + require.Equal(t, tc.expectBetaValues, got) + }) + } +} + +func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload { + t.Helper() + + payload, err := NewMessagesRequestPayload([]byte(requestBody)) + require.NoError(t, err) + + return payload +} + // mockServerProxier is a test implementation of mcp.ServerProxier. type mockServerProxier struct { tools []*mcp.Tool @@ -633,3 +739,98 @@ func (m *mockServerProxier) GetTool(id string) *mcp.Tool { func (m *mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) { return nil, nil } + +func TestFilterBedrockBetaFlags(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + model string + inputValues []string // header values to set (each element is a separate header value) + expectValues []string // expected separate header values after filtering + }{ + { + name: "empty header", + model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + inputValues: nil, + expectValues: nil, + }, + { + name: "all supported flags kept", + model: "anthropic.claude-opus-4-5-20250929-v1:0", + inputValues: []string{"interleaved-thinking-2025-05-14,effort-2025-11-24"}, + expectValues: []string{"interleaved-thinking-2025-05-14", "effort-2025-11-24"}, + }, + { + name: "unsupported flags removed", + model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + inputValues: []string{"claude-code-20250219,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05"}, + expectValues: []string{"interleaved-thinking-2025-05-14"}, + }, + { + name: "header removed when all flags unsupported", + model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + inputValues: []string{"claude-code-20250219,prompt-caching-scope-2026-01-05"}, + expectValues: nil, + }, + { + name: "effort flag removed for non opus 4.5 model", + model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + inputValues: []string{"effort-2025-11-24,interleaved-thinking-2025-05-14"}, + expectValues: []string{"interleaved-thinking-2025-05-14"}, + }, + { + name: "effort flag kept for opus 4.5 model", + model: "anthropic.claude-opus-4-5-20250929-v1:0", + inputValues: []string{"effort-2025-11-24,interleaved-thinking-2025-05-14"}, + expectValues: []string{"effort-2025-11-24", "interleaved-thinking-2025-05-14"}, + }, + { + name: "context management kept for sonnet 4.5", + model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + inputValues: []string{"context-management-2025-06-27"}, + expectValues: []string{"context-management-2025-06-27"}, + }, + { + name: "context management kept for haiku 4.5", + model: "anthropic.claude-haiku-4-5-20250929-v1:0", + inputValues: []string{"context-management-2025-06-27"}, + expectValues: []string{"context-management-2025-06-27"}, + }, + { + name: "context management removed for unsupported model", + model: "anthropic.claude-opus-4-6-v1", + inputValues: []string{"context-management-2025-06-27,interleaved-thinking-2025-05-14"}, + expectValues: []string{"interleaved-thinking-2025-05-14"}, + }, + { + name: "separate header values are handled correctly", + model: "anthropic.claude-sonnet-4-5-20250929-v1:0", + inputValues: []string{"interleaved-thinking-2025-05-14", "context-management-2025-06-27"}, + expectValues: []string{"interleaved-thinking-2025-05-14", "context-management-2025-06-27"}, + }, + { + name: "mixed comma-joined and separate header values", + model: "anthropic.claude-opus-4-5-20250929-v1:0", + inputValues: []string{"interleaved-thinking-2025-05-14,effort-2025-11-24", "token-efficient-tools-2025-02-19"}, + expectValues: []string{"interleaved-thinking-2025-05-14", "effort-2025-11-24", "token-efficient-tools-2025-02-19"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + headers := http.Header{} + for _, v := range tc.inputValues { + headers.Add("Anthropic-Beta", v) + } + + filterBedrockBetaFlags(headers, tc.model) + + // Each kept flag should be a separate header value. + got := headers.Values("Anthropic-Beta") + require.Equal(t, tc.expectValues, got) + }) + } +} diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 7ab2bedf..51781d91 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -28,11 +28,10 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor(id uuid.UUID, reqPayload MessagesRequestPayload, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, - req: req, - payload: payload, + reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, tracer: tracer, @@ -52,8 +51,8 @@ func (s *BlockingInterception) Streaming() bool { } func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { - if i.req == nil { - return fmt.Errorf("developer error: req is nil") + if len(i.reqPayload) == 0 { + return fmt.Errorf("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -61,15 +60,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req i.injectTools() - var ( - prompt *string - err error - ) - // Track user prompt if not a small/fast model + var prompt *string if !i.isSmallFastModel() { - prompt, err = i.req.lastUserPrompt() - if err != nil { - i.logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(err)) + promptText, promptFound, promptErr := i.reqPayload.lastUserPrompt() + if promptErr != nil { + i.logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(promptErr)) + } else if promptFound { + prompt = &promptText } } @@ -85,8 +82,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req return err } - messages := i.req.MessageNewParams - logger := i.logger.With(slog.F("model", i.req.Model)) + logger := i.logger.With(slog.F("model", i.Model())) var resp *anthropic.Message // Accumulate usage across the entire streaming interaction (including tool reinvocations). @@ -94,7 +90,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) - resp, err = i.newMessage(ctx, svc, messages) + resp, err = i.newMessage(ctx, svc) if err != nil { if eventstream.IsConnError(err) { // Can't write a response, just error out. @@ -163,9 +159,8 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req break } - // Append the assistant's message (which contains the tool_use block) - // to the messages for the next API call. - messages.Messages = append(messages.Messages, resp.ToParam()) + var loopMessages []anthropic.MessageParam + loopMessages = append(loopMessages, resp.ToParam()) // Process each pending tool call. for _, tc := range pendingToolCalls { @@ -177,7 +172,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req if tool == nil { logger.Warn(ctx, "tool not found in manager", slog.F("tool", tc.Name)) // Continue to next tool call, but still append an error tool_result - messages.Messages = append(messages.Messages, + loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: tool %s not found", tc.Name), true)), ) continue @@ -197,7 +192,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req if err != nil { // Always provide a tool_result even if the tool call failed - messages.Messages = append(messages.Messages, + loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(tc.ID, fmt.Sprintf("Error: calling tool: %v", err), true)), ) continue @@ -276,16 +271,16 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } if len(toolResult.OfToolResult.Content) > 0 { - messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult)) + loopMessages = append(loopMessages, anthropic.NewUserMessage(toolResult)) } } - // Sync the raw payload with updated messages so that withBody() - // sends the updated payload on the next iteration. - if err := i.syncPayloadMessages(messages.Messages); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return fmt.Errorf("sync payload for agentic loop: %w", err) + updatedPayload, rewriteErr := i.reqPayload.appendedMessages(loopMessages) + if rewriteErr != nil { + http.Error(w, rewriteErr.Error(), http.StatusInternalServerError) + return fmt.Errorf("rewrite payload for agentic loop: %w", rewriteErr) } + i.reqPayload = updatedPayload } if resp == nil { @@ -311,9 +306,9 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req return nil } -func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService, msgParams anthropic.MessageNewParams) (_ *anthropic.Message, outErr error) { +func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (_ *anthropic.Message, outErr error) { ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) - return svc.New(ctx, msgParams, i.withBody()) + return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody()) } diff --git a/intercept/messages/paramswrap.go b/intercept/messages/paramswrap.go deleted file mode 100644 index bd5175aa..00000000 --- a/intercept/messages/paramswrap.go +++ /dev/null @@ -1,142 +0,0 @@ -package messages - -import ( - "encoding/json" - "errors" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/packages/param" -) - -// MessageNewParamsWrapper exists because the "stream" param is not included in anthropic.MessageNewParams. -type MessageNewParamsWrapper struct { - anthropic.MessageNewParams `json:""` - Stream bool `json:"stream,omitempty"` -} - -func (b MessageNewParamsWrapper) MarshalJSON() ([]byte, error) { - type shadow MessageNewParamsWrapper - return param.MarshalWithExtras(b, (*shadow)(&b), map[string]any{ - "stream": b.Stream, - }) -} - -func (b *MessageNewParamsWrapper) UnmarshalJSON(raw []byte) error { - // Parse JSON once and extract both stream field and do content conversion - // to avoid double-parsing the same payload. - var modifiedJSON map[string]any - if err := json.Unmarshal(raw, &modifiedJSON); err != nil { - return err - } - - // Extract stream field from already-parsed map - if stream, ok := modifiedJSON["stream"].(bool); ok { - b.Stream = stream - } - - // Convert string content to array format if needed - if _, hasMessages := modifiedJSON["messages"]; hasMessages { - convertStringContentRecursive(modifiedJSON) - } - - // Marshal back for SDK parsing - convertedRaw, err := json.Marshal(modifiedJSON) - if err != nil { - return err - } - - return b.MessageNewParams.UnmarshalJSON(convertedRaw) -} - -func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) { - if b == nil { - return nil, errors.New("nil struct") - } - - if len(b.Messages) == 0 { - return nil, errors.New("no messages") - } - - // We only care if the last message was issued by a user. - msg := b.Messages[len(b.Messages)-1] - if msg.Role != anthropic.MessageParamRoleUser { - return nil, nil - } - - if len(msg.Content) == 0 { - return nil, nil - } - - // Walk backwards on "user"-initiated message content. Clients often inject - // content ahead of the actual prompt to provide context to the model, - // so the last item in the slice is most likely the user's prompt. - for i := len(msg.Content) - 1; i >= 0; i-- { - // Only text content is supported currently. - if textContent := msg.Content[i].GetText(); textContent != nil { - return textContent, nil - } - } - - return nil, nil -} - -// convertStringContentRecursive recursively scans JSON data and converts string "content" fields -// to proper text block arrays where needed for Anthropic SDK compatibility. -// Returns true if any modifications were made. -func convertStringContentRecursive(data any) bool { - modified := false - switch v := data.(type) { - case map[string]any: - // Check if this object has a "content" field with string value - if content, hasContent := v["content"]; hasContent { - if contentStr, isString := content.(string); isString { - // Check if this needs conversion based on context - if shouldConvertContentField(v) { - v["content"] = []map[string]any{ - { - "type": "text", - "text": contentStr, - }, - } - modified = true - } - } - } - - // Recursively process all values in the map - for _, value := range v { - if convertStringContentRecursive(value) { - modified = true - } - } - - case []any: - // Recursively process all items in the array - for _, item := range v { - if convertStringContentRecursive(item) { - modified = true - } - } - } - return modified -} - -// shouldConvertContentField determines if a "content" string field should be converted to text block array -func shouldConvertContentField(obj map[string]any) bool { - // Check if this is a message-level content (has "role" field) - if _, hasRole := obj["role"]; hasRole { - return true - } - - // Check if this is a tool_result block (but not mcp_tool_result which supports strings) - if objType, hasType := obj["type"].(string); hasType { - switch objType { - case "tool_result": - return true // Regular tool_result needs array format - case "mcp_tool_result": - return false // MCP tool_result supports strings - } - } - - return false -} diff --git a/intercept/messages/paramswrap_test.go b/intercept/messages/paramswrap_test.go deleted file mode 100644 index 7f8793d7..00000000 --- a/intercept/messages/paramswrap_test.go +++ /dev/null @@ -1,303 +0,0 @@ -package messages - -import ( - "testing" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/stretchr/testify/require" -) - -func TestMessageNewParamsWrapperUnmarshalJSON(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - expectedStream bool - checkContent func(t *testing.T, w *MessageNewParamsWrapper) - }{ - { - name: "message with string content converts to array", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"Hello world"}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - require.Equal(t, anthropic.MessageParamRoleUser, w.Messages[0].Role) - text := w.Messages[0].Content[0].GetText() - require.NotNil(t, text) - require.Equal(t, "Hello world", *text) - }, - }, - { - name: "stream field extracted", - input: `{"model":"claude-3","max_tokens":1000,"stream":true,"messages":[{"role":"user","content":"Hi"}]}`, - expectedStream: true, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - }, - }, - { - name: "stream false", - input: `{"model":"claude-3","max_tokens":1000,"stream":false,"messages":[{"role":"user","content":"Hi"}]}`, - expectedStream: false, - checkContent: nil, - }, - { - name: "array content unchanged", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 1) - text := w.Messages[0].Content[0].GetText() - require.NotNil(t, text) - require.Equal(t, "Hello", *text) - }, - }, - { - name: "multiple messages with mixed content", - input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"First"},{"role":"assistant","content":[{"type":"text","text":"Response"}]},{"role":"user","content":"Second"}]}`, - expectedStream: false, - checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { - require.Len(t, w.Messages, 3) - text0 := w.Messages[0].Content[0].GetText() - require.NotNil(t, text0) - require.Equal(t, "First", *text0) - text2 := w.Messages[2].Content[0].GetText() - require.NotNil(t, text2) - require.Equal(t, "Second", *text2) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var wrapper MessageNewParamsWrapper - err := wrapper.UnmarshalJSON([]byte(tt.input)) - require.NoError(t, err) - require.Equal(t, tt.expectedStream, wrapper.Stream) - if tt.checkContent != nil { - tt.checkContent(t, &wrapper) - } - }) - } -} - -func TestShouldConvertContentField(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - obj map[string]any - expected bool - }{ - { - name: "message with role", - obj: map[string]any{ - "role": "user", - "content": "test", - }, - expected: true, - }, - { - name: "tool_result type", - obj: map[string]any{ - "type": "tool_result", - "content": "result", - }, - expected: true, - }, - { - name: "mcp_tool_result type", - obj: map[string]any{ - "type": "mcp_tool_result", - "content": "result", - }, - expected: false, - }, - { - name: "other type", - obj: map[string]any{ - "type": "text", - "content": "text", - }, - expected: false, - }, - { - name: "no role or type", - obj: map[string]any{ - "content": "test", - }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldConvertContentField(tt.obj) - require.Equal(t, tt.expected, result) - }) - } -} - -func TestAnthropicLastUserPrompt(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - wrapper *MessageNewParamsWrapper - expected string - expectError bool - errorMsg string - }{ - { - name: "nil struct", - expectError: true, - errorMsg: "nil struct", - }, - { - name: "no messages", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{}, - }, - }, - expectError: true, - errorMsg: "no messages", - }, - { - name: "last message not from user", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("user message"), - }, - }, - { - Role: anthropic.MessageParamRoleAssistant, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("assistant message"), - }, - }, - }, - }, - }, - }, - { - name: "last user message with empty content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{}, - }, - }, - }, - }, - }, - { - name: "last user message with single text content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Hello, world!"), - }, - }, - }, - }, - }, - expected: "Hello, world!", - }, - { - name: "last user message with multiple content blocks - text at end", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewImageBlockBase64("image/png", "base64data"), - anthropic.NewTextBlock("First text"), - anthropic.NewImageBlockBase64("image/jpeg", "moredata"), - anthropic.NewTextBlock("Last text"), - }, - }, - }, - }, - }, - expected: "Last text", - }, - { - name: "last user message with only non-text content", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewImageBlockBase64("image/png", "base64data"), - anthropic.NewImageBlockBase64("image/jpeg", "moredata"), - }, - }, - }, - }, - }, - }, - { - name: "multiple messages with last being user", - wrapper: &MessageNewParamsWrapper{ - MessageNewParams: anthropic.MessageNewParams{ - Messages: []anthropic.MessageParam{ - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("First user message"), - }, - }, - { - Role: anthropic.MessageParamRoleAssistant, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Assistant response"), - }, - }, - { - Role: anthropic.MessageParamRoleUser, - Content: []anthropic.ContentBlockParamUnion{ - anthropic.NewTextBlock("Second user message"), - }, - }, - }, - }, - }, - expected: "Second user message", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := tt.wrapper.lastUserPrompt() - - if tt.expectError { - require.Error(t, err) - require.Contains(t, err.Error(), tt.errorMsg) - require.Nil(t, result) - } else { - require.NoError(t, err) - // Check pointer equality - both nil or both non-nil - if tt.expected == "" { - require.Nil(t, result) - } else { - require.NotNil(t, result) - // The result should point to the same string from the content block - require.Equal(t, tt.expected, *result) - } - } - }) - } -} diff --git a/intercept/messages/reqpayload.go b/intercept/messages/reqpayload.go new file mode 100644 index 00000000..a139f9c1 --- /dev/null +++ b/intercept/messages/reqpayload.go @@ -0,0 +1,412 @@ +package messages + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "slices" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + // Absolute JSON paths from the request root. + messagesReqPathMessages = "messages" + messagesReqPathMaxTokens = "max_tokens" + messagesReqPathModel = "model" + messagesReqPathOutputConfig = "output_config" + messagesReqPathOutputConfigEffort = "output_config.effort" + messagesReqPathMetadata = "metadata" + messagesReqPathServiceTier = "service_tier" + messagesReqPathContainer = "container" + messagesReqPathInferenceGeo = "inference_geo" + messagesReqPathContextManagement = "context_management" + messagesReqPathStream = "stream" + messagesReqPathThinking = "thinking" + messagesReqPathThinkingBudgetTokens = "thinking.budget_tokens" + messagesReqPathThinkingType = "thinking.type" + messagesReqPathToolChoice = "tool_choice" + messagesReqPathToolChoiceDisableParallel = "tool_choice.disable_parallel_tool_use" + messagesReqPathToolChoiceType = "tool_choice.type" + messagesReqPathTools = "tools" + + // Relative field names used within sub-objects. + messagesReqFieldContent = "content" + messagesReqFieldRole = "role" + messagesReqFieldText = "text" + messagesReqFieldToolUseID = "tool_use_id" + messagesReqFieldType = "type" +) + +const ( + constAdaptive = "adaptive" + constDisabled = "disabled" + constEnabled = "enabled" +) + +var ( + constAny = string(constant.ValueOf[constant.Any]()) + constAuto = string(constant.ValueOf[constant.Auto]()) + constNone = string(constant.ValueOf[constant.None]()) + constText = string(constant.ValueOf[constant.Text]()) + constTool = string(constant.ValueOf[constant.Tool]()) + constToolResult = string(constant.ValueOf[constant.ToolResult]()) + constUser = string(anthropic.MessageParamRoleUser) + + // bedrockUnsupportedFields are top-level fields present in the Anthropic Messages + // API that are absent from the Bedrock request body schema. Sending them results + // in a 400 "Extra inputs are not permitted" error. + // + // Anthropic API fields: https://platform.claude.com/docs/en/api/messages/create + // Bedrock request body: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html + bedrockUnsupportedFields = []string{ + messagesReqPathMetadata, + messagesReqPathServiceTier, + messagesReqPathContainer, + messagesReqPathInferenceGeo, + } + + // bedrockBetaGatedFields maps body fields to the beta flag that enables them. + // If the beta flag is present in the (already-filtered) Anthropic-Beta header, + // the field is kept; otherwise it is stripped. Model-specific beta flags must + // be removed from the header before this check (see filterBedrockBetaFlags). + bedrockBetaGatedFields = map[string]string{ + // output_config requires the effort beta (Opus 4.5 only). + messagesReqPathOutputConfig: "effort-2025-11-24", + // context_management requires the context-management beta (Sonnet 4.5, Haiku 4.5). + messagesReqPathContextManagement: "context-management-2025-06-27", + } +) + +// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request. +// Methods provide package-specific reads and rewrites while preserving the +// original body for upstream pass-through. +type MessagesRequestPayload []byte + +func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { + if len(bytes.TrimSpace(raw)) == 0 { + return nil, fmt.Errorf("messages empty request body") + } + if !json.Valid(raw) { + return nil, fmt.Errorf("messages invalid JSON request body") + } + + return MessagesRequestPayload(raw), nil +} + +func (p MessagesRequestPayload) Stream() bool { + v := gjson.GetBytes(p, messagesReqPathStream) + if !v.IsBool() { + return false + } + return v.Bool() +} + +func (p MessagesRequestPayload) model() string { + return gjson.GetBytes(p, messagesReqPathModel).Str +} + +func (p MessagesRequestPayload) correlatingToolCallID() *string { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.IsArray() { + return nil + } + + messageItems := messages.Array() + if len(messageItems) == 0 { + return nil + } + + content := messageItems[len(messageItems)-1].Get(messagesReqFieldContent) + if !content.IsArray() { + return nil + } + + contentItems := content.Array() + for idx := len(contentItems) - 1; idx >= 0; idx-- { + contentItem := contentItems[idx] + if contentItem.Get(messagesReqFieldType).String() != constToolResult { + continue + } + + toolUseID := contentItem.Get(messagesReqFieldToolUseID).String() + if toolUseID == "" { + continue + } + + return &toolUseID + } + + return nil +} + +// lastUserPrompt returns the prompt text from the last user message. If no prompt +// is found, it returns empty string, false, nil. Unexpected shapes are treated as +// unsupported and do not fail the request path. +func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.Exists() || messages.Type == gjson.Null { + return "", false, nil + } + if !messages.IsArray() { + return "", false, fmt.Errorf("unexpected messages type: %s", messages.Type) + } + + messageItems := messages.Array() + if len(messageItems) == 0 { + return "", false, nil + } + + lastMessage := messageItems[len(messageItems)-1] + if lastMessage.Get(messagesReqFieldRole).String() != constUser { + return "", false, nil + } + + content := lastMessage.Get(messagesReqFieldContent) + if !content.Exists() || content.Type == gjson.Null { + return "", false, nil + } + if content.Type == gjson.String { + return content.String(), true, nil + } + if !content.IsArray() { + return "", false, fmt.Errorf("unexpected message content type: %s", content.Type) + } + + contentItems := content.Array() + for idx := len(contentItems) - 1; idx >= 0; idx-- { + contentItem := contentItems[idx] + if contentItem.Get(messagesReqFieldType).String() != constText { + continue + } + + text := contentItem.Get(messagesReqFieldText) + if text.Type != gjson.String { + continue + } + + return text.String(), true, nil + } + + return "", false, nil +} + +func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) { + if len(injected) == 0 { + return p, nil + } + + existing, err := p.tools() + if err != nil { + return p, fmt.Errorf("get existing tools: %w", err) + } + + // Using []json.Marshaler to merge differently-typed slices ([]anthropic.ToolUnionParam + // and []json.Marshaler containing json.RawMessage) keeps JSON re-marshalings to a minimum: + // sjson.SetBytes marshals each element exactly once, and json.RawMessage + // elements are passed through without re-serialization. + allTools := make([]json.Marshaler, 0, len(injected)+len(existing)) + for _, tool := range injected { + allTools = append(allTools, tool) + } + + for _, e := range existing { + allTools = append(allTools, e) + } + + return p.set(messagesReqPathTools, allTools) +} + +func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) { + toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice) + + // If no tool_choice was defined, assume auto. + // See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use. + if !toolChoice.Exists() || toolChoice.Type == gjson.Null { + updated, err := p.set(messagesReqPathToolChoiceType, constAuto) + if err != nil { + return p, fmt.Errorf("set tool choice type: %w", err) + } + return updated.set(messagesReqPathToolChoiceDisableParallel, true) + } + if !toolChoice.IsObject() { + return p, fmt.Errorf("unsupported tool_choice type: %s", toolChoice.Type) + } + + toolChoiceType := gjson.GetBytes(p, messagesReqPathToolChoiceType) + if toolChoiceType.Exists() && toolChoiceType.Type != gjson.String { + return p, fmt.Errorf("unsupported tool_choice.type type: %s", toolChoiceType.Type) + } + + switch toolChoiceType.String() { + case "": + updated, err := p.set(messagesReqPathToolChoiceType, constAuto) + if err != nil { + return p, fmt.Errorf("set tool_choice.type: %w", err) + } + return updated.set(messagesReqPathToolChoiceDisableParallel, true) + case constAuto, constAny, constTool: + return p.set(messagesReqPathToolChoiceDisableParallel, true) + case constNone: + return p, nil + default: + return p, fmt.Errorf("unsupported tool_choice.type value: %q", toolChoiceType.String()) + } +} + +func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (MessagesRequestPayload, error) { + if len(newMessages) == 0 { + return p, nil + } + + existing, err := p.messages() + if err != nil { + return p, fmt.Errorf("get existing messages: %w", err) + } + + // Using []json.Marshaler to merge differently-typed slices ([]json.Marshaler containing + // json.RawMessage and []anthropic.MessageParam) keeps JSON re-marshalings + // to a minimum: sjson.SetBytes marshals each element exactly once, and + // json.RawMessage elements are passed through without re-serialization. + allMessages := make([]json.Marshaler, 0, len(existing)+len(newMessages)) + + for _, e := range existing { + allMessages = append(allMessages, e) + } + + for _, new := range newMessages { + allMessages = append(allMessages, new) + } + + return p.set(messagesReqPathMessages, allMessages) +} + +func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) { + return p.set(messagesReqPathModel, model) +} + +func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) { + messages := gjson.GetBytes(p, messagesReqPathMessages) + if !messages.Exists() || messages.Type == gjson.Null { + return nil, nil + } + if !messages.IsArray() { + return nil, fmt.Errorf("unsupported messages type: %s", messages.Type) + } + + return p.resultToRawMessage(messages.Array()), nil +} + +func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { + tools := gjson.GetBytes(p, messagesReqPathTools) + if !tools.Exists() || tools.Type == gjson.Null { + return nil, nil + } + if !tools.IsArray() { + return nil, fmt.Errorf("unsupported tools type: %s", tools.Type) + } + + return p.resultToRawMessage(tools.Array()), nil +} + +func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage { + // gjson.Result conversion to json.RawMessage is needed because + // gjson.Result does not implement json.Marshaler — would + // serialize its struct fields instead of the raw JSON it represents. + rawMessages := make([]json.RawMessage, 0, len(items)) + for _, item := range items { + rawMessages = append(rawMessages, json.RawMessage(item.Raw)) + } + return rawMessages +} + +// convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens +// conversion is needed for Bedrock models that does not support the "adaptive" thinking.type +func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) { + thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType) + if thinkingType.String() != constAdaptive { + return p, nil + } + + maxTokens := gjson.GetBytes(p, messagesReqPathMaxTokens).Int() + if maxTokens <= 0 { + // max_tokens is required by messages API + return p, fmt.Errorf("max_tokens: field required") + } + + effort := gjson.GetBytes(p, messagesReqPathOutputConfigEffort).String() + + // Enabled thinking type requires budget_tokens set. + // Heuristically calculate value based on the effort level. + // Effort-to-ratio mapping adapted from OpenRouter: + // https://openrouter.ai/docs/guides/best-practices/reasoning-tokens#reasoning-effort-level + var ratio float64 + switch effort { + case "low": + ratio = 0.2 + case "medium": + ratio = 0.5 + case "max": + ratio = 0.95 + default: // "high" or absent (high is the default effort) + ratio = 0.8 + } + + // budget_tokens must be ≥ 1024 && < max_tokens. If the calculated budget + // doesn't meet the minimum, disable thinking entirely rather than forcing + // an artificially high budget that would starve the output. + // https://platform.claude.com/docs/en/api/messages/create#create.thinking + // https://platform.claude.com/docs/en/build-with-claude/extended-thinking#how-to-use-extended-thinking + budgetTokens := int64(float64(maxTokens) * ratio) + if budgetTokens < 1024 { + return p.set(messagesReqPathThinking, map[string]string{"type": constDisabled}) + } + + return p.set(messagesReqPathThinking, map[string]any{ + "type": constEnabled, + "budget_tokens": budgetTokens, + }) +} + +// removeUnsupportedBedrockFields strips top-level fields that Bedrock does not +// support from the payload. Fields that are gated behind a beta flag are only +// removed when the corresponding flag is absent from the Anthropic-Beta header. +// Model-specific beta flags must already be filtered from the header before +// calling this method (see filterBedrockBetaFlags). +func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Header) (MessagesRequestPayload, error) { + var payloadMap map[string]any + if err := json.Unmarshal(p, &payloadMap); err != nil { + return p, fmt.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err) + } + + // Always strip unconditionally unsupported fields. + for _, field := range bedrockUnsupportedFields { + delete(payloadMap, field) + } + + // Strip beta-gated fields only when their beta flag is missing. + betaValues := headers.Values("Anthropic-Beta") + for field, requiredFlag := range bedrockBetaGatedFields { + if !slices.Contains(betaValues, requiredFlag) { + delete(payloadMap, field) + } + } + + result, err := json.Marshal(payloadMap) + if err != nil { + return p, fmt.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err) + } + return MessagesRequestPayload(result), nil +} + +func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) { + out, err := sjson.SetBytes(p, path, value) + if err != nil { + return p, fmt.Errorf("set %s: %w", path, err) + } + return MessagesRequestPayload(out), nil +} diff --git a/intercept/messages/reqpayload_test.go b/intercept/messages/reqpayload_test.go new file mode 100644 index 00000000..fcfdd39b --- /dev/null +++ b/intercept/messages/reqpayload_test.go @@ -0,0 +1,476 @@ +package messages + +import ( + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/shared/constant" + "github.com/coder/aibridge/utils" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNewMessagesRequestPayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody []byte + + expectError bool + }{ + { + name: "empty body", + requestBody: []byte(" \n\t "), + expectError: true, + }, + { + name: "invalid json", + requestBody: []byte(`{"model":`), + expectError: true, + }, + { + name: "valid json", + requestBody: []byte(`{"model":"claude-opus-4-5","max_tokens":1024}`), + expectError: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload, err := NewMessagesRequestPayload(testCase.requestBody) + if testCase.expectError { + require.Error(t, err) + require.Nil(t, payload) + return + } + + require.NoError(t, err) + require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload) + }) + } +} + +func TestMessagesRequestPayloadStream(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedStream bool + }{ + { + name: "stream true", + requestBody: `{"stream":true}`, + expectedStream: true, + }, + { + name: "stream false", + requestBody: `{"stream":false}`, + expectedStream: false, + }, + { + name: "stream missing", + requestBody: `{}`, + expectedStream: false, + }, + { + name: "stream wrong type", + requestBody: `{"stream":"true"}`, + expectedStream: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + require.Equal(t, testCase.expectedStream, payload.Stream()) + }) + } +} + +func TestMessagesRequestPayloadModel(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + requestBody string + expectedModel string + }{ + { + name: "model present", + requestBody: `{"model":"claude-opus-4-5"}`, + expectedModel: "claude-opus-4-5", + }, + { + name: "model missing", + requestBody: `{}`, + expectedModel: "", + }, + { + name: "model wrong type", + requestBody: `{"model":123}`, + expectedModel: "", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + require.Equal(t, testCase.expectedModel, payload.model()) + }) + } +} + +func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedPrompt string + + expectedFound bool + + expectError bool + }{ + { + name: "last user message string content", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + expectedPrompt: "hello", + expectedFound: true, + expectError: false, + }, + { + name: "last user message typed content returns last text block", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}},{"type":"text","text":"first"},{"type":"text","text":"last"}]}]}`, + expectedPrompt: "last", + expectedFound: true, + expectError: false, + }, + { + name: "last message not from user", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"assistant","content":"hello"}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "no messages key", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "empty messages array", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "last user message with empty content array", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[]}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "last user message with only non text content", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}},{"type":"image","source":{"type":"base64","media_type":"image/jpeg","data":"def"}}]}]}`, + expectedPrompt: "", + expectedFound: false, + expectError: false, + }, + { + name: "multiple messages with last being user", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"first"},{"role":"assistant","content":[{"type":"text","text":"response"}]},{"role":"user","content":"second"}]}`, + expectedPrompt: "second", + expectedFound: true, + expectError: false, + }, + { + name: "messages wrong type returns error", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":{}}`, + expectedPrompt: "", + expectedFound: false, + expectError: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + prompt, found, err := payload.lastUserPrompt() + if testCase.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, testCase.expectedFound, found) + require.Equal(t, testCase.expectedPrompt, prompt) + }) + } +} + +func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedToolUseID *string + }{ + { + name: "no tool result block", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`, + expectedToolUseID: nil, + }, + { + name: "returns last tool result from final message", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`, + expectedToolUseID: utils.PtrTo("toolu_second"), + }, + { + name: "ignores earlier message tool result", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"assistant","content":"done"}]}`, + expectedToolUseID: nil, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + require.Equal(t, testCase.expectedToolUseID, payload.correlatingToolCallID()) + }) + } +} + +func TestMessagesRequestPayloadInjectTools(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`) + + updatedPayload, err := payload.injectTools([]anthropic.ToolUnionParam{ + { + OfTool: &anthropic.ToolParam{ + Name: "injected_tool", + Type: anthropic.ToolTypeCustom, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: map[string]interface{}{}, + }, + }, + }, + }) + require.NoError(t, err) + + toolItems := gjson.GetBytes(updatedPayload, "tools").Array() + require.Len(t, toolItems, 2) + require.Equal(t, "injected_tool", toolItems[0].Get("name").String()) + require.Equal(t, "existing_tool", toolItems[1].Get("name").String()) + require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String()) +} + +func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + + requestBody string + + expectedThinkingType string + expectedBudgetTokens int64 + expectError bool + }{ + { + name: "no_thinking_field_is_no_op", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"messages":[]}`, + expectedThinkingType: "", + }, + { + name: "non_adaptive_thinking_type_is_no_op", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000},"messages":[]}`, + expectedThinkingType: "enabled", + expectedBudgetTokens: 5000, + }, + { + name: "adaptive_with_no_effort_defaults_to_80%", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"messages":[]}`, + expectedThinkingType: "enabled", + expectedBudgetTokens: 8000, // 10000 * 0.8 (default/high effort) + }, + { + name: "adaptive_with_explicit_effort_uses_correct_percentage", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"},"messages":[]}`, + expectedThinkingType: "enabled", + expectedBudgetTokens: 2000, // 10000 * 0.2 + }, + { + name: "adaptive_disables_thinking_when_budget_below_minimum", + requestBody: `{"model":"claude-sonnet-4-5","max_tokens":512,"thinking":{"type":"adaptive"},"messages":[]}`, + expectedThinkingType: "disabled", // 512 * 0.8 = 409, below 1024 minimum + }, + { + name: "adaptive_without_max_tokens_returns_error", + requestBody: `{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[]}`, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, tc.requestBody) + updatedPayload, err := payload.convertAdaptiveThinkingForBedrock() + if tc.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + + thinking := gjson.GetBytes(updatedPayload, messagesReqPathThinking) + require.NotEqual(t, tc.expectedThinkingType == "", thinking.Exists(), "thinking should not be set") + require.Equal(t, tc.expectedThinkingType, gjson.GetBytes(updatedPayload, messagesReqPathThinkingType).String()) // non existing field returns zero value + + budgetTokens := gjson.GetBytes(updatedPayload, messagesReqPathThinkingBudgetTokens) + require.NotEqual(t, tc.expectedBudgetTokens == 0, budgetTokens.Exists(), "budget_tokens should not be set") + require.Equal(t, tc.expectedBudgetTokens, budgetTokens.Int()) // non existing field returns zero value + }) + } +} + +func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + requestBody string + expectError string + expectedType string + expectedDisableParallel *bool + }{ + { + name: "defaults to auto when missing", + requestBody: `{"model":"claude-opus-4-5","max_tokens":1024}`, + expectedType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "auto gets disabled", + requestBody: `{"tool_choice":{"type":"auto"}}`, + expectedType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "any gets disabled", + requestBody: `{"tool_choice":{"type":"any"}}`, + expectedType: string(constant.ValueOf[constant.Any]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "tool gets disabled", + requestBody: `{"tool_choice":{"type":"tool","name":"abc"}}`, + expectedType: string(constant.ValueOf[constant.Tool]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "none remains unchanged", + requestBody: `{"tool_choice":{"type":"none"}}`, + expectedType: string(constant.ValueOf[constant.None]()), + expectedDisableParallel: nil, + }, + { + name: "empty type defaults to auto", + requestBody: `{"tool_choice":{}}`, + expectedType: string(constant.ValueOf[constant.Auto]()), + expectedDisableParallel: utils.PtrTo(true), + }, + { + name: "non-object tool_choice returns error", + requestBody: `{"tool_choice":"auto"}`, + expectError: "unsupported tool_choice type", + }, + { + name: "non-string tool_choice type returns error", + requestBody: `{"tool_choice":{"type":123}}`, + expectError: "unsupported tool_choice.type type", + }, + { + name: "unsupported tool_choice type returns error", + requestBody: `{"tool_choice":{"type":"unknown"}}`, + expectError: "unsupported tool_choice.type value", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, testCase.requestBody) + updatedPayload, err := payload.disableParallelToolCalls() + if testCase.expectError != "" { + require.ErrorContains(t, err, testCase.expectError) + return + } + require.NoError(t, err) + + toolChoice := gjson.GetBytes(updatedPayload, "tool_choice") + require.Equal(t, testCase.expectedType, toolChoice.Get("type").String()) + + disableParallelResult := toolChoice.Get("disable_parallel_tool_use") + if testCase.expectedDisableParallel == nil { + require.False(t, disableParallelResult.Exists()) + return + } + + require.True(t, disableParallelResult.Exists()) + require.Equal(t, *testCase.expectedDisableParallel, disableParallelResult.Bool()) + }) + } +} + +func TestMessagesRequestPayloadAppendedMessages(t *testing.T) { + t.Parallel() + + payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`) + + updatedPayload, err := payload.appendedMessages([]anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("assistant response"), + }, + }, + anthropic.NewUserMessage(anthropic.NewToolResultBlock("toolu_123", "tool output", false)), + }) + require.NoError(t, err) + + messageItems := gjson.GetBytes(updatedPayload, "messages").Array() + require.Len(t, messageItems, 3) + require.Equal(t, "hello", messageItems[0].Get("content").String()) + require.Equal(t, "assistant", messageItems[1].Get("role").String()) + require.Equal(t, "assistant response", messageItems[1].Get("content.0.text").String()) + require.Equal(t, "tool_result", messageItems[2].Get("content.0.type").String()) + require.Equal(t, "toolu_123", messageItems[2].Get("content.0.tool_use_id").String()) +} diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 4fc19fdf..5fa829ca 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -34,11 +34,10 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor(id uuid.UUID, reqPayload MessagesRequestPayload, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, - req: req, - payload: payload, + reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, tracer: tracer, @@ -77,8 +76,8 @@ func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.Key // results relayed to the SERVER. The response from the server will be handled synchronously, and this loop // can continue until all injected tool invocations are completed and the response is relayed to the client. func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { - if i.req == nil { - return fmt.Errorf("developer error: req is nil") + if len(i.reqPayload) == 0 { + return fmt.Errorf("developer error: request payload is empty") } ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) @@ -89,16 +88,17 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - logger := i.logger.With(slog.F("model", i.req.Model)) + logger := i.logger.With(slog.F("model", i.Model())) var ( - prompt *string - err error + prompt string + promptFound bool + err error ) // Claude Code uses a "small/fast model" for certain tasks. if !i.isSmallFastModel() { - prompt, err = i.req.lastUserPrompt() + prompt, promptFound, err = i.reqPayload.lastUserPrompt() if err != nil { logger.Warn(ctx, "failed to determine last user prompt", slog.Error(err)) } @@ -129,8 +129,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - messages := i.req.MessageNewParams - // Accumulate usage across the entire streaming interaction (including tool reinvocations). var cumulativeUsage anthropic.Usage @@ -146,7 +144,7 @@ newStream: break } - stream := i.newStream(streamCtx, svc, messages) + stream := i.newStream(streamCtx, svc) var message anthropic.Message var lastToolName string @@ -254,7 +252,8 @@ newStream: case string(constant.ValueOf[constant.MessageStop]()): if len(pendingToolCalls) > 0 { // Append the whole message from this stream as context since we'll be sending a new request with the tool results. - messages.Messages = append(messages.Messages, message.ToParam()) + var loopMessages []anthropic.MessageParam + loopMessages = append(loopMessages, message.ToParam()) for name, id := range pendingToolCalls { if i.mcpProxy == nil { @@ -307,7 +306,7 @@ newStream: if err != nil { // Always provide a tool_result even if the tool call failed - messages.Messages = append(messages.Messages, + loopMessages = append(loopMessages, anthropic.NewUserMessage(anthropic.NewToolResultBlock(id, fmt.Sprintf("Error calling tool: %v", err), true)), ) continue @@ -385,16 +384,18 @@ newStream: } if len(toolResult.OfToolResult.Content) > 0 { - messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult)) + loopMessages = append(loopMessages, anthropic.NewUserMessage(toolResult)) } } // Sync the raw payload with updated messages so that withBody() // sends the updated payload on the next iteration. - if syncErr := i.syncPayloadMessages(messages.Messages); syncErr != nil { + updatedPayload, syncErr := i.reqPayload.appendedMessages(loopMessages) + if syncErr != nil { lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr) break } + i.reqPayload = updatedPayload // Causes a new stream to be run with updated messages. isFirst = false @@ -439,13 +440,14 @@ newStream: } } - if prompt != nil { + if promptFound { _ = i.recorder.RecordPromptUsage(ctx, &recorder.PromptUsageRecord{ InterceptionID: i.ID().String(), MsgID: message.ID, - Prompt: *prompt, + Prompt: prompt, }) - prompt = nil + prompt = "" + promptFound = false } if events.IsStreaming() { @@ -553,10 +555,10 @@ func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []by return buf.Bytes() } -// newStream traces svc.NewStreaming(streamCtx, messages) -func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, messages anthropic.MessageNewParams) *ssestream.Stream[anthropic.MessageStreamEventUnion] { +// newStream traces svc.NewStreaming() call. +func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] { _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return svc.NewStreaming(ctx, messages, s.withBody()) + return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, s.withBody()) } diff --git a/intercept/responses/base.go b/intercept/responses/base.go index dcd72a0d..cad980f1 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -54,8 +54,10 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. - for key, value := range i.cfg.ExtraHeaders { - opts = append(opts, option.WithHeader(key, value)) + for key, values := range i.cfg.ExtraHeaders { + for _, v := range values { + opts = append(opts, option.WithHeaderAdd(key, v)) + } } // Add API dump middleware if configured diff --git a/provider/anthropic.go b/provider/anthropic.go index 3625a31f..ac5e54e7 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -1,8 +1,6 @@ package provider import ( - "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -102,8 +100,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr if err != nil { return nil, fmt.Errorf("read body: %w", err) } - var req messages.MessageNewParamsWrapper - if err := json.NewDecoder(bytes.NewReader(payload)).Decode(&req); err != nil { + + reqPayload, err := messages.NewMessagesRequestPayload(payload) + if err != nil { return nil, fmt.Errorf("unmarshal request body: %w", err) } @@ -111,10 +110,10 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr cfg.ExtraHeaders = extractAnthropicHeaders(r) var interceptor intercept.Interceptor - if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + if reqPayload.Stream() { + interceptor = messages.NewStreamingInterceptor(id, reqPayload, cfg, p.bedrockCfg, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + interceptor = messages.NewBlockingInterceptor(id, reqPayload, cfg, p.bedrockCfg, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil @@ -147,11 +146,11 @@ func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker { // extractAnthropicHeaders extracts headers required by the Anthropic API from // the incoming request. // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192 -func extractAnthropicHeaders(r *http.Request) map[string]string { - headers := make(map[string]string, len(anthropicForwardHeaders)) +func extractAnthropicHeaders(r *http.Request) http.Header { + headers := make(http.Header, len(anthropicForwardHeaders)) for _, h := range anthropicForwardHeaders { - if v := r.Header.Get(h); v != "" { - headers[h] = v + if values := r.Header.Values(h); len(values) > 0 { + headers[h] = values } } return headers diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index 924c0f98..5087fccc 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -127,27 +127,27 @@ func TestExtractAnthropicHeaders(t *testing.T) { tests := []struct { name string headers map[string]string - expected map[string]string + expected http.Header }{ { name: "no headers", headers: map[string]string{}, - expected: map[string]string{}, + expected: http.Header{}, }, { name: "single beta", headers: map[string]string{"Anthropic-Beta": "claude-code-20250219"}, - expected: map[string]string{"Anthropic-Beta": "claude-code-20250219"}, + expected: http.Header{"Anthropic-Beta": {"claude-code-20250219"}}, }, { name: "multiple betas in single header", headers: map[string]string{"Anthropic-Beta": "claude-code-20250219,adaptive-thinking-2026-01-28,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24"}, - expected: map[string]string{"Anthropic-Beta": "claude-code-20250219,adaptive-thinking-2026-01-28,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24"}, + expected: http.Header{"Anthropic-Beta": {"claude-code-20250219,adaptive-thinking-2026-01-28,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24"}}, }, { name: "ignores other headers", headers: map[string]string{"Anthropic-Beta": "claude-code-20250219,context-management-2025-06-27", "X-Api-Key": "secret"}, - expected: map[string]string{"Anthropic-Beta": "claude-code-20250219,context-management-2025-06-27"}, + expected: http.Header{"Anthropic-Beta": {"claude-code-20250219,context-management-2025-06-27"}}, }, } diff --git a/provider/copilot.go b/provider/copilot.go index 34fab491..112a9499 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -187,11 +187,11 @@ func extractBearerToken(auth string) string { // extractCopilotHeaders extracts headers required by the Copilot API from the // incoming request. Copilot requires certain client headers to be forwarded. -func extractCopilotHeaders(r *http.Request) map[string]string { - headers := make(map[string]string, len(copilotForwardHeaders)) +func extractCopilotHeaders(r *http.Request) http.Header { + headers := make(http.Header, len(copilotForwardHeaders)) for _, h := range copilotForwardHeaders { - if v := r.Header.Get(h); v != "" { - headers[h] = v + if values := r.Header.Values(h); len(values) > 0 { + headers[h] = values } } return headers diff --git a/provider/copilot_test.go b/provider/copilot_test.go index 697b6990..0e26b56f 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -354,27 +354,27 @@ func TestExtractCopilotHeaders(t *testing.T) { tests := []struct { name string headers map[string]string - expected map[string]string + expected http.Header }{ { name: "all headers present", headers: map[string]string{"Editor-Version": "vscode/1.85.0", "Copilot-Integration-Id": "some-id"}, - expected: map[string]string{"Editor-Version": "vscode/1.85.0", "Copilot-Integration-Id": "some-id"}, + expected: http.Header{"Editor-Version": {"vscode/1.85.0"}, "Copilot-Integration-Id": {"some-id"}}, }, { name: "some headers present", headers: map[string]string{"Editor-Version": "vscode/1.85.0"}, - expected: map[string]string{"Editor-Version": "vscode/1.85.0"}, + expected: http.Header{"Editor-Version": {"vscode/1.85.0"}}, }, { name: "no headers", headers: map[string]string{}, - expected: map[string]string{}, + expected: http.Header{}, }, { name: "ignores other headers", headers: map[string]string{"Editor-Version": "vscode/1.85.0", "Authorization": "Bearer token"}, - expected: map[string]string{"Editor-Version": "vscode/1.85.0"}, + expected: http.Header{"Editor-Version": {"vscode/1.85.0"}}, }, }