diff --git a/pkg/vmcp/auth/strategies/header_injection_test.go b/pkg/vmcp/auth/strategies/header_injection_test.go index 9c0baa4628..723589e520 100644 --- a/pkg/vmcp/auth/strategies/header_injection_test.go +++ b/pkg/vmcp/auth/strategies/header_injection_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - "github.com/stacklok/toolhive/pkg/vmcp/health" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) func TestHeaderInjectionStrategy_Name(t *testing.T) { @@ -43,7 +43,7 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { HeaderValue: "secret-key-123", }, }, - setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) }, + setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { t.Helper() diff --git a/pkg/vmcp/auth/strategies/tokenexchange.go b/pkg/vmcp/auth/strategies/tokenexchange.go index 15b2c95fad..ed754fd972 100644 --- a/pkg/vmcp/auth/strategies/tokenexchange.go +++ b/pkg/vmcp/auth/strategies/tokenexchange.go @@ -18,7 +18,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/auth/tokenexchange" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - "github.com/stacklok/toolhive/pkg/vmcp/health" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) const ( @@ -107,7 +107,7 @@ func (s *TokenExchangeStrategy) Authenticate( // For health checks there is no user identity to exchange. If client credentials // are configured, use a client credentials grant to authenticate the probe request. // Otherwise skip authentication — the backend will be probed unauthenticated. - if health.IsHealthCheck(ctx) { + if healthcontext.IsHealthCheck(ctx) { if config.ClientID != "" && config.ClientSecret != "" { return s.authenticateWithClientCredentials(ctx, req, config) } diff --git a/pkg/vmcp/auth/strategies/tokenexchange_test.go b/pkg/vmcp/auth/strategies/tokenexchange_test.go index 2bb7688c0a..9041bb9c67 100644 --- a/pkg/vmcp/auth/strategies/tokenexchange_test.go +++ b/pkg/vmcp/auth/strategies/tokenexchange_test.go @@ -18,7 +18,7 @@ import ( "github.com/stacklok/toolhive-core/env/mocks" "github.com/stacklok/toolhive/pkg/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - "github.com/stacklok/toolhive/pkg/vmcp/health" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) // Test constants @@ -108,7 +108,7 @@ func TestTokenExchangeStrategy_Authenticate(t *testing.T) { }{ { name: "health check without client credentials skips authentication", - setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) }, + setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) }, setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Error("token endpoint should not be called when no client credentials are configured") @@ -125,7 +125,7 @@ func TestTokenExchangeStrategy_Authenticate(t *testing.T) { }, { name: "health check with client credentials uses client credentials grant", - setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) }, + setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) }, setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Helper() diff --git a/pkg/vmcp/auth/strategies/upstream_inject.go b/pkg/vmcp/auth/strategies/upstream_inject.go index 9540196c01..519216cc04 100644 --- a/pkg/vmcp/auth/strategies/upstream_inject.go +++ b/pkg/vmcp/auth/strategies/upstream_inject.go @@ -10,7 +10,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - "github.com/stacklok/toolhive/pkg/vmcp/health" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) // UpstreamInjectStrategy injects an upstream IDP token into backend request headers. @@ -62,7 +62,7 @@ func (*UpstreamInjectStrategy) Authenticate( ctx context.Context, req *http.Request, strategy *authtypes.BackendAuthStrategy, ) error { // Health checks have no user identity — skip authentication. - if health.IsHealthCheck(ctx) { + if healthcontext.IsHealthCheck(ctx) { return nil } diff --git a/pkg/vmcp/auth/strategies/upstream_inject_test.go b/pkg/vmcp/auth/strategies/upstream_inject_test.go index 01cbc55751..d77d8e75e7 100644 --- a/pkg/vmcp/auth/strategies/upstream_inject_test.go +++ b/pkg/vmcp/auth/strategies/upstream_inject_test.go @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/require" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" // BackendAuthStrategy, ErrUpstreamTokenNotFound - "github.com/stacklok/toolhive/pkg/vmcp/health" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) func TestUpstreamInjectStrategy_Name(t *testing.T) { @@ -118,7 +118,7 @@ func TestUpstreamInjectStrategy_Authenticate(t *testing.T) { }, { name: "health check bypass", - setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) }, + setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) }, strategy: &authtypes.BackendAuthStrategy{ Type: authtypes.StrategyTypeUpstreamInject, UpstreamInject: &authtypes.UpstreamInjectConfig{ diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 8a2f0b202c..6f69a4a967 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -29,6 +29,7 @@ import ( vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" "github.com/stacklok/toolhive/pkg/vmcp/conversion" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) const ( @@ -116,19 +117,31 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } -// identityPropagatingRoundTripper propagates identity to backend HTTP requests. +// identityPropagatingRoundTripper propagates identity and health-check markers to backend HTTP requests. // This ensures that identity information from the vMCP handler is available for authentication // strategies that need it (e.g., token exchange). +// +// The health-check marker is stored at transport creation time and re-injected into every +// outgoing request, including the DELETE that mcp-go sends when closing a streamable-HTTP +// session. Without this, mcp-go's Close() creates a fresh context.Background()-based request +// that loses the health-check marker, causing auth strategies (UpstreamInjectStrategy, +// TokenExchangeStrategy) to fail with "no identity found in context". type identityPropagatingRoundTripper struct { - base http.RoundTripper - identity *auth.Identity + base http.RoundTripper + identity *auth.Identity + isHealthCheck bool } -// RoundTrip implements http.RoundTripper by adding identity to the request context. +// RoundTrip implements http.RoundTripper by adding identity and health-check marker to the request context. func (i *identityPropagatingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() if i.identity != nil { - // Add identity to the request's context - ctx := auth.WithIdentity(req.Context(), i.identity) + ctx = auth.WithIdentity(ctx, i.identity) + } + if i.isHealthCheck { + ctx = healthcontext.WithHealthCheckMarker(ctx) + } + if i.identity != nil || i.isHealthCheck { req = req.Clone(ctx) } return i.base.RoundTrip(req) @@ -227,12 +240,16 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm target: target, } - // Extract identity from context and propagate it to backend requests - // This ensures authentication strategies (e.g., token exchange) can access identity + // Extract identity and health-check marker from context and propagate them to backend + // requests. The health-check marker must be carried through to the DELETE request that + // mcp-go emits when closing a streamable-HTTP session: mcp-go creates that request with + // context.Background(), which loses both the identity and the health-check marker that + // were present on the original ListCapabilities call context. identity, _ := auth.IdentityFromContext(ctx) baseTransport = &identityPropagatingRoundTripper{ - base: baseTransport, - identity: identity, + base: baseTransport, + identity: identity, + isHealthCheck: healthcontext.IsHealthCheck(ctx), } // Inject W3C Trace Context headers (traceparent/tracestate) into outgoing requests. diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go index beb27a0882..c8624727d8 100644 --- a/pkg/vmcp/client/client_test.go +++ b/pkg/vmcp/client/client_test.go @@ -22,11 +22,13 @@ import ( "go.opentelemetry.io/otel/trace" "go.uber.org/mock/gomock" + pkgauth "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/auth" authmocks "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) { @@ -855,3 +857,125 @@ func TestWrapBackendError(t *testing.T) { }) } } + +// --------------------------------------------------------------------------- +// identityPropagatingRoundTripper +// --------------------------------------------------------------------------- + +func TestIdentityPropagatingRoundTripper_WithIdentity_PropagatesIdentityInContext(t *testing.T) { + t.Parallel() + + base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}} + identity := &pkgauth.Identity{PrincipalInfo: pkgauth.PrincipalInfo{Subject: "user-1"}} + rt := &identityPropagatingRoundTripper{base: base, identity: identity} + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, base.capturedReq) + got, ok := pkgauth.IdentityFromContext(base.capturedReq.Context()) + require.True(t, ok, "identity should be in downstream request context") + assert.Equal(t, "user-1", got.Subject) +} + +func TestIdentityPropagatingRoundTripper_NilIdentity_NoIdentityInContext(t *testing.T) { + t.Parallel() + + base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}} + rt := &identityPropagatingRoundTripper{base: base, identity: nil} + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, base.capturedReq) + _, ok := pkgauth.IdentityFromContext(base.capturedReq.Context()) + assert.False(t, ok, "no identity should be in downstream context when nil identity configured") +} + +func TestIdentityPropagatingRoundTripper_HealthCheck_PropagatesMarker(t *testing.T) { + t.Parallel() + + base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}} + rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: true} + + // Simulate mcp-go Close(): request created with context.Background(), no health check marker. + req, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, "http://backend.example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, base.capturedReq) + assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()), + "health check marker should be propagated even when original request context lacks it") +} + +func TestIdentityPropagatingRoundTripper_NonHealthCheck_NoMarkerAdded(t *testing.T) { + t.Parallel() + + base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}} + rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: false} + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, base.capturedReq) + assert.False(t, healthcontext.IsHealthCheck(base.capturedReq.Context()), + "health check marker should not be injected for non-health-check transports") +} + +func TestIdentityPropagatingRoundTripper_HealthCheckWithIdentity_PropagatesBoth(t *testing.T) { + t.Parallel() + + base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}} + identity := &pkgauth.Identity{PrincipalInfo: pkgauth.PrincipalInfo{Subject: "svc-account"}} + rt := &identityPropagatingRoundTripper{base: base, identity: identity, isHealthCheck: true} + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, base.capturedReq) + got, ok := pkgauth.IdentityFromContext(base.capturedReq.Context()) + require.True(t, ok) + assert.Equal(t, "svc-account", got.Subject) + assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context())) +} + +// TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged verifies +// that when the transport is in health-check mode, RoundTrip injects the health-check marker +// into the downstream request's context without mutating the original request context. This +// covers requests (e.g. the DELETE mcp-go emits on Close()) whose context does not already +// carry the marker. +func TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged(t *testing.T) { + t.Parallel() + + base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}} + rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: true} + + originalCtx := context.Background() // no health check marker — simulates mcp-go Close() + req, err := http.NewRequestWithContext(originalCtx, http.MethodDelete, "http://backend.example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + // Original request context must NOT be modified. + assert.False(t, healthcontext.IsHealthCheck(originalCtx), + "original request context must not be mutated") + // But downstream context MUST have the marker. + require.NotNil(t, base.capturedReq) + assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()), + "downstream request must carry health check marker") +} diff --git a/pkg/vmcp/health/context/context.go b/pkg/vmcp/health/context/context.go new file mode 100644 index 0000000000..3a8df9deef --- /dev/null +++ b/pkg/vmcp/health/context/context.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package healthcontext provides a lightweight, dependency-free context marker +// for identifying health check requests. Keeping this in a separate package +// allows packages like pkg/vmcp/client and pkg/vmcp/auth/strategies to use +// the marker without pulling in the heavyweight pkg/vmcp/health dependencies +// (e.g. k8s.io/apimachinery). +package healthcontext + +import "context" + +// healthCheckContextKey is an unexported key type for the health check marker. +type healthCheckContextKey struct{} + +// WithHealthCheckMarker marks a context as a health check request. +// Authentication layers can use IsHealthCheck to identify and skip authentication +// for health check requests. +func WithHealthCheckMarker(ctx context.Context) context.Context { + return context.WithValue(ctx, healthCheckContextKey{}, true) +} + +// IsHealthCheck returns true if the context is marked as a health check. +// Authentication strategies use this to bypass authentication for health checks, +// since health checks verify backend availability and should not require user credentials. +// Returns false for nil contexts. +func IsHealthCheck(ctx context.Context) bool { + if ctx == nil { + return false + } + val, ok := ctx.Value(healthCheckContextKey{}).(bool) + return ok && val +} diff --git a/pkg/vmcp/health/context/context_test.go b/pkg/vmcp/health/context/context_test.go new file mode 100644 index 0000000000..d407d23d8d --- /dev/null +++ b/pkg/vmcp/health/context/context_test.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package healthcontext + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsHealthCheck_WrongValueType(t *testing.T) { + t.Parallel() + + ctx := context.WithValue(context.Background(), healthCheckContextKey{}, "not-a-bool") + assert.False(t, IsHealthCheck(ctx), "non-bool value should not be treated as health check marker") +} + +func TestIsHealthCheck_FalseValue(t *testing.T) { + t.Parallel() + + ctx := context.WithValue(context.Background(), healthCheckContextKey{}, false) + assert.False(t, IsHealthCheck(ctx), "explicit false value should not be treated as health check marker") +} diff --git a/pkg/vmcp/health/monitor.go b/pkg/vmcp/health/monitor.go index 7b34288848..29588146d7 100644 --- a/pkg/vmcp/health/monitor.go +++ b/pkg/vmcp/health/monitor.go @@ -14,16 +14,14 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/stacklok/toolhive/pkg/vmcp" + healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) -// healthCheckContextKey is a marker for health check requests. -type healthCheckContextKey struct{} - // WithHealthCheckMarker marks a context as a health check request. // Authentication layers can use IsHealthCheck to identify and skip authentication // for health check requests. func WithHealthCheckMarker(ctx context.Context) context.Context { - return context.WithValue(ctx, healthCheckContextKey{}, true) + return healthcontext.WithHealthCheckMarker(ctx) } // IsHealthCheck returns true if the context is marked as a health check. @@ -31,11 +29,7 @@ func WithHealthCheckMarker(ctx context.Context) context.Context { // since health checks verify backend availability and should not require user credentials. // Returns false for nil contexts. func IsHealthCheck(ctx context.Context) bool { - if ctx == nil { - return false - } - val, ok := ctx.Value(healthCheckContextKey{}).(bool) - return ok && val + return healthcontext.IsHealthCheck(ctx) } // StatusProvider provides read-only access to backend health status. diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go index 2bf2d24051..d03f305395 100644 --- a/pkg/vmcp/health/monitor_test.go +++ b/pkg/vmcp/health/monitor_test.go @@ -702,20 +702,6 @@ func TestIsHealthCheck(t *testing.T) { }, expected: false, }, - { - name: "returns false for context with wrong value type", - setupCtx: func() context.Context { - return context.WithValue(context.Background(), healthCheckContextKey{}, "not-a-bool") - }, - expected: false, - }, - { - name: "returns false for context with false value", - setupCtx: func() context.Context { - return context.WithValue(context.Background(), healthCheckContextKey{}, false) - }, - expected: false, - }, { name: "returns true when nested in parent context", setupCtx: func() context.Context {