diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 84b66129a..df649f9ab 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -839,6 +839,7 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events EventSink tools.ConfigureHandlers(toolset, r.elicitationHandler, r.samplingHandler, + r.samplingWithToolsHandler, func() { events.Emit(Authorization(tools.ElicitationActionAccept, a.Name())) }, r.managedOAuth, r.unmanagedOAuthRedirectURI, diff --git a/pkg/runtime/sampling.go b/pkg/runtime/sampling.go index c1ff91559..360c0d908 100644 --- a/pkg/runtime/sampling.go +++ b/pkg/runtime/sampling.go @@ -3,6 +3,7 @@ package runtime import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" "io" @@ -14,6 +15,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/tools" ) // Limits applied to inbound sampling requests to keep a misbehaving or @@ -29,6 +31,13 @@ const ( // maxSamplingBinaryBytes caps the size of an individual image/audio // block before we refuse to inline it as a data URL. maxSamplingBinaryBytes = 8 << 20 // 8 MiB + // maxSamplingTools caps the number of tool definitions a server can + // inject into a single sampling-with-tools request. + maxSamplingTools = 64 + // maxSamplingToolCalls caps the number of tool calls we will return + // from a single sampling-with-tools completion. Per-call argument size + // is bounded by maxSamplingTextBytes. + maxSamplingToolCalls = 32 ) // samplingHandler is the MCP-toolset-side hook that satisfies an inbound @@ -214,12 +223,21 @@ func dataURL(mimeType string, data []byte) string { // reshape the model's reply into something the MCP server didn't ask // for. func samplingModelOptions(req *mcp.CreateMessageParams) []options.Opt { + return samplingModelOptionsFor(req.MaxTokens) +} + +// samplingModelOptionsFor returns the per-request model options shared by the +// basic and with-tools sampling handlers. Structured output is cleared so a +// request cannot inherit the agent's JSON-schema response format; thinking is +// disabled because sampling is a delegated one-shot call rather than an agent +// turn; MaxTokens is honoured when non-zero. +func samplingModelOptionsFor(maxTokens int64) []options.Opt { opts := []options.Opt{ options.WithStructuredOutput(nil), options.WithNoThinking(), } - if req.MaxTokens > 0 { - opts = append(opts, options.WithMaxTokens(req.MaxTokens)) + if maxTokens > 0 { + opts = append(opts, options.WithMaxTokens(maxTokens)) } return opts } @@ -265,3 +283,361 @@ func stopReason(fr chat.FinishReason) string { return "endTurn" } } + +// samplingWithToolsHandler is the MCP-toolset-side hook that satisfies an +// inbound sampling/createMessage request that carries a tools array. It +// forwards the server-supplied tool definitions to the host's model and +// returns any tool_use blocks the model emits; the requesting MCP server +// then executes the tool and continues the loop in a follow-up sampling +// request with tool_result blocks added. +// +// The host never executes the server-supplied tools itself — they exist +// only to inform the model's response. The placeholder handler attached to +// each converted tool surfaces an error if a downstream call site mistakes +// these for ordinary agent tools. +func (r *LocalRuntime) samplingWithToolsHandler(ctx context.Context, req *mcp.CreateMessageWithToolsParams) (*mcp.CreateMessageWithToolsResult, error) { + if req == nil { + return nil, errors.New("sampling request is nil") + } + if len(req.Tools) > maxSamplingTools { + return nil, fmt.Errorf("sampling request includes %d tools (limit %d)", + len(req.Tools), maxSamplingTools) + } + + slog.InfoContext(ctx, "Sampling-with-tools request received from MCP server", + "messages", len(req.Messages), + "tools", len(req.Tools), + "max_tokens", req.MaxTokens, + "system_prompt", req.SystemPrompt != "", + ) + + a := r.CurrentAgent() + if a == nil { + return nil, errors.New("no current agent available to handle sampling request") + } + + messages, err := samplingMessagesV2ToChat(req) + if err != nil { + return nil, fmt.Errorf("converting sampling messages: %w", err) + } + + chatTools := samplingToolsToChat(req.Tools) + + baseModel := a.Model(ctx) + if baseModel == nil { + return nil, errors.New("current agent has no model configured") + } + + model := provider.CloneWithOptions(ctx, baseModel, samplingModelOptionsFor(req.MaxTokens)...) + + stream, err := model.CreateChatCompletionStream(ctx, messages, chatTools) + if err != nil { + return nil, fmt.Errorf("creating sampling completion stream: %w", err) + } + + text, toolCalls, finishReason, err := drainSamplingStreamWithTools(stream) + if err != nil { + return nil, fmt.Errorf("reading sampling completion stream: %w", err) + } + + if len(toolCalls) > maxSamplingToolCalls { + return nil, fmt.Errorf("model emitted %d tool calls (limit %d)", + len(toolCalls), maxSamplingToolCalls) + } + + sr := stopReason(finishReason) + if len(toolCalls) > 0 { + sr = "toolUse" + } + + slog.DebugContext(ctx, "Sampling-with-tools request completed", + "agent", a.Name(), + "model", model.ID().String(), + "finish_reason", finishReason, + "stop_reason", sr, + "tool_calls", len(toolCalls), + "content_bytes", len(text), + ) + + return &mcp.CreateMessageWithToolsResult{ + Role: mcp.Role("assistant"), + Model: model.ID().String(), + Content: buildSamplingWithToolsContent(text, toolCalls), + StopReason: sr, + }, nil +} + +// samplingMessagesV2ToChat converts a CreateMessageWithToolsParams (V2 +// messages with multi-block content) into chat.Messages. The optional system +// prompt is prepended; per-message blocks are folded into one or more chat +// messages depending on which content types are present. +func samplingMessagesV2ToChat(req *mcp.CreateMessageWithToolsParams) ([]chat.Message, error) { + if len(req.Messages) == 0 { + return nil, errors.New("sampling request contains no messages") + } + if len(req.Messages) > maxSamplingMessages { + return nil, fmt.Errorf("sampling request contains %d messages (limit %d)", + len(req.Messages), maxSamplingMessages) + } + + messages := make([]chat.Message, 0, len(req.Messages)+1) + if req.SystemPrompt != "" { + if len(req.SystemPrompt) > maxSamplingTextBytes { + return nil, fmt.Errorf("sampling system prompt is too large (%d bytes, limit %d)", + len(req.SystemPrompt), maxSamplingTextBytes) + } + messages = append(messages, chat.Message{ + Role: chat.MessageRoleSystem, + Content: req.SystemPrompt, + }) + } + for i, m := range req.Messages { + if m == nil { + return nil, fmt.Errorf("sampling message at index %d is nil", i) + } + role, err := samplingRoleToChat(m.Role) + if err != nil { + return nil, err + } + converted, err := samplingV2BlocksToMessages(role, m.Content) + if err != nil { + return nil, fmt.Errorf("sampling message at index %d: %w", i, err) + } + messages = append(messages, converted...) + } + return messages, nil +} + +// samplingV2BlocksToMessages converts a single V2 message's content blocks +// into one or more chat.Messages. Plain blocks (text, image, audio) collapse +// into a single message at the supplied role; tool_use blocks attach as +// ToolCalls on an assistant message; tool_result blocks expand into one +// MessageRoleTool row per result (matching how chat history represents +// parallel tool calls). +func samplingV2BlocksToMessages(role chat.MessageRole, blocks []mcp.Content) ([]chat.Message, error) { + var text strings.Builder + var parts []chat.MessagePart + var toolCalls []tools.ToolCall + var toolResults []chat.Message + + for _, c := range blocks { + switch v := c.(type) { + case nil: + continue + case *mcp.ToolUseContent: + args, err := json.Marshal(v.Input) + if err != nil { + args = []byte("{}") + } + toolCalls = append(toolCalls, tools.ToolCall{ + ID: v.ID, + Type: "function", + Function: tools.FunctionCall{ + Name: v.Name, + Arguments: string(args), + }, + }) + case *mcp.ToolResultContent: + resultText, err := samplingToolResultText(v.Content) + if err != nil { + return nil, fmt.Errorf("tool_result content: %w", err) + } + toolResults = append(toolResults, chat.Message{ + Role: chat.MessageRoleTool, + Content: resultText, + ToolCallID: v.ToolUseID, + IsError: v.IsError, + }) + default: + t, p, err := samplingContentToChat(c) + if err != nil { + return nil, err + } + if t != "" { + if text.Len() > 0 { + text.WriteString("\n") + } + text.WriteString(t) + } + parts = append(parts, p...) + } + } + + var out []chat.Message + if text.Len() > 0 || len(parts) > 0 || (len(toolCalls) > 0 && role == chat.MessageRoleAssistant) { + msg := chat.Message{ + Role: role, + Content: text.String(), + } + if len(parts) > 0 { + msg.MultiContent = parts + } + if role == chat.MessageRoleAssistant && len(toolCalls) > 0 { + msg.ToolCalls = toolCalls + } + out = append(out, msg) + } + out = append(out, toolResults...) + return out, nil +} + +// samplingToolResultText flattens the nested content of a tool_result block +// into a single text string. chat.MessageRoleTool messages don't carry +// multi-part content, so non-text blocks render as a placeholder. +func samplingToolResultText(blocks []mcp.Content) (string, error) { + var b strings.Builder + var nonText int + for _, c := range blocks { + t, parts, err := samplingContentToChat(c) + if err != nil { + return "", err + } + if t != "" { + if b.Len() > 0 { + b.WriteString("\n") + } + b.WriteString(t) + } + nonText += len(parts) + } + if b.Len() == 0 && nonText > 0 { + b.WriteString("[tool returned non-text content]") + } + return b.String(), nil +} + +// samplingToolsToChat converts the server-supplied MCP tool definitions into +// the host's tools.Tool representation so the model can be told which tools +// it may call. The Handler is a no-op: the LLM's tool_use response is sent +// back to the requesting MCP server for execution, never invoked here. +func samplingToolsToChat(mcpTools []*mcp.Tool) []tools.Tool { + if len(mcpTools) == 0 { + return nil + } + out := make([]tools.Tool, 0, len(mcpTools)) + for _, t := range mcpTools { + if t == nil { + continue + } + out = append(out, tools.Tool{ + Name: t.Name, + Category: "mcp-sampling", + Description: t.Description, + Parameters: t.InputSchema, + OutputSchema: t.OutputSchema, + Handler: noOpSamplingToolHandler, + }) + } + return out +} + +func noOpSamplingToolHandler(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { + return tools.ResultError("sampling tool execution belongs to the requesting MCP server"), nil +} + +// drainSamplingStreamWithTools reads a chat completion stream to completion +// and returns the concatenated assistant text, aggregated tool calls, and +// the final finish reason. It mirrors the tool-call aggregation in +// pkg/runtime/streaming.go::handleStream but omits agent events, telemetry, +// session bookkeeping, and the XML fallback — none of which apply to a +// one-shot delegated completion. +func drainSamplingStreamWithTools(stream chat.MessageStream) (string, []tools.ToolCall, chat.FinishReason, error) { + defer stream.Close() + + var text strings.Builder + var toolCalls []tools.ToolCall + toolIndex := make(map[string]int) + var providerFinishReason chat.FinishReason + + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return "", nil, "", err + } + if len(response.Choices) == 0 { + continue + } + choice := response.Choices[0] + + if choice.Delta.Content != "" { + text.WriteString(choice.Delta.Content) + } + + for _, delta := range choice.Delta.ToolCalls { + idx, ok := toolIndex[delta.ID] + if !ok { + idx = len(toolCalls) + toolIndex[delta.ID] = idx + toolCalls = append(toolCalls, tools.ToolCall{ + ID: delta.ID, + Type: delta.Type, + }) + } + tc := &toolCalls[idx] + if delta.Type != "" { + tc.Type = delta.Type + } + if delta.Function.Name != "" { + tc.Function.Name = delta.Function.Name + } + if delta.Function.Arguments != "" { + tc.Function.Arguments += delta.Function.Arguments + } + } + + if choice.FinishReason != "" { + providerFinishReason = choice.FinishReason + } + if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength { + break + } + } + + finishReason := providerFinishReason + if finishReason == "" { + switch { + case len(toolCalls) > 0: + finishReason = chat.FinishReasonToolCalls + case text.Len() > 0: + finishReason = chat.FinishReasonStop + default: + finishReason = chat.FinishReasonNull + } + } + switch { + case finishReason == chat.FinishReasonToolCalls && len(toolCalls) == 0: + finishReason = chat.FinishReasonNull + case finishReason == chat.FinishReasonStop && len(toolCalls) > 0: + finishReason = chat.FinishReasonToolCalls + } + + return text.String(), toolCalls, finishReason, nil +} + +// buildSamplingWithToolsContent assembles the assistant response Content +// slice. Any leading text becomes a TextContent block; each tool call +// becomes a ToolUseContent block with the function arguments parsed as a +// JSON object. Malformed arguments fall back to an empty input map so the +// server still sees the call (and can report a tool-side validation error) +// rather than the loop terminating on the client. +func buildSamplingWithToolsContent(text string, toolCalls []tools.ToolCall) []mcp.Content { + var blocks []mcp.Content + if strings.TrimSpace(text) != "" { + blocks = append(blocks, &mcp.TextContent{Text: text}) + } + for _, tc := range toolCalls { + input := map[string]any{} + if tc.Function.Arguments != "" { + _ = json.Unmarshal([]byte(tc.Function.Arguments), &input) + } + blocks = append(blocks, &mcp.ToolUseContent{ + ID: tc.ID, + Name: tc.Function.Name, + Input: input, + }) + } + return blocks +} diff --git a/pkg/runtime/sampling_test.go b/pkg/runtime/sampling_test.go index 70f55ca4b..304dafa59 100644 --- a/pkg/runtime/sampling_test.go +++ b/pkg/runtime/sampling_test.go @@ -1,6 +1,8 @@ package runtime import ( + "encoding/json" + "io" "strings" "testing" @@ -9,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/tools" ) func TestSamplingMessagesToChat(t *testing.T) { @@ -219,3 +222,419 @@ func TestDataURL(t *testing.T) { assert.Equal(t, "data:image/png;base64,UE5HQllURVM=", dataURL("image/png", []byte("PNGBYTES"))) assert.Equal(t, "data:application/octet-stream;base64,YQ==", dataURL("", []byte("a"))) } + +func TestSamplingMessagesV2ToChat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req *mcp.CreateMessageWithToolsParams + want []chat.Message + wantErr bool + }{ + { + name: "single user text block", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, + }, + { + name: "system prompt is prepended", + req: &mcp.CreateMessageWithToolsParams{ + SystemPrompt: "be terse", + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{&mcp.TextContent{Text: "hi"}}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "be terse"}, + {Role: chat.MessageRoleUser, Content: "hi"}, + }, + }, + { + name: "multiple text blocks are concatenated", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.TextContent{Text: "first"}, + &mcp.TextContent{Text: "second"}, + }}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "first\nsecond"}, + }, + }, + { + name: "text and image in one message", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.TextContent{Text: "describe"}, + &mcp.ImageContent{Data: []byte("PNG"), MIMEType: "image/png"}, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleUser, + Content: "describe", + MultiContent: []chat.MessagePart{{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,UE5H"}, + }}, + }, + }, + }, + { + name: "tool_use becomes assistant ToolCalls", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "assistant", Content: []mcp.Content{ + &mcp.ToolUseContent{ + ID: "call_1", + Name: "get_weather", + Input: map[string]any{"city": "Paris"}, + }, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleAssistant, + ToolCalls: []tools.ToolCall{{ + ID: "call_1", + Type: "function", + Function: tools.FunctionCall{ + Name: "get_weather", + Arguments: `{"city":"Paris"}`, + }, + }}, + }, + }, + }, + { + name: "tool_result expands to tool-role message", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.ToolResultContent{ + ToolUseID: "call_1", + Content: []mcp.Content{&mcp.TextContent{Text: "sunny, 22C"}}, + }, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleTool, + Content: "sunny, 22C", + ToolCallID: "call_1", + }, + }, + }, + { + name: "tool_result IsError surfaces", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.ToolResultContent{ + ToolUseID: "call_1", + Content: []mcp.Content{&mcp.TextContent{Text: "no such city"}}, + IsError: true, + }, + }}, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleTool, + Content: "no such city", + ToolCallID: "call_1", + IsError: true, + }, + }, + }, + { + name: "parallel tool_results expand to multiple rows", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.ToolResultContent{ToolUseID: "a", Content: []mcp.Content{&mcp.TextContent{Text: "1"}}}, + &mcp.ToolResultContent{ToolUseID: "b", Content: []mcp.Content{&mcp.TextContent{Text: "2"}}}, + }}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleTool, Content: "1", ToolCallID: "a"}, + {Role: chat.MessageRoleTool, Content: "2", ToolCallID: "b"}, + }, + }, + { + name: "too many messages is rejected", + req: &mcp.CreateMessageWithToolsParams{ + Messages: tooManyV2Messages(maxSamplingMessages + 1), + }, + wantErr: true, + }, + { + name: "nil message entry is rejected", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{nil}, + }, + wantErr: true, + }, + { + name: "oversize text block is rejected", + req: &mcp.CreateMessageWithToolsParams{ + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{ + &mcp.TextContent{Text: strings.Repeat("a", maxSamplingTextBytes+1)}, + }}, + }, + }, + wantErr: true, + }, + { + name: "empty messages is rejected", + req: &mcp.CreateMessageWithToolsParams{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := samplingMessagesV2ToChat(tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func tooManyV2Messages(n int) []*mcp.SamplingMessageV2 { + out := make([]*mcp.SamplingMessageV2, n) + for i := range out { + out[i] = &mcp.SamplingMessageV2{ + Role: "user", + Content: []mcp.Content{&mcp.TextContent{Text: "x"}}, + } + } + return out +} + +func TestSamplingToolsToChat(t *testing.T) { + t.Parallel() + + t.Run("nil input returns nil", func(t *testing.T) { + t.Parallel() + assert.Nil(t, samplingToolsToChat(nil)) + }) + + t.Run("converts and preserves schema", func(t *testing.T) { + t.Parallel() + schema := map[string]any{"type": "object"} + got := samplingToolsToChat([]*mcp.Tool{ + {Name: "lookup", Description: "look up a thing", InputSchema: schema}, + nil, // skipped + {Name: "other"}, + }) + require.Len(t, got, 2) + assert.Equal(t, "lookup", got[0].Name) + assert.Equal(t, "mcp-sampling", got[0].Category) + assert.Equal(t, "look up a thing", got[0].Description) + assert.Equal(t, schema, got[0].Parameters) + assert.NotNil(t, got[0].Handler) + assert.Equal(t, "other", got[1].Name) + }) + + t.Run("noOp handler returns error result", func(t *testing.T) { + t.Parallel() + res, err := noOpSamplingToolHandler(t.Context(), tools.ToolCall{}) + require.NoError(t, err) + require.NotNil(t, res) + assert.True(t, res.IsError) + }) +} + +func TestBuildSamplingWithToolsContent(t *testing.T) { + t.Parallel() + + t.Run("text only", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent("hello world", nil) + require.Len(t, got, 1) + text, ok := got[0].(*mcp.TextContent) + require.True(t, ok) + assert.Equal(t, "hello world", text.Text) + }) + + t.Run("tool calls only — empty text is dropped", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent(" ", []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn", Arguments: `{"x":1}`}}, + }) + require.Len(t, got, 1) + tu, ok := got[0].(*mcp.ToolUseContent) + require.True(t, ok) + assert.Equal(t, "a", tu.ID) + assert.Equal(t, "fn", tu.Name) + assert.Equal(t, map[string]any{"x": float64(1)}, tu.Input) + }) + + t.Run("text plus parallel tool calls", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent("ok", []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn1", Arguments: `{}`}}, + {ID: "b", Function: tools.FunctionCall{Name: "fn2", Arguments: `{}`}}, + }) + require.Len(t, got, 3) + _, isText := got[0].(*mcp.TextContent) + _, isToolA := got[1].(*mcp.ToolUseContent) + _, isToolB := got[2].(*mcp.ToolUseContent) + assert.True(t, isText) + assert.True(t, isToolA) + assert.True(t, isToolB) + }) + + t.Run("malformed JSON args fall back to empty input", func(t *testing.T) { + t.Parallel() + got := buildSamplingWithToolsContent("", []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn", Arguments: `not json`}}, + }) + require.Len(t, got, 1) + tu, ok := got[0].(*mcp.ToolUseContent) + require.True(t, ok) + assert.Equal(t, map[string]any{}, tu.Input) + }) +} + +func TestSamplingWithToolsHandler_LimitRejection(t *testing.T) { + t.Parallel() + + r := &LocalRuntime{} + _, err := r.samplingWithToolsHandler(t.Context(), &mcp.CreateMessageWithToolsParams{ + Tools: make([]*mcp.Tool, maxSamplingTools+1), + Messages: []*mcp.SamplingMessageV2{ + {Role: "user", Content: []mcp.Content{&mcp.TextContent{Text: "hi"}}}, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "tools") +} + +// fakeStream feeds a fixed sequence of MessageStreamResponse values into +// drainSamplingStreamWithTools for unit testing. +type fakeStream struct { + responses []chat.MessageStreamResponse + idx int + closed bool +} + +func (f *fakeStream) Recv() (chat.MessageStreamResponse, error) { + if f.idx >= len(f.responses) { + return chat.MessageStreamResponse{}, io.EOF + } + resp := f.responses[f.idx] + f.idx++ + return resp, nil +} + +func (f *fakeStream) Close() { + f.closed = true +} + +func TestDrainSamplingStreamWithTools(t *testing.T) { + t.Parallel() + + t.Run("plain text completion", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{Content: "hello "}}}}, + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{Content: "world"}, FinishReason: chat.FinishReasonStop}}}, + }} + text, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + assert.Equal(t, "hello world", text) + assert.Empty(t, calls) + assert.Equal(t, chat.FinishReasonStop, fr) + assert.True(t, s.closed) + }) + + t.Run("tool call aggregation across chunks", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "c1", Type: "function", Function: tools.FunctionCall{Name: "fn", Arguments: `{"a":`}}, + }}}}}, + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "c1", Function: tools.FunctionCall{Arguments: `1}`}}, + }}, FinishReason: chat.FinishReasonToolCalls}}}, + }} + text, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + assert.Empty(t, text) + require.Len(t, calls, 1) + assert.Equal(t, "c1", calls[0].ID) + assert.Equal(t, "fn", calls[0].Function.Name) + assert.Equal(t, `{"a":1}`, calls[0].Function.Arguments) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + // Sanity-check that the JSON we accumulated is parseable. + var v map[string]any + require.NoError(t, json.Unmarshal([]byte(calls[0].Function.Arguments), &v)) + }) + + t.Run("parallel tool calls collected by ID", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "a", Function: tools.FunctionCall{Name: "fn1", Arguments: `{}`}}, + {ID: "b", Function: tools.FunctionCall{Name: "fn2", Arguments: `{}`}}, + }}, FinishReason: chat.FinishReasonToolCalls}}}, + }} + _, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + require.Len(t, calls, 2) + assert.Equal(t, "a", calls[0].ID) + assert.Equal(t, "b", calls[1].ID) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + }) + + t.Run("inferred tool_calls when provider omits finish reason", func(t *testing.T) { + t.Parallel() + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "x", Function: tools.FunctionCall{Name: "fn", Arguments: `{}`}}, + }}}}}, + }} + _, calls, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + require.Len(t, calls, 1) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + }) + + t.Run("stop reconciled to tool_calls when calls present", func(t *testing.T) { + t.Parallel() + // Provider says "stop" but also emits tool calls — reconciliation + // should treat this as a tool-call turn (the early-exit on stop fires + // in handleStream-style aggregation, then reconciliation upgrades). + s := &fakeStream{responses: []chat.MessageStreamResponse{ + {Choices: []chat.MessageStreamChoice{{Delta: chat.MessageDelta{ToolCalls: []tools.ToolCall{ + {ID: "x", Function: tools.FunctionCall{Name: "fn", Arguments: `{}`}}, + }}, FinishReason: chat.FinishReasonStop}}}, + }} + _, _, fr, err := drainSamplingStreamWithTools(s) + require.NoError(t, err) + assert.Equal(t, chat.FinishReasonToolCalls, fr) + }) +} diff --git a/pkg/tools/capabilities.go b/pkg/tools/capabilities.go index 878c7feb9..3609f44f0 100644 --- a/pkg/tools/capabilities.go +++ b/pkg/tools/capabilities.go @@ -53,6 +53,14 @@ type Sampleable interface { SetSamplingHandler(handler SamplingHandler) } +// SampleableWithTools is implemented by toolsets that support MCP sampling +// requests carrying a tools array (sampling-with-tools). The handler is +// invoked instead of the basic SamplingHandler when both are registered and +// the SDK negotiates the with-tools variant on the wire. +type SampleableWithTools interface { + SetSamplingWithToolsHandler(handler SamplingWithToolsHandler) +} + // OAuthCapable is implemented by toolsets that support OAuth flows. type OAuthCapable interface { SetOAuthSuccessHandler(handler func()) @@ -81,16 +89,19 @@ type ChangeNotifier interface { } // ConfigureHandlers sets all applicable handlers on a toolset. -// It checks for Elicitable, Sampleable and OAuthCapable interfaces and -// configures them. This is a convenience function that handles the capability -// checking internally. -func ConfigureHandlers(ts ToolSet, elicitHandler ElicitationHandler, samplingHandler SamplingHandler, oauthHandler func(), managedOAuth bool, unmanagedOAuthRedirectURI string) { +// It checks for Elicitable, Sampleable, SampleableWithTools, and OAuthCapable +// interfaces and configures them. This is a convenience function that handles +// the capability checking internally. +func ConfigureHandlers(ts ToolSet, elicitHandler ElicitationHandler, samplingHandler SamplingHandler, samplingWithToolsHandler SamplingWithToolsHandler, oauthHandler func(), managedOAuth bool, unmanagedOAuthRedirectURI string) { if e, ok := As[Elicitable](ts); ok { e.SetElicitationHandler(elicitHandler) } if s, ok := As[Sampleable](ts); ok { s.SetSamplingHandler(samplingHandler) } + if s, ok := As[SampleableWithTools](ts); ok { + s.SetSamplingWithToolsHandler(samplingWithToolsHandler) + } if o, ok := As[OAuthCapable](ts); ok { o.SetOAuthSuccessHandler(oauthHandler) o.SetManagedOAuth(managedOAuth) diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index 1b08f1e89..ffa4cf2eb 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -137,6 +137,7 @@ type mcpClient interface { GetPrompt(ctx context.Context, request *mcp.GetPromptParams) (*mcp.GetPromptResult, error) SetElicitationHandler(handler tools.ElicitationHandler) SetSamplingHandler(handler tools.SamplingHandler) + SetSamplingWithToolsHandler(handler tools.SamplingWithToolsHandler) SetOAuthSuccessHandler(handler func()) SetManagedOAuth(managed bool) SetUnmanagedOAuthRedirectURI(uri string) @@ -198,11 +199,12 @@ var ( // Verify that Toolset implements optional capability interfaces var ( - _ tools.Instructable = (*Toolset)(nil) - _ tools.Elicitable = (*Toolset)(nil) - _ tools.Sampleable = (*Toolset)(nil) - _ tools.OAuthCapable = (*Toolset)(nil) - _ tools.ChangeNotifier = (*Toolset)(nil) + _ tools.Instructable = (*Toolset)(nil) + _ tools.Elicitable = (*Toolset)(nil) + _ tools.Sampleable = (*Toolset)(nil) + _ tools.SampleableWithTools = (*Toolset)(nil) + _ tools.OAuthCapable = (*Toolset)(nil) + _ tools.ChangeNotifier = (*Toolset)(nil) ) // NewToolsetCommand creates a new MCP toolset from a command. @@ -501,7 +503,9 @@ func (c *clientConnector) Connect(ctx context.Context) (lifecycle.Session, error Form: &mcp.FormElicitationCapabilities{}, URL: &mcp.URLElicitationCapabilities{}, }, - Sampling: &mcp.SamplingCapabilities{}, + Sampling: &mcp.SamplingCapabilities{ + Tools: &mcp.SamplingToolsCapabilities{}, + }, }, }, } @@ -845,6 +849,10 @@ func (ts *Toolset) SetSamplingHandler(handler tools.SamplingHandler) { ts.mcpClient.SetSamplingHandler(handler) } +func (ts *Toolset) SetSamplingWithToolsHandler(handler tools.SamplingWithToolsHandler) { + ts.mcpClient.SetSamplingWithToolsHandler(handler) +} + func (ts *Toolset) SetOAuthSuccessHandler(handler func()) { ts.mcpClient.SetOAuthSuccessHandler(handler) } diff --git a/pkg/tools/mcp/mcp_test.go b/pkg/tools/mcp/mcp_test.go index dd89d82c1..8fa4e3755 100644 --- a/pkg/tools/mcp/mcp_test.go +++ b/pkg/tools/mcp/mcp_test.go @@ -44,6 +44,8 @@ func (m *mockMCPClient) SetElicitationHandler(tools.ElicitationHandler) {} func (m *mockMCPClient) SetSamplingHandler(tools.SamplingHandler) {} +func (m *mockMCPClient) SetSamplingWithToolsHandler(tools.SamplingWithToolsHandler) {} + func (m *mockMCPClient) SetOAuthSuccessHandler(func()) {} func (m *mockMCPClient) SetManagedOAuth(bool) {} diff --git a/pkg/tools/mcp/reconnect_test.go b/pkg/tools/mcp/reconnect_test.go index 8ef8aaa0c..cd4a32271 100644 --- a/pkg/tools/mcp/reconnect_test.go +++ b/pkg/tools/mcp/reconnect_test.go @@ -69,11 +69,13 @@ func (m *failingInitClient) GetPrompt(context.Context, *gomcp.GetPromptParams) ( func (m *failingInitClient) SetElicitationHandler(tools.ElicitationHandler) {} func (m *failingInitClient) SetSamplingHandler(tools.SamplingHandler) {} -func (m *failingInitClient) SetOAuthSuccessHandler(func()) {} -func (m *failingInitClient) SetManagedOAuth(bool) {} -func (m *failingInitClient) SetUnmanagedOAuthRedirectURI(string) {} -func (m *failingInitClient) SetToolListChangedHandler(func()) {} -func (m *failingInitClient) SetPromptListChangedHandler(func()) {} +func (m *failingInitClient) SetSamplingWithToolsHandler(tools.SamplingWithToolsHandler) { +} +func (m *failingInitClient) SetOAuthSuccessHandler(func()) {} +func (m *failingInitClient) SetManagedOAuth(bool) {} +func (m *failingInitClient) SetUnmanagedOAuthRedirectURI(string) {} +func (m *failingInitClient) SetToolListChangedHandler(func()) {} +func (m *failingInitClient) SetPromptListChangedHandler(func()) {} func (m *failingInitClient) Wait() error { m.mu.Lock() diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index f4c5d7f54..e633f8163 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -90,12 +90,19 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeReq toolChanged, promptChanged := c.notificationHandlers() + // Sampling: prefer the with-tools handler when registered. The SDK's two + // CreateMessage* handlers are mutually exclusive, so populate exactly one. opts := &gomcp.ClientOptions{ ElicitationHandler: c.handleElicitationRequest, - CreateMessageHandler: c.handleSamplingRequest, ToolListChangedHandler: toolChanged, PromptListChangedHandler: promptChanged, } + switch { + case c.samplingWithToolsHandler != nil: + opts.CreateMessageWithToolsHandler = c.handleSamplingWithToolsRequest + case c.samplingHandler != nil: + opts.CreateMessageHandler = c.handleSamplingRequest + } client := gomcp.NewClient(impl, opts) diff --git a/pkg/tools/mcp/session_client.go b/pkg/tools/mcp/session_client.go index 86852a861..473a912e1 100644 --- a/pkg/tools/mcp/session_client.go +++ b/pkg/tools/mcp/session_client.go @@ -23,6 +23,7 @@ type sessionClient struct { promptListChangedHandler func() elicitationHandler tools.ElicitationHandler samplingHandler tools.SamplingHandler + samplingWithToolsHandler tools.SamplingWithToolsHandler oauthSuccessHandler func() mu sync.RWMutex } @@ -188,6 +189,40 @@ func (c *sessionClient) SetSamplingHandler(handler tools.SamplingHandler) { c.mu.Unlock() } +// handleSamplingWithToolsRequest forwards incoming sampling/createMessage +// requests that may include tools to the registered handler. It is used as +// the gomcp CreateMessageWithToolsHandler callback for both stdio and remote +// clients when the with-tools handler is registered. +func (c *sessionClient) handleSamplingWithToolsRequest(ctx context.Context, req *gomcp.CreateMessageWithToolsRequest) (*gomcp.CreateMessageWithToolsResult, error) { + slog.DebugContext(ctx, "Received sampling-with-tools request from MCP server", + "messages", len(req.Params.Messages), + "tools", len(req.Params.Tools), + ) + + c.mu.RLock() + handler := c.samplingWithToolsHandler + c.mu.RUnlock() + + if handler == nil { + return nil, errors.New("no sampling-with-tools handler configured") + } + + result, err := handler(ctx, req.Params) + if err != nil { + return nil, fmt.Errorf("sampling failed: %w", err) + } + + return result, nil +} + +// SetSamplingWithToolsHandler sets the handler that processes sampling +// requests carrying a tools array from the MCP server. +func (c *sessionClient) SetSamplingWithToolsHandler(handler tools.SamplingWithToolsHandler) { + c.mu.Lock() + c.samplingWithToolsHandler = handler + c.mu.Unlock() +} + // requestElicitation invokes the registered elicitation handler directly. // This is used by the OAuth transport to trigger elicitation outside of // the normal MCP request flow. diff --git a/pkg/tools/mcp/stdio.go b/pkg/tools/mcp/stdio.go index feb4e3ac5..5234bb686 100644 --- a/pkg/tools/mcp/stdio.go +++ b/pkg/tools/mcp/stdio.go @@ -38,13 +38,20 @@ func (c *stdioMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeRequ toolChanged, promptChanged := c.notificationHandlers() - // Create client options with elicitation, sampling, and notification support + // Create client options with elicitation, sampling, and notification support. + // Sampling: prefer the with-tools handler when registered. The SDK's two + // CreateMessage* handlers are mutually exclusive, so populate exactly one. opts := &gomcp.ClientOptions{ ElicitationHandler: c.handleElicitationRequest, - CreateMessageHandler: c.handleSamplingRequest, ToolListChangedHandler: toolChanged, PromptListChangedHandler: promptChanged, } + switch { + case c.samplingWithToolsHandler != nil: + opts.CreateMessageWithToolsHandler = c.handleSamplingWithToolsRequest + case c.samplingHandler != nil: + opts.CreateMessageHandler = c.handleSamplingRequest + } client := gomcp.NewClient(&gomcp.Implementation{ Name: "docker agent", diff --git a/pkg/tools/sampling.go b/pkg/tools/sampling.go index 0bdb24e35..5af98cd52 100644 --- a/pkg/tools/sampling.go +++ b/pkg/tools/sampling.go @@ -15,3 +15,11 @@ import ( // expected to call the host's model with the supplied messages and return // the model's response (or an error if the request was declined or failed). type SamplingHandler func(ctx context.Context, req *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error) + +// SamplingWithToolsHandler handles sampling/createMessage requests that may +// involve tool use. The request carries a tools array and supports messages +// with multi-block content (tool_use, tool_result). The handler is expected +// to forward the tools to the host's model and return any tool_use blocks +// the model emits — the requesting MCP server executes the tools and +// continues the loop in a follow-up sampling request. +type SamplingWithToolsHandler func(ctx context.Context, req *mcp.CreateMessageWithToolsParams) (*mcp.CreateMessageWithToolsResult, error)