diff --git a/go/ai/background_model.go b/go/ai/background_model.go index e784e7b092..b563ff6743 100644 --- a/go/ai/background_model.go +++ b/go/ai/background_model.go @@ -52,10 +52,10 @@ type ModelOperation = core.Operation[*ModelResponse] // StartModelOpFunc starts a background model operation. type StartModelOpFunc = func(ctx context.Context, req *ModelRequest) (*ModelOperation, error) -// CheckOperationFunc checks the status of a background model operation. +// CheckModelOpFunc checks the status of a background model operation. type CheckModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error) -// CancelOperationFunc cancels a background model operation. +// CancelModelOpFunc cancels a background model operation. type CancelModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error) // BackgroundModelOptions holds configuration for defining a background model diff --git a/go/ai/gen.go b/go/ai/gen.go index e391ef2215..3c91ab50c3 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -334,13 +334,13 @@ type MultipartToolResponse struct { // Operation represents a long-running background task. type Operation struct { // Action is the name of the action being performed by this operation. - Action string `json:"action,omitempty"` + Action string `json:"action"` // Done indicates whether the operation has completed. - Done bool `json:"done,omitempty"` + Done bool `json:"done"` // Error contains error information if the operation failed. Error *OperationError `json:"error,omitempty"` // Id is the unique identifier for this operation. - Id string `json:"id,omitempty"` + Id string `json:"id"` // Metadata contains additional information about the operation. Metadata map[string]any `json:"metadata,omitempty"` // Output contains the result of the operation if it has completed successfully. diff --git a/go/ai/generate.go b/go/ai/generate.go index de1cc1276b..7cc240d6ea 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -284,7 +284,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi // Native constrained output is enabled only when the user has // requested it, the model supports it, and there's a JSON schema. outputCfg.Constrained = opts.Output.JsonSchema != nil && - opts.Output.Constrained && outputCfg.Constrained && m.(*model).supportsConstrained(len(toolDefs) > 0) + opts.Output.Constrained && outputCfg.Constrained && m != nil && m.(*model).supportsConstrained(len(toolDefs) > 0) // Add schema instructions to prompt when not using native constraints. // This is a no-op for unstructured output requests. @@ -313,12 +313,14 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - fn := m.Generate + var fn ModelFunc if bm != nil { if cb != nil { logger.FromContext(ctx).Warn("background model does not support streaming", "model", bm.Name()) } fn = backgroundModelToModelFn(bm.Start) + } else { + fn = m.Generate } fn = core.ChainMiddleware(mw...)(fn) diff --git a/go/core/background_action.go b/go/core/background_action.go index e6af50399b..578de9e5bd 100644 --- a/go/core/background_action.go +++ b/go/core/background_action.go @@ -33,12 +33,12 @@ type CancelOpFunc[Out any] = func(ctx context.Context, op *Operation[Out]) (*Ope // Operation represents a long-running operation started by a background action. type Operation[Out any] struct { - Action string // Key of the action that created this operation. - ID string // ID of the operation. - Done bool // Whether the operation is complete. - Output Out // Result when done. - Error error // Error if the operation failed. - Metadata map[string]any // Additional metadata. + Action string `json:"action"` // Key of the action that created this operation. + ID string `json:"id"` // ID of the operation. + Done bool `json:"done"` // Whether the operation is complete. + Output Out `json:"output,omitempty"` // Result when done. + Error error `json:"error,omitempty"` // Error if the operation failed. + Metadata map[string]any `json:"metadata,omitempty"` // Additional metadata. } // BackgroundActionDef is a background action that can be used to start, check, and cancel background operations. diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index 6ed2d5a8e7..bf3d67e152 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -49,6 +49,11 @@ var ( "ModelResponseChunk": { "index": {}, // fields should be as defined in core/schemas.config }, + "Operation": { + "action": {}, + "done": {}, + "id": {}, + }, } ) diff --git a/go/internal/cmd/jsonschemagen/testdata/golden b/go/internal/cmd/jsonschemagen/testdata/golden index d5e5c0e1ab..06ae7570d2 100644 --- a/go/internal/cmd/jsonschemagen/testdata/golden +++ b/go/internal/cmd/jsonschemagen/testdata/golden @@ -181,7 +181,7 @@ type Message struct { type Operation struct { BlockedOnStep *OperationBlockedOnStep `json:"blockedOnStep,omitempty"` // If the value is false, it means the operation is still in progress. If true, the operation is completed, and either error or response is available. - Done bool `json:"done,omitempty"` + Done bool `json:"done"` // Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Metadata any `json:"metadata,omitempty"` // server-assigned name, which is only unique within the same service that originally returns it. diff --git a/go/plugins/googlegenai/README.md b/go/plugins/googlegenai/README.md new file mode 100644 index 0000000000..95171829ff --- /dev/null +++ b/go/plugins/googlegenai/README.md @@ -0,0 +1,577 @@ +# Google Generative AI Plugin + +The Google AI plugin provides a unified interface to connect with Google's generative AI models through the **Gemini Developer API** or **Vertex AI** using API key authentication or Google Cloud credentials. + +The plugin supports a wide range of capabilities: + +- **Language Models**: Gemini models for text generation, reasoning, and multimodal tasks +- **Embedding Models**: Text and multimodal embeddings +- **Image Models**: Imagen for generation and Gemini for image analysis +- **Video Models**: Veo for video generation and Gemini for video understanding +- **Speech Models**: Polyglot text-to-speech generation + +## Setup + +### Installation + +```bash +go get github.com/firebase/genkit/go/plugins/googlegenai +``` + +### Configuration + +You can use either the Google AI (Gemini API) or Vertex AI backend. + +**Using Google AI (Gemini API):** + +```go +import ( + "context" + "log" + + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +func main() { + ctx := context.Background() + + g := genkit.Init(ctx, + genkit.WithPlugins(&googlegenai.GoogleAI{ + APIKey: "your-api-key", // Optional: defaults to GEMINI_API_KEY or GOOGLE_API_KEY env var + }), + ) +} +``` + +**Using Vertex AI:** + +```go +import ( + "context" + "log" + + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +func main() { + ctx := context.Background() + + g := genkit.Init(ctx, + genkit.WithPlugins(&googlegenai.VertexAI{ + ProjectID: "your-project-id", // Optional: defaults to GOOGLE_CLOUD_PROJECT + Location: "us-central1", // Optional: defaults to GOOGLE_CLOUD_LOCATION + }), + ) +} +``` + +### Authentication + +**Google AI**: Requires a Gemini API Key, which you can get from [Google AI Studio](https://aistudio.google.com/apikey). Set the `GEMINI_API_KEY` environment variable or pass it to the plugin configuration. + +**Vertex AI**: Requires Google Cloud credentials. Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to your service account key file path, or use default credentials (e.g., `gcloud auth application-default login`). + +## Language Models + +You can create models that call the Google Generative AI API. The models support tool calls and some have multi-modal capabilities. + +### Available Models + +Genkit automatically discovers available models supported by the [Go GenAI SDK](https://github.com/google/go-genai). This ensures that recently released models are available immediately as they are added to the SDK, while deprecated models are automatically ignored and hidden from the list of actions. + +Commonly used models include: + +- **Gemini Series**: `gemini-3-pro-preview`, `gemini-3-flash-preview`, `gemini-2.5-flash`, `gemini-2.5-pro` +- **Imagen Series**: `imagen-3.0-generate-001` +- **Veo Series**: `veo-3.0-generate-001` + +:::note +You can use any model ID supported by the underlying SDK. For a complete and up-to-date list of models and their specific capabilities, refer to the [Google Generative AI models documentation](https://ai.google.dev/gemini-api/docs/models). +::: + +### Basic Usage + +```go +import ( + "context" + "fmt" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +func main() { + // ... Init genkit with googlegenai plugin ... + + resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Explain how neural networks learn in simple terms."), + ) + if err != nil { + log.Fatal(err) + } + + fmt.Println(resp.Text()) +} +``` + +### Structured Output + +Gemini models support structured output generation, which guarantees that the model output will conform to a specified schema. Genkit Go provides type-safe generics to make this easy. + +**Using `GenerateData` (Recommended):** + +```go +type Character struct { + Name string `json:"name"` + Bio string `json:"bio"` + Age int `json:"age"` +} + +// Automatically infers schema from the struct and unmarshals the result +char, resp, err := genkit.GenerateData[Character](ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Generate a profile for a fictional character"), +) +if err != nil { + log.Fatal(err) +} + +fmt.Printf("Name: %s, Age: %d\n", char.Name, char.Age) +``` + +**Using `Generate` (Standard):** + +You can also use the standard `Generate` function and unmarshal manually: + +```go +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Generate a profile for a fictional character"), + ai.WithOutputType(Character{}), +) +if err != nil { + log.Fatal(err) +} + +var char Character +if err := resp.Output(&char); err != nil { + log.Fatal(err) +} +``` + +#### Schema Limitations + +The Gemini API relies on a specific subset of the OpenAPI 3.0 standard. When defining schemas (Go structs), keep the following limitations in mind: + +- **Validation**: Keywords like `pattern`, `minLength`, `maxLength` are **not supported** by the API's constrained decoding. +- **Unions**: Complex unions are often problematic. +- **Recursion**: Recursive schemas are generally not supported. + +### Thinking and Reasoning + +Gemini 2.5 and newer models use an internal thinking process that improves reasoning for complex tasks. + +**Thinking Budget:** + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("what is heavier, one kilo of steel or one kilo of feathers"), + ai.WithConfig(&genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](1024), // Number of thinking tokens + IncludeThoughts: true, // Include thought summaries + }, + }), +) +``` + +### Context Caching + +Gemini 2.5 and newer models automatically cache common content prefixes. In Genkit Go, you can mark content for caching using `WithCacheTTL` or `WithCacheName`. + +```go +// Create a message with cached content +cachedMsg := ai.NewUserTextMessage(largeContent).WithCacheTTL(300) + +// First request - content will be cached +resp1, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithMessages(cachedMsg), + ai.WithPrompt("Task 1..."), +) + +// Second request with same prefix - eligible for cache hit +resp2, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + // Reuse the history from previous response or construct messages with same prefix + ai.WithMessages(resp1.History()...), + ai.WithPrompt("Task 2..."), +) +``` + +### Safety Settings + +You can configure safety settings to control content filtering: + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Your prompt here"), + ai.WithConfig(&genai.GenerateContentConfig{ + SafetySettings: []*genai.SafetySetting{ + { + Category: genai.HarmCategoryHateSpeech, + Threshold: genai.HarmBlockThresholdBlockLowAndAbove, + }, + { + Category: genai.HarmCategoryDangerousContent, + Threshold: genai.HarmBlockThresholdBlockMediumAndAbove, + }, + }, + }), +) +``` + +### Google Search Grounding + +Enable Google Search to provide answers with current information and verifiable sources. + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("What are the top tech news stories this week?"), + ai.WithConfig(&genai.GenerateContentConfig{ + Tools: []*genai.Tool{ + { + GoogleSearch: &genai.GoogleSearch{}, + }, + }, + }), +) +``` + +### Google Maps Grounding + +Enable Google Maps to provide location-aware responses. + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Find coffee shops near Times Square"), + ai.WithConfig(&genai.GenerateContentConfig{ + Tools: []*genai.Tool{ + { + GoogleMaps: &genai.GoogleMaps{ + EnableWidget: genai.Ptr(true), + }, + }, + }, + ToolConfig: &genai.ToolConfig{ + RetrievalConfig: &genai.RetrievalConfig{ + LatLng: &genai.LatLng{ + Latitude: genai.Ptr(37.7749), + Longitude: genai.Ptr(-122.4194), + }, + }, + }, + }), +) + +// Access grounding metadata (e.g., for map widget) +if custom, ok := resp.Custom["candidates"].([]*genai.Candidate); ok { + for _, cand := range custom { + if cand.GroundingMetadata != nil && cand.GroundingMetadata.GoogleMapsWidgetContextToken != "" { + fmt.Printf("Map Widget Token: %s\n", cand.GroundingMetadata.GoogleMapsWidgetContextToken) + } + } +} +``` + +### Code Execution + +Enable the model to write and execute Python code for calculations and logic. + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-pro"), + ai.WithPrompt("Calculate the 20th Fibonacci number"), + ai.WithConfig(&genai.GenerateContentConfig{ + Tools: []*genai.Tool{ + { + CodeExecution: &genai.ToolCodeExecution{}, + }, + }, + }), +) +``` + +### Generating Text and Images + +Some Gemini models (like `gemini-2.5-flash-image`) can output images natively alongside text. + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash-image"), + ai.WithPrompt("Create a picture of a futuristic city and describe it"), + ai.WithConfig(&genai.GenerateContentConfig{ + ResponseModalities: []string{"IMAGE", "TEXT"}, + }), +) + +for _, part := range resp.Message.Content { + if part.IsMedia() { + fmt.Printf("Generated image: %s\n", part.ContentType) + // Access data via part.Text (data URI) or helper functions + } +} +``` + +### Multimodal Input Capabilities + +Genkit supports multimodal input (text, image, video, audio) via `ai.Part`. + +**Video/Image/Audio/PDF Input:** + +```go +// Using a URL +videoPart := ai.NewMediaPart("video/mp4", "https://example.com/video.mp4") + +// Using inline data (base64) +imagePart := ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,...") + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithMessages( + ai.NewUserMessage( + ai.NewTextPart("Describe this content"), + videoPart, + ), + ), +) +``` + +## Embedding Models + +### Available Models + +- `text-embedding-004` +- `gemini-embedding-001` +- `multimodalembedding` + +### Usage + +```go +res, err := genkit.Embed(ctx, g, + ai.WithEmbedderName("googleai/gemini-embedding-001"), + ai.WithTextDocs("Machine learning models process data to make predictions."), +) +if err != nil { + log.Fatal(err) +} + +fmt.Printf("Embedding: %v\n", res.Embeddings[0].Embedding) +``` + +## Image Models + +### Available Models + +**Imagen 3 Series**: + +- `imagen-3.0-generate-001` +- `imagen-3.0-fast-generate-001` + +### Usage + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/imagen-3.0-generate-001"), + ai.WithPrompt("A serene Japanese garden with cherry blossoms"), + ai.WithConfig(&genai.GenerateImagesConfig{ + NumberOfImages: 4, + AspectRatio: "16:9", + PersonGeneration: "allow_adult", + }), +) + +// Access generated images in resp.Message.Content +``` + +## Video Models + +The Google AI plugin provides access to video generation capabilities through the Veo models. + +### Available Models + +**Veo 3.1 Series**: + +- `veo-3.1-generate-preview` + +**Veo 3.0 Series**: + +- `veo-3.0-generate-001` +- `veo-3.0-fast-generate-001` + +**Veo 2.0 Series**: + +- `veo-2.0-generate-001` + +### Usage + +Veo operations are long-running and support multiple generation modes. + +#### Backend-Specific Considerations + +The output format and behavior of Veo differ depending on whether you are using the **Google AI** or **Vertex AI** backend. + +##### Model Names + +Ensure you use the correct provider prefix: +- **Google AI**: `googleai/veo-3.1-generate-preview` +- **Vertex AI**: `vertexai/veo-3.1-generate-preview` + +##### Output Format (Video URLs vs. Raw Bytes) + +Depending on the backend and configuration, the generated video will be returned as either a remote URI or as raw bytes encoded in a base64 data URI. + +- **Google AI**: Typically returns a public URI for the video. To download it via HTTP, you must append your API key to the URL: `https://.../video.mp4?key=YOUR_API_KEY`. +- **Vertex AI**: Can return a Cloud Storage URI (`gs://...`) if configured, but by default often returns **raw video bytes**. The Genkit plugin automatically encodes these raw bytes as a **base64 data URI** in the message's text field. + +Your application should be prepared to handle both formats. For example, to save the output directly to a file: + +```go +for _, part := range op.Output.Message.Content { + if part.IsMedia() { + if strings.HasPrefix(part.Text, "data:video/mp4;base64,") { + // Handle base64 encoded bytes (Common for Vertex AI default) + data := strings.TrimPrefix(part.Text, "data:video/mp4;base64,") + b, _ := base64.StdEncoding.DecodeString(data) + os.WriteFile("video.mp4", b, 0644) + } else { + // Handle remote URI (Common for Google AI or Vertex AI with GCS) + // You would typically use an HTTP client or Google Cloud Storage client here + fmt.Printf("Video available at URI: %s\n", part.Text) + } + } +} +``` + +##### Safety Filtering (RAI) + +Veo has strict safety policies. If a prompt triggers a safety filter, the operation will complete but return no video. In this case: + +1. `FinishReason` will be `ai.FinishReasonBlocked`. +2. The output message will contain a text part listing the specific reasons the content was filtered. +3. The original API response (including RAI counts) is available in the `Raw` field. + +#### Text-to-Video + +Generate a video from a text description. + +```go +op, err := genkit.GenerateOperation(ctx, g, + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithMessages(ai.NewUserTextMessage("A majestic dragon soaring over a mystical forest at dawn.")), + ai.WithConfig(&genai.GenerateVideosConfig{ + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(8)), + Resolution: "720p", + }), +) +if err != nil { + log.Fatal(err) +} + +// Poll for completion +op, err = genkit.CheckModelOperation(ctx, g, op) +``` + +#### Image-to-Video + +Animate a static image using a text prompt. + +```go +// Load image data (e.g., base64 encoded) +imagePart := ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,...") + +op, err := genkit.GenerateOperation(ctx, g, + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithMessages(ai.NewUserMessage( + ai.NewTextPart("The cat wakes up and starts accelerating the go-kart."), + imagePart, + )), + ai.WithConfig(&genai.GenerateVideosConfig{ + AspectRatio: "16:9", + }), +) +``` + +#### Video-to-Video (Video Editing) + +Edit or transform an existing video. + +:::note +Video-to-video generation requires a **Veo video URL** (a URL generated by a previous Veo model operation). Arbitrary external video URLs or files are not currently supported for this mode. +::: + +```go +// Provide the URI of a Veo-generated video to edit +videoPart := ai.NewMediaPart("video/mp4", "https://generativelanguage.googleapis.com/...") + +op, err := genkit.GenerateOperation(ctx, g, + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithMessages(ai.NewUserMessage( + ai.NewTextPart("Change the video style to be a cartoon from 1950."), + videoPart, + )), + ai.WithConfig(&genai.GenerateVideosConfig{ + AspectRatio: "16:9", + }), +) +``` + +## Speech Models + +Use `gemini-2.5-flash` or `gemini-2.5-pro` with audio output modality. + +### Usage + +```go +import "google.golang.org/genai" + +resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Say that Genkit is an amazing AI framework"), + ai.WithConfig(&genai.GenerateContentConfig{ + ResponseModalities: []string{"AUDIO"}, + SpeechConfig: &genai.SpeechConfig{ + VoiceConfig: &genai.VoiceConfig{ + PrebuiltVoiceConfig: &genai.PrebuiltVoiceConfig{ + VoiceName: "Algenib", + }, + }, + }, + }), +) + +// The audio data will be in resp.Message.Content as a media part +``` diff --git a/go/plugins/googlegenai/actions.go b/go/plugins/googlegenai/actions.go new file mode 100644 index 0000000000..8a6dd051b6 --- /dev/null +++ b/go/plugins/googlegenai/actions.go @@ -0,0 +1,150 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "google.golang.org/genai" +) + +// ListActions lists all the actions supported by the Google AI plugin. +func (ga *GoogleAI) ListActions(ctx context.Context) []api.ActionDesc { + return listActions(ctx, ga.gclient, googleAIProvider) +} + +// ListActions lists all the actions supported by the Vertex AI plugin. +func (v *VertexAI) ListActions(ctx context.Context) []api.ActionDesc { + return listActions(ctx, v.gclient, vertexAIProvider) +} + +// listActions is the shared implementation for listing actions. +func listActions(ctx context.Context, client *genai.Client, provider string) []api.ActionDesc { + models, err := listGenaiModels(ctx, client) + if err != nil { + return nil + } + + actions := []api.ActionDesc{} + + // Gemini models + for _, name := range models.gemini { + opts := GetModelOptions(name, provider) + model := newModel(client, name, opts) + if actionDef, ok := model.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + // Imagen models + for _, name := range models.imagen { + opts := GetModelOptions(name, provider) + model := newModel(client, name, opts) + if actionDef, ok := model.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + // Veo models (background models) + for _, name := range models.veo { + opts := GetModelOptions(name, provider) + veoModel := newVeoModel(client, name, opts) + if actionDef, ok := veoModel.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + // Embedders + for _, name := range models.embedders { + opts := GetEmbedderOptions(name, provider) + embedder := newEmbedder(client, name, &opts) + if actionDef, ok := embedder.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + return actions +} + +// ResolveAction resolves an action with the given name. +func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action { + return resolveAction(ga.gclient, googleAIProvider, atype, name) +} + +// ResolveAction resolves an action with the given name. +func (v *VertexAI) ResolveAction(atype api.ActionType, name string) api.Action { + return resolveAction(v.gclient, vertexAIProvider, atype, name) +} + +// resolveAction is the shared implementation for resolving actions. +func resolveAction(client *genai.Client, provider string, atype api.ActionType, name string) api.Action { + mt := ClassifyModel(name) + + switch atype { + case api.ActionTypeEmbedder: + opts := GetEmbedderOptions(name, provider) + return newEmbedder(client, name, &opts).(api.Action) + + case api.ActionTypeModel: + // Veo models should not be resolved as regular models + if mt == ModelTypeVeo { + return nil + } + opts := GetModelOptions(name, provider) + return newModel(client, name, opts).(api.Action) + + case api.ActionTypeBackgroundModel: + if mt != ModelTypeVeo { + return nil + } + return createVeoBackgroundAction(client, name, provider) + + case api.ActionTypeCheckOperation: + if mt != ModelTypeVeo { + return nil + } + return createVeoCheckAction(client, name, provider) + } + + return nil +} + +// createVeoBackgroundAction creates a background model action for Veo. +func createVeoBackgroundAction(client *genai.Client, name, provider string) api.Action { + opts := GetModelOptions(name, provider) + veoModel := newVeoModel(client, name, opts) + actionName := api.NewName(provider, name) + + return core.NewAction(actionName, api.ActionTypeBackgroundModel, nil, nil, + func(ctx context.Context, input *ai.ModelRequest) (*core.Operation[*ai.ModelResponse], error) { + op, err := veoModel.Start(ctx, input) + if err != nil { + return nil, err + } + op.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) + return op, nil + }) +} + +// createVeoCheckAction creates a check operation action for Veo. +func createVeoCheckAction(client *genai.Client, name, provider string) api.Action { + opts := GetModelOptions(name, provider) + veoModel := newVeoModel(client, name, opts) + actionName := api.NewName(provider, name) + + return core.NewAction(actionName, api.ActionTypeCheckOperation, + map[string]any{"description": fmt.Sprintf("Check status of %s operation", name)}, nil, + func(ctx context.Context, op *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { + updatedOp, err := veoModel.Check(ctx, op) + if err != nil { + return nil, err + } + updatedOp.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) + return updatedOp, nil + }) +} diff --git a/go/plugins/googlegenai/code_execution.go b/go/plugins/googlegenai/code_execution.go new file mode 100644 index 0000000000..71251fedb3 --- /dev/null +++ b/go/plugins/googlegenai/code_execution.go @@ -0,0 +1,121 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "github.com/firebase/genkit/go/ai" +) + +// CodeExecutionResult represents the result of a code execution. +type CodeExecutionResult struct { + Outcome string `json:"outcome"` + Output string `json:"output"` +} + +// ExecutableCode represents executable code. +type ExecutableCode struct { + Language string `json:"language"` + Code string `json:"code"` +} + +// newCodeExecutionResultPart returns a Part containing the result of code execution. +// This is internal and used by translateCandidate. +func newCodeExecutionResultPart(outcome string, output string) *ai.Part { + return ai.NewCustomPart(map[string]any{ + "codeExecutionResult": map[string]any{ + "outcome": outcome, + "output": output, + }, + }) +} + +// newExecutableCodePart returns a Part containing executable code. +// This is internal and used by translateCandidate. +func newExecutableCodePart(language string, code string) *ai.Part { + return ai.NewCustomPart(map[string]any{ + "executableCode": map[string]any{ + "language": language, + "code": code, + }, + }) +} + +// ToCodeExecutionResult tries to convert an ai.Part to a CodeExecutionResult. +// Returns nil if the part doesn't contain code execution results. +func ToCodeExecutionResult(part *ai.Part) *CodeExecutionResult { + if !part.IsCustom() { + return nil + } + + codeExec, ok := part.Custom["codeExecutionResult"] + if !ok { + return nil + } + + result, ok := codeExec.(map[string]any) + if !ok { + return nil + } + + outcome, _ := result["outcome"].(string) + output, _ := result["output"].(string) + + return &CodeExecutionResult{ + Outcome: outcome, + Output: output, + } +} + +// ToExecutableCode tries to convert an ai.Part to an ExecutableCode. +// Returns nil if the part doesn't contain executable code. +func ToExecutableCode(part *ai.Part) *ExecutableCode { + if !part.IsCustom() { + return nil + } + + execCode, ok := part.Custom["executableCode"] + if !ok { + return nil + } + + code, ok := execCode.(map[string]any) + if !ok { + return nil + } + + language, _ := code["language"].(string) + codeStr, _ := code["code"].(string) + + return &ExecutableCode{ + Language: language, + Code: codeStr, + } +} + +// HasCodeExecution checks if a message contains code execution results or executable code. +func HasCodeExecution(msg *ai.Message) bool { + return GetCodeExecutionResult(msg) != nil || GetExecutableCode(msg) != nil +} + +// GetExecutableCode returns the first executable code from a message. +// Returns nil if the message doesn't contain executable code. +func GetExecutableCode(msg *ai.Message) *ExecutableCode { + for _, part := range msg.Content { + if code := ToExecutableCode(part); code != nil { + return code + } + } + return nil +} + +// GetCodeExecutionResult returns the first code execution result from a message. +// Returns nil if the message doesn't contain a code execution result. +func GetCodeExecutionResult(msg *ai.Message) *CodeExecutionResult { + for _, part := range msg.Content { + if result := ToCodeExecutionResult(part); result != nil { + return result + } + } + return nil +} diff --git a/go/plugins/googlegenai/embedder.go b/go/plugins/googlegenai/embedder.go new file mode 100644 index 0000000000..759d5a7332 --- /dev/null +++ b/go/plugins/googlegenai/embedder.go @@ -0,0 +1,54 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "context" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "google.golang.org/genai" +) + +// newEmbedder creates an embedder without registering it. +func newEmbedder(client *genai.Client, name string, embedOpts *ai.EmbedderOptions) ai.Embedder { + provider := googleAIProvider + if client.ClientConfig().Backend == genai.BackendVertexAI { + provider = vertexAIProvider + } + + if embedOpts.ConfigSchema == nil { + embedOpts.ConfigSchema = core.InferSchemaMap(genai.EmbedContentConfig{}) + } + + return ai.NewEmbedder(api.NewName(provider, name), embedOpts, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + var content []*genai.Content + var embedConfig *genai.EmbedContentConfig + + if config, ok := req.Options.(*genai.EmbedContentConfig); ok { + embedConfig = config + } + + for _, doc := range req.Input { + parts, err := toGeminiParts(doc.Content) + if err != nil { + return nil, err + } + content = append(content, &genai.Content{ + Parts: parts, + }) + } + + r, err := genai.Models.EmbedContent(*client.Models, ctx, name, content, embedConfig) + if err != nil { + return nil, err + } + var res ai.EmbedResponse + for _, emb := range r.Embeddings { + res.Embeddings = append(res.Embeddings, &ai.Embedding{Embedding: emb.Values}) + } + return &res, nil + }) +} diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index bfd9ef5410..05450034d8 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -19,15 +19,11 @@ package googlegenai import ( "context" "encoding/base64" - "encoding/json" "errors" "fmt" "net/http" "net/url" - "reflect" - "regexp" "slices" - "strconv" "strings" "github.com/firebase/genkit/go/ai" @@ -40,31 +36,7 @@ import ( "google.golang.org/genai" ) -const ( - // Tool name regex - toolNameRegex = "^[a-zA-Z_][a-zA-Z0-9_.-]{0,63}$" -) - var ( - // BasicText describes model capabilities for text-only Gemini models. - BasicText = ai.ModelSupports{ - Multiturn: true, - Tools: true, - ToolChoice: true, - SystemRole: true, - Media: false, - } - - // Multimodal describes model capabilities for multimodal Gemini models. - Multimodal = ai.ModelSupports{ - Multiturn: true, - Tools: true, - ToolChoice: true, - SystemRole: true, - Media: true, - Constrained: ai.ConstrainedSupportNoTools, - } - // Attribution header xGoogApiClientHeader = http.CanonicalHeaderKey("x-goog-api-client") genkitClientHeader = http.Header{ @@ -72,16 +44,6 @@ var ( } ) -// EmbedOptions are options for the Vertex AI embedder. -// Set [ai.EmbedRequest.Options] to a value of type *[EmbedOptions]. -type EmbedOptions struct { - // Document title. - Title string `json:"title,omitempty"` - // Task type: RETRIEVAL_QUERY, RETRIEVAL_DOCUMENT, and so forth. - // See the Vertex AI text embedding docs. - TaskType string `json:"task_type,omitempty"` -} - // configToMap converts a config struct to a map[string]any. func configToMap(config any) map[string]any { r := jsonschema.Reflector{ @@ -113,37 +75,37 @@ func configFromRequest(input *ai.ModelRequest) (*genai.GenerateContentConfig, er var err error result, err = base.MapToStruct[genai.GenerateContentConfig](config) if err != nil { - return nil, err + return nil, core.NewPublicError(core.INVALID_ARGUMENT, fmt.Sprintf("The configuration settings are not in the correct format. Check that the names and values match what the model expects: %v", err), nil) } case nil: // Empty but valid config default: - return nil, fmt.Errorf("unexpected config type: %T", input.Config) + return nil, core.NewPublicError(core.INVALID_ARGUMENT, fmt.Sprintf("Invalid configuration type: %T. Expected *genai.GenerateContentConfig. Ensure you are using the correct ModelRef helper (e.g., ModelRef) or passing the correct configuration struct.", input.Config), nil) } return &result, nil } -// newModel creates a model without registering it +// newModel creates a model without registering it. func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model { provider := googleAIProvider if client.ClientConfig().Backend == genai.BackendVertexAI { provider = vertexAIProvider } - var config any - config = &genai.GenerateContentConfig{} - if strings.Contains(name, "imagen") { - config = &genai.GenerateImagesConfig{} - } else if vi, fnd := supportedVideoModels[name]; fnd { - config = &genai.GenerateVideosConfig{} - opts = vi + mt := ClassifyModel(name) + + if opts.ConfigSchema == nil { + if config := mt.DefaultConfig(); config != nil { + opts.ConfigSchema = configToMap(config) + } } + meta := &ai.ModelOptions{ Label: opts.Label, Supports: opts.Supports, Versions: opts.Versions, - ConfigSchema: configToMap(config), + ConfigSchema: opts.ConfigSchema, Stage: opts.Stage, } @@ -152,13 +114,14 @@ func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, ) (*ai.ModelResponse, error) { - switch config.(type) { - case *genai.GenerateImagesConfig: + switch mt { + case ModelTypeImagen: return generateImage(ctx, client, name, input, cb) default: return generate(ctx, client, name, input, cb) } } + // the gemini api doesn't support downloading media from http(s) if opts.Supports.Media { fn = core.ChainMiddleware(ai.DownloadRequestMedia(&ai.DownloadMediaOptions{ @@ -182,49 +145,8 @@ func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model return ai.NewModel(api.NewName(provider, name), meta, fn) } -// newEmbedder creates an embedder without registering it -func newEmbedder(client *genai.Client, name string, embedOpts *ai.EmbedderOptions) ai.Embedder { - provider := googleAIProvider - if client.ClientConfig().Backend == genai.BackendVertexAI { - provider = vertexAIProvider - } - - if embedOpts.ConfigSchema == nil { - embedOpts.ConfigSchema = core.InferSchemaMap(genai.EmbedContentConfig{}) - } - - return ai.NewEmbedder(api.NewName(provider, name), embedOpts, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { - var content []*genai.Content - var embedConfig *genai.EmbedContentConfig - - if config, ok := req.Options.(*genai.EmbedContentConfig); ok { - embedConfig = config - } - - for _, doc := range req.Input { - parts, err := toGeminiParts(doc.Content) - if err != nil { - return nil, err - } - content = append(content, &genai.Content{ - Parts: parts, - }) - } - - r, err := genai.Models.EmbedContent(*client.Models, ctx, name, content, embedConfig) - if err != nil { - return nil, err - } - var res ai.EmbedResponse - for _, emb := range r.Embeddings { - res.Embeddings = append(res.Embeddings, &ai.Embedding{Embedding: emb.Values}) - } - return &res, nil - }) -} - -// Generate requests generate call to the specified model with the provided -// configuration +// generate requests generate call to the specified model with the provided +// configuration. func generate( ctx context.Context, client *genai.Client, @@ -447,344 +369,6 @@ func toGeminiRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai return gcc, nil } -// toGeminiTools translates a slice of [ai.ToolDefinition] to a slice of [genai.Tool]. -func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { - var outTools []*genai.Tool - functions := []*genai.FunctionDeclaration{} - - for _, t := range inTools { - if !validToolName(t.Name) { - return nil, fmt.Errorf(`invalid tool name: %q, must start with a letter or an underscore, must be alphanumeric, underscores, dots or dashes with a max length of 64 chars`, t.Name) - } - inputSchema, err := toGeminiSchema(t.InputSchema, t.InputSchema) - if err != nil { - return nil, err - } - fd := &genai.FunctionDeclaration{ - Name: t.Name, - Parameters: inputSchema, - Description: t.Description, - } - functions = append(functions, fd) - } - - if len(functions) > 0 { - outTools = append(outTools, &genai.Tool{ - FunctionDeclarations: functions, - }) - } - - return outTools, nil -} - -// toGeminiFunctionResponsePart translates a slice of [ai.Part] to a slice of [genai.FunctionResponsePart] -func toGeminiFunctionResponsePart(parts []*ai.Part) ([]*genai.FunctionResponsePart, error) { - frp := []*genai.FunctionResponsePart{} - for _, p := range parts { - switch { - case p.IsData(): - contentType, data, err := uri.Data(p) - if err != nil { - return nil, err - } - frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) - case p.IsMedia(): - if strings.HasPrefix(p.Text, "data:") { - contentType, data, err := uri.Data(p) - if err != nil { - return nil, err - } - frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) - continue - } - frp = append(frp, genai.NewFunctionResponsePartFromURI(p.Text, p.ContentType)) - default: - return nil, fmt.Errorf("unsupported function response part type: %d", p.Kind) - } - } - return frp, nil -} - -// mergeTools consolidates all FunctionDeclarations into a single Tool -// while preserving non-function tools (Retrieval, GoogleSearch, CodeExecution, etc.) -func mergeTools(ts []*genai.Tool) []*genai.Tool { - var decls []*genai.FunctionDeclaration - var out []*genai.Tool - - for _, t := range ts { - if t == nil { - continue - } - if len(t.FunctionDeclarations) == 0 { - out = append(out, t) - continue - } - decls = append(decls, t.FunctionDeclarations...) - if cpy := cloneToolWithoutFunctions(t); cpy != nil && !reflect.ValueOf(*cpy).IsZero() { - out = append(out, cpy) - } - } - - if len(decls) > 0 { - out = append([]*genai.Tool{{FunctionDeclarations: decls}}, out...) - } - return out -} - -func cloneToolWithoutFunctions(t *genai.Tool) *genai.Tool { - if t == nil { - return nil - } - clone := *t - clone.FunctionDeclarations = nil - return &clone -} - -// toGeminiSchema translates a map representing a standard JSON schema to a more -// limited [genai.Schema]. -func toGeminiSchema(originalSchema map[string]any, genkitSchema map[string]any) (*genai.Schema, error) { - // this covers genkitSchema == nil and {} - // genkitSchema will be {} if it's any - if len(genkitSchema) == 0 { - return nil, nil - } - if v, ok := genkitSchema["$ref"]; ok { - ref, ok := v.(string) - if !ok { - return nil, fmt.Errorf("invalid $ref value: not a string") - } - s, err := resolveRef(originalSchema, ref) - if err != nil { - return nil, err - } - return toGeminiSchema(originalSchema, s) - } - - // Handle "anyOf" subschemas by finding the first valid schema definition - if v, ok := genkitSchema["anyOf"]; ok { - if anyOfList, isList := v.([]map[string]any); isList { - for _, subSchema := range anyOfList { - if subSchemaType, hasType := subSchema["type"]; hasType { - if typeStr, isString := subSchemaType.(string); isString && typeStr != "null" { - if title, ok := genkitSchema["title"]; ok { - subSchema["title"] = title - } - if description, ok := genkitSchema["description"]; ok { - subSchema["description"] = description - } - // Found a schema like: {"type": "string"} - return toGeminiSchema(originalSchema, subSchema) - } - } - } - } - } - - schema := &genai.Schema{} - typeVal, ok := genkitSchema["type"] - if !ok { - return nil, fmt.Errorf("schema is missing the 'type' field: %#v", genkitSchema) - } - - typeStr, ok := typeVal.(string) - if !ok { - return nil, fmt.Errorf("schema 'type' field is not a string, but %T", typeVal) - } - - switch typeStr { - case "string": - schema.Type = genai.TypeString - case "float64", "number": - schema.Type = genai.TypeNumber - case "integer": - schema.Type = genai.TypeInteger - case "boolean": - schema.Type = genai.TypeBoolean - case "object": - schema.Type = genai.TypeObject - case "array": - schema.Type = genai.TypeArray - default: - return nil, fmt.Errorf("schema type %q not allowed", genkitSchema["type"]) - } - if v, ok := genkitSchema["required"]; ok { - schema.Required = castToStringArray(v) - } - if v, ok := genkitSchema["propertyOrdering"]; ok { - schema.PropertyOrdering = castToStringArray(v) - } - if v, ok := genkitSchema["description"]; ok { - schema.Description = v.(string) - } - if v, ok := genkitSchema["format"]; ok { - schema.Format = v.(string) - } - if v, ok := genkitSchema["title"]; ok { - schema.Title = v.(string) - } - if v, ok := genkitSchema["minItems"]; ok { - if i64, ok := castToInt64(v); ok { - schema.MinItems = genai.Ptr(i64) - } - } - if v, ok := genkitSchema["maxItems"]; ok { - if i64, ok := castToInt64(v); ok { - schema.MaxItems = genai.Ptr(i64) - } - } - if v, ok := genkitSchema["maximum"]; ok { - if f64, ok := castToFloat64(v); ok { - schema.Maximum = genai.Ptr(f64) - } - } - if v, ok := genkitSchema["minimum"]; ok { - if f64, ok := castToFloat64(v); ok { - schema.Minimum = genai.Ptr(f64) - } - } - if v, ok := genkitSchema["enum"]; ok { - schema.Enum = castToStringArray(v) - } - if v, ok := genkitSchema["items"]; ok { - items, err := toGeminiSchema(originalSchema, v.(map[string]any)) - if err != nil { - return nil, err - } - schema.Items = items - } - if val, ok := genkitSchema["properties"]; ok { - props := map[string]*genai.Schema{} - for k, v := range val.(map[string]any) { - p, err := toGeminiSchema(originalSchema, v.(map[string]any)) - if err != nil { - return nil, err - } - props[k] = p - } - schema.Properties = props - } - // Nullable -- not supported in jsonschema.Schema - - return schema, nil -} - -func resolveRef(originalSchema map[string]any, ref string) (map[string]any, error) { - tkns := strings.Split(ref, "/") - // refs look like: $/ref/foo -- we need the foo part - name := tkns[len(tkns)-1] - if defs, ok := originalSchema["$defs"].(map[string]any); ok { - if def, ok := defs[name].(map[string]any); ok { - return def, nil - } - } - // definitions (legacy) - if defs, ok := originalSchema["definitions"].(map[string]any); ok { - if def, ok := defs[name].(map[string]any); ok { - return def, nil - } - } - return nil, fmt.Errorf("unable to resolve schema reference") -} - -// castToStringArray converts either []any or []string to []string, filtering non-strings. -// This handles enum values from JSON Schema which may come as either type depending on unmarshaling. -// Filter out non-string types from if v is []any type. -func castToStringArray(v any) []string { - switch a := v.(type) { - case []string: - // Return a shallow copy to avoid aliasing - out := make([]string, 0, len(a)) - for _, s := range a { - if s != "" { - out = append(out, s) - } - } - return out - case []any: - var out []string - for _, it := range a { - if s, ok := it.(string); ok && s != "" { - out = append(out, s) - } - } - return out - default: - return nil - } -} - -// castToInt64 converts v to int64 when possible. -func castToInt64(v any) (int64, bool) { - switch t := v.(type) { - case int: - return int64(t), true - case int64: - return t, true - case float64: - return int64(t), true - case string: - if i, err := strconv.ParseInt(t, 10, 64); err == nil { - return i, true - } - case json.Number: - if i, err := t.Int64(); err == nil { - return i, true - } - } - return 0, false -} - -// castToFloat64 converts v to float64 when possible. -func castToFloat64(v any) (float64, bool) { - switch t := v.(type) { - case float64: - return t, true - case int: - return float64(t), true - case int64: - return float64(t), true - case string: - if f, err := strconv.ParseFloat(t, 64); err == nil { - return f, true - } - case json.Number: - if f, err := t.Float64(); err == nil { - return f, true - } - } - return 0, false -} - -func toGeminiToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) (*genai.ToolConfig, error) { - var mode genai.FunctionCallingConfigMode - switch toolChoice { - case "": - return nil, nil - case ai.ToolChoiceAuto: - mode = genai.FunctionCallingConfigModeAuto - case ai.ToolChoiceRequired: - mode = genai.FunctionCallingConfigModeAny - case ai.ToolChoiceNone: - mode = genai.FunctionCallingConfigModeNone - default: - return nil, fmt.Errorf("tool choice mode %q not supported", toolChoice) - } - - var toolNames []string - // Per docs, only set AllowedToolNames with mode set to ANY. - if mode == genai.FunctionCallingConfigModeAny { - for _, t := range tools { - toolNames = append(toolNames, t.Name) - } - } - return &genai.ToolConfig{ - FunctionCallingConfig: &genai.FunctionCallingConfig{ - Mode: mode, - AllowedFunctionNames: toolNames, - }, - }, nil -} - // translateCandidate translates from a genai.GenerateContentResponse to an ai.ModelResponse. func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { m := &ai.ModelResponse{} @@ -865,14 +449,14 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { } if part.CodeExecutionResult != nil { partFound++ - p = NewCodeExecutionResultPart( + p = newCodeExecutionResultPart( string(part.CodeExecutionResult.Outcome), part.CodeExecutionResult.Output, ) } if part.ExecutableCode != nil { partFound++ - p = NewExecutableCodePart( + p = newExecutableCodePart( string(part.ExecutableCode.Language), part.ExecutableCode.Code, ) @@ -897,7 +481,7 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { return m, nil } -// Translate from a genai.GenerateContentResponse to a ai.ModelResponse. +// translateResponse translates from a genai.GenerateContentResponse to a ai.ModelResponse. func translateResponse(resp *genai.GenerateContentResponse) (*ai.ModelResponse, error) { var r *ai.ModelResponse var err error @@ -1027,125 +611,3 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { return gp, nil } - -// validToolName checks whether the provided tool name matches the -// following criteria: -// - Start with a letter or an underscore -// - Must be alphanumeric and can include underscores, dots or dashes -// - Maximum length of 64 chars -func validToolName(n string) bool { - re := regexp.MustCompile(toolNameRegex) - - return re.MatchString(n) -} - -// CodeExecutionResult represents the result of a code execution. -type CodeExecutionResult struct { - Outcome string `json:"outcome"` - Output string `json:"output"` -} - -// ExecutableCode represents executable code. -type ExecutableCode struct { - Language string `json:"language"` - Code string `json:"code"` -} - -// NewCodeExecutionResultPart returns a Part containing the result of code execution. -func NewCodeExecutionResultPart(outcome string, output string) *ai.Part { - return ai.NewCustomPart(map[string]any{ - "codeExecutionResult": map[string]any{ - "outcome": outcome, - "output": output, - }, - }) -} - -// NewExecutableCodePart returns a Part containing executable code. -func NewExecutableCodePart(language string, code string) *ai.Part { - return ai.NewCustomPart(map[string]any{ - "executableCode": map[string]any{ - "language": language, - "code": code, - }, - }) -} - -// ToCodeExecutionResult tries to convert an ai.Part to a CodeExecutionResult. -// Returns nil if the part doesn't contain code execution results. -func ToCodeExecutionResult(part *ai.Part) *CodeExecutionResult { - if !part.IsCustom() { - return nil - } - - codeExec, ok := part.Custom["codeExecutionResult"] - if !ok { - return nil - } - - result, ok := codeExec.(map[string]any) - if !ok { - return nil - } - - outcome, _ := result["outcome"].(string) - output, _ := result["output"].(string) - - return &CodeExecutionResult{ - Outcome: outcome, - Output: output, - } -} - -// ToExecutableCode tries to convert an ai.Part to an ExecutableCode. -// Returns nil if the part doesn't contain executable code. -func ToExecutableCode(part *ai.Part) *ExecutableCode { - if !part.IsCustom() { - return nil - } - - execCode, ok := part.Custom["executableCode"] - if !ok { - return nil - } - - code, ok := execCode.(map[string]any) - if !ok { - return nil - } - - language, _ := code["language"].(string) - codeStr, _ := code["code"].(string) - - return &ExecutableCode{ - Language: language, - Code: codeStr, - } -} - -// HasCodeExecution checks if a message contains code execution results or executable code. -func HasCodeExecution(msg *ai.Message) bool { - return GetCodeExecutionResult(msg) != nil || GetExecutableCode(msg) != nil -} - -// GetExecutableCode returns the first executable code from a message. -// Returns nil if the message doesn't contain executable code. -func GetExecutableCode(msg *ai.Message) *ExecutableCode { - for _, part := range msg.Content { - if code := ToExecutableCode(part); code != nil { - return code - } - } - return nil -} - -// GetCodeExecutionResult returns the first code execution result from a message. -// Returns nil if the message doesn't contain a code execution result. -func GetCodeExecutionResult(msg *ai.Message) *CodeExecutionResult { - for _, part := range msg.Content { - if result := ToCodeExecutionResult(part); result != nil { - return result - } - } - return nil -} diff --git a/go/plugins/googlegenai/gemini_test.go b/go/plugins/googlegenai/gemini_test.go index b319d91f08..a8db1ed2ee 100644 --- a/go/plugins/googlegenai/gemini_test.go +++ b/go/plugins/googlegenai/gemini_test.go @@ -262,6 +262,17 @@ func TestConvertRequest(t *testing.T) { }) } }) + t.Run("invalid config map", func(t *testing.T) { + req := ai.ModelRequest{ + Config: map[string]any{ + "temperature": "not a number", // This should fail map->struct conversion + }, + } + _, err := toGeminiRequest(&req, nil) + if err == nil { + t.Fatal("expected error for invalid config map") + } + }) t.Run("convert tools with valid tool", func(t *testing.T) { tools := []*ai.ToolDefinition{tool} gt, err := toGeminiTools(tools) diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index 783eccd239..f5b2375bae 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -70,8 +70,6 @@ func TestGoogleAILive(t *testing.T) { genkit.WithPlugins(&googlegenai.GoogleAI{APIKey: apiKey}), ) - embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001") - gablorkenTool := genkit.DefineTool(g, "gablorken", "use this tool when the user asks to calculate a gablorken, carefuly inspect the user input to determine which value from the prompt corresponds to the input structure", func(ctx *ai.ToolContext, input struct { Value int @@ -89,7 +87,7 @@ func TestGoogleAILive(t *testing.T) { ) t.Run("embedder", func(t *testing.T) { - res, err := genkit.Embed(ctx, g, ai.WithEmbedder(embedder), ai.WithTextDocs("yellow banana")) + res, err := genkit.Embed(ctx, g, ai.WithEmbedderName("googleai/gemini-embedding-001"), ai.WithTextDocs("yellow banana")) if err != nil { t.Fatal(err) } diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go index d8651f1dea..4940f7ab2c 100644 --- a/go/plugins/googlegenai/googlegenai.go +++ b/go/plugins/googlegenai/googlegenai.go @@ -9,13 +9,11 @@ import ( "fmt" "net/http" "os" - "strings" "sync" "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/httptransport" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/genkit" @@ -31,27 +29,6 @@ const ( vertexAILabelPrefix = "Vertex AI" ) -var ( - defaultGeminiOpts = ai.ModelOptions{ - Supports: &Multimodal, - Versions: []string{}, - Stage: ai.ModelStageUnstable, - } - - defaultImagenOpts = ai.ModelOptions{ - Supports: &Media, - Versions: []string{}, - Stage: ai.ModelStageUnstable, - } - - defaultEmbedOpts = ai.EmbedderOptions{ - Supports: &ai.EmbedderSupports{ - Input: []string{"text"}, - }, - Dimensions: 768, - } -) - // GoogleAI is a Genkit plugin for interacting with the Google AI service. type GoogleAI struct { APIKey string // API key to access the service. If empty, the values of the environment variables GEMINI_API_KEY or GOOGLE_API_KEY will be consulted, in that order. @@ -283,283 +260,34 @@ func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool { return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) != nil } -// ModelRef creates a new ModelRef for a Google Gen AI model with the given name and configuration. -func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(name, config) -} - -// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given ID and configuration. -func GoogleAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(googleAIProvider+"/"+id, config) -} - -// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given ID and configuration. -func VertexAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(vertexAIProvider+"/"+id, config) -} - // GoogleAIModel returns the [ai.Model] with the given name. // It returns nil if the model was not defined. +// +// Deprecated: Use genkit.LookupModel instead. func GoogleAIModel(g *genkit.Genkit, name string) ai.Model { return genkit.LookupModel(g, api.NewName(googleAIProvider, name)) } // VertexAIModel returns the [ai.Model] with the given name. // It returns nil if the model was not defined. +// +// Deprecated: Use genkit.LookupModel instead. func VertexAIModel(g *genkit.Genkit, name string) ai.Model { return genkit.LookupModel(g, api.NewName(vertexAIProvider, name)) } // GoogleAIEmbedder returns the [ai.Embedder] with the given name. // It returns nil if the embedder was not defined. +// +// Deprecated: Use genkit.LookupEmbedder instead. func GoogleAIEmbedder(g *genkit.Genkit, name string) ai.Embedder { return genkit.LookupEmbedder(g, api.NewName(googleAIProvider, name)) } // VertexAIEmbedder returns the [ai.Embedder] with the given name. // It returns nil if the embedder was not defined. +// +// Deprecated: Use genkit.LookupEmbedder instead. func VertexAIEmbedder(g *genkit.Genkit, name string) ai.Embedder { return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) } - -// ListActions lists all the actions supported by the Google AI plugin. -func (ga *GoogleAI) ListActions(ctx context.Context) []api.ActionDesc { - models, err := listGenaiModels(ctx, ga.gclient) - if err != nil { - return nil - } - - actions := []api.ActionDesc{} - - // Generative models. - for _, name := range models.gemini { - var opts ai.ModelOptions - if knownOpts, ok := supportedGeminiModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, opts.Label) - } else { - opts = defaultGeminiOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, name) - } - - model := newModel(ga.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Imagen models. - for _, name := range models.imagen { - var opts ai.ModelOptions - if knownOpts, ok := supportedImagenModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, opts.Label) - } else { - opts = defaultImagenOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, name) - } - - model := newModel(ga.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Embedders. - for _, e := range models.embedders { - var embedOpts ai.EmbedderOptions - if knownOpts, ok := googleAIEmbedderConfig[e]; ok { - embedOpts = knownOpts - } else { - embedOpts = defaultEmbedOpts - embedOpts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, e) - } - - embedder := newEmbedder(ga.gclient, e, &embedOpts) - if actionDef, ok := embedder.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - return actions -} - -// ResolveAction resolves an action with the given name. -func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action { - switch atype { - case api.ActionTypeEmbedder: - return newEmbedder(ga.gclient, name, &ai.EmbedderOptions{}).(api.Action) - case api.ActionTypeModel: - var supports *ai.ModelSupports - var config any - - // TODO: Add veo case. - switch { - case strings.Contains(name, "imagen"): - supports = &Media - config = &genai.GenerateImagesConfig{} - default: - supports = &Multimodal - config = &genai.GenerateContentConfig{} - } - - return newModel(ga.gclient, name, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: supports, - ConfigSchema: configToMap(config), - }).(api.Action) - case api.ActionTypeBackgroundModel: - // Handle VEO models as background models - if strings.HasPrefix(name, "veo") { - veoModel := newVeoModel(ga.gclient, name, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - }) - actionName := fmt.Sprintf("%s/%s", googleAIProvider, name) - return core.NewAction(actionName, api.ActionTypeBackgroundModel, nil, nil, - func(ctx context.Context, input *ai.ModelRequest) (*core.Operation[*ai.ModelResponse], error) { - op, err := veoModel.Start(ctx, input) - if err != nil { - return nil, err - } - op.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) - return op, nil - }) - } - return nil - case api.ActionTypeCheckOperation: - // Handle VEO model check operations - if strings.HasPrefix(name, "veo") { - veoModel := newVeoModel(ga.gclient, name, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - }) - - actionName := fmt.Sprintf("%s/%s", googleAIProvider, name) - return core.NewAction(actionName, api.ActionTypeCheckOperation, - map[string]any{"description": fmt.Sprintf("Check status of %s operation", name)}, nil, - func(ctx context.Context, op *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { - updatedOp, err := veoModel.Check(ctx, op) - if err != nil { - return nil, err - } - updatedOp.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) - return updatedOp, nil - }) - } - return nil - } - return nil -} - -// ListActions lists all the actions supported by the Vertex AI plugin. -func (v *VertexAI) ListActions(ctx context.Context) []api.ActionDesc { - models, err := listGenaiModels(ctx, v.gclient) - if err != nil { - return nil - } - - actions := []api.ActionDesc{} - - // Gemini generative models. - for _, name := range models.gemini { - var opts ai.ModelOptions - if knownOpts, ok := supportedGeminiModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, opts.Label) - } else { - opts = defaultGeminiOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name) - } - - model := newModel(v.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Imagen models. - for _, name := range models.imagen { - var opts ai.ModelOptions - if knownOpts, ok := supportedImagenModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, opts.Label) - } else { - opts = defaultImagenOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name) - } - - model := newModel(v.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Embedders. - for _, e := range models.embedders { - var embedOpts ai.EmbedderOptions - if knownOpts, ok := googleAIEmbedderConfig[e]; ok { - embedOpts = knownOpts - } else { - embedOpts = defaultEmbedOpts - embedOpts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, e) - } - - embedder := newEmbedder(v.gclient, e, &embedOpts) - if actionDef, ok := embedder.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - return actions -} - -// ResolveAction resolves an action with the given name. -func (v *VertexAI) ResolveAction(atype api.ActionType, id string) api.Action { - switch atype { - case api.ActionTypeEmbedder: - return newEmbedder(v.gclient, id, &ai.EmbedderOptions{}).(api.Action) - case api.ActionTypeModel: - var supports *ai.ModelSupports - var config any - - // TODO: Add veo case. - switch { - case strings.Contains(id, "imagen"): - supports = &Media - config = &genai.GenerateImagesConfig{} - default: - supports = &Multimodal - config = &genai.GenerateContentConfig{} - } - - return newModel(v.gclient, id, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", vertexAILabelPrefix, id), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: supports, - ConfigSchema: configToMap(config), - }).(api.Action) - } - return nil -} diff --git a/go/plugins/googlegenai/imagen.go b/go/plugins/googlegenai/imagen.go index 6003494a30..5ecda8e6ca 100644 --- a/go/plugins/googlegenai/imagen.go +++ b/go/plugins/googlegenai/imagen.go @@ -22,20 +22,11 @@ import ( "fmt" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/internal/base" "google.golang.org/genai" ) -// Media describes model capabilities for Gemini models with media and text -// input and image only output -var Media = ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - ToolChoice: false, - SystemRole: false, -} - // imagenConfigFromRequest translates an [*ai.ModelRequest] configuration to [*genai.GenerateImagesConfig] func imagenConfigFromRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfig, error) { var result genai.GenerateImagesConfig @@ -49,12 +40,12 @@ func imagenConfigFromRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfi var err error result, err = base.MapToStruct[genai.GenerateImagesConfig](config) if err != nil { - return nil, err + return nil, core.NewPublicError(core.INVALID_ARGUMENT, fmt.Sprintf("The image configuration settings are not in the correct format. Check that the names and values match what the model expects: %v", err), nil) } case nil: // empty but valid config default: - return nil, fmt.Errorf("unexpected config type: %T", input.Config) + return nil, core.NewPublicError(core.INVALID_ARGUMENT, fmt.Sprintf("Invalid configuration type: %T. Expected *genai.GenerateImagesConfig. Ensure you are using the correct ModelRef helper (e.g., ImageModelRef) or passing the correct configuration struct.", input.Config), nil) } return &result, nil diff --git a/go/plugins/googlegenai/imagen_test.go b/go/plugins/googlegenai/imagen_test.go new file mode 100644 index 0000000000..ca5f3731a1 --- /dev/null +++ b/go/plugins/googlegenai/imagen_test.go @@ -0,0 +1,120 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "testing" + + "github.com/firebase/genkit/go/ai" + "google.golang.org/genai" +) + +func TestImagenConfigFromRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + request *ai.ModelRequest + expectError bool + }{ + { + name: "valid config struct pointer", + request: &ai.ModelRequest{ + Config: &genai.GenerateImagesConfig{ + NumberOfImages: 2, + }, + }, + expectError: false, + }, + { + name: "valid config struct value", + request: &ai.ModelRequest{ + Config: genai.GenerateImagesConfig{ + NumberOfImages: 1, + }, + }, + expectError: false, + }, + { + name: "valid map config", + request: &ai.ModelRequest{ + Config: map[string]any{ + "numberOfImages": 4, + }, + }, + expectError: false, + }, + { + name: "nil config", + request: &ai.ModelRequest{ + Config: nil, + }, + expectError: false, + }, + { + name: "invalid config type", + request: &ai.ModelRequest{ + Config: &genai.GenerateContentConfig{}, + }, + expectError: true, + }, + { + name: "invalid map values", + request: &ai.ModelRequest{ + Config: map[string]any{ + "numberOfImages": "not-a-number", + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := imagenConfigFromRequest(tt.request) + if (err != nil) != tt.expectError { + t.Errorf("imagenConfigFromRequest() error = %v, expectError %v", err, tt.expectError) + } + }) + } +} + +func TestTranslateImagenResponse(t *testing.T) { + t.Parallel() + + resp := &genai.GenerateImagesResponse{ + GeneratedImages: []*genai.GeneratedImage{ + { + Image: &genai.Image{ + MIMEType: "image/png", + ImageBytes: []byte("fake-image-data"), + }, + }, + }, + } + + res := translateImagenResponse(resp) + if res.FinishReason != ai.FinishReasonStop { + t.Errorf("expected finish reason %s, got %s", ai.FinishReasonStop, res.FinishReason) + } + if len(res.Message.Content) != 1 { + t.Fatalf("expected 1 content part, got %d", len(res.Message.Content)) + } + if res.Message.Content[0].ContentType != "image/png" { + t.Errorf("expected content type image/png, got %s", res.Message.Content[0].ContentType) + } +} diff --git a/go/plugins/googlegenai/model_type.go b/go/plugins/googlegenai/model_type.go new file mode 100644 index 0000000000..201f50a25e --- /dev/null +++ b/go/plugins/googlegenai/model_type.go @@ -0,0 +1,83 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "google.golang.org/genai" +) + +// ModelType categorizes models by their generation modality. +type ModelType int + +const ( + ModelTypeUnknown ModelType = iota + ModelTypeGemini // Text/multimodal generation (gemini-*, gemma-*) + ModelTypeImagen // Image generation (imagen-*) + ModelTypeVeo // Video generation (veo-*), long-running + ModelTypeEmbedder // Embedding models (*embedding*) +) + +// ClassifyModel determines the model type from its name. +// This is the single source of truth for model type classification. +func ClassifyModel(name string) ModelType { + switch { + case strings.HasPrefix(name, "veo"): + return ModelTypeVeo + case strings.HasPrefix(name, "imagen"), strings.HasPrefix(name, "image"): + return ModelTypeImagen + case strings.HasPrefix(name, "gemini"), strings.HasPrefix(name, "gemma"): + return ModelTypeGemini + case strings.Contains(name, "embedding"): + // Covers: text-embedding-*, embedding-*, textembedding-*, multimodalembedding + return ModelTypeEmbedder + default: + return ModelTypeUnknown + } +} + +// ActionType returns the appropriate API action type for this model type. +func (mt ModelType) ActionType() api.ActionType { + switch mt { + case ModelTypeVeo: + return api.ActionTypeBackgroundModel + case ModelTypeEmbedder: + return api.ActionTypeEmbedder + default: + return api.ActionTypeModel + } +} + +// DefaultSupports returns the default ModelSupports for this model type. +func (mt ModelType) DefaultSupports() *ai.ModelSupports { + switch mt { + case ModelTypeGemini: + return &Multimodal + case ModelTypeImagen: + return &Media + case ModelTypeVeo: + return &VeoSupports + default: + return nil + } +} + +// DefaultConfig returns the default config struct for this model type. +func (mt ModelType) DefaultConfig() any { + switch mt { + case ModelTypeGemini: + return &genai.GenerateContentConfig{} + case ModelTypeImagen: + return &genai.GenerateImagesConfig{} + case ModelTypeVeo: + return &genai.GenerateVideosConfig{} + case ModelTypeEmbedder: + return &genai.EmbedContentConfig{} + default: + return nil + } +} diff --git a/go/plugins/googlegenai/models.go b/go/plugins/googlegenai/models.go index 3762a9e4f4..3ca9912ffd 100644 --- a/go/plugins/googlegenai/models.go +++ b/go/plugins/googlegenai/models.go @@ -6,7 +6,6 @@ package googlegenai import ( "context" "fmt" - "log" "slices" "strings" @@ -14,33 +13,90 @@ import ( "google.golang.org/genai" ) +// Model capability definitions - these describe what different model types support. +var ( + // BasicText describes model capabilities for text-only Gemini models. + BasicText = ai.ModelSupports{ + Multiturn: true, + Tools: true, + ToolChoice: true, + SystemRole: true, + Media: false, + } + + // Multimodal describes model capabilities for multimodal Gemini models. + Multimodal = ai.ModelSupports{ + Multiturn: true, + Tools: true, + ToolChoice: true, + SystemRole: true, + Media: true, + Constrained: ai.ConstrainedSupportNoTools, + } + + // Media describes model capabilities for image generation models (Imagen). + Media = ai.ModelSupports{ + Multiturn: false, + Tools: false, + SystemRole: false, + Media: true, + Output: []string{"media"}, + } + + // VeoSupports describes model capabilities for video generation models (Veo). + VeoSupports = ai.ModelSupports{ + Media: true, + Multiturn: false, + Tools: false, + SystemRole: false, + Output: []string{"media"}, + LongRunning: true, + } +) + +// Default options for unknown models of each type. +var ( + defaultGeminiOpts = ai.ModelOptions{ + Supports: &Multimodal, + Stage: ai.ModelStageUnstable, + ConfigSchema: configToMap(genai.GenerateContentConfig{}), + } + + defaultImagenOpts = ai.ModelOptions{ + Supports: &Media, + Stage: ai.ModelStageUnstable, + ConfigSchema: configToMap(genai.GenerateImagesConfig{}), + } + + defaultVeoOpts = ai.ModelOptions{ + Supports: &VeoSupports, + Stage: ai.ModelStageUnstable, + ConfigSchema: configToMap(genai.GenerateVideosConfig{}), + } + + defaultEmbedOpts = ai.EmbedderOptions{ + Supports: &ai.EmbedderSupports{Input: []string{"text"}}, + Dimensions: 768, + } +) + const ( - gemini15Flash = "gemini-1.5-flash" - gemini15Pro = "gemini-1.5-pro" - gemini15Flash8b = "gemini-1.5-flash-8b" - - gemini20Flash = "gemini-2.0-flash" - gemini20FlashExp = "gemini-2.0-flash-exp" - gemini20FlashLite = "gemini-2.0-flash-lite" - gemini20FlashLitePrev = "gemini-2.0-flash-lite-preview" - gemini20ProExp0205 = "gemini-2.0-pro-exp-02-05" - gemini20FlashThinkingExp0121 = "gemini-2.0-flash-thinking-exp-01-21" - gemini20FlashPrevImageGen = "gemini-2.0-flash-preview-image-generation" - - gemini25Flash = "gemini-2.5-flash" - gemini25FlashLite = "gemini-2.5-flash-lite" - gemini25FlashLitePrev0617 = "gemini-2.5-flash-lite-preview-06-17" - - gemini25Pro = "gemini-2.5-pro" - gemini25ProExp0325 = "gemini-2.5-pro-exp-03-25" - gemini25ProPreview0325 = "gemini-2.5-pro-preview-03-25" - gemini25ProPreview0506 = "gemini-2.5-pro-preview-05-06" + gemini20Flash = "gemini-2.0-flash" + gemini20FlashExp = "gemini-2.0-flash-exp" + gemini20FlashLite = "gemini-2.0-flash-lite" + + gemini25Flash = "gemini-2.5-flash" + gemini25FlashLite = "gemini-2.5-flash-lite" + + gemini25Pro = "gemini-2.5-pro" imagen3Generate001 = "imagen-3.0-generate-001" - imagen3Generate002 = "imagen-3.0-generate-002" imagen3FastGenerate001 = "imagen-3.0-fast-generate-001" - textembedding004 = "text-embedding-004" + veo20Generate001 = "veo-2.0-generate-001" + veo30Generate001 = "veo-3.0-generate-001" + veo30FastGenerate001 = "veo-3.0-fast-generate-001" + embedding001 = "embedding-001" textembeddinggecko003 = "textembedding-gecko@003" textembeddinggecko002 = "textembedding-gecko@002" @@ -48,33 +104,19 @@ const ( textembeddinggeckomultilingual001 = "textembedding-gecko-multilingual@001" textmultilingualembedding002 = "text-multilingual-embedding-002" multimodalembedding = "multimodalembedding" - veo20Generate001 = "veo-2.0-generate-001" - veo30Generate001 = "veo-3.0-generate-001" - veo30FastGenerate001 = "veo-3.0-fast-generate-001" ) var ( // eventually, Vertex AI and Google AI models will match, in the meantime, // keep them sepparated vertexAIModels = []string{ - gemini15Flash, - gemini15Pro, gemini20Flash, gemini20FlashLite, - gemini20FlashLitePrev, - gemini20ProExp0205, - gemini20FlashThinkingExp0121, - gemini20FlashPrevImageGen, gemini25Flash, gemini25FlashLite, gemini25Pro, - gemini25FlashLitePrev0617, - gemini25ProExp0325, - gemini25ProPreview0325, - gemini25ProPreview0506, imagen3Generate001, - imagen3Generate002, imagen3FastGenerate001, veo20Generate001, @@ -83,24 +125,11 @@ var ( } googleAIModels = []string{ - gemini15Flash, - gemini15Pro, - gemini15Flash8b, gemini20Flash, gemini20FlashExp, - gemini20FlashLitePrev, - gemini20ProExp0205, - gemini20FlashThinkingExp0121, - gemini20FlashPrevImageGen, gemini25Flash, gemini25FlashLite, gemini25Pro, - gemini25FlashLitePrev0617, - gemini25ProExp0325, - gemini25ProPreview0325, - gemini25ProPreview0506, - - imagen3Generate002, veo20Generate001, veo30Generate001, @@ -108,35 +137,6 @@ var ( } supportedGeminiModels = map[string]ai.ModelOptions{ - gemini15Flash: { - Label: "Gemini 1.5 Flash", - Versions: []string{ - "gemini-1.5-flash-latest", - "gemini-1.5-flash-001", - "gemini-1.5-flash-002", - }, - Supports: &Multimodal, - Stage: ai.ModelStageStable, - }, - gemini15Pro: { - Label: "Gemini 1.5 Pro", - Versions: []string{ - "gemini-1.5-pro-latest", - "gemini-1.5-pro-001", - "gemini-1.5-pro-002", - }, - Supports: &Multimodal, - Stage: ai.ModelStageStable, - }, - gemini15Flash8b: { - Label: "Gemini 1.5 Flash 8B", - Versions: []string{ - "gemini-1.5-flash-8b-latest", - "gemini-1.5-flash-8b-001", - }, - Supports: &Multimodal, - Stage: ai.ModelStageStable, - }, gemini20Flash: { Label: "Gemini 2.0 Flash", Versions: []string{ @@ -145,12 +145,6 @@ var ( Supports: &Multimodal, Stage: ai.ModelStageStable, }, - gemini20FlashExp: { - Label: "Gemini 2.0 Flash Exp", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, gemini20FlashLite: { Label: "Gemini 2.0 Flash Lite", Versions: []string{ @@ -159,71 +153,23 @@ var ( Supports: &Multimodal, Stage: ai.ModelStageStable, }, - gemini20FlashLitePrev: { - Label: "Gemini 2.0 Flash Lite Preview 02-05", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, - gemini20ProExp0205: { - Label: "Gemini 2.0 Pro Exp 02-05", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, - gemini20FlashThinkingExp0121: { - Label: "Gemini 2.0 Flash Thinking Exp 01-21", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, - gemini20FlashPrevImageGen: { - Label: "Gemini 2.0 Flash Preview Image Generation", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, gemini25Flash: { Label: "Gemini 2.5 Flash", Versions: []string{}, Supports: &Multimodal, Stage: ai.ModelStageStable, }, - gemini25Pro: { - Label: "Gemini 2.5 Pro", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageStable, - }, - gemini25ProExp0325: { - Label: "Gemini 2.5 Pro Exp 03-25", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, - gemini25ProPreview0325: { - Label: "Gemini 2.5 Pro Preview 03-25", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, - gemini25ProPreview0506: { - Label: "Gemini 2.5 Pro Preview 05-06", - Versions: []string{}, - Supports: &Multimodal, - Stage: ai.ModelStageUnstable, - }, gemini25FlashLite: { Label: "Gemini 2.5 Flash Lite", Versions: []string{}, Supports: &Multimodal, Stage: ai.ModelStageStable, }, - gemini25FlashLitePrev0617: { - Label: "Gemini 2.5 Flash Lite Preview 06-17", + gemini25Pro: { + Label: "Gemini 2.5 Pro", Versions: []string{}, Supports: &Multimodal, - Stage: ai.ModelStageUnstable, + Stage: ai.ModelStageStable, }, } @@ -234,12 +180,6 @@ var ( Supports: &Media, Stage: ai.ModelStageStable, }, - imagen3Generate002: { - Label: "Imagen 3 Generate 002", - Versions: []string{}, - Supports: &Media, - Stage: ai.ModelStageStable, - }, imagen3FastGenerate001: { Label: "Imagen 3 Fast Generate 001", Versions: []string{}, @@ -250,54 +190,26 @@ var ( supportedVideoModels = map[string]ai.ModelOptions{ veo20Generate001: { - Label: "Google AI - Veo 2.0 Generate 001", + Label: "Veo 2.0 Generate 001", Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - Stage: ai.ModelStageStable, + Supports: &VeoSupports, + Stage: ai.ModelStageStable, }, veo30Generate001: { - Label: "Google AI - Veo 3.0 Generate 001", + Label: "Veo 3.0 Generate 001", Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - Stage: ai.ModelStageStable, + Supports: &VeoSupports, + Stage: ai.ModelStageStable, }, veo30FastGenerate001: { - Label: "Google AI - Veo 3.0 Fast Generate 001", + Label: "Veo 3.0 Fast Generate 001", Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - Stage: ai.ModelStageStable, + Supports: &VeoSupports, + Stage: ai.ModelStageStable, }, } - googleAIEmbedderConfig = map[string]ai.EmbedderOptions{ - textembedding004: { - Dimensions: 768, - Label: "Google Gen AI - Text Embedding 001", - Supports: &ai.EmbedderSupports{ - Input: []string{"text"}, - }, - }, + embedderConfig = map[string]ai.EmbedderOptions{ embedding001: { Dimensions: 768, Label: "Google Gen AI - Text Embedding Gecko (Legacy)", @@ -354,39 +266,105 @@ var ( } ) +// GetModelOptions returns ModelOptions for a model name with provider-prefixed label. +func GetModelOptions(name, provider string) ai.ModelOptions { + mt := ClassifyModel(name) + var opts ai.ModelOptions + var ok bool + + switch mt { + case ModelTypeGemini: + opts, ok = supportedGeminiModels[name] + if !ok { + opts = defaultGeminiOpts + } + case ModelTypeImagen: + opts, ok = supportedImagenModels[name] + if !ok { + opts = defaultImagenOpts + } + case ModelTypeVeo: + opts, ok = supportedVideoModels[name] + if !ok { + opts = defaultVeoOpts + } + default: + opts = defaultGeminiOpts + } + + if opts.ConfigSchema == nil { + if cfg := mt.DefaultConfig(); cfg != nil { + opts.ConfigSchema = configToMap(cfg) + } + } + + // Set label with provider prefix + prefix := googleAILabelPrefix + if provider == vertexAIProvider { + prefix = vertexAILabelPrefix + } + if opts.Label == "" { + opts.Label = name + } + opts.Label = fmt.Sprintf("%s - %s", prefix, opts.Label) + + return opts +} + +// GetEmbedderOptions returns EmbedderOptions for an embedder name with provider-prefixed label. +func GetEmbedderOptions(name, provider string) ai.EmbedderOptions { + opts, ok := embedderConfig[name] + if !ok { + opts = defaultEmbedOpts + } + + prefix := googleAILabelPrefix + if provider == vertexAIProvider { + prefix = vertexAILabelPrefix + } + if opts.Label == "" { + opts.Label = name + } + opts.Label = fmt.Sprintf("%s - %s", prefix, opts.Label) + + return opts +} + // listModels returns a map of supported models and their capabilities -// based on the detected backend +// based on the detected backend. func listModels(provider string) (map[string]ai.ModelOptions, error) { var names []string - var prefix string switch provider { case googleAIProvider: names = googleAIModels - prefix = googleAILabelPrefix case vertexAIProvider: names = vertexAIModels - prefix = vertexAILabelPrefix default: return nil, fmt.Errorf("unknown provider detected %s", provider) } - models := make(map[string]ai.ModelOptions, 0) + models := make(map[string]ai.ModelOptions, len(names)) for _, n := range names { + mt := ClassifyModel(n) var m ai.ModelOptions var ok bool - if strings.HasPrefix(n, "image") { + + switch mt { + case ModelTypeImagen: m, ok = supportedImagenModels[n] - } else if strings.HasPrefix(n, "veo") { + case ModelTypeVeo: m, ok = supportedVideoModels[n] - } else { + default: m, ok = supportedGeminiModels[n] } if !ok { return nil, fmt.Errorf("model %s not found for provider %s", n, provider) } + models[n] = GetModelOptions(n, provider) + // Preserve original fields that GetModelOptions doesn't copy models[n] = ai.ModelOptions{ - Label: prefix + " - " + m.Label, + Label: models[n].Label, Versions: m.Versions, Supports: m.Supports, ConfigSchema: m.ConfigSchema, @@ -406,63 +384,38 @@ type genaiModels struct { } // listGenaiModels returns a list of supported models and embedders from the -// Go Genai SDK +// Go Genai SDK, categorized by model type. func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, error) { models := genaiModels{} - allowedModels := []string{"gemini", "gemma"} for item, err := range client.Models.All(ctx) { - var name string - var description string if err != nil { - log.Fatal(err) + return genaiModels{}, fmt.Errorf("failed to list models: %w", err) } - switch { - case strings.HasPrefix(item.Name, "publishers/google/models/"): - name = strings.TrimPrefix(item.Name, "publishers/google/models/") - case strings.HasPrefix(item.Name, "models/"): - name = strings.TrimPrefix(item.Name, "models/") - default: + + name := strings.TrimPrefix(item.Name, "publishers/google/") + name = strings.TrimPrefix(name, "models/") + + // The Vertex AI backend does not populate SupportedActions, + // so we fall back to name-based categorization. + if slices.Contains(item.SupportedActions, "embedContent") || strings.Contains(name, "embed") { + models.embedders = append(models.embedders, name) continue } - description = strings.ToLower(item.Description) - if strings.Contains(description, "deprecated") { + + if strings.Contains(name, "imagen") { + models.imagen = append(models.imagen, name) continue } - // The Vertex AI backend does not populate SupportedActions, - // so fall back to name-based categorization when it's empty. - if len(item.SupportedActions) > 0 { - if slices.Contains(item.SupportedActions, "embedContent") { - models.embedders = append(models.embedders, name) - continue - } - if slices.Contains(item.SupportedActions, "predict") && strings.Contains(name, "imagen") { - models.imagen = append(models.imagen, name) - continue - } - if slices.Contains(item.SupportedActions, "generateContent") { - found := slices.ContainsFunc(allowedModels, func(s string) bool { - return strings.Contains(name, s) - }) - if found { - models.gemini = append(models.gemini, name) - } - } - } else { - switch { - case strings.Contains(name, "embedding") || strings.Contains(name, "embed"): - models.embedders = append(models.embedders, name) - case strings.Contains(name, "imagen"): - models.imagen = append(models.imagen, name) - default: - found := slices.ContainsFunc(allowedModels, func(s string) bool { - return strings.Contains(name, s) - }) - if found { - models.gemini = append(models.gemini, name) - } - } + + if strings.Contains(name, "veo") { + models.veo = append(models.veo, name) + continue } + + // Assume unknown models support generate content + // and let the backend error if not. + models.gemini = append(models.gemini, name) } return models, nil diff --git a/go/plugins/googlegenai/refs.go b/go/plugins/googlegenai/refs.go new file mode 100644 index 0000000000..61d4d33785 --- /dev/null +++ b/go/plugins/googlegenai/refs.go @@ -0,0 +1,55 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "github.com/firebase/genkit/go/ai" + "google.golang.org/genai" +) + +// --- Gemini (text generation) --- + +// ModelRef creates a ModelRef for a Gemini model. +// The name should include provider prefix (e.g., "googleai/gemini-2.0-flash"). +func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(name, config) +} + +// GoogleAIModelRef creates a ModelRef for a Google AI Gemini model. +// +// Deprecated: Use ModelRef with full name instead. +func GoogleAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(googleAIProvider+"/"+id, config) +} + +// VertexAIModelRef creates a ModelRef for a Vertex AI Gemini model. +// +// Deprecated: Use ModelRef with full name instead. +func VertexAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(vertexAIProvider+"/"+id, config) +} + +// --- Image generation (Imagen) --- + +// ImageModelRef creates a ModelRef for an image generation model. +// The name should include provider prefix (e.g., "googleai/imagen-3.0-generate-001"). +func ImageModelRef(name string, config *genai.GenerateImagesConfig) ai.ModelRef { + return ai.NewModelRef(name, config) +} + +// --- Video generation (Veo) --- + +// VideoModelRef creates a ModelRef for a video generation model. +// The name should include provider prefix (e.g., "googleai/veo-2.0-generate-001"). +func VideoModelRef(name string, config *genai.GenerateVideosConfig) ai.ModelRef { + return ai.NewModelRef(name, config) +} + +// --- Embedders --- + +// EmbedderRef creates an EmbedderRef for an embedding model. +// The name should include provider prefix (e.g., "googleai/text-embedding-004"). +func EmbedderRef(name string, config *genai.EmbedContentConfig) ai.EmbedderRef { + return ai.NewEmbedderRef(name, config) +} diff --git a/go/plugins/googlegenai/schema.go b/go/plugins/googlegenai/schema.go new file mode 100644 index 0000000000..f352104232 --- /dev/null +++ b/go/plugins/googlegenai/schema.go @@ -0,0 +1,229 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "google.golang.org/genai" +) + +// toGeminiSchema translates a map representing a standard JSON schema to a more +// limited [genai.Schema]. +func toGeminiSchema(originalSchema map[string]any, genkitSchema map[string]any) (*genai.Schema, error) { + // this covers genkitSchema == nil and {} + // genkitSchema will be {} if it's any + if len(genkitSchema) == 0 { + return nil, nil + } + if v, ok := genkitSchema["$ref"]; ok { + ref, ok := v.(string) + if !ok { + return nil, fmt.Errorf("invalid $ref value: not a string") + } + s, err := resolveRef(originalSchema, ref) + if err != nil { + return nil, err + } + return toGeminiSchema(originalSchema, s) + } + + // Handle "anyOf" subschemas by finding the first valid schema definition + if v, ok := genkitSchema["anyOf"]; ok { + if anyOfList, isList := v.([]map[string]any); isList { + for _, subSchema := range anyOfList { + if subSchemaType, hasType := subSchema["type"]; hasType { + if typeStr, isString := subSchemaType.(string); isString && typeStr != "null" { + if title, ok := genkitSchema["title"]; ok { + subSchema["title"] = title + } + if description, ok := genkitSchema["description"]; ok { + subSchema["description"] = description + } + // Found a schema like: {"type": "string"} + return toGeminiSchema(originalSchema, subSchema) + } + } + } + } + } + + schema := &genai.Schema{} + typeVal, ok := genkitSchema["type"] + if !ok { + return nil, fmt.Errorf("schema is missing the 'type' field: %#v", genkitSchema) + } + + typeStr, ok := typeVal.(string) + if !ok { + return nil, fmt.Errorf("schema 'type' field is not a string, but %T", typeVal) + } + + switch typeStr { + case "string": + schema.Type = genai.TypeString + case "float64", "number": + schema.Type = genai.TypeNumber + case "integer": + schema.Type = genai.TypeInteger + case "boolean": + schema.Type = genai.TypeBoolean + case "object": + schema.Type = genai.TypeObject + case "array": + schema.Type = genai.TypeArray + default: + return nil, fmt.Errorf("schema type %q not allowed", genkitSchema["type"]) + } + if v, ok := genkitSchema["required"]; ok { + schema.Required = castToStringArray(v) + } + if v, ok := genkitSchema["propertyOrdering"]; ok { + schema.PropertyOrdering = castToStringArray(v) + } + if v, ok := genkitSchema["description"]; ok { + schema.Description = v.(string) + } + if v, ok := genkitSchema["format"]; ok { + schema.Format = v.(string) + } + if v, ok := genkitSchema["title"]; ok { + schema.Title = v.(string) + } + if v, ok := genkitSchema["minItems"]; ok { + if i64, ok := castToInt64(v); ok { + schema.MinItems = genai.Ptr(i64) + } + } + if v, ok := genkitSchema["maxItems"]; ok { + if i64, ok := castToInt64(v); ok { + schema.MaxItems = genai.Ptr(i64) + } + } + if v, ok := genkitSchema["maximum"]; ok { + if f64, ok := castToFloat64(v); ok { + schema.Maximum = genai.Ptr(f64) + } + } + if v, ok := genkitSchema["minimum"]; ok { + if f64, ok := castToFloat64(v); ok { + schema.Minimum = genai.Ptr(f64) + } + } + if v, ok := genkitSchema["enum"]; ok { + schema.Enum = castToStringArray(v) + } + if v, ok := genkitSchema["items"]; ok { + items, err := toGeminiSchema(originalSchema, v.(map[string]any)) + if err != nil { + return nil, err + } + schema.Items = items + } + if val, ok := genkitSchema["properties"]; ok { + props := map[string]*genai.Schema{} + for k, v := range val.(map[string]any) { + p, err := toGeminiSchema(originalSchema, v.(map[string]any)) + if err != nil { + return nil, err + } + props[k] = p + } + schema.Properties = props + } + // Nullable -- not supported in jsonschema.Schema + + return schema, nil +} + +// resolveRef resolves a $ref reference in a JSON schema. +func resolveRef(originalSchema map[string]any, ref string) (map[string]any, error) { + tkns := strings.Split(ref, "/") + // refs look like: $/ref/foo -- we need the foo part + name := tkns[len(tkns)-1] + if defs, ok := originalSchema["$defs"].(map[string]any); ok { + if def, ok := defs[name].(map[string]any); ok { + return def, nil + } + } + // definitions (legacy) + if defs, ok := originalSchema["definitions"].(map[string]any); ok { + if def, ok := defs[name].(map[string]any); ok { + return def, nil + } + } + return nil, fmt.Errorf("unable to resolve schema reference") +} + +// castToStringArray converts either []any or []string to []string, filtering non-strings. +// This handles enum values from JSON Schema which may come as either type depending on unmarshaling. +// Filter out non-string types from if v is []any type. +func castToStringArray(v any) []string { + switch a := v.(type) { + case []string: + // Return a shallow copy to avoid aliasing + out := make([]string, 0, len(a)) + for _, s := range a { + if s != "" { + out = append(out, s) + } + } + return out + case []any: + var out []string + for _, it := range a { + if s, ok := it.(string); ok && s != "" { + out = append(out, s) + } + } + return out + default: + return nil + } +} + +// castToInt64 converts v to int64 when possible. +func castToInt64(v any) (int64, bool) { + switch t := v.(type) { + case int: + return int64(t), true + case int64: + return t, true + case float64: + return int64(t), true + case string: + if i, err := strconv.ParseInt(t, 10, 64); err == nil { + return i, true + } + case json.Number: + if i, err := t.Int64(); err == nil { + return i, true + } + } + return 0, false +} + +// castToFloat64 converts v to float64 when possible. +func castToFloat64(v any) (float64, bool) { + switch t := v.(type) { + case float64: + return t, true + case int: + return float64(t), true + case int64: + return float64(t), true + case string: + if f, err := strconv.ParseFloat(t, 64); err == nil { + return f, true + } + case json.Number: + if f, err := t.Float64(); err == nil { + return f, true + } + } + return 0, false +} diff --git a/go/plugins/googlegenai/tools.go b/go/plugins/googlegenai/tools.go new file mode 100644 index 0000000000..2a639688f4 --- /dev/null +++ b/go/plugins/googlegenai/tools.go @@ -0,0 +1,155 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/plugins/internal/uri" + "google.golang.org/genai" +) + +const ( + // toolNameRegex validates tool names. + toolNameRegex = "^[a-zA-Z_][a-zA-Z0-9_.-]{0,63}$" +) + +// toGeminiTools translates a slice of [ai.ToolDefinition] to a slice of [genai.Tool]. +func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { + var outTools []*genai.Tool + functions := []*genai.FunctionDeclaration{} + + for _, t := range inTools { + if !validToolName(t.Name) { + return nil, fmt.Errorf(`invalid tool name: %q, must start with a letter or an underscore, must be alphanumeric, underscores, dots or dashes with a max length of 64 chars`, t.Name) + } + inputSchema, err := toGeminiSchema(t.InputSchema, t.InputSchema) + if err != nil { + return nil, err + } + fd := &genai.FunctionDeclaration{ + Name: t.Name, + Parameters: inputSchema, + Description: t.Description, + } + functions = append(functions, fd) + } + + if len(functions) > 0 { + outTools = append(outTools, &genai.Tool{ + FunctionDeclarations: functions, + }) + } + + return outTools, nil +} + +// toGeminiFunctionResponsePart translates a slice of [ai.Part] to a slice of [genai.FunctionResponsePart] +func toGeminiFunctionResponsePart(parts []*ai.Part) ([]*genai.FunctionResponsePart, error) { + frp := []*genai.FunctionResponsePart{} + for _, p := range parts { + switch { + case p.IsData(): + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + case p.IsMedia(): + if strings.HasPrefix(p.Text, "data:") { + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + continue + } + frp = append(frp, genai.NewFunctionResponsePartFromURI(p.Text, p.ContentType)) + default: + return nil, fmt.Errorf("unsupported function response part type: %d", p.Kind) + } + } + return frp, nil +} + +// mergeTools consolidates all FunctionDeclarations into a single Tool +// while preserving non-function tools (Retrieval, GoogleSearch, CodeExecution, etc.) +func mergeTools(ts []*genai.Tool) []*genai.Tool { + var decls []*genai.FunctionDeclaration + var out []*genai.Tool + + for _, t := range ts { + if t == nil { + continue + } + if len(t.FunctionDeclarations) == 0 { + out = append(out, t) + continue + } + decls = append(decls, t.FunctionDeclarations...) + if cpy := cloneToolWithoutFunctions(t); cpy != nil && !reflect.ValueOf(*cpy).IsZero() { + out = append(out, cpy) + } + } + + if len(decls) > 0 { + out = append([]*genai.Tool{{FunctionDeclarations: decls}}, out...) + } + return out +} + +func cloneToolWithoutFunctions(t *genai.Tool) *genai.Tool { + if t == nil { + return nil + } + clone := *t + clone.FunctionDeclarations = nil + return &clone +} + +// toGeminiToolChoice translates tool choice settings to Gemini tool config. +func toGeminiToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) (*genai.ToolConfig, error) { + var mode genai.FunctionCallingConfigMode + switch toolChoice { + case "": + return nil, nil + case ai.ToolChoiceAuto: + mode = genai.FunctionCallingConfigModeAuto + case ai.ToolChoiceRequired: + mode = genai.FunctionCallingConfigModeAny + case ai.ToolChoiceNone: + mode = genai.FunctionCallingConfigModeNone + default: + return nil, fmt.Errorf("tool choice mode %q not supported", toolChoice) + } + + var toolNames []string + // Per docs, only set AllowedToolNames with mode set to ANY. + if mode == genai.FunctionCallingConfigModeAny { + for _, t := range tools { + toolNames = append(toolNames, t.Name) + } + } + return &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: mode, + AllowedFunctionNames: toolNames, + }, + }, nil +} + +var toolNameRegexCompiled = regexp.MustCompile(toolNameRegex) + +// validToolName checks whether the provided tool name matches the +// following criteria: +// - Start with a letter or an underscore +// - Must be alphanumeric and can include underscores, dots or dashes +// - Maximum length of 64 chars +func validToolName(n string) bool { + return toolNameRegexCompiled.MatchString(n) +} diff --git a/go/plugins/googlegenai/veo.go b/go/plugins/googlegenai/veo.go index 328b4598f7..b394ebe404 100644 --- a/go/plugins/googlegenai/veo.go +++ b/go/plugins/googlegenai/veo.go @@ -18,11 +18,15 @@ package googlegenai import ( "context" + "encoding/base64" "fmt" + "strings" "time" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/plugins/internal/uri" "google.golang.org/genai" ) @@ -34,6 +38,11 @@ func newVeoModel( name string, info ai.ModelOptions, ) ai.BackgroundModel { + provider := googleAIProvider + if client.ClientConfig().Backend == genai.BackendVertexAI { + provider = vertexAIProvider + } + startFunc := func(ctx context.Context, req *ai.ModelRequest) (*ai.ModelOperation, error) { // Extract text prompt from the request prompt := extractTextFromRequest(req) @@ -41,15 +50,29 @@ func newVeoModel( return nil, fmt.Errorf("no text prompt found in request") } + video := extractVeoVideoFromRequest(req) image := extractVeoImageFromRequest(req) + videoConfig, err := toVeoParameters(req) + if err != nil { + return nil, err + } - videoConfig := toVeoParameters(req) + // prevent SDK to pick a default number of video generation (usually 2) + // if users do not provide this setting + if videoConfig.NumberOfVideos == 0 { + videoConfig.NumberOfVideos = 1 + } - operation, err := client.Models.GenerateVideos( + sourceConfig := &genai.GenerateVideosSource{ + Prompt: prompt, + Image: image, + Video: video, + } + + operation, err := client.Models.GenerateVideosFromSource( ctx, name, - prompt, - image, + sourceConfig, videoConfig, ) if err != nil { @@ -105,7 +128,7 @@ func newVeoModel( return updatedOp, nil } - return ai.NewBackgroundModel(name, &ai.BackgroundModelOptions{ModelOptions: info}, startFunc, checkFunc) + return ai.NewBackgroundModel(api.NewName(provider, name), &ai.BackgroundModelOptions{ModelOptions: info}, startFunc, checkFunc) } // extractTextFromRequest extracts the text prompt from a model request. @@ -131,7 +154,7 @@ func extractVeoImageFromRequest(request *ai.ModelRequest) *genai.Image { for _, message := range request.Messages { for _, part := range message.Content { - if part.IsMedia() { + if part.IsMedia() && !part.IsVideo() { _, data, err := uri.Data(part) if err != nil { return nil @@ -147,15 +170,58 @@ func extractVeoImageFromRequest(request *ai.ModelRequest) *genai.Image { return nil } +// extractVeoVideoFromRequest extracts video content from a model request for Veo. +func extractVeoVideoFromRequest(request *ai.ModelRequest) *genai.Video { + if len(request.Messages) == 0 { + return nil + } + + for _, message := range request.Messages { + for _, part := range message.Content { + if !part.IsVideo() { + continue + } + if strings.HasPrefix(part.Text, "data:") { + contentType, data, err := uri.Data(part) + if err != nil { + return nil + } + return &genai.Video{ + VideoBytes: data, + MIMEType: contentType, + } + } + return &genai.Video{ + URI: part.Text, + } + } + } + + return nil +} + // toVeoParameters converts model request configuration to Veo video generation parameters. -func toVeoParameters(request *ai.ModelRequest) *genai.GenerateVideosConfig { - params := &genai.GenerateVideosConfig{} - if request.Config != nil { - if config, ok := request.Config.(*genai.GenerateVideosConfig); ok { - return config +func toVeoParameters(request *ai.ModelRequest) (*genai.GenerateVideosConfig, error) { + if request.Config == nil { + return &genai.GenerateVideosConfig{}, nil + } + + switch config := request.Config.(type) { + case *genai.GenerateVideosConfig: + return config, nil + case genai.GenerateVideosConfig: + return &config, nil + case map[string]any: + var result genai.GenerateVideosConfig + var err error + result, err = base.MapToStruct[genai.GenerateVideosConfig](config) + if err != nil { + return nil, core.NewPublicError(core.INVALID_ARGUMENT, fmt.Sprintf("The video configuration settings are not in the correct format. Check that the names and values match what the model expects: %v", err), nil) } + return &result, nil + default: + return nil, core.NewPublicError(core.INVALID_ARGUMENT, fmt.Sprintf("The configuration type %T is not supported. Use the correct configuration for this model (like VideoModelRef) or a configuration struct.", request.Config), nil) } - return params } // fromVeoOperation converts a Veo API operation to a Genkit core operation. @@ -166,6 +232,14 @@ func fromVeoOperation(veoOp *genai.GenerateVideosOperation) *ai.ModelOperation { Metadata: make(map[string]any), } + // Copy any API-provided metadata (e.g. progress percentage, queue status) + // so developers can use it for debugging or UI updates. + if veoOp.Metadata != nil { + for k, v := range veoOp.Metadata { + operation.Metadata[k] = v + } + } + // Handle error cases if veoOp.Error != nil { if errorMsg, ok := veoOp.Error["message"].(string); ok { @@ -191,8 +265,13 @@ func fromVeoOperation(veoOp *genai.GenerateVideosOperation) *ai.ModelOperation { if veoOp.Done && veoOp.Response != nil && veoOp.Response.GeneratedVideos != nil && len(veoOp.Response.GeneratedVideos) > 0 { content := make([]*ai.Part, 0, len(veoOp.Response.GeneratedVideos)) for _, sample := range veoOp.Response.GeneratedVideos { - if sample.Video != nil && sample.Video.URI != "" { - content = append(content, ai.NewMediaPart("video/mp4", sample.Video.URI)) + if sample.Video != nil { + if sample.Video.URI != "" { + content = append(content, ai.NewMediaPart("video/mp4", sample.Video.URI)) + } else if len(sample.Video.VideoBytes) > 0 { + importBase64 := "data:video/mp4;base64," + base64.StdEncoding.EncodeToString(sample.Video.VideoBytes) + content = append(content, ai.NewMediaPart("video/mp4", importBase64)) + } } } @@ -203,11 +282,27 @@ func fromVeoOperation(veoOp *genai.GenerateVideosOperation) *ai.ModelOperation { Content: content, }, FinishReason: ai.FinishReasonStop, + Raw: veoOp.Response, } return operation } } + // Handle Responsible AI (Safety) filtering + if veoOp.Done && veoOp.Response != nil && veoOp.Response.RAIMediaFilteredCount > 0 { + reasons := strings.Join(veoOp.Response.RAIMediaFilteredReasons, ", ") + operation.Output = &ai.ModelResponse{ + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart(fmt.Sprintf("Video generation blocked by safety filters. Reasons: %s", reasons))}, + }, + FinishReason: ai.FinishReasonBlocked, + FinishMessage: fmt.Sprintf("%d videos filtered due to RAI policies", veoOp.Response.RAIMediaFilteredCount), + Raw: veoOp.Response, + } + return operation + } + // Handle completed operations without valid response operation.Output = &ai.ModelResponse{ Message: &ai.Message{ @@ -215,6 +310,7 @@ func fromVeoOperation(veoOp *genai.GenerateVideosOperation) *ai.ModelOperation { Content: []*ai.Part{ai.NewTextPart("Video generation completed but no videos were generated")}, }, FinishReason: ai.FinishReasonStop, + Raw: veoOp.Response, } return operation diff --git a/go/plugins/googlegenai/veo_test.go b/go/plugins/googlegenai/veo_test.go index c86ed81241..a94cdc1f91 100644 --- a/go/plugins/googlegenai/veo_test.go +++ b/go/plugins/googlegenai/veo_test.go @@ -169,16 +169,18 @@ func TestToVeoParameters(t *testing.T) { t.Parallel() tests := []struct { - name string - request *ai.ModelRequest - expected genai.GenerateVideosConfig + name string + request *ai.ModelRequest + expected genai.GenerateVideosConfig + expectError bool }{ { name: "request with no config", request: &ai.ModelRequest{ Config: nil, }, - expected: genai.GenerateVideosConfig{}, + expected: genai.GenerateVideosConfig{}, + expectError: false, }, { name: "request with valid GenerateVideosConfig", @@ -194,6 +196,23 @@ func TestToVeoParameters(t *testing.T) { DurationSeconds: genai.Ptr(int32(5)), PersonGeneration: "allow_adult", }, + expectError: false, + }, + { + name: "request with valid map config", + request: &ai.ModelRequest{ + Config: map[string]any{ + "aspectRatio": "16:9", + "durationSeconds": 5, + "personGeneration": "allow_adult", + }, + }, + expected: genai.GenerateVideosConfig{ + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(5)), + PersonGeneration: "allow_adult", + }, + expectError: false, }, { name: "request with different config type", @@ -202,13 +221,25 @@ func TestToVeoParameters(t *testing.T) { MaxOutputTokens: int32(100), }, }, - expected: genai.GenerateVideosConfig{}, // Should return default config + expected: genai.GenerateVideosConfig{}, + expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := toVeoParameters(tt.request) + result, err := toVeoParameters(tt.request) + + if tt.expectError { + if err == nil { + t.Errorf("toVeoParameters() expected error but got nil") + } + return + } + + if err != nil { + t.Fatalf("toVeoParameters() unexpected error: %v", err) + } // Compare AspectRatio if result.AspectRatio != tt.expected.AspectRatio { diff --git a/go/samples/veo/main.go b/go/samples/veo/main.go index dceb622ff9..0f4580611a 100644 --- a/go/samples/veo/main.go +++ b/go/samples/veo/main.go @@ -16,6 +16,7 @@ package main import ( "context" + "encoding/base64" "fmt" "io" "log" @@ -36,35 +37,114 @@ func main() { g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - operation, err := genkit.GenerateOperation(ctx, g, - ai.WithMessages(ai.NewUserTextMessage("Cat racing mouse")), - ai.WithModelName("googleai/veo-3.0-generate-001"), - ai.WithConfig(&genai.GenerateVideosConfig{ - NumberOfVideos: 1, - AspectRatio: "16:9", - DurationSeconds: genai.Ptr(int32(8)), - Resolution: "720p", - }), - ) - if err != nil { - log.Fatalf("Failed to start video generation: %v", err) - } + genkit.DefineFlow(g, "text-to-video", func(ctx context.Context, input string) (string, error) { + if input == "" { + input = "Cat racing mouse" + } + operation, err := genkit.GenerateOperation(ctx, g, + ai.WithMessages(ai.NewUserTextMessage(input)), + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithConfig(&genai.GenerateVideosConfig{ + NumberOfVideos: 1, + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(8)), + Resolution: "720p", + }), + ) + if err != nil { + log.Fatalf("Failed to start video generation: %v", err) + } + printStatus(operation) - log.Printf("Started operation: %s", operation.ID) - printStatus(operation) + operation, err = waitForCompletion(ctx, g, operation) + if err != nil { + log.Fatalf("Operation failed: %v", err) + } + log.Println("Video generation completed successfully!") - operation, err = waitForCompletion(ctx, g, operation) - if err != nil { - log.Fatalf("Operation failed: %v", err) - } + if err := downloadGeneratedVideo(ctx, operation); err != nil { + log.Fatalf("Failed to download video: %v", err) + } + + // Return the video URI for chaining + uri, err := extractVideoURL(operation) + if err != nil { + return "", err + } + return uri, nil + }) - log.Println("Video generation completed successfully!") + genkit.DefineFlow(g, "image-to-video", func(ctx context.Context, input any) (string, error) { + imgb64, err := fetchImgAsBase64() + if err != nil { + log.Fatalf("unable to download image: %v", err) + } + operation, err := genkit.GenerateOperation(ctx, g, + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithMessages(ai.NewUserMessage(ai.NewTextPart("Generate a video of the following image, the cat should wake up and start accelerating the go-kart as if it just acquired a mushroom from Mario Kart"), + ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,"+imgb64), + )), + ai.WithConfig(&genai.GenerateVideosConfig{ + NumberOfVideos: 1, + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(8)), + }), + ) + if err != nil { + log.Fatalf("Failed to start video generation: %v", err) + } + printStatus(operation) - if err := downloadGeneratedVideo(ctx, operation); err != nil { - log.Fatalf("Failed to download video: %v", err) - } + operation, err = waitForCompletion(ctx, g, operation) + if err != nil { + log.Fatalf("Operation failed: %v", err) + } + log.Println("Video generation completed successfully!") + + if err := downloadGeneratedVideo(ctx, operation); err != nil { + log.Fatalf("Failed to download video: %v", err) + } + + return "Video successfully downloaded to veo3_video.mp4", nil + }) + + genkit.DefineFlow(g, "video-to-video", func(ctx context.Context, inputURI string) (string, error) { + if inputURI == "" { + return "", fmt.Errorf("input URI is required for video extension") + } + + log.Printf("Extending video from URI: %s", inputURI) + + operation, err := genkit.GenerateOperation(ctx, g, + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithMessages(ai.NewUserMessage( + ai.NewTextPart("Edit the original video backround to be a rainforest, also change the video style to be a cartoon from 1950, make the transition smooth. You must keep the characters from the original video"), + ai.NewMediaPart("video/mp4", inputURI), + )), + ai.WithConfig(&genai.GenerateVideosConfig{ + NumberOfVideos: 1, + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(8)), + }), + ) + if err != nil { + log.Fatalf("Failed to start video generation: %v", err) + } + printStatus(operation) + + operation, err = waitForCompletion(ctx, g, operation) + if err != nil { + log.Fatalf("Operation failed: %v", err) + } + log.Println("Video extension completed successfully!") - log.Println("Video successfully downloaded to veo3_video.mp4") + if err := downloadGeneratedVideo(ctx, operation); err != nil { + log.Fatalf("Failed to download video: %v", err) + } + + return "Video successfully downloaded to veo3_video.mp4", nil + }) + <-ctx.Done() } // waitForCompletion polls the operation status until it completes. @@ -193,3 +273,25 @@ func downloadVideo(ctx context.Context, url, filename string) error { return nil } + +// fetchImgAsBase64 downloads a predefined image and returns the image encoded in a base64 string +func fetchImgAsBase64() (string, error) { + // CC0 license image + imgURL := "https://pd.w.org/2025/07/896686fbbcd9990c9.84605288-2048x1365.jpg" + resp, err := http.Get(imgURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", err + } + + imageBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + base64string := base64.StdEncoding.EncodeToString(imageBytes) + return base64string, nil +} diff --git a/go/samples/vertexai-veo/main.go b/go/samples/vertexai-veo/main.go new file mode 100644 index 0000000000..eac6014a06 --- /dev/null +++ b/go/samples/vertexai-veo/main.go @@ -0,0 +1,138 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "encoding/base64" + "fmt" + "log" + "os" + "strings" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + + // Initialize with Vertex AI plugin. Ensure GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION are set. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.VertexAI{})) + + genkit.DefineFlow(g, "text-to-video", func(ctx context.Context, input string) (string, error) { + if input == "" { + input = "A futuristic city at sunset, flying cars, cyberpunk style" + } + operation, err := genkit.GenerateOperation(ctx, g, + ai.WithMessages(ai.NewUserTextMessage(input)), + ai.WithModelName("vertexai/veo-3.1-generate-preview"), + ai.WithConfig(&genai.GenerateVideosConfig{ + NumberOfVideos: 1, + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(4)), + }), + ) + if err != nil { + log.Fatalf("Failed to start video generation: %v", err) + } + printStatus(operation) + + operation, err = waitForCompletion(ctx, g, operation) + if err != nil { + log.Fatalf("Operation failed: %v", err) + } + log.Println("Video generation completed successfully!") + + if err := saveGeneratedVideo(operation, "veo3_vertexai_video.mp4"); err != nil { + log.Fatalf("Failed to save video: %v", err) + } + + return "Video successfully saved to veo3_vertexai_video.mp4", nil + }) + + <-ctx.Done() +} + +// waitForCompletion polls the operation status until it completes. +func waitForCompletion(ctx context.Context, g *genkit.Genkit, op *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for !op.Done { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("context cancelled: %w", ctx.Err()) + case <-ticker.C: + updatedOp, err := genkit.CheckModelOperation(ctx, g, op) + if err != nil { + return nil, fmt.Errorf("failed to check status: %w", err) + } + + if updatedOp.Error != nil { + return nil, fmt.Errorf("operation error: %w", updatedOp.Error) + } + + printStatus(updatedOp) + op = updatedOp + } + } + + return op, nil +} + +// printStatus prints the current status message from the operation. +func printStatus(op *core.Operation[*ai.ModelResponse]) { + if op.Output != nil && !op.Done && op.Output.Message != nil && len(op.Output.Message.Content) > 0 { + log.Printf("Status Message: %s", op.Output.Message.Content[0].Text) + } +} + +// saveGeneratedVideo extracts the video from the operation output and saves it to disk. +// Vertex AI returns raw VideoBytes encoded as a data URI by the Genkit plugin. +func saveGeneratedVideo(operation *core.Operation[*ai.ModelResponse], filename string) error { + if operation.Output == nil || operation.Output.Message == nil { + return fmt.Errorf("operation output is empty") + } + + for _, part := range operation.Output.Message.Content { + if part.IsMedia() && part.Text != "" { + if strings.HasPrefix(part.Text, "data:") { + // Vertex AI returns the raw video encoded as base64 in the text field + commaIndex := strings.Index(part.Text, ",") + if commaIndex == -1 { + return fmt.Errorf("invalid data URI format") + } + + base64Data := part.Text[commaIndex+1:] + videoBytes, err := base64.StdEncoding.DecodeString(base64Data) + if err != nil { + return fmt.Errorf("failed to decode base64 video data: %w", err) + } + + return os.WriteFile(filename, videoBytes, 0o644) + } else { + // If it's a direct Cloud Storage URI (less common without output GCS bucket config) + return fmt.Errorf("received URI instead of raw bytes: %s. Use HTTP client to download", part.Text) + } + } + } + + return fmt.Errorf("no video found in the operation output") +}