diff --git a/mcp/protocol.go b/mcp/protocol.go index 1646788a..f55055ca 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -165,10 +165,12 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error { } func (x *CallToolParams) isParams() {} +func (x *CallToolParams) isNil() bool { return x == nil } func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) isNil() bool { return x == nil } func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } @@ -187,6 +189,7 @@ type CancelledParams struct { } func (x *CancelledParams) isParams() {} +func (x *CancelledParams) isNil() bool { return x == nil } func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -374,7 +377,8 @@ type CompleteParams struct { Ref *CompleteReference `json:"ref"` } -func (*CompleteParams) isParams() {} +func (x *CompleteParams) isParams() {} +func (x *CompleteParams) isNil() bool { return x == nil } type CompletionResultDetails struct { HasMore bool `json:"hasMore,omitempty"` @@ -422,6 +426,7 @@ type CreateMessageParams struct { } func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) isNil() bool { return x == nil } func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -448,6 +453,7 @@ type CreateMessageWithToolsParams struct { } func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) isNil() bool { return x == nil } func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -654,6 +660,7 @@ type GetPromptParams struct { } func (x *GetPromptParams) isParams() {} +func (x *GetPromptParams) isNil() bool { return x == nil } func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -706,6 +713,7 @@ func (p *initializeParamsV2) toV1() *InitializeParams { } func (x *InitializeParams) isParams() {} +func (x *InitializeParams) isNil() bool { return x == nil } func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -739,6 +747,7 @@ type InitializedParams struct { } func (x *InitializedParams) isParams() {} +func (x *InitializedParams) isNil() bool { return x == nil } func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -752,6 +761,7 @@ type ListPromptsParams struct { } func (x *ListPromptsParams) isParams() {} +func (x *ListPromptsParams) isNil() bool { return x == nil } func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } @@ -780,6 +790,7 @@ type ListResourceTemplatesParams struct { } func (x *ListResourceTemplatesParams) isParams() {} +func (x *ListResourceTemplatesParams) isNil() bool { return x == nil } func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } @@ -808,6 +819,7 @@ type ListResourcesParams struct { } func (x *ListResourcesParams) isParams() {} +func (x *ListResourcesParams) isNil() bool { return x == nil } func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } @@ -833,6 +845,7 @@ type ListRootsParams struct { } func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) isNil() bool { return x == nil } func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -858,6 +871,7 @@ type ListToolsParams struct { } func (x *ListToolsParams) isParams() {} +func (x *ListToolsParams) isNil() bool { return x == nil } func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } @@ -896,6 +910,7 @@ type LoggingMessageParams struct { } func (x *LoggingMessageParams) isParams() {} +func (x *LoggingMessageParams) isNil() bool { return x == nil } func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -958,6 +973,7 @@ type PingParams struct { } func (x *PingParams) isParams() {} +func (x *PingParams) isNil() bool { return x == nil } func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -978,7 +994,8 @@ type ProgressNotificationParams struct { Total float64 `json:"total,omitempty"` } -func (*ProgressNotificationParams) isParams() {} +func (x *ProgressNotificationParams) isParams() {} +func (x *ProgressNotificationParams) isNil() bool { return x == nil } // IconTheme specifies the theme an icon is designed for. type IconTheme string @@ -1048,6 +1065,7 @@ type PromptListChangedParams struct { } func (x *PromptListChangedParams) isParams() {} +func (x *PromptListChangedParams) isNil() bool { return x == nil } func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1089,6 +1107,7 @@ type ReadResourceParams struct { } func (x *ReadResourceParams) isParams() {} +func (x *ReadResourceParams) isNil() bool { return x == nil } func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1145,6 +1164,7 @@ type ResourceListChangedParams struct { } func (x *ResourceListChangedParams) isParams() {} +func (x *ResourceListChangedParams) isNil() bool { return x == nil } func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1205,6 +1225,7 @@ type RootsListChangedParams struct { } func (x *RootsListChangedParams) isParams() {} +func (x *RootsListChangedParams) isNil() bool { return x == nil } func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1288,6 +1309,7 @@ type SetLoggingLevelParams struct { } func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) isNil() bool { return x == nil } func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1390,6 +1412,7 @@ type ToolListChangedParams struct { } func (x *ToolListChangedParams) isParams() {} +func (x *ToolListChangedParams) isNil() bool { return x == nil } func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1403,7 +1426,8 @@ type SubscribeParams struct { URI string `json:"uri"` } -func (*SubscribeParams) isParams() {} +func (x *SubscribeParams) isParams() {} +func (x *SubscribeParams) isNil() bool { return x == nil } // Sent from the client to request cancellation of resources/updated // notifications from the server. This should follow a previous @@ -1416,7 +1440,8 @@ type UnsubscribeParams struct { URI string `json:"uri"` } -func (*UnsubscribeParams) isParams() {} +func (x *UnsubscribeParams) isParams() {} +func (x *UnsubscribeParams) isNil() bool { return x == nil } // A notification from the server to the client, informing it that a resource // has changed and may need to be read again. This should only be sent if the @@ -1429,7 +1454,8 @@ type ResourceUpdatedNotificationParams struct { URI string `json:"uri"` } -func (*ResourceUpdatedNotificationParams) isParams() {} +func (x *ResourceUpdatedNotificationParams) isParams() {} +func (x *ResourceUpdatedNotificationParams) isNil() bool { return x == nil } // TODO(jba): add CompleteRequest and related types. @@ -1468,7 +1494,8 @@ type ElicitParams struct { ElicitationID string `json:"elicitationId,omitempty"` } -func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isNil() bool { return x == nil } func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1500,7 +1527,8 @@ type ElicitationCompleteParams struct { ElicitationID string `json:"elicitationId"` } -func (*ElicitationCompleteParams) isParams() {} +func (x *ElicitationCompleteParams) isParams() {} +func (x *ElicitationCompleteParams) isNil() bool { return x == nil } // An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. @@ -1630,3 +1658,17 @@ const ( notificationToolListChanged = "notifications/tools/list_changed" methodUnsubscribe = "resources/unsubscribe" ) + +// Per-request _meta field names for the >= 2026-06-30 protocol version. +// +// These keys appear inside a Params._meta map and carry information that +// previously came from the initialization handshake (SEP-2575). +const ( + // MetaKeyProtocolVersion identifies the MCP protocol version that the + // request follows. + MetaKeyProtocolVersion = "io.modelcontextprotocol/protocolVersion" + // MetaKeyClientInfo carries the client's [Implementation]. + MetaKeyClientInfo = "io.modelcontextprotocol/clientInfo" + // MetaKeyClientCapabilities carries the client's [ClientCapabilities]. + MetaKeyClientCapabilities = "io.modelcontextprotocol/clientCapabilities" +) diff --git a/mcp/server.go b/mcp/server.go index 183226d1..bc6b64a6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1450,13 +1450,32 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, initialized := ss.state.InitializeParams != nil ss.mu.Unlock() - // From the spec: - // "The client SHOULD NOT send requests other than pings before the server - // has responded to the initialize request." + // Per-request protocol detection (SEP-2575): if the request carries + // `io.modelcontextprotocol/protocolVersion` in its `_meta` field, it + // follows the new sessionless protocol. The initialization gate is + // skipped for such requests. + validatedMeta, perRequestErr := validateRequestMeta(req) + if perRequestErr != nil { + return nil, perRequestErr + } + + if !initialized && validatedMeta.usesNewProtocol && validatedMeta.initializeParams != nil { + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = validatedMeta.initializeParams + }) + } + switch req.Method { case methodInitialize, methodPing, notificationInitialized: + if validatedMeta.usesNewProtocol { + ss.server.opts.Logger.Error("method removed in the new protocol", "method", req.Method) + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeMethodNotFound, + Message: fmt.Sprintf("%q is not supported in the new protocol", req.Method), + } + } default: - if !initialized { + if !initialized && !validatedMeta.usesNewProtocol { ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method) return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } diff --git a/mcp/server_test.go b/mcp/server_test.go index 2937ea2b..8eb462d5 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "log" "log/slog" @@ -19,6 +20,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) type testItem struct { @@ -1007,3 +1009,167 @@ func TestServerCapabilitiesOverWire(t *testing.T) { }) } } + +// SEP-2575 removes the initialization handshake. An `initialize` request +// that opts into the new protocol via `_meta.protocolVersion` must be +// rejected with `Method not found` (-32601). +func TestServerSessionHandle_RejectsInitializeOnNewProtocol(t *testing.T) { + tests := []struct { + name string + params any + wantReject bool + }{ + { + name: "initialize with new-protocol _meta is rejected", + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "protocolVersion": protocolVersion20260630, + }, + wantReject: true, + }, + { + name: "initialize without _meta is allowed (old protocol)", + params: map[string]any{ + "protocolVersion": protocolVersion20251125, + }, + wantReject: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: methodInitialize, + Params: mustMarshal(tc.params), + } + _, err = ss.handle(context.Background(), req) + if tc.wantReject { + if err == nil { + t.Fatal("expected error rejecting initialize, got nil") + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("error type = %T, want *jsonrpc.Error so the wire returns the right code", err) + } + if jerr.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("error code = %d, want %d (CodeMethodNotFound = -32601)", jerr.Code, jsonrpc.CodeMethodNotFound) + } + if !strings.Contains(jerr.Message, "initialize") { + t.Errorf("error message %q does not mention %q", jerr.Message, "initialize") + } + } else { + // Old-protocol initialize should be dispatched normally; any + // CodeMethodNotFound here would mean the rejection branch + // fired incorrectly. + var jerr *jsonrpc.Error + if errors.As(err, &jerr) && jerr.Code == jsonrpc.CodeMethodNotFound { + t.Errorf("old-protocol initialize was incorrectly rejected: %v", err) + } + } + }) + } + + t.Run("rejection error encodes to wire as code -32601", func(t *testing.T) { + // Belt-and-braces check that the error type produced by handle() + // actually serializes to JSON-RPC code -32601, not a bare 0. + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: methodInitialize, + Params: mustMarshal(map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "protocolVersion": protocolVersion20260630, + }), + } + _, handleErr := ss.handle(context.Background(), req) + if handleErr == nil { + t.Fatal("expected rejection error, got nil") + } + data, encErr := jsonrpc.EncodeMessage(&jsonrpc.Response{ID: id, Error: handleErr.(*jsonrpc.Error)}) + if encErr != nil { + t.Fatal(encErr) + } + var wire struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(data, &wire); err != nil { + t.Fatal(err) + } + if wire.Error.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("wire error code = %d, want %d; full response = %s", wire.Error.Code, jsonrpc.CodeMethodNotFound, data) + } + }) +} + +// TestServerSessionHandle_RejectsRemovedMethodsOnNewProtocol verifies that +// the methods removed by SEP-2575 (`initialize`, `notifications/initialized`, +// `ping`) all return Method not found when the request opts into the new +// protocol via `_meta.protocolVersion`. +func TestServerSessionHandle_RejectsRemovedMethodsOnNewProtocol(t *testing.T) { + newProtoMeta := map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + } + + tests := []struct { + name string + method string + }{ + {"initialize", methodInitialize}, + {"ping", methodPing}, + {"notifications/initialized", notificationInitialized}, + } + + for _, tc := range tests { + t.Run(tc.name+" rejected on new protocol", func(t *testing.T) { + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: tc.method, + Params: mustMarshal(newProtoMeta), + } + _, err = ss.handle(context.Background(), req) + if err == nil { + t.Fatalf("method %q on new protocol: got nil error, want CodeMethodNotFound", tc.method) + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("error type = %T, want *jsonrpc.Error", err) + } + if jerr.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("method %q: code = %d, want %d", tc.method, jerr.Code, jsonrpc.CodeMethodNotFound) + } + if !strings.Contains(jerr.Message, tc.method) { + t.Errorf("method %q: message %q does not mention method name", tc.method, jerr.Message) + } + }) + } +} diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..d06e5c9b 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -465,6 +465,69 @@ func setProgressToken(p Params, pt any) { m[progressTokenKey] = pt } +// extractRequestMeta performs a lightweight partial unmarshal of the `_meta` +// field from a JSON-RPC request's raw params. +func extractRequestMeta(rawParams json.RawMessage) Meta { + if len(rawParams) == 0 { + return nil + } + var meta struct { + Meta Meta `json:"_meta"` + } + if err := internaljson.Unmarshal(rawParams, &meta); err != nil { + return nil + } + return meta.Meta +} + +type validatedMeta struct { + usesNewProtocol bool + initializeParams *InitializeParams +} + +// validateRequestMeta inspects a JSON-RPC request to detect whether it follows +// the >= 2026-06-30 protocol via the `_meta` field. +// If the request has no _meta, or no protocolVersion in _meta, it returns a non-nil +// validatedMeta with usesNewProtocol set to false, and a nil error. +// If the request has a protocolVersion in _meta: +// - For notifications, it returns usesNewProtocol set to true and a nil initializeParams. +// - For call requests, it validates the presence of clientInfo and clientCapabilities in _meta. +// If either is missing or invalid, it returns nil and a non-nil error. Otherwise, it returns +// usesNewProtocol set to true and the populated initializeParams. +func validateRequestMeta(req *jsonrpc.Request) (*validatedMeta, error) { + meta := extractRequestMeta(req.Params) + if meta == nil { + return &validatedMeta{usesNewProtocol: false, initializeParams: nil}, nil + } + protocolVersion, ok := meta[MetaKeyProtocolVersion].(string) + if !ok { + return &validatedMeta{usesNewProtocol: false, initializeParams: nil}, nil + } + // Notifications do not carry full client identity + if !req.IsCall() { + return &validatedMeta{usesNewProtocol: true, initializeParams: nil}, nil + } + clientInfo, ok := decodeMetaValue[*Implementation](meta, MetaKeyClientInfo) + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("missing or invalid _meta field %q", MetaKeyClientInfo), + } + } + capabilities, ok := decodeMetaValue[*ClientCapabilities](meta, MetaKeyClientCapabilities) + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("missing or invalid _meta field %q", MetaKeyClientCapabilities), + } + } + return &validatedMeta{usesNewProtocol: true, initializeParams: &InitializeParams{ + ProtocolVersion: protocolVersion, + Capabilities: capabilities, + ClientInfo: clientInfo, + }}, nil +} + // A Request is a method request with parameters and additional information, such as the session. // Request is implemented by [*ClientRequest] and [*ServerRequest]. type Request interface { @@ -525,6 +588,94 @@ func (r *ServerRequest[P]) GetParams() Params { return r.Params } func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } +// ProtocolVersion returns the protocol version negotiated for this request. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams] established during the +// initialize handshake. +func (r *ServerRequest[P]) ProtocolVersion() string { + if m := getRequestMeta(r); m != nil { + if v, ok := m[MetaKeyProtocolVersion].(string); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.ProtocolVersion + } + } + return "" +} + +// ClientInfo returns the [Implementation] identifying the calling client. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams]. +func (r *ServerRequest[P]) ClientInfo() *Implementation { + if m := getRequestMeta(r); m != nil { + if v, ok := decodeMetaValue[*Implementation](m, MetaKeyClientInfo); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.ClientInfo + } + } + return nil +} + +// ClientCapabilities returns the [ClientCapabilities] of the calling client. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams]. +func (r *ServerRequest[P]) ClientCapabilities() *ClientCapabilities { + if m := getRequestMeta(r); m != nil { + if v, ok := decodeMetaValue[*ClientCapabilities](m, MetaKeyClientCapabilities); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.Capabilities + } + } + return nil +} + +// getRequestMeta returns the raw `_meta` map from the request's params, or +// nil if the params are absent. +func getRequestMeta[P Params](r *ServerRequest[P]) map[string]any { + // In practice P is a pointer type implementing Params. + if any(r.Params) == nil || r.Params.isNil() { + return nil + } + return r.Params.GetMeta() +} + +// decodeMetaValue decodes a typed value out of a `_meta` map. Values may +// arrive either as the typed Go value (when constructed in-process) or as +// the generic JSON map produced by encoding/json after wire transit. In the +// latter case, the value is re-encoded and decoded into the target type. +func decodeMetaValue[T any](m map[string]any, key string) (T, bool) { + var zero T + raw, ok := m[key] + if !ok || raw == nil { + return zero, false + } + if v, ok := raw.(T); ok { + return v, true + } + var v T + if err := remarshal(raw, &v); err != nil { + return zero, false + } + return v, true +} + func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { return &ServerRequest[P]{Session: s, Params: p} } @@ -542,6 +693,9 @@ type Params interface { // isParams discourages implementation of Params outside of this package. isParams() + + // isNil returns true if the underlying value is nil. + isNil() bool } // RequestParams is a parameter (input) type for an MCP request. diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 23818f87..e4f563d1 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -4,6 +4,247 @@ package mcp +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +func TestValidateRequestMeta(t *testing.T) { + tests := []struct { + name string + method string + isNotification bool + params any + wantUsesNew bool + wantErrContains string + }{ + { + name: "no params: old protocol", + method: methodListTools, + params: nil, + wantUsesNew: false, + }, + { + name: "no _meta: old protocol", + method: methodCallTool, + params: map[string]any{"name": "x"}, + wantUsesNew: false, + }, + { + name: "_meta without protocolVersion: old protocol", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{"otherKey": "v"}, + "name": "x", + }, + wantUsesNew: false, + }, + { + name: "new protocol with all required fields", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "name": "x", + }, + wantUsesNew: true, + }, + { + name: "new protocol missing clientInfo", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientCapabilities: map[string]any{}, + }, + "name": "x", + }, + wantUsesNew: false, + wantErrContains: MetaKeyClientInfo, + }, + { + name: "new protocol missing clientCapabilities", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + }, + "name": "x", + }, + wantUsesNew: false, + wantErrContains: MetaKeyClientCapabilities, + }, + { + name: "notifications exempt from required fields", + method: notificationCancelled, + isNotification: true, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + }, + "requestId": "r1", + }, + wantUsesNew: true, + }, + { + name: "malformed _meta is ignored", + method: methodCallTool, + params: json.RawMessage(`{"_meta": "not an object", "name": "x"}`), + wantUsesNew: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var raw json.RawMessage + switch p := tc.params.(type) { + case json.RawMessage: + raw = p + default: + raw = mustMarshal(tc.params) + } + req := &jsonrpc.Request{Method: tc.method, Params: raw} + if !tc.isNotification { + req.ID = jsonrpc.ID{} + // Give the request an ID by parsing one. + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req.ID = id + } + + vmeta, err := validateRequestMeta(req) + usesNew := vmeta != nil && vmeta.usesNewProtocol + if usesNew != tc.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", usesNew, tc.wantUsesNew) + } + if tc.wantErrContains == "" { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErrContains) + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("expected *jsonrpc.Error, got %T: %v", err, err) + } + if jerr.Code != jsonrpc.CodeInvalidParams { + t.Errorf("error code = %d, want %d", jerr.Code, jsonrpc.CodeInvalidParams) + } + if !strings.Contains(jerr.Message, tc.wantErrContains) { + t.Errorf("error message %q does not contain %q", jerr.Message, tc.wantErrContains) + } + }) + } +} + +func TestServerRequest_PerRequestAccessors(t *testing.T) { + // A request carrying the new-protocol _meta fields populates the + // accessors with values from _meta. + caps := &ClientCapabilities{Sampling: &SamplingCapabilities{}} + info := &Implementation{Name: "c", Version: "1"} + params := &CallToolParamsRaw{ + Meta: Meta{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: info, + MetaKeyClientCapabilities: caps, + }, + Name: "x", + } + req := &ServerRequest[*CallToolParamsRaw]{Params: params} + if got := req.ProtocolVersion(); got != protocolVersion20260630 { + t.Errorf("ProtocolVersion = %q, want %q", got, protocolVersion20260630) + } + if got := req.ClientInfo(); got == nil || got.Name != "c" { + t.Errorf("ClientInfo = %+v, want Name=c", got) + } + if got := req.ClientCapabilities(); got == nil || got.Sampling == nil { + t.Errorf("ClientCapabilities = %+v, want non-nil Sampling", got) + } +} + +func TestServerRequest_PerRequestAccessors_FromJSON(t *testing.T) { + // Values arriving over the wire are JSON maps; the accessors should + // re-decode them into typed Go values. + raw := json.RawMessage(`{ + "_meta": { + "io.modelcontextprotocol/protocolVersion": "2026-06-30", + "io.modelcontextprotocol/clientInfo": {"name": "wire-client", "version": "9"}, + "io.modelcontextprotocol/clientCapabilities": {"sampling": {}} + }, + "name": "tool" + }`) + var params CallToolParamsRaw + if err := json.Unmarshal(raw, ¶ms); err != nil { + t.Fatal(err) + } + req := &ServerRequest[*CallToolParamsRaw]{Params: ¶ms} + if got, want := req.ProtocolVersion(), protocolVersion20260630; got != want { + t.Errorf("ProtocolVersion = %q, want %q", got, want) + } + gotInfo := req.ClientInfo() + wantInfo := &Implementation{Name: "wire-client", Version: "9"} + if diff := cmp.Diff(wantInfo, gotInfo); diff != "" { + t.Errorf("ClientInfo mismatch (-want +got):\n%s", diff) + } + gotCaps := req.ClientCapabilities() + if gotCaps == nil || gotCaps.Sampling == nil { + t.Errorf("ClientCapabilities = %+v, want non-nil Sampling", gotCaps) + } +} + +func TestServerRequest_PerRequestAccessors_FallbackToInitializeParams(t *testing.T) { + // With no _meta on the request, accessors must fall back to the + // session's InitializeParams (the old-protocol path). + ss := &ServerSession{} + ss.state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion20251125, + ClientInfo: &Implementation{Name: "old", Version: "0"}, + Capabilities: &ClientCapabilities{Elicitation: &ElicitationCapabilities{}}, + } + req := &ServerRequest[*CallToolParamsRaw]{ + Session: ss, + Params: &CallToolParamsRaw{Name: "x"}, + } + if got, want := req.ProtocolVersion(), protocolVersion20251125; got != want { + t.Errorf("ProtocolVersion fallback = %q, want %q", got, want) + } + if got := req.ClientInfo(); got == nil || got.Name != "old" { + t.Errorf("ClientInfo fallback = %+v, want Name=old", got) + } + if got := req.ClientCapabilities(); got == nil || got.Elicitation == nil { + t.Errorf("ClientCapabilities fallback = %+v, want non-nil Elicitation", got) + } +} + +func TestServerRequest_PerRequestAccessors_Empty(t *testing.T) { + // With no _meta and no session, accessors return zero values. + req := &ServerRequest[*CallToolParamsRaw]{ + Params: &CallToolParamsRaw{Name: "x"}, + } + if got := req.ProtocolVersion(); got != "" { + t.Errorf("ProtocolVersion = %q, want empty", got) + } + if got := req.ClientInfo(); got != nil { + t.Errorf("ClientInfo = %+v, want nil", got) + } + if got := req.ClientCapabilities(); got != nil { + t.Errorf("ClientCapabilities = %+v, want nil", got) + } +} + // TODO(v0.3.0): rewrite this test. // func TestToolValidate(t *testing.T) { // // Check that the tool returned from NewServerTool properly validates its input schema. diff --git a/mcp/streamable.go b/mcp/streamable.go index d3f3f4fa..f5e93b40 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -343,8 +343,14 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. return } + connectOpts, usesNewProtocol, err := h.ephemeralConnectOpts(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var sessionID string - if legacySessions { + if legacySessions && !usesNewProtocol { sessionID = req.Header.Get(sessionIDHeader) if sessionID == "" { sessionID = server.opts.GetSessionID() @@ -359,11 +365,6 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. logger: h.opts.Logger, } - connectOpts, err := h.ephemeralConnectOpts(req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } session, err := connectStreamable(req.Context(), server, transport, connectOpts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) @@ -389,10 +390,17 @@ func (h *StreamableHTTPHandler) serveStatelessLegacyDELETE(w http.ResponseWriter } // ephemeralConnectOpts peeks at the request body to determine whether it -// contains an initialize or initialized message. If not, default session state -// is constructed so that the session doesn't reject the request. +// contains an initialize or initialized message or whether the protocol version +// header indicates a protocol version >= 2026-06-30 (SEP-2575). +// +// For old-protocol requests, default session state is synthesized so that +// the session's init gate doesn't reject the request. +// // It is used for both stateless servers and stateful servers with no session ID. -func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*ServerSessionOptions, error) { +// +// The returned usesNewProtocol bool reports whether the protocol version +// header indicates a protocol version >= 2026-06-30 (SEP-2575). +func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *ServerSessionOptions, usesNewProtocol bool, err error) { protocolVersion := protocolVersionFromContext(req.Context()) if protocolVersion == "" { protocolVersion = protocolVersion20250326 @@ -401,7 +409,7 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*Server var hasInitialize, hasInitialized bool body, err := io.ReadAll(req.Body) if err != nil { - return nil, fmt.Errorf("failed to read body") + return nil, false, fmt.Errorf("failed to read body") } req.Body.Close() req.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -415,23 +423,28 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*Server case notificationInitialized: hasInitialized = true } + if protocolVersion >= protocolVersion20260630 { + usesNewProtocol = true + } } } } state := new(ServerSessionState) - if !hasInitialize { + // Only synthesize fake InitializeParams/InitializedParams for old-protocol + // requests. + if !hasInitialize && !usesNewProtocol { state.InitializeParams = &InitializeParams{ ProtocolVersion: protocolVersion, } } - if !hasInitialized { + if !hasInitialized && !usesNewProtocol { state.InitializedParams = new(InitializedParams) } state.LogLevel = "info" return &ServerSessionOptions{ State: state, - }, nil + }, usesNewProtocol, nil } func connectStreamable(ctx context.Context, server *Server, transport *StreamableServerTransport, opts *ServerSessionOptions) (*ServerSession, error) { @@ -576,7 +589,7 @@ func (h *StreamableHTTPHandler) serveStatefulPOST(w http.ResponseWriter, req *ht // that arrives before a session exists (e.g. initialize or ping) on a // server configured this way. if sessionID == "" { - connectOpts, err := h.ephemeralConnectOpts(req) + connectOpts, _, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -1279,6 +1292,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false var initializeProtocolVersion string + headerVersion := protocolVersionFromContext(req.Context()) for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -1296,6 +1310,39 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques initializeProtocolVersion = params.ProtocolVersion } } + // SEP-2575: requests carrying `_meta.protocolVersion` require the + // Mcp-Protocol-Version HTTP header to be present and to match the + // per-request `_meta.protocolVersion` value. + // The new (>= 2026-06-30) protocol is supported on the HTTP transport + // only when [StreamableHTTPOptions.Stateless] is true. + var metaVersion string + if meta := extractRequestMeta(jreq.Params); meta != nil { + metaVersion, _ = meta[MetaKeyProtocolVersion].(string) + } + if protocolVersion >= protocolVersion20260630 || metaVersion != "" { + if !c.stateless { + http.Error(w, fmt.Sprintf( + "Bad Request: protocol version %q is only supported on stateless HTTP servers (set StreamableHTTPOptions.Stateless = true)", + protocolVersion), + http.StatusBadRequest) + return + } + if headerVersion == "" { + http.Error(w, fmt.Sprintf( + "Bad Request: %s header is required for requests carrying %q", + protocolVersionHeader, MetaKeyProtocolVersion), + http.StatusBadRequest) + return + } + if headerVersion != metaVersion { + http.Error(w, fmt.Sprintf( + "Bad Request: %s header %q does not match request %s %q", + protocolVersionHeader, headerVersion, + MetaKeyProtocolVersion, metaVersion), + http.StatusBadRequest) + return + } + } // Include metadata for all requests (including notifications). jreq.Extra = &RequestExtra{ TokenInfo: tokenInfo, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2e54224..e566bf25 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1926,39 +1926,18 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { return &CallToolResult{}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() - initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: minVersionForStandardHeaders}) - initResp := resp(1, &InitializeResult{ - Capabilities: &ServerCapabilities{ - Logging: &LoggingCapabilities{}, - Tools: &ToolCapabilities{ListChanged: true}, - }, - ProtocolVersion: minVersionForStandardHeaders, - ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, - }, nil) - - initialize := streamableRequest{ - method: "POST", - messages: []jsonrpc.Message{initReq}, - wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{initResp}, - wantSessionID: true, - } - initialized := streamableRequest{ - method: "POST", - headers: http.Header{ - protocolVersionHeader: {minVersionForStandardHeaders}, - methodHeader: {notificationInitialized}, - }, - messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, - wantStatusCode: http.StatusAccepted, + testMeta := Meta{ + MetaKeyProtocolVersion: minVersionForStandardHeaders, + MetaKeyClientInfo: map[string]any{"name": "testClient", "version": "v1.0.0"}, + MetaKeyClientCapabilities: map[string]any{}, } testStreamableHandler(t, handler, []streamableRequest{ - initialize, - initialized, { method: "POST", headers: http.Header{ @@ -1966,7 +1945,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, }, @@ -1977,7 +1956,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"prompts/get"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Method header value", }, @@ -1988,7 +1967,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"wrong-tool"}, }, - messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Name header value", }, @@ -1999,7 +1978,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"TOOLS/CALL"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Method header value", }, @@ -2010,7 +1989,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, }, @@ -2023,6 +2002,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { paramHeaderPrefix + "Region": {"us-west1"}, }, messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, })}, @@ -2038,6 +2018,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { paramHeaderPrefix + "Region": {"eu-central1"}, }, messages: []jsonrpc.Message{req(8, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1"}, })}, @@ -2052,6 +2033,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { nameHeader: {"execute_sql"}, }, messages: []jsonrpc.Message{req(9, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1"}, })}, @@ -2061,6 +2043,68 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }) } +// TODO: Remove this once client operations will automatically inject metadata in the requests +func injectMetaToRequest(req *http.Request) error { + if req.Body == nil { + return nil + } + body, err := io.ReadAll(req.Body) + if err != nil { + return err + } + req.Body.Close() + + var val any + if err := json.Unmarshal(body, &val); err == nil { + var method string + if m, ok := val.(map[string]any); ok { + method, _ = m["method"].(string) + } else if list, ok := val.([]any); ok && len(list) > 0 { + if m, ok := list[0].(map[string]any); ok { + method, _ = m["method"].(string) + } + } + + if method == "initialize" || method == "notifications/initialized" || strings.HasPrefix(method, "notifications/") { + req.Header.Set(protocolVersionHeader, "2025-11-25") + } else { + req.Header.Set(protocolVersionHeader, minVersionForStandardHeaders) + + var msgs []map[string]any + if m, ok := val.(map[string]any); ok { + msgs = []map[string]any{m} + } else if list, ok := val.([]any); ok { + for _, item := range list { + if m, ok := item.(map[string]any); ok { + msgs = append(msgs, m) + } + } + } + + for _, m := range msgs { + params, _ := m["params"].(map[string]any) + if params == nil { + params = make(map[string]any) + m["params"] = params + } + meta, _ := params["_meta"].(map[string]any) + if meta == nil { + meta = make(map[string]any) + params["_meta"] = meta + } + meta[MetaKeyProtocolVersion] = minVersionForStandardHeaders + meta[MetaKeyClientInfo] = map[string]any{"name": "testClient", "version": "v1.0.0"} + meta[MetaKeyClientCapabilities] = map[string]any{} + } + body, _ = json.Marshal(val) + } + } + + req.Body = io.NopCloser(bytes.NewReader(body)) + req.ContentLength = int64(len(body)) + return nil +} + // TestStreamableMcpHeaderValidationErrorFormat verifies that header // validation errors return a JSON-RPC error with code -32001 and // Content-Type application/json, per SEP-2243. @@ -2076,7 +2120,9 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { return &CallToolResult{}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) @@ -2088,6 +2134,9 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } var originalMethodHeader string if req.Header.Get(methodHeader) == "tools/call" { originalMethodHeader = req.Header.Get(methodHeader) @@ -2237,7 +2286,9 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() @@ -2245,6 +2296,9 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { var capturedHeaders http.Header customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } if req.Header.Get(methodHeader) == "tools/call" { capturedHeaders = req.Header.Clone() } @@ -2347,14 +2401,28 @@ func TestStreamableFilterValidToolsIntegration(t *testing.T) { InputSchema: &jsonschema.Schema{Type: "object"}, }, noop) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } + return http.DefaultTransport.RoundTrip(req) + }), + } + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) ctx := context.Background() - session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + session, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + }, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) if err != nil { t.Fatal(err) } @@ -3207,3 +3275,314 @@ func TestStandaloneSSEEmitsCommentForHTTP2Flush(t *testing.T) { t.Fatal("timed out waiting for first SSE bytes; the standalone SSE stream must emit a DATA frame immediately so HTTP/2 reverse proxies don't buffer the HEADERS frame") } } + +// newProtocolBody builds a raw JSON body for a tools/call request that +// carries the >= 2026-06-30 per-request _meta fields. +func newProtocolBody(t *testing.T, toolName string, args any) []byte { + t.Helper() + rawArgs, err := json.Marshal(args) + if err != nil { + t.Fatal(err) + } + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "new-proto-client", "version": "9.9"}, + MetaKeyClientCapabilities: map[string]any{"sampling": map[string]any{}}, + }, + "name": toolName, + "arguments": json.RawMessage(rawArgs), + }, + }) + if err != nil { + t.Fatal(err) + } + return body +} + +func TestEphemeralConnectOpts(t *testing.T) { + mkReq := func(body []byte) *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + r.Header.Set("Content-Type", "application/json") + return r + } + + h := &StreamableHTTPHandler{opts: StreamableHTTPOptions{}} + + oldProtocolBody, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{"name": "x", "arguments": map[string]any{}}, + }) + if err != nil { + t.Fatal(err) + } + initializeBody, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": methodInitialize, + "params": map[string]any{"protocolVersion": protocolVersion20250618}, + }) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + body []byte + wantUsesNew bool + wantInitializeParams bool + wantInitializedParams bool + }{ + { + name: "new-protocol request: no synthetic state", + body: newProtocolBody(t, "x", struct{}{}), + wantUsesNew: true, + wantInitializeParams: false, + wantInitializedParams: false, + }, + { + name: "old-protocol request: synthetic state populated", + body: oldProtocolBody, + wantUsesNew: false, + wantInitializeParams: true, + wantInitializedParams: true, + }, + { + name: "initialize request: no synthetic InitializeParams", + body: initializeBody, + wantUsesNew: false, + wantInitializeParams: false, + wantInitializedParams: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mkReq(tt.body) + var pver string + if tt.wantUsesNew { + pver = protocolVersion20260630 + } else { + pver = protocolVersion20250326 + } + req.Header.Set(protocolVersionHeader, pver) + req = req.WithContext(context.WithValue(req.Context(), protocolVersionContextKey{}, pver)) + opts, usesNew, err := h.ephemeralConnectOpts(req) + if err != nil { + t.Fatal(err) + } + if usesNew != tt.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", usesNew, tt.wantUsesNew) + } + if got := opts.State.InitializeParams != nil; got != tt.wantInitializeParams { + t.Errorf("InitializeParams non-nil = %v, want %v (value = %+v)", + got, tt.wantInitializeParams, opts.State.InitializeParams) + } + if got := opts.State.InitializedParams != nil; got != tt.wantInitializedParams { + t.Errorf("InitializedParams non-nil = %v, want %v (value = %+v)", + got, tt.wantInitializedParams, opts.State.InitializedParams) + } + }) + } +} + +// statelessHandlerCapture builds a stateless server with a single tool whose +// handler captures everything we want to assert about the per-request view of +// the session and the new-protocol accessors. +type statelessHandlerCapture struct { + mu sync.Mutex + sessionInitParams *InitializeParams + reqProtocolVersion string + reqClientInfo *Implementation + reqClientCapabilities *ClientCapabilities +} + +func TestStreamableStateless_NewProtocolSession_NoFakeInit(t *testing.T) { + // SEP-2575: the MCP-Protocol-Version header is mandatory for new-protocol + // requests and must be a supported version. The 2026-06-30 version is + // not yet in the global list, so register it for the duration of the test. + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + capture := &statelessHandlerCapture{} + mcpServer := NewServer(testImpl, nil) + AddTool(mcpServer, &Tool{Name: "capture", Description: "captures request info"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + capture.mu.Lock() + defer capture.mu.Unlock() + capture.sessionInitParams = req.Session.InitializeParams() + capture.reqProtocolVersion = req.ProtocolVersion() + capture.reqClientInfo = req.ClientInfo() + capture.reqClientCapabilities = req.ClientCapabilities() + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return mcpServer }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + body := newProtocolBody(t, "capture", struct{}{}) + httpReq, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + httpReq.Header.Set(protocolVersionHeader, protocolVersion20260630) + // >= 2026-06-30 also requires the Mcp-Method and Mcp-Name standard + // headers (see streamable_headers.go). + httpReq.Header.Set(methodHeader, "tools/call") + httpReq.Header.Set(nameHeader, "capture") + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want 200; body = %s", resp.StatusCode, respBody) + } + + capture.mu.Lock() + defer capture.mu.Unlock() + if capture.sessionInitParams == nil { + t.Errorf("Session.InitializeParams() is nil, want populated initializeParams for new-protocol session") + } else { + if got, want := capture.sessionInitParams.ProtocolVersion, protocolVersion20260630; got != want { + t.Errorf("Session.InitializeParams().ProtocolVersion = %q, want %q", got, want) + } + if got, want := capture.sessionInitParams.ClientInfo.Name, "new-proto-client"; got != want { + t.Errorf("Session.InitializeParams().ClientInfo.Name = %q, want %q", got, want) + } + } + if got, want := capture.reqProtocolVersion, protocolVersion20260630; got != want { + t.Errorf("req.ProtocolVersion() = %q, want %q", got, want) + } + if capture.reqClientInfo == nil || capture.reqClientInfo.Name != "new-proto-client" { + t.Errorf("req.ClientInfo() = %+v, want Name=new-proto-client", capture.reqClientInfo) + } + if capture.reqClientCapabilities == nil || capture.reqClientCapabilities.Sampling == nil { + t.Errorf("req.ClientCapabilities() = %+v, want non-nil Sampling", capture.reqClientCapabilities) + } +} + +// TestStreamableStateful_RejectsNewProtocol verifies that a stateful HTTP +// server rejects requests carrying _meta.protocolVersion (i.e. >= 2026-06-30 +// requests) with HTTP 400. The new protocol is +// supported on HTTP only when StreamableHTTPOptions.Stateless=true. +func TestStreamableStateful_RejectsNewProtocol(t *testing.T) { + // Make 2026-06-30 a "known" version so that the request reaches servePOST + // (otherwise the early header validation at ServeHTTP rejects it). + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "noop"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // Initialize a legacy session first. + initBody := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`) + initReq, err := http.NewRequest(http.MethodPost, httpServer.URL, initBody) + if err != nil { + t.Fatal(err) + } + initReq.Header.Set("Content-Type", "application/json") + initReq.Header.Set("Accept", "application/json, text/event-stream") + initResp, err := http.DefaultClient.Do(initReq) + if err != nil { + t.Fatal(err) + } + io.Copy(io.Discard, initResp.Body) + initResp.Body.Close() + sessionID := initResp.Header.Get(sessionIDHeader) + if sessionID == "" { + t.Fatalf("initialize response missing %s header", sessionIDHeader) + } + + // Drive the existing session with a new-protocol request whose header and + // body agree. The cross-check passes; the stateful-rejection check fires. + body := newProtocolBody(t, "noop", struct{}{}) + req, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set(sessionIDHeader, sessionID) + req.Header.Set(protocolVersionHeader, protocolVersion20260630) + req.Header.Set(methodHeader, "tools/call") + req.Header.Set(nameHeader, "noop") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body = %s", resp.StatusCode, respBody) + } + if !strings.Contains(string(respBody), "stateless") { + t.Errorf("body = %q, want a message mentioning 'stateless'", respBody) + } +} + +// TestStreamableStateless_AcceptsNewProtocol is the positive control: +// confirms that a stateless server still accepts new-protocol requests +// (the rejection in TestStreamableStateful_RejectsNewProtocol must not +// fire on Stateless: true). +func TestStreamableStateless_AcceptsNewProtocol(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "noop"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + body := newProtocolBody(t, "noop", struct{}{}) + req, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set(protocolVersionHeader, protocolVersion20260630) + req.Header.Set(methodHeader, "tools/call") + req.Header.Set(nameHeader, "noop") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want 200; body = %s", resp.StatusCode, respBody) + } +}