diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..0d0477a4 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -1022,17 +1022,56 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) // // The params.Arguments can be any value that marshals into a JSON object. func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) { + params = normalizeCallToolParams(params) + ctx = cs.toolCallContext(ctx, params.Name) + return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) +} + +// CallToolRaw is like [ClientSession.CallTool] but unmarshals the response +// into a [CallToolResultRaw], leaving Content and StructuredContent as raw +// JSON bytes. It is intended for proxy and gateway implementations that +// forward tool calls without paying the cost of decoding and re-encoding +// large content payloads. +// +// Outbound, gateway authors can wrap upstream Content items in [RawContent] +// to splice their bytes verbatim into a downstream [CallToolResult] without +// a second marshal. +func (cs *ClientSession) CallToolRaw(ctx context.Context, params *CallToolParams) (*CallToolResultRaw, error) { + params = normalizeCallToolParams(params) + ctx = cs.toolCallContext(ctx, params.Name) + req := &callToolRawRequest{ClientRequest: newClientRequest(cs, params)} + return handleSend[*CallToolResultRaw](ctx, methodCallTool, req) +} + +// callToolRawRequest is a ClientRequest variant that decodes the response +// into *CallToolResultRaw rather than the *CallToolResult registered for +// tools/call. It implements resultOverrider so the sending dispatcher +// allocates the correct destination while preserving middleware behavior. +type callToolRawRequest struct { + *ClientRequest[*CallToolParams] +} + +func (*callToolRawRequest) newResult() Result { return new(CallToolResultRaw) } + +// normalizeCallToolParams returns a non-nil CallToolParams whose Arguments +// field is safe to send over the wire (tools/call requires a JSON object). +func normalizeCallToolParams(params *CallToolParams) *CallToolParams { if params == nil { params = new(CallToolParams) } if params.Arguments == nil { - // Avoid sending nil over the wire. params.Arguments = map[string]any{} } - if tool := cs.getCachedTool(params.Name); tool != nil { - ctx = context.WithValue(ctx, toolContextKey, tool) + return params +} + +// toolCallContext returns a context that carries the cached Tool definition, +// if any, so that the transport layer can use it (e.g. for header injection). +func (cs *ClientSession) toolCallContext(ctx context.Context, name string) context.Context { + if tool := cs.getCachedTool(name); tool != nil { + return context.WithValue(ctx, toolContextKey, tool) } - return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) + return ctx } func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error { diff --git a/mcp/content.go b/mcp/content.go index 95ea40d8..a5bf68ac 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -15,10 +15,13 @@ import ( ) // A Content is a [TextContent], [ImageContent], [AudioContent], -// [ResourceLink], [EmbeddedResource], [ToolUseContent], or [ToolResultContent]. +// [ResourceLink], [EmbeddedResource], [ToolUseContent], [ToolResultContent], +// or [RawContent]. // // Note: [ToolUseContent] and [ToolResultContent] are only valid in sampling // message contexts (CreateMessageParams/CreateMessageResult). +// +// [RawContent] is an outbound-only passthrough type for gateways and proxies. type Content interface { MarshalJSON() ([]byte, error) fromWire(*wireContent) @@ -53,6 +56,24 @@ func (c *TextContent) fromWire(wire *wireContent) { c.Annotations = wire.Annotations } +// RawContent is a [Content] backed by pre-encoded JSON bytes. Its MarshalJSON +// returns Raw verbatim, allowing gateways and proxies to forward upstream +// tool result content items without typed re-encoding. +// +// RawContent is outbound-only; inbound parsing rebuilds typed Content values. +type RawContent struct { + Raw json.RawMessage +} + +func (c *RawContent) MarshalJSON() ([]byte, error) { + if len(c.Raw) == 0 { + return []byte("null"), nil + } + return c.Raw, nil +} + +func (c *RawContent) fromWire(*wireContent) {} + // ImageContent contains base64-encoded image data. type ImageContent struct { Meta Meta diff --git a/mcp/protocol.go b/mcp/protocol.go index 1646788a..5259e6fc 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -114,6 +114,25 @@ type CallToolResult struct { err error } +// CallToolResultRaw is the raw form of a [CallToolResult], returned by +// [ClientSession.CallToolRaw]. Its content fields are left as undecoded JSON, +// so that callers (such as gateways and proxies) can forward the payload +// without paying the cost of decoding and re-encoding [Content] values. +type CallToolResultRaw struct { + // Meta is reserved by the protocol to allow clients and servers to attach + // additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Content is the raw JSON content array from the wire. + Content json.RawMessage `json:"content"` + // StructuredContent is the raw JSON structured content from the wire, if any. + StructuredContent json.RawMessage `json:"structuredContent,omitempty"` + // IsError reports whether the tool call ended in an error. See + // [CallToolResult.IsError] for the full semantics. + IsError bool `json:"isError,omitempty"` +} + +func (*CallToolResultRaw) isResult() {} + // seterroroverwrite is a compatibility parameter that restores the pre-1.6.0 // behavior of [CallToolResult.SetError], where Content was always overwritten // with the error text. See the documentation for the mcpgodebug package for diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..e5d4a2c7 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -109,15 +109,29 @@ func defaultSendingMethodHandler(ctx context.Context, method string, req Request if strings.HasPrefix(method, "notifications/") { return nil, req.GetSession().getConn().Notify(ctx, method, params) } - // Create the result to unmarshal into. - // The concrete type of the result is the return type of the receiving function. - res := info.newResult() + // Create the result to unmarshal into. The concrete type is normally the + // return type registered on the receiving function, but a request may + // override it (e.g. CallToolRaw decodes into *CallToolResultRaw). + var res Result + if o, ok := req.(resultOverrider); ok { + res = o.newResult() + } else { + res = info.newResult() + } if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { return nil, err } return res, nil } +// resultOverrider lets a Request supply a custom Result destination, +// bypassing the methodInfo's registered newResult. This enables alternate +// decode shapes (e.g. raw-bytes variants) without registering a separate +// JSON-RPC method. +type resultOverrider interface { + newResult() Result +} + // Helper method to avoid typed nil. func orZero[T any, P *U, U any](p P) T { if p == nil { diff --git a/mcp/tool_test.go b/mcp/tool_test.go index dfd859be..e05da585 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -146,6 +146,256 @@ func TestToolErrorHandling(t *testing.T) { }) } +// TestCallToolRaw verifies that ClientSession.CallToolRaw returns raw JSON +// content for both structured and unstructured tool results, normalizes +// nil/empty arguments to a JSON object, and surfaces tool errors via IsError +// rather than protocol errors. +func TestCallToolRaw(t *testing.T) { + type echoIn struct { + Msg string `json:"msg"` + } + type echoOut struct { + Echo string `json:"echo"` + } + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "echo"}, func(_ context.Context, _ *CallToolRequest, in echoIn) (*CallToolResult, echoOut, error) { + return nil, echoOut{Echo: in.Msg}, nil + }) + AddTool(server, &Tool{Name: "boom"}, func(_ context.Context, _ *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { + return nil, nil, errors.New("tool failed") + }) + + ct, st := NewInMemoryTransports() + if _, err := server.Connect(context.Background(), st, nil); err != nil { + t.Fatal(err) + } + cs, err := NewClient(testImpl, nil).Connect(context.Background(), ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + ctx := context.Background() + + t.Run("structured", func(t *testing.T) { + got, err := cs.CallToolRaw(ctx, &CallToolParams{ + Name: "echo", + Arguments: map[string]any{"msg": "hello"}, + }) + if err != nil { + t.Fatalf("CallToolRaw failed: %v", err) + } + if got.IsError { + t.Errorf("unexpected IsError=true; content=%s", got.Content) + } + // StructuredContent should contain exactly the bytes the tool produced; + // no decode/re-encode should round-trip through Go types. + if want := `{"echo":"hello"}`; string(got.StructuredContent) != want { + t.Errorf("StructuredContent = %s, want %s", got.StructuredContent, want) + } + if len(got.Content) == 0 || got.Content[0] != '[' { + t.Errorf("Content = %q, want non-empty JSON array", got.Content) + } + }) + + t.Run("raw_arguments", func(t *testing.T) { + // Gateway-style use: pass raw JSON bytes through CallToolParams.Arguments + // without remarshaling them. + got, err := cs.CallToolRaw(ctx, &CallToolParams{ + Name: "echo", + Arguments: json.RawMessage(`{"msg":"raw"}`), + }) + if err != nil { + t.Fatalf("CallToolRaw failed: %v", err) + } + if want := `{"echo":"raw"}`; string(got.StructuredContent) != want { + t.Errorf("StructuredContent = %s, want %s", got.StructuredContent, want) + } + }) + + t.Run("nil_params", func(t *testing.T) { + got, err := cs.CallToolRaw(ctx, nil) + if err == nil { + t.Fatalf("CallToolRaw(nil) succeeded; want error for missing tool name; result=%+v", got) + } + }) + + t.Run("tool_error", func(t *testing.T) { + got, err := cs.CallToolRaw(ctx, &CallToolParams{Name: "boom"}) + if err != nil { + t.Fatalf("CallToolRaw failed: %v", err) + } + if !got.IsError { + t.Errorf("IsError = false, want true") + } + }) +} + +// TestRawContent_MarshalVerbatim verifies that marshaling a []Content +// composed of *RawContent items reproduces the underlying raw bytes +// verbatim, enabling gateway-style splicing without typed re-encoding. +func TestRawContent_MarshalVerbatim(t *testing.T) { + raw := json.RawMessage(`[{"type":"text","text":"hello"},{"type":"image","data":"AAA=","mimeType":"image/png"}]`) + + var items []json.RawMessage + if err := json.Unmarshal(raw, &items); err != nil { + t.Fatalf("unmarshal items: %v", err) + } + content := make([]Content, len(items)) + for i, b := range items { + content[i] = &RawContent{Raw: b} + } + + got, err := json.Marshal(content) + if err != nil { + t.Fatalf("marshal []Content: %v", err) + } + if string(got) != string(raw) { + t.Errorf("RawContent splice mismatch:\n got = %s\nwant = %s", got, raw) + } + + t.Run("nil_raw", func(t *testing.T) { + b, err := json.Marshal(&RawContent{}) + if err != nil { + t.Fatalf("marshal empty RawContent: %v", err) + } + if string(b) != "null" { + t.Errorf("empty RawContent = %s, want null", b) + } + }) +} + +// TestCallToolRaw_RunsSendingMiddleware locks in middleware parity between +// CallTool and CallToolRaw: a sending middleware registered on the client +// fires exactly once for each call regardless of which method is used. +func TestCallToolRaw_RunsSendingMiddleware(t *testing.T) { + type out struct { + N int `json:"n"` + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "n"}, func(_ context.Context, _ *CallToolRequest, _ struct{}) (*CallToolResult, out, error) { + return nil, out{N: 1}, nil + }) + + ct, st := NewInMemoryTransports() + if _, err := server.Connect(context.Background(), st, nil); err != nil { + t.Fatal(err) + } + + var calls int + c := NewClient(testImpl, nil) + c.AddSendingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == methodCallTool { + calls++ + } + return next(ctx, method, req) + } + }) + cs, err := c.Connect(context.Background(), ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + if _, err := cs.CallTool(context.Background(), &CallToolParams{Name: "n"}); err != nil { + t.Fatalf("CallTool: %v", err) + } + if got, want := calls, 1; got != want { + t.Fatalf("after CallTool, middleware fired %d times, want %d", got, want) + } + if _, err := cs.CallToolRaw(context.Background(), &CallToolParams{Name: "n"}); err != nil { + t.Fatalf("CallToolRaw: %v", err) + } + if got, want := calls, 2; got != want { + t.Fatalf("after CallToolRaw, middleware fired %d times, want %d", got, want) + } +} + +// TestCallToolRawPassthrough is the canonical "MCP gateway" example: an +// upstream tool result is forwarded through a gateway server using +// CallToolRaw inbound and RawContent for outbound splicing, so the +// downstream client observes byte-identical Content and StructuredContent +// without any typed re-encoding in the middle. +func TestCallToolRawPassthrough(t *testing.T) { + type out struct { + N int `json:"n"` + } + upstream := NewServer(&Implementation{Name: "upstream", Version: "v1"}, nil) + AddTool(upstream, &Tool{Name: "n"}, func(_ context.Context, _ *CallToolRequest, _ struct{}) (*CallToolResult, out, error) { + return nil, out{N: 7}, nil + }) + + uct, ust := NewInMemoryTransports() + if _, err := upstream.Connect(context.Background(), ust, nil); err != nil { + t.Fatal(err) + } + upstreamCS, err := NewClient(testImpl, nil).Connect(context.Background(), uct, nil) + if err != nil { + t.Fatal(err) + } + defer upstreamCS.Close() + + // Capture what upstream returns so we can assert byte-for-byte passthrough. + upstreamRaw, err := upstreamCS.CallToolRaw(context.Background(), &CallToolParams{Name: "n"}) + if err != nil { + t.Fatalf("upstream CallToolRaw: %v", err) + } + + gateway := NewServer(&Implementation{Name: "gateway", Version: "v1"}, nil) + gateway.AddTool(&Tool{Name: "n", InputSchema: &jsonschema.Schema{Type: "object"}}, func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + raw, err := upstreamCS.CallToolRaw(ctx, &CallToolParams{ + Name: req.Params.Name, + Arguments: req.Params.Arguments, + }) + if err != nil { + return nil, err + } + // Shallow-parse the Content array once and wrap each item in + // RawContent so the outbound marshal splices the bytes verbatim. + var items []json.RawMessage + if len(raw.Content) > 0 { + if err := json.Unmarshal(raw.Content, &items); err != nil { + return nil, err + } + } + content := make([]Content, len(items)) + for i, b := range items { + content[i] = &RawContent{Raw: b} + } + return &CallToolResult{ + Content: content, + StructuredContent: raw.StructuredContent, + IsError: raw.IsError, + }, nil + }) + + gct, gst := NewInMemoryTransports() + if _, err := gateway.Connect(context.Background(), gst, nil); err != nil { + t.Fatal(err) + } + gatewayCS, err := NewClient(testImpl, nil).Connect(context.Background(), gct, nil) + if err != nil { + t.Fatal(err) + } + defer gatewayCS.Close() + + got, err := gatewayCS.CallToolRaw(context.Background(), &CallToolParams{Name: "n"}) + if err != nil { + t.Fatalf("gateway CallToolRaw: %v", err) + } + if got.IsError { + t.Errorf("IsError = true, want false") + } + if string(got.StructuredContent) != string(upstreamRaw.StructuredContent) { + t.Errorf("StructuredContent mismatch:\n got = %s\nwant = %s", got.StructuredContent, upstreamRaw.StructuredContent) + } + if string(got.Content) != string(upstreamRaw.Content) { + t.Errorf("Content mismatch:\n got = %s\nwant = %s", got.Content, upstreamRaw.Content) + } +} + func TestValidateToolName(t *testing.T) { t.Run("valid", func(t *testing.T) { validTests := []struct {