From 5135d8964fcaddc2b082c67e2de9cf085e8c5cc0 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Tue, 19 May 2026 11:06:00 +0800 Subject: [PATCH] fix: reject initialize protocol version mismatch --- mcp/streamable.go | 21 ++++++++++++++++++++- mcp/streamable_test.go | 13 +++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index d3f3f4fa..c8a0697a 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1258,7 +1258,8 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } - protocolVersion := protocolVersionFromContext(req.Context()) + headerProtocolVersion := protocolVersionFromContext(req.Context()) + protocolVersion := headerProtocolVersion if protocolVersion == "" { protocolVersion = protocolVersion20250326 } @@ -1278,6 +1279,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false + var initializeID jsonrpc.ID var initializeProtocolVersion string for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { @@ -1290,6 +1292,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } if jreq.Method == methodInitialize { isInitialize = true + initializeID = jreq.ID // Extract the protocol version from InitializeParams. var params InitializeParams if err := internaljson.Unmarshal(jreq.Params, ¶ms); err == nil { @@ -1322,6 +1325,22 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } + if headerProtocolVersion != "" && initializeProtocolVersion != "" && headerProtocolVersion != initializeProtocolVersion { + resp := &jsonrpc.Response{ + ID: initializeID, + Error: jsonrpc2.NewError( + CodeHeaderMismatch, + fmt.Sprintf("header mismatch: %s header value %q does not match body protocolVersion %q", protocolVersionHeader, headerProtocolVersion, initializeProtocolVersion), + ), + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if data, err := jsonrpc2.EncodeMessage(resp); err == nil { + w.Write(data) + } + return + } + // Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*) if !isBatch && len(incoming) == 1 { if err := validateMcpHeaders(req.Header, incoming[0], c.toolLookup); err != nil { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2e54224..25cb0242 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -966,6 +966,19 @@ func TestStreamableServerTransport(t *testing.T) { }, wantSessions: 1, }, + { + name: "initialize protocol version header mismatch", + requests: []streamableRequest{ + { + method: "POST", + headers: http.Header{protocolVersionHeader: {protocolVersion20251125}}, + messages: []jsonrpc.Message{req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20250618})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "header mismatch", + }, + }, + wantSessions: 0, + }, { name: "batch rejected on 2025-06-18", requests: []streamableRequest{