diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 7a755e06..8e0e3be2 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -9,6 +9,7 @@ import ( "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -29,6 +30,10 @@ type interceptionBase struct { req *ChatCompletionNewParamsWrapper cfg config.OpenAI + // clientHeaders holds the original client request headers to forward + // to upstream providers. + clientHeaders http.Header + logger slog.Logger tracer trace.Tracer @@ -37,7 +42,18 @@ type interceptionBase struct { } func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { - opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)} + var opts []option.RequestOption + + // Forward sanitized client headers to the upstream provider. + // Client headers are added first so that SDK auth appended + // below takes priority on any conflict. + for k, vals := range intercept.SanitizeClientHeaders(i.clientHeaders) { + for _, v := range vals { + opts = append(opts, option.WithHeader(k, v)) + } + } + + opts = append(opts, option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)) // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 9a84d143..8bf11055 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -28,12 +28,19 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor( + id uuid.UUID, + req *ChatCompletionNewParamsWrapper, + cfg config.OpenAI, + clientHeaders http.Header, + tracer trace.Tracer, +) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - cfg: cfg, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + clientHeaders: clientHeaders, + tracer: tracer, }} } diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index ff3b78c6..b220766f 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -33,12 +33,19 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor( + id uuid.UUID, + req *ChatCompletionNewParamsWrapper, + cfg config.OpenAI, + clientHeaders http.Header, + tracer trace.Tracer, +) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - cfg: cfg, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + clientHeaders: clientHeaders, + tracer: tracer, }} } diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 7d8d4d57..d95f8d7c 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -82,7 +82,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { } tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, tracer) + interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, nil, tracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) diff --git a/intercept/client_headers.go b/intercept/client_headers.go new file mode 100644 index 00000000..c479fccb --- /dev/null +++ b/intercept/client_headers.go @@ -0,0 +1,58 @@ +package intercept + +import "net/http" + +// hopByHopHeaders are connection-level headers specific to the connection +// between client and AI Bridge, not meant for the upstream. +// See https://www.rfc-editor.org/rfc/rfc2616#section-13.5.1. +var hopByHopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +// nonForwardedHeaders are headers that should not be forwarded to the upstream provider: +// - Connection-specific headers managed by AI Bridge or the HTTP transport. +// - Auth headers that are re-injected by the SDK from the provider configuration. +// - User-Agent is set by the SDK to identify AI Bridge as the upstream client. +var nonForwardedHeaders = []string{ + "Host", + "Accept-Encoding", + "Content-Length", + "Content-Type", + "Authorization", + "X-Api-Key", + "User-Agent", +} + +// SanitizeClientHeaders clones headers and returns a sanitized copy suitable +// for forwarding to an upstream provider. +// +// It removes: +// - Hop-by-hop headers +// - Non-forwarded headers (connection-specific, transport-managed, auth, and SDK-managed headers) +// +// Callers should apply these headers first, so that any subsequently +// added headers take priority in case of conflict. +func SanitizeClientHeaders(headers http.Header) http.Header { + if headers == nil { + return http.Header{} + } + + outHeaders := headers.Clone() + + for _, h := range hopByHopHeaders { + outHeaders.Del(h) + } + + for _, h := range nonForwardedHeaders { + outHeaders.Del(h) + } + + return outHeaders +} diff --git a/intercept/client_headers_test.go b/intercept/client_headers_test.go new file mode 100644 index 00000000..51625a9f --- /dev/null +++ b/intercept/client_headers_test.go @@ -0,0 +1,109 @@ +package intercept + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSanitizeClientHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input http.Header + expectAbsent []string + expectPresent map[string][]string + expectEmpty bool + }{ + { + name: "no headers returns empty header", + input: nil, + expectEmpty: true, + }, + { + name: "hop-by-hop headers are removed", + input: http.Header{ + "Connection": []string{"keep-alive"}, + "Keep-Alive": []string{"timeout=5"}, + "Transfer-Encoding": []string{"chunked"}, + "Upgrade": []string{"websocket"}, + }, + expectAbsent: []string{"Connection", "Keep-Alive", "Transfer-Encoding", "Upgrade"}, + }, + { + name: "bridge headers are removed", + input: http.Header{ + "Host": []string{"example.com"}, + "Accept-Encoding": []string{"gzip"}, + "Content-Length": []string{"42"}, + "Content-Type": []string{"application/json"}, + "User-Agent": []string{"client/1.0"}, + }, + expectAbsent: []string{"Host", "Accept-Encoding", "Content-Length", "Content-Type", "User-Agent"}, + }, + { + name: "auth headers are removed", + input: http.Header{ + "Authorization": []string{"Bearer some-token"}, + "X-Api-Key": []string{"some-key"}, + }, + expectAbsent: []string{"Authorization", "X-Api-Key"}, + }, + { + name: "custom headers are preserved", + input: http.Header{ + "X-Custom-Header": []string{"custom-value"}, + "X-Request-Id": []string{"req-123"}, + }, + expectPresent: map[string][]string{ + "X-Custom-Header": {"custom-value"}, + "X-Request-Id": {"req-123"}, + }, + }, + { + name: "multi-value headers are preserved", + input: http.Header{ + "X-Custom-Header": []string{"value-1", "value-2"}, + }, + expectPresent: map[string][]string{ + "X-Custom-Header": {"value-1", "value-2"}, + }, + }, + { + name: "input is not mutated", + input: http.Header{ + "Connection": []string{"keep-alive"}, + "X-Custom-Header": []string{"custom-value"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Capture original state to verify no mutation. + originalCopy := tc.input.Clone() + + result := SanitizeClientHeaders(tc.input) + + if tc.expectEmpty { + require.Empty(t, result) + return + } + + for _, h := range tc.expectAbsent { + require.Empty(t, result.Get(h), "expected header %q to be absent", h) + } + + for h, vals := range tc.expectPresent { + require.Equal(t, vals, result[h], "expected header %q to be present", h) + } + + // Verify input was not mutated. + require.Equal(t, originalCopy, tc.input) + }) + } +} diff --git a/intercept/messages/base.go b/intercept/messages/base.go index c61a3a7c..57e45a44 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -18,6 +18,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" aibconfig "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -40,6 +41,10 @@ type interceptionBase struct { cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock + // clientHeaders holds the original client request headers to forward + // to upstream providers. + clientHeaders http.Header + tracer trace.Tracer logger slog.Logger @@ -178,6 +183,15 @@ func (i *interceptionBase) isSmallFastModel() bool { } func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { + // Forward sanitized client headers to the upstream provider. + // Client headers are added first so that SDK auth appended + // below takes priority on any conflict. + for k, vals := range intercept.SanitizeClientHeaders(i.clientHeaders) { + for _, v := range vals { + opts = append(opts, option.WithHeader(k, v)) + } + } + opts = append(opts, option.WithAPIKey(i.cfg.Key)) opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index e22b97f8..e755ee05 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -28,14 +28,23 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor( + id uuid.UUID, + req *MessageNewParamsWrapper, + payload []byte, + cfg config.Anthropic, + bedrockCfg *config.AWSBedrock, + clientHeaders http.Header, + tracer trace.Tracer, +) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + tracer: tracer, }} } diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 4e87fd85..46481110 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -34,14 +34,23 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor( + id uuid.UUID, + req *MessageNewParamsWrapper, + payload []byte, + cfg config.Anthropic, + bedrockCfg *config.AWSBedrock, + clientHeaders http.Header, + tracer trace.Tracer, +) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + tracer: tracer, }} } diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 8b7c3ded..e65480aa 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" @@ -41,16 +42,30 @@ type responsesInterceptionBase struct { req *ResponsesNewParamsWrapper reqPayload []byte cfg config.OpenAI - model string - recorder recorder.Recorder - mcpProxy mcp.ServerProxier - logger slog.Logger - metrics metrics.Metrics - tracer trace.Tracer + // clientHeaders holds the original client request headers to forward + // to upstream providers. + clientHeaders http.Header + model string + recorder recorder.Recorder + mcpProxy mcp.ServerProxier + logger slog.Logger + metrics metrics.Metrics + tracer trace.Tracer } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { - opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)} + var opts []option.RequestOption + + // Forward sanitized client headers to the upstream provider. + // Client headers are added first so that SDK auth appended + // below takes priority on any conflict. + for k, vals := range intercept.SanitizeClientHeaders(i.clientHeaders) { + for _, v := range vals { + opts = append(opts, option.WithHeader(k, v)) + } + } + + opts = append(opts, option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)) // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 3e94a6cc..88b3cc19 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -25,15 +25,24 @@ type BlockingResponsesInterceptor struct { responsesInterceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *BlockingResponsesInterceptor { +func NewBlockingInterceptor( + id uuid.UUID, + req *ResponsesNewParamsWrapper, + reqPayload []byte, + cfg config.OpenAI, + model string, + clientHeaders http.Header, + tracer trace.Tracer, +) *BlockingResponsesInterceptor { return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + clientHeaders: clientHeaders, + tracer: tracer, }, } } diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 6925d86f..0db17103 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -32,15 +32,24 @@ type StreamingResponsesInterceptor struct { responsesInterceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string, tracer trace.Tracer) *StreamingResponsesInterceptor { +func NewStreamingInterceptor( + id uuid.UUID, + req *ResponsesNewParamsWrapper, + reqPayload []byte, + cfg config.OpenAI, + model string, + clientHeaders http.Header, + tracer trace.Tracer, +) *StreamingResponsesInterceptor { return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ - id: id, - req: req, - reqPayload: reqPayload, - cfg: cfg, - model: model, - tracer: tracer, + id: id, + req: req, + reqPayload: reqPayload, + cfg: cfg, + model: model, + clientHeaders: clientHeaders, + tracer: tracer, }, } } diff --git a/provider/anthropic.go b/provider/anthropic.go index e682fdb7..1c3932b8 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -112,9 +112,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr var interceptor intercept.Interceptor if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + interceptor = messages.NewStreamingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, r.Header, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, tracer) + interceptor = messages.NewBlockingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, r.Header, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index 924c0f98..a4d438b1 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -4,6 +4,7 @@ import ( "bytes" "net/http" "net/http/httptest" + "strings" "testing" "cdr.dev/slog/v3" @@ -61,52 +62,6 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal request body") }) - t.Run("Messages_ForwardsAnthropicBetaHeaderToUpstream", func(t *testing.T) { - t.Parallel() - - var receivedHeaders http.Header - - // Mock upstream that captures headers. - mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"msg-123","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) - })) - t.Cleanup(mockUpstream.Close) - - provider := NewAnthropic(config.Anthropic{ - BaseURL: mockUpstream.URL, - Key: "test-key", - }, nil) - - // Use a realistic multi-beta value as sent by Claude Code clients. - betaHeader := "claude-code-20250219,adaptive-thinking-2026-01-28,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24" - - body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": false}` - req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) - req.Header.Set("Anthropic-Beta", betaHeader) - req.Header.Set("X-Custom-Header", "should-not-forward") - w := httptest.NewRecorder() - - interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.NoError(t, err) - require.NotNil(t, interceptor) - - logger := slog.Make() - interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - - processReq := httptest.NewRequest(http.MethodPost, routeMessages, nil) - err = interceptor.ProcessRequest(w, processReq) - require.NoError(t, err) - - // Verify the full Anthropic-Beta header (all betas) was forwarded unchanged. - assert.Equal(t, betaHeader, receivedHeaders.Get("Anthropic-Beta")) - - // Verify non-Anthropic headers are not forwarded. - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Anthropic headers should not be forwarded") - }) - t.Run("UnknownRoute", func(t *testing.T) { t.Parallel() @@ -188,3 +143,53 @@ func Test_anthropicIsFailure(t *testing.T) { assert.Equal(t, tt.isFailure, anthropicIsFailure(tt.statusCode), "status code %d", tt.statusCode) } } + +// TestAnthropic_ForwardsHeadersToUpstream verifies that custom client headers are forwarded +// to the upstream, while filtered headers (auth, hop-by-hop) are stripped and the +// configured API key is always used. +func TestAnthropic_ForwardsHeadersToUpstream(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg-123","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-3-haiku-20240307","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + })) + t.Cleanup(mockUpstream.Close) + + const configuredKey = "configured-key" + p := NewAnthropic(config.Anthropic{ + BaseURL: mockUpstream.URL, + Key: configuredKey, + }, nil) + + body := `{"model":"claude-3-haiku-20240307","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) + req.Header.Set("X-Custom-Header", "custom-value") // should be forwarded + req.Header.Set("Authorization", "Bearer client-fake-key") // should be stripped (re-injected by SDK) + req.Header.Set("X-Api-Key", "client-fake-key") // should be stripped (re-injected by SDK) + req.Header.Set("Upgrade", "websocket") // should be stripped (hop-by-hop) + w := httptest.NewRecorder() + + interceptor, err := p.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) + + err = interceptor.ProcessRequest(w, httptest.NewRequest(http.MethodPost, routeMessages, nil)) + require.NoError(t, err) + + // Custom headers must reach the upstream. + assert.Equal(t, "custom-value", receivedHeaders.Get("X-Custom-Header")) + // Hop-by-hop headers must not reach the upstream. + assert.Empty(t, receivedHeaders.Get("Upgrade")) + // Configured key must be used, not the client-provided fake. + assert.Equal(t, configuredKey, receivedHeaders.Get("X-Api-Key")) + // User-Agent must be set by the SDK, not forwarded from the client. + assert.True(t, strings.HasPrefix(receivedHeaders.Get("User-Agent"), "Anthropic/Go"), + "upstream User-Agent should be set by the SDK") +} diff --git a/provider/copilot.go b/provider/copilot.go index 9b128cab..826d41eb 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -148,9 +148,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, tracer) } case routeCopilotResponses: @@ -164,9 +164,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, tracer) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, cfg, req.Model, r.Header, tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, tracer) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, cfg, req.Model, r.Header, tracer) } default: diff --git a/provider/copilot_test.go b/provider/copilot_test.go index 697b6990..0b85f074 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -4,6 +4,7 @@ import ( "bytes" "net/http" "net/http/httptest" + "strings" "testing" "cdr.dev/slog/v3" @@ -24,29 +25,29 @@ func TestCopilot_InjectAuthHeader(t *testing.T) { // so InjectAuthHeader should not modify any headers. provider := NewCopilot(config.Copilot{}) - t.Run("ExistingHeaders_Unchanged", func(t *testing.T) { + t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) { t.Parallel() headers := http.Header{} - headers.Set("Authorization", "Bearer user-token") - headers.Set("X-Custom-Header", "custom-value") provider.InjectAuthHeader(&headers) - assert.Equal(t, "Bearer user-token", headers.Get("Authorization"), - "Authorization header should remain unchanged") - assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"), - "other headers should remain unchanged") + assert.Empty(t, headers, "no headers should be added") }) - t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) { + t.Run("ExistingHeaders_Unchanged", func(t *testing.T) { t.Parallel() headers := http.Header{} + headers.Set("Authorization", "Bearer user-token") + headers.Set("X-Custom-Header", "custom-value") provider.InjectAuthHeader(&headers) - assert.Empty(t, headers, "no headers should be added") + assert.Equal(t, "Bearer user-token", headers.Get("Authorization"), + "Authorization header should remain unchanged") + assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"), + "other headers should remain unchanged") }) } @@ -129,53 +130,6 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal chat completions request body") }) - t.Run("ChatCompletions_ForwardsHeadersToUpstream", func(t *testing.T) { - t.Parallel() - - var receivedHeaders http.Header - - // Mock upstream that captures headers - mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`)) - })) - t.Cleanup(mockUpstream.Close) - - // Create provider with mock upstream URL - provider := NewCopilot(config.Copilot{ - BaseURL: mockUpstream.URL, - }) - - body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}` - req := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, bytes.NewBufferString(body)) - req.Header.Set("Authorization", "Bearer test-token") - req.Header.Set("Editor-Version", "vscode/1.85.0") - req.Header.Set("Copilot-Integration-Id", "test-integration") - req.Header.Set("X-Custom-Header", "should-not-forward") - w := httptest.NewRecorder() - - interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.NoError(t, err) - require.NotNil(t, interceptor) - - // Setup and process request - logger := slog.Make() - interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - - processReq := httptest.NewRequest(http.MethodPost, routeCopilotChatCompletions, nil) - err = interceptor.ProcessRequest(w, processReq) - require.NoError(t, err) - - // Verify headers were forwarded - assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) - assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) - - // Verify non-Copilot headers are not forwarded - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") - }) - t.Run("Responses_NonStreamingRequest_BlockingInterceptor", func(t *testing.T) { t.Parallel() @@ -221,53 +175,6 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Contains(t, err.Error(), "unmarshal responses request body") }) - t.Run("Responses_ForwardsHeadersToUpstream", func(t *testing.T) { - t.Parallel() - - var receivedHeaders http.Header - - // Mock upstream that captures headers - mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"resp-123","object":"responses.response","created":1677652288,"model":"gpt-5-mini","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}`)) - })) - t.Cleanup(mockUpstream.Close) - - // Create provider with mock upstream URL - provider := NewCopilot(config.Copilot{ - BaseURL: mockUpstream.URL, - }) - - body := `{"model": "gpt-5-mini", "input": "hello", "stream": false}` - req := httptest.NewRequest(http.MethodPost, routeCopilotResponses, bytes.NewBufferString(body)) - req.Header.Set("Authorization", "Bearer test-token") - req.Header.Set("Editor-Version", "vscode/1.85.0") - req.Header.Set("Copilot-Integration-Id", "test-integration") - req.Header.Set("X-Custom-Header", "should-not-forward") - w := httptest.NewRecorder() - - interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.NoError(t, err) - require.NotNil(t, interceptor) - - // Setup and process request - logger := slog.Make() - interceptor.Setup(logger, &testutil.MockRecorder{}, nil) - - processReq := httptest.NewRequest(http.MethodPost, routeCopilotResponses, nil) - err = interceptor.ProcessRequest(w, processReq) - require.NoError(t, err) - - // Verify headers were forwarded - assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) - assert.Equal(t, "test-integration", receivedHeaders.Get("Copilot-Integration-Id")) - - // Verify non-Copilot headers are not forwarded - assert.Empty(t, receivedHeaders.Get("X-Custom-Header"), "non-Copilot headers should not be forwarded") - }) - t.Run("UnknownRoute", func(t *testing.T) { t.Parallel() @@ -283,6 +190,84 @@ func TestCopilot_CreateInterceptor(t *testing.T) { }) } +// TestCopilot_ForwardsHeadersToUpstream verifies that custom client headers are forwarded +// to the upstream, while filtered headers (hop-by-hop) are stripped, Copilot-specific +// headers are always forwarded, and the per-user token is used for auth. +func TestCopilot_ForwardsHeadersToUpstream(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + route string + body string + mockResponse string + }{ + { + name: "chat-completions", + route: routeCopilotChatCompletions, + body: `{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":false}`, + mockResponse: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + }, + { + name: "responses", + route: routeCopilotResponses, + body: `{"model":"gpt-5-mini","input":"hello","stream":false}`, + mockResponse: `{"id":"resp-123","object":"responses.response","created":1677652288,"model":"gpt-5-mini","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tc.mockResponse)) + })) + t.Cleanup(mockUpstream.Close) + + p := NewCopilot(config.Copilot{ + BaseURL: mockUpstream.URL, + }) + + req := httptest.NewRequest(http.MethodPost, tc.route, bytes.NewBufferString(tc.body)) + req.Header.Set("Authorization", "Bearer user-token") // per-user key, must reach upstream + req.Header.Set("Editor-Version", "vscode/1.85.0") // Copilot-specific, must reach upstream + req.Header.Set("Copilot-Integration-Id", "test-id") // Copilot-specific, must reach upstream + req.Header.Set("X-Custom-Header", "custom-value") // should be forwarded + req.Header.Set("X-Api-Key", "client-fake-key") // should be stripped (re-injected by SDK) + req.Header.Set("Upgrade", "websocket") // should be stripped (hop-by-hop) + w := httptest.NewRecorder() + + interceptor, err := p.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) + + err = interceptor.ProcessRequest(w, httptest.NewRequest(http.MethodPost, tc.route, nil)) + require.NoError(t, err) + + // Copilot-specific headers must reach the upstream. + assert.Equal(t, "vscode/1.85.0", receivedHeaders.Get("Editor-Version")) + assert.Equal(t, "test-id", receivedHeaders.Get("Copilot-Integration-Id")) + // Custom headers must reach the upstream. + assert.Equal(t, "custom-value", receivedHeaders.Get("X-Custom-Header")) + // Hop-by-hop headers must not reach the upstream. + assert.Empty(t, receivedHeaders.Get("Upgrade")) + // Per-user token must be used for auth as Copilot uses the client's token, not a global key. + assert.Equal(t, "Bearer user-token", receivedHeaders.Get("Authorization")) + // User-Agent must be set by the SDK, not forwarded from the client. + assert.True(t, strings.HasPrefix(receivedHeaders.Get("User-Agent"), "OpenAI/Go"), + "upstream User-Agent should be set by the SDK") + }) + } +} + func Test_extractBearerToken(t *testing.T) { t.Parallel() diff --git a/provider/openai.go b/provider/openai.go index 43d6811e..68d568d9 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -105,9 +105,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, r.Header, tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, r.Header, tracer) } case routeResponses: @@ -120,9 +120,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), tracer) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), tracer) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model), r.Header, tracer) } default: diff --git a/provider/openai_test.go b/provider/openai_test.go index f2654b07..c9b986f9 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -9,9 +9,14 @@ import ( "strings" "testing" - "github.com/coder/aibridge/config" + "cdr.dev/slog/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" "golang.org/x/sync/errgroup" + + "github.com/coder/aibridge/config" + "github.com/coder/aibridge/internal/testutil" ) type message struct { @@ -150,6 +155,82 @@ func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []by return bodyBytes } +// TestOpenAI_ForwardsHeadersToUpstream verifies that custom client headers are forwarded +// to the upstream, while filtered headers (auth, hop-by-hop) are stripped and the +// configured API key is always used. +func TestOpenAI_ForwardsHeadersToUpstream(t *testing.T) { + t.Parallel() + + const configuredKey = "configured-key" + + testCases := []struct { + name string + route string + body string + mockResponse string + }{ + { + name: "chat-completions", + route: routeChatCompletions, + body: `{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":false}`, + mockResponse: `{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, + }, + { + name: "responses", + route: routeResponses, + body: `{"model":"gpt-5-mini","input":"hello","stream":false}`, + mockResponse: `{"id":"resp-123","object":"realtime.response","created":1677652288,"model":"gpt-5-mini","output":[],"usage":{"input_tokens":5,"output_tokens":10,"total_tokens":15}}`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tc.mockResponse)) + })) + t.Cleanup(mockUpstream.Close) + + p := NewOpenAI(config.OpenAI{ + BaseURL: mockUpstream.URL, + Key: configuredKey, + }) + + req := httptest.NewRequest(http.MethodPost, tc.route, bytes.NewBufferString(tc.body)) + req.Header.Set("X-Custom-Header", "custom-value") // should be forwarded + req.Header.Set("Authorization", "Bearer client-fake-key") // should be stripped (re-injected by SDK) + req.Header.Set("X-Api-Key", "client-fake-key") // should be stripped (re-injected by SDK) + req.Header.Set("Upgrade", "websocket") // should be stripped (hop-by-hop) + w := httptest.NewRecorder() + + interceptor, err := p.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) + + err = interceptor.ProcessRequest(w, httptest.NewRequest(http.MethodPost, tc.route, nil)) + require.NoError(t, err) + + // Custom headers must reach the upstream. + assert.Equal(t, "custom-value", receivedHeaders.Get("X-Custom-Header")) + // Hop-by-hop headers must not reach the upstream. + assert.Empty(t, receivedHeaders.Get("Upgrade")) + // Configured key must be used, not the client-provided fake. + assert.Equal(t, "Bearer "+configuredKey, receivedHeaders.Get("Authorization")) + // User-Agent must be set by the SDK, not forwarded from the client. + assert.True(t, strings.HasPrefix(receivedHeaders.Get("User-Agent"), "OpenAI/Go"), + "upstream User-Agent should be set by the SDK") + }) + } +} + func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) { provider := NewOpenAI(config.OpenAI{ BaseURL: "https://api.openai.com/v1/",