diff --git a/mcp/streamable.go b/mcp/streamable.go index d3f3f4fa..c461dbe2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1322,6 +1322,27 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } + // If the client sent both a Mcp-Protocol-Version header and an initialize + // request with a protocolVersion field, verify they match. + if isInitialize && initializeProtocolVersion != "" && protocolVersion != "" && initializeProtocolVersion != protocolVersion { + resp := &jsonrpc.Response{ + Error: jsonrpc2.NewError(CodeHeaderMismatch, fmt.Sprintf("protocol version mismatch: Mcp-Protocol-Version header '%s' does not match body protocolVersion '%s'", protocolVersion, initializeProtocolVersion)), + } + // Find the initialize request to get its ID. + for _, msg := range incoming { + if jreq, ok := msg.(*jsonrpc.Request); ok && jreq.Method == methodInitialize { + resp.ID = jreq.ID + break + } + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if data, err := jsonrpc2.EncodeMessage(resp); err == nil { + w.Write(data) + } + return + } + // Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*) if !isBatch && len(incoming) == 1 { if err := validateMcpHeaders(req.Header, incoming[0], c.toolLookup); err != nil { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2e54224..ab161419 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1256,6 +1256,20 @@ func TestStreamableServerTransport(t *testing.T) { }, wantSessions: 0, }, + { + name: "protocol version mismatch", + requests: []streamableRequest{ + { + method: "POST", + headers: http.Header{protocolVersionHeader: {protocolVersion20251125}}, + messages: []jsonrpc.Message{req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20250618})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "protocol version mismatch", + wantSessionID: false, + }, + }, + wantSessions: 0, + }, } for _, test := range tests {