diff --git a/mcp/client.go b/mcp/client.go index 5f142fb1..1e0e18a4 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -290,18 +290,41 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp // signals it doesn't support it, fall back to the legacy initialize // handshake. discoverCtx := context.WithValue(ctx, protocolVersionContextKey{}, protocolVersion) - discRes, fallback, err := c.discover(discoverCtx, cs) - if err != nil { - return nil, err - } - if !fallback { - cs.state.InitializeResult = discRes - if hc, ok := cs.mcpConn.(clientConnection); ok { - hc.sessionUpdated(cs.state) + // We try to discover the server's capabilities. If the server rejects the + // requested version but specifies which versions it supports, we negotiate + // a mutually supported version and try again. + for range 2 { + discRes, err := c.discover(discoverCtx, cs) + if err == nil { + cs.state.InitializeResult = discRes + if hc, ok := cs.mcpConn.(clientConnection); ok { + hc.sessionUpdated(cs.state) + } + return cs, nil + } + + var werr *jsonrpc.Error + if !errors.As(err, &werr) { + return nil, err + } + // Try to negotiate a mutually supported version if the server + // reports an UnsupportedProtocolVersionError with a supported version. + if werr.Code == CodeUnsupportedProtocolVersion && werr.Data != nil { + var data UnsupportedProtocolVersionData + if err := json.Unmarshal(werr.Data, &data); err == nil { + if negotiatedVersion := negotiateMutuallySupportedVersion(data.Supported); negotiatedVersion != "" && negotiatedVersion >= protocolVersion20260630 { + discoverCtx = context.WithValue(ctx, protocolVersionContextKey{}, negotiatedVersion) + continue + } + } } - return cs, nil + // MethodNotFound and UnsupportedProtocolVersion trigger a fallback to legacy initialize. + if werr.Code == jsonrpc.CodeMethodNotFound || werr.Code == CodeUnsupportedProtocolVersion { + break + } + return nil, err } - // Fallback to the legacy initialize handshake. + // Fallback to the legacy initialize handshake with the legacy protocol version. protocolVersion = protocolVersion20251125 } @@ -338,15 +361,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp // discover sends a SEP-2575 server/discover request to probe the server for // stateless protocol support. -// -// The return values have three possible combinations: -// - (result, false, nil): discovery succeeded; caller should skip legacy initialization. -// - (nil, true, nil): the server explicitly signaled it doesn't support -// discovery (Method not found, or UnsupportedProtocolVersionError, or version mismatch); -// caller should fall back to the legacy initialize handshake. -// - (nil, false, err): any other failure (transport error, malformed response, etc.); -// caller should propagate the error. -func (c *Client) discover(ctx context.Context, cs *ClientSession) (*InitializeResult, bool, error) { +func (c *Client) discover(ctx context.Context, cs *ClientSession) (*InitializeResult, error) { protocolVersion := protocolVersionFromContext(ctx) caps := c.capabilities(protocolVersion) params := &DiscoverParams{ @@ -359,13 +374,7 @@ func (c *Client) discover(ctx context.Context, cs *ClientSession) (*InitializeRe req := &DiscoverRequest{Session: cs, Params: params} res, err := handleSend[*DiscoverResult](ctx, methodDiscover, req) if err != nil { - // According to SEP-2575, only the two signals below (MethodNotFound - // and UnsupportedProtocolVersionError) should trigger a fallback. - var werr *jsonrpc.Error - if errors.As(err, &werr) && (werr.Code == jsonrpc.CodeMethodNotFound || werr.Code == CodeUnsupportedProtocolVersion) { - return nil, true, nil - } - return nil, false, err + return nil, err } // Pick the highest protocol version that both the server and this SDK support. @@ -375,17 +384,12 @@ func (c *Client) discover(ctx context.Context, cs *ClientSession) (*InitializeRe if slices.Contains(res.SupportedVersions, protocolVersion) { negotiated = protocolVersion } else { - for _, v := range supportedProtocolVersions { - if slices.Contains(res.SupportedVersions, v) { - negotiated = v - break - } - } + negotiated = negotiateMutuallySupportedVersion(res.SupportedVersions) } if negotiated == "" || negotiated < protocolVersion20260630 { // If there is no overlap, fall back to initialize so version // negotiation can happen via the legacy path. - return nil, true, nil + return nil, jsonrpc2.ErrUnsupportedProtocolVersion } return &InitializeResult{ @@ -393,7 +397,7 @@ func (c *Client) discover(ctx context.Context, cs *ClientSession) (*InitializeRe Instructions: res.Instructions, ProtocolVersion: negotiated, ServerInfo: res.ServerInfo, - }, false, nil + }, nil } // A ClientSession is a logical connection with an MCP server. Its diff --git a/mcp/client_test.go b/mcp/client_test.go index 1d629521..75d1c4f9 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "encoding/json" "fmt" "log/slog" "sync/atomic" @@ -841,3 +842,114 @@ func TestClientConnectDiscover_RequestContents(t *testing.T) { t.Errorf("clientCapabilities.sampling missing (CreateMessageHandler was set); got %v", caps) } } + +// If the server does not support the requested version, it returns an +// UnsupportedProtocolVersionError containing its list of supported +// versions. The client selects a mutually supported version from the list +// and retries. +func TestClientConnectDiscover_UnsupportedVersionNegotiation(t *testing.T) { + // Temporarily enable 2026-06-30 support in the SDK for this test so it + // is a candidate during negotiation. + oldSupported := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, supportedProtocolVersions...) + t.Cleanup(func() { + supportedProtocolVersions = oldSupported + }) + + ctx := context.Background() + + const ( + unsupportedClientVersion = "2099-12-31" + serverNegotiatedVersion = protocolVersion20260630 + ) + + var ( + discoverCalls atomic.Int32 + gotInitialize atomic.Bool + firstRequestedVersion atomic.Value // string + secondRequestedVersion atomic.Value // string + ) + + s := NewServer(testImpl, nil) + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + switch method { + case methodDiscover: + sr, ok := req.(*ServerRequest[*DiscoverParams]) + if !ok { + t.Errorf("discover req has unexpected type %T", req) + return nil, jsonrpc2.ErrMethodNotFound + } + requested, _ := sr.Params.GetMeta()[MetaKeyProtocolVersion].(string) + + n := discoverCalls.Add(1) + switch n { + case 1: + firstRequestedVersion.Store(requested) + data, err := json.Marshal(UnsupportedProtocolVersionData{ + Supported: []string{serverNegotiatedVersion}, + Requested: requested, + }) + if err != nil { + t.Fatalf("marshal error data: %v", err) + } + return nil, &jsonrpc.Error{ + Code: CodeUnsupportedProtocolVersion, + Message: "unsupported protocol version", + Data: data, + } + case 2: + secondRequestedVersion.Store(requested) + return &DiscoverResult{ + SupportedVersions: []string{serverNegotiatedVersion}, + Capabilities: &ServerCapabilities{ + Tools: &ToolCapabilities{ListChanged: true}, + }, + ServerInfo: &Implementation{Name: "discoverServer", Version: "v1.0.0"}, + }, nil + default: + t.Errorf("unexpected discover call #%d", n) + return nil, jsonrpc2.ErrMethodNotFound + } + case methodInitialize: + gotInitialize.Store(true) + } + return next(ctx, method, req) + } + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server Connect: %v", err) + } + defer ss.Close() + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: unsupportedClientVersion}) + if err != nil { + t.Fatalf("Connect: %v", err) + } + defer cs.Close() + + if got, want := discoverCalls.Load(), int32(2); got != want { + t.Errorf("server/discover call count = %d, want %d", got, want) + } + if got, _ := firstRequestedVersion.Load().(string); got != unsupportedClientVersion { + t.Errorf("first discover requested version = %q, want %q", got, unsupportedClientVersion) + } + if got, _ := secondRequestedVersion.Load().(string); got != serverNegotiatedVersion { + t.Errorf("retry discover requested version = %q, want %q (server's advertised supported version)", got, serverNegotiatedVersion) + } + if gotInitialize.Load() { + t.Error("legacy initialize handshake ran, but negotiated discover should have succeeded") + } + + ir := cs.InitializeResult() + if ir == nil { + t.Fatal("InitializeResult is nil after Connect") + } + if got, want := ir.ProtocolVersion, serverNegotiatedVersion; got != want { + t.Errorf("InitializeResult.ProtocolVersion = %q, want %q", got, want) + } +} diff --git a/mcp/protocol.go b/mcp/protocol.go index f5b79aeb..d838067b 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -2103,3 +2103,14 @@ const ( // MetaKeyClientCapabilities carries the client's [ClientCapabilities]. MetaKeyClientCapabilities = "io.modelcontextprotocol/clientCapabilities" ) + +// UnsupportedProtocolVersionData is the SEP-2575 payload carried in the +// `data` field of a JSON-RPC error response with code +// [CodeUnsupportedProtocolVersion]. The server uses it to advertise which +// versions it supports so the client can pick a mutually supported one. +type UnsupportedProtocolVersionData struct { + // Supported is the list of protocol versions the server supports. + Supported []string `json:"supported"` + // Requested is the protocol version the client asked for. + Requested string `json:"requested"` +} diff --git a/mcp/shared.go b/mcp/shared.go index a0232d37..88c2464e 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -64,6 +64,17 @@ func negotiatedVersion(clientVersion string) string { return clientVersion } +// negotiateMutuallySupportedVersion returns a protocol version that is supported +// by both the client and the server. +func negotiateMutuallySupportedVersion(supported []string) string { + for _, ver := range supportedProtocolVersions { + if slices.Contains(supported, ver) { + return ver + } + } + return "" +} + // A MethodHandler handles MCP messages. // For methods, exactly one of the return values must be nil. // For notifications, both must be nil.