From 896592734d5ba52493e2766ab8b478988ce4e0d8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 20 Feb 2026 18:02:09 -0800 Subject: [PATCH 1/8] added experimental ways to define tools and interrupts --- go/ai/x/tool/tool.go | 138 ++++++++++ go/ai/x/tools.go | 242 ++++++++++++++++++ go/genkit/genkit.go | 119 +++++++++ go/samples/x-agent-interrupts/main.go | 219 ++++++++++++++++ .../prompts/paymentAgent.prompt | 10 + 5 files changed, 728 insertions(+) create mode 100644 go/ai/x/tool/tool.go create mode 100644 go/ai/x/tools.go create mode 100644 go/samples/x-agent-interrupts/main.go create mode 100644 go/samples/x-agent-interrupts/prompts/paymentAgent.prompt diff --git a/go/ai/x/tool/tool.go b/go/ai/x/tool/tool.go new file mode 100644 index 0000000000..0b13f9bf05 --- /dev/null +++ b/go/ai/x/tool/tool.go @@ -0,0 +1,138 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package tool provides runtime helpers for use inside tool functions. +// +// APIs in this package are under active development and may change in any +// minor version release. Use with caution in production environments. +package tool + +import ( + "context" + "fmt" + "maps" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/base" +) + +// --- Interrupt --- + +// InterruptError is returned by [Interrupt] to signal tool interruption. +type InterruptError struct { + Data any +} + +func (e *InterruptError) Error() string { + return "tool interrupted" +} + +// Interrupt interrupts tool execution and sends data to the caller. +// The caller can read this data with [InterruptAs] and resume the tool +// with [Resume]. +func Interrupt(data any) error { + return &InterruptError{Data: data} +} + +// InterruptAs extracts typed interrupt data from an interrupted tool request [ai.Part]. +// Returns the zero value and false if the part is not an interrupt or the type doesn't match. +func InterruptAs[T any](p *ai.Part) (T, bool) { + return ai.InterruptAs[T](p) +} + +// Resume creates a restart [ai.Part] for resuming an interrupted tool call. +// The interruptedPart must be an interrupted tool request (as received via +// [InterruptAs] or [ai.ModelResponse.Interrupts]). The data is delivered to +// the tool function's resume parameter when it is re-executed. +// +// This is a convenience alternative to [aix.InterruptibleTool.Resume] that +// does not require access to the tool definition. +func Resume[Res any](interruptedPart *ai.Part, data Res) (*ai.Part, error) { + if interruptedPart == nil || !interruptedPart.IsToolRequest() { + return nil, fmt.Errorf("tool.Resume: part is not a tool request") + } + + m, err := base.StructToMap(data) + if err != nil { + return nil, fmt.Errorf("tool.Resume: %w", err) + } + + newMeta := maps.Clone(interruptedPart.Metadata) + if newMeta == nil { + newMeta = make(map[string]any) + } + newMeta["resumed"] = m + delete(newMeta, "interrupt") + + newPart := ai.NewToolRequestPart(&ai.ToolRequest{ + Name: interruptedPart.ToolRequest.Name, + Ref: interruptedPart.ToolRequest.Ref, + Input: interruptedPart.ToolRequest.Input, + }) + newPart.Metadata = newMeta + return newPart, nil +} + +// --- AttachParts --- + +// AttachParts attaches additional content parts (e.g., media) to the tool's +// response. This can be called from any tool to produce a multipart response +// without changing the function signature. +func AttachParts(ctx context.Context, parts ...*ai.Part) { + c := partsCollectorKey.FromContext(ctx) + if c == nil { + return + } + c.parts = append(c.parts, parts...) +} + +// --- OriginalInput --- + +// OriginalInput returns the original input if the caller replaced it during +// restart. Returns the zero value and false if the input was not replaced +// or the tool is not being resumed. +func OriginalInput[In any](ctx context.Context) (In, bool) { + v := originalInputKey.FromContext(ctx) + if v == nil { + var zero In + return zero, false + } + return base.ConvertTo[In](v) +} + +// --- Internal plumbing (used by the aix package, not for end users) --- + +var originalInputKey = base.NewContextKey[any]() +var partsCollectorKey = base.NewContextKey[*partsCollector]() + +// partsCollector accumulates content parts attached during tool execution. +type partsCollector struct { + parts []*ai.Part +} + +// SetOriginalInput stores the original input in the context. +// This is internal plumbing used by the aix package. +func SetOriginalInput(ctx context.Context, input any) context.Context { + return originalInputKey.NewContext(ctx, input) +} + +// NewPartsContext returns a context with a parts collector. +// This is internal plumbing used by the aix package. +func NewPartsContext(ctx context.Context) (context.Context, func() []*ai.Part) { + c := &partsCollector{} + ctx = partsCollectorKey.NewContext(ctx, c) + return ctx, func() []*ai.Part { return c.parts } +} diff --git a/go/ai/x/tools.go b/go/ai/x/tools.go new file mode 100644 index 0000000000..9b85768096 --- /dev/null +++ b/go/ai/x/tools.go @@ -0,0 +1,242 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import ( + "context" + "errors" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/ai/x/tool" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/base" +) + +// ToolFunc is the function signature for tools created with [DefineTool] and [NewTool]. +type ToolFunc[In, Out any] = func(ctx context.Context, input In) (Out, error) + +// InterruptibleToolFunc is the function signature for tools created with +// [DefineInterruptibleTool] and [NewInterruptibleTool]. The resumed parameter +// is non-nil when the tool is being re-executed after an interrupt. +type InterruptibleToolFunc[In, Out, Resume any] = func(ctx context.Context, input In, res *Resume) (Out, error) + +// Tool wraps an [ai.Tool] with experimental x package features +// such as a plain [context.Context] function signature and [tool.AttachParts]. +// +// DEPRECATED(breaking): With breaking changes, Tool would not wrap ai.ToolDef. +// It would be the primary tool type, backed directly by core.DefineAction, +// eliminating the inner field and all delegation methods below. +type Tool[In, Out any] struct { + inner *ai.ToolDef[In, *ai.MultipartToolResponse] // DEPRECATED(breaking): remove wrapper; Tool owns the action directly. +} + +// DEPRECATED(breaking): All methods below exist only to implement ai.Tool by +// delegating to the wrapped ai.ToolDef. With breaking changes, Tool would own +// the action directly and implement these natively without delegation. + +// Name returns the name of the tool. +func (t *Tool[In, Out]) Name() string { return t.inner.Name() } + +// Definition returns the [ai.ToolDefinition] for this tool. +func (t *Tool[In, Out]) Definition() *ai.ToolDefinition { return t.inner.Definition() } + +// RunRaw runs the tool with raw input. +func (t *Tool[In, Out]) RunRaw(ctx context.Context, input any) (any, error) { + return t.inner.RunRaw(ctx, input) +} + +// RunRawMultipart runs the tool with raw input and returns the full multipart response. +func (t *Tool[In, Out]) RunRawMultipart(ctx context.Context, input any) (*ai.MultipartToolResponse, error) { + return t.inner.RunRawMultipart(ctx, input) +} + +// Respond creates a tool response part for an interrupted tool request. +func (t *Tool[In, Out]) Respond(toolReq *ai.Part, outputData any, opts *ai.RespondOptions) *ai.Part { + return t.inner.Respond(toolReq, outputData, opts) +} + +// Restart creates a restart part using the legacy [ai.RestartOptions]. +// +// DEPRECATED(breaking): Remove entirely. Superseded by [InterruptibleTool.Resume]. +func (t *Tool[In, Out]) Restart(toolReq *ai.Part, opts *ai.RestartOptions) *ai.Part { + return t.inner.Restart(toolReq, opts) +} + +// Register registers the tool with the given registry. +func (t *Tool[In, Out]) Register(r api.Registry) { t.inner.Register(r) } + +// InterruptibleTool is a [Tool] that supports typed interrupt/resume. +// The Res type parameter is the type of data the caller sends back when +// resuming the tool after an interrupt. +type InterruptibleTool[In, Out, Res any] struct { + Tool[In, Out] +} + +// Resume creates a restart part for resuming this interrupted tool with typed data. +// The data will be deserialized into the *Res parameter of the tool function +// when it is re-executed. +// +// Unlike [tool.Resume], this method also validates that the interrupted part +// belongs to this tool. +func (t *InterruptibleTool[In, Out, Res]) Resume(interruptedPart *ai.Part, data Res) (*ai.Part, error) { + if interruptedPart == nil || !interruptedPart.IsToolRequest() { + return nil, fmt.Errorf("Resume: part is not a tool request") + } + if interruptedPart.ToolRequest.Name != t.Name() { + return nil, fmt.Errorf("Resume: tool request is for %q, not %q", interruptedPart.ToolRequest.Name, t.Name()) + } + return tool.Resume(interruptedPart, data) +} + +// DefineTool creates a new tool with a simple function signature and registers it. +// The function receives a plain [context.Context] instead of [ai.ToolContext]. +// Use [tool.AttachParts] inside the function to return additional content parts. +func DefineTool[In, Out any]( + r api.Registry, + name, description string, + fn ToolFunc[In, Out], + opts ...ai.ToolOption, +) *Tool[In, Out] { + t := NewTool(name, description, fn, opts...) + t.Register(r) + return t +} + +// NewTool creates a new unregistered tool with a simple function signature. +// Use [tool.AttachParts] inside the function to return additional content parts. +func NewTool[In, Out any]( + name, description string, + fn ToolFunc[In, Out], + opts ...ai.ToolOption, +) *Tool[In, Out] { + // DEPRECATED(breaking): Call core.NewAction directly instead of wrapping ai.NewMultipartTool. + inner := ai.NewMultipartTool(name, description, wrapSimpleFunc(fn), opts...) + return &Tool[In, Out]{inner: inner} +} + +// DefineInterruptibleTool creates a new interruptible tool and registers it. +// The resumed parameter is non-nil when the tool is being resumed after an +// interrupt. Use [tool.Interrupt] inside the function to interrupt execution +// and send data to the caller. +func DefineInterruptibleTool[In, Out, Res any]( + r api.Registry, + name, description string, + fn InterruptibleToolFunc[In, Out, Res], + opts ...ai.ToolOption, +) *InterruptibleTool[In, Out, Res] { + t := NewInterruptibleTool(name, description, fn, opts...) + t.Register(r) + return t +} + +// NewInterruptibleTool creates a new unregistered interruptible tool. +func NewInterruptibleTool[In, Out, Res any]( + name, description string, + fn InterruptibleToolFunc[In, Out, Res], + opts ...ai.ToolOption, +) *InterruptibleTool[In, Out, Res] { + // DEPRECATED(breaking): Call core.NewAction directly instead of wrapping ai.NewMultipartTool. + inner := ai.NewMultipartTool(name, description, wrapInterruptibleFunc(fn), opts...) + return &InterruptibleTool[In, Out, Res]{Tool: Tool[In, Out]{inner: inner}} +} + +// DEPRECATED(breaking): wrapSimpleFunc exists to adapt our func(context.Context, In) (Out, error) +// to ai.MultipartToolFunc[In] (which takes *ai.ToolContext). With breaking changes, +// core.DefineAction would accept our function signature directly, and the ToolContext +// adapter, resumed/originalInput extraction from ToolContext, and interrupt error +// conversion would all be unnecessary. +func wrapSimpleFunc[In, Out any](fn ToolFunc[In, Out]) ai.MultipartToolFunc[In] { + return func(tc *ai.ToolContext, input In) (*ai.MultipartToolResponse, error) { + ctx := tc.Context + ctx, collector := tool.NewPartsContext(ctx) + if tc.OriginalInput != nil { + ctx = tool.SetOriginalInput(ctx, tc.OriginalInput) + } + + output, err := fn(ctx, input) + if err != nil { + return nil, convertInterruptError(tc, err) + } + + resp := &ai.MultipartToolResponse{Output: output} + if parts := collector(); len(parts) > 0 { + resp.Content = parts + } + return resp, nil + } +} + +// DEPRECATED(breaking): Same as wrapSimpleFunc — exists only to bridge between +// the new function signature and ai.MultipartToolFunc/ai.ToolContext. +func wrapInterruptibleFunc[In, Out, Res any](fn InterruptibleToolFunc[In, Out, Res]) ai.MultipartToolFunc[In] { + return func(tc *ai.ToolContext, input In) (*ai.MultipartToolResponse, error) { + ctx := tc.Context + ctx, collector := tool.NewPartsContext(ctx) + if tc.OriginalInput != nil { + ctx = tool.SetOriginalInput(ctx, tc.OriginalInput) + } + + // DEPRECATED(breaking): Resumed data would come from context keys set by + // the generate loop directly, not from ai.ToolContext.Resumed. + var res *Res + if tc.Resumed != nil { + r, err := base.MapToStruct[Res](tc.Resumed) + if err == nil { + res = &r + } + } + + output, err := fn(ctx, input, res) + if err != nil { + return nil, convertInterruptError(tc, err) + } + + resp := &ai.MultipartToolResponse{Output: output} + if parts := collector(); len(parts) > 0 { + resp.Content = parts + } + return resp, nil + } +} + +// DEPRECATED(breaking): convertInterruptError exists because tool.InterruptError +// must be converted to ai's unexported toolInterruptError (via tc.Interrupt) for +// the generate loop to recognize it. With breaking changes, the generate loop +// would recognize tool.InterruptError directly. +func convertInterruptError(tc *ai.ToolContext, err error) error { + var ie *tool.InterruptError + if errors.As(err, &ie) { + m, mapErr := toMap(ie.Data) + if mapErr != nil { + return fmt.Errorf("tool.Interrupt: failed to convert data: %w", mapErr) + } + return tc.Interrupt(&ai.InterruptOptions{Metadata: m}) + } + return err +} + +// DEPRECATED(breaking): toMap exists only for convertInterruptError above. +func toMap(v any) (map[string]any, error) { + if v == nil { + return nil, nil + } + if m, ok := v.(map[string]any); ok { + return m, nil + } + return base.StructToMap(v) +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 0455eac406..951cd06742 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -863,6 +863,125 @@ func DefineMultipartTool[In any](g *Genkit, name, description string, fn ai.Mult return ai.DefineMultipartTool(g.reg, name, description, fn, opts...) } +// DefineXTool defines a tool with a simplified function signature, registers it +// as a [core.Action] of type Tool, and returns an [aix.Tool]. +// +// Experimental: This API is under active development and may change in any +// minor version release. +// +// Unlike [DefineTool], the function receives a plain [context.Context] instead +// of [ai.ToolContext]. Use [tool.AttachParts] inside the function to return +// additional content parts alongside the output. +// +// For tools that don't need to be registered (e.g., dynamically created tools), +// use [aix.NewTool] instead. +// +// # Options +// +// - [ai.WithInputSchema]: Provide a custom JSON schema instead of inferring from the type parameter +// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name +// +// Example: +// +// type WeatherInput struct { +// City string `json:"city" jsonschema:"description=city name"` +// } +// +// weatherTool := genkit.DefineXTool(g, "getWeather", "Fetches the weather for a given city", +// func(ctx context.Context, input WeatherInput) (string, error) { +// if input.City == "Paris" { +// return "Sunny, 25°C", nil +// } +// return "Cloudy, 18°C", nil +// }, +// ) +// +// resp, err := genkit.Generate(ctx, g, +// ai.WithPrompt("What's the weather like in Paris?"), +// ai.WithTools(weatherTool), +// ) +// if err != nil { +// log.Fatalf("Generate failed: %v", err) +// } +// fmt.Println(resp.Text()) +func DefineXTool[In, Out any](g *Genkit, name, description string, fn aix.ToolFunc[In, Out], opts ...ai.ToolOption) *aix.Tool[In, Out] { + return aix.DefineTool(g.reg, name, description, fn, opts...) +} + +// DefineInterruptibleTool defines a tool that supports typed interrupt/resume, +// registers it as a [core.Action] of type Tool, and returns an +// [aix.InterruptibleTool]. +// +// Experimental: This API is under active development and may change in any +// minor version release. +// +// The function receives a plain [context.Context], the tool input, and a +// resumed parameter that is non-nil when the tool is being re-executed after +// an interrupt. Inside the function, call [tool.Interrupt] to pause execution +// and send data to the caller. The caller can inspect the interrupt with +// [tool.InterruptAs] and resume the tool with [tool.Resume] or the typed +// [aix.InterruptibleTool.Resume] method. +// +// For tools that don't need to be registered (e.g., dynamically created tools), +// use [aix.NewInterruptibleTool] instead. +// +// # Options +// +// - [ai.WithInputSchema]: Provide a custom JSON schema instead of inferring from the type parameter +// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name +// +// Example: +// +// type TransferInput struct { +// ToAccount string `json:"toAccount"` +// Amount float64 `json:"amount"` +// } +// +// type TransferOutput struct { +// Status string `json:"status"` +// Balance float64 `json:"balance"` +// } +// +// type Confirmation struct { +// Approved bool `json:"approved"` +// } +// +// genkit.DefineInterruptibleTool(g, "transfer", +// "Transfers money to another account.", +// func(ctx context.Context, input TransferInput, confirm *Confirmation) (*TransferOutput, error) { +// if confirm != nil && !confirm.Approved { +// return &TransferOutput{Status: "cancelled"}, nil +// } +// if confirm == nil && input.Amount > 100 { +// // Pause and ask the caller for confirmation. +// return nil, tool.Interrupt(map[string]any{ +// "reason": "large_amount", +// "amount": input.Amount, +// }) +// } +// return &TransferOutput{Status: "completed", Balance: 50}, nil +// }, +// ) +// +// // In a generate loop, handle the interrupt: +// resp, _ := genkit.Generate(ctx, g, +// ai.WithPrompt("Transfer $200 to Alice"), +// ai.WithTools(transferTool), +// ) +// if resp.FinishReason == ai.FinishReasonInterrupted { +// for _, interrupt := range resp.Interrupts() { +// restart, _ := tool.Resume(interrupt, Confirmation{Approved: true}) +// resp, _ = genkit.Generate(ctx, g, +// ai.WithMessages(resp.History()...), +// ai.WithTools(transferTool), +// ai.WithToolRestarts(restart), +// ) +// } +// } +func DefineInterruptibleTool[In, Out, Res any](g *Genkit, name, description string, fn aix.InterruptibleToolFunc[In, Out, Res], opts ...ai.ToolOption) *aix.InterruptibleTool[In, Out, Res] { + return aix.DefineInterruptibleTool(g.reg, name, description, fn, opts...) +} + // LookupTool retrieves a registered tool by its name. // It returns the tool instance if found, or `nil` if no tool with the // given name is registered (e.g., via [DefineTool]). diff --git a/go/samples/x-agent-interrupts/main.go b/go/samples/x-agent-interrupts/main.go new file mode 100644 index 0000000000..ba8f501d16 --- /dev/null +++ b/go/samples/x-agent-interrupts/main.go @@ -0,0 +1,219 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// x-agent-interrupts demonstrates the experimental tool interrupts API +// using DefinePromptAgent. Unlike x-interrupts (which handles interrupts +// inside a flow), this sample separates concerns: the agent streams +// interrupts to the client, and the client handles user interaction +// and sends resume data back. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strconv" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/ai/x/tool" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +// --- Tool schemas --- + +type TransferInput struct { + ToAccount string `json:"toAccount" jsonschema:"description=destination account ID"` + Amount float64 `json:"amount" jsonschema:"description=amount in dollars (e.g. 50.00 for $50)"` +} + +type TransferOutput struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` + NewBalance float64 `json:"newBalance,omitempty"` +} + +type TransferInterrupt struct { + Reason string `json:"reason"` + ToAccount string `json:"toAccount"` + Amount float64 `json:"amount"` + Balance float64 `json:"balance,omitempty"` +} + +type Confirmation struct { + Approved bool `json:"approved"` + AdjustedAmount *float64 `json:"adjustedAmount,omitempty"` +} + +var accountBalance = 150.00 + +func main() { + ctx := context.Background() + reader := bufio.NewReader(os.Stdin) + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + genkit.DefineInterruptibleTool(g, "transferMoney", + "Transfers money to another account. Use this when the user wants to send money.", + func(ctx context.Context, input TransferInput, confirm *Confirmation) (*TransferOutput, error) { + if confirm != nil { + if !confirm.Approved { + return &TransferOutput{"cancelled", "Transfer cancelled by user.", accountBalance}, nil + } + if confirm.AdjustedAmount != nil { + input.Amount = *confirm.AdjustedAmount + } + } + + if input.Amount > accountBalance { + if accountBalance <= 0 { + return &TransferOutput{"rejected", "Account balance is 0. Please add funds.", accountBalance}, nil + } + return nil, tool.Interrupt(TransferInterrupt{ + "insufficient_balance", input.ToAccount, input.Amount, accountBalance, + }) + } + + if confirm == nil && input.Amount > 100 { + return nil, tool.Interrupt(TransferInterrupt{ + "confirm_large", input.ToAccount, input.Amount, accountBalance, + }) + } + + accountBalance -= input.Amount + return &TransferOutput{ + "completed", + fmt.Sprintf("Transferred $%.2f to %s.", input.Amount, input.ToAccount), + accountBalance, + }, nil + }) + + paymentAgent := genkit.DefinePromptAgent[any, any](g, "paymentAgent", nil) + + fmt.Println("Payment Agent (Prompt Agent + Interrupts)") + fmt.Printf("Balance: $%.2f\n", accountBalance) + fmt.Println("Type 'quit' to exit.") + fmt.Println() + + conn, err := paymentAgent.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + if input == "quit" || input == "exit" { + conn.Close() + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + var interrupts []*ai.Part + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) + interrupts = append(interrupts, chunk.ModelChunk.Interrupts()...) + } + + if chunk.EndTurn { + if len(interrupts) == 0 { + fmt.Println() + fmt.Println() + break + } + + var restarts []*ai.Part + for _, interrupt := range interrupts { + if restart, err := handleInterrupt(reader, interrupt); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + } else if restart != nil { + restarts = append(restarts, restart) + } + } + interrupts = nil + + if err := conn.SendToolRestarts(restarts...); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + } + } + } + + conn.Output() +} + +func handleInterrupt(reader *bufio.Reader, part *ai.Part) (*ai.Part, error) { + meta, ok := tool.InterruptAs[TransferInterrupt](part) + if !ok { + return nil, nil + } + + switch meta.Reason { + case "insufficient_balance": + fmt.Printf("\n[Insufficient Balance] You requested $%.2f but only have $%.2f\n", meta.Amount, meta.Balance) + fmt.Printf("Options: [1] Transfer $%.2f instead [2] Cancel\n", meta.Balance) + fmt.Print("Choice: ") + + if promptChoice(reader, 1, 2) == 1 { + return tool.Resume(part, Confirmation{ + Approved: true, + AdjustedAmount: &meta.Balance, + }) + } + return tool.Resume(part, Confirmation{Approved: false}) + + case "confirm_large": + fmt.Printf("\n[Confirm Large Transfer] Send $%.2f to %s? (yes/no): ", meta.Amount, meta.ToAccount) + return tool.Resume(part, Confirmation{Approved: promptYesNo(reader)}) + + default: + return tool.Resume(part, Confirmation{Approved: true}) + } +} + +func promptChoice(reader *bufio.Reader, min, max int) int { + for { + text, _ := reader.ReadString('\n') + text = strings.TrimSpace(text) + n, err := strconv.Atoi(text) + if err == nil && n >= min && n <= max { + return n + } + fmt.Printf("Please enter a number between %d and %d: ", min, max) + } +} + +func promptYesNo(reader *bufio.Reader) bool { + text, _ := reader.ReadString('\n') + text = strings.ToLower(strings.TrimSpace(text)) + return text == "yes" || text == "y" +} diff --git a/go/samples/x-agent-interrupts/prompts/paymentAgent.prompt b/go/samples/x-agent-interrupts/prompts/paymentAgent.prompt new file mode 100644 index 0000000000..c95c6a05bc --- /dev/null +++ b/go/samples/x-agent-interrupts/prompts/paymentAgent.prompt @@ -0,0 +1,10 @@ +--- +model: googleai/gemini-3-flash-preview +config: + thinkingConfig: + thinkingBudget: 0 +tools: + - transferMoney +--- +{{role "system"}} +You are a helpful payment assistant. When the user wants to transfer money, use the transferMoney tool. Always confirm the result with the user. From 83f2cce3e48bfad6f84a1145cbe304804c84920b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 20 Feb 2026 18:45:51 -0800 Subject: [PATCH 2/8] added streaming from tools --- go/ai/document.go | 15 +++++++++++++++ go/ai/generate.go | 31 ++++++++++++++++++++++++++++++- go/ai/x/tool/tool.go | 21 +++++++++++++++++++++ go/internal/base/context_key.go | 4 ++++ 4 files changed, 70 insertions(+), 1 deletion(-) diff --git a/go/ai/document.go b/go/ai/document.go index 1c6407ef17..7b0270b7fd 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -156,6 +156,21 @@ func (p *Part) IsInterrupt() bool { return p != nil && p.IsToolRequest() && p.Metadata != nil && p.Metadata["interrupt"] != nil } +// IsPartial reports whether the [Part] contains a partial tool response +// streamed during tool execution (e.g., a progress update). +func (p *Part) IsPartial() bool { + return p != nil && p.IsToolResponse() && p.Metadata != nil && p.Metadata["partial"] == true +} + +// NewPartialToolResponsePart returns a [Part] containing a partial tool response. +// Partial tool responses are streamed during tool execution for client-side +// display (e.g., progress indicators) and are not included in conversation history. +func NewPartialToolResponsePart(r *ToolResponse) *Part { + p := NewToolResponsePart(r) + p.Metadata = map[string]any{"partial": true} + return p +} + // IsCustom reports whether the [Part] contains custom plugin-specific data. func (p *Part) IsCustom() bool { return p != nil && p.Kind == PartCustom diff --git a/go/ai/generate.go b/go/ai/generate.go index c641012609..32dd6a1309 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -796,7 +796,23 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input) + // Inject a per-tool chunk sender so tools can stream partial + // responses (e.g., progress updates) via tool.SendChunk. + toolCtx := ctx + if cb != nil { + toolCtx = base.ToolPartialSenderKey.NewContext(ctx, func(sendCtx context.Context, output any) { + cb(sendCtx, &ModelResponseChunk{ + Role: RoleTool, + Content: []*Part{NewPartialToolResponsePart(&ToolResponse{ + Name: toolReq.Name, + Ref: toolReq.Ref, + Output: output, + })}, + }) + }) + } + + multipartResp, err := tool.RunRawMultipart(toolCtx, toolReq.Input) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -1028,6 +1044,19 @@ func (c *ModelResponseChunk) Interrupts() []*Part { return parts } +// ToolResponses returns the tool response parts from the chunk. +// Use [Part.IsPartial] to distinguish streaming progress updates +// from final tool results. +func (c *ModelResponseChunk) ToolResponses() []*Part { + var parts []*Part + for _, p := range c.Content { + if p.IsToolResponse() { + parts = append(parts, p) + } + } + return parts +} + // Output parses the chunk using the format handler and unmarshals the result into v. // Returns an error if the format handler is not set or does not support parsing chunks. func (c *ModelResponseChunk) Output(v any) error { diff --git a/go/ai/x/tool/tool.go b/go/ai/x/tool/tool.go index 0b13f9bf05..9a8b6db94f 100644 --- a/go/ai/x/tool/tool.go +++ b/go/ai/x/tool/tool.go @@ -86,6 +86,27 @@ func Resume[Res any](interruptedPart *ai.Part, data Res) (*ai.Part, error) { return newPart, nil } +// --- SendChunk --- + +// SendPartial streams a partial tool response during tool execution. +// The output is arbitrary structured data (e.g., progress information) +// that will be delivered to the client as a partial [ai.ToolResponse]. +// +// This is best-effort: if no streaming callback is available (e.g., the +// tool is called via a non-streaming Generate), the call is a no-op. +// The tool's final return value is always the authoritative response. +// +// Example: +// +// tool.SendPartial(ctx, map[string]any{"step": "uploading", "progress": 50}) +func SendPartial(ctx context.Context, output any) { + send := base.ToolPartialSenderKey.FromContext(ctx) + if send == nil { + return + } + send(ctx, output) +} + // --- AttachParts --- // AttachParts attaches additional content parts (e.g., media) to the tool's diff --git a/go/internal/base/context_key.go b/go/internal/base/context_key.go index aab9d5c2bf..4f1c33298f 100644 --- a/go/internal/base/context_key.go +++ b/go/internal/base/context_key.go @@ -41,3 +41,7 @@ func (k ContextKey[T]) FromContext(ctx context.Context) T { t, _ := ctx.Value(k.key).(T) return t } + +// ToolPartialSenderKey is the context key for streaming partial tool responses. +// Set by ai/generate.go (handleToolRequests), read by ai/x/tool (SendPartial). +var ToolPartialSenderKey = NewContextKey[func(context.Context, any)]() From 2c123b223dc530cbf98cf9bc9dff24f208e56f36 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 20 Feb 2026 20:23:33 -0800 Subject: [PATCH 3/8] feedback --- go/ai/x/tools.go | 11 ++++++----- go/genkit/genkit.go | 16 +++++++++++----- go/samples/x-agent-interrupts/main.go | 3 ++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/go/ai/x/tools.go b/go/ai/x/tools.go index 9b85768096..413babc00f 100644 --- a/go/ai/x/tools.go +++ b/go/ai/x/tools.go @@ -183,7 +183,7 @@ func wrapSimpleFunc[In, Out any](fn ToolFunc[In, Out]) ai.MultipartToolFunc[In] // DEPRECATED(breaking): Same as wrapSimpleFunc — exists only to bridge between // the new function signature and ai.MultipartToolFunc/ai.ToolContext. -func wrapInterruptibleFunc[In, Out, Res any](fn InterruptibleToolFunc[In, Out, Res]) ai.MultipartToolFunc[In] { +func wrapInterruptibleFunc[In, Out, Resume any](fn InterruptibleToolFunc[In, Out, Resume]) ai.MultipartToolFunc[In] { return func(tc *ai.ToolContext, input In) (*ai.MultipartToolResponse, error) { ctx := tc.Context ctx, collector := tool.NewPartsContext(ctx) @@ -193,12 +193,13 @@ func wrapInterruptibleFunc[In, Out, Res any](fn InterruptibleToolFunc[In, Out, R // DEPRECATED(breaking): Resumed data would come from context keys set by // the generate loop directly, not from ai.ToolContext.Resumed. - var res *Res + var res *Resume if tc.Resumed != nil { - r, err := base.MapToStruct[Res](tc.Resumed) - if err == nil { - res = &r + r, err := base.MapToStruct[Resume](tc.Resumed) + if err != nil { + return nil, fmt.Errorf("aix.wrapInterruptibleFunc: failed to convert resumed data: %w", err) } + res = &r } output, err := fn(ctx, input, res) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 951cd06742..abf5ec5d7e 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -942,11 +942,16 @@ func DefineXTool[In, Out any](g *Genkit, name, description string, fn aix.ToolFu // Balance float64 `json:"balance"` // } // +// type TransferInterrupt struct { +// Reason string `json:"reason"` +// Amount float64 `json:"amount"` +// } +// // type Confirmation struct { // Approved bool `json:"approved"` // } // -// genkit.DefineInterruptibleTool(g, "transfer", +// transferTool := genkit.DefineInterruptibleTool(g, "transfer", // "Transfers money to another account.", // func(ctx context.Context, input TransferInput, confirm *Confirmation) (*TransferOutput, error) { // if confirm != nil && !confirm.Approved { @@ -954,9 +959,9 @@ func DefineXTool[In, Out any](g *Genkit, name, description string, fn aix.ToolFu // } // if confirm == nil && input.Amount > 100 { // // Pause and ask the caller for confirmation. -// return nil, tool.Interrupt(map[string]any{ -// "reason": "large_amount", -// "amount": input.Amount, +// return nil, tool.Interrupt(TransferInterrupt{ +// Reason: "large_amount", +// Amount: input.Amount, // }) // } // return &TransferOutput{Status: "completed", Balance: 50}, nil @@ -970,6 +975,7 @@ func DefineXTool[In, Out any](g *Genkit, name, description string, fn aix.ToolFu // ) // if resp.FinishReason == ai.FinishReasonInterrupted { // for _, interrupt := range resp.Interrupts() { +// // Ask the user for confirmation. // restart, _ := tool.Resume(interrupt, Confirmation{Approved: true}) // resp, _ = genkit.Generate(ctx, g, // ai.WithMessages(resp.History()...), @@ -978,7 +984,7 @@ func DefineXTool[In, Out any](g *Genkit, name, description string, fn aix.ToolFu // ) // } // } -func DefineInterruptibleTool[In, Out, Res any](g *Genkit, name, description string, fn aix.InterruptibleToolFunc[In, Out, Res], opts ...ai.ToolOption) *aix.InterruptibleTool[In, Out, Res] { +func DefineInterruptibleTool[In, Out, Resume any](g *Genkit, name, description string, fn aix.InterruptibleToolFunc[In, Out, Resume], opts ...ai.ToolOption) *aix.InterruptibleTool[In, Out, Resume] { return aix.DefineInterruptibleTool(g.reg, name, description, fn, opts...) } diff --git a/go/samples/x-agent-interrupts/main.go b/go/samples/x-agent-interrupts/main.go index ba8f501d16..da2b48c361 100644 --- a/go/samples/x-agent-interrupts/main.go +++ b/go/samples/x-agent-interrupts/main.go @@ -196,7 +196,8 @@ func handleInterrupt(reader *bufio.Reader, part *ai.Part) (*ai.Part, error) { return tool.Resume(part, Confirmation{Approved: promptYesNo(reader)}) default: - return tool.Resume(part, Confirmation{Approved: true}) + fmt.Printf("\n[Unknown Interrupt] Reason: %q. Cancelling transaction.\n", meta.Reason) + return tool.Resume(part, Confirmation{Approved: false}) } } From b5dc72cee5bbeb3999e6def34e6e3b6eb2d5ef43 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Feb 2026 21:03:10 -0800 Subject: [PATCH 4/8] added `InterruptibleTool.Respond()` --- go/ai/x/tool/tool.go | 24 ++++++++++++++++++++++-- go/ai/x/tools.go | 20 ++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/go/ai/x/tool/tool.go b/go/ai/x/tool/tool.go index 9a8b6db94f..f1f2ea5c6b 100644 --- a/go/ai/x/tool/tool.go +++ b/go/ai/x/tool/tool.go @@ -61,8 +61,8 @@ func InterruptAs[T any](p *ai.Part) (T, bool) { // This is a convenience alternative to [aix.InterruptibleTool.Resume] that // does not require access to the tool definition. func Resume[Res any](interruptedPart *ai.Part, data Res) (*ai.Part, error) { - if interruptedPart == nil || !interruptedPart.IsToolRequest() { - return nil, fmt.Errorf("tool.Resume: part is not a tool request") + if interruptedPart == nil || !interruptedPart.IsInterrupt() { + return nil, fmt.Errorf("tool.Resume: part is not an interrupted tool request") } m, err := base.StructToMap(data) @@ -86,6 +86,26 @@ func Resume[Res any](interruptedPart *ai.Part, data Res) (*ai.Part, error) { return newPart, nil } +// --- Respond --- + +// Respond creates a tool response [ai.Part] for an interrupted tool request. +// Instead of re-executing the tool (as [Resume] does), this provides a +// pre-computed result directly. +// +// This is a convenience alternative to [aix.Tool.Respond] that does not +// require access to the tool definition. +func Respond(interruptedPart *ai.Part, output any) (*ai.Part, error) { + if interruptedPart == nil || !interruptedPart.IsInterrupt() { + return nil, fmt.Errorf("tool.Respond: part is not an interrupted tool request") + } + + resp := ai.NewResponseForToolRequest(interruptedPart, output) + resp.Metadata = map[string]any{ + "interruptResponse": true, + } + return resp, nil +} + // --- SendChunk --- // SendPartial streams a partial tool response during tool execution. diff --git a/go/ai/x/tools.go b/go/ai/x/tools.go index 413babc00f..3da776396a 100644 --- a/go/ai/x/tools.go +++ b/go/ai/x/tools.go @@ -94,8 +94,8 @@ type InterruptibleTool[In, Out, Res any] struct { // Unlike [tool.Resume], this method also validates that the interrupted part // belongs to this tool. func (t *InterruptibleTool[In, Out, Res]) Resume(interruptedPart *ai.Part, data Res) (*ai.Part, error) { - if interruptedPart == nil || !interruptedPart.IsToolRequest() { - return nil, fmt.Errorf("Resume: part is not a tool request") + if interruptedPart == nil || !interruptedPart.IsInterrupt() { + return nil, fmt.Errorf("Resume: part is not an interrupted tool request") } if interruptedPart.ToolRequest.Name != t.Name() { return nil, fmt.Errorf("Resume: tool request is for %q, not %q", interruptedPart.ToolRequest.Name, t.Name()) @@ -103,6 +103,22 @@ func (t *InterruptibleTool[In, Out, Res]) Resume(interruptedPart *ai.Part, data return tool.Resume(interruptedPart, data) } +// Respond creates a tool response [ai.Part] for an interrupted tool request. +// Instead of re-executing the tool (as [Resume] does), this provides a +// pre-computed result directly. +// +// Unlike [tool.Respond], this method validates that the interrupted part +// belongs to this tool and accepts a strongly-typed output. +func (t *InterruptibleTool[In, Out, Res]) Respond(interruptedPart *ai.Part, output Out) (*ai.Part, error) { + if interruptedPart == nil || !interruptedPart.IsInterrupt() { + return nil, fmt.Errorf("Respond: part is not an interrupted tool request") + } + if interruptedPart.ToolRequest.Name != t.Name() { + return nil, fmt.Errorf("Respond: tool request is for %q, not %q", interruptedPart.ToolRequest.Name, t.Name()) + } + return tool.Respond(interruptedPart, output) +} + // DefineTool creates a new tool with a simple function signature and registers it. // The function receives a plain [context.Context] instead of [ai.ToolContext]. // Use [tool.AttachParts] inside the function to return additional content parts. From 5a84dfe57057179381fe0905ee0b21b4660730f1 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Feb 2026 22:00:30 -0800 Subject: [PATCH 5/8] Update generate.go --- go/ai/generate.go | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 32dd6a1309..3ef503b3ff 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -916,11 +916,10 @@ func (mr *ModelResponse) History() []*Message { // Reasoning concatenates all reasoning parts present in the message func (mr *ModelResponse) Reasoning() string { - var sb strings.Builder if mr == nil || mr.Message == nil { return "" } - + var sb strings.Builder for _, p := range mr.Message.Content { if !p.IsReasoning() { continue @@ -934,7 +933,7 @@ func (mr *ModelResponse) Reasoning() string { // If a format handler is set, it uses the handler's ParseOutput method. // Otherwise, it falls back to parsing the response text as JSON. func (mr *ModelResponse) Output(v any) error { - if mr.Message == nil || len(mr.Message.Content) == 0 { + if mr == nil || mr.Message == nil || len(mr.Message.Content) == 0 { return errors.New("no content in response") } @@ -959,28 +958,28 @@ func (mr *ModelResponse) Output(v any) error { } // ToolRequests returns the tool requests from the response. -func (mr *ModelResponse) ToolRequests() []*ToolRequest { - toolReqs := []*ToolRequest{} +func (mr *ModelResponse) ToolRequests() []*Part { + var parts []*Part if mr == nil || mr.Message == nil { - return toolReqs + return parts } - for _, part := range mr.Message.Content { - if part.IsToolRequest() { - toolReqs = append(toolReqs, part.ToolRequest) + for _, p := range mr.Message.Content { + if p.IsToolRequest() { + parts = append(parts, p) } } - return toolReqs + return parts } // Interrupts returns the interrupted tool request parts from the response. func (mr *ModelResponse) Interrupts() []*Part { - parts := []*Part{} + var parts []*Part if mr == nil || mr.Message == nil { return parts } - for _, part := range mr.Message.Content { - if part.IsInterrupt() { - parts = append(parts, part) + for _, p := range mr.Message.Content { + if p.IsInterrupt() { + parts = append(parts, p) } } return parts @@ -1003,12 +1002,9 @@ func (mr *ModelResponse) Media() string { // It returns an empty string if there is no Content in the response chunk. // For the parsed structured output, use [ModelResponseChunk.Output] instead. func (c *ModelResponseChunk) Text() string { - if len(c.Content) == 0 { + if c == nil { return "" } - if len(c.Content) == 1 { - return c.Content[0].Text - } var sb strings.Builder for _, p := range c.Content { if p.IsText() || p.IsData() { @@ -1021,7 +1017,7 @@ func (c *ModelResponseChunk) Text() string { // Reasoning returns the reasoning content of the ModelResponseChunk as a string. // It returns an empty string if there is no Content in the response chunk. func (c *ModelResponseChunk) Reasoning() string { - if len(c.Content) == 0 { + if c == nil { return "" } var sb strings.Builder @@ -1036,6 +1032,9 @@ func (c *ModelResponseChunk) Reasoning() string { // Interrupts returns the interrupted tool request parts from the chunk. func (c *ModelResponseChunk) Interrupts() []*Part { var parts []*Part + if c == nil { + return parts + } for _, p := range c.Content { if p.IsInterrupt() { parts = append(parts, p) @@ -1049,6 +1048,9 @@ func (c *ModelResponseChunk) Interrupts() []*Part { // from final tool results. func (c *ModelResponseChunk) ToolResponses() []*Part { var parts []*Part + if c == nil { + return parts + } for _, p := range c.Content { if p.IsToolResponse() { parts = append(parts, p) @@ -1060,6 +1062,9 @@ func (c *ModelResponseChunk) ToolResponses() []*Part { // Output parses the chunk using the format handler and unmarshals the result into v. // Returns an error if the format handler is not set or does not support parsing chunks. func (c *ModelResponseChunk) Output(v any) error { + if c == nil { + return errors.New("chunk is nil") + } if c.formatHandler == nil { return errors.New("output format chosen does not support parsing chunks") } From 689820ac461877a0cc118b9a4286697390dd38ec Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 24 Feb 2026 22:20:19 -0800 Subject: [PATCH 6/8] Update tools.go --- go/ai/x/tools.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/go/ai/x/tools.go b/go/ai/x/tools.go index 3da776396a..d888e0cc23 100644 --- a/go/ai/x/tools.go +++ b/go/ai/x/tools.go @@ -93,14 +93,14 @@ type InterruptibleTool[In, Out, Res any] struct { // // Unlike [tool.Resume], this method also validates that the interrupted part // belongs to this tool. -func (t *InterruptibleTool[In, Out, Res]) Resume(interruptedPart *ai.Part, data Res) (*ai.Part, error) { - if interruptedPart == nil || !interruptedPart.IsInterrupt() { +func (t *InterruptibleTool[In, Out, Resume]) Resume(part *ai.Part, res Resume) (*ai.Part, error) { + if part == nil || !part.IsInterrupt() { return nil, fmt.Errorf("Resume: part is not an interrupted tool request") } - if interruptedPart.ToolRequest.Name != t.Name() { - return nil, fmt.Errorf("Resume: tool request is for %q, not %q", interruptedPart.ToolRequest.Name, t.Name()) + if part.ToolRequest.Name != t.Name() { + return nil, fmt.Errorf("Resume: tool request is for %q, not %q", part.ToolRequest.Name, t.Name()) } - return tool.Resume(interruptedPart, data) + return tool.Resume(part, res) } // Respond creates a tool response [ai.Part] for an interrupted tool request. @@ -109,14 +109,14 @@ func (t *InterruptibleTool[In, Out, Res]) Resume(interruptedPart *ai.Part, data // // Unlike [tool.Respond], this method validates that the interrupted part // belongs to this tool and accepts a strongly-typed output. -func (t *InterruptibleTool[In, Out, Res]) Respond(interruptedPart *ai.Part, output Out) (*ai.Part, error) { - if interruptedPart == nil || !interruptedPart.IsInterrupt() { +func (t *InterruptibleTool[In, Out, Resume]) Respond(part *ai.Part, output Out) (*ai.Part, error) { + if part == nil || !part.IsInterrupt() { return nil, fmt.Errorf("Respond: part is not an interrupted tool request") } - if interruptedPart.ToolRequest.Name != t.Name() { - return nil, fmt.Errorf("Respond: tool request is for %q, not %q", interruptedPart.ToolRequest.Name, t.Name()) + if part.ToolRequest.Name != t.Name() { + return nil, fmt.Errorf("Respond: tool request is for %q, not %q", part.ToolRequest.Name, t.Name()) } - return tool.Respond(interruptedPart, output) + return tool.Respond(part, output) } // DefineTool creates a new tool with a simple function signature and registers it. From 59e6cf8a1efb505b51ad37a5a1dee68c81f3d9df Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 07:39:57 -0800 Subject: [PATCH 7/8] added `tool.SendChunk()` --- go/ai/generate.go | 8 ++++++-- go/ai/x/tool/tool.go | 19 ++++++++++++++++++- go/internal/base/context_key.go | 7 +++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 3ef503b3ff..f62c17334b 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -796,8 +796,9 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - // Inject a per-tool chunk sender so tools can stream partial - // responses (e.g., progress updates) via tool.SendChunk. + // Inject per-tool streaming senders so tools can stream via + // tool.SendPartial (wrapped partial responses) and + // tool.SendChunk (raw model response chunks). toolCtx := ctx if cb != nil { toolCtx = base.ToolPartialSenderKey.NewContext(ctx, func(sendCtx context.Context, output any) { @@ -810,6 +811,9 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, })}, }) }) + toolCtx = base.ToolChunkSenderKey.NewContext(toolCtx, func(sendCtx context.Context, chunk any) { + cb(sendCtx, chunk.(*ModelResponseChunk)) + }) } multipartResp, err := tool.RunRawMultipart(toolCtx, toolReq.Input) diff --git a/go/ai/x/tool/tool.go b/go/ai/x/tool/tool.go index f1f2ea5c6b..ed42dba2e8 100644 --- a/go/ai/x/tool/tool.go +++ b/go/ai/x/tool/tool.go @@ -106,7 +106,7 @@ func Respond(interruptedPart *ai.Part, output any) (*ai.Part, error) { return resp, nil } -// --- SendChunk --- +// --- SendPartial --- // SendPartial streams a partial tool response during tool execution. // The output is arbitrary structured data (e.g., progress information) @@ -127,6 +127,23 @@ func SendPartial(ctx context.Context, output any) { send(ctx, output) } +// --- SendChunk --- + +// SendChunk streams a raw [ai.ModelResponseChunk] during tool execution. +// Unlike [SendPartial], which wraps arbitrary data in a partial tool response, +// SendChunk gives the tool full control over the chunk contents. +// +// This is best-effort: if no streaming callback is available (e.g., the +// tool is called via a non-streaming Generate), the call is a no-op. +// The tool's final return value is always the authoritative response. +func SendChunk(ctx context.Context, chunk *ai.ModelResponseChunk) { + send := base.ToolChunkSenderKey.FromContext(ctx) + if send == nil { + return + } + send(ctx, chunk) +} + // --- AttachParts --- // AttachParts attaches additional content parts (e.g., media) to the tool's diff --git a/go/internal/base/context_key.go b/go/internal/base/context_key.go index 4f1c33298f..74f4fdc443 100644 --- a/go/internal/base/context_key.go +++ b/go/internal/base/context_key.go @@ -45,3 +45,10 @@ func (k ContextKey[T]) FromContext(ctx context.Context) T { // ToolPartialSenderKey is the context key for streaming partial tool responses. // Set by ai/generate.go (handleToolRequests), read by ai/x/tool (SendPartial). var ToolPartialSenderKey = NewContextKey[func(context.Context, any)]() + +// ToolChunkSenderKey is the context key for streaming raw model response chunks +// from within a tool. Set by ai/generate.go (handleToolRequests), read by +// ai/x/tool (SendChunk). The any value is *ai.ModelResponseChunk (typed as any +// to avoid a circular import). +var ToolChunkSenderKey = NewContextKey[func(context.Context, any)]() + From ad74c5ad7b7e4cf2da7773209972c5c6109d2434 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 25 Feb 2026 07:42:28 -0800 Subject: [PATCH 8/8] Update context_key.go --- go/internal/base/context_key.go | 1 - 1 file changed, 1 deletion(-) diff --git a/go/internal/base/context_key.go b/go/internal/base/context_key.go index 74f4fdc443..f85ebea4ce 100644 --- a/go/internal/base/context_key.go +++ b/go/internal/base/context_key.go @@ -51,4 +51,3 @@ var ToolPartialSenderKey = NewContextKey[func(context.Context, any)]() // ai/x/tool (SendChunk). The any value is *ai.ModelResponseChunk (typed as any // to avoid a circular import). var ToolChunkSenderKey = NewContextKey[func(context.Context, any)]() -