diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 72cc7408..df6ef5e7 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -720,7 +720,7 @@ func (c *Connection) write(ctx context.Context, msg Message) error { // For cancelled or rejected requests, we don't set the writeErr (which would // break the connection). They can just be returned to the caller. - if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) { + if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) && !errors.Is(err, ErrUnsupportedProtocolVersion) && !errors.Is(err, ErrMethodNotFound) { // The call to Write failed, and since ctx.Err() is nil we can't attribute // the failure (even indirectly) to Context cancellation. The writer appears // to be broken, and future writes are likely to also fail. diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index 4d123f2c..b0beae02 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -31,7 +31,7 @@ var ( // ErrUnknown should be used for all non coded errors. ErrUnknown = NewError(-32001, "unknown error") // ErrServerClosing is returned for calls that arrive while the server is closing. - ErrServerClosing = NewError(-32004, "server is closing") + ErrServerClosing = NewError(-32006, "server is closing") // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. ErrClientClosing = NewError(-32003, "client is closing") @@ -45,6 +45,8 @@ var ( // should be returned to the caller to indicate that the specific request is // invalid in the current context. ErrRejected = NewError(-32005, "rejected by transport") + // ErrUnsupportedProtocolVersion is returned when a server does not support the protocol version. + ErrUnsupportedProtocolVersion = NewError(-32004, "unsupported protocol version") ) const wireVersion = "2.0" diff --git a/mcp/client.go b/mcp/client.go index 9f6f2955..c33c709d 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -165,6 +165,7 @@ type ClientOptions struct { // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. + // NOTE: The keepalive feature is only available for protocol versions < 2026-06-30 KeepAlive time.Duration } @@ -276,6 +277,27 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp if opts != nil && opts.protocolVersion != "" { protocolVersion = opts.protocolVersion } + + if protocolVersion >= protocolVersion20260630 { + // Per SEP-2575, try the stateless server/discover RPC first. If the server + // signals it doesn't support it, fall back to the legacy initialize + // handshake. + discoverCtx := context.WithValue(ctx, protocolVersionContextKey{}, protocolVersion) + discRes, fallback, err := c.discover(discoverCtx, cs) + if err != nil { + return nil, err + } + if !fallback { + cs.state.InitializeResult = discRes + if hc, ok := cs.mcpConn.(clientConnection); ok { + hc.sessionUpdated(cs.state) + } + return cs, nil + } + // Fallback to the legacy initialize handshake. + protocolVersion = protocolVersion20251125 + } + params := &InitializeParams{ ProtocolVersion: protocolVersion, ClientInfo: c.impl, @@ -307,6 +329,66 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp return cs, nil } +// discover sends a SEP-2575 server/discover request to probe the server for +// stateless protocol support. +// +// The return values have three possible combinations: +// - (result, false, nil): discovery succeeded; caller should skip legacy initialization. +// - (nil, true, nil): the server explicitly signaled it doesn't support +// discovery (Method not found, or UnsupportedProtocolVersionError, or version mismatch); +// caller should fall back to the legacy initialize handshake. +// - (nil, false, err): any other failure (transport error, malformed response, etc.); +// caller should propagate the error. +func (c *Client) discover(ctx context.Context, cs *ClientSession) (*InitializeResult, bool, error) { + protocolVersion := protocolVersionFromContext(ctx) + caps := c.capabilities(protocolVersion) + params := &DiscoverParams{ + Meta: Meta{ + MetaKeyProtocolVersion: protocolVersion, + MetaKeyClientInfo: c.impl, + MetaKeyClientCapabilities: caps, + }, + } + req := &DiscoverRequest{Session: cs, Params: params} + res, err := handleSend[*DiscoverResult](ctx, methodDiscover, req) + if err != nil { + // According to SEP-2575, only the two signals below (MethodNotFound + // and UnsupportedProtocolVersionError) should trigger a fallback. + var werr *jsonrpc.Error + if errors.As(err, &werr) && (werr.Code == jsonrpc.CodeMethodNotFound || werr.Code == CodeUnsupportedProtocolVersion) { + return nil, true, nil + } + return nil, false, err + } + + // Pick the highest protocol version that both the server and this SDK support. + // Since supportedProtocolVersions is defined in descending order (newest to oldest), + // the first match we find is the highest supported version. + var negotiated string + if slices.Contains(res.SupportedVersions, protocolVersion) { + negotiated = protocolVersion + } else { + for _, v := range supportedProtocolVersions { + if slices.Contains(res.SupportedVersions, v) { + negotiated = v + break + } + } + } + if negotiated == "" || negotiated < protocolVersion20260630 { + // If there is no overlap, fall back to initialize so version + // negotiation can happen via the legacy path. + return nil, true, nil + } + + return &InitializeResult{ + Capabilities: res.Capabilities, + Instructions: res.Instructions, + ProtocolVersion: negotiated, + ServerInfo: res.ServerInfo, + }, false, nil +} + // A ClientSession is a logical connection with an MCP server. Its // methods can be used to send requests or notifications to the server. Create // a session by calling [Client.Connect]. @@ -347,6 +429,42 @@ type clientSessionState struct { func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult } +// usesNewProtocol reports whether this session has negotiated a protocol +// version >= 2026-06-30, which requires the SEP-2575 per-request `_meta` +// triple on every outgoing request. +func (cs *ClientSession) usesNewProtocol() bool { + res := cs.state.InitializeResult + return res != nil && res.ProtocolVersion >= protocolVersion20260630 +} + +// injectRequestMeta populates the SEP-2575 per-request `_meta` triple +// (protocolVersion, clientInfo, clientCapabilities) on the given outgoing +// request params. Keys already present in params.Meta are not overwritten. +func injectRequestMeta[T any, P interface { + *T + Params +}](cs *ClientSession, params P) P { + res := cs.state.InitializeResult + if params.isNil() { + params = new(T) + } + m := params.GetMeta() + if m == nil { + m = map[string]any{} + } + if _, ok := m[MetaKeyProtocolVersion]; !ok { + m[MetaKeyProtocolVersion] = res.ProtocolVersion + } + if _, ok := m[MetaKeyClientInfo]; !ok { + m[MetaKeyClientInfo] = cs.client.impl + } + if _, ok := m[MetaKeyClientCapabilities]; !ok { + m[MetaKeyClientCapabilities] = cs.client.capabilities(res.ProtocolVersion) + } + params.SetMeta(m) + return params +} + func (cs *ClientSession) ID() string { if c, ok := cs.mcpConn.(hasSessionID); ok { return c.SessionID() @@ -1007,16 +1125,25 @@ func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { // ListPrompts lists prompts that are currently available on the server. func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params))) } // GetPrompt gets a prompt from the server. func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[*GetPromptResult](ctx, methodGetPrompt, newClientRequest(cs, orZero[Params](params))) } // ListTools lists tools that are currently available on the server. func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) if err != nil { return nil, err @@ -1040,6 +1167,9 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( if tool := cs.getCachedTool(params.Name); tool != nil { ctx = context.WithValue(ctx, toolContextKey, tool) } + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } @@ -1050,20 +1180,32 @@ func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLogging // ListResources lists the resources that are currently available on the server. func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params))) } // ListResourceTemplates lists the resource templates that are currently available on the server. func (cs *ClientSession) ListResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params))) } // ReadResource asks the server to read a resource and return its contents. func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params))) } func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (*CompleteResult, error) { + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[*CompleteResult](ctx, methodComplete, newClientRequest(cs, orZero[Params](params))) } diff --git a/mcp/client_test.go b/mcp/client_test.go index 609fd501..1d629521 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -8,11 +8,14 @@ import ( "context" "fmt" "log/slog" + "sync/atomic" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) type Item struct { @@ -617,3 +620,224 @@ func TestClientCapabilitiesOverWire(t *testing.T) { }) } } + +// TestClientConnectDiscover exercises the SEP-2575 server/discover probe that +// Client.Connect now sends before falling back to the legacy initialize +// handshake. +// +// Each subtest installs a server-side receiving middleware that intercepts the +// "server/discover" method and returns a canned response: a successful +// DiscoverResult, a "Method not found" error, an UnsupportedProtocolVersion +// error, an unrelated failure, or a DiscoverResult whose supportedVersions +// don't overlap with the SDK. The test then asserts the resulting session +// state and whether the legacy initialize handshake ran. +func TestClientConnectDiscover(t *testing.T) { + // Temporarily enable 2026-06-30 support in the SDK for this test + oldSupported := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, supportedProtocolVersions...) + t.Cleanup(func() { + supportedProtocolVersions = oldSupported + }) + + const otherVersionsOnly = "1999-01-01" + + tests := []struct { + name string + // discoverHandler decides how the server replies to server/discover. + // Returning (nil, nil) means "let the default stub handle it" (which + // returns ErrMethodNotFound). + discoverHandler func() (Result, error) + wantConnectErr bool + // wantInitialize is true if the legacy initialize handshake should + // have run (i.e. discover signaled "not supported"). + wantInitialize bool + // wantVersion is the protocol version expected to end up on + // ClientSession.state.InitializeResult after Connect returns. + wantVersion string + }{ + { + name: "discover success skips initialize", + discoverHandler: func() (Result, error) { + return &DiscoverResult{ + SupportedVersions: []string{protocolVersion20260630}, + Capabilities: &ServerCapabilities{ + Tools: &ToolCapabilities{ListChanged: true}, + }, + ServerInfo: &Implementation{Name: "discoverServer", Version: "v1.0.0"}, + }, nil + }, + wantInitialize: false, + wantVersion: protocolVersion20260630, + }, + { + name: "method not found falls back to initialize", + discoverHandler: func() (Result, error) { + return nil, jsonrpc2.ErrMethodNotFound + }, + wantInitialize: true, + wantVersion: latestProtocolVersion, + }, + { + name: "no overlapping supported version falls back to initialize", + discoverHandler: func() (Result, error) { + return &DiscoverResult{ + SupportedVersions: []string{otherVersionsOnly}, + Capabilities: &ServerCapabilities{}, + ServerInfo: &Implementation{Name: "discoverServer", Version: "v1.0.0"}, + }, nil + }, + wantInitialize: true, + wantVersion: latestProtocolVersion, + }, + { + name: "unexpected error propagates and aborts Connect", + discoverHandler: func() (Result, error) { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "boom", + } + }, + wantConnectErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + var ( + gotDiscover atomic.Bool + gotInitialize atomic.Bool + ) + + s := NewServer(testImpl, nil) + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + switch method { + case methodDiscover: + gotDiscover.Store(true) + return tc.discoverHandler() + case methodInitialize: + gotInitialize.Store(true) + } + return next(ctx, method, req) + } + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server Connect: %v", err) + } + defer ss.Close() + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if tc.wantConnectErr { + if err == nil { + _ = cs.Close() + t.Fatal("Connect succeeded, want error") + } + if !gotDiscover.Load() { + t.Error("server did not receive server/discover") + } + if gotInitialize.Load() { + t.Error("server received initialize but discover should have aborted Connect") + } + return + } + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer cs.Close() + + if !gotDiscover.Load() { + t.Error("server did not receive server/discover") + } + if got, want := gotInitialize.Load(), tc.wantInitialize; got != want { + t.Errorf("initialize invoked = %v, want %v", got, want) + } + ir := cs.InitializeResult() + if ir == nil { + t.Fatal("InitializeResult is nil after Connect") + } + if got, want := ir.ProtocolVersion, tc.wantVersion; got != want { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q", got, want) + } + }) + } +} + +// TestClientConnectDiscover_RequestContents verifies that the server/discover +// request sent by Client.Connect carries the SEP-2575 per-request _meta triple: +// protocolVersion, clientInfo, and clientCapabilities. +func TestClientConnectDiscover_RequestContents(t *testing.T) { + ctx := context.Background() + + type captured struct { + params *DiscoverParams + } + var got captured + + s := NewServer(testImpl, nil) + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == methodDiscover { + sr, ok := req.(*ServerRequest[*DiscoverParams]) + if !ok { + t.Errorf("discover req has unexpected type %T", req) + return nil, jsonrpc2.ErrMethodNotFound + } + got.params = sr.Params + // Make discover "not supported" so Connect proceeds (we only + // care about the discover request payload here). + return nil, jsonrpc2.ErrMethodNotFound + } + return next(ctx, method, req) + } + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server Connect: %v", err) + } + defer ss.Close() + + clientImpl := &Implementation{Name: "discover-probe-client", Version: "v9.9.9"} + c := NewClient(clientImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return nil, nil + }, + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer cs.Close() + + if got.params == nil { + t.Fatal("server did not receive server/discover") + } + + meta := got.params.GetMeta() + if v, _ := meta[MetaKeyProtocolVersion].(string); v != protocolVersion20260630 { + t.Errorf("_meta[%s] = %q, want %q", MetaKeyProtocolVersion, v, protocolVersion20260630) + } + // _meta values round-trip through JSON, so on the server side they + // arrive as map[string]any rather than the typed Go pointers we sent. + info, ok := meta[MetaKeyClientInfo].(map[string]any) + if !ok { + t.Fatalf("_meta[%s] = %T, want map[string]any", MetaKeyClientInfo, meta[MetaKeyClientInfo]) + } + if got, want := info["name"], any(clientImpl.Name); got != want { + t.Errorf("clientInfo.name = %v, want %v", got, want) + } + caps, ok := meta[MetaKeyClientCapabilities].(map[string]any) + if !ok { + t.Fatalf("_meta[%s] = %T, want map[string]any", MetaKeyClientCapabilities, meta[MetaKeyClientCapabilities]) + } + if _, ok := caps["sampling"]; !ok { + t.Errorf("clientCapabilities.sampling missing (CreateMessageHandler was set); got %v", caps) + } +} diff --git a/mcp/mrtr_test.go b/mcp/mrtr_test.go index 12eba7ea..e1e61b2a 100644 --- a/mcp/mrtr_test.go +++ b/mcp/mrtr_test.go @@ -525,6 +525,28 @@ func TestMultiRoundTrip_ReadResource_ManualRetry(t *testing.T) { func mustConnect(t *testing.T, s *Server, clientOpts *ClientOptions) *ClientSession { t.Helper() + + // The mrtr tests require negotiating the 2026-06-30 protocol version. + // Server.discover is currently a stub that returns ErrMethodNotFound, which + // would cause Client.Connect to fall back to the legacy initialize handshake + // and downgrade the negotiated version. Install a receiving middleware that + // answers server/discover with a DiscoverResult advertising 2026-06-30 so + // the client can negotiate the new protocol via the discover path. + // + // TODO: Remove this once the server has a proper discover implementation. + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == methodDiscover { + return &DiscoverResult{ + SupportedVersions: []string{protocolVersion20260630}, + Capabilities: &ServerCapabilities{}, + ServerInfo: testImpl, + }, nil + } + return next(ctx, method, req) + } + }) + st, ct := NewInMemoryTransports() ss, err := s.Connect(t.Context(), st, nil) if err != nil { diff --git a/mcp/protocol.go b/mcp/protocol.go index bebdc196..8c658f9f 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -1071,6 +1071,29 @@ type ListPromptsParams struct { Cursor string `json:"cursor,omitempty"` } +type DiscoverParams struct { + Meta `json:"_meta,omitempty"` +} + +func (x *DiscoverParams) isParams() {} +func (x *DiscoverParams) isNil() bool { return x == nil } +func (x *DiscoverParams) GetProgressToken() any { return getProgressToken(x) } +func (x *DiscoverParams) SetProgressToken(t any) { setProgressToken(x, t) } + +type DiscoverResult struct { + Meta `json:"_meta,omitempty"` + // The versions of the Model Context Protocol that the server supports. + SupportedVersions []string `json:"supportedVersions"` + // The server's capabilities. + Capabilities *ServerCapabilities `json:"capabilities"` + // Information about the server implementation. + ServerInfo *Implementation `json:"serverInfo"` + // Instructions describing how to use the server and its features. + Instructions string `json:"instructions,omitempty"` +} + +func (*DiscoverResult) isResult() {} + func (x *ListPromptsParams) isParams() {} func (x *ListPromptsParams) isNil() bool { return x == nil } func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } @@ -2039,6 +2062,7 @@ const ( methodCallTool = "tools/call" notificationCancelled = "notifications/cancelled" methodComplete = "completion/complete" + methodDiscover = "server/discover" methodCreateMessage = "sampling/createMessage" methodElicit = "elicitation/create" notificationElicitationComplete = "notifications/elicitation/complete" diff --git a/mcp/requests.go b/mcp/requests.go index 42809413..36368c99 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -25,6 +25,7 @@ type ( type ( CreateMessageRequest = ClientRequest[*CreateMessageParams] CreateMessageWithToolsRequest = ClientRequest[*CreateMessageWithToolsParams] + DiscoverRequest = ClientRequest[*DiscoverParams] ElicitRequest = ClientRequest[*ElicitParams] initializedClientRequest = ClientRequest[*InitializedParams] InitializeRequest = ClientRequest[*InitializeParams] diff --git a/mcp/server.go b/mcp/server.go index b86b3e6c..6a54516e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -757,6 +757,13 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm return res, err } +// discover is the server-side handler for the SEP-2575 "server/discover" RPC. +// +// TODO: Complete implementation. +func (s *Server) discover(context.Context, *ServerRequest[*DiscoverParams]) (*DiscoverResult, error) { + return nil, jsonrpc2.ErrMethodNotFound +} + func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() @@ -1409,6 +1416,7 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware) { // curating these method flags. var serverMethodInfos = map[string]methodInfo{ methodComplete: newServerMethodInfo(serverMethod((*Server).complete), 0), + methodDiscover: newServerMethodInfo(serverMethod((*Server).discover), missingParamsOK), methodInitialize: initializeMethodInfo(), methodPing: newServerMethodInfo(serverSessionMethod((*ServerSession).ping), missingParamsOK), methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), @@ -1482,12 +1490,6 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, return nil, perRequestErr } - if !initialized && validatedMeta.usesNewProtocol && validatedMeta.initializeParams != nil { - ss.updateState(func(state *ServerSessionState) { - state.InitializeParams = validatedMeta.initializeParams - }) - } - switch req.Method { case methodInitialize, methodPing, notificationInitialized: if validatedMeta.usesNewProtocol { @@ -1497,11 +1499,20 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, Message: fmt.Sprintf("%q is not supported in the new protocol", req.Method), } } + case methodDiscover: + // In case of methodDiscover call the state.initializeParams is populated + // within the discover handle function to make sure the method is supported + // when the user is probing a pre-2026-06-30 server. default: if !initialized && !validatedMeta.usesNewProtocol { ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method) return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } + if !initialized && validatedMeta.usesNewProtocol && validatedMeta.initializeParams != nil { + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = validatedMeta.initializeParams + }) + } } // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and diff --git a/mcp/shared.go b/mcp/shared.go index 1caacac3..25044c43 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -105,6 +105,7 @@ func defaultSendingMethodHandler(ctx context.Context, method string, req Request // capabilities, so any panic here is a bug. params = initParams.toV2() } + // Notifications don't have results. if strings.HasPrefix(method, "notifications/") { return nil, req.GetSession().getConn().Notify(ctx, method, params) @@ -344,6 +345,9 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont // MCP-specific error codes. const ( + // CodeUnsupportedProtocolVersion is the JSON-RPC error code defined by + // SEP-2575 for UnsupportedProtocolVersionError. + CodeUnsupportedProtocolVersion = -32004 // CodeHeaderMismatch indicates that HTTP headers do not match the corresponding values // in the request body, or that required headers are missing or malformed. CodeHeaderMismatch = -32001 diff --git a/mcp/streamable.go b/mcp/streamable.go index 0f4e65b8..e6a9bfe5 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1844,6 +1844,13 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { c.initializedResult = state.InitializeResult c.mu.Unlock() + // When the protocol version is >= 2026-06-30, the standalone HTTP GET + // SSE stream is removed. + if state.InitializeResult == nil || + state.InitializeResult.ProtocolVersion >= protocolVersion20260630 { + return + } + // Start the standalone SSE stream as soon as we have the initialized // result, if continuous listening is enabled. // @@ -1903,7 +1910,7 @@ func (c *streamableClientConn) connectStandaloneSSE() { return } summary := "standalone SSE stream" - if err := c.checkResponse(summary, resp); err != nil { + if err := c.checkResponse(c.ctx, summary, resp); err != nil { c.fail(err) return } @@ -2046,10 +2053,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } } - if err := c.checkResponse(requestSummary, resp); err != nil { + if err := c.checkResponse(ctx, requestSummary, resp); err != nil { // Only fail the connection for non-transient errors. // Transient errors (wrapped with ErrRejected) should not break the connection. - if !errors.Is(err, jsonrpc2.ErrRejected) { + // ErrMethodNotFound and ErrUnsupportedProtocolVersion should not break the connection as they trigger the initialize fallback. + if !errors.Is(err, jsonrpc2.ErrRejected) && !errors.Is(err, jsonrpc2.ErrMethodNotFound) && !errors.Is(err, jsonrpc2.ErrUnsupportedProtocolVersion) { c.fail(err) } return err @@ -2143,6 +2151,8 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { } if c.initializedResult != nil { req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) + } else if v := protocolVersionFromContext(req.Context()); v != "" { + req.Header.Set(protocolVersionHeader, v) } if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) @@ -2228,7 +2238,7 @@ func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary str } resp = newResp - if err := c.checkResponse(requestSummary, resp); err != nil { + if err := c.checkResponse(ctx, requestSummary, resp); err != nil { c.fail(err) return } @@ -2239,7 +2249,7 @@ func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary str // translates it into an error if the request was unsuccessful. // // The response body is close if a non-nil error is returned. -func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.Response) (err error) { +func (c *streamableClientConn) checkResponse(ctx context.Context, requestSummary string, resp *http.Response) (err error) { defer func() { if err != nil { resp.Body.Close() @@ -2259,6 +2269,19 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode)) } if resp.StatusCode < 200 || resp.StatusCode >= 300 { + // Read the body and if we can detect vPre servers that + // reject "server/discover" as unsupported method with a plain HTTP 400, + // then return jsonrpc2.ErrMethodNotFound or jsonrpc2.ErrUnsupportedProtocolVersion to trigger the fallback. + protocolVersion := protocolVersionFromContext(ctx) + if protocolVersion != "" && protocolVersion >= protocolVersion20260630 { + body, _ := io.ReadAll(resp.Body) + if strings.Contains(string(body), fmt.Sprintf("%s: %q unsupported", jsonrpc2.ErrNotHandled, methodDiscover)) { + return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrMethodNotFound, http.StatusText(resp.StatusCode)) + } + if strings.Contains(string(body), "Unsupported protocol version") { + return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrUnsupportedProtocolVersion, http.StatusText(resp.StatusCode)) + } + } return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode)) } return nil diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 517e51af..785e8943 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -102,6 +103,23 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques resp, ok := s.responses[key] if !ok { + if key.jsonrpcMethod == "server/discover" { + // Return MethodNotFound to trigger fallback to legacy initialize. + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + respErr := &jsonrpc.Error{ + Code: jsonrpc.CodeMethodNotFound, + Message: `method not found: "server/discover"`, + } + var id jsonrpc.ID + if jsonrpcReq != nil { + id = jsonrpcReq.ID + } + respMsg, _ := jsonrpc2.NewResponse(id, nil, respErr) + data, _ := jsonrpc2.EncodeMessage(respMsg) + w.Write(data) + return + } s.t.Errorf("missing response for %v", key) http.Error(w, "no response", http.StatusInternalServerError) return @@ -1212,3 +1230,408 @@ func TestStreamableClientOAuth_RetrieveError(t *testing.T) { t.Fatalf("client.Connect() error = %v, want %v", err, errTestAuthorizeFailed) } } + +// discoverResult is the canned successful DiscoverResult returned by +// fakeStreamableServer setups in the tests below. +var discoverResult = &DiscoverResult{ + SupportedVersions: []string{protocolVersion20260630}, + Capabilities: &ServerCapabilities{ + Tools: &ToolCapabilities{ListChanged: true}, + }, + ServerInfo: &Implementation{Name: "discoverServer", Version: "v1.0.0"}, + Instructions: "test discover", +} + +// TestStreamableClientConnect_DiscoverSuccess verifies that Client.Connect on +// the streamable transport: +// - sends a POST server/discover with Mcp-Protocol-Version: 2026-06-30 and +// the SEP-2575 per-request _meta triple in the body, +// - on a successful DiscoverResult, skips the legacy initialize handshake +// entirely, and +// - seeds ClientSession.InitializeResult() from the discover response, +// picking a mutually-supported protocol version. +func TestStreamableClientConnect_DiscoverSuccess(t *testing.T) { + ctx := context.Background() + + // Temporarily enable 2026-06-30 support in the SDK for this test + oldSupported := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, supportedProtocolVersions...) + t.Cleanup(func() { + supportedProtocolVersions = oldSupported + }) + + var ( + gotDiscoverMu sync.Mutex + gotDiscover *jsonrpc.Request + ) + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodDiscover, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "sess-1", + }, + wantProtocolVersion: protocolVersion20260630, + responseFunc: func(r *jsonrpc.Request) (string, int) { + gotDiscoverMu.Lock() + gotDiscover = r + gotDiscoverMu.Unlock() + return jsonBody(t, resp(1, discoverResult, nil)), http.StatusOK + }, + }, + {"DELETE", "sess-1", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer session.Close() + + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("missing expected requests: %v", missing) + } + + gotDiscoverMu.Lock() + defer gotDiscoverMu.Unlock() + if gotDiscover == nil { + t.Fatal("server did not receive server/discover") + } + + // Inspect the discover request body for the SEP-2575 _meta triple. + var body struct { + Meta map[string]any `json:"_meta"` + } + if err := json.Unmarshal(gotDiscover.Params, &body); err != nil { + t.Fatalf("decoding discover params: %v", err) + } + if v, _ := body.Meta[MetaKeyProtocolVersion].(string); v != protocolVersion20260630 { + t.Errorf("_meta[%s] = %q, want %q", MetaKeyProtocolVersion, v, protocolVersion20260630) + } + if _, ok := body.Meta[MetaKeyClientInfo]; !ok { + t.Errorf("_meta[%s] missing", MetaKeyClientInfo) + } + if _, ok := body.Meta[MetaKeyClientCapabilities]; !ok { + t.Errorf("_meta[%s] missing", MetaKeyClientCapabilities) + } + + ir := session.InitializeResult() + if ir == nil { + t.Fatal("InitializeResult is nil after Connect") + } + if got, want := ir.ProtocolVersion, protocolVersion20260630; got != want { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q", got, want) + } + if ir.ServerInfo == nil || ir.ServerInfo.Name != "discoverServer" { + t.Errorf("InitializeResult.ServerInfo = %+v, want Name=discoverServer", ir.ServerInfo) + } + if ir.Instructions != "test discover" { + t.Errorf("InitializeResult.Instructions = %q, want %q", ir.Instructions, "test discover") + } +} + +// TestStreamableClientConnect_DiscoverMethodNotFound verifies that Client.Connect +// falls back to the legacy initialize handshake when the server responds to +// server/discover with a JSON-RPC "Method not found" error. +func TestStreamableClientConnect_DiscoverMethodNotFound(t *testing.T) { + ctx := context.Background() + + // Each request gets a fresh jsonrpc2 ID from the same client connection. + // Use responseFunc to echo the request's ID back so the client matches + // the response to the in-flight call regardless of ordering. + echoErr := func(err error) func(*jsonrpc.Request) (string, int) { + return func(r *jsonrpc.Request) (string, int) { + return jsonBody(t, &jsonrpc.Response{ID: r.ID, Error: err.(*jsonrpc.Error)}), http.StatusOK + } + } + echoResult := func(result any) func(*jsonrpc.Request) (string, int) { + return func(r *jsonrpc.Request) (string, int) { + return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(result)}), http.StatusOK + } + } + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodDiscover, ""}: { + header: header{"Content-Type": "application/json"}, + wantProtocolVersion: protocolVersion20260630, + responseFunc: echoErr(&jsonrpc.Error{ + Code: jsonrpc.CodeMethodNotFound, + Message: "method not found", + }), + }, + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "fallback", + }, + responseFunc: echoResult(initResult), + }, + {"POST", "fallback", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "fallback", "", ""}: { + header: header{"Content-Type": "text/event-stream"}, + wantProtocolVersion: latestProtocolVersion, + optional: true, + }, + {"DELETE", "fallback", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer session.Close() + + if got := session.InitializeResult().ProtocolVersion; got != latestProtocolVersion { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q (initialize fallback)", got, latestProtocolVersion) + } +} + +// TestStreamableClientConnect_DiscoverUnsupportedVersion verifies that +// Client.Connect falls back to the legacy initialize handshake when the +// server responds to server/discover with the SEP-2575 +// UnsupportedProtocolVersionError JSON-RPC code. +func TestStreamableClientConnect_DiscoverUnsupportedVersion(t *testing.T) { + ctx := context.Background() + + echoErr := func(err error) func(*jsonrpc.Request) (string, int) { + return func(r *jsonrpc.Request) (string, int) { + return jsonBody(t, &jsonrpc.Response{ID: r.ID, Error: err.(*jsonrpc.Error)}), http.StatusOK + } + } + echoResult := func(result any) func(*jsonrpc.Request) (string, int) { + return func(r *jsonrpc.Request) (string, int) { + return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(result)}), http.StatusOK + } + } + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodDiscover, ""}: { + header: header{"Content-Type": "application/json"}, + wantProtocolVersion: protocolVersion20260630, + responseFunc: echoErr(&jsonrpc.Error{ + Code: CodeUnsupportedProtocolVersion, + Message: "unsupported protocol version", + }), + }, + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "fallback", + }, + responseFunc: echoResult(initResult), + }, + {"POST", "fallback", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "fallback", "", ""}: { + header: header{"Content-Type": "text/event-stream"}, + wantProtocolVersion: latestProtocolVersion, + optional: true, + }, + {"DELETE", "fallback", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer session.Close() + + if got := session.InitializeResult().ProtocolVersion; got != latestProtocolVersion { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q (initialize fallback)", got, latestProtocolVersion) + } +} + +// TestStreamableClientConnect_DiscoverPropagatesOtherErrors verifies that +// Client.Connect does NOT fall back to initialize when server/discover +// returns an unrelated JSON-RPC error (here, CodeInternalError). The Connect +// call should fail with the propagated error rather than masking it. +func TestStreamableClientConnect_DiscoverPropagatesOtherErrors(t *testing.T) { + ctx := context.Background() + + var sawInitialize atomic.Bool + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodDiscover, ""}: { + header: header{"Content-Type": "application/json"}, + wantProtocolVersion: protocolVersion20260630, + responseFunc: func(r *jsonrpc.Request) (string, int) { + return jsonBody(t, &jsonrpc.Response{ + ID: r.ID, + Error: &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "boom", + }, + }), http.StatusOK + }, + }, + {"POST", "", methodInitialize, ""}: { + responseFunc: func(r *jsonrpc.Request) (string, int) { + sawInitialize.Store(true) + return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(initResult)}), http.StatusOK + }, + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "fallback", + }, + optional: true, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err == nil { + _ = session.Close() + t.Fatal("Connect succeeded; want propagated error") + } + if sawInitialize.Load() { + t.Error("server received initialize; Connect should have aborted on the discover error") + } +} + +// TestStreamableClientConnect_DiscoverMethodNotFoundVPre verifies that +// Client.Connect falls back to the legacy initialize handshake when a +// pre-SEP-2575 (vPre) server rejects server/discover. +func TestStreamableClientConnect_DiscoverMethodNotFoundVPre(t *testing.T) { + ctx := context.Background() + + echoResult := func(result any) func(*jsonrpc.Request) (string, int) { + return func(r *jsonrpc.Request) (string, int) { + return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(result)}), http.StatusOK + } + } + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodDiscover, ""}: { + wantProtocolVersion: protocolVersion20260630, + // Reproduce the exact body a vPre server produces via + // http.Error(w, err.Error(), 400) where err comes from + // checkRequest. http.Error appends a trailing newline. + body: "JSON RPC not handled: \"server/discover\" unsupported\n", + status: http.StatusBadRequest, + header: header{"Content-Type": "text/plain; charset=utf-8"}, + }, + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "fallback", + }, + responseFunc: echoResult(initResult), + }, + {"POST", "fallback", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"DELETE", "fallback", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + DisableStandaloneSSE: true, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer session.Close() + + if got := session.InitializeResult().ProtocolVersion; got != latestProtocolVersion { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q (initialize fallback)", got, latestProtocolVersion) + } +} + +// TestStreamableClientConnect_DiscoverUnsupportedProtocolVersion verifies that +// Client.Connect falls back to the legacy initialize handshake when a +// server rejects server/discover with a plain HTTP 400 containing "Unsupported protocol version". +func TestStreamableClientConnect_DiscoverUnsupportedVersionVPre(t *testing.T) { + ctx := context.Background() + + echoResult := func(result any) func(*jsonrpc.Request) (string, int) { + return func(r *jsonrpc.Request) (string, int) { + return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(result)}), http.StatusOK + } + } + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodDiscover, ""}: { + wantProtocolVersion: protocolVersion20260630, + body: "Bad Request: Unsupported protocol version\n", + status: http.StatusBadRequest, + header: header{"Content-Type": "text/plain; charset=utf-8"}, + }, + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "fallback", + }, + responseFunc: echoResult(initResult), + }, + {"POST", "fallback", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"DELETE", "fallback", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + DisableStandaloneSSE: true, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer session.Close() + + if got := session.InitializeResult().ProtocolVersion; got != latestProtocolVersion { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q (initialize fallback)", got, latestProtocolVersion) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 53806da9..a27696a1 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -2049,68 +2049,6 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }) } -// TODO: Remove this once client operations will automatically inject metadata in the requests -func injectMetaToRequest(req *http.Request) error { - if req.Body == nil { - return nil - } - body, err := io.ReadAll(req.Body) - if err != nil { - return err - } - req.Body.Close() - - var val any - if err := json.Unmarshal(body, &val); err == nil { - var method string - if m, ok := val.(map[string]any); ok { - method, _ = m["method"].(string) - } else if list, ok := val.([]any); ok && len(list) > 0 { - if m, ok := list[0].(map[string]any); ok { - method, _ = m["method"].(string) - } - } - - if method == "initialize" || method == "notifications/initialized" || strings.HasPrefix(method, "notifications/") { - req.Header.Set(protocolVersionHeader, "2025-11-25") - } else { - req.Header.Set(protocolVersionHeader, minVersionForStandardHeaders) - - var msgs []map[string]any - if m, ok := val.(map[string]any); ok { - msgs = []map[string]any{m} - } else if list, ok := val.([]any); ok { - for _, item := range list { - if m, ok := item.(map[string]any); ok { - msgs = append(msgs, m) - } - } - } - - for _, m := range msgs { - params, _ := m["params"].(map[string]any) - if params == nil { - params = make(map[string]any) - m["params"] = params - } - meta, _ := params["_meta"].(map[string]any) - if meta == nil { - meta = make(map[string]any) - params["_meta"] = meta - } - meta[MetaKeyProtocolVersion] = minVersionForStandardHeaders - meta[MetaKeyClientInfo] = map[string]any{"name": "testClient", "version": "v1.0.0"} - meta[MetaKeyClientCapabilities] = map[string]any{} - } - body, _ = json.Marshal(val) - } - } - - req.Body = io.NopCloser(bytes.NewReader(body)) - req.ContentLength = int64(len(body)) - return nil -} - // TestStreamableMcpHeaderValidationErrorFormat verifies that header // validation errors return a JSON-RPC error with code -32001 and // Content-Type application/json, per SEP-2243. @@ -2131,7 +2069,15 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { }) defer handler.closeAll() - httpServer := httptest.NewServer(mustNotPanic(t, handler)) + // TODO(SEP-2575): drop discoverInterceptor and hit `handler` directly + // once Server.discover returns a real DiscoverResult instead of + // MethodNotFound. See comment on discoverInterceptor for details. + wrapped := discoverInterceptor(t, handler, + []string{minVersionForStandardHeaders}, + &ServerCapabilities{Tools: &ToolCapabilities{}}, + &Implementation{Name: "testServer", Version: "v1.0.0"}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, wrapped)) defer httpServer.Close() // Use the MCP client with a custom RoundTripper to inject a bad header. @@ -2140,9 +2086,6 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - if err := injectMetaToRequest(req); err != nil { - return nil, err - } var originalMethodHeader string if req.Header.Get(methodHeader) == "tools/call" { originalMethodHeader = req.Header.Get(methodHeader) @@ -2296,15 +2239,20 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { Stateless: true, }) defer handler.closeAll() - httpServer := httptest.NewServer(mustNotPanic(t, handler)) + // TODO(SEP-2575): drop discoverInterceptor and hit `handler` directly + // once Server.discover returns a real DiscoverResult instead of + // MethodNotFound. See comment on discoverInterceptor for details. + wrapped := discoverInterceptor(t, handler, + []string{minVersionForStandardHeaders}, + &ServerCapabilities{Tools: &ToolCapabilities{ListChanged: true}}, + &Implementation{Name: "testServer", Version: "v1.0.0"}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, wrapped)) defer httpServer.Close() var capturedHeaders http.Header customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - if err := injectMetaToRequest(req); err != nil { - return nil, err - } if req.Header.Get(methodHeader) == "tools/call" { capturedHeaders = req.Header.Clone() } @@ -2326,7 +2274,9 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { defer session.Close() // ListTools to populate the tool cache (needed for param headers). - if _, err := session.ListTools(ctx, nil); err != nil { + // Pass a non-nil params so the SEP-2575 per-request _meta triple is + // injected; injectMeta is a no-op when params is nil. + if _, err := session.ListTools(ctx, &ListToolsParams{}); err != nil { t.Fatal(err) } @@ -2411,23 +2361,21 @@ func TestStreamableFilterValidToolsIntegration(t *testing.T) { Stateless: true, }) defer handler.closeAll() - httpServer := httptest.NewServer(mustNotPanic(t, handler)) + // TODO(SEP-2575): drop discoverInterceptor and hit `handler` directly + // once Server.discover returns a real DiscoverResult instead of + // MethodNotFound. See comment on discoverInterceptor for details. + wrapped := discoverInterceptor(t, handler, + []string{minVersionForStandardHeaders}, + &ServerCapabilities{Tools: &ToolCapabilities{ListChanged: true}}, + &Implementation{Name: "testServer", Version: "v1.0.0"}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, wrapped)) defer httpServer.Close() - customClient := &http.Client{ - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - if err := injectMetaToRequest(req); err != nil { - return nil, err - } - return http.DefaultTransport.RoundTrip(req) - }), - } - client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) ctx := context.Background() session, err := client.Connect(ctx, &StreamableClientTransport{ - Endpoint: httpServer.URL, - HTTPClient: customClient, + Endpoint: httpServer.URL, }, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) if err != nil { t.Fatal(err) @@ -2742,6 +2690,51 @@ func TestStreamableSessionTimeout(t *testing.T) { handler.mu.Unlock() } +// discoverInterceptor wraps an HTTP handler so that POST requests carrying a +// server/discover JSON-RPC request are answered with a canned DiscoverResult +// advertising the given supportedVersions. All other requests are forwarded +// to next unchanged. +// +// TODO(SEP-2575): this is a workaround for tests that need an end-to-end +// SEP-2575 session (e.g. to exercise the Mcp-Method / Mcp-Param-* request +// headers gated on protocol >= 2026-06-30) while the server-side +// Server.discover implementation still returns MethodNotFound. Once +// server-side discover is implemented, this helper can be removed and the +// tests can hit the real handler directly. +func discoverInterceptor(t *testing.T, next http.Handler, supportedVersions []string, capabilities *ServerCapabilities, serverInfo *Implementation) http.Handler { + t.Helper() + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + next.ServeHTTP(w, req) + return + } + body, err := io.ReadAll(req.Body) + req.Body.Close() + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + req.Body = io.NopCloser(bytes.NewReader(body)) + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + next.ServeHTTP(w, req) + return + } + r, ok := msg.(*jsonrpc.Request) + if !ok || r.Method != methodDiscover { + next.ServeHTTP(w, req) + return + } + result := &DiscoverResult{ + SupportedVersions: supportedVersions, + Capabilities: capabilities, + ServerInfo: serverInfo, + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(result)}))) + }) +} + // mustNotPanic is a helper to enforce that test handlers do not panic (see // issue #556). func mustNotPanic(t *testing.T, h http.Handler) http.Handler { @@ -3592,3 +3585,48 @@ func TestStreamableStateless_AcceptsNewProtocol(t *testing.T) { t.Fatalf("status = %d, want 200; body = %s", resp.StatusCode, respBody) } } + +// TestStreamableClientUnsupportedVersionFallback exercises the full +// SEP-2575 fallback. The client requests protocolVersion20260630, which the server does +// not advertise in supportedProtocolVersions. The server therefore rejects +// the server/discover POST at the transport-level header validation with a +// plain HTTP 400 ("Bad Request: Unsupported protocol version ..."). The +// streamable client must recognize this body, keep the connection alive, and +// successfully complete the legacy initialize handshake. +// +// TODO: once 20260630 is part of supportedProtocolVersion on server side, modify the list in the test to keep it consistent. +func TestStreamableClientUnsupportedVersionFallback(t *testing.T) { + ctx := context.Background() + + server := NewServer(testImpl, nil) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + nil, + ) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + client := NewClient(testImpl, nil) + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + + session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer session.Close() + + ir := session.InitializeResult() + if ir == nil { + t.Fatal("InitializeResult is nil; expected legacy initialize to populate it") + } + if ir.ProtocolVersion != latestProtocolVersion { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q (legacy fallback)", + ir.ProtocolVersion, latestProtocolVersion) + } + + // Verify the session is fully usable after the fallback by issuing a + // real call against the server. + if err := session.Ping(ctx, nil); err != nil { + t.Errorf("Ping after fallback initialize: %v", err) + } +}