diff --git a/cmd/thv-operator/controllers/mcpserver_runconfig.go b/cmd/thv-operator/controllers/mcpserver_runconfig.go index 45dad52740..6f57eb1e7f 100644 --- a/cmd/thv-operator/controllers/mcpserver_runconfig.go +++ b/cmd/thv-operator/controllers/mcpserver_runconfig.go @@ -288,8 +288,8 @@ func (r *MCPServerReconciler) createRunConfigFromMCPServer(m *mcpv1alpha1.MCPSer return runConfig, nil } -// populateScalingConfig sets BackendReplicas and SessionRedis on the RunConfig from the MCPServer spec. -// Fields are only set when present in the spec; nil means "not configured" and is left as-is. +// populateScalingConfig sets BackendReplicas, SessionRedis, and HeadlessService on the RunConfig +// from the MCPServer spec. Fields are only set when present in the spec; nil means "not configured". func populateScalingConfig(runConfig *runner.RunConfig, m *mcpv1alpha1.MCPServer) { hasBackendReplicas := m.Spec.BackendReplicas != nil hasRedis := m.Spec.SessionStorage != nil && m.Spec.SessionStorage.Provider == mcpv1alpha1.SessionStorageProviderRedis @@ -305,6 +305,18 @@ func populateScalingConfig(runConfig *runner.RunConfig, m *mcpv1alpha1.MCPServer if hasBackendReplicas { val := *m.Spec.BackendReplicas runConfig.ScalingConfig.BackendReplicas = &val + + // Always populate headless service config when BackendReplicas is set. + // This enables the proxy runner to route each session to a specific pod via + // headless DNS (e.g. myserver-0.mcp-myserver-headless.default.svc.cluster.local) + // so sessions survive proxy-runner restarts. For single-replica StatefulSets, + // ordinal 0 is always selected deterministically. + runConfig.ScalingConfig.HeadlessService = &transporttypes.HeadlessServiceConfig{ + StatefulSetName: m.Name, + ServiceName: fmt.Sprintf("mcp-%s-headless", m.Name), + Namespace: m.Namespace, + Replicas: val, + } } if hasRedis { diff --git a/docs/server/docs.go b/docs/server/docs.go index dde72bfaa7..8e90b51287 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -1047,6 +1047,28 @@ const docTemplate = `{ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_runner.HeadlessServiceConfig": { + "description": "HeadlessService holds the information needed to construct pod-specific headless DNS URLs\nfor session-affinity routing in multi-replica StatefulSet deployments.\nPopulated by the operator when backendReplicas \u003e 1; nil for single-replica deployments.\n+optional", + "properties": { + "namespace": { + "description": "Namespace is the Kubernetes namespace of the StatefulSet.", + "type": "string" + }, + "replicas": { + "description": "Replicas is the StatefulSet replica count, used to select a random pod ordinal.", + "type": "integer" + }, + "service_name": { + "description": "ServiceName is the name of the headless Kubernetes service (e.g. \"mcp-myserver-headless\").", + "type": "string" + }, + "statefulset_name": { + "description": "StatefulSetName is the name of the backend StatefulSet (equals the MCPServer name).", + "type": "string" + } + }, + "type": "object" + }, "github_com_stacklok_toolhive_pkg_runner.RunConfig": { "properties": { "allow_docker_gateway": { @@ -1273,6 +1295,9 @@ const docTemplate = `{ "description": "BackendReplicas is the desired StatefulSet replica count for the proxy runner backend.\nWhen nil, replicas are unmanaged (preserving HPA or manual kubectl control).\nWhen set (including 0), the value is an explicit replica count.", "type": "integer" }, + "headless_service": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_runner.HeadlessServiceConfig" + }, "session_redis": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_runner.SessionRedisConfig" } diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 4f6dd8ea72..1547e703ff 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1040,6 +1040,28 @@ }, "type": "object" }, + "github_com_stacklok_toolhive_pkg_runner.HeadlessServiceConfig": { + "description": "HeadlessService holds the information needed to construct pod-specific headless DNS URLs\nfor session-affinity routing in multi-replica StatefulSet deployments.\nPopulated by the operator when backendReplicas \u003e 1; nil for single-replica deployments.\n+optional", + "properties": { + "namespace": { + "description": "Namespace is the Kubernetes namespace of the StatefulSet.", + "type": "string" + }, + "replicas": { + "description": "Replicas is the StatefulSet replica count, used to select a random pod ordinal.", + "type": "integer" + }, + "service_name": { + "description": "ServiceName is the name of the headless Kubernetes service (e.g. \"mcp-myserver-headless\").", + "type": "string" + }, + "statefulset_name": { + "description": "StatefulSetName is the name of the backend StatefulSet (equals the MCPServer name).", + "type": "string" + } + }, + "type": "object" + }, "github_com_stacklok_toolhive_pkg_runner.RunConfig": { "properties": { "allow_docker_gateway": { @@ -1266,6 +1288,9 @@ "description": "BackendReplicas is the desired StatefulSet replica count for the proxy runner backend.\nWhen nil, replicas are unmanaged (preserving HPA or manual kubectl control).\nWhen set (including 0), the value is an explicit replica count.", "type": "integer" }, + "headless_service": { + "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_runner.HeadlessServiceConfig" + }, "session_redis": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_runner.SessionRedisConfig" } diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 306f81dc6b..12246164ae 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -988,6 +988,29 @@ components: For sensitive values (API keys, tokens), use AddHeadersFromSecret instead. type: object type: object + github_com_stacklok_toolhive_pkg_runner.HeadlessServiceConfig: + description: |- + HeadlessService holds the information needed to construct pod-specific headless DNS URLs + for session-affinity routing in multi-replica StatefulSet deployments. + Populated by the operator when backendReplicas > 1; nil for single-replica deployments. + +optional + properties: + namespace: + description: Namespace is the Kubernetes namespace of the StatefulSet. + type: string + replicas: + description: Replicas is the StatefulSet replica count, used to select a + random pod ordinal. + type: integer + service_name: + description: ServiceName is the name of the headless Kubernetes service + (e.g. "mcp-myserver-headless"). + type: string + statefulset_name: + description: StatefulSetName is the name of the backend StatefulSet (equals + the MCPServer name). + type: string + type: object github_com_stacklok_toolhive_pkg_runner.RunConfig: properties: allow_docker_gateway: @@ -1194,6 +1217,8 @@ components: When nil, replicas are unmanaged (preserving HPA or manual kubectl control). When set (including 0), the value is an explicit replica count. type: integer + headless_service: + $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_runner.HeadlessServiceConfig' session_redis: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_runner.SessionRedisConfig' type: object diff --git a/pkg/container/kubernetes/client.go b/pkg/container/kubernetes/client.go index fca0355484..2c338cef2d 100644 --- a/pkg/container/kubernetes/client.go +++ b/pkg/container/kubernetes/client.go @@ -447,7 +447,7 @@ func buildStatefulSetSpec( spec := appsv1apply.StatefulSetSpec(). WithSelector(metav1apply.LabelSelector(). WithMatchLabels(map[string]string{"app": containerName})). - WithServiceName(containerName). + WithServiceName(fmt.Sprintf("mcp-%s-headless", containerName)). WithTemplate(podTemplateSpec) if options != nil && options.ScalingConfig != nil && options.ScalingConfig.BackendReplicas != nil { spec = spec.WithReplicas(*options.ScalingConfig.BackendReplicas) diff --git a/pkg/runner/config.go b/pkg/runner/config.go index 3bda0e8b06..f3ac52593b 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -247,6 +247,12 @@ type ScalingConfig struct { // The Redis password is not included — it is injected as env var THV_SESSION_REDIS_PASSWORD. // +optional SessionRedis *SessionRedisConfig `json:"session_redis,omitempty" yaml:"session_redis,omitempty"` + + // HeadlessService holds the information needed to construct pod-specific headless DNS URLs + // for session-affinity routing in Kubernetes StatefulSet deployments. + // Populated by the operator whenever BackendReplicas is set (including single-replica). + // +optional + HeadlessService *types.HeadlessServiceConfig `json:"headless_service,omitempty" yaml:"headless_service,omitempty"` } // SessionRedisConfig contains non-sensitive Redis connection parameters used for distributed diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 42be99f7fb..fb753610b5 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -35,6 +35,7 @@ import ( "github.com/stacklok/toolhive/pkg/transport" "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/transport/types" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/workloads/statuses" ) @@ -354,6 +355,11 @@ func (r *Runner) Run(ctx context.Context) error { } } + // Enable pod-specific session routing for Kubernetes StatefulSet backends. + if r.Config.ScalingConfig != nil && r.Config.ScalingConfig.HeadlessService != nil { + transportConfig.HeadlessService = r.Config.ScalingConfig.HeadlessService + } + // When Redis session storage is configured, create a Redis-backed session store // so sessions are shared across proxy replicas instead of being pod-local. if r.Config.ScalingConfig != nil && r.Config.ScalingConfig.SessionRedis != nil { @@ -364,7 +370,7 @@ func (r *Runner) Run(ctx context.Context) error { } storage, err := session.NewRedisStorage(ctx, session.RedisConfig{ Addr: redisCfg.Address, - Password: os.Getenv(session.RedisPasswordEnvVar), + Password: os.Getenv(vmcpconfig.RedisPasswordEnvVar), DB: int(redisCfg.DB), KeyPrefix: keyPrefix, }, session.DefaultSessionTTL) diff --git a/pkg/transport/factory.go b/pkg/transport/factory.go index 003bab13ee..47897d34c6 100644 --- a/pkg/transport/factory.go +++ b/pkg/transport/factory.go @@ -73,6 +73,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er config.Middlewares..., ) httpTransport.sessionStorage = config.SessionStorage + httpTransport.headlessService = config.HeadlessService tr = httpTransport case types.TransportTypeStreamableHTTP: httpTransport := NewHTTPTransport( @@ -91,6 +92,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er config.Middlewares..., ) httpTransport.sessionStorage = config.SessionStorage + httpTransport.headlessService = config.HeadlessService tr = httpTransport case types.TransportTypeInspector: // HTTP transport is not implemented yet diff --git a/pkg/transport/http.go b/pkg/transport/http.go index a9c81bd343..ccbe7870b9 100644 --- a/pkg/transport/http.go +++ b/pkg/transport/http.go @@ -74,6 +74,9 @@ type HTTPTransport struct { // Mutex for protecting shared state mutex sync.Mutex + // headlessService configures pod-specific routing for Kubernetes StatefulSet deployments. + headlessService *types.HeadlessServiceConfig + // sessionStorage overrides the default in-memory session store when set. // Used for Redis-backed session sharing across replicas. sessionStorage session.Storage @@ -239,6 +242,33 @@ func (t *HTTPTransport) setTargetURI(targetURI string) { t.targetURI = targetURI } +// resolveTargetURI determines the proxy target URI, base path, and raw query from the +// transport configuration. For remote MCP servers it parses the remote URL; for local +// containers it returns the pre-configured targetURI. +func (t *HTTPTransport) resolveTargetURI() (targetURI, remoteBasePath, remoteRawQuery string, err error) { + if t.remoteURL != "" { + remoteURL, err := url.Parse(t.remoteURL) + if err != nil { + return "", "", "", fmt.Errorf("failed to parse remote URL: %w", err) + } + targetURI = (&url.URL{Scheme: remoteURL.Scheme, Host: remoteURL.Host}).String() + remoteBasePath = remoteURL.Path + remoteRawQuery = remoteURL.RawQuery + slog.Debug("setting up transparent proxy to forward to remote URL", + "port", t.proxyPort, "target", targetURI, "base_path", remoteBasePath, "raw_query", remoteRawQuery) + return targetURI, remoteBasePath, remoteRawQuery, nil + } + if t.containerName == "" { + return "", "", "", transporterrors.ErrContainerNameNotSet + } + if t.targetURI == "" { + return "", "", "", fmt.Errorf("target URI not set for HTTP transport") + } + slog.Debug("setting up transparent proxy to forward to target", + "port", t.proxyPort, "target", t.targetURI) + return t.targetURI, "", "", nil +} + // Start initializes the transport and begins processing messages. // The transport is responsible for starting the container. // @@ -251,52 +281,15 @@ func (t *HTTPTransport) Start(ctx context.Context) error { return fmt.Errorf("container deployer not set") } - // Determine target URI - var targetURI string - // remoteBasePath holds the path component from the remote URL (e.g., "/v2" from // "https://mcp.asana.com/v2/mcp"). This must be prepended to incoming request // paths so they reach the correct endpoint on the remote server. - var remoteBasePath string - // remoteRawQuery holds the raw query string from the remote URL (e.g., // "toolsets=core,alerting" from "https://mcp.example.com/mcp?toolsets=core,alerting"). // This must be forwarded on every outbound request or it is silently dropped. - var remoteRawQuery string - - if t.remoteURL != "" { - // For remote MCP servers, construct target URI from remote URL - remoteURL, err := url.Parse(t.remoteURL) - if err != nil { - return fmt.Errorf("failed to parse remote URL: %w", err) - } - targetURI = (&url.URL{ - Scheme: remoteURL.Scheme, - Host: remoteURL.Host, - }).String() - - // Extract the path prefix that needs to be prepended to incoming requests. - // The target URI only has scheme+host, so without this the remote path is lost. - remoteBasePath = remoteURL.Path - - remoteRawQuery = remoteURL.RawQuery - - //nolint:gosec // G706: logging proxy port and remote URL from config - slog.Debug("setting up transparent proxy to forward to remote URL", - "port", t.proxyPort, "target", targetURI, "base_path", remoteBasePath, "raw_query", remoteRawQuery) - } else { - if t.containerName == "" { - return transporterrors.ErrContainerNameNotSet - } - - // For local containers, use the configured target URI - if t.targetURI == "" { - return fmt.Errorf("target URI not set for HTTP transport") - } - targetURI = t.targetURI - //nolint:gosec // G706: logging proxy port and target URI from config - slog.Debug("setting up transparent proxy to forward to target", - "port", t.proxyPort, "target", targetURI) + targetURI, remoteBasePath, remoteRawQuery, err := t.resolveTargetURI() + if err != nil { + return err } // Create middlewares slice @@ -330,6 +323,21 @@ func (t *HTTPTransport) Start(ctx context.Context) error { proxyOptions = append(proxyOptions, transparent.WithSessionStorage(t.sessionStorage)) } + // Inject Redis-backed session storage for cross-replica session sharing. + if t.sessionStorage != nil { + proxyOptions = append(proxyOptions, transparent.WithSessionStorage(t.sessionStorage)) + } + + // Enable pod-specific routing for Kubernetes StatefulSet backends. + // When configured, each new session is pinned to a specific pod via headless DNS + // so that session routing survives proxy-runner restarts. + if t.headlessService != nil { + proxyOptions = append(proxyOptions, transparent.WithPodHeadlessService( + t.headlessService.StatefulSetName, t.headlessService.ServiceName, + t.headlessService.Namespace, t.headlessService.Replicas, + )) + } + // Create the transparent proxy t.proxy = transparent.NewTransparentProxyWithOptions( t.host, diff --git a/pkg/transport/proxy/transparent/backend_routing_test.go b/pkg/transport/proxy/transparent/backend_routing_test.go index 4167c86406..c7485be4a4 100644 --- a/pkg/transport/proxy/transparent/backend_routing_test.go +++ b/pkg/transport/proxy/transparent/backend_routing_test.go @@ -5,9 +5,11 @@ package transparent import ( "context" + "fmt" "io" "net/http" "net/http/httptest" + "net/url" "strings" "sync/atomic" "testing" @@ -289,3 +291,128 @@ func TestRoundTripStoresBackendURLOnInitialize(t *testing.T) { require.True(t, ok, "session should have backend_url metadata") assert.Equal(t, backend.URL, backendURL) } + +// TestWithPodHeadlessServiceStoresPodURL verifies that when WithPodHeadlessService is configured, +// an initialize response causes the session's backend_url to be a pod-specific headless DNS URL +// (e.g. http://myserver-N.mcp-myserver-headless.default.svc.cluster.local:) rather than +// the ClusterIP targetURI. +func TestWithPodHeadlessServiceStoresPodURL(t *testing.T) { + t.Parallel() + + sessionID := uuid.New().String() + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Mcp-Session-Id", sessionID) + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + const ( + statefulSetName = "myserver" + serviceName = "mcp-myserver-headless" + namespace = "default" + replicas = int32(3) + ) + + proxy := NewTransparentProxyWithOptions( + "127.0.0.1", 0, backend.URL, + nil, nil, nil, + false, false, "sse", + nil, nil, "", false, + nil, + WithPodHeadlessService(statefulSetName, serviceName, namespace, replicas), + ) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(func() { + cancel() + stopCtx, stopCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer stopCancel() + _ = proxy.Stop(stopCtx) + }) + require.NoError(t, proxy.Start(ctx)) + addr := proxy.listener.Addr().String() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://"+addr+"/mcp", + strings.NewReader(`{"method":"initialize"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + + sess, ok := proxy.sessionManager.Get(normalizeSessionID(sessionID)) + require.True(t, ok, "session should have been created by RoundTrip") + backendURL, ok := sess.GetMetadataValue(sessionMetadataBackendURL) + require.True(t, ok, "session should have backend_url metadata") + + // The URL must use headless DNS, not the ClusterIP backend.URL. + assert.NotEqual(t, backend.URL, backendURL, "backend_url should be pod-specific, not ClusterIP") + assert.Contains(t, backendURL, ".mcp-myserver-headless.default.svc.cluster.local", + "backend_url should contain the headless service DNS suffix") + assert.Contains(t, backendURL, "myserver-", + "backend_url should contain the StatefulSet pod name prefix") + + // Ordinal must be in range [0, replicas). + parsedURL, err := url.Parse(backendURL) + require.NoError(t, err) + host := parsedURL.Hostname() + // host is e.g. "myserver-2.mcp-myserver-headless.default.svc.cluster.local" + var ordinal int + n, err := fmt.Sscanf(host, "myserver-%d.", &ordinal) + require.NoError(t, err) + require.Equal(t, 1, n) + assert.GreaterOrEqual(t, ordinal, 0) + assert.Less(t, ordinal, int(replicas)) +} + +// TestWithPodHeadlessServiceSingleReplica verifies that WithPodHeadlessService with replicas=1 +// routes to ordinal 0 via headless DNS, not the static ClusterIP targetURI. +func TestWithPodHeadlessServiceSingleReplica(t *testing.T) { + t.Parallel() + + sessionID := uuid.New().String() + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Mcp-Session-Id", sessionID) + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + proxy := NewTransparentProxyWithOptions( + "127.0.0.1", 0, backend.URL, + nil, nil, nil, + false, false, "sse", + nil, nil, "", false, + nil, + WithPodHeadlessService("myserver", "mcp-myserver-headless", "default", 1), + ) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(func() { + cancel() + stopCtx, stopCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer stopCancel() + _ = proxy.Stop(stopCtx) + }) + require.NoError(t, proxy.Start(ctx)) + addr := proxy.listener.Addr().String() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://"+addr+"/mcp", + strings.NewReader(`{"method":"initialize"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + + sess, ok := proxy.sessionManager.Get(normalizeSessionID(sessionID)) + require.True(t, ok, "session should have been created by RoundTrip") + backendURL, ok := sess.GetMetadataValue(sessionMetadataBackendURL) + require.True(t, ok, "session should have backend_url metadata") + + // With replicas=1, headless DNS is still used — ordinal is always 0. + assert.NotEqual(t, backend.URL, backendURL, "backend_url should use headless DNS, not static ClusterIP") + assert.Contains(t, backendURL, "myserver-0.mcp-myserver-headless.default.svc.cluster.local", + "single-replica should always route to pod ordinal 0 via headless DNS") +} diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 408abd7f08..714c14ba6c 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -8,11 +8,13 @@ package transparent import ( "bytes" "context" + "crypto/rand" "encoding/json" "errors" "fmt" "io" "log/slog" + "math/big" "net" "net/http" "net/http/httputil" @@ -34,6 +36,15 @@ import ( "github.com/stacklok/toolhive/pkg/transport/types" ) +// podHeadlessService holds the Kubernetes headless service information used to route +// new sessions to a specific StatefulSet pod via headless DNS in multi-replica deployments. +type podHeadlessService struct { + statefulSetName string // e.g. "myserver" (StatefulSet name == MCPServer name) + serviceName string // e.g. "mcp-myserver-headless" + namespace string // e.g. "default" + replicas int32 // number of replicas, for random pod selection +} + // TransparentProxy implements the Proxy interface as a transparent HTTP proxy // that forwards requests to a destination. // It's used by the SSE transport to forward requests to the container's HTTP server. @@ -128,6 +139,11 @@ type TransparentProxy struct { // Shutdown timeout for graceful HTTP server shutdown (default: 30 seconds) shutdownTimeout time.Duration + + // headlessService holds Kubernetes headless service info for pod-specific routing. + // When set, new sessions are pinned to a specific pod via headless DNS so routing + // survives proxy-runner restarts. Nil for non-Kubernetes deployments. + headlessService *podHeadlessService } const ( @@ -223,6 +239,25 @@ func withShutdownTimeout(timeout time.Duration) Option { } } +// WithPodHeadlessService configures pod-specific routing for Kubernetes StatefulSet deployments. +// When set, each new MCP session is pinned to a specific pod via its headless DNS name +// (e.g. myserver-0.mcp-myserver-headless.default.svc.cluster.local) so that session routing +// survives proxy-runner restarts. For single-replica StatefulSets, ordinal 0 is always selected. +// The option is a no-op when any required field is empty. +func WithPodHeadlessService(statefulSetName, serviceName, namespace string, replicas int32) Option { + return func(p *TransparentProxy) { + if statefulSetName == "" || serviceName == "" || namespace == "" { + return + } + p.headlessService = &podHeadlessService{ + statefulSetName: statefulSetName, + serviceName: serviceName, + namespace: namespace, + replicas: replicas, + } + } +} + // WithSessionStorage injects a custom storage backend into the session manager. // When not provided, the proxy uses in-memory LocalStorage (single-replica default). // Provide a Redis-backed storage for multi-replica deployments so all replicas @@ -365,6 +400,32 @@ type tracingTransport struct { p *TransparentProxy } +// pickPodBackendURL selects a random StatefulSet pod and returns its headless DNS URL. +// The URL has the form http://-...svc.cluster.local:. +// Falls back to p.targetURI on any parse error so routing always succeeds. +func (p *TransparentProxy) pickPodBackendURL() string { + parsed, err := url.Parse(p.targetURI) + if err != nil || parsed.Host == "" { + return p.targetURI + } + _, port, err := net.SplitHostPort(parsed.Host) + if err != nil { + // targetURI host has no explicit port — use scheme default and fall back + return p.targetURI + } + n, err := rand.Int(rand.Reader, big.NewInt(int64(p.headlessService.replicas))) + if err != nil { + return p.targetURI + } + podHost := fmt.Sprintf("%s-%d.%s.%s.svc.cluster.local", + p.headlessService.statefulSetName, + n.Int64(), + p.headlessService.serviceName, + p.headlessService.namespace, + ) + return fmt.Sprintf("%s://%s:%s", parsed.Scheme, podHost, port) +} + func (p *TransparentProxy) setServerInitialized() { if p.isServerInitialized.CompareAndSwap(false, true) { //nolint:gosec // G706: logging target URI from config @@ -480,14 +541,16 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) internalID := normalizeSessionID(ct) if _, ok := t.p.sessionManager.Get(internalID); !ok { sess := session.NewProxySession(internalID) - // Store targetURI as the default backend_url for this session. - // In single-replica deployments targetURI is already the pod address, - // so no override is needed. In multi-replica deployments the - // vMCP/operator layer is responsible for setting backend_url to the - // actual pod DNS name (e.g. http://mcp-server-0.mcp-server.default.svc:8080) - // before the request reaches this proxy; the Rewrite closure then reads - // that value and routes follow-up requests to the correct pod. - sess.SetMetadata(sessionMetadataBackendURL, t.p.targetURI) + // Store backend_url for this session so follow-up requests are routed + // to the same pod that handled initialize. + // - Single-replica / no headless config: use the static ClusterIP targetURI. + // - Multi-replica with headless config: pick a random pod via headless DNS + // so sessions survive proxy-runner restarts without hitting the wrong pod. + backendURL := t.p.targetURI + if t.p.headlessService != nil { + backendURL = t.p.pickPodBackendURL() + } + sess.SetMetadata(sessionMetadataBackendURL, backendURL) if err := t.p.sessionManager.AddSession(sess); err != nil { //nolint:gosec // G706: session ID from HTTP response header slog.Error("failed to create session from header", diff --git a/pkg/transport/types/transport.go b/pkg/transport/types/transport.go index beec37e407..8a272a413c 100644 --- a/pkg/transport/types/transport.go +++ b/pkg/transport/types/transport.go @@ -276,6 +276,26 @@ type Config struct { // Used for Redis-backed session sharing across replicas. // When nil, transports use their default in-memory LocalStorage. SessionStorage session.Storage + + // HeadlessService configures pod-specific routing for Kubernetes StatefulSet deployments. + // When set, each new MCP session is pinned to a specific pod via headless DNS + // (e.g. myserver-0.mcp-myserver-headless.ns.svc.cluster.local) so session routing + // survives proxy-runner restarts. Nil for non-Kubernetes deployments. + HeadlessService *HeadlessServiceConfig +} + +// HeadlessServiceConfig holds Kubernetes headless service information used to construct +// pod-specific DNS URLs (e.g. myserver-0.mcp-myserver-headless.default.svc.cluster.local) +// for session-affinity routing in StatefulSet deployments. +type HeadlessServiceConfig struct { + // StatefulSetName is the name of the backend StatefulSet (equals the MCPServer name). + StatefulSetName string `json:"statefulset_name,omitempty" yaml:"statefulset_name,omitempty"` + // ServiceName is the name of the headless Kubernetes service (e.g. "mcp-myserver-headless"). + ServiceName string `json:"service_name,omitempty" yaml:"service_name,omitempty"` + // Namespace is the Kubernetes namespace of the StatefulSet. + Namespace string `json:"namespace,omitempty" yaml:"namespace,omitempty"` + // Replicas is the StatefulSet replica count, used to select a random pod ordinal. + Replicas int32 `json:"replicas,omitempty" yaml:"replicas,omitempty"` } // ProxyMode represents the proxy mode for stdio transport. diff --git a/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go b/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go index 5c827171c6..253548a610 100644 --- a/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go +++ b/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go @@ -103,6 +103,8 @@ func cleanupRedis(namespace, name string) { } // getReadyMCPServerPods returns all Running+Ready pods for an MCPServer. +// +//nolint:unparam // namespace kept as parameter for reusability across test contexts func getReadyMCPServerPods(mcpServerName, namespace string) ([]corev1.Pod, error) { podList := &corev1.PodList{} if err := k8sClient.List(ctx, podList, @@ -193,6 +195,195 @@ var _ = ginkgo.Describe("MCPServer Cross-Replica Session Routing with Redis", fu proxyPort = int32(8080) ) + ginkgo.Context("When MCPServer has backendReplicas=2 and proxy runner restarts", ginkgo.Ordered, func() { + var ( + mcpServerName string + redisName string + nodePortName string + nodePort int32 + ) + + ginkgo.BeforeAll(func() { + ts := time.Now().UnixNano() + mcpServerName = fmt.Sprintf("e2e-backend-scale-%d", ts) + redisName = fmt.Sprintf("e2e-redis-be-%d", ts) + nodePortName = mcpServerName + "-np" + + ginkgo.By("Deploying Redis for session storage") + deployRedis(defaultNamespace, redisName, timeout, pollInterval) + + replicas := int32(1) + backendReplicas := int32(2) + redisAddr := fmt.Sprintf("%s.%s.svc.cluster.local:6379", redisName, defaultNamespace) + + ginkgo.By("Creating MCPServer with replicas=1, backendReplicas=2, Redis session storage") + gomega.Expect(k8sClient.Create(ctx, &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: mcpServerName, Namespace: defaultNamespace}, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: proxyPort, + McpPort: 8080, + Replicas: &replicas, + BackendReplicas: &backendReplicas, + SessionAffinity: "None", + SessionStorage: &mcpv1alpha1.SessionStorageConfig{ + Provider: mcpv1alpha1.SessionStorageProviderRedis, + Address: redisAddr, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for MCPServer to be Running") + waitForMCPServerRunning(mcpServerName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Waiting for 1 ready proxy runner pod") + gomega.Eventually(func() (int, error) { + pods, err := getReadyMCPServerPods(mcpServerName, defaultNamespace) + if err != nil { + return 0, err + } + return len(pods), nil + }, timeout, pollInterval).Should(gomega.Equal(1)) + + ginkgo.By("Creating a NodePort service for external access to the proxy runner") + gomega.Expect(k8sClient.Create(ctx, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodePortName, + Namespace: defaultNamespace, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeNodePort, + Selector: map[string]string{ + "app.kubernetes.io/name": "mcpserver", + "app.kubernetes.io/instance": mcpServerName, + }, + Ports: []corev1.ServicePort{{ + Port: proxyPort, + TargetPort: intstr.FromInt32(proxyPort), + Protocol: corev1.ProtocolTCP, + Name: "http", + }}, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for NodePort to be assigned and accessible") + gomega.Eventually(func() error { + svc := &corev1.Service{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: nodePortName, Namespace: defaultNamespace, + }, svc); err != nil { + return err + } + if len(svc.Spec.Ports) == 0 || svc.Spec.Ports[0].NodePort == 0 { + return fmt.Errorf("nodePort not assigned") + } + nodePort = svc.Spec.Ports[0].NodePort + + if err := checkPortAccessible(nodePort, 1*time.Second); err != nil { + return fmt.Errorf("nodePort %d not accessible: %w", nodePort, err) + } + if err := checkHTTPHealthReady(nodePort, 2*time.Second); err != nil { + return fmt.Errorf("nodePort %d not ready: %w", nodePort, err) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed(), "NodePort should be assigned and ready") + }) + + ginkgo.AfterAll(func() { + _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: mcpServerName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: nodePortName, Namespace: defaultNamespace}, + }) + cleanupRedis(defaultNamespace, redisName) + + gomega.Eventually(func() bool { + err := k8sClient.Get(ctx, types.NamespacedName{Name: mcpServerName, Namespace: defaultNamespace}, &mcpv1alpha1.MCPServer{}) + return apierrors.IsNotFound(err) + }, timeout, pollInterval).Should(gomega.BeTrue()) + }) + + ginkgo.It("Should route session to the correct backend after proxy runner restart", func() { + ginkgo.By("Initializing an MCP session via NodePort") + mcpClient, err := CreateInitializedMCPClient(nodePort, "e2e-backend-routing-test", 30*time.Second) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + sessionID := mcpClient.Client.GetSessionId() + gomega.Expect(sessionID).NotTo(gomega.BeEmpty(), "session ID must be assigned after Initialize") + + ginkgo.By("Calling tools/list to verify session works before restart") + toolsBefore, err := mcpClient.Client.ListTools(mcpClient.Ctx, mcp.ListToolsRequest{}) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(toolsBefore.Tools).NotTo(gomega.BeEmpty()) + + mcpClient.Close() + + ginkgo.By("Getting the current proxy runner pod name") + var pods []corev1.Pod + gomega.Eventually(func() (int, error) { + var listErr error + pods, listErr = getReadyMCPServerPods(mcpServerName, defaultNamespace) + if listErr != nil { + return 0, listErr + } + return len(pods), nil + }, timeout, pollInterval).Should(gomega.Equal(1)) + oldPodName := pods[0].Name + + ginkgo.By(fmt.Sprintf("Deleting proxy runner pod %s (Deployment will recreate it)", oldPodName)) + gomega.Expect(k8sClient.Delete(ctx, &pods[0])).To(gomega.Succeed()) + + ginkgo.By("Waiting for new proxy runner pod to be Running+Ready") + gomega.Eventually(func() (string, error) { + newPods, listErr := getReadyMCPServerPods(mcpServerName, defaultNamespace) + if listErr != nil || len(newPods) == 0 { + return "", fmt.Errorf("waiting for new pod") + } + if newPods[0].Name == oldPodName { + return "", fmt.Errorf("old pod %s still present", oldPodName) + } + return newPods[0].Name, nil + }, timeout, pollInterval).ShouldNot(gomega.BeEmpty()) + + ginkgo.By("Waiting for NodePort to be accessible on the new pod") + gomega.Eventually(func() error { + if err := checkHTTPHealthReady(nodePort, 2*time.Second); err != nil { + return fmt.Errorf("nodePort %d not ready after restart: %w", nodePort, err) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed()) + + ginkgo.By("Creating a new client with the SAME session ID") + serverURL := fmt.Sprintf("http://localhost:%d/mcp", nodePort) + newClient, err := mcpclient.NewStreamableHttpClient(serverURL, transport.WithSession(sessionID)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + defer func() { _ = newClient.Close() }() + + startCtx, startCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer startCancel() + gomega.Expect(newClient.Start(startCtx)).To(gomega.Succeed()) + + // With backendReplicas=2 and sessionAffinity=None, the backend_url stored + // in Redis is the ClusterIP service URL. After proxy runner restart, + // kube-proxy may route to a different backend pod that doesn't have this + // MCP session. Send multiple requests to make routing failure reliably + // detectable: with 2 backends and random routing, + // P(all 5 hit correct backend) ≈ 3%. + ginkgo.By("Sending 5 requests with the recovered session to verify backend routing") + for i := 0; i < 5; i++ { + listCtx, listCancel := context.WithTimeout(context.Background(), 30*time.Second) + toolsAfter, listErr := newClient.ListTools(listCtx, mcp.ListToolsRequest{}) + listCancel() + gomega.Expect(listErr).NotTo(gomega.HaveOccurred(), + "Request %d/5 should succeed — session should route to the correct backend", i+1) + gomega.Expect(toolsAfter.Tools).To(gomega.HaveLen(len(toolsBefore.Tools)), + "Request %d/5 should return the same tools as before restart", i+1) + } + }) + }) + ginkgo.Context("When MCPServer has replicas=2 with Redis session storage", ginkgo.Ordered, func() { var ( mcpServerName string