diff --git a/genkit-tools/common/src/types/agent.ts b/genkit-tools/common/src/types/agent.ts new file mode 100644 index 0000000000..fb8c2ce9ec --- /dev/null +++ b/genkit-tools/common/src/types/agent.ts @@ -0,0 +1,118 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from 'zod'; +import { MessageSchema, ModelResponseChunkSchema } from './model'; +import { PartSchema } from './parts'; + +/** + * Zod schema for an artifact produced during a session. + */ +export const ArtifactSchema = z.object({ + /** Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). */ + name: z.string().optional(), + /** Parts contains the artifact content (text, media, etc.). */ + parts: z.array(PartSchema), + /** Metadata contains additional artifact-specific data. */ + metadata: z.record(z.any()).optional(), +}); +export type Artifact = z.infer; + +/** + * Zod schema for snapshot event. + */ +export const SnapshotEventSchema = z.enum(['turnEnd', 'invocationEnd']); +export type SnapshotEvent = z.infer; + +/** + * Zod schema for session state. + */ +export const SessionStateSchema = z.object({ + /** Conversation history (user/model exchanges). */ + messages: z.array(MessageSchema).optional(), + /** User-defined state associated with this conversation. */ + custom: z.any().optional(), + /** Named collections of parts produced during the conversation. */ + artifacts: z.array(ArtifactSchema).optional(), + /** Input used for agent flows that require input variables. */ + inputVariables: z.any().optional(), +}); +export type SessionState = z.infer; + +/** + * Zod schema for agent flow input (per-turn). + */ +export const AgentFlowInputSchema = z.object({ + /** User's input messages for this turn. */ + messages: z.array(MessageSchema).optional(), + /** Tool request parts to re-execute interrupted tools. */ + toolRestarts: z.array(PartSchema).optional(), +}); +export type AgentFlowInput = z.infer; + +/** + * Zod schema for agent flow initialization. + */ +export const AgentFlowInitSchema = z.object({ + /** Loads state from a persisted snapshot. Mutually exclusive with state. */ + snapshotId: z.string().optional(), + /** Direct state for the invocation. Mutually exclusive with snapshotId. */ + state: SessionStateSchema.optional(), +}); +export type AgentFlowInit = z.infer; + +/** + * Zod schema for agent flow result. + */ +export const AgentFlowResultSchema = z.object({ + /** Last model response message from the conversation. */ + message: MessageSchema.optional(), + /** Artifacts produced during the session. */ + artifacts: z.array(ArtifactSchema).optional(), +}); +export type AgentFlowResult = z.infer; + +/** + * Zod schema for agent flow output. + */ +export const AgentFlowOutputSchema = z.object({ + /** ID of the snapshot created at the end of this invocation. */ + snapshotId: z.string().optional(), + /** Final conversation state (only when client-managed). */ + state: SessionStateSchema.optional(), + /** Last model response message from the conversation. */ + message: MessageSchema.optional(), + /** Artifacts produced during the session. */ + artifacts: z.array(ArtifactSchema).optional(), +}); +export type AgentFlowOutput = z.infer; + +/** + * Zod schema for agent flow stream chunk. + */ +export const AgentFlowStreamChunkSchema = z.object({ + /** Generation tokens from the model. */ + modelChunk: ModelResponseChunkSchema.optional(), + /** User-defined structured status information. */ + status: z.any().optional(), + /** A newly produced artifact. */ + artifact: ArtifactSchema.optional(), + /** ID of a snapshot that was just persisted. */ + snapshotId: z.string().optional(), + /** Signals that the agent flow has finished processing the current input. */ + endTurn: z.boolean().optional(), +}); +export type AgentFlowStreamChunk = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 26cc4fbf4f..bae17a88ae 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -1,6 +1,140 @@ { "$schema": "http://json-schema.org/draft-07/schema#", "$defs": { + "AgentFlowInit": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + } + }, + "additionalProperties": false + }, + "AgentFlowInput": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "$ref": "#/$defs/Message" + } + }, + "toolRestarts": { + "type": "array", + "items": { + "$ref": "#/$defs/Part" + } + } + }, + "additionalProperties": false + }, + "AgentFlowOutput": { + "type": "object", + "properties": { + "snapshotId": { + "type": "string" + }, + "state": { + "$ref": "#/$defs/SessionState" + }, + "message": { + "$ref": "#/$defs/Message" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + } + }, + "additionalProperties": false + }, + "AgentFlowResult": { + "type": "object", + "properties": { + "message": { + "$ref": "#/$defs/Message" + }, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + } + }, + "additionalProperties": false + }, + "AgentFlowStreamChunk": { + "type": "object", + "properties": { + "modelChunk": { + "$ref": "#/$defs/ModelResponseChunk" + }, + "status": {}, + "artifact": { + "$ref": "#/$defs/Artifact" + }, + "snapshotId": { + "type": "string" + }, + "endTurn": { + "type": "boolean" + } + }, + "additionalProperties": false + }, + "Artifact": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parts": { + "type": "array", + "items": { + "$ref": "#/$defs/Part" + } + }, + "metadata": { + "type": "object", + "additionalProperties": {} + } + }, + "required": [ + "parts" + ], + "additionalProperties": false + }, + "SessionState": { + "type": "object", + "properties": { + "messages": { + "type": "array", + "items": { + "$ref": "#/$defs/Message" + } + }, + "custom": {}, + "artifacts": { + "type": "array", + "items": { + "$ref": "#/$defs/Artifact" + } + }, + "inputVariables": {} + }, + "additionalProperties": false + }, + "SnapshotEvent": { + "type": "string", + "enum": [ + "turnEnd", + "invocationEnd" + ] + }, "DocumentData": { "type": "object", "properties": { diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 48df79b56a..e12e44fc8e 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -22,6 +22,7 @@ import { zodToJsonSchema } from 'zod-to-json-schema'; /** List of files that contain types to be exported. */ const EXPORTED_TYPE_MODULES = [ + '../common/src/types/agent.ts', '../common/src/types/document.ts', '../common/src/types/embedder.ts', '../common/src/types/evaluator.ts', diff --git a/go/ai/exp/agent.go b/go/ai/exp/agent.go new file mode 100644 index 0000000000..c25d46e5aa --- /dev/null +++ b/go/ai/exp/agent.go @@ -0,0 +1,663 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package exp provides experimental AI primitives for Genkit. +// +// APIs in this package are under active development and may change in any +// minor version release. +package exp + +import ( + "context" + "fmt" + "iter" + "sync" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/logger" + "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal/base" + "github.com/google/uuid" +) + +// --- AgentSession --- + +// AgentSession extends Session with agent-flow-specific functionality: +// turn management, snapshot persistence, and input channel handling. +type AgentSession[State any] struct { + *Session[State] + + // InputCh is the channel that delivers per-turn inputs from the client. + // It is consumed automatically by [AgentSession.Run], but is exposed + // for advanced use cases that need direct access to the input stream + // (e.g., custom turn loops or fan-out patterns). + InputCh <-chan *AgentFlowInput + // TurnIndex is the zero-based index of the current conversation turn. + // It is incremented automatically by [AgentSession.Run], but is exposed + // for advanced use cases that need to track or manipulate turn ordering + // directly. + TurnIndex int + + snapshotCallback SnapshotCallback[State] + onEndTurn func(ctx context.Context) + lastSnapshot *SessionSnapshot[State] + lastSnapshotVersion uint64 + collectTurnOutput func() any +} + +// Run loops over the input channel, calling fn for each turn. Each turn is +// wrapped in a trace span for observability. Input messages are automatically +// added to the session before fn is called. After fn returns successfully, an +// EndTurn chunk is sent and a snapshot check is triggered. +func (a *AgentSession[State]) Run(ctx context.Context, fn func(ctx context.Context, input *AgentFlowInput) error) error { + for input := range a.InputCh { + spanMeta := &tracing.SpanMetadata{ + Name: fmt.Sprintf("agentFlow/turn/%d", a.TurnIndex), + Type: "agentFlowTurn", + Subtype: "agentFlowTurn", + } + + _, err := tracing.RunInNewSpan(ctx, spanMeta, input, + func(ctx context.Context, input *AgentFlowInput) (any, error) { + a.AddMessages(input.Messages...) + + if err := fn(ctx, input); err != nil { + return nil, err + } + + a.onEndTurn(ctx) + a.TurnIndex++ + + if a.collectTurnOutput != nil { + return a.collectTurnOutput(), nil + } + return nil, nil + }, + ) + if err != nil { + return err + } + } + return nil +} + +// Result returns an [AgentFlowResult] populated from the current session state: +// the last message in the conversation history and all artifacts. +// It is a convenience for custom agent flows that don't need to construct the +// result manually. +func (a *AgentSession[State]) Result() *AgentFlowResult { + a.mu.RLock() + defer a.mu.RUnlock() + + result := &AgentFlowResult{} + if msgs := a.state.Messages; len(msgs) > 0 { + result.Message = msgs[len(msgs)-1] + } + if len(a.state.Artifacts) > 0 { + arts := make([]*Artifact, len(a.state.Artifacts)) + copy(arts, a.state.Artifacts) + result.Artifacts = arts + } + return result +} + +// maybeSnapshot creates a snapshot if conditions are met (store configured, +// callback approves, state changed). Returns the snapshot ID or empty string. +func (a *AgentSession[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { + if a.store == nil { + return "" + } + + a.mu.RLock() + currentVersion := a.version + currentState := a.copyStateLocked() + a.mu.RUnlock() + + // Skip if state hasn't changed since the last snapshot. This avoids + // redundant snapshots, e.g. the invocation-end snapshot after a + // single-turn Run where the turn-end snapshot already captured the + // same state. + if a.lastSnapshot != nil && currentVersion == a.lastSnapshotVersion { + return "" + } + + if a.snapshotCallback != nil { + var prevState *SessionState[State] + if a.lastSnapshot != nil { + prevState = &a.lastSnapshot.State + } + if !a.snapshotCallback(ctx, &SnapshotContext[State]{ + State: ¤tState, + PrevState: prevState, + TurnIndex: a.TurnIndex, + Event: event, + }) { + return "" + } + } + + snapshot := &SessionSnapshot[State]{ + SnapshotID: uuid.New().String(), + CreatedAt: time.Now(), + Event: event, + State: currentState, + } + if a.lastSnapshot != nil { + snapshot.ParentID = a.lastSnapshot.SnapshotID + } + + if err := a.store.SaveSnapshot(ctx, snapshot); err != nil { + logger.FromContext(ctx).Error("agent flow: failed to save snapshot", "err", err) + return "" + } + + // Set snapshotId in last message metadata. + a.mu.Lock() + if msgs := a.state.Messages; len(msgs) > 0 { + lastMsg := msgs[len(msgs)-1] + if lastMsg.Metadata == nil { + lastMsg.Metadata = make(map[string]any) + } + lastMsg.Metadata["snapshotId"] = snapshot.SnapshotID + } + a.mu.Unlock() + + a.lastSnapshot = snapshot + a.lastSnapshotVersion = currentVersion + + return snapshot.SnapshotID +} + +// --- Responder --- + +// Responder is the output channel for an agent flow. Artifacts sent through +// it are automatically added to the session before being forwarded to the +// client. +type Responder[Stream any] chan<- *AgentFlowStreamChunk[Stream] + +// SendModelChunk sends a generation chunk (token-level streaming). +func (r Responder[Stream]) SendModelChunk(chunk *ai.ModelResponseChunk) { + r <- &AgentFlowStreamChunk[Stream]{ModelChunk: chunk} +} + +// SendStatus sends a user-defined status update. +func (r Responder[Stream]) SendStatus(status Stream) { + r <- &AgentFlowStreamChunk[Stream]{Status: status} +} + +// SendArtifact sends an artifact to the stream and adds it to the session. +// If an artifact with the same name already exists in the session, it is replaced. +func (r Responder[Stream]) SendArtifact(artifact *Artifact) { + r <- &AgentFlowStreamChunk[Stream]{Artifact: artifact} +} + +// --- AgentFlow --- + +// AgentFlowFunc is the function signature for agent flows. +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +type AgentFlowFunc[Stream, State any] = func(ctx context.Context, resp Responder[Stream], sess *AgentSession[State]) (*AgentFlowResult, error) + +// AgentFlow is a bidirectional streaming flow with automatic snapshot management. +type AgentFlow[Stream, State any] struct { + flow *core.Flow[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream], *AgentFlowInit[State]] +} + +// DefineCustomAgent creates an AgentFlow with automatic snapshot management and registers it. +func DefineCustomAgent[Stream, State any]( + r api.Registry, + name string, + fn AgentFlowFunc[Stream, State], + opts ...AgentFlowOption[State], +) *AgentFlow[Stream, State] { + afOpts := &agentFlowOptions[State]{} + for _, opt := range opts { + if err := opt.applyAgentFlow(afOpts); err != nil { + panic(fmt.Errorf("DefineCustomAgent %q: %w", name, err)) + } + } + + store := afOpts.store + snapshotCallback := afOpts.callback + + flow := core.DefineBidiFlow(r, name, func( + ctx context.Context, + init *AgentFlowInit[State], + inCh <-chan *AgentFlowInput, + outCh chan<- *AgentFlowStreamChunk[Stream], + ) (*AgentFlowOutput[State], error) { + session, snapshot, err := newSessionFromInit(ctx, init, store) + if err != nil { + return nil, err + } + ctx = NewSessionContext(ctx, session) + + agentSess := &AgentSession[State]{ + Session: session, + snapshotCallback: snapshotCallback, + InputCh: inCh, + lastSnapshot: snapshot, + } + + // Turn output accumulator: collects content chunks per turn for span output. + var ( + turnMu sync.Mutex + turnChunks []*AgentFlowStreamChunk[Stream] + ) + + agentSess.collectTurnOutput = func() any { + turnMu.Lock() + defer turnMu.Unlock() + result := turnChunks + turnChunks = nil + return result + } + + // Intermediary channel: intercepts artifacts, accumulates turn output, + // and forwards to outCh. + respCh := make(chan *AgentFlowStreamChunk[Stream]) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for chunk := range respCh { + if chunk.Artifact != nil { + session.AddArtifacts(chunk.Artifact) + } + // Accumulate content chunks (exclude control signals from onEndTurn). + if !chunk.EndTurn && chunk.SnapshotID == "" { + turnMu.Lock() + turnChunks = append(turnChunks, chunk) + turnMu.Unlock() + } + outCh <- chunk + } + }() + + // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. + // Writes through respCh to preserve ordering with user chunks. + agentSess.onEndTurn = func(turnCtx context.Context) { + snapshotID := agentSess.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) + if snapshotID != "" { + respCh <- &AgentFlowStreamChunk[Stream]{SnapshotID: snapshotID} + } + respCh <- &AgentFlowStreamChunk[Stream]{EndTurn: true} + } + + result, fnErr := fn(ctx, Responder[Stream](respCh), agentSess) + close(respCh) + wg.Wait() + + if fnErr != nil { + return nil, fnErr + } + + // Final snapshot at invocation end. If skipped (state unchanged + // since last turn-end snapshot), use the last snapshot's ID so + // the output always reflects the latest snapshot. + snapshotID := agentSess.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + if snapshotID == "" && agentSess.lastSnapshot != nil { + snapshotID = agentSess.lastSnapshot.SnapshotID + } + + out := &AgentFlowOutput[State]{ + SnapshotID: snapshotID, + } + if result != nil { + out.Message = result.Message + out.Artifacts = result.Artifacts + } + + // Only include full state when client-managed (no store). + if store == nil { + out.State = session.State() + } + + return out, nil + }) + + return &AgentFlow[Stream, State]{flow: flow} +} + +// promptMessageKey is the metadata key used to tag prompt-rendered messages +// so they can be excluded from session history after generation. +const promptMessageKey = "_genkit_prompt" + +// DefinePromptAgent creates a prompt-backed AgentFlow with an +// automatic conversation loop. Each turn renders the prompt, appends +// conversation history, calls GenerateWithRequest, streams chunks to the +// client, and adds the model response to the session. +// +// The prompt is looked up by name from the registry using +// [ai.LookupDataPrompt]. The defaultInput is used for prompt rendering +// unless overridden per invocation via WithInputVariables. +func DefinePromptAgent[State, PromptIn any]( + r api.Registry, + promptName string, + defaultInput PromptIn, + opts ...AgentFlowOption[State], +) *AgentFlow[any, State] { + p := ai.LookupDataPrompt[PromptIn, string](r, promptName) + if p == nil { + panic(fmt.Sprintf("DefinePromptAgent: prompt %q not found", promptName)) + } + + fn := func(ctx context.Context, resp Responder[any], sess *AgentSession[State]) (*AgentFlowResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + // Resolve prompt input: session state override > default. + promptInput := defaultInput + if stored := sess.InputVariables(); stored != nil { + typed, ok := base.ConvertTo[PromptIn](stored) + if !ok { + return core.NewError(core.INVALID_ARGUMENT, "input variables type mismatch: got %T, want %T", stored, promptInput) + } + promptInput = typed + } + + // Render the prompt template. + genOpts, err := p.Render(ctx, promptInput) + if err != nil { + return fmt.Errorf("prompt render: %w", err) + } + + // Tag prompt-rendered messages so we can exclude them from + // session history after generation. + for _, m := range genOpts.Messages { + if m.Metadata == nil { + m.Metadata = make(map[string]any) + } + m.Metadata[promptMessageKey] = true + } + + // Append conversation history after the prompt-rendered messages. + genOpts.Messages = append(genOpts.Messages, sess.Messages()...) + + // If tool restarts were provided, set the resume option so + // handleResumeOption re-executes the interrupted tools. + if len(input.ToolRestarts) > 0 { + for _, p := range input.ToolRestarts { + if !p.IsToolRequest() { + return core.NewError(core.INVALID_ARGUMENT, "ToolRestarts: part is not a tool request") + } + } + genOpts.Resume = ai.NewResume(input.ToolRestarts, nil) + } + + // Call the model with streaming. + modelResp, err := ai.GenerateWithRequest(ctx, r, genOpts, nil, + func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + resp.SendModelChunk(chunk) + return nil + }, + ) + if err != nil { + return fmt.Errorf("generate: %w", err) + } + + // Replace session messages with the full history minus prompt + // messages. This captures intermediate tool call/response + // messages from the tool loop, not just the final response. + if modelResp.Request != nil { + var msgs []*ai.Message + for _, m := range modelResp.History() { + if m.Metadata != nil && m.Metadata[promptMessageKey] == true { + continue + } + msgs = append(msgs, m) + } + sess.SetMessages(msgs) + } else if modelResp.Message != nil { + sess.AddMessages(modelResp.Message) + } + + // Stream interrupt parts so the client can detect and + // handle them (e.g. prompt the user for confirmation). + if modelResp.FinishReason == ai.FinishReasonInterrupted { + if parts := modelResp.Interrupts(); len(parts) > 0 { + resp.SendModelChunk(&ai.ModelResponseChunk{ + Role: ai.RoleTool, + Content: parts, + }) + } + } + + return nil + }); err != nil { + return nil, err + } + return sess.Result(), nil + } + + return DefineCustomAgent(r, promptName, fn, opts...) +} + +// StreamBidi starts a new agent flow invocation with bidirectional streaming. +// Use this for multi-turn interactions where you need to send multiple inputs +// and receive streaming chunks. For single-turn usage, see Run and RunText. +func (af *AgentFlow[Stream, State]) StreamBidi( + ctx context.Context, + opts ...InvocationOption[State], +) (*AgentFlowConnection[Stream, State], error) { + invOpts, err := af.resolveOptions(opts) + if err != nil { + return nil, err + } + + conn, err := af.flow.StreamBidi(ctx, invOpts) + if err != nil { + return nil, err + } + + return &AgentFlowConnection[Stream, State]{conn: conn}, nil +} + +// Run starts a single-turn agent flow invocation with the given input. +// It sends the input, waits for the flow to complete, and returns the output. +// For multi-turn interactions or streaming, use StreamBidi instead. +func (af *AgentFlow[Stream, State]) Run( + ctx context.Context, + input *AgentFlowInput, + opts ...InvocationOption[State], +) (*AgentFlowOutput[State], error) { + conn, err := af.StreamBidi(ctx, opts...) + if err != nil { + return nil, err + } + + if err := conn.Send(input); err != nil { + return nil, err + } + if err := conn.Close(); err != nil { + return nil, err + } + + // Drain stream chunks. + for _, err := range conn.Receive() { + if err != nil { + return nil, err + } + } + + return conn.Output() +} + +// RunText is a convenience method that starts a single-turn agent flow +// invocation with a user text message. It is equivalent to calling Run with +// an AgentFlowInput containing a single user text message. +func (af *AgentFlow[Stream, State]) RunText( + ctx context.Context, + text string, + opts ...InvocationOption[State], +) (*AgentFlowOutput[State], error) { + return af.Run(ctx, &AgentFlowInput{ + Messages: []*ai.Message{ai.NewUserTextMessage(text)}, + }, opts...) +} + +// resolveOptions applies invocation options and returns the init struct. +func (af *AgentFlow[Stream, State]) resolveOptions(opts []InvocationOption[State]) (*AgentFlowInit[State], error) { + invOpts := &invocationOptions[State]{} + for _, opt := range opts { + if err := opt.applyInvocation(invOpts); err != nil { + return nil, fmt.Errorf("AgentFlow %q: %w", af.flow.Name(), err) + } + } + + init := &AgentFlowInit[State]{ + SnapshotID: invOpts.snapshotID, + State: invOpts.state, + } + if invOpts.promptInput != nil { + if init.State == nil { + init.State = &SessionState[State]{} + } + init.State.InputVariables = invOpts.promptInput + } + + return init, nil +} + +// newSessionFromInit creates a Session from initialization data. +// If resuming from a snapshot, the loaded snapshot is also returned. +func newSessionFromInit[State any]( + ctx context.Context, + init *AgentFlowInit[State], + store SessionStore[State], +) (*Session[State], *SessionSnapshot[State], error) { + s := &Session[State]{store: store} + + var snapshot *SessionSnapshot[State] + if init != nil { + if init.SnapshotID != "" && init.State != nil { + return nil, nil, core.NewError(core.INVALID_ARGUMENT, "snapshot ID and state are mutually exclusive") + } + if init.SnapshotID != "" && store == nil { + return nil, nil, core.NewError(core.FAILED_PRECONDITION, "snapshot ID %q provided but no session store configured", init.SnapshotID) + } + if init.SnapshotID != "" && store != nil { + var err error + snapshot, err = store.GetSnapshot(ctx, init.SnapshotID) + if err != nil { + return nil, nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) + } + if snapshot == nil { + return nil, nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) + } + s.state = snapshot.State + } else if init.State != nil { + s.state = *init.State + } + } + + return s, snapshot, nil +} + +// --- AgentFlowConnection --- + +// AgentFlowConnection wraps BidiConnection with agent flow-specific functionality. +// It provides a Receive() iterator that supports multi-turn patterns: breaking out +// of the iterator between turns does not cancel the underlying connection. +type AgentFlowConnection[Stream, State any] struct { + conn *core.BidiConnection[*AgentFlowInput, *AgentFlowOutput[State], *AgentFlowStreamChunk[Stream]] + // chunks buffers stream chunks from the underlying connection so that + // breaking from Receive() between turns doesn't cancel the context. + chunks chan *AgentFlowStreamChunk[Stream] + chunkErr error + initOnce sync.Once +} + +// initReceiver starts a goroutine that drains the underlying BidiConnection's +// Receive into a channel. This goroutine never breaks from the underlying +// iterator, preventing context cancellation. +func (c *AgentFlowConnection[Stream, State]) initReceiver() { + c.initOnce.Do(func() { + c.chunks = make(chan *AgentFlowStreamChunk[Stream], 1) + go func() { + defer close(c.chunks) + for chunk, err := range c.conn.Receive() { + if err != nil { + c.chunkErr = err + return + } + c.chunks <- chunk + } + }() + }) +} + +// Send sends an AgentFlowInput to the agent flow. +func (c *AgentFlowConnection[Stream, State]) Send(input *AgentFlowInput) error { + return c.conn.Send(input) +} + +// SendMessages sends messages to the agent flow. +func (c *AgentFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { + return c.conn.Send(&AgentFlowInput{Messages: messages}) +} + +// SendText sends a single user text message to the agent flow. +func (c *AgentFlowConnection[Stream, State]) SendText(text string) error { + return c.conn.Send(&AgentFlowInput{ + Messages: []*ai.Message{ai.NewUserTextMessage(text)}, + }) +} + +// SendToolRestarts sends tool restart parts to resume interrupted tool calls. +// Parts should be created via [ai.ToolDef.RestartWith]. +func (c *AgentFlowConnection[Stream, State]) SendToolRestarts(parts ...*ai.Part) error { + return c.conn.Send(&AgentFlowInput{ToolRestarts: parts}) +} + +// Close signals that no more inputs will be sent. +func (c *AgentFlowConnection[Stream, State]) Close() error { + return c.conn.Close() +} + +// Receive returns an iterator for receiving stream chunks. +// Unlike the underlying BidiConnection.Receive, breaking out of this iterator +// does not cancel the connection. This enables multi-turn patterns where the +// caller breaks on EndTurn, sends the next input, then calls Receive again. +func (c *AgentFlowConnection[Stream, State]) Receive() iter.Seq2[*AgentFlowStreamChunk[Stream], error] { + c.initReceiver() + return func(yield func(*AgentFlowStreamChunk[Stream], error) bool) { + for { + chunk, ok := <-c.chunks + if !ok { + if err := c.chunkErr; err != nil { + yield(nil, err) + } + return + } + if !yield(chunk, nil) { + return + } + } + } +} + +// Output returns the final response after the agent flow completes. +func (c *AgentFlowConnection[Stream, State]) Output() (*AgentFlowOutput[State], error) { + return c.conn.Output() +} + +// Done returns a channel closed when the connection completes. +func (c *AgentFlowConnection[Stream, State]) Done() <-chan struct{} { + return c.conn.Done() +} diff --git a/go/ai/exp/agent_test.go b/go/ai/exp/agent_test.go new file mode 100644 index 0000000000..d2d75c232f --- /dev/null +++ b/go/ai/exp/agent_test.go @@ -0,0 +1,1668 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/registry" +) + +type testState struct { + Counter int `json:"counter"` + Topics []string `json:"topics,omitempty"` +} + +type testStatus struct { + Phase string `json:"phase"` +} + +func newTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + return registry.New() +} + +func TestAgentFlow_BasicMultiTurn(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "basicFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + resp.SendStatus(testStatus{Phase: "generating"}) + // Echo back the user's message. + if len(input.Messages) > 0 { + reply := ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text) + sess.AddMessages(reply) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + resp.SendStatus(testStatus{Phase: "complete"}) + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + var turn1Chunks int + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + turn1Chunks++ + if chunk.EndTurn { + break + } + } + if turn1Chunks < 2 { // at least status + endTurn + t.Errorf("expected at least 2 chunks in turn 1, got %d", turn1Chunks) + } + + // Turn 2. + if err := conn.SendText("world"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 user messages + 2 echo replies = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + if got := response.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } +} + +func TestAgentFlow_WithSessionStore(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "snapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("turn1") + + var snapshotIDs []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotID) + } + if chunk.EndTurn { + break + } + } + + if len(snapshotIDs) != 1 { + t.Fatalf("expected 1 snapshot from turn, got %d", len(snapshotIDs)) + } + + // Verify the snapshot was persisted. + snap, err := store.GetSnapshot(ctx, snapshotIDs[0]) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap == nil { + t.Fatal("expected snapshot, got nil") + } + if snap.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 in snapshot, got %d", snap.State.Custom.Counter) + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Final snapshot at invocation end. + if response.SnapshotID == "" { + t.Error("expected final snapshot ID") + } +} + +func TestAgentFlow_ResumeFromSnapshot(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "resumeFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + // First invocation: create a snapshot. + conn1, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + conn1.SendText("first message") + for chunk, err := range conn1.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn1.Close() + resp1, err := conn1.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + if resp1.SnapshotID == "" { + t.Fatal("expected snapshot ID from first invocation") + } + + // Second invocation: resume from snapshot. + conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp1.SnapshotID)) + if err != nil { + t.Fatalf("StreamBidi with snapshot failed: %v", err) + } + conn2.SendText("continued message") + for chunk, err := range conn2.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn2.Close() + resp2, err := conn2.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // The new snapshot should reference the previous as parent. + if resp2.SnapshotID == "" { + t.Fatal("expected snapshot ID from second invocation") + } + snap2, err := store.GetSnapshot(ctx, resp2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + + // Should have messages from both invocations: + // first: user + reply (2) + second: user + reply (2) = 4. + if got := len(snap2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + // Counter should be 2 (1 from first + 1 from second). + if got := snap2.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } + // The parent chain: snap2's parent is a turn-end snapshot from the second invocation, + // which itself has a parent from the first invocation's final snapshot. + // We just verify that the parent chain exists (not empty). + if snap2.ParentID == "" { + t.Error("expected parent ID on resumed snapshot") + } +} + +func TestAgentFlow_ClientManagedState(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "clientStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + ) + + // Start with client-provided state. + clientState := &SessionState[testState]{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("previous message"), + ai.NewModelTextMessage("previous reply"), + }, + Custom: testState{Counter: 5}, + } + + conn, err := af.StreamBidi(ctx, WithState(clientState)) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("new message") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 previous + 1 new user + 1 reply = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + // Counter should be 6 (started at 5, incremented once). + if got := response.State.Custom.Counter; got != 6 { + t.Errorf("expected counter=6, got %d", got) + } + // No snapshot since no store was configured. + if response.SnapshotID != "" { + t.Errorf("expected no snapshot ID without store, got %q", response.SnapshotID) + } +} + +func TestAgentFlow_Artifacts(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "artifactFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + + resp.SendArtifact(&Artifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("package main")}, + }) + + // Replace artifact with same name. + resp.SendArtifact(&Artifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("package main\nfunc main() {}")}, + }) + + // Add another artifact. + resp.SendArtifact(&Artifact{ + Name: "readme.md", + Parts: []*ai.Part{ai.NewTextPart("# README")}, + }) + + sess.AddMessages(ai.NewModelTextMessage("done")) + return nil + }) + if err != nil { + return nil, err + } + return &AgentFlowResult{Artifacts: sess.Artifacts()}, nil + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("generate code") + var receivedArtifacts []*Artifact + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Artifact != nil { + receivedArtifacts = append(receivedArtifacts, chunk.Artifact) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + if len(receivedArtifacts) != 3 { // all 3 sends are streamed + t.Errorf("expected 3 streamed artifacts, got %d", len(receivedArtifacts)) + } + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Output should have 2 unique artifacts (code.go was replaced). + if got := len(response.Artifacts); got != 2 { + t.Errorf("expected 2 artifacts, got %d", got) + } +} + +func TestAgentFlow_SnapshotCallback(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + // Only snapshot on even turns. + callbackCalls := 0 + af := DefineCustomAgent(reg, "callbackFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + WithSnapshotCallback(func(ctx context.Context, sc *SnapshotContext[testState]) bool { + callbackCalls++ + return sc.TurnIndex%2 == 0 // only snapshot on even turns + }), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + var snapshotIDs []string + for i := 0; i < 3; i++ { + conn.SendText(fmt.Sprintf("turn %d", i)) + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error on turn %d: %v", i, err) + } + if chunk.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotID) + } + if chunk.EndTurn { + break + } + } + } + conn.Close() + conn.Output() // drain + + // Turn 0 (even) → snapshot, Turn 1 (odd) → no, Turn 2 (even) → snapshot. + // That's 2 turn snapshots from the callback. + if got := len(snapshotIDs); got != 2 { + t.Errorf("expected 2 turn snapshots, got %d", got) + } + // Callback should have been called 3 times (once per turn). + if callbackCalls < 3 { + t.Errorf("expected at least 3 callback calls, got %d", callbackCalls) + } +} + +func TestAgentFlow_SendMessages(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "sendMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Send multiple messages at once. + err = conn.SendMessages( + ai.NewUserTextMessage("msg1"), + ai.NewUserTextMessage("msg2"), + ) + if err != nil { + t.Fatalf("SendMessages failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Both messages should have been added. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + } +} + +func TestAgentFlow_SessionContext(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + var retrievedCounter int + af := DefineCustomAgent(reg, "contextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + // Session should be retrievable from context. + ctxSess := SessionFromContext[testState](ctx) + if ctxSess == nil { + t.Error("expected session from context") + return nil + } + ctxSess.UpdateCustom(func(s testState) testState { + s.Counter = 42 + return s + }) + retrievedCounter = ctxSess.Custom().Counter + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("test") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + conn.Output() + + if retrievedCounter != 42 { + t.Errorf("expected counter=42 from context, got %d", retrievedCounter) + } +} + +func TestAgentFlow_ErrorInTurn(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "errorFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + return fmt.Errorf("turn failed") + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("trigger error") + conn.Close() + + _, err = conn.Output() + if err == nil { + t.Fatal("expected error from failed turn") + } +} + +func TestAgentFlow_SetMessages(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "setMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + // Replace all messages with just one. + sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("original") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // SetMessages replaced everything with 1 message. + if got := len(response.State.Messages); got != 1 { + t.Errorf("expected 1 message after SetMessages, got %d", got) + } +} + +func TestAgentFlow_SnapshotIDInMessageMetadata(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "metadataFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + if err != nil { + return nil, err + } + msgs := sess.Messages() + return &AgentFlowResult{Message: msgs[len(msgs)-1]}, nil + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // The last model message should have snapshotId in its metadata. + if response.Message == nil { + t.Fatal("expected Message in response") + } + if response.Message.Metadata == nil { + t.Fatal("expected metadata on last message") + } + if _, ok := response.Message.Metadata["snapshotId"]; !ok { + t.Error("expected snapshotId in last message metadata") + } +} + +func TestInMemorySessionStore(t *testing.T) { + ctx := context.Background() + store := NewInMemorySessionStore[testState]() + + // Get non-existent. + snap, err := store.GetSnapshot(ctx, "nonexistent") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap != nil { + t.Errorf("expected nil, got %v", snap) + } + + // Save and retrieve. + snapshot := &SessionSnapshot[testState]{ + SnapshotID: "snap-1", + State: SessionState[testState]{ + Custom: testState{Counter: 1}, + }, + } + if err := store.SaveSnapshot(ctx, snapshot); err != nil { + t.Fatalf("SaveSnapshot failed: %v", err) + } + + retrieved, err := store.GetSnapshot(ctx, "snap-1") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if retrieved == nil { + t.Fatal("expected snapshot") + } + if retrieved.State.Custom.Counter != 1 { + t.Errorf("expected counter=1, got %d", retrieved.State.Custom.Counter) + } + + // Verify isolation. + snapshot.State.Custom.Counter = 999 + retrieved2, _ := store.GetSnapshot(ctx, "snap-1") + if retrieved2.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) + } +} + +func TestAgentFlow_TurnSpanOutput(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + var capturedOutputs []any + + af := DefineCustomAgent(reg, "turnOutputFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + // Wrap collectTurnOutput to capture what each turn produces. + originalCollect := sess.collectTurnOutput + sess.collectTurnOutput = func() any { + output := originalCollect() + capturedOutputs = append(capturedOutputs, output) + return output + } + + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + resp.SendStatus(testStatus{Phase: "thinking"}) + resp.SendModelChunk(&ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart("reply")}, + }) + resp.SendArtifact(&Artifact{ + Name: "out.txt", + Parts: []*ai.Part{ai.NewTextPart("content")}, + }) + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Two turns. + for turn := range 2 { + if err := conn.SendText(fmt.Sprintf("turn %d", turn)); err != nil { + t.Fatalf("SendText failed on turn %d: %v", turn, err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error on turn %d: %v", turn, err) + } + if chunk.EndTurn { + break + } + } + } + + conn.Close() + if _, err := conn.Output(); err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have captured output for each turn. + if len(capturedOutputs) != 2 { + t.Fatalf("expected 2 captured outputs, got %d", len(capturedOutputs)) + } + + for i, output := range capturedOutputs { + chunks, ok := output.([]*AgentFlowStreamChunk[testStatus]) + if !ok { + t.Fatalf("turn %d: expected []*AgentFlowStreamChunk[testStatus], got %T", i, output) + } + // 3 content chunks per turn: status + model chunk + artifact. + if len(chunks) != 3 { + t.Errorf("turn %d: expected 3 chunks, got %d", i, len(chunks)) + } + for j, chunk := range chunks { + if chunk.EndTurn { + t.Errorf("turn %d, chunk %d: EndTurn should not be in turn output", i, j) + } + if chunk.SnapshotID != "" { + t.Errorf("turn %d, chunk %d: SnapshotID should not be in turn output", i, j) + } + } + } +} + +func TestAgentFlow_TurnSpanOutput_WithSnapshots(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + var capturedOutputs []any + + af := DefineCustomAgent(reg, "turnOutputSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + originalCollect := sess.collectTurnOutput + sess.collectTurnOutput = func() any { + output := originalCollect() + capturedOutputs = append(capturedOutputs, output) + return output + } + + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + resp.SendStatus(testStatus{Phase: "working"}) + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + WithSessionStore(store), + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + var sawSnapshot bool + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.SnapshotID != "" { + sawSnapshot = true + } + if chunk.EndTurn { + break + } + } + conn.Close() + conn.Output() + + if !sawSnapshot { + t.Fatal("expected a snapshot chunk on the stream") + } + + // Turn output should contain only the status chunk, not the snapshot/endTurn. + if len(capturedOutputs) != 1 { + t.Fatalf("expected 1 captured output, got %d", len(capturedOutputs)) + } + chunks := capturedOutputs[0].([]*AgentFlowStreamChunk[testStatus]) + if len(chunks) != 1 { + t.Errorf("expected 1 content chunk, got %d", len(chunks)) + } + if chunks[0].Status.Phase != "working" { + t.Errorf("expected status phase 'working', got %q", chunks[0].Status.Phase) + } +} + +// setupPromptTestRegistry creates a registry with an echo model and generate action. +func setupPromptTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + reg := registry.New() + ctx := context.Background() + + ai.ConfigureFormats(reg) + ai.DefineModel(reg, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Echo back the last user message text. + var text string + for i := len(req.Messages) - 1; i >= 0; i-- { + if req.Messages[i].Role == ai.RoleUser { + text = req.Messages[i].Text() + break + } + } + if text == "" { + text = "no input" + } + + resp := &ai.ModelResponse{ + Request: req, + Message: ai.NewModelTextMessage("echo: " + text), + } + + if cb != nil { + if err := cb(ctx, &ai.ModelResponseChunk{ + Content: resp.Message.Content, + }); err != nil { + return nil, err + } + } + + return resp, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + return reg +} + +func TestPromptAgent_Basic(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + ai.DefinePrompt(reg, "testPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + af := DefinePromptAgent[testState, any]( + reg, "testPrompt", nil, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + + var gotChunk bool + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.ModelChunk != nil { + gotChunk = true + } + if chunk.EndTurn { + break + } + } + if !gotChunk { + t.Error("expected at least one streaming chunk") + } + + // Turn 2. + if err := conn.SendText("world"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 user messages + 2 model replies = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptAgent_PromptInputOverride(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + type greetInput struct { + Name string `json:"name"` + } + + ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", + ai.WithModelName("test/echo"), + ai.WithPrompt("Hello {{name}}!"), + ) + + af := DefinePromptAgent[testState]( + reg, "greetPrompt", greetInput{Name: "default"}, + ) + + // Use WithPromptInput to override. + conn, err := af.StreamBidi(ctx, + WithInputVariables[testState](greetInput{Name: "override"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Verify the override was stored in session state. + if response.State.InputVariables == nil { + t.Fatal("expected PromptInput in state") + } + + // The model echoes the last user message, which is "hi". + // But the prompt was rendered with "override" so "Hello override!" should appear + // in the messages sent to the model (verified via the echo). + // We primarily verify the state was set correctly. + inputMap, ok := response.State.InputVariables.(map[string]any) + if !ok { + t.Fatalf("expected PromptInput to be map[string]any, got %T", response.State.InputVariables) + } + if name, _ := inputMap["name"].(string); name != "override" { + t.Errorf("expected PromptInput name='override', got %q", name) + } +} + +func TestPromptAgent_MultiTurnHistory(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + // Use a model that echoes all message count so we can verify history grows. + ai.DefineModel(reg, "test/history", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Count total messages received (includes prompt-rendered + history). + var parts []string + for _, m := range req.Messages { + parts = append(parts, string(m.Role)+":"+m.Text()) + } + text := strings.Join(parts, "|") + + resp := &ai.ModelResponse{ + Request: req, + Message: ai.NewModelTextMessage(text), + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) + } + return resp, nil + }, + ) + + ai.DefinePrompt(reg, "historyPrompt", + ai.WithModelName("test/history"), + ai.WithSystem("system prompt"), + ) + + af := DefinePromptAgent[testState, any]( + reg, "historyPrompt", nil, + ) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + conn.SendText("turn1") + var turn1Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.ModelChunk != nil { + turn1Response += chunk.ModelChunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 1 should have: system message + user message "turn1" (2 messages total from prompt + history). + // The system message comes from the prompt, "turn1" from session history. + if !strings.Contains(turn1Response, "turn1") { + t.Errorf("turn1 response should contain 'turn1', got: %s", turn1Response) + } + + // Turn 2. + conn.SendText("turn2") + var turn2Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.ModelChunk != nil { + turn2Response += chunk.ModelChunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 2 should have: system + turn1 user + turn1 model reply + turn2 user (4 messages from prompt + history). + if !strings.Contains(turn2Response, "turn1") || !strings.Contains(turn2Response, "turn2") { + t.Errorf("turn2 response should contain both 'turn1' and 'turn2', got: %s", turn2Response) + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should have: turn1 user + turn1 model + turn2 user + turn2 model = 4 messages. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages in session, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptAgent_SnapshotPersistsPromptInput(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + ai.DefinePrompt(reg, "snapPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + af := DefinePromptAgent[testState, any]( + reg, "snapPrompt", nil, + WithSessionStore(store), + ) + + // Start with prompt input. + conn, err := af.StreamBidi(ctx, + WithInputVariables[testState](map[string]any{"key": "value"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + resp, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + if resp.SnapshotID == "" { + t.Fatal("expected snapshot ID") + } + + // Verify the snapshot contains PromptInput. + snap, err := store.GetSnapshot(ctx, resp.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.State.InputVariables == nil { + t.Error("expected InputVariables in snapshot state") + } + + // Resume from snapshot — the PromptInput should be preserved. + conn2, err := af.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) + if err != nil { + t.Fatalf("StreamBidi with snapshot failed: %v", err) + } + + conn2.SendText("continued") + for chunk, err := range conn2.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn2.Close() + + resp2, err := conn2.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Verify state via snapshot (server-managed state). + snap2, err := store.GetSnapshot(ctx, resp2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if got := len(snap2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + if snap2.State.InputVariables == nil { + t.Error("expected PromptInput preserved after resume") + } +} + +func TestPromptAgent_ToolLoopMessages(t *testing.T) { + ctx := context.Background() + reg := registry.New() + ai.ConfigureFormats(reg) + + // Define a tool that the model will call. + ai.DefineTool(reg, "greet", "returns a greeting", + func(ctx *ai.ToolContext, input struct { + Name string `json:"name"` + }) (string, error) { + return "hello " + input.Name, nil + }, + ) + + // Model that requests a tool call on the first call, then returns + // a final text response once it sees the tool result. + ai.DefineModel(reg, "test/toolmodel", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true, SystemRole: true, Tools: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Check if we already got a tool response. + for _, msg := range req.Messages { + for _, p := range msg.Content { + if p.IsToolResponse() { + resp := &ai.ModelResponse{ + Request: req, + Message: ai.NewModelTextMessage("done: " + fmt.Sprintf("%v", p.ToolResponse.Output)), + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) + } + return resp, nil + } + } + } + // First call: request the tool. + resp := &ai.ModelResponse{ + Request: req, + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewToolRequestPart(&ai.ToolRequest{ + Name: "greet", + Input: map[string]any{"name": "world"}, + })}, + }, + } + return resp, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + + ai.DefinePrompt(reg, "toolPrompt", + ai.WithModelName("test/toolmodel"), + ai.WithSystem("You are a test assistant."), + ai.WithTools(ai.ToolName("greet")), + ) + + af := DefinePromptAgent[testState, any](reg, "toolPrompt", nil) + + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("go") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should contain: + // 1. user message ("go") + // 2. model tool-call message + // 3. tool response message + // 4. final model text response + msgs := response.State.Messages + if got := len(msgs); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + for i, m := range msgs { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + t.FailNow() + } + + if msgs[0].Role != ai.RoleUser { + t.Errorf("msg[0] role = %s, want user", msgs[0].Role) + } + hasToolReq := false + for _, p := range msgs[1].Content { + if p.IsToolRequest() { + hasToolReq = true + break + } + } + if msgs[1].Role != ai.RoleModel || !hasToolReq { + t.Errorf("msg[1] should be a model tool-call message") + } + if msgs[2].Role != ai.RoleTool { + t.Errorf("msg[2] role = %s, want tool", msgs[2].Role) + } + if msgs[3].Role != ai.RoleModel || !strings.Contains(msgs[3].Text(), "done:") { + t.Errorf("msg[3] should be final model response, got role=%s text=%s", msgs[3].Role, msgs[3].Text()) + } +} + +func TestAgentFlow_RunText(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "runTextFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text)) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + ) + + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + // 1 user message + 1 echo reply = 2. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + } + if got := response.State.Custom.Counter; got != 1 { + t.Errorf("expected counter=1, got %d", got) + } +} + +func TestAgentFlow_Run(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "runFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + return nil + }) + }, + ) + + input := &AgentFlowInput{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("msg1"), + ai.NewUserTextMessage("msg2"), + }, + } + + response, err := af.Run(ctx, input) + if err != nil { + t.Fatalf("Run failed: %v", err) + } + + // 2 user messages + 1 reply = 3. + if got := len(response.State.Messages); got != 3 { + t.Errorf("expected 3 messages, got %d", got) + } +} + +func TestAgentFlow_RunText_WithState(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + af := DefineCustomAgent(reg, "runStateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + ) + + clientState := &SessionState[testState]{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("previous"), + ai.NewModelTextMessage("previous reply"), + }, + Custom: testState{Counter: 10}, + } + + response, err := af.RunText(ctx, "new message", WithState(clientState)) + if err != nil { + t.Fatalf("RunText with state failed: %v", err) + } + + // 2 previous + 1 new user + 1 reply = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + // Counter should be 11 (started at 10, incremented once). + if got := response.State.Custom.Counter; got != 11 { + t.Errorf("expected counter=11, got %d", got) + } +} + +func TestAgentFlow_RunText_WithSnapshot(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "runSnapshotFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + // First invocation via RunText. + resp1, err := af.RunText(ctx, "first") + if err != nil { + t.Fatalf("first RunText failed: %v", err) + } + if resp1.SnapshotID == "" { + t.Fatal("expected snapshot ID from first invocation") + } + + // Resume from snapshot via RunText. + resp2, err := af.RunText(ctx, "second", WithSnapshotID[testState](resp1.SnapshotID)) + if err != nil { + t.Fatalf("second RunText failed: %v", err) + } + + snap, err := store.GetSnapshot(ctx, resp2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + // 4 messages: first user + reply + second user + reply. + if got := len(snap.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + if got := snap.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } +} + +func TestPromptAgent_RunText(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + ai.DefinePrompt(reg, "runTextPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + af := DefinePromptAgent[testState, any](reg, "runTextPrompt", nil) + + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + // 1 user message + 1 model reply = 2. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestAgentFlow_SingleTurnSnapshotDedup(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "dedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + // Single-turn invocation: should produce exactly 1 snapshot (turn-end), + // not 2 (turn-end + invocation-end with identical state). + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + if response.SnapshotID == "" { + t.Fatal("expected snapshot ID in response") + } + + // Count total snapshots in the store. + snap, err := store.GetSnapshot(ctx, response.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.Event != SnapshotEventTurnEnd { + t.Errorf("expected turn-end snapshot, got %s", snap.Event) + } + // The turn-end snapshot should have no parent (first and only snapshot). + if snap.ParentID != "" { + t.Errorf("expected no parent (single snapshot), got parent %q", snap.ParentID) + } +} + +func TestAgentFlow_MultiTurnSnapshotDedup(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "multiDedupFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + return nil, sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSessionStore(store), + ) + + // Multi-turn: last turn-end snapshot should dedup with invocation-end. + conn, err := af.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + var snapshotIDs []string + for i := 0; i < 3; i++ { + conn.SendText(fmt.Sprintf("turn %d", i)) + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error on turn %d: %v", i, err) + } + if chunk.SnapshotID != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotID) + } + if chunk.EndTurn { + break + } + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have 3 turn-end snapshots (one per turn), no extra invocation-end. + if got := len(snapshotIDs); got != 3 { + t.Errorf("expected 3 turn-end snapshots, got %d", got) + } + + // The output snapshot ID should reuse the last turn-end snapshot. + if response.SnapshotID == "" { + t.Fatal("expected snapshot ID in response") + } + if response.SnapshotID != snapshotIDs[len(snapshotIDs)-1] { + t.Errorf("expected output snapshot to reuse last turn-end snapshot %q, got %q", + snapshotIDs[len(snapshotIDs)-1], response.SnapshotID) + } +} + +func TestAgentFlow_InvocationEndSnapshotWhenStateChangesAfterRun(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySessionStore[testState]() + + af := DefineCustomAgent(reg, "postRunMutateFlow", + func(ctx context.Context, resp Responder[testStatus], sess *AgentSession[testState]) (*AgentFlowResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *AgentFlowInput) error { + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }); err != nil { + return nil, err + } + // Mutate state AFTER sess.Run returns -- this should trigger + // a separate invocation-end snapshot. + sess.UpdateCustom(func(s testState) testState { + s.Counter = 99 + return s + }) + return sess.Result(), nil + }, + WithSessionStore(store), + ) + + response, err := af.RunText(ctx, "hello") + if err != nil { + t.Fatalf("RunText failed: %v", err) + } + + if response.SnapshotID == "" { + t.Fatal("expected snapshot ID in response") + } + + // The final snapshot should be an invocation-end snapshot that captured + // the post-Run mutation. + snap, err := store.GetSnapshot(ctx, response.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.Event != SnapshotEventInvocationEnd { + t.Errorf("expected invocation-end snapshot, got %s", snap.Event) + } + if snap.State.Custom.Counter != 99 { + t.Errorf("expected counter=99 in final snapshot, got %d", snap.State.Custom.Counter) + } + // Should have a parent (the turn-end snapshot). + if snap.ParentID == "" { + t.Error("expected parent ID (turn-end snapshot)") + } +} diff --git a/go/ai/exp/gen.go b/go/ai/exp/gen.go new file mode 100644 index 0000000000..b13c62aa86 --- /dev/null +++ b/go/ai/exp/gen.go @@ -0,0 +1,122 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// This file was generated by jsonschemagen. DO NOT EDIT. + +package exp + +import ( + "github.com/firebase/genkit/go/ai" +) + +// AgentFlowInit is the input for starting an agent flow invocation. +// Provide either SnapshotID (to load from store) or State (direct state). +type AgentFlowInit[State any] struct { + // SnapshotID loads state from a persisted snapshot. + // Mutually exclusive with State. + SnapshotID string `json:"snapshotId,omitempty"` + // State provides direct state for the invocation. + // Mutually exclusive with SnapshotID. + State *SessionState[State] `json:"state,omitempty"` +} + +// AgentFlowInput is the input sent to an agent flow during a conversation turn. +type AgentFlowInput struct { + // Messages contains the user's input for this turn. + Messages []*ai.Message `json:"messages,omitempty"` + // ToolRestarts contains tool request parts to re-execute interrupted tools. + // Use [ai.ToolDef.RestartWith] to create these parts from an interrupted + // tool request. When set, the generate call resumes with these restarts + // instead of treating Messages as tool responses. + ToolRestarts []*ai.Part `json:"toolRestarts,omitempty"` +} + +// AgentFlowOutput is the output when an agent flow invocation completes. +// It wraps AgentFlowResult with framework-managed fields. +type AgentFlowOutput[State any] struct { + // Artifacts contains artifacts produced during the session. + Artifacts []*Artifact `json:"artifacts,omitempty"` + // Message is the last model response message from the conversation. + Message *ai.Message `json:"message,omitempty"` + // SnapshotID is the ID of the snapshot created at the end of this invocation. + // Empty if no snapshot was created (callback returned false or no store configured). + SnapshotID string `json:"snapshotId,omitempty"` + // State contains the final conversation state. + // Only populated when state is client-managed (no store configured). + State *SessionState[State] `json:"state,omitempty"` +} + +// AgentFlowResult is the return value from an AgentFlowFunc. +// It contains the user-specified outputs of the agent invocation. +type AgentFlowResult struct { + // Artifacts contains artifacts produced during the session. + Artifacts []*Artifact `json:"artifacts,omitempty"` + // Message is the last model response message from the conversation. + Message *ai.Message `json:"message,omitempty"` +} + +// AgentFlowStreamChunk represents a single item in the agent flow's output stream. +// Multiple fields can be populated in a single chunk. +type AgentFlowStreamChunk[Stream any] struct { + // Artifact contains a newly produced artifact. + Artifact *Artifact `json:"artifact,omitempty"` + // EndTurn signals that the agent flow has finished processing the current input. + // When true, the client should stop iterating and may send the next input. + EndTurn bool `json:"endTurn,omitempty"` + // ModelChunk contains generation tokens from the model. + ModelChunk *ai.ModelResponseChunk `json:"modelChunk,omitempty"` + // SnapshotID contains the ID of a snapshot that was just persisted. + SnapshotID string `json:"snapshotId,omitempty"` + // Status contains user-defined structured status information. + // The Stream type parameter defines the shape of this data. + Status Stream `json:"status,omitempty"` +} + +// Artifact represents a named collection of parts produced during a session. +// Examples: generated files, images, code snippets, diagrams, etc. +type Artifact struct { + // Metadata contains additional artifact-specific data. + Metadata map[string]any `json:"metadata,omitempty"` + // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). + Name string `json:"name,omitempty"` + // Parts contains the artifact content (text, media, etc.). + Parts []*ai.Part `json:"parts"` +} + +// SessionState is the portable conversation state that flows between client +// and server. It contains only the data needed for conversation continuity. +type SessionState[State any] struct { + // Artifacts are named collections of parts produced during the conversation. + Artifacts []*Artifact `json:"artifacts,omitempty"` + // Custom is the user-defined state associated with this conversation. + Custom State `json:"custom,omitempty"` + // InputVariables is the input used for agent flows that require input variables + // (e.g. prompt-backed agent flows). + InputVariables any `json:"inputVariables,omitempty"` + // Messages is the conversation history (user/model exchanges). + // Does NOT include prompt-rendered messages — those are rendered fresh each turn. + Messages []*ai.Message `json:"messages,omitempty"` +} + +// SnapshotEvent identifies what triggered a snapshot. +type SnapshotEvent string + +const ( + // TurnEnd indicates the snapshot was triggered at the end of a turn. + SnapshotEventTurnEnd SnapshotEvent = "turnEnd" + // InvocationEnd indicates the snapshot was triggered at the end of the invocation. + SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" +) diff --git a/go/ai/exp/option.go b/go/ai/exp/option.go new file mode 100644 index 0000000000..f7e80639ca --- /dev/null +++ b/go/ai/exp/option.go @@ -0,0 +1,134 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "errors" +) + +// --- AgentFlowOption --- + +// AgentFlowOption configures an AgentFlow. +type AgentFlowOption[State any] interface { + applyAgentFlow(*agentFlowOptions[State]) error +} + +type agentFlowOptions[State any] struct { + store SessionStore[State] + callback SnapshotCallback[State] +} + +func (o *agentFlowOptions[State]) applyAgentFlow(opts *agentFlowOptions[State]) error { + if o.store != nil { + if opts.store != nil { + return errors.New("cannot set session store more than once (WithSessionStore)") + } + opts.store = o.store + } + if o.callback != nil { + if opts.callback != nil { + return errors.New("cannot set snapshot callback more than once (WithSnapshotCallback)") + } + opts.callback = o.callback + } + return nil +} + +// WithSessionStore sets the store for persisting snapshots. +func WithSessionStore[State any](store SessionStore[State]) AgentFlowOption[State] { + return &agentFlowOptions[State]{store: store} +} + +// WithSnapshotCallback configures when snapshots are created. +// If not provided and a store is configured, snapshots are always created. +func WithSnapshotCallback[State any](cb SnapshotCallback[State]) AgentFlowOption[State] { + return &agentFlowOptions[State]{callback: cb} +} + +// WithSnapshotOn configures snapshots to be created only for the specified events. +// For example, WithSnapshotOn[MyState](SnapshotEventTurnEnd) skips the +// invocation-end snapshot. +func WithSnapshotOn[State any](events ...SnapshotEvent) AgentFlowOption[State] { + set := make(map[SnapshotEvent]struct{}, len(events)) + for _, e := range events { + set[e] = struct{}{} + } + return WithSnapshotCallback[State](func(_ context.Context, sc *SnapshotContext[State]) bool { + _, ok := set[sc.Event] + return ok + }) +} + +// --- InvocationOption --- + +// InvocationOption configures an agent flow invocation (StreamBidi, Run, or RunText). +type InvocationOption[State any] interface { + applyInvocation(*invocationOptions[State]) error +} + +type invocationOptions[State any] struct { + state *SessionState[State] + snapshotID string + promptInput any +} + +func (o *invocationOptions[State]) applyInvocation(opts *invocationOptions[State]) error { + if o.state != nil { + if opts.state != nil { + return errors.New("cannot set state more than once (WithState)") + } + if opts.snapshotID != "" { + return errors.New("WithState and WithSnapshotID are mutually exclusive") + } + opts.state = o.state + } + if o.snapshotID != "" { + if opts.snapshotID != "" { + return errors.New("cannot set snapshot ID more than once (WithSnapshotID)") + } + if opts.state != nil { + return errors.New("WithSnapshotID and WithState are mutually exclusive") + } + opts.snapshotID = o.snapshotID + } + if o.promptInput != nil { + if opts.promptInput != nil { + return errors.New("cannot set prompt input more than once (WithPromptInput)") + } + opts.promptInput = o.promptInput + } + return nil +} + +// WithState sets the initial state for the invocation. +// Use this for client-managed state where the client sends state directly. +func WithState[State any](state *SessionState[State]) InvocationOption[State] { + return &invocationOptions[State]{state: state} +} + +// WithSnapshotID loads state from a persisted snapshot by ID. +// Use this for server-managed state where snapshots are stored. +func WithSnapshotID[State any](id string) InvocationOption[State] { + return &invocationOptions[State]{snapshotID: id} +} + +// WithInputVariables overrides the default input variables for a prompt-backed agent flow. +// Used with DefinePromptAgent to customize the input variables per invocation. +func WithInputVariables[State any](input any) InvocationOption[State] { + return &invocationOptions[State]{promptInput: input} +} diff --git a/go/ai/exp/session.go b/go/ai/exp/session.go new file mode 100644 index 0000000000..40f0d6dfc8 --- /dev/null +++ b/go/ai/exp/session.go @@ -0,0 +1,275 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package exp + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/base" +) + +// --- Snapshot --- + +// SessionSnapshot is a persisted point-in-time capture of session state. +type SessionSnapshot[State any] struct { + // SnapshotID is the unique identifier for this snapshot (UUID). + SnapshotID string `json:"snapshotId"` + // ParentID is the ID of the previous snapshot in this timeline. + ParentID string `json:"parentId,omitempty"` + // CreatedAt is when the snapshot was created. + CreatedAt time.Time `json:"createdAt"` + // Event is what triggered this snapshot. + Event SnapshotEvent `json:"event"` + // State is the actual conversation state. + State SessionState[State] `json:"state"` +} + +// SnapshotContext provides context for snapshot decision callbacks. +type SnapshotContext[State any] struct { + // State is the current state that will be snapshotted if the callback returns true. + State *SessionState[State] + // PrevState is the state at the last snapshot, or nil if none exists. + PrevState *SessionState[State] + // TurnIndex is the turn number in the current invocation. + TurnIndex int + // Event is what triggered this snapshot check. + Event SnapshotEvent +} + +// SnapshotCallback decides whether to create a snapshot. +// If not provided and a store is configured, snapshots are always created. +type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool + +// --- Session store --- + +// SessionStore persists and retrieves snapshots. +type SessionStore[State any] interface { + // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. + GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) + // SaveSnapshot persists a snapshot. + SaveSnapshot(ctx context.Context, snapshot *SessionSnapshot[State]) error +} + +// InMemorySessionStore provides a thread-safe in-memory snapshot store. +type InMemorySessionStore[State any] struct { + snapshots map[string]*SessionSnapshot[State] + mu sync.RWMutex +} + +// NewInMemorySessionStore creates a new in-memory snapshot store. +func NewInMemorySessionStore[State any]() *InMemorySessionStore[State] { + return &InMemorySessionStore[State]{ + snapshots: make(map[string]*SessionSnapshot[State]), + } +} + +// GetSnapshot retrieves a snapshot by ID. Returns nil if not found. +func (s *InMemorySessionStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { + s.mu.RLock() + defer s.mu.RUnlock() + + snap, exists := s.snapshots[snapshotID] + if !exists { + return nil, nil + } + + copied, err := copySnapshot(snap) + if err != nil { + return nil, err + } + return copied, nil +} + +// SaveSnapshot persists a snapshot. +func (s *InMemorySessionStore[State]) SaveSnapshot(_ context.Context, snapshot *SessionSnapshot[State]) error { + s.mu.Lock() + defer s.mu.Unlock() + + copied, err := copySnapshot(snapshot) + if err != nil { + return err + } + s.snapshots[copied.SnapshotID] = copied + return nil +} + +// copySnapshot creates a deep copy of a snapshot using JSON marshaling. +func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + if snap == nil { + return nil, nil + } + bytes, err := json.Marshal(snap) + if err != nil { + return nil, fmt.Errorf("copy snapshot: marshal: %w", err) + } + var copied SessionSnapshot[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + return nil, fmt.Errorf("copy snapshot: unmarshal: %w", err) + } + return &copied, nil +} + +// --- Session --- + +// Session holds conversation state and provides thread-safe read/write access to messages, +// input variables, custom state, and artifacts. +type Session[State any] struct { + mu sync.RWMutex + state SessionState[State] + store SessionStore[State] + version uint64 // incremented on every mutation; used to skip redundant snapshots +} + +// State returns a copy of the current state. +func (s *Session[State]) State() *SessionState[State] { + s.mu.RLock() + defer s.mu.RUnlock() + copied := s.copyStateLocked() + return &copied +} + +// Messages returns the current conversation history. +func (s *Session[State]) Messages() []*ai.Message { + s.mu.RLock() + defer s.mu.RUnlock() + msgs := make([]*ai.Message, len(s.state.Messages)) + copy(msgs, s.state.Messages) + return msgs +} + +// AddMessages appends messages to the conversation history. +func (s *Session[State]) AddMessages(messages ...*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = append(s.state.Messages, messages...) + s.version++ +} + +// SetMessages replaces the conversation history with the given messages. +func (s *Session[State]) SetMessages(messages []*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = messages + s.version++ +} + +// UpdateMessages atomically reads the current messages, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateMessages(fn func([]*ai.Message) []*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = fn(s.state.Messages) + s.version++ +} + +// Custom returns the current user-defined custom state. +func (s *Session[State]) Custom() State { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.Custom +} + +// UpdateCustom atomically reads the current custom state, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateCustom(fn func(State) State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Custom = fn(s.state.Custom) + s.version++ +} + +// InputVariables returns the prompt input stored in the session state. +func (s *Session[State]) InputVariables() any { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.InputVariables +} + +// Artifacts returns the current artifacts. +func (s *Session[State]) Artifacts() []*Artifact { + s.mu.RLock() + defer s.mu.RUnlock() + arts := make([]*Artifact, len(s.state.Artifacts)) + copy(arts, s.state.Artifacts) + return arts +} + +// AddArtifacts adds artifacts to the session. If an artifact with the same +// name already exists, it is replaced. +func (s *Session[State]) AddArtifacts(artifacts ...*Artifact) { + s.mu.Lock() + defer s.mu.Unlock() + for _, a := range artifacts { + replaced := false + if a.Name != "" { + for i, existing := range s.state.Artifacts { + if existing.Name == a.Name { + s.state.Artifacts[i] = a + replaced = true + break + } + } + } + if !replaced { + s.state.Artifacts = append(s.state.Artifacts, a) + } + } + s.version++ +} + +// UpdateArtifacts atomically reads the current artifacts, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateArtifacts(fn func([]*Artifact) []*Artifact) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Artifacts = fn(s.state.Artifacts) + s.version++ +} + +// copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). +func (s *Session[State]) copyStateLocked() SessionState[State] { + bytes, err := json.Marshal(s.state) + if err != nil { + panic(fmt.Sprintf("agent flow: failed to marshal state: %v", err)) + } + var copied SessionState[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + panic(fmt.Sprintf("agent flow: failed to unmarshal state: %v", err)) + } + return copied +} + +// --- Session context --- + +var sessionCtxKey = base.NewContextKey[any]() + +// NewSessionContext returns a new context with the session attached. +func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { + return sessionCtxKey.NewContext(ctx, s) +} + +// SessionFromContext retrieves the current session from context. +// Returns nil if no session is in context or if the type doesn't match. +func SessionFromContext[State any](ctx context.Context) *Session[State] { + session, _ := sessionCtxKey.FromContext(ctx).(*Session[State]) + return session +} diff --git a/go/ai/generate.go b/go/ai/generate.go index 6aeb1e6642..c641012609 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -1017,6 +1017,17 @@ func (c *ModelResponseChunk) Reasoning() string { return sb.String() } +// Interrupts returns the interrupted tool request parts from the chunk. +func (c *ModelResponseChunk) Interrupts() []*Part { + var parts []*Part + for _, p := range c.Content { + if p.IsInterrupt() { + parts = append(parts, p) + } + } + return parts +} + // Output parses the chunk using the format handler and unmarshals the result into v. // Returns an error if the format handler is not set or does not support parsing chunks. func (c *ModelResponseChunk) Output(v any) error { @@ -1100,6 +1111,27 @@ func (m *Message) Text() string { return sb.String() } +// NewResume constructs a [GenerateActionResume] from Part slices. +// This is useful when building [GenerateActionOptions] directly (e.g., from a +// rendered prompt) and need to set the Resume field from [*Part] values +// produced by [ToolDef.RestartWith] or [ToolDef.RespondWith]. +func NewResume(restarts, responds []*Part) *GenerateActionResume { + resume := &GenerateActionResume{} + for _, p := range restarts { + resume.Restart = append(resume.Restart, &toolRequestPart{ + ToolRequest: p.ToolRequest, + Metadata: p.Metadata, + }) + } + for _, p := range responds { + resume.Respond = append(resume.Respond, &toolResponsePart{ + ToolResponse: p.ToolResponse, + Metadata: p.Metadata, + }) + } + return resume +} + // NewModelRef creates a new ModelRef with the given name and configuration. func NewModelRef(name string, config any) ModelRef { return ModelRef{name: name, config: config} diff --git a/go/ai/tools.go b/go/ai/tools.go index 453ad779a3..6099d1eff5 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -248,8 +248,7 @@ func ResumedValue[T any](tc *ToolContext, key string) (T, bool) { if !ok { return zero, false } - typed, ok := v.(T) - return typed, ok + return base.ConvertTo[T](v) } // OriginalInputAs returns the original input typed appropriately. @@ -259,19 +258,7 @@ func OriginalInputAs[T any](tc *ToolContext) (T, bool) { if tc.OriginalInput == nil { return zero, false } - // Try direct type assertion first (for when input is already typed) - if typed, ok := tc.OriginalInput.(T); ok { - return typed, ok - } - // Otherwise try to convert from map[string]any (common case from JSON) - if m, ok := tc.OriginalInput.(map[string]any); ok { - result, err := base.MapToStruct[T](m) - if err != nil { - return zero, false - } - return result, true - } - return zero, false + return base.ConvertTo[T](tc.OriginalInput) } // DefineTool creates a new [ToolDef] and registers it. diff --git a/go/core/api/action.go b/go/core/api/action.go index a38958af51..704fb1b9f0 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -64,6 +64,7 @@ const ( ActionTypeCustom ActionType = "custom" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" + ActionTypeAgentFlow ActionType = "agent-flow" ) // ActionDesc is a descriptor of an action. diff --git a/go/core/flow.go b/go/core/flow.go index c173a0306c..b220c30479 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -93,7 +93,7 @@ func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, St // DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. // Flow context is injected so that [Run] works inside the bidi function. func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { - f := NewBidiFlow[In, Out, Stream, Init](name, fn) + f := NewBidiFlow(name, fn) f.Register(r) return f } diff --git a/go/core/schemas.config b/go/core/schemas.config index 70798f2eb3..b58c36e2a9 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1108,3 +1108,222 @@ Embedding.embedding type []float32 GenkitError omit GenkitErrorData omit GenkitErrorDataGenkitErrorDetails omit + +# ============================================================================ +# AGENT FLOW TYPES (generated into ai/x package) +# ============================================================================ + +# Package configuration: ai/x directory uses "aix" as Go package name. +ai/x name aix +aix import github.com/firebase/genkit/go/ai + +# ---------------------------------------------------------------------------- +# Artifact +# ---------------------------------------------------------------------------- + +Artifact pkg ai/x + +Artifact doc +Artifact represents a named collection of parts produced during a session. +Examples: generated files, images, code snippets, diagrams, etc. +. + +Artifact.name doc +Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). +. + +Artifact.parts type []*ai.Part +Artifact.parts noomitempty +Artifact.parts doc +Parts contains the artifact content (text, media, etc.). +. + +Artifact.metadata type map[string]any +Artifact.metadata doc +Metadata contains additional artifact-specific data. +. + +# ---------------------------------------------------------------------------- +# AgentFlowInput +# ---------------------------------------------------------------------------- + +AgentFlowInput pkg ai/x + +AgentFlowInput doc +AgentFlowInput is the input sent to an agent flow during a conversation turn. +. + +AgentFlowInput.messages type []*ai.Message +AgentFlowInput.messages doc +Messages contains the user's input for this turn. +. + +AgentFlowInput.toolRestarts type []*ai.Part +AgentFlowInput.toolRestarts doc +ToolRestarts contains tool request parts to re-execute interrupted tools. +Use [ai.ToolDef.RestartWith] to create these parts from an interrupted +tool request. When set, the generate call resumes with these restarts +instead of treating Messages as tool responses. +. + +# ---------------------------------------------------------------------------- +# AgentFlowInit +# ---------------------------------------------------------------------------- + +AgentFlowInit pkg ai/x +AgentFlowInit typeparams [State any] + +AgentFlowInit doc +AgentFlowInit is the input for starting an agent flow invocation. +Provide either SnapshotID (to load from store) or State (direct state). +. + +AgentFlowInit.snapshotId doc +SnapshotID loads state from a persisted snapshot. +Mutually exclusive with State. +. + +AgentFlowInit.state type *SessionState[State] +AgentFlowInit.state doc +State provides direct state for the invocation. +Mutually exclusive with SnapshotID. +. + +# ---------------------------------------------------------------------------- +# AgentFlowResult +# ---------------------------------------------------------------------------- + +AgentFlowResult pkg ai/x + +AgentFlowResult doc +AgentFlowResult is the return value from an AgentFlowFunc. +It contains the user-specified outputs of the agent invocation. +. + +AgentFlowResult.message type *ai.Message +AgentFlowResult.message doc +Message is the last model response message from the conversation. +. + +AgentFlowResult.artifacts doc +Artifacts contains artifacts produced during the session. +. + +# ---------------------------------------------------------------------------- +# AgentFlowOutput +# ---------------------------------------------------------------------------- + +AgentFlowOutput pkg ai/x +AgentFlowOutput typeparams [State any] + +AgentFlowOutput doc +AgentFlowOutput is the output when an agent flow invocation completes. +It wraps AgentFlowResult with framework-managed fields. +. + +AgentFlowOutput.snapshotId doc +SnapshotID is the ID of the snapshot created at the end of this invocation. +Empty if no snapshot was created (callback returned false or no store configured). +. + +AgentFlowOutput.state type *SessionState[State] +AgentFlowOutput.state doc +State contains the final conversation state. +Only populated when state is client-managed (no store configured). +. + +AgentFlowOutput.message type *ai.Message +AgentFlowOutput.message doc +Message is the last model response message from the conversation. +. + +AgentFlowOutput.artifacts doc +Artifacts contains artifacts produced during the session. +. + +# ---------------------------------------------------------------------------- +# AgentFlowStreamChunk +# ---------------------------------------------------------------------------- + +AgentFlowStreamChunk pkg ai/x +AgentFlowStreamChunk typeparams [Stream any] + +AgentFlowStreamChunk doc +AgentFlowStreamChunk represents a single item in the agent flow's output stream. +Multiple fields can be populated in a single chunk. +. + +AgentFlowStreamChunk.modelChunk type *ai.ModelResponseChunk +AgentFlowStreamChunk.modelChunk doc +ModelChunk contains generation tokens from the model. +. + +AgentFlowStreamChunk.status type Stream +AgentFlowStreamChunk.status doc +Status contains user-defined structured status information. +The Stream type parameter defines the shape of this data. +. + +AgentFlowStreamChunk.artifact doc +Artifact contains a newly produced artifact. +. + +AgentFlowStreamChunk.snapshotId doc +SnapshotID contains the ID of a snapshot that was just persisted. +. + +AgentFlowStreamChunk.endTurn doc +EndTurn signals that the agent flow has finished processing the current input. +When true, the client should stop iterating and may send the next input. +. + +# ---------------------------------------------------------------------------- +# SessionState +# ---------------------------------------------------------------------------- + +SessionState pkg ai/x +SessionState typeparams [State any] + +SessionState doc +SessionState is the portable conversation state that flows between client +and server. It contains only the data needed for conversation continuity. +. + +SessionState.messages type []*ai.Message +SessionState.messages doc +Messages is the conversation history (user/model exchanges). +Does NOT include prompt-rendered messages — those are rendered fresh each turn. +. + +SessionState.custom type State +SessionState.custom doc +Custom is the user-defined state associated with this conversation. +. + +SessionState.artifacts doc +Artifacts are named collections of parts produced during the conversation. +. + +SessionState.inputVariables doc +InputVariables is the input used for agent flows that require input variables +(e.g. prompt-backed agent flows). +. + +# ---------------------------------------------------------------------------- +# SnapshotEvent +# ---------------------------------------------------------------------------- + +SnapshotEvent pkg ai/x + +SnapshotEvent doc +SnapshotEvent identifies what triggered a snapshot. +. + +SnapshotEventTurnEnd doc +TurnEnd indicates the snapshot was triggered at the end of a turn. +. + +SnapshotEventInvocationEnd doc +InvocationEnd indicates the snapshot was triggered at the end of the invocation. +. + diff --git a/go/core/x/session/session.go b/go/core/x/session/session.go index 8a9f0387f9..01db9e8562 100644 --- a/go/core/x/session/session.go +++ b/go/core/x/session/session.go @@ -165,7 +165,7 @@ func New[S any](ctx context.Context, opts ...Option[S]) (*Session[S], error) { func Load[S any](ctx context.Context, store Store[S], sessionID string) (*Session[S], error) { data, err := store.Get(ctx, sessionID) if err != nil { - return nil, err + return nil, fmt.Errorf("session.Load: %w", err) } if data == nil { return nil, &NotFoundError{SessionID: sessionID} @@ -221,7 +221,7 @@ func (s *Session[S]) UpdateState(ctx context.Context, state S) error { State: state, } if err := s.store.Save(ctx, s.id, data); err != nil { - return err + return fmt.Errorf("session.UpdateState: %w", err) } } @@ -346,12 +346,12 @@ func copyData[S any](data *Data[S]) (*Data[S], error) { bytes, err := json.Marshal(data) if err != nil { - return nil, err + return nil, fmt.Errorf("copy session data: marshal: %w", err) } var copied Data[S] if err := json.Unmarshal(bytes, &copied); err != nil { - return nil, err + return nil, fmt.Errorf("copy session data: unmarshal: %w", err) } return &copied, nil diff --git a/go/core/x/session/session_test.go b/go/core/x/session/session_test.go index b34100a44c..55c25d76d1 100644 --- a/go/core/x/session/session_test.go +++ b/go/core/x/session/session_test.go @@ -776,7 +776,7 @@ func TestSession_UpdateState_StoreError(t *testing.T) { if err == nil { t.Fatal("Expected error from failing store") } - if err != expectedErr { - t.Errorf("Expected error %v, got %v", expectedErr, err) + if !errors.Is(err, expectedErr) { + t.Errorf("Expected error wrapping %v, got %v", expectedErr, err) } } diff --git a/go/custom-agent b/go/custom-agent new file mode 100755 index 0000000000..30c9ea23fa Binary files /dev/null and b/go/custom-agent differ diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 40e137a2b8..21f312de6a 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -30,6 +30,7 @@ import ( "syscall" "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" @@ -186,7 +187,7 @@ func WithPromptFS(fsys fs.FS) GenkitOption { // // Assumes a prompt file at ./prompts/jokePrompt.prompt // g := genkit.Init(ctx, // genkit.WithPlugins(&googlegenai.GoogleAI{}), -// genkit.WithDefaultModel("googleai/gemini-2.5-flash"), +// genkit.WithDefaultModel("googleai/gemini-3-flash-preview"), // genkit.WithPromptDir("./prompts"), // ) // @@ -407,6 +408,159 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B return core.DefineBidiFlow(g.reg, name, fn) } +// DefineCustomAgent defines a custom agent flow with full control over the +// conversation loop, registers it as a [core.Action] of type Flow, and +// returns an [aix.AgentFlow]. +// +// Experimental: This API is under active development and may change in any +// minor version release. +// +// An AgentFlow is a stateful, multi-turn conversational flow. It builds on +// bidirectional streaming to enable ongoing conversations where each turn's +// input and output are streamed between client and server. The framework +// handles session state, conversation history, and optional snapshot +// persistence automatically. +// +// The provided function fn receives a [aix.Responder] for streaming output +// to the client and an [aix.AgentSession] for accessing conversation state. +// Call [aix.AgentSession.Run] to enter the turn loop, which blocks until the +// client sends the next message. +// +// For prompt-backed agents that follow a standard render-generate-stream loop, +// use [DefinePromptAgent] instead. +// +// # Options +// +// - [aix.WithSessionStore]: Enable snapshot persistence with a [aix.SessionStore] +// - [aix.WithSnapshotCallback]: Control when snapshots are created +// - [aix.WithSnapshotOn]: Create snapshots only for specific [aix.SnapshotEvent] types +// +// Type parameters: +// - Stream: Type for custom status updates sent via [aix.Responder.SendStatus] +// - State: Type for user-defined state persisted in snapshots +// +// Example: +// +// chatAgent := genkit.DefineCustomAgent(g, "chat", +// func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentFlowResult, error) { +// var lastMessage *ai.Message +// err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { +// sess.AddMessages(input.Messages...) +// for result, err := range genkit.GenerateStream(ctx, g, +// ai.WithModelName("googleai/gemini-3-flash-preview"), +// ai.WithMessages(sess.Messages()...), +// ) { +// if err != nil { +// return err +// } +// if result.Done { +// lastMessage = result.Response.Message +// sess.AddMessages(lastMessage) +// } else { +// resp.SendModelChunk(result.Chunk) +// } +// } +// return nil +// }) +// if err != nil { +// return nil, err +// } +// return &aix.AgentFlowResult{Message: lastMessage}, nil +// }, +// ) +// +// // Start a conversation: +// conn, err := chatAgent.StreamBidi(ctx) +// if err != nil { +// // handle error +// } +// +// // Send a message and stream the response: +// conn.SendText("Hello!") +// for chunk, err := range conn.Receive() { +// if chunk.EndTurn { +// break +// } +// fmt.Print(chunk.ModelChunk.Text()) +// } +// conn.Close() +func DefineCustomAgent[Stream, State any]( + g *Genkit, + name string, + fn aix.AgentFlowFunc[Stream, State], + opts ...aix.AgentFlowOption[State], +) *aix.AgentFlow[Stream, State] { + return aix.DefineCustomAgent(g.reg, name, fn, opts...) +} + +// DefinePromptAgent defines a prompt-backed agent flow, registers it as a +// [core.Action] of type Flow, and returns an [aix.AgentFlow]. +// +// Experimental: This API is under active development and may change in any +// minor version release. +// +// This is a higher-level alternative to [DefineCustomAgent] for agents backed +// by a prompt (defined via [DefinePrompt] or loaded from a .prompt file). The +// conversation loop is handled automatically: each turn renders the prompt, +// appends conversation history, calls the model with streaming, and updates +// session state. +// +// The prompt is looked up by promptName from the registry. The defaultInput +// provides template variables for prompt rendering (e.g., personality, tone) +// and can be overridden per invocation via [aix.WithInputVariables]. +// +// DefinePromptAgent accepts the same options as [DefineCustomAgent]. See +// [DefineCustomAgent] for available options. +// +// Type parameters: +// - State: Type for user-defined state persisted in snapshots +// - PromptIn: The prompt input type (inferred from defaultInput) +// +// Example: +// +// // Given a .prompt file "chat.prompt" loaded via WithPromptDir: +// // --- +// // model: googleai/gemini-3-flash-preview +// // input: +// // schema: +// // personality: string +// // --- +// // {{role "system"}} +// // You are {{personality}}. +// +// type ChatInput struct { +// Personality string `json:"personality"` +// } +// +// chatAgent := genkit.DefinePromptAgent(g, "chat", +// ChatInput{Personality: "a helpful assistant"}, +// aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), +// ) +// +// // Start a conversation: +// conn, err := chatAgent.StreamBidi(ctx) +// if err != nil { +// // handle error +// } +// +// // Send a message and stream the response: +// conn.SendText("Hello!") +// for chunk, err := range conn.Receive() { +// if chunk.EndTurn { +// break +// } +// fmt.Print(chunk.ModelChunk.Text()) +// } +// conn.Close() +func DefinePromptAgent[State, PromptIn any]( + g *Genkit, + promptName string, + defaultInput PromptIn, + opts ...aix.AgentFlowOption[State], +) *aix.AgentFlow[any, State] { + return aix.DefinePromptAgent(g.reg, promptName, defaultInput, opts...) +} + // Run executes the given function `fn` within the context of the current flow run, // creating a distinct trace span for this step. It's used to add observability // to specific sub-operations within a flow defined by [DefineFlow] or [DefineStreamingFlow]. @@ -787,7 +941,7 @@ func LookupTool(g *Genkit, name string) ai.Tool { // // Define the prompt // capitalPrompt := genkit.DefinePrompt(g, "findCapital", // ai.WithDescription("Finds the capital of a country."), -// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithSystem("You are a helpful geography assistant."), // ai.WithPrompt("What is the capital of {{country}}?"), // ai.WithInputType(GeoInput{Country: "USA"}), @@ -896,7 +1050,7 @@ func DefineSchemaFor[T any](g *Genkit) { // } // // capitalPrompt := genkit.DefineDataPrompt[GeoInput, GeoOutput](g, "findCapital", -// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithSystem("You are a helpful geography assistant."), // ai.WithPrompt("What is the capital of {{country}}?"), // ) @@ -953,7 +1107,7 @@ func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.Generate // // Model and Configuration: // - [ai.WithModel]: Specify the model (accepts [ai.Model] or [ai.ModelRef]) -// - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-2.5-flash") +// - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-3-flash-preview") // - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) // // Prompting: @@ -991,7 +1145,7 @@ func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.Generate // Example: // // resp, err := genkit.Generate(ctx, g, -// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithModelName("googleai/gemini-3-flash-preview"), // ai.WithPrompt("Write a short poem about clouds."), // ) // if err != nil { @@ -1478,7 +1632,7 @@ func LoadPrompt(g *Genkit, path, namespace string) ai.Prompt { // Example: // // promptSource := `--- -// model: googleai/gemini-2.5-flash +// model: googleai/gemini-3-flash-preview // input: // schema: // name: string diff --git a/go/internal/base/json.go b/go/internal/base/json.go index dff45c260a..a52ba3491b 100644 --- a/go/internal/base/json.go +++ b/go/internal/base/json.go @@ -118,6 +118,24 @@ func InferJSONSchema(x any) (s *jsonschema.Schema) { return s } +// ConvertTo attempts to convert a value to type T. It tries a direct type +// assertion first, then falls back to a JSON round-trip for values that were +// deserialized from JSON (e.g., map[string]any instead of a concrete struct). +func ConvertTo[T any](v any) (T, bool) { + if typed, ok := v.(T); ok { + return typed, true + } + var result T + data, err := json.Marshal(v) + if err != nil { + return result, false + } + if err := json.Unmarshal(data, &result); err != nil { + return result, false + } + return result, true +} + // MapToStruct converts a map[string]any to a struct of type T via JSON round-trip. func MapToStruct[T any](m map[string]any) (T, error) { var result T diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index 6ed2d5a8e7..c1c9a3fb5e 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -131,9 +131,14 @@ func run(infile, defaultPkgPath, configFile, outRoot string) error { // Generate code by package. for pkgPath, schemaMap := range schemasByPackage { + // Derive package name from path, with config override. + pkgName := path.Base(pkgPath) + if pc := cfg.configFor(pkgPath); pc != nil && pc.name != "" { + pkgName = pc.name + } // Generate code for each type in the package. gen := &generator{ - pkgName: path.Base(pkgPath), + pkgName: pkgName, schemas: schemaMap, cfg: cfg, } @@ -292,7 +297,15 @@ func (g *generator) generate() ([]byte, error) { g.pr("package %s\n\n", g.pkgName) if pc := g.cfg.configFor(g.pkgName); pc != nil { - g.pr("import %q\n", pc.pkgPath) + if len(pc.imports) > 0 { + g.pr("import (\n") + for _, imp := range pc.imports { + g.pr(" %q\n", imp) + } + g.pr(")\n\n") + } else if pc.pkgPath != "" { + g.pr("import %q\n", pc.pkgPath) + } } // Sort the names so the output is deterministic. @@ -386,7 +399,7 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err if goName == "" { goName = adjustIdentifier(name) } - g.pr("type %s struct {\n", goName) + g.pr("type %s%s struct {\n", goName, tcfg.typeparams) for _, field := range sortedKeys(s.Properties) { fcfg := g.cfg.configFor(name + "." + field) if fcfg == nil { @@ -416,7 +429,7 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err g.generateDoc(fs, fcfg) jsonTag := fmt.Sprintf(`json:"%s,omitempty"`, field) - if skipOmitEmpty(goName, field) { + if skipOmitEmpty(goName, field) || fcfg.noOmitEmpty { jsonTag = fmt.Sprintf(`json:"%s"`, field) } g.pr(fmt.Sprintf(" %s %s `%s`\n", adjustIdentifier(field), typeExpr, jsonTag)) @@ -603,12 +616,15 @@ func (c config) configFor(name string) *itemConfig { // itemConfig is configuration for one item: a type, a field or a package. // Not all itemConfig fields apply to both, but using one type simplifies the parser. type itemConfig struct { - omit bool - name string - pkgPath string - typeExpr string - docLines []string - fields []extraField + omit bool + name string + pkgPath string + typeExpr string + docLines []string + fields []extraField + typeparams string // Go type parameters (e.g., "[State any]") + noOmitEmpty bool // omit the omitempty tag for this field + imports []string // import paths for the package } // extraField represents an additional unexported field to add to a struct. @@ -636,7 +652,11 @@ type extraField struct { // pkg // package path, relative to outdir (last component is package name) // import -// path of package to import (for packages only) +// path of package to import (for packages only, may be repeated) +// typeparams PARAMS +// Go type parameters to add to the type declaration (e.g., "[State any]") +// noomitempty +// don't add omitempty to this field's json tag // field NAME TYPE // add an unexported field to the struct (for types only) func parseConfigFile(filename string) (config, error) { @@ -703,7 +723,14 @@ func parseConfigFile(filename string) (config, error) { if len(words) < 3 { return errf("need NAME import PATH") } - ic.pkgPath = words[2] + ic.imports = append(ic.imports, words[2]) + case "typeparams": + if len(words) < 3 { + return errf("need NAME typeparams PARAMS") + } + ic.typeparams = strings.Join(words[2:], " ") + case "noomitempty": + ic.noOmitEmpty = true case "field": if len(words) < 4 { return errf("need NAME field FIELDNAME TYPE") diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index 783eccd239..f65b42ff56 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -70,7 +70,7 @@ func TestGoogleAILive(t *testing.T) { genkit.WithPlugins(&googlegenai.GoogleAI{APIKey: apiKey}), ) - embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001") + embedder := googlegenai.GoogleAIEmbedder(g, "gemini-embedding-001") gablorkenTool := genkit.DefineTool(g, "gablorken", "use this tool when the user asks to calculate a gablorken, carefuly inspect the user input to determine which value from the prompt corresponds to the input structure", func(ctx *ai.ToolContext, input struct { diff --git a/go/samples/custom-agent/main.go b/go/samples/custom-agent/main.go new file mode 100644 index 0000000000..a092e88ad2 --- /dev/null +++ b/go/samples/custom-agent/main.go @@ -0,0 +1,120 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates the AgentFlow API for multi-turn conversation +// with token-level streaming. It runs a CLI REPL where conversation history +// is managed automatically by the session. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + chatFlow := genkit.DefineCustomAgent(g, "chat", + func(ctx context.Context, resp aix.Responder[any], sess *aix.AgentSession[any]) (*aix.AgentFlowResult, error) { + if err := sess.Run(ctx, func(ctx context.Context, input *aix.AgentFlowInput) error { + for chunk, err := range genkit.GenerateStream(ctx, g, + ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a helpful assistant. Keep responses concise."), + ai.WithMessages(sess.Messages()...), + ) { + if err != nil { + return err + } + if chunk.Done { + sess.AddMessages(chunk.Response.Message) + break + } + resp.SendModelChunk(chunk.Chunk) + } + + return nil + }); err != nil { + return nil, err + } + return sess.Result(), nil + }, + aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), + aix.WithSnapshotOn[any](aix.SnapshotEventTurnEnd), + ) + + fmt.Println("Agent Flow Chat (type 'quit' to exit)") + fmt.Println() + + conn, err := chatFlow.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) + } + if chunk.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) + } + if chunk.EndTurn { + fmt.Println() + fmt.Println() + break + } + } + } + + conn.Close() + fmt.Println(conn.Output()) +} diff --git a/go/samples/prompt-agent/main.go b/go/samples/prompt-agent/main.go new file mode 100644 index 0000000000..0f73c52c82 --- /dev/null +++ b/go/samples/prompt-agent/main.go @@ -0,0 +1,99 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates DefinePromptAgent, which creates a +// multi-turn conversational agent flow backed by a .prompt file. The +// conversation loop (render prompt, call model, stream chunks, update history) +// is handled automatically. Compare with custom-agent which wires +// the same loop manually. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + aix "github.com/firebase/genkit/go/ai/exp" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +type ChatPromptInput struct { + Personality string `json:"personality"` +} + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + chatFlow := genkit.DefinePromptAgent( + g, "chat", ChatPromptInput{Personality: "a sarcastic pirate"}, + aix.WithSessionStore(aix.NewInMemorySessionStore[any]()), + aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[any]) bool { + return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 + }), + ) + + fmt.Println("Prompt Agent Chat (type 'quit' to exit)") + fmt.Println() + + conn, err := chatFlow.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.ModelChunk != nil { + fmt.Print(chunk.ModelChunk.Text()) + } + if chunk.SnapshotID != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotID) + } + if chunk.EndTurn { + fmt.Println() + fmt.Println() + break + } + } + } + + conn.Close() +} diff --git a/go/samples/prompt-agent/prompts/chat.prompt b/go/samples/prompt-agent/prompts/chat.prompt new file mode 100644 index 0000000000..6a78a99b07 --- /dev/null +++ b/go/samples/prompt-agent/prompts/chat.prompt @@ -0,0 +1,12 @@ +--- +model: googleai/gemini-3-flash-preview +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: + personality: string + default: + personality: a helpful assistant +--- +You are {{personality}}. Keep responses concise. diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 05d623c07d..37464bce79 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -39,6 +39,13 @@ class Model(RootModel[Any]): root: Any +class SnapshotEvent(StrEnum): + """SnapshotEvent data type class.""" + + TURN_END = 'turnEnd' + INVOCATION_END = 'invocationEnd' + + class Embedding(BaseModel): """Model for embedding data.""" @@ -880,6 +887,15 @@ class Content(RootModel[list[Part]]): root: list[Part] +class Artifact(BaseModel): + """Model for artifact data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str | None = None + parts: list[Part] + metadata: dict[str, Any] | None = None + + class DocumentData(BaseModel): """Model for documentdata data.""" @@ -971,6 +987,43 @@ class Messages(RootModel[list[Message]]): root: list[Message] +class AgentFlowInput(BaseModel): + """Model for agentflowinput data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + messages: list[Message] | None = None + tool_restarts: list[Part] | None = Field(default=None) + + +class AgentFlowResult(BaseModel): + """Model for agentflowresult data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + message: Message | None = None + artifacts: list[Artifact] | None = None + + +class AgentFlowStreamChunk(BaseModel): + """Model for agentflowstreamchunk data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + model_chunk: ModelResponseChunk | None = Field(default=None) + status: Any | None = None + artifact: Artifact | None = None + snapshot_id: str | None = Field(default=None) + end_turn: bool | None = Field(default=None) + + +class SessionState(BaseModel): + """Model for sessionstate data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + messages: list[Message] | None = None + custom: Any | None = None + artifacts: list[Artifact] | None = None + input_variables: Any | None = Field(default=None) + + class Candidate(BaseModel): """Model for candidate data.""" @@ -1048,6 +1101,24 @@ class Request(RootModel[GenerateRequest]): root: GenerateRequest +class AgentFlowInit(BaseModel): + """Model for agentflowinit data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str | None = Field(default=None) + state: SessionState | None = None + + +class AgentFlowOutput(BaseModel): + """Model for agentflowoutput data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + snapshot_id: str | None = Field(default=None) + state: SessionState | None = None + message: Message | None = None + artifacts: list[Artifact] | None = None + + class ModelResponse(BaseModel): """Model for modelresponse data."""