Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions go/ai/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ func (p *Part) IsInterrupt() bool {
return p != nil && p.IsToolRequest() && p.Metadata != nil && p.Metadata["interrupt"] != nil
}

// IsPartial reports whether the [Part] contains a partial tool response
// streamed during tool execution (e.g., a progress update).
func (p *Part) IsPartial() bool {
return p != nil && p.IsToolResponse() && p.Metadata != nil && p.Metadata["partial"] == true
}

// NewPartialToolResponsePart returns a [Part] containing a partial tool response.
// Partial tool responses are streamed during tool execution for client-side
// display (e.g., progress indicators) and are not included in conversation history.
func NewPartialToolResponsePart(r *ToolResponse) *Part {
p := NewToolResponsePart(r)
p.Metadata = map[string]any{"partial": true}
return p
}

// IsCustom reports whether the [Part] contains custom plugin-specific data.
func (p *Part) IsCustom() bool {
return p != nil && p.Kind == PartCustom
Expand Down
78 changes: 58 additions & 20 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,27 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
return
}

multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input)
// Inject per-tool streaming senders so tools can stream via
// tool.SendPartial (wrapped partial responses) and
// tool.SendChunk (raw model response chunks).
toolCtx := ctx
if cb != nil {
toolCtx = base.ToolPartialSenderKey.NewContext(ctx, func(sendCtx context.Context, output any) {
cb(sendCtx, &ModelResponseChunk{
Role: RoleTool,
Content: []*Part{NewPartialToolResponsePart(&ToolResponse{
Name: toolReq.Name,
Ref: toolReq.Ref,
Output: output,
})},
})
})
toolCtx = base.ToolChunkSenderKey.NewContext(toolCtx, func(sendCtx context.Context, chunk any) {
cb(sendCtx, chunk.(*ModelResponseChunk))
})
}

multipartResp, err := tool.RunRawMultipart(toolCtx, toolReq.Input)
if err != nil {
var tie *toolInterruptError
if errors.As(err, &tie) {
Expand Down Expand Up @@ -900,11 +920,10 @@ func (mr *ModelResponse) History() []*Message {

// Reasoning concatenates all reasoning parts present in the message
func (mr *ModelResponse) Reasoning() string {
var sb strings.Builder
if mr == nil || mr.Message == nil {
return ""
}

var sb strings.Builder
for _, p := range mr.Message.Content {
if !p.IsReasoning() {
continue
Expand All @@ -918,7 +937,7 @@ func (mr *ModelResponse) Reasoning() string {
// If a format handler is set, it uses the handler's ParseOutput method.
// Otherwise, it falls back to parsing the response text as JSON.
func (mr *ModelResponse) Output(v any) error {
if mr.Message == nil || len(mr.Message.Content) == 0 {
if mr == nil || mr.Message == nil || len(mr.Message.Content) == 0 {
return errors.New("no content in response")
}

Expand All @@ -943,28 +962,28 @@ func (mr *ModelResponse) Output(v any) error {
}

// ToolRequests returns the tool requests from the response.
func (mr *ModelResponse) ToolRequests() []*ToolRequest {
toolReqs := []*ToolRequest{}
func (mr *ModelResponse) ToolRequests() []*Part {
var parts []*Part
if mr == nil || mr.Message == nil {
return toolReqs
return parts
}
for _, part := range mr.Message.Content {
if part.IsToolRequest() {
toolReqs = append(toolReqs, part.ToolRequest)
for _, p := range mr.Message.Content {
if p.IsToolRequest() {
parts = append(parts, p)
}
}
return toolReqs
return parts
}

// Interrupts returns the interrupted tool request parts from the response.
func (mr *ModelResponse) Interrupts() []*Part {
parts := []*Part{}
var parts []*Part
if mr == nil || mr.Message == nil {
return parts
}
for _, part := range mr.Message.Content {
if part.IsInterrupt() {
parts = append(parts, part)
for _, p := range mr.Message.Content {
if p.IsInterrupt() {
parts = append(parts, p)
}
}
return parts
Expand All @@ -987,12 +1006,9 @@ func (mr *ModelResponse) Media() string {
// It returns an empty string if there is no Content in the response chunk.
// For the parsed structured output, use [ModelResponseChunk.Output] instead.
func (c *ModelResponseChunk) Text() string {
if len(c.Content) == 0 {
if c == nil {
return ""
}
if len(c.Content) == 1 {
return c.Content[0].Text
}
var sb strings.Builder
for _, p := range c.Content {
if p.IsText() || p.IsData() {
Expand All @@ -1005,7 +1021,7 @@ func (c *ModelResponseChunk) Text() string {
// Reasoning returns the reasoning content of the ModelResponseChunk as a string.
// It returns an empty string if there is no Content in the response chunk.
func (c *ModelResponseChunk) Reasoning() string {
if len(c.Content) == 0 {
if c == nil {
return ""
}
var sb strings.Builder
Expand All @@ -1020,6 +1036,9 @@ func (c *ModelResponseChunk) Reasoning() string {
// Interrupts returns the interrupted tool request parts from the chunk.
func (c *ModelResponseChunk) Interrupts() []*Part {
var parts []*Part
if c == nil {
return parts
}
for _, p := range c.Content {
if p.IsInterrupt() {
parts = append(parts, p)
Expand All @@ -1028,9 +1047,28 @@ func (c *ModelResponseChunk) Interrupts() []*Part {
return parts
}

// ToolResponses returns the tool response parts from the chunk.
// Use [Part.IsPartial] to distinguish streaming progress updates
// from final tool results.
func (c *ModelResponseChunk) ToolResponses() []*Part {
var parts []*Part
if c == nil {
return parts
}
for _, p := range c.Content {
if p.IsToolResponse() {
parts = append(parts, p)
}
}
return parts
}

// Output parses the chunk using the format handler and unmarshals the result into v.
// Returns an error if the format handler is not set or does not support parsing chunks.
func (c *ModelResponseChunk) Output(v any) error {
if c == nil {
return errors.New("chunk is nil")
}
if c.formatHandler == nil {
return errors.New("output format chosen does not support parsing chunks")
}
Expand Down
196 changes: 196 additions & 0 deletions go/ai/x/tool/tool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

// Package tool provides runtime helpers for use inside tool functions.
//
// APIs in this package are under active development and may change in any
// minor version release. Use with caution in production environments.
package tool

import (
"context"
"fmt"
"maps"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/internal/base"
)

// --- Interrupt ---

// InterruptError is returned by [Interrupt] to signal tool interruption.
type InterruptError struct {
Data any
}

func (e *InterruptError) Error() string {
return "tool interrupted"
}

// Interrupt interrupts tool execution and sends data to the caller.
// The caller can read this data with [InterruptAs] and resume the tool
// with [Resume].
func Interrupt(data any) error {
return &InterruptError{Data: data}
}

// InterruptAs extracts typed interrupt data from an interrupted tool request [ai.Part].
// Returns the zero value and false if the part is not an interrupt or the type doesn't match.
func InterruptAs[T any](p *ai.Part) (T, bool) {
return ai.InterruptAs[T](p)
}

// Resume creates a restart [ai.Part] for resuming an interrupted tool call.
// The interruptedPart must be an interrupted tool request (as received via
// [InterruptAs] or [ai.ModelResponse.Interrupts]). The data is delivered to
// the tool function's resume parameter when it is re-executed.
//
// This is a convenience alternative to [aix.InterruptibleTool.Resume] that
// does not require access to the tool definition.
func Resume[Res any](interruptedPart *ai.Part, data Res) (*ai.Part, error) {
if interruptedPart == nil || !interruptedPart.IsInterrupt() {
return nil, fmt.Errorf("tool.Resume: part is not an interrupted tool request")
}

m, err := base.StructToMap(data)
if err != nil {
return nil, fmt.Errorf("tool.Resume: %w", err)
}

newMeta := maps.Clone(interruptedPart.Metadata)
if newMeta == nil {
newMeta = make(map[string]any)
}
newMeta["resumed"] = m
delete(newMeta, "interrupt")

newPart := ai.NewToolRequestPart(&ai.ToolRequest{
Name: interruptedPart.ToolRequest.Name,
Ref: interruptedPart.ToolRequest.Ref,
Input: interruptedPart.ToolRequest.Input,
})
newPart.Metadata = newMeta
return newPart, nil
}

// --- Respond ---

// Respond creates a tool response [ai.Part] for an interrupted tool request.
// Instead of re-executing the tool (as [Resume] does), this provides a
// pre-computed result directly.
//
// This is a convenience alternative to [aix.Tool.Respond] that does not
// require access to the tool definition.
func Respond(interruptedPart *ai.Part, output any) (*ai.Part, error) {
if interruptedPart == nil || !interruptedPart.IsInterrupt() {
return nil, fmt.Errorf("tool.Respond: part is not an interrupted tool request")
}

resp := ai.NewResponseForToolRequest(interruptedPart, output)
resp.Metadata = map[string]any{
"interruptResponse": true,
}
return resp, nil
}

// --- SendPartial ---

// SendPartial streams a partial tool response during tool execution.
// The output is arbitrary structured data (e.g., progress information)
// that will be delivered to the client as a partial [ai.ToolResponse].
//
// This is best-effort: if no streaming callback is available (e.g., the
// tool is called via a non-streaming Generate), the call is a no-op.
// The tool's final return value is always the authoritative response.
//
// Example:
//
// tool.SendPartial(ctx, map[string]any{"step": "uploading", "progress": 50})
func SendPartial(ctx context.Context, output any) {
send := base.ToolPartialSenderKey.FromContext(ctx)
if send == nil {
return
}
send(ctx, output)
}

// --- SendChunk ---

// SendChunk streams a raw [ai.ModelResponseChunk] during tool execution.
// Unlike [SendPartial], which wraps arbitrary data in a partial tool response,
// SendChunk gives the tool full control over the chunk contents.
//
// This is best-effort: if no streaming callback is available (e.g., the
// tool is called via a non-streaming Generate), the call is a no-op.
// The tool's final return value is always the authoritative response.
func SendChunk(ctx context.Context, chunk *ai.ModelResponseChunk) {
send := base.ToolChunkSenderKey.FromContext(ctx)
if send == nil {
return
}
send(ctx, chunk)
}

// --- AttachParts ---

// AttachParts attaches additional content parts (e.g., media) to the tool's
// response. This can be called from any tool to produce a multipart response
// without changing the function signature.
func AttachParts(ctx context.Context, parts ...*ai.Part) {
c := partsCollectorKey.FromContext(ctx)
if c == nil {
return
}
c.parts = append(c.parts, parts...)
}

// --- OriginalInput ---

// OriginalInput returns the original input if the caller replaced it during
// restart. Returns the zero value and false if the input was not replaced
// or the tool is not being resumed.
func OriginalInput[In any](ctx context.Context) (In, bool) {
v := originalInputKey.FromContext(ctx)
if v == nil {
var zero In
return zero, false
}
return base.ConvertTo[In](v)
}

// --- Internal plumbing (used by the aix package, not for end users) ---

var originalInputKey = base.NewContextKey[any]()
var partsCollectorKey = base.NewContextKey[*partsCollector]()

// partsCollector accumulates content parts attached during tool execution.
type partsCollector struct {
parts []*ai.Part
}

// SetOriginalInput stores the original input in the context.
// This is internal plumbing used by the aix package.
func SetOriginalInput(ctx context.Context, input any) context.Context {
return originalInputKey.NewContext(ctx, input)
}

// NewPartsContext returns a context with a parts collector.
// This is internal plumbing used by the aix package.
func NewPartsContext(ctx context.Context) (context.Context, func() []*ai.Part) {
c := &partsCollector{}
ctx = partsCollectorKey.NewContext(ctx, c)
return ctx, func() []*ai.Part { return c.parts }
}
Loading
Loading