Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/vmcp/auth/strategies/header_injection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/require"

authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
"github.com/stacklok/toolhive/pkg/vmcp/health"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

func TestHeaderInjectionStrategy_Name(t *testing.T) {
Expand Down Expand Up @@ -43,7 +43,7 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) {
HeaderValue: "secret-key-123",
},
},
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
expectError: false,
checkHeader: func(t *testing.T, req *http.Request) {
t.Helper()
Expand Down
4 changes: 2 additions & 2 deletions pkg/vmcp/auth/strategies/tokenexchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
"github.com/stacklok/toolhive/pkg/vmcp/health"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

const (
Expand Down Expand Up @@ -107,7 +107,7 @@ func (s *TokenExchangeStrategy) Authenticate(
// For health checks there is no user identity to exchange. If client credentials
// are configured, use a client credentials grant to authenticate the probe request.
// Otherwise skip authentication — the backend will be probed unauthenticated.
if health.IsHealthCheck(ctx) {
if healthcontext.IsHealthCheck(ctx) {
if config.ClientID != "" && config.ClientSecret != "" {
return s.authenticateWithClientCredentials(ctx, req, config)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/vmcp/auth/strategies/tokenexchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/stacklok/toolhive-core/env/mocks"
"github.com/stacklok/toolhive/pkg/auth"
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
"github.com/stacklok/toolhive/pkg/vmcp/health"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

// Test constants
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestTokenExchangeStrategy_Authenticate(t *testing.T) {
}{
{
name: "health check without client credentials skips authentication",
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
t.Error("token endpoint should not be called when no client credentials are configured")
Expand All @@ -125,7 +125,7 @@ func TestTokenExchangeStrategy_Authenticate(t *testing.T) {
},
{
name: "health check with client credentials uses client credentials grant",
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Helper()
Expand Down
4 changes: 2 additions & 2 deletions pkg/vmcp/auth/strategies/upstream_inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

"github.com/stacklok/toolhive/pkg/auth"
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
"github.com/stacklok/toolhive/pkg/vmcp/health"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

// UpstreamInjectStrategy injects an upstream IDP token into backend request headers.
Expand Down Expand Up @@ -62,7 +62,7 @@ func (*UpstreamInjectStrategy) Authenticate(
ctx context.Context, req *http.Request, strategy *authtypes.BackendAuthStrategy,
) error {
// Health checks have no user identity — skip authentication.
if health.IsHealthCheck(ctx) {
if healthcontext.IsHealthCheck(ctx) {
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/vmcp/auth/strategies/upstream_inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/stretchr/testify/require"

authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" // BackendAuthStrategy, ErrUpstreamTokenNotFound
"github.com/stacklok/toolhive/pkg/vmcp/health"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

func TestUpstreamInjectStrategy_Name(t *testing.T) {
Expand Down Expand Up @@ -118,7 +118,7 @@ func TestUpstreamInjectStrategy_Authenticate(t *testing.T) {
},
{
name: "health check bypass",
setupCtx: func() context.Context { return health.WithHealthCheckMarker(context.Background()) },
setupCtx: func() context.Context { return healthcontext.WithHealthCheckMarker(context.Background()) },
strategy: &authtypes.BackendAuthStrategy{
Type: authtypes.StrategyTypeUpstreamInject,
UpstreamInject: &authtypes.UpstreamInjectConfig{
Expand Down
37 changes: 27 additions & 10 deletions pkg/vmcp/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
"github.com/stacklok/toolhive/pkg/vmcp/conversion"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

const (
Expand Down Expand Up @@ -116,19 +117,31 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

// identityPropagatingRoundTripper propagates identity to backend HTTP requests.
// identityPropagatingRoundTripper propagates identity and health-check markers to backend HTTP requests.
// This ensures that identity information from the vMCP handler is available for authentication
// strategies that need it (e.g., token exchange).
//
// The health-check marker is stored at transport creation time and re-injected into every
// outgoing request, including the DELETE that mcp-go sends when closing a streamable-HTTP
// session. Without this, mcp-go's Close() creates a fresh context.Background()-based request
// that loses the health-check marker, causing auth strategies (UpstreamInjectStrategy,
// TokenExchangeStrategy) to fail with "no identity found in context".
type identityPropagatingRoundTripper struct {
base http.RoundTripper
identity *auth.Identity
base http.RoundTripper
identity *auth.Identity
isHealthCheck bool
}

// RoundTrip implements http.RoundTripper by adding identity to the request context.
// RoundTrip implements http.RoundTripper by adding identity and health-check marker to the request context.
func (i *identityPropagatingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
if i.identity != nil {
// Add identity to the request's context
ctx := auth.WithIdentity(req.Context(), i.identity)
ctx = auth.WithIdentity(ctx, i.identity)
}
if i.isHealthCheck {
ctx = healthcontext.WithHealthCheckMarker(ctx)
}
if i.identity != nil || i.isHealthCheck {
req = req.Clone(ctx)
}
return i.base.RoundTrip(req)
Expand Down Expand Up @@ -227,12 +240,16 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm
target: target,
}

// Extract identity from context and propagate it to backend requests
// This ensures authentication strategies (e.g., token exchange) can access identity
// Extract identity and health-check marker from context and propagate them to backend
// requests. The health-check marker must be carried through to the DELETE request that
// mcp-go emits when closing a streamable-HTTP session: mcp-go creates that request with
// context.Background(), which loses both the identity and the health-check marker that
// were present on the original ListCapabilities call context.
identity, _ := auth.IdentityFromContext(ctx)
baseTransport = &identityPropagatingRoundTripper{
base: baseTransport,
identity: identity,
base: baseTransport,
identity: identity,
isHealthCheck: healthcontext.IsHealthCheck(ctx),
}

// Inject W3C Trace Context headers (traceparent/tracestate) into outgoing requests.
Expand Down
124 changes: 124 additions & 0 deletions pkg/vmcp/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import (
"go.opentelemetry.io/otel/trace"
"go.uber.org/mock/gomock"

pkgauth "github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/vmcp"
"github.com/stacklok/toolhive/pkg/vmcp/auth"
authmocks "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks"
"github.com/stacklok/toolhive/pkg/vmcp/auth/strategies"
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) {
Expand Down Expand Up @@ -855,3 +857,125 @@ func TestWrapBackendError(t *testing.T) {
})
}
}

// ---------------------------------------------------------------------------
// identityPropagatingRoundTripper
// ---------------------------------------------------------------------------

func TestIdentityPropagatingRoundTripper_WithIdentity_PropagatesIdentityInContext(t *testing.T) {
t.Parallel()

base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
identity := &pkgauth.Identity{PrincipalInfo: pkgauth.PrincipalInfo{Subject: "user-1"}}
rt := &identityPropagatingRoundTripper{base: base, identity: identity}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
require.NoError(t, err)

_, err = rt.RoundTrip(req)
require.NoError(t, err)

require.NotNil(t, base.capturedReq)
got, ok := pkgauth.IdentityFromContext(base.capturedReq.Context())
require.True(t, ok, "identity should be in downstream request context")
assert.Equal(t, "user-1", got.Subject)
}

func TestIdentityPropagatingRoundTripper_NilIdentity_NoIdentityInContext(t *testing.T) {
t.Parallel()

base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
rt := &identityPropagatingRoundTripper{base: base, identity: nil}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
require.NoError(t, err)

_, err = rt.RoundTrip(req)
require.NoError(t, err)

require.NotNil(t, base.capturedReq)
_, ok := pkgauth.IdentityFromContext(base.capturedReq.Context())
assert.False(t, ok, "no identity should be in downstream context when nil identity configured")
}

func TestIdentityPropagatingRoundTripper_HealthCheck_PropagatesMarker(t *testing.T) {
t.Parallel()

base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: true}

// Simulate mcp-go Close(): request created with context.Background(), no health check marker.
req, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, "http://backend.example.com/mcp", nil)
require.NoError(t, err)

_, err = rt.RoundTrip(req)
require.NoError(t, err)

require.NotNil(t, base.capturedReq)
assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()),
"health check marker should be propagated even when original request context lacks it")
}

func TestIdentityPropagatingRoundTripper_NonHealthCheck_NoMarkerAdded(t *testing.T) {
t.Parallel()

base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: false}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
require.NoError(t, err)

_, err = rt.RoundTrip(req)
require.NoError(t, err)

require.NotNil(t, base.capturedReq)
assert.False(t, healthcontext.IsHealthCheck(base.capturedReq.Context()),
"health check marker should not be injected for non-health-check transports")
}

func TestIdentityPropagatingRoundTripper_HealthCheckWithIdentity_PropagatesBoth(t *testing.T) {
t.Parallel()

base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
identity := &pkgauth.Identity{PrincipalInfo: pkgauth.PrincipalInfo{Subject: "svc-account"}}
rt := &identityPropagatingRoundTripper{base: base, identity: identity, isHealthCheck: true}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://backend.example.com/mcp", nil)
require.NoError(t, err)

_, err = rt.RoundTrip(req)
require.NoError(t, err)

require.NotNil(t, base.capturedReq)
got, ok := pkgauth.IdentityFromContext(base.capturedReq.Context())
require.True(t, ok)
assert.Equal(t, "svc-account", got.Subject)
assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()))
}

// TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged verifies
// that when the transport is in health-check mode, RoundTrip injects the health-check marker
// into the downstream request's context without mutating the original request context. This
// covers requests (e.g. the DELETE mcp-go emits on Close()) whose context does not already
// carry the marker.
func TestIdentityPropagatingRoundTripper_HealthCheckClose_OriginalRequestContextUnchanged(t *testing.T) {
t.Parallel()

base := &mockRoundTripper{response: &http.Response{StatusCode: http.StatusOK}}
rt := &identityPropagatingRoundTripper{base: base, identity: nil, isHealthCheck: true}

originalCtx := context.Background() // no health check marker — simulates mcp-go Close()
req, err := http.NewRequestWithContext(originalCtx, http.MethodDelete, "http://backend.example.com/mcp", nil)
require.NoError(t, err)

_, err = rt.RoundTrip(req)
require.NoError(t, err)

// Original request context must NOT be modified.
assert.False(t, healthcontext.IsHealthCheck(originalCtx),
"original request context must not be mutated")
// But downstream context MUST have the marker.
require.NotNil(t, base.capturedReq)
assert.True(t, healthcontext.IsHealthCheck(base.capturedReq.Context()),
"downstream request must carry health check marker")
}
33 changes: 33 additions & 0 deletions pkg/vmcp/health/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

// Package healthcontext provides a lightweight, dependency-free context marker
// for identifying health check requests. Keeping this in a separate package
// allows packages like pkg/vmcp/client and pkg/vmcp/auth/strategies to use
// the marker without pulling in the heavyweight pkg/vmcp/health dependencies
// (e.g. k8s.io/apimachinery).
package healthcontext

import "context"

// healthCheckContextKey is an unexported key type for the health check marker.
type healthCheckContextKey struct{}

// WithHealthCheckMarker marks a context as a health check request.
// Authentication layers can use IsHealthCheck to identify and skip authentication
// for health check requests.
func WithHealthCheckMarker(ctx context.Context) context.Context {
return context.WithValue(ctx, healthCheckContextKey{}, true)
}

// IsHealthCheck returns true if the context is marked as a health check.
// Authentication strategies use this to bypass authentication for health checks,
// since health checks verify backend availability and should not require user credentials.
// Returns false for nil contexts.
func IsHealthCheck(ctx context.Context) bool {
if ctx == nil {
return false
}
val, ok := ctx.Value(healthCheckContextKey{}).(bool)
return ok && val
}
25 changes: 25 additions & 0 deletions pkg/vmcp/health/context/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package healthcontext

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestIsHealthCheck_WrongValueType(t *testing.T) {
t.Parallel()

ctx := context.WithValue(context.Background(), healthCheckContextKey{}, "not-a-bool")
assert.False(t, IsHealthCheck(ctx), "non-bool value should not be treated as health check marker")
}

func TestIsHealthCheck_FalseValue(t *testing.T) {
t.Parallel()

ctx := context.WithValue(context.Background(), healthCheckContextKey{}, false)
assert.False(t, IsHealthCheck(ctx), "explicit false value should not be treated as health check marker")
}
12 changes: 3 additions & 9 deletions pkg/vmcp/health/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,22 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/stacklok/toolhive/pkg/vmcp"
healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context"
)

// healthCheckContextKey is a marker for health check requests.
type healthCheckContextKey struct{}

// WithHealthCheckMarker marks a context as a health check request.
// Authentication layers can use IsHealthCheck to identify and skip authentication
// for health check requests.
func WithHealthCheckMarker(ctx context.Context) context.Context {
return context.WithValue(ctx, healthCheckContextKey{}, true)
return healthcontext.WithHealthCheckMarker(ctx)
}

// IsHealthCheck returns true if the context is marked as a health check.
// Authentication strategies use this to bypass authentication for health checks,
// since health checks verify backend availability and should not require user credentials.
// Returns false for nil contexts.
func IsHealthCheck(ctx context.Context) bool {
if ctx == nil {
return false
}
val, ok := ctx.Value(healthCheckContextKey{}).(bool)
return ok && val
return healthcontext.IsHealthCheck(ctx)
}

// StatusProvider provides read-only access to backend health status.
Expand Down
Loading
Loading