diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index 4a375980c..1366f10b2 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -28,6 +28,7 @@ import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; @@ -35,7 +36,6 @@ import com.google.genai.types.Part; import io.a2a.client.Client; import io.a2a.client.ClientEvent; -import io.a2a.client.MessageEvent; import io.a2a.client.TaskEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.A2AClientException; @@ -541,6 +541,11 @@ protected Flowable runLiveImpl(InvocationContext invocationContext) { "runLiveImpl for " + getClass() + " via A2A is not implemented."); } + @Override + public AgentOrigin toolOrigin() { + return AgentOrigin.A2A; + } + /** Exception thrown when the agent card cannot be resolved. */ public static class AgentCardResolutionError extends RuntimeException { public AgentCardResolutionError(String message) { diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 95fe838cc..cbceceed2 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -25,6 +25,7 @@ import com.google.adk.events.Event; import com.google.adk.plugins.Plugin; import com.google.adk.telemetry.Tracing; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; @@ -256,6 +257,15 @@ public ImmutableList afterAgentCallback() { return afterAgentCallback; } + /** + * Returns the origin of the tool when this agent is used as a tool. + * + * @return the tool origin, defaults to "BASE_AGENT". + */ + public AgentOrigin toolOrigin() { + return AgentOrigin.BASE_AGENT; + } + /** * The resolved beforeAgentCallback field as a list. * diff --git a/core/src/main/java/com/google/adk/models/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/ChatCompletionsResponse.java deleted file mode 100644 index fe5cdd116..000000000 --- a/core/src/main/java/com/google/adk/models/ChatCompletionsResponse.java +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.models; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.List; -import java.util.Map; - -/** - * Data Transfer Objects for Chat Completion and Chat Completion Chunk API responses. - * - *

These classes are used for deserializing JSON responses from the `/chat/completions` endpoint. - */ -@JsonIgnoreProperties(ignoreUnknown = true) -final class ChatCompletionsResponse { - - private ChatCompletionsResponse() {} - - @JsonIgnoreProperties(ignoreUnknown = true) - static class ChatCompletion { - public String id; - public List choices; - public Long created; - public String model; - public String object; - - @JsonProperty("service_tier") - public String serviceTier; - - @JsonProperty("system_fingerprint") - public String systemFingerprint; - - public Usage usage; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Choice { - @JsonProperty("finish_reason") - public String finishReason; - - public Integer index; - public Logprobs logprobs; - public Message message; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class ChatCompletionChunk { - public String id; - public List choices; - public Long created; - public String model; - public String object; - - @JsonProperty("service_tier") - public String serviceTier; - - @JsonProperty("system_fingerprint") - public String systemFingerprint; - - public Usage usage; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class ChunkChoice { - @JsonProperty("finish_reason") - public String finishReason; - - public Integer index; - public Logprobs logprobs; - public Message delta; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Message { - public String content; - public String refusal; - public String role; - - @JsonProperty("tool_calls") - public List toolCalls; - - // function_call is not supported in ChatCompletionChunk and ChatCompletion support is - // deprecated. - @JsonProperty("function_call") - public Function functionCall; // Fallback for deprecated top-level function calls - - public List annotations; - public Audio audio; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class ToolCall { - // Index is only used in ChatCompletionChunk. - public Integer index; - public String id; - public String type; - public Function function; - public Custom custom; - - @JsonProperty("extra_content") - public Map extraContent; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Function { - public String name; - public String arguments; // JSON string - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Custom { - public String input; - public String name; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Logprobs { - public List content; - public List refusal; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - @JsonInclude(JsonInclude.Include.NON_NULL) - static class TokenLogprob { - public String token; - public List bytes; - public Double logprob; - - @JsonProperty("top_logprobs") - public List topLogprobs; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Usage { - @JsonProperty("completion_tokens") - public Integer completionTokens; - - @JsonProperty("prompt_tokens") - public Integer promptTokens; - - @JsonProperty("total_tokens") - public Integer totalTokens; - - @JsonProperty("thoughts_token_count") - public Integer thoughtsTokenCount; // Gemini-specific extension - - @JsonProperty("completion_tokens_details") - public CompletionTokensDetails completionTokensDetails; - - @JsonProperty("prompt_tokens_details") - public PromptTokensDetails promptTokensDetails; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class CompletionTokensDetails { - @JsonProperty("accepted_prediction_tokens") - public Integer acceptedPredictionTokens; - - @JsonProperty("audio_tokens") - public Integer audioTokens; - - @JsonProperty("reasoning_tokens") - public Integer reasoningTokens; - - @JsonProperty("rejected_prediction_tokens") - public Integer rejectedPredictionTokens; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class PromptTokensDetails { - @JsonProperty("audio_tokens") - public Integer audioTokens; - - @JsonProperty("cached_tokens") - public Integer cachedTokens; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Annotation { - public String type; - - @JsonProperty("url_citation") - public UrlCitation urlCitation; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class UrlCitation { - @JsonProperty("end_index") - public Integer endIndex; - - @JsonProperty("start_index") - public Integer startIndex; - - public String title; - public String url; - } - - @JsonIgnoreProperties(ignoreUnknown = true) - static class Audio { - public String id; - public String data; - - @JsonProperty("expires_at") - public Long expiresAt; - - public String transcript; - } -} diff --git a/core/src/main/java/com/google/adk/models/LlmRegistry.java b/core/src/main/java/com/google/adk/models/LlmRegistry.java index a73d89430..acc038695 100644 --- a/core/src/main/java/com/google/adk/models/LlmRegistry.java +++ b/core/src/main/java/com/google/adk/models/LlmRegistry.java @@ -16,6 +16,7 @@ package com.google.adk.models; +import com.google.common.annotations.VisibleForTesting; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -38,6 +39,7 @@ public interface LlmFactory { static { registerLlm("gemini-.*", modelName -> Gemini.builder().modelName(modelName).build()); registerLlm("apigee/.*", modelName -> ApigeeLlm.builder().modelName(modelName).build()); + registerLlm("gemma-.*", modelName -> Gemini.builder().modelName(modelName).build()); } /** @@ -50,6 +52,17 @@ public static void registerLlm(String modelNamePattern, LlmFactory factory) { llmFactories.put(modelNamePattern, factory); } + /** + * Checks if the given model name matches any of the registered LLM factory patterns. + * + * @param modelName The model name to check. + * @return {@code true} if the model name matches at least one pattern, {@code false} otherwise. + */ + @VisibleForTesting + static boolean matchesAnyPattern(String modelName) { + return llmFactories.keySet().stream().anyMatch(modelName::matches); + } + /** * Returns an LLM instance for the given model name, using a cached or new factory-created * instance. diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java new file mode 100644 index 000000000..730f53af2 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java @@ -0,0 +1,88 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; + +/** Shared models for Chat Completions Request and Response. */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +final class ChatCompletionsCommon { + + private ChatCompletionsCommon() {} + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_tool_call%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ToolCall { + /** See class definition for more details. */ + public Integer index; + + /** See class definition for more details. */ + public String id; + + /** See class definition for more details. */ + public String type; + + /** See class definition for more details. */ + public Function function; + + /** See class definition for more details. */ + public Custom custom; + + /** + * Used to supply additional parameters for specific models, for example: + * https://ai.google.dev/gemini-api/docs/openai#thinking + */ + @JsonProperty("extra_content") + public Map extraContent; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_function_tool_call%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Function { + /** See class definition for more details. */ + public String name; + + /** See class definition for more details. */ + public String arguments; // JSON string + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_custom_tool%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Custom { + /** See class definition for more details. */ + public String input; + + /** See class definition for more details. */ + public String name; + } +} diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java new file mode 100644 index 000000000..4b6747fb1 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java @@ -0,0 +1,728 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.chat; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonValue; +import java.util.List; +import java.util.Map; + +/** + * Data Transfer Objects for Chat Completion API requests. + * + *

See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create + */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +final class ChatCompletionsRequest { + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema) + */ + public List messages; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20model%20%3E%20(schema) + */ + public String model; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20audio%20%3E%20(schema) + */ + public AudioParam audio; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20frequency_penalty%20%3E%20(schema) + */ + @JsonProperty("frequency_penalty") + public Double frequencyPenalty; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20logit_bias%20%3E%20(schema) + */ + @JsonProperty("logit_bias") + public Map logitBias; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20logprobs%20%3E%20(schema) + */ + public Boolean logprobs; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20max_completion_tokens%20%3E%20(schema) + */ + @JsonProperty("max_completion_tokens") + public Integer maxCompletionTokens; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20metadata%20%3E%20(schema) + */ + public Map metadata; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20modalities%20%3E%20(schema) + */ + public List modalities; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20n%20%3E%20(schema) + */ + public Integer n; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20parallel_tool_calls%20%3E%20(schema) + */ + @JsonProperty("parallel_tool_calls") + public Boolean parallelToolCalls; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20prediction%20%3E%20(schema) + */ + public Prediction prediction; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20presence_penalty%20%3E%20(schema) + */ + @JsonProperty("presence_penalty") + public Double presencePenalty; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20prompt_cache_key%20%3E%20(schema) + */ + @JsonProperty("prompt_cache_key") + public String promptCacheKey; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20prompt_cache_retention%20%3E%20(schema) + */ + @JsonProperty("prompt_cache_retention") + public String promptCacheRetention; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20reasoning_effort%20%3E%20(schema) + */ + @JsonProperty("reasoning_effort") + public String reasoningEffort; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20response_format%20%3E%20(schema) + */ + @JsonProperty("response_format") + public ResponseFormat responseFormat; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20safety_identifier%20%3E%20(schema) + */ + @JsonProperty("safety_identifier") + public String safetyIdentifier; + + /** + * Deprecated. Use temperature instead. See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20seed%20%3E%20(schema) + */ + public Long seed; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20service_tier%20%3E%20(schema) + */ + @JsonProperty("service_tier") + public String serviceTier; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20stop%20%3E%20(schema) + */ + public StopCondition stop; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20store%20%3E%20(schema) + */ + public Boolean store; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20stream%20%3E%20(schema) + */ + public Boolean stream; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20stream_options%20%3E%20(schema) + */ + @JsonProperty("stream_options") + public StreamOptions streamOptions; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20temperature%20%3E%20(schema) + */ + public Double temperature; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tool_choice%20%3E%20(schema) + */ + @JsonProperty("tool_choice") + public ToolChoice toolChoice; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tools%20%3E%20(schema) + */ + public List tools; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20top_logprobs%20%3E%20(schema) + */ + @JsonProperty("top_logprobs") + public Integer topLogprobs; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20top_p%20%3E%20(schema) + */ + @JsonProperty("top_p") + public Double topP; + + /** + * Deprecated, use safety_identifier and prompt_cache_key instead. See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20user%20%3E%20(schema) + */ + public String user; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20verbosity%20%3E%20(schema) + */ + public String verbosity; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20web_search_options%20%3E%20(schema) + */ + @JsonProperty("web_search_options") + public WebSearchOptions webSearchOptions; + + /** + * Additional body parameters used for specific models, for example: + * https://ai.google.dev/gemini-api/docs/openai#extra-body + */ + @JsonProperty("extra_body") + public Map extraBody; + + /** + * A catch-all class for message parameters. See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Message { + /** See class definition for more details. */ + public String role; + + /** See class definition for more details. */ + public MessageContent content; + + /** See class definition for more details. */ + public String name; + + /** See class definition for more details. */ + @JsonProperty("tool_calls") + public List toolCalls; + + /** Deprecated. Use tool_calls instead.See class definition for more details. */ + @JsonProperty("function_call") + public FunctionCall functionCall; + + /** See class definition for more details. */ + @JsonProperty("tool_call_id") + public String toolCallId; + + /** See class definition for more details. */ + public Audio audio; + + /** See class definition for more details. */ + public String refusal; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_content_part_text%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ContentPart { + /** See class definition for more details. */ + public String type; + + /** See class definition for more details. */ + public String text; + + /** See class definition for more details. */ + public String refusal; + + /** See class definition for more details. */ + @JsonProperty("image_url") + public ImageUrl imageUrl; + + /** See class definition for more details. */ + @JsonProperty("input_audio") + public InputAudio inputAudio; + + /** See class definition for more details. */ + public File file; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_content_part_text%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ImageUrl { + /** See class definition for more details. */ + public String url; + + /** See class definition for more details. */ + public String detail; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_content_part_text%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class InputAudio { + /** See class definition for more details. */ + public String data; + + /** See class definition for more details. */ + public String format; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class File { + /** See class definition for more details. */ + @JsonProperty("file_data") + public String fileData; + + /** See class definition for more details. */ + @JsonProperty("file_id") + public String fileId; + + /** See class definition for more details. */ + public String filename; + } + + /** + * Deprecated. Function call details replaced by tool_calls. See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class FunctionCall { + /** See class definition for more details. */ + public String name; + + /** See class definition for more details. */ + public String arguments; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20audio%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class AudioParam { + /** See class definition for more details. */ + public String format; + + /** See class definition for more details. */ + public VoiceConfig voice; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20audio%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Audio { + /** See class definition for more details. */ + public String id; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20prediction%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Prediction { + /** See class definition for more details. */ + public String type; + + /** See class definition for more details. */ + public Object content; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20stream_options%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class StreamOptions { + /** See class definition for more details. */ + @JsonProperty("include_obfuscation") + public Boolean includeObfuscation; + + /** See class definition for more details. */ + @JsonProperty("include_usage") + public Boolean includeUsage; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tools%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class Tool { + /** See class definition for more details. */ + public String type; + + /** See class definition for more details. */ + public FunctionDefinition function; + + /** See class definition for more details. */ + public CustomTool custom; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tools%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class FunctionDefinition { + /** See class definition for more details. */ + public String name; + + /** See class definition for more details. */ + public String description; + + /** See class definition for more details. */ + public Map parameters; + + /** See class definition for more details. */ + public Boolean strict; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_custom_tool%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class CustomTool { + /** See class definition for more details. */ + public String name; + + /** See class definition for more details. */ + public String description; + + /** See class definition for more details. */ + public Object format; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20web_search_options%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class WebSearchOptions { + /** See class definition for more details. */ + @JsonProperty("search_context_size") + public String searchContextSize; + + /** See class definition for more details. */ + @JsonProperty("user_location") + public UserLocation userLocation; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20web_search_options%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class UserLocation { + /** See class definition for more details. */ + public String type; + + /** See class definition for more details. */ + public ApproximateLocation approximate; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20web_search_options%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ApproximateLocation { + /** See class definition for more details. */ + public String city; + + /** See class definition for more details. */ + public String country; + + /** See class definition for more details. */ + public String region; + + /** See class definition for more details. */ + public String timezone; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20response_format%20%3E%20(schema) + */ + interface ResponseFormat {} + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20response_format%20%3E%20(schema) + */ + static class ResponseFormatText implements ResponseFormat { + public String type = "text"; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20response_format%20%3E%20(schema) + */ + static class ResponseFormatJsonObject implements ResponseFormat { + public String type = "json_object"; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20response_format%20%3E%20(schema) + */ + static class ResponseFormatJsonSchema implements ResponseFormat { + public String type = "json_schema"; + + @JsonProperty("json_schema") + public JsonSchema jsonSchema; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20response_format%20%3E%20(schema) + */ + static class JsonSchema { + /** See class definition for more details. */ + public String name; + + /** See class definition for more details. */ + public String description; + + /** See class definition for more details. */ + public Map schema; + + /** See class definition for more details. */ + public Boolean strict; + } + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tool_choice%20%3E%20(schema) + */ + interface ToolChoice {} + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tool_choice%20%3E%20(schema) + */ + static class ToolChoiceMode implements ToolChoice { + private final String mode; + + public ToolChoiceMode(String mode) { + this.mode = mode; + } + + @JsonValue + public String getMode() { + return mode; + } + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tool_choice%20%3E%20(schema) + */ + static class NamedToolChoice implements ToolChoice { + /** See class definition for more details. */ + public String type = "function"; + + /** See class definition for more details. */ + public FunctionName function; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tool_choice%20%3E%20(schema) + */ + static class FunctionName { + /** See class definition for more details. */ + public String name; + } + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tool_choice%20%3E%20(schema) + */ + static class NamedToolChoiceCustom implements ToolChoice { + /** See class definition for more details. */ + public String type = "custom"; + + /** See class definition for more details. */ + public CustomName custom; + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20tool_choice%20%3E%20(schema) + */ + static class CustomName { + /** See class definition for more details. */ + public String name; + } + } + + /** + * Wrapper class for stop. See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20stop%20%3E%20(schema) + */ + static class StopCondition { + private final String stringValue; + private final List listValue; + + @JsonCreator + public StopCondition(String stringValue) { + this.stringValue = stringValue; + this.listValue = null; + } + + @JsonCreator + public StopCondition(List listValue) { + this.stringValue = null; + this.listValue = listValue; + } + + @JsonValue + public Object getValue() { + return stringValue != null ? stringValue : listValue; + } + } + + /** + * Wrapper class for messages. See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema) + */ + static class MessageContent { + private final String stringValue; + private final List listValue; + + @JsonCreator + public MessageContent(String stringValue) { + this.stringValue = stringValue; + this.listValue = null; + } + + @JsonCreator + public MessageContent(List listValue) { + this.stringValue = null; + this.listValue = listValue; + } + + @JsonValue + public Object getValue() { + return stringValue != null ? stringValue : listValue; + } + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20audio%20%3E%20(schema) + */ + static class VoiceConfig { + private final String stringValue; + private final Map mapValue; + + @JsonCreator + public VoiceConfig(String stringValue) { + this.stringValue = stringValue; + this.mapValue = null; + } + + @JsonCreator + public VoiceConfig(Map mapValue) { + this.stringValue = null; + this.mapValue = mapValue; + } + + @JsonValue + public Object getValue() { + return stringValue != null ? stringValue : mapValue; + } + } +} diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java new file mode 100644 index 000000000..75a96e6ee --- /dev/null +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -0,0 +1,327 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.chat; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Data Transfer Objects for Chat Completion and Chat Completion Chunk API responses. + * + *

See https://developers.openai.com/api/reference/resources/chat + */ +@JsonIgnoreProperties(ignoreUnknown = true) +final class ChatCompletionsResponse { + + private ChatCompletionsResponse() {} + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class ChatCompletion { + /** See class definition for more details. */ + public String id; + + /** See class definition for more details. */ + public List choices; + + /** See class definition for more details. */ + public Long created; + + /** See class definition for more details. */ + public String model; + + /** See class definition for more details. */ + public String object; + + /** See class definition for more details. */ + @JsonProperty("service_tier") + public String serviceTier; + + /** Deprecated. See class definition for more details. */ + @JsonProperty("system_fingerprint") + public String systemFingerprint; + + /** See class definition for more details. */ + public Usage usage; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion%20%3E%20(schema)%20%3E%20(property)%20choices + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class Choice { + /** See class definition for more details. */ + @JsonProperty("finish_reason") + public String finishReason; + + /** See class definition for more details. */ + public Integer index; + + /** See class definition for more details. */ + public Logprobs logprobs; + + /** See class definition for more details. */ + public Message message; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_chunk%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class ChatCompletionChunk { + /** See class definition for more details. */ + public String id; + + /** See class definition for more details. */ + public List choices; + + /** See class definition for more details. */ + public Long created; + + /** See class definition for more details. */ + public String model; + + /** See class definition for more details. */ + public String object; + + /** See class definition for more details. */ + @JsonProperty("service_tier") + public String serviceTier; + + /** Deprecated. See class definition for more details. */ + @JsonProperty("system_fingerprint") + public String systemFingerprint; + + /** See class definition for more details. */ + public Usage usage; + } + + /** + * Used for streaming responses. See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_chunk%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class ChunkChoice { + /** See class definition for more details. */ + @JsonProperty("finish_reason") + public String finishReason; + + /** See class definition for more details. */ + public Integer index; + + /** See class definition for more details. */ + public Logprobs logprobs; + + /** See class definition for more details. */ + public Message delta; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class Message { + /** See class definition for more details. */ + public String content; + + /** See class definition for more details. */ + public String refusal; + + /** See class definition for more details. */ + public String role; + + /** See class definition for more details. */ + @JsonProperty("tool_calls") + public List toolCalls; + + /** Deprecated. Use tool_calls instead. See class definition for more details. */ + @JsonProperty("function_call") + public ChatCompletionsCommon.Function functionCall; + + /** See class definition for more details. */ + public List annotations; + + /** See class definition for more details. */ + public Audio audio; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_logprobs%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class Logprobs { + /** See class definition for more details. */ + public List content; + + /** See class definition for more details. */ + public List refusal; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_token_logprob%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(JsonInclude.Include.NON_NULL) + static class TokenLogprob { + /** See class definition for more details. */ + public String token; + + /** See class definition for more details. */ + public List bytes; + + /** See class definition for more details. */ + public Double logprob; + + /** See class definition for more details. */ + @JsonProperty("top_logprobs") + public List topLogprobs; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/completions#(resource)%20completions%20%3E%20(model)%20completion_usage%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class Usage { + /** See class definition for more details. */ + @JsonProperty("completion_tokens") + public Integer completionTokens; + + /** See class definition for more details. */ + @JsonProperty("prompt_tokens") + public Integer promptTokens; + + /** See class definition for more details. */ + @JsonProperty("total_tokens") + public Integer totalTokens; + + /** See class definition for more details. */ + @JsonProperty("thoughts_token_count") + public Integer thoughtsTokenCount; + + /** See class definition for more details. */ + @JsonProperty("completion_tokens_details") + public CompletionTokensDetails completionTokensDetails; + + /** See class definition for more details. */ + @JsonProperty("prompt_tokens_details") + public PromptTokensDetails promptTokensDetails; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/completions#(resource)%20completions%20%3E%20(model)%20completion_usage%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class CompletionTokensDetails { + /** See class definition for more details. */ + @JsonProperty("accepted_prediction_tokens") + public Integer acceptedPredictionTokens; + + /** See class definition for more details. */ + @JsonProperty("audio_tokens") + public Integer audioTokens; + + /** See class definition for more details. */ + @JsonProperty("reasoning_tokens") + public Integer reasoningTokens; + + /** See class definition for more details. */ + @JsonProperty("rejected_prediction_tokens") + public Integer rejectedPredictionTokens; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/completions#(resource)%20completions%20%3E%20(model)%20completion_usage%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class PromptTokensDetails { + /** See class definition for more details. */ + @JsonProperty("audio_tokens") + public Integer audioTokens; + + /** See class definition for more details. */ + @JsonProperty("cached_tokens") + public Integer cachedTokens; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message%20%3E%20(schema)%20%3E%20(property)%20annotations + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class Annotation { + /** See class definition for more details. */ + public String type; + + /** See class definition for more details. */ + @JsonProperty("url_citation") + public UrlCitation urlCitation; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message%20%3E%20(schema)%20%3E%20(property)%20annotations%20%3E%20(items)%20%3E%20(property)%20url_citation + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class UrlCitation { + /** See class definition for more details. */ + @JsonProperty("end_index") + public Integer endIndex; + + /** See class definition for more details. */ + @JsonProperty("start_index") + public Integer startIndex; + + /** See class definition for more details. */ + public String title; + + /** See class definition for more details. */ + public String url; + } + + /** + * See + * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_audio%20%3E%20(schema) + */ + @JsonIgnoreProperties(ignoreUnknown = true) + static class Audio { + /** See class definition for more details. */ + public String id; + + /** See class definition for more details. */ + public String data; + + /** See class definition for more details. */ + @JsonProperty("expires_at") + public Long expiresAt; + + /** See class definition for more details. */ + public String transcript; + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java index ef826fb56..924ad228e 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -147,7 +147,12 @@ public void flush() { } else { logger.fine("Successfully wrote " + batch.size() + " rows to BigQuery."); } - } catch (AppendSerializationError ase) { + } + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (e.getCause() instanceof AppendSerializationError ase) { logger.log( Level.SEVERE, "Failed to write batch to BigQuery due to serialization error", ase); Map rowIndexToErrorMessage = ase.getRowIndexToErrorMessage(); @@ -161,12 +166,9 @@ public void flush() { logger.severe( "AppendSerializationError occurred, but no row-specific errors were provided."); } + } else { + logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); } - } catch (Exception e) { - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); } finally { // Clear the vectors to release the memory. root.clear(); @@ -185,7 +187,12 @@ private void populateVector(FieldVector vector, int index, Object value) { return; } if (vector instanceof VarCharVector varCharVector) { - String strValue = (value instanceof JsonNode jsonNode) ? jsonNode.asText() : value.toString(); + String strValue; + if (value instanceof JsonNode jsonNode) { + strValue = jsonNode.isTextual() ? jsonNode.asText() : jsonNode.toString(); + } else { + strValue = value.toString(); + } varCharVector.setSafe(index, strValue.getBytes(UTF_8)); } else if (vector instanceof BigIntVector bigIntVector) { long longValue; diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 68b5fb5a1..cf7dad9df 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -16,6 +16,9 @@ package com.google.adk.plugins.agentanalytics; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject; import static java.util.concurrent.TimeUnit.MILLISECONDS; import com.google.adk.agents.BaseAgent; @@ -25,8 +28,17 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.agentanalytics.JsonFormatter.ParsedContent; +import com.google.adk.plugins.agentanalytics.JsonFormatter.TruncationResult; +import com.google.adk.plugins.agentanalytics.TraceManager.RecordData; +import com.google.adk.plugins.agentanalytics.TraceManager.SpanIds; +import com.google.adk.sessions.Session; +import com.google.adk.tools.AgentTool; import com.google.adk.tools.BaseTool; +import com.google.adk.tools.FunctionTool; import com.google.adk.tools.ToolContext; +import com.google.adk.tools.mcp.AbstractMcpTool; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.retrying.RetrySettings; import com.google.auth.oauth2.GoogleCredentials; @@ -45,12 +57,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; -import io.opentelemetry.api.trace.Span; -import io.opentelemetry.api.trace.SpanContext; +import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.io.IOException; +import java.time.Duration; import java.time.Instant; import java.util.HashMap; import java.util.Map; @@ -61,7 +74,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; -import org.threeten.bp.Duration; +import org.jspecify.annotations.Nullable; /** * BigQuery Agent Analytics Plugin for Java. @@ -74,6 +87,14 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { private static final ImmutableList DEFAULT_AUTH_SCOPES = ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); private static final AtomicLong threadCounter = new AtomicLong(0); + private static final ImmutableMap HITL_EVENT_TYPES = + ImmutableMap.of( + "adk_request_credential", + "HITL_CREDENTIAL_REQUEST", + "adk_request_confirmation", + "HITL_CONFIRMATION_REQUEST", + "adk_request_input", + "HITL_INPUT_REQUEST"); private final BigQueryLoggerConfig config; private final BigQuery bigQuery; @@ -81,6 +102,7 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin { private final ScheduledExecutorService executor; private final Object tableEnsuredLock = new Object(); @VisibleForTesting final BatchProcessor batchProcessor; + @VisibleForTesting final TraceManager traceManager; private volatile boolean tableEnsured = false; public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { @@ -96,6 +118,7 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQue r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); this.executor = Executors.newScheduledThreadPool(1, threadFactory); this.writeClient = createWriteClient(config); + this.traceManager = createTraceManager(); if (config.enabled()) { StreamWriter writer = createWriter(config); @@ -194,9 +217,10 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) { RetrySettings retrySettings = RetrySettings.newBuilder() .setMaxAttempts(retryConfig.maxRetries()) - .setInitialRetryDelay(Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setInitialRetryDelay( + org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis())) .setRetryDelayMultiplier(retryConfig.multiplier()) - .setMaxRetryDelay(Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis())) .build(); String streamName = getStreamName(config); @@ -210,58 +234,130 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) { } } + protected TraceManager createTraceManager() { + return new TraceManager(); + } + + private void logEvent( + String eventType, + InvocationContext invocationContext, + Object content, + Optional eventData) { + logEvent(eventType, invocationContext, content, false, eventData); + } + private void logEvent( String eventType, InvocationContext invocationContext, - Optional callbackContext, Object content, - Map extraAttributes) { - if (batchProcessor == null) { + boolean isContentTruncated, + Optional eventData) { + if (!config.enabled() || batchProcessor == null) { return; } - + if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) { + return; + } + if (config.eventDenylist().contains(eventType)) { + return; + } + // Ensure table exists before logging. ensureTableExistsOnce(); - + // Log common fields Map row = new HashMap<>(); row.put("timestamp", Instant.now()); row.put("event_type", eventType); - row.put( - "agent", - callbackContext.map(CallbackContext::agentName).orElse(invocationContext.agent().name())); + row.put("agent", invocationContext.agent().name()); row.put("session_id", invocationContext.session().id()); row.put("invocation_id", invocationContext.invocationId()); row.put("user_id", invocationContext.userId()); - - if (content instanceof Content contentParts) { - row.put( - "content_parts", - JsonFormatter.formatContentParts(Optional.of(contentParts), config.maxContentLength())); - row.put( - "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); - } else if (content != null) { - row.put( - "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + // Parse and log content + ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength()); + row.put("content_parts", parsedContent.parts()); + row.put("content", parsedContent.content()); + row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); + + EventData data = eventData.orElse(EventData.builder().build()); + row.put("status", data.status()); + data.errorMessage().ifPresent(msg -> row.put("error_message", msg)); + + Map latencyMap = extractLatency(data); + if (latencyMap != null) { + row.put("latency_ms", convertToJsonNode(latencyMap)); } + row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext))); - Map attributes = new HashMap<>(config.customTags()); - if (extraAttributes != null) { - attributes.putAll(extraAttributes); - } + addTraceDetails(row, invocationContext, eventData); + batchProcessor.append(row); + } + + private void addTraceDetails( + Map row, InvocationContext invocationContext, Optional eventData) { + String traceId = + eventData + .flatMap(EventData::traceIdOverride) + .orElseGet(() -> traceManager.getTraceId(invocationContext)); + Optional ambientSpanIds = traceManager.getAmbientSpanAndParent(); + SpanIds spanIds = ambientSpanIds.orElse(traceManager.getCurrentSpanAndParent()); + + row.put("trace_id", traceId); + row.put( + "span_id", + eventData.flatMap(EventData::spanIdOverride).orElse(spanIds.spanId().orElse(null))); row.put( - "attributes", - JsonFormatter.smartTruncate(attributes, config.maxContentLength()).toString()); + "parent_span_id", + eventData + .flatMap(EventData::parentSpanIdOverride) + .orElse(spanIds.parentSpanId().orElse(null))); + } - addTraceDetails(row); - batchProcessor.append(row); + private @Nullable Map extractLatency(EventData eventData) { + Map latencyMap = new HashMap<>(); + eventData.latency().ifPresent(v -> latencyMap.put("total_ms", v.toMillis())); + eventData + .timeToFirstToken() + .ifPresent(v -> latencyMap.put("time_to_first_token_ms", v.toMillis())); + return latencyMap.isEmpty() ? null : latencyMap; } - // TODO(b/491849911): Implement own trace management functionality. - private void addTraceDetails(Map row) { - SpanContext spanContext = Span.current().getSpanContext(); - if (spanContext.isValid()) { - row.put("trace_id", spanContext.getTraceId()); - row.put("span_id", spanContext.getSpanId()); + private Map getAttributes( + EventData eventData, InvocationContext invocationContext) { + Map attributes = new HashMap<>(eventData.extraAttributes()); + + attributes.put("root_agent_name", traceManager.getRootAgentName()); + eventData.model().ifPresent(m -> attributes.put("model", m)); + eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv)); + eventData + .usageMetadata() + .ifPresent( + um -> { + TruncationResult result = smartTruncate(um, config.maxContentLength()); + attributes.put("usage_metadata", toJavaObject(result.node())); + }); + + if (config.logSessionMetadata()) { + try { + Session session = invocationContext.session(); + Map sessionMeta = new HashMap<>(); + sessionMeta.put("session_id", session.id()); + sessionMeta.put("app_name", session.appName()); + sessionMeta.put("user_id", session.userId()); + + if (!session.state().isEmpty()) { + TruncationResult result = smartTruncate(session.state(), config.maxContentLength()); + sessionMeta.put("state", toJavaObject(result.node())); + } + attributes.put("session_metadata", sessionMeta); + } catch (RuntimeException e) { + // Ignore session enrichment errors as in Python. + } } + + if (!config.customTags().isEmpty()) { + attributes.put("custom_tags", config.customTags()); + } + + return attributes; } @Override @@ -284,77 +380,237 @@ public Completable close() { return Completable.complete(); } - @Override - public Maybe onUserMessageCallback( - InvocationContext invocationContext, Content userMessage) { - return Maybe.fromAction( - () -> logEvent("USER_MESSAGE", invocationContext, Optional.empty(), userMessage, null)); + private Optional getCompletedEventData(InvocationContext invocationContext) { + String traceId = traceManager.getTraceId(invocationContext); + // Pop the invocation span from the trace manager. + Optional popped = traceManager.popSpan(); + if (popped.isEmpty()) { + // No invocation span to pop. + logger.info("No invocation span to pop."); + return Optional.empty(); + } + Optional parentSpanId = traceManager.getCurrentSpanId(); + + EventData.Builder eventDataBuilder = EventData.builder(); + eventDataBuilder.setTraceIdOverride(traceId); + eventDataBuilder.setLatency(popped.get().duration()); + // Only override span IDs when no ambient OTel span exists. + // Keep STARTING/COMPLETED pairs consistent. + if (!traceManager.hasAmbientSpan()) { + if (parentSpanId.isPresent()) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId.get()); + } + if (popped.get().spanId() != null) { + eventDataBuilder.setSpanIdOverride(popped.get().spanId()); + } + } + return Optional.of(eventDataBuilder.build()); } + // --- Plugin callbacks --- @Override - public Maybe beforeRunCallback(InvocationContext invocationContext) { + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { return Maybe.fromAction( - () -> logEvent("INVOCATION_START", invocationContext, Optional.empty(), null, null)); + () -> { + traceManager.ensureInvocationSpan(invocationContext); + logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty()); + if (userMessage.parts().isPresent()) { + for (Part part : userMessage.parts().get()) { + if (part.functionCall().isPresent() + && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); + TruncationResult truncatedResult = smartTruncate(part, config.maxContentLength()); + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionCall().get().name().get(), + "result", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty()); + } + } + } + }); } @Override public Maybe onEventCallback(InvocationContext invocationContext, Event event) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("event_author", event.author()); + EventData.Builder eventDataBuilder = + EventData.builder() + .setExtraAttributes( + ImmutableMap.builder() + .put("state_delta", event.actions().stateDelta()) + .put("author", event.author()) + .buildOrThrow()); logEvent( - "EVENT", invocationContext, Optional.empty(), event.content().orElse(null), attrs); + "STATE_DELTA", + invocationContext, + event.content().orElse(null), + Optional.of(eventDataBuilder.build())); + + if (event.content().isPresent() && event.content().get().parts().isPresent()) { + for (Part part : event.content().get().parts().get()) { + if (part.functionCall().isPresent() + && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); + TruncationResult truncatedResult = + smartTruncate(part.functionCall().get().args(), config.maxContentLength()); + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionCall().get().name().get(), + "args", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty()); + } + if (part.functionResponse().isPresent() + && HITL_EVENT_TYPES.containsKey( + part.functionResponse().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionResponse().get().name().get()); + TruncationResult truncatedResult = + smartTruncate( + part.functionResponse().get().response(), config.maxContentLength()); + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionResponse().get().name().get(), + "response", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty()); + } + } + } }); } + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + traceManager.ensureInvocationSpan(invocationContext); + return Maybe.fromAction( + () -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty())); + } + @Override public Completable afterRunCallback(InvocationContext invocationContext) { return Completable.fromAction( () -> { - logEvent("INVOCATION_END", invocationContext, Optional.empty(), null, null); + logEvent( + "INVOCATION_COMPLETED", + invocationContext, + null, + getCompletedEventData(invocationContext)); batchProcessor.flush(); + traceManager.clearStack(); }); } @Override public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return Maybe.fromAction( - () -> - logEvent( - "AGENT_START", - callbackContext.invocationContext(), - Optional.of(callbackContext), - null, - null)); + () -> { + traceManager.pushSpan("agent:" + agent.name()); + logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty()); + }); } @Override public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return Maybe.fromAction( - () -> - logEvent( - "AGENT_END", - callbackContext.invocationContext(), - Optional.of(callbackContext), - null, - null)); + () -> { + logEvent( + "AGENT_COMPLETED", + callbackContext.invocationContext(), + null, + getCompletedEventData(callbackContext.invocationContext())); + }); } + /** + * Callback before LLM call. + * + *

Logs the LLM request details including: 1. Prompt content 2. System instruction (if + * available) + * + *

The content is formatted as 'Prompt: {prompt} | System Prompt: {system_prompt}'. + */ @Override public Maybe beforeModelCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); + Map attributes = new HashMap<>(); + Map llmConfig = new HashMap<>(); LlmRequest req = llmRequest.build(); - attrs.put("model", req.model().orElse("unknown")); - logEvent( - "MODEL_REQUEST", - callbackContext.invocationContext(), - Optional.of(callbackContext), - req, - attrs); + if (req.config().isPresent()) { + if (req.config().get().temperature().isPresent()) { + llmConfig.put("temperature", req.config().get().temperature().get()); + } + if (req.config().get().topP().isPresent()) { + llmConfig.put("top_p", req.config().get().topP().get()); + } + if (req.config().get().topK().isPresent()) { + llmConfig.put("top_k", req.config().get().topK().get()); + } + if (req.config().get().candidateCount().isPresent()) { + llmConfig.put("candidate_count", req.config().get().candidateCount().get()); + } + if (req.config().get().maxOutputTokens().isPresent()) { + llmConfig.put("max_output_tokens", req.config().get().maxOutputTokens().get()); + } + if (req.config().get().stopSequences().isPresent()) { + llmConfig.put("stop_sequences", req.config().get().stopSequences().get()); + } + if (req.config().get().presencePenalty().isPresent()) { + llmConfig.put("presence_penalty", req.config().get().presencePenalty().get()); + } + if (req.config().get().frequencyPenalty().isPresent()) { + llmConfig.put("frequency_penalty", req.config().get().frequencyPenalty().get()); + } + if (req.config().get().responseMimeType().isPresent()) { + llmConfig.put("response_mime_type", req.config().get().responseMimeType().get()); + } + if (req.config().get().responseSchema().isPresent()) { + llmConfig.put("response_schema", req.config().get().responseSchema().get()); + } + if (req.config().get().seed().isPresent()) { + llmConfig.put("seed", req.config().get().seed().get()); + } + if (req.config().get().responseLogprobs().isPresent()) { + llmConfig.put("response_logprobs", req.config().get().responseLogprobs().get()); + } + if (req.config().get().logprobs().isPresent()) { + llmConfig.put("logprobs", req.config().get().logprobs().get()); + } + // Put labels in attributes instead of LLM config. + if (req.config().get().labels().isPresent()) { + attributes.put("labels", req.config().get().labels().get()); + } + } + if (!llmConfig.isEmpty()) { + attributes.put("llm_config", llmConfig); + } + if (!req.tools().isEmpty()) { + attributes.put("tools", req.tools().keySet()); + } + EventData eventData = + EventData.builder() + .setModel(req.model().orElse("")) + .setExtraAttributes(attributes) + .build(); + traceManager.pushSpan("llm_request"); + logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData)); }); } @@ -363,14 +619,94 @@ public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - llmResponse.usageMetadata().ifPresent(u -> attrs.put("usage_metadata", u)); + // TODO(b/495809488): Add formatting of the content + ParsedContent parsedContent = + JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength()); + + Map usageDict = new HashMap<>(); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage.promptTokenCount().ifPresent(c -> usageDict.put("prompt", c)); + usage.candidatesTokenCount().ifPresent(c -> usageDict.put("completion", c)); + usage.totalTokenCount().ifPresent(c -> usageDict.put("total", c)); + }); + + Map contentMap = new HashMap<>(); + if (parsedContent.content() != null && !parsedContent.content().isNull()) { + contentMap.put("response", parsedContent.content()); + } + if (!usageDict.isEmpty()) { + contentMap.put("usage", usageDict); + } + + InvocationContext invocationContext = callbackContext.invocationContext(); + Optional spanId = traceManager.getCurrentSpanId(); + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + String parentSpanId = spanIds.parentSpanId().orElse(null); + + boolean isPopped = false; + Duration duration = Duration.ZERO; + Duration ttft = null; + Optional startTime = Optional.empty(); + Optional firstTokenTime = Optional.empty(); + + if (spanId.isPresent()) { + traceManager.recordFirstToken(spanId.get()); + startTime = traceManager.getStartTime(spanId.get()); + firstTokenTime = traceManager.getFirstTokenTime(spanId.get()); + if (startTime.isPresent() && firstTokenTime.isPresent()) { + ttft = Duration.between(startTime.get(), firstTokenTime.get()); + } + } + + if (llmResponse.partial().orElse(false)) { + // Streaming chunk - do NOT pop span yet + if (startTime.isPresent()) { + duration = Duration.between(startTime.get(), Instant.now()); + } + } else { + // Final response - pop span + Optional popped = traceManager.popSpan(); + if (popped.isPresent()) { + spanId = Optional.of(popped.get().spanId()); + duration = popped.get().duration(); + isPopped = true; + } + } + + boolean hasAmbient = traceManager.hasAmbientSpan(); + boolean useOverride = isPopped && !hasAmbient; + + EventData.Builder eventDataBuilder = EventData.builder(); + if (!duration.isZero()) { + eventDataBuilder.setLatency(duration); + } + if (ttft != null) { + eventDataBuilder.setTimeToFirstToken(ttft); + } + llmResponse.modelVersion().ifPresent(eventDataBuilder::setModelVersion); + + if (!usageDict.isEmpty()) { + eventDataBuilder.setUsageMetadata(usageDict); + } + + if (useOverride) { + if (spanId.isPresent()) { + eventDataBuilder.setSpanIdOverride(spanId.get()); + } + if (parentSpanId != null) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId); + } + } + logEvent( - "MODEL_RESPONSE", - callbackContext.invocationContext(), - Optional.of(callbackContext), - llmResponse, - attrs); + "LLM_RESPONSE", + invocationContext, + contentMap.isEmpty() ? null : contentMap, + parsedContent.isTruncated(), + Optional.of(eventDataBuilder.build())); }); } @@ -379,14 +715,28 @@ public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("error_message", error.getMessage()); - logEvent( - "MODEL_ERROR", - callbackContext.invocationContext(), - Optional.of(callbackContext), - null, - attrs); + InvocationContext invocationContext = callbackContext.invocationContext(); + Optional popped = traceManager.popSpan(); + String spanId = popped.map(RecordData::spanId).orElse(null); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + String parentSpanId = spanIds.spanId().orElse(null); + + boolean hasAmbient = traceManager.hasAmbientSpan(); + EventData.Builder eventDataBuilder = + EventData.builder().setStatus("ERROR").setErrorMessage(error.getMessage()); + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + if (spanId != null) { + eventDataBuilder.setSpanIdOverride(spanId); + } + if (parentSpanId != null) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId); + } + } + logEvent("LLM_ERROR", invocationContext, null, Optional.of(eventDataBuilder.build())); }); } @@ -395,14 +745,12 @@ public Maybe> beforeToolCallback( BaseTool tool, Map toolArgs, ToolContext toolContext) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("tool_name", tool.name()); - logEvent( - "TOOL_START", - toolContext.invocationContext(), - Optional.of(toolContext), - toolArgs, - attrs); + TruncationResult res = smartTruncate(toolArgs, config.maxContentLength()); + ImmutableMap contentMap = + ImmutableMap.of( + "tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node()); + traceManager.pushSpan("tool"); + logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty()); }); } @@ -414,10 +762,35 @@ public Maybe> afterToolCallback( Map result) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("tool_name", tool.name()); + Optional popped = traceManager.popSpan(); + TruncationResult truncationResult = smartTruncate(result, config.maxContentLength()); + ImmutableMap contentMap = + ImmutableMap.of( + "tool", + tool.name(), + "result", + truncationResult.node(), + "tool_origin", + getToolOrigin(tool)); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + boolean hasAmbient = traceManager.hasAmbientSpan(); + + EventData.Builder eventDataBuilder = EventData.builder(); + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); + spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); + } + logEvent( - "TOOL_END", toolContext.invocationContext(), Optional.of(toolContext), result, attrs); + "TOOL_COMPLETED", + toolContext.invocationContext(), + contentMap, + truncationResult.isTruncated(), + Optional.of(eventDataBuilder.build())); }); } @@ -426,11 +799,51 @@ public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { return Maybe.fromAction( () -> { - Map attrs = new HashMap<>(); - attrs.put("tool_name", tool.name()); - attrs.put("error_message", error.getMessage()); + Optional popped = traceManager.popSpan(); + TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength()); + String toolOrigin = getToolOrigin(tool); + ImmutableMap contentMap = + ImmutableMap.of( + "tool", tool.name(), "args", truncationResult.node(), "tool_origin", toolOrigin); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + boolean hasAmbient = traceManager.hasAmbientSpan(); + + EventData.Builder eventDataBuilder = + EventData.builder().setStatus("ERROR").setErrorMessage(error.getMessage()); + + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); + spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); + } + logEvent( - "TOOL_ERROR", toolContext.invocationContext(), Optional.of(toolContext), null, attrs); + "TOOL_ERROR", + toolContext.invocationContext(), + contentMap, + truncationResult.isTruncated(), + Optional.of(eventDataBuilder.build())); }); } + + private String getToolOrigin(BaseTool tool) { + if (tool instanceof AbstractMcpTool) { + return "MCP"; + } + if (tool instanceof AgentTool agentTool) { + return agentTool.getAgent().toolOrigin().equals(AgentOrigin.BASE_AGENT) + ? AgentOrigin.SUB_AGENT.toString() + : agentTool.getAgent().toolOrigin().toString(); + } + if (tool.name().equals("transfer_to_agent")) { + return "TRANSFER_AGENT"; + } + if (tool instanceof FunctionTool) { + return "LOCAL"; + } + return "UNKNOWN"; + } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java index 22ced137e..149c8a92c 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -34,14 +34,10 @@ public abstract class BigQueryLoggerConfig { // Whether the plugin is enabled. public abstract boolean enabled(); - // List of event types to log. If None, all are allowed - // TODO(b/491852782): Implement allowlist/denylist for event types. - @Nullable + // List of event types to log. If None, all are allowed. public abstract ImmutableList eventAllowlist(); // List of event types to ignore. - // TODO(b/491852782): Implement allowlist/denylist for event types. - @Nullable public abstract ImmutableList eventDenylist(); // Max length for text content before truncation. @@ -103,6 +99,8 @@ public abstract class BigQueryLoggerConfig { @Nullable public abstract Credentials credentials(); + public abstract Builder toBuilder(); + public static Builder builder() { return new AutoValue_BigQueryLoggerConfig.Builder() .enabled(true) @@ -118,6 +116,8 @@ public static Builder builder() { .queueMaxSize(10000) .logSessionMetadata(true) .customTags(ImmutableMap.of()) + .eventAllowlist(ImmutableList.of()) + .eventDenylist(ImmutableList.of()) // TODO(b/491851868): Enable auto-schema upgrade once implemented. .autoSchemaUpgrade(false); } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/EventData.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/EventData.java new file mode 100644 index 000000000..8fd95a070 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/EventData.java @@ -0,0 +1,64 @@ +package com.google.adk.plugins.agentanalytics; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableMap; +import java.time.Duration; +import java.util.Map; +import java.util.Optional; + +/** Typed container for structured fields passed to _log_event. */ +@AutoValue +abstract class EventData { + abstract Optional spanIdOverride(); + + abstract Optional parentSpanIdOverride(); + + abstract Optional latency(); + + abstract Optional timeToFirstToken(); + + abstract Optional model(); + + abstract Optional modelVersion(); + + abstract Optional usageMetadata(); + + abstract String status(); + + abstract Optional errorMessage(); + + abstract ImmutableMap extraAttributes(); + + abstract Optional traceIdOverride(); + + static Builder builder() { + return new AutoValue_EventData.Builder().setStatus("OK").setExtraAttributes(ImmutableMap.of()); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setSpanIdOverride(String value); + + abstract Builder setParentSpanIdOverride(String value); + + abstract Builder setLatency(Duration value); + + abstract Builder setTimeToFirstToken(Duration value); + + abstract Builder setModel(String value); + + abstract Builder setModelVersion(String value); + + abstract Builder setUsageMetadata(Object value); + + abstract Builder setStatus(String value); + + abstract Builder setErrorMessage(String value); + + abstract Builder setExtraAttributes(Map value); + + abstract Builder setTraceIdOverride(String value); + + abstract EventData build(); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java index b4b4a1049..26f436f29 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -16,26 +16,258 @@ package com.google.adk.plugins.agentanalytics; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.models.LlmRequest; +import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; import com.google.genai.types.Part; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; +import org.jspecify.annotations.Nullable; -/** Utility for formatting and truncating content for BigQuery logging. */ +/** Utility for parsing, formatting and truncating content for BigQuery logging. */ final class JsonFormatter { private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); - private JsonFormatter() {} + @AutoValue + abstract static class TruncationResult { + abstract JsonNode node(); + + abstract boolean isTruncated(); + + static TruncationResult create(JsonNode node, boolean isTruncated) { + return new AutoValue_JsonFormatter_TruncationResult(node, isTruncated); + } + } + + @AutoValue + abstract static class ParsedContent { + abstract ImmutableList parts(); + + abstract JsonNode content(); + + abstract boolean isTruncated(); + + static ParsedContent create( + ImmutableList parts, JsonNode content, boolean isTruncated) { + return new AutoValue_JsonFormatter_ParsedContent(parts, content, isTruncated); + } + } + + @AutoValue + abstract static class ParsedContentObject { + abstract ArrayNode parts(); + + abstract String summary(); + + abstract boolean isTruncated(); + + static ParsedContentObject create(ArrayNode parts, String summary, boolean isTruncated) { + return new AutoValue_JsonFormatter_ParsedContentObject(parts, summary, isTruncated); + } + } + + @AutoValue + abstract static class ContentPart { + @JsonProperty("part_index") + abstract int partIndex(); + + @JsonProperty("mime_type") + abstract @Nullable String mimeType(); + + @JsonProperty("uri") + abstract @Nullable String uri(); + + @JsonProperty("text") + abstract @Nullable String text(); + + @JsonProperty("part_attributes") + abstract String partAttributes(); + + @JsonProperty("storage_mode") + abstract String storageMode(); + + @JsonProperty("object_ref") + abstract @Nullable String objectRef(); + + static Builder builder() { + return new AutoValue_JsonFormatter_ContentPart.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setPartIndex(int value); + + abstract Builder setMimeType(@Nullable String value); + + abstract Builder setUri(@Nullable String value); + + abstract Builder setText(@Nullable String value); + + abstract Builder setPartAttributes(String value); + + abstract Builder setStorageMode(String value); + + abstract Builder setObjectRef(@Nullable String value); + + abstract ContentPart build(); + } + } + + /** + * Parses content into JSON payload and content parts, matching Python implementation. + * + * @param content the content to parse + * @param maxLength the maximum length for text fields + * @return a ParsedContent object + */ + static ParsedContent parse(Object content, int maxLength) { + JsonNode contentNode = mapper.nullNode(); + ArrayNode contentParts = mapper.createArrayNode(); + boolean isTruncated = false; + + if (content instanceof LlmRequest llmRequest) { + ObjectNode jsonPayload = mapper.createObjectNode(); + // Handle prompt + ArrayNode messages = mapper.createArrayNode(); + List contents = llmRequest.contents(); + for (Content c : contents) { + String role = c.role().orElse("unknown"); + ParsedContentObject parsedContentObject = parseContentObject(c, maxLength); + isTruncated = isTruncated || parsedContentObject.isTruncated(); + contentParts.addAll(parsedContentObject.parts()); + + ObjectNode message = mapper.createObjectNode(); + message.put("role", role); + message.put("content", parsedContentObject.summary()); + messages.add(message); + } + if (!messages.isEmpty()) { + jsonPayload.set("prompt", messages); + } + // Handle system instruction + if (llmRequest.config().isPresent() + && llmRequest.config().get().systemInstruction().isPresent()) { + Content systemInstruction = llmRequest.config().get().systemInstruction().get(); + ParsedContentObject parsedSystemInstruction = + parseContentObject(systemInstruction, maxLength); + isTruncated = isTruncated || parsedSystemInstruction.isTruncated(); + contentParts.addAll(parsedSystemInstruction.parts()); + jsonPayload.put("system_prompt", parsedSystemInstruction.summary()); + } + contentNode = jsonPayload; + } else if (content instanceof Content || content instanceof Part) { + ParsedContentObject parsedContentObject = parseContentObject(content, maxLength); + ObjectNode summaryNode = mapper.createObjectNode(); + summaryNode.put("text_summary", parsedContentObject.summary()); + return ParsedContent.create( + ImmutableList.copyOf(parsedContentObject.parts()), + summaryNode, + parsedContentObject.isTruncated()); + } else if (content instanceof String s) { + TruncationResult result = truncateWithStatus(s, maxLength); + contentNode = result.node(); + isTruncated = result.isTruncated(); + } else { + TruncationResult result = smartTruncate(content, maxLength); + contentNode = result.node(); + isTruncated = result.isTruncated(); + } + return ParsedContent.create(ImmutableList.copyOf(contentParts), contentNode, isTruncated); + } + + /** + * Parses a Content or Part object into summary text and content parts. + * + * @param content the Content or Part object to parse + * @param maxLength the maximum length of text fields before truncation + * @return a ParsedContentObject containing parts, summary, and truncation flag + */ + private static ParsedContentObject parseContentObject(Object content, int maxLength) { + ArrayNode contentParts = mapper.createArrayNode(); + boolean isTruncated = false; + List summaryText = new ArrayList<>(); + + List parts; + if (content instanceof Content c) { + parts = c.parts().orElse(ImmutableList.of()); + } else if (content instanceof Part p) { + parts = ImmutableList.of(p); + } else { + return ParsedContentObject.create(contentParts, "", false); + } + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ContentPart.Builder partBuilder = + ContentPart.builder() + .setPartIndex(i) + .setMimeType("text/plain") + .setUri(null) + .setText(null) + .setPartAttributes("{}") + .setStorageMode("INLINE") + .setObjectRef(null); + + // CASE A: It is already a URI (e.g. from user input) + if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partBuilder + .setStorageMode("EXTERNAL_URI") + .setUri(fileData.fileUri().orElse(null)) + .setMimeType(fileData.mimeType().orElse(null)); + } + // CASE B: It is Binary/Inline Data (Image/Blob) + else if (part.inlineData().isPresent()) { + // TODO: (b/485571635) Implement GCS offloading here. + partBuilder + .setText("[BINARY DATA]") + .setMimeType(part.inlineData().get().mimeType().orElse("")); + } + // CASE C: Text + else if (part.text().isPresent()) { + String text = part.text().get(); + // TODO: (b/485571635) Implement GCS offloading if text length exceeds maxLength. + if (text.length() > maxLength) { + text = truncate(text, maxLength); + isTruncated = true; + } + partBuilder.setText(text); + summaryText.add(text); + } else if (part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + ObjectNode partAttributes = mapper.createObjectNode(); + partAttributes.put("function_name", fc.name().orElse("unknown")); + partBuilder + .setMimeType("application/json") + .setText("Function: " + fc.name().orElse("unknown")) + .setPartAttributes(partAttributes.toString()); + } + contentParts.add(mapper.valueToTree(partBuilder.build())); + } + + String summaryResult = String.join(" | ", summaryText); + if (summaryResult.length() > maxLength) { + summaryResult = truncate(summaryResult, maxLength); + isTruncated = true; + } + + return ParsedContentObject.create(contentParts, summaryResult, isTruncated); + } /** Formats Content parts into an ArrayNode for BigQuery logging. */ - public static ArrayNode formatContentParts(Optional content, int maxLength) { + static ArrayNode formatContentParts(Optional content, int maxLength) { ArrayNode partsArray = mapper.createArrayNode(); if (content.isEmpty() || content.get().parts() == null) { return partsArray; @@ -51,7 +283,7 @@ public static ArrayNode formatContentParts(Optional content, int maxLen if (part.text().isPresent()) { partObj.put("mime_type", "text/plain"); - partObj.put("text", truncateString(part.text().get(), maxLength)); + partObj.put("text", truncate(part.text().get(), maxLength)); } else if (part.inlineData().isPresent()) { Blob blob = part.inlineData().get(); partObj.put("mime_type", blob.mimeType().orElse("")); @@ -67,45 +299,84 @@ public static ArrayNode formatContentParts(Optional content, int maxLen return partsArray; } - /** Recursively truncates long strings inside an object and returns a Jackson JsonNode. */ - public static JsonNode smartTruncate(Object obj, int maxLength) { + /** Recursively truncates long strings inside an object and returns a TruncationResult. */ + static TruncationResult smartTruncate(Object obj, int maxLength) { if (obj == null) { - return mapper.nullNode(); + return TruncationResult.create(mapper.nullNode(), false); } try { return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); + } catch (IllegalArgumentException e) { + // Fallback for types that mapper can't handle directly as a tree + return truncateWithStatus(String.valueOf(obj), maxLength); + } + } + + static JsonNode convertToJsonNode(Object obj) { + if (obj == null) { + return mapper.nullNode(); + } + try { + return mapper.valueToTree(obj); } catch (IllegalArgumentException e) { // Fallback for types that mapper can't handle directly as a tree return mapper.valueToTree(String.valueOf(obj)); } } - private static JsonNode recursiveSmartTruncate(JsonNode node, int maxLength) { + private static TruncationResult recursiveSmartTruncate(JsonNode node, int maxLength) { + boolean isTruncated = false; if (node.isTextual()) { - return mapper.valueToTree(truncateString(node.asText(), maxLength)); + String text = node.asText(); + if (text.length() > maxLength) { + return TruncationResult.create(mapper.valueToTree(truncate(text, maxLength)), true); + } + return TruncationResult.create(node, false); } else if (node.isObject()) { ObjectNode newNode = mapper.createObjectNode(); - node.properties() - .iterator() - .forEachRemaining( - entry -> { - newNode.set(entry.getKey(), recursiveSmartTruncate(entry.getValue(), maxLength)); - }); - return newNode; + Set> properties = node.properties(); + for (Map.Entry entry : properties) { + TruncationResult res = recursiveSmartTruncate(entry.getValue(), maxLength); + newNode.set(entry.getKey(), res.node()); + isTruncated = isTruncated || res.isTruncated(); + } + return TruncationResult.create(newNode, isTruncated); } else if (node.isArray()) { ArrayNode newNode = mapper.createArrayNode(); for (JsonNode element : node) { - newNode.add(recursiveSmartTruncate(element, maxLength)); + TruncationResult res = recursiveSmartTruncate(element, maxLength); + newNode.add(res.node()); + isTruncated = isTruncated || res.isTruncated(); } - return newNode; + return TruncationResult.create(newNode, isTruncated); + } + return TruncationResult.create(node, false); + } + + private static TruncationResult truncateWithStatus(String s, int maxLength) { + if (s == null) { + return TruncationResult.create(mapper.nullNode(), false); + } + if (s.length() <= maxLength) { + return TruncationResult.create(mapper.valueToTree(s), false); } - return node; + return TruncationResult.create(mapper.valueToTree(truncate(s, maxLength)), true); } - private static String truncateString(String s, int maxLength) { + private static String truncate(String s, int maxLength) { if (s == null || s.length() <= maxLength) { return s; } return s.substring(0, maxLength) + "...[truncated]"; } + + /** Converts a JsonNode to a standard Java object (Map, List, etc.). */ + public static @Nullable Object toJavaObject(JsonNode node) { + if (node == null || node.isNull()) { + return null; + } + return mapper.convertValue(node, Object.class); + } + + private JsonFormatter() {} } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/TraceManager.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/TraceManager.java new file mode 100644 index 000000000..a02ea00b4 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/TraceManager.java @@ -0,0 +1,279 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.google.adk.agents.InvocationContext; +import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.sdk.trace.ReadableSpan; +import java.time.Duration; +import java.time.Instant; +import java.util.Iterator; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; +import java.util.logging.Logger; + +/** + * Manages OpenTelemetry-style trace and span context using InvocationContext callback data. + * + *

Uses a stack of SpanRecord objects to keep span, ID, ownership, and timing in sync. + */ +public final class TraceManager { + private static final Logger logger = Logger.getLogger(TraceManager.class.getName()); + + private final ConcurrentLinkedDeque records = new ConcurrentLinkedDeque<>(); + private String rootAgentName = "_bq_analytics_root_agent_name"; + private String activeInvocationId = "_bq_analytics_active_invocation_id"; + + private final Tracer tracer; + + TraceManager() { + this(GlobalOpenTelemetry.getTracer("google.adk.plugins.bigquery_agent_analytics")); + } + + TraceManager(Tracer tracer) { + this.tracer = tracer; + } + + @AutoValue + abstract static class SpanRecord { + abstract Span span(); + + abstract String spanId(); + + abstract boolean ownsSpan(); + + abstract Instant startTime(); + + abstract AtomicReference firstTokenTime(); + + static SpanRecord create(Span span, String spanId, boolean ownsSpan, Instant startTime) { + return new AutoValue_TraceManager_SpanRecord( + span, spanId, ownsSpan, startTime, new AtomicReference<>()); + } + } + + @AutoValue + abstract static class RecordData { + abstract String spanId(); + + abstract Duration duration(); + + static RecordData create(String spanId, Duration duration) { + return new AutoValue_TraceManager_RecordData(spanId, duration); + } + } + + @AutoValue + abstract static class SpanIds { + abstract Optional spanId(); + + abstract Optional parentSpanId(); + + static SpanIds create(String spanId, String parentSpanId) { + return new AutoValue_TraceManager_SpanIds( + Optional.ofNullable(spanId), Optional.ofNullable(parentSpanId)); + } + } + + public String getRootAgentName() { + return rootAgentName; + } + + public void initTrace(InvocationContext context) { + String rootAgentName = context.agent().rootAgent().name(); + this.rootAgentName = rootAgentName; + } + + public String getTraceId(InvocationContext context) { + if (!records.isEmpty()) { + Span currentSpan = records.peekLast().span(); + if (currentSpan.getSpanContext().isValid()) { + return currentSpan.getSpanContext().getTraceId(); + } + } + // Fallback to the ambient span. + SpanContext ambient = Span.current().getSpanContext(); + if (ambient.isValid()) { + return ambient.getTraceId(); + } + // Fallback to the invocation ID. + return context.invocationId(); + } + + public boolean hasAmbientSpan() { + return Span.current().getSpanContext().isValid(); + } + + @CanIgnoreReturnValue + public String pushSpan(String spanName) { + Context parentContext = Context.current(); + if (!records.isEmpty()) { + Span parentSpan = records.peekLast().span(); + if (parentSpan.getSpanContext().isValid()) { + parentContext = parentContext.with(parentSpan); + } + } + + Span span = tracer.spanBuilder(spanName).setParent(parentContext).startSpan(); + String spanIdStr; + if (span.getSpanContext().isValid()) { + spanIdStr = span.getSpanContext().getSpanId(); + } else { + // This span id aligns with the OpenTelemetry Span ID format. + spanIdStr = UUID.randomUUID().toString().replace("-", "").substring(0, 16); + } + + SpanRecord record = SpanRecord.create(span, spanIdStr, true, Instant.now()); + records.add(record); + return spanIdStr; + } + + @CanIgnoreReturnValue + public String attachCurrentSpan() { + Span span = Span.current(); + String spanIdStr; + if (span.getSpanContext().isValid()) { + spanIdStr = span.getSpanContext().getSpanId(); + } else { + spanIdStr = UUID.randomUUID().toString().replace("-", "").substring(0, 16); + } + + SpanRecord record = SpanRecord.create(span, spanIdStr, false, Instant.now()); + records.add(record); + return spanIdStr; + } + + public void ensureInvocationSpan(InvocationContext context) { + String currentInv = context.invocationId(); + + if (!records.isEmpty()) { + if (currentInv.equals(activeInvocationId)) { + return; + } + logger.info("Clearing stale span records from previous invocation."); + clearStack(); + } + + activeInvocationId = currentInv; + + Span ambient = Span.current(); + if (ambient.getSpanContext().isValid()) { + attachCurrentSpan(); + } else { + pushSpan("invocation"); + } + } + + @CanIgnoreReturnValue + public Optional popSpan() { + if (records.isEmpty()) { + return Optional.empty(); + } + SpanRecord record = records.pollLast(); + if (record == null) { + return Optional.empty(); + } + Duration duration = Duration.between(record.startTime(), Instant.now()); + if (record.ownsSpan()) { + record.span().end(); + } + return Optional.of(RecordData.create(record.spanId(), duration)); + } + + public void clearStack() { + for (SpanRecord record : records) { + if (record.ownsSpan()) { + record.span().end(); + } + } + records.clear(); + } + + public SpanIds getCurrentSpanAndParent() { + if (records.isEmpty()) { + return SpanIds.create(null, null); + } + + String spanId = records.peekLast().spanId(); + String parentId = + findRecord(records.descendingIterator(), record -> !record.spanId().equals(spanId)) + .map(SpanRecord::spanId) + .orElse(null); + return SpanIds.create(spanId, parentId); + } + + Optional getAmbientSpanAndParent() { + Span ambient = Span.current(); + if (!ambient.getSpanContext().isValid()) { + return Optional.empty(); + } + String spanId = ambient.getSpanContext().getSpanId(); + String parentSpanId = null; + if (ambient instanceof ReadableSpan readableSpan) { + SpanContext parentCtx = readableSpan.getParentSpanContext(); + if (parentCtx != null && parentCtx.isValid()) { + parentSpanId = parentCtx.getSpanId(); + } + } + return Optional.of(SpanIds.create(spanId, parentSpanId)); + } + + public Optional getCurrentSpanId() { + if (records.isEmpty()) { + return Optional.empty(); + } + return Optional.of(records.peekLast().spanId()); + } + + private Optional findRecord( + Iterator iterator, Predicate predicate) { + while (iterator.hasNext()) { + SpanRecord record = iterator.next(); + if (predicate.test(record)) { + return Optional.of(record); + } + } + return Optional.empty(); + } + + private Optional findSpanRecord(String spanId) { + // Search from newest to oldest for efficiency. + return findRecord(records.descendingIterator(), record -> record.spanId().equals(spanId)); + } + + public void recordFirstToken(String spanId) { + findSpanRecord(spanId) + .ifPresent(record -> record.firstTokenTime().compareAndSet(null, Instant.now())); + } + + public Optional getStartTime(String spanId) { + return findSpanRecord(spanId).map(SpanRecord::startTime); + } + + public Optional getFirstTokenTime(String spanId) { + return findSpanRecord(spanId).map(record -> record.firstTokenTime().get()); + } +} diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index 7eabc48c4..956a8eb51 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -80,7 +80,7 @@ protected AgentTool(BaseAgent agent, boolean skipSummarization) { } @VisibleForTesting - BaseAgent getAgent() { + public BaseAgent getAgent() { return agent; } diff --git a/core/src/main/java/com/google/adk/utils/AgentEnums.java b/core/src/main/java/com/google/adk/utils/AgentEnums.java new file mode 100644 index 000000000..05460b540 --- /dev/null +++ b/core/src/main/java/com/google/adk/utils/AgentEnums.java @@ -0,0 +1,13 @@ +package com.google.adk.utils; + +/** Enums for agents. */ +public final class AgentEnums { + /** Origin of the agent. */ + public static enum AgentOrigin { + BASE_AGENT, + SUB_AGENT, + A2A, + } + + private AgentEnums() {} +} diff --git a/core/src/test/java/com/google/adk/models/GemmaTest.java b/core/src/test/java/com/google/adk/models/GemmaTest.java new file mode 100644 index 000000000..fe6315dcc --- /dev/null +++ b/core/src/test/java/com/google/adk/models/GemmaTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GemmaTest { + + @Test + public void getLlm_withValidGemmaModels_succeeds() { + assertThat(LlmRegistry.matchesAnyPattern("gemma-4-26b-a4b-it")).isTrue(); + assertThat(LlmRegistry.matchesAnyPattern("gemma-4-31b-it")).isTrue(); + } + + @Test + public void getLlm_withInvalidGemmaModels_throwsException() { + assertThat(LlmRegistry.matchesAnyPattern("not-a-gemma")).isFalse(); + assertThat(LlmRegistry.matchesAnyPattern("gemma")).isFalse(); + } +} diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java new file mode 100644 index 000000000..9dc63c5d6 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java @@ -0,0 +1,221 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.chat; + +import static com.google.common.truth.Truth.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import java.util.HashMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class ChatCompletionsRequestTest { + + private ObjectMapper objectMapper; + + @Before + public void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + public void testSerializeChatCompletionRequest_standard() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); + message.role = "user"; + message.content = new ChatCompletionsRequest.MessageContent("Hello"); + request.messages = ImmutableList.of(message); + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"model\":\"gemini-3-flash-preview\""); + assertThat(json).contains("\"role\":\"user\""); + assertThat(json).contains("\"content\":\"Hello\""); + } + + @Test + public void testSerializeChatCompletionRequest_withExtraBody() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); + message.role = "user"; + message.content = new ChatCompletionsRequest.MessageContent("Explain to me how AI works"); + request.messages = ImmutableList.of(message); + + Map thinkingConfig = new HashMap<>(); + thinkingConfig.put("thinking_level", "low"); + thinkingConfig.put("include_thoughts", true); + + Map google = new HashMap<>(); + google.put("thinking_config", thinkingConfig); + + Map extraBody = new HashMap<>(); + extraBody.put("google", google); + + request.extraBody = extraBody; + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"extra_body\":{"); + assertThat(json).contains("\"thinking_level\":\"low\""); + assertThat(json).contains("\"include_thoughts\":true"); + } + + @Test + public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + ChatCompletionsRequest.Message userMessage = new ChatCompletionsRequest.Message(); + userMessage.role = "user"; + userMessage.content = new ChatCompletionsRequest.MessageContent("Check flight status"); + + ChatCompletionsRequest.Message modelMessage = new ChatCompletionsRequest.Message(); + modelMessage.role = "model"; + + ChatCompletionsCommon.ToolCall toolCall = new ChatCompletionsCommon.ToolCall(); + toolCall.id = "function-call-1"; + toolCall.type = "function"; + + ChatCompletionsCommon.Function function = new ChatCompletionsCommon.Function(); + function.name = "check_flight"; + function.arguments = "{\"flight\":\"AA100\"}"; + toolCall.function = function; + + Map google = new HashMap<>(); + google.put("thought_signature", ""); + + Map extraContent = new HashMap<>(); + extraContent.put("google", google); + + toolCall.extraContent = extraContent; + + modelMessage.toolCalls = ImmutableList.of(toolCall); + + ChatCompletionsRequest.Message toolMessage = new ChatCompletionsRequest.Message(); + toolMessage.role = "tool"; + toolMessage.name = "check_flight"; + toolMessage.toolCallId = "function-call-1"; + toolMessage.content = new ChatCompletionsRequest.MessageContent("{\"status\":\"delayed\"}"); + + request.messages = ImmutableList.of(userMessage, modelMessage, toolMessage); + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"role\":\"user\""); + assertThat(json).contains("\"role\":\"model\""); + assertThat(json).contains("\"role\":\"tool\""); + assertThat(json).contains("\"extra_content\":{"); + assertThat(json).contains("\"thought_signature\":\"\""); + assertThat(json).contains("\"tool_call_id\":\"function-call-1\""); + } + + @Test + public void testSerializeChatCompletionRequest_comprehensive() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + // Developer message with name + ChatCompletionsRequest.Message devMsg = new ChatCompletionsRequest.Message(); + devMsg.role = "developer"; + devMsg.content = new ChatCompletionsRequest.MessageContent("System instruction"); + devMsg.name = "system-bot"; + + request.messages = ImmutableList.of(devMsg); + + // Response Format JSON Schema + ChatCompletionsRequest.ResponseFormatJsonSchema format = + new ChatCompletionsRequest.ResponseFormatJsonSchema(); + format.jsonSchema = new ChatCompletionsRequest.ResponseFormatJsonSchema.JsonSchema(); + format.jsonSchema.name = "MySchema"; + format.jsonSchema.strict = true; + request.responseFormat = format; + + // Tool Choice Named + ChatCompletionsRequest.NamedToolChoice choice = new ChatCompletionsRequest.NamedToolChoice(); + choice.function = new ChatCompletionsRequest.NamedToolChoice.FunctionName(); + choice.function.name = "my_function"; + request.toolChoice = choice; + + String json = objectMapper.writeValueAsString(request); + + // Assert Developer Message + assertThat(json).contains("\"role\":\"developer\""); + assertThat(json).contains("\"name\":\"system-bot\""); + assertThat(json).contains("\"content\":\"System instruction\""); + + // Assert Response Format + assertThat(json).contains("\"response_format\":{"); + assertThat(json).contains("\"type\":\"json_schema\""); + assertThat(json).contains("\"name\":\"MySchema\""); + assertThat(json).contains("\"strict\":true"); + + // Assert Tool Choice + assertThat(json).contains("\"tool_choice\":{"); + assertThat(json).contains("\"type\":\"function\""); + assertThat(json).contains("\"name\":\"my_function\""); + } + + @Test + public void testSerializeChatCompletionRequest_withToolChoiceMode() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + request.toolChoice = new ChatCompletionsRequest.ToolChoiceMode("none"); + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"tool_choice\":\"none\""); + } + + @Test + public void testSerializeChatCompletionRequest_withStopAndVoice() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + request.stop = new ChatCompletionsRequest.StopCondition("STOP"); + + ChatCompletionsRequest.AudioParam audio = new ChatCompletionsRequest.AudioParam(); + audio.voice = new ChatCompletionsRequest.VoiceConfig("alloy"); + request.audio = audio; + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"stop\":\"STOP\""); + assertThat(json).contains("\"voice\":\"alloy\""); + } + + @Test + public void testSerializeChatCompletionRequest_withStopList() throws Exception { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + + request.stop = new ChatCompletionsRequest.StopCondition(ImmutableList.of("STOP1", "STOP2")); + + String json = objectMapper.writeValueAsString(request); + + assertThat(json).contains("\"stop\":[\"STOP1\",\"STOP2\"]"); + } +} diff --git a/core/src/test/java/com/google/adk/models/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java similarity index 96% rename from core/src/test/java/com/google/adk/models/ChatCompletionsResponseTest.java rename to core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index 53fcdfbdf..52134476c 100644 --- a/core/src/test/java/com/google/adk/models/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -14,13 +14,13 @@ * limitations under the License. */ -package com.google.adk.models; +package com.google.adk.models.chat; import static com.google.common.truth.Truth.assertThat; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.adk.models.ChatCompletionsResponse.ChatCompletion; -import com.google.adk.models.ChatCompletionsResponse.ChatCompletionChunk; +import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletion; +import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletionChunk; import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -245,7 +245,7 @@ public void testDeserializeChatCompletion_withCustomToolCall() throws Exception objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); assertThat(completion.choices.get(0).message.toolCalls).hasSize(1); - ChatCompletionsResponse.ToolCall toolCall = completion.choices.get(0).message.toolCalls.get(0); + ChatCompletionsCommon.ToolCall toolCall = completion.choices.get(0).message.toolCalls.get(0); assertThat(toolCall.type).isEqualTo("custom"); assertThat(toolCall.custom.name).isEqualTo("custom_tool"); assertThat(toolCall.custom.input).isEqualTo("{\"arg\":\"val\"}"); @@ -310,7 +310,7 @@ public void testDeserializeChatCompletionChunk_withToolCallDelta() throws Except ChatCompletionChunk chunk = objectMapper.readValue(json, ChatCompletionChunk.class); assertThat(chunk.choices.get(0).delta.toolCalls).hasSize(1); - ChatCompletionsResponse.ToolCall toolCall = chunk.choices.get(0).delta.toolCalls.get(0); + ChatCompletionsCommon.ToolCall toolCall = chunk.choices.get(0).delta.toolCalls.get(0); assertThat(toolCall.index).isEqualTo(1); assertThat(toolCall.id).isEqualTo("call_abc"); assertThat(toolCall.type).isEqualTo("function"); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java new file mode 100644 index 000000000..c4f8dc2cf --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java @@ -0,0 +1,238 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.Session; +import com.google.api.core.ApiFutures; +import com.google.auth.Credentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class BigQueryAgentAnalyticsPluginE2ETest { + private BigQuery mockBigQuery; + private StreamWriter mockWriter; + private BigQueryWriteClient mockWriteClient; + private BigQueryLoggerConfig config; + private BigQueryAgentAnalyticsPlugin plugin; + private Runner runner; + private BaseAgent fakeAgent; + private final List> capturedRows = + Collections.synchronizedList(new ArrayList<>()); + + @Before + public void setUp() throws Exception { + mockBigQuery = mock(BigQuery.class); + mockWriter = mock(StreamWriter.class); + mockWriteClient = mock(BigQueryWriteClient.class); + + config = + BigQueryLoggerConfig.builder() + .setEnabled(true) + .setProjectId("project") + .setDatasetId("dataset") + .setTableName("table") + .setBatchSize(10) + .setBatchFlushInterval(Duration.ofSeconds(10)) + .setCredentials(mock(Credentials.class)) + .build(); + + when(mockBigQuery.getOptions()) + .thenReturn(BigQueryOptions.newBuilder().setProjectId("test-project").build()); + when(mockBigQuery.getTable(any(TableId.class))).thenReturn(mock(Table.class)); + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + plugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + }; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create( + BigQuerySchema.getArrowSchema(), plugin.batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + for (int i = 0; i < root.getRowCount(); i++) { + Map row = new HashMap<>(); + row.put("event_type", String.valueOf(root.getVector("event_type").getObject(i))); + row.put("agent", String.valueOf(root.getVector("agent").getObject(i))); + row.put("session_id", String.valueOf(root.getVector("session_id").getObject(i))); + row.put( + "invocation_id", + String.valueOf(root.getVector("invocation_id").getObject(i))); + row.put("user_id", String.valueOf(root.getVector("user_id").getObject(i))); + row.put( + "timestamp", ((TimeStampMicroTZVector) root.getVector("timestamp")).get(i)); + row.put("is_truncated", root.getVector("is_truncated").getObject(i)); + row.put("content", String.valueOf(root.getVector("content").getObject(i))); + capturedRows.add(row); + } + } catch (RuntimeException e) { + throw new RuntimeException("Error in thenAnswer", e); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + fakeAgent = new FakeAgent("test_agent"); + runner = Runner.builder().agent(fakeAgent).appName("test_app").plugins(plugin).build(); + } + + @Test + public void runAgent_logsAgentStartingAndCompleted() throws Exception { + Session session = runner.sessionService().createSession("test_app", "user").blockingGet(); + String sessionId = session.id(); + + runner + .runAsync("user", sessionId, Content.fromParts(Part.fromText("hello"))) + .blockingSubscribe(); + + // Ensure everything is flushed. The BatchProcessor flushes asynchronously sometimes, + // but the direct flush() call should help. We wait up to 2 seconds for all 5 expected events. + for (int i = 0; i < 20 && capturedRows.size() < 5; i++) { + plugin.batchProcessor.flush(); + if (capturedRows.size() < 5) { + Thread.sleep(100); + } + } + + // Verify presence of expected events + List eventTypes = + capturedRows.stream().map(row -> (String) row.get("event_type")).toList(); + + assertFalse("capturedRows should not be empty", capturedRows.isEmpty()); + assertTrue( + "Events should contain AGENT_STARTING. Actual: " + eventTypes, + eventTypes.contains("AGENT_STARTING")); + assertTrue( + "Events should contain AGENT_COMPLETED. Actual: " + eventTypes, + eventTypes.contains("AGENT_COMPLETED")); + assertTrue( + "Events should contain USER_MESSAGE_RECEIVED. Actual: " + eventTypes, + eventTypes.contains("USER_MESSAGE_RECEIVED")); + assertTrue( + "Events should contain INVOCATION_STARTING. Actual: " + eventTypes, + eventTypes.contains("INVOCATION_STARTING")); + assertTrue( + "Events should contain INVOCATION_COMPLETED. Actual: " + eventTypes, + eventTypes.contains("INVOCATION_COMPLETED")); + + // Verify common fields for one of the rows + Map agentStartingRow = + capturedRows.stream() + .filter(row -> Objects.equals(row.get("event_type"), "AGENT_STARTING")) + .findFirst() + .orElseThrow(); + + assertEquals("test_agent", agentStartingRow.get("agent")); + assertEquals(sessionId, agentStartingRow.get("session_id")); + assertEquals("user", agentStartingRow.get("user_id")); + assertNotNull("invocation_id should be populated", agentStartingRow.get("invocation_id")); + assertTrue("timestamp should be positive", (Long) agentStartingRow.get("timestamp") > 0); + assertEquals(false, agentStartingRow.get("is_truncated")); + + // Verify content for USER_MESSAGE_RECEIVED + Map userMessageRow = + capturedRows.stream() + .filter(row -> Objects.equals(row.get("event_type"), "USER_MESSAGE_RECEIVED")) + .findFirst() + .orElseThrow(); + String contentJson = (String) userMessageRow.get("content"); + assertTrue("Content should contain 'hello'", contentJson.contains("hello")); + + // Verify order + int userMessageIdx = eventTypes.indexOf("USER_MESSAGE_RECEIVED"); + int invocationStartIdx = eventTypes.indexOf("INVOCATION_STARTING"); + int agentStartIdx = eventTypes.indexOf("AGENT_STARTING"); + int agentCompletedIdx = eventTypes.indexOf("AGENT_COMPLETED"); + int invocationCompletedIdx = eventTypes.indexOf("INVOCATION_COMPLETED"); + + assertTrue( + "USER_MESSAGE_RECEIVED should be first by Runner implementation", + userMessageIdx < invocationStartIdx); + assertTrue( + "INVOCATION_STARTING should be before AGENT_STARTING", invocationStartIdx < agentStartIdx); + assertTrue( + "AGENT_STARTING should be before AGENT_COMPLETED", agentStartIdx < agentCompletedIdx); + assertTrue( + "AGENT_COMPLETED should be before INVOCATION_COMPLETED", + agentCompletedIdx < invocationCompletedIdx); + } + + private static class FakeAgent extends BaseAgent { + FakeAgent(String name) { + super(name, "description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index 0822c2cae..1d066e632 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -33,7 +33,12 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; import com.google.adk.sessions.Session; +import com.google.adk.tools.AgentTool; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.api.core.ApiFutures; import com.google.auth.Credentials; import com.google.cloud.bigquery.BigQuery; @@ -43,18 +48,27 @@ import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Candidate; import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.reactivex.rxjava3.core.Flowable; import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -67,6 +81,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -80,6 +95,7 @@ @RunWith(JUnit4.class) public class BigQueryAgentAnalyticsPluginTest { @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); @Mock private BigQuery mockBigQuery; @Mock private StreamWriter mockWriter; @@ -90,9 +106,11 @@ public class BigQueryAgentAnalyticsPluginTest { private BigQueryLoggerConfig config; private BigQueryAgentAnalyticsPlugin plugin; private Handler mockHandler; + private Tracer tracer; @Before public void setUp() throws Exception { + tracer = openTelemetryRule.getOpenTelemetry().getTracer("test-plugin"); fakeAgent = new FakeAgent("agent_name"); config = BigQueryLoggerConfig.builder() @@ -124,12 +142,18 @@ protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { protected StreamWriter createWriter(BigQueryLoggerConfig config) { return mockWriter; } + + @Override + protected TraceManager createTraceManager() { + return new TraceManager(tracer); + } }; - Session session = Session.builder("session_id").build(); + Session session = Session.builder("session_id").appName("test_app").userId("test_user").build(); when(mockInvocationContext.session()).thenReturn(session); when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); when(mockInvocationContext.agent()).thenReturn(fakeAgent); + when(mockInvocationContext.callbackContextData()).thenReturn(new ConcurrentHashMap<>()); when(mockInvocationContext.userId()).thenReturn("user_id"); Logger logger = Logger.getLogger(BatchProcessor.class.getName()); @@ -137,6 +161,14 @@ protected StreamWriter createWriter(BigQueryLoggerConfig config) { logger.addHandler(mockHandler); } + @After + public void tearDown() { + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + if (mockHandler != null) { + logger.removeHandler(mockHandler); + } + } + @Test public void onUserMessageCallback_appendsToWriter() throws Exception { Content content = Content.builder().build(); @@ -216,12 +248,15 @@ public void onUserMessageCallback_handlesTableCreationFailure() throws Exception ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); verify(mockHandler, atLeastOnce()).publish(captor.capture()); - assertTrue( - captor - .getValue() - .getMessage() - .contains("Failed to check or create/upgrade BigQuery table")); - assertEquals(Level.WARNING, captor.getValue().getLevel()); + boolean found = + captor.getAllValues().stream() + .anyMatch( + record -> + record + .getMessage() + .contains("Failed to check or create/upgrade BigQuery table") + && Objects.equals(record.getLevel(), Level.WARNING)); + assertTrue("Should have logged table creation failure warning", found); } finally { logger.removeHandler(mockHandler); } @@ -313,7 +348,8 @@ public void logEvent_populatesCommonFields() throws Exception { if (root.getRowCount() != 1) { failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); } else if (!Objects.equals( - root.getVector("event_type").getObject(0).toString(), "USER_MESSAGE")) { + root.getVector("event_type").getObject(0).toString(), + "USER_MESSAGE_RECEIVED")) { failureMessage[0] = "Wrong event_type: " + root.getVector("event_type").getObject(0); } else if (!root.getVector("agent").getObject(0).toString().equals("agent_name")) { @@ -334,6 +370,9 @@ public void logEvent_populatesCommonFields() throws Exception { failureMessage[0] = "Wrong user_id: " + root.getVector("user_id").getObject(0); } else if (((TimeStampMicroTZVector) root.getVector("timestamp")).get(0) <= 0) { failureMessage[0] = "Timestamp not populated"; + } else if (!Objects.equals(root.getVector("is_truncated").getObject(0), false)) { + failureMessage[0] = + "Wrong is_truncated: " + root.getVector("is_truncated").getObject(0); } else { // Check content and content_parts String contentJson = root.getVector("content").getObject(0).toString(); @@ -381,6 +420,8 @@ public void logEvent_populatesTraceDetails() throws Exception { Span mockSpan = Span.wrap(mockSpanContext); try (Scope scope = mockSpan.makeCurrent()) { + plugin.traceManager.attachCurrentSpan(); + Content content = Content.builder().build(); plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); @@ -414,29 +455,190 @@ public void onEventCallback_populatesCorrectFields() throws Exception { Map row = plugin.batchProcessor.queue.poll(); assertNotNull("Row not found in queue", row); - assertEquals("EVENT", row.get("event_type")); + assertEquals("STATE_DELTA", row.get("event_type")); assertEquals("agent_name", row.get("agent")); - assertTrue(row.get("attributes").toString().contains("agent_author")); + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertEquals("agent_author", attributes.get("author").asText()); assertTrue(row.get("content").toString().contains("event content")); + assertEquals(false, row.get("is_truncated")); } @Test public void onModelErrorCallback_populatesCorrectFields() throws Exception { CallbackContext mockCallbackContext = mock(CallbackContext.class); when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); - when(mockCallbackContext.agentName()).thenReturn("agent_in_context"); LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); Throwable error = new RuntimeException("model error message"); + plugin.traceManager.pushSpan("llm_request"); plugin .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) .blockingSubscribe(); Map row = plugin.batchProcessor.queue.poll(); assertNotNull("Row not found in queue", row); - assertEquals("MODEL_ERROR", row.get("event_type")); - assertEquals("agent_in_context", row.get("agent")); - assertTrue(row.get("attributes").toString().contains("model error message")); + assertEquals("LLM_ERROR", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + assertEquals("ERROR", row.get("status")); + assertEquals("model error message", row.get("error_message")); + assertNotNull(row.get("latency_ms")); + assertEquals(false, row.get("is_truncated")); + } + + @Test + public void afterModelCallback_populatesCorrectFields() throws Exception { + CallbackContext mockCallbackContext = mock(CallbackContext.class); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + + GenerateContentResponseUsageMetadata usage = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + + GenerateContentResponse response = + GenerateContentResponse.builder() + .modelVersion("v1") + .usageMetadata(usage) + .candidates( + ImmutableList.of( + Candidate.builder() + .content(Content.fromParts(Part.fromText("llm response"))) + .build())) + .build(); + + LlmResponse adkResponse = LlmResponse.create(response); + + Span parentSpan = tracer.spanBuilder("parent_request").startSpan(); + Span ambientSpan = + tracer.spanBuilder("ambient").setParent(Context.current().with(parentSpan)).startSpan(); + // Set valid ambient span context + try (Scope scope = ambientSpan.makeCurrent()) { + plugin.traceManager.pushSpan("parent_request"); + plugin.traceManager.pushSpan("llm_request"); + plugin.afterModelCallback(mockCallbackContext, adkResponse).blockingSubscribe(); + } finally { + ambientSpan.end(); + } + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("LLM_RESPONSE", row.get("event_type")); + ObjectNode contentMap = (ObjectNode) row.get("content"); + assertNotNull(contentMap.get("response")); + ObjectNode usageMap = (ObjectNode) contentMap.get("usage"); + assertEquals(10, usageMap.get("prompt").asInt()); + + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertEquals("v1", attributes.get("model_version").asText()); + ObjectNode usageAttr = (ObjectNode) attributes.get("usage_metadata"); + assertEquals(10, usageAttr.get("prompt").asInt()); + assertEquals(false, row.get("is_truncated")); + assertNotNull(row.get("parent_span_id")); + ObjectNode latencyMs = (ObjectNode) row.get("latency_ms"); + assertNotNull("latency_ms should not be null", latencyMs); + assertTrue( + "latency_ms should contain time_to_first_token_ms", + latencyMs.has("time_to_first_token_ms")); + } + + @Test + public void afterToolCallback_populatesCorrectFields() throws Exception { + ToolContext mockToolContext = mock(ToolContext.class); + when(mockToolContext.invocationContext()).thenReturn(mockInvocationContext); + + BaseTool mockTool = mock(BaseTool.class); + when(mockTool.name()).thenReturn("test_tool"); + + ImmutableMap toolArgs = ImmutableMap.of("arg1", "value1"); + ImmutableMap result = ImmutableMap.of("res1", "value2"); + + plugin.traceManager.pushSpan("tool_request"); + plugin.afterToolCallback(mockTool, toolArgs, mockToolContext, result).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("TOOL_COMPLETED", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + ObjectNode contentMap = (ObjectNode) row.get("content"); + assertEquals("test_tool", contentMap.get("tool").asText()); + assertNotNull(contentMap.get("result")); + assertEquals("UNKNOWN", contentMap.get("tool_origin").asText()); + assertEquals(false, row.get("is_truncated")); + assertNotNull(row.get("latency_ms")); + } + + @Test + public void afterToolCallback_identifiesA2AOrigin() throws Exception { + ToolContext mockToolContext = mock(ToolContext.class); + when(mockToolContext.invocationContext()).thenReturn(mockInvocationContext); + + BaseAgent a2aAgent = + new FakeAgent("a2a_agent") { + @Override + public AgentOrigin toolOrigin() { + return AgentOrigin.A2A; + } + }; + + AgentTool a2aTool = AgentTool.create(a2aAgent); + + plugin.traceManager.pushSpan("tool_request"); + plugin + .afterToolCallback(a2aTool, ImmutableMap.of(), mockToolContext, ImmutableMap.of()) + .blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull(row); + ObjectNode contentMap = (ObjectNode) row.get("content"); + assertEquals("A2A", contentMap.get("tool_origin").asText()); + } + + @Test + public void logEvent_includesSessionMetadata_whenEnabled() throws Exception { + // Config default has logSessionMetadata(true) + Content content = Content.fromParts(Part.fromText("test message")); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull(row); + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertTrue("attributes should contain session_metadata", attributes.has("session_metadata")); + ObjectNode sessionMeta = (ObjectNode) attributes.get("session_metadata"); + assertEquals("session_id", sessionMeta.get("session_id").asText()); + assertEquals("test_user", sessionMeta.get("user_id").asText()); + assertEquals("test_app", sessionMeta.get("app_name").asText()); + } + + @Test + public void logEvent_excludesSessionMetadata_whenDisabled() throws Exception { + BigQueryLoggerConfig disabledConfig = config.toBuilder().setLogSessionMetadata(false).build(); + BigQueryAgentAnalyticsPlugin disabledPlugin = + new BigQueryAgentAnalyticsPlugin(disabledConfig, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + + @Override + protected TraceManager createTraceManager() { + return new TraceManager(GlobalOpenTelemetry.getTracer("test-plugin-disabled")); + } + }; + + Content content = Content.fromParts(Part.fromText("test message")); + disabledPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = disabledPlugin.batchProcessor.queue.poll(); + assertNotNull(row); + ObjectNode attributes = (ObjectNode) row.get("attributes"); + assertFalse( + "attributes should not contain session_metadata", attributes.has("session_metadata")); } private static class FakeAgent extends BaseAgent { diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java new file mode 100644 index 000000000..739f3a7c3 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java @@ -0,0 +1,145 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.google.adk.models.LlmRequest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class JsonFormatterTest { + + @Test + public void parse_llmRequest_populatesPrompt() { + LlmRequest request = + LlmRequest.builder() + .contents( + ImmutableList.of( + Content.fromParts(Part.fromText("hello")).toBuilder().role("user").build())) + .build(); + + JsonFormatter.ParsedContent result = JsonFormatter.parse(request, 100); + + assertTrue(result.content().has("prompt")); + ArrayNode prompt = (ArrayNode) result.content().get("prompt"); + assertEquals(1, prompt.size()); + assertEquals("user", prompt.get(0).get("role").asText()); + assertEquals("hello", prompt.get(0).get("content").asText()); + assertFalse(result.isTruncated()); + } + + @Test + public void parse_llmRequest_populatesSystemPrompt() { + LlmRequest request = + LlmRequest.builder() + .config( + GenerateContentConfig.builder() + .systemInstruction(Content.fromParts(Part.fromText("be helpful"))) + .build()) + .build(); + + JsonFormatter.ParsedContent result = JsonFormatter.parse(request, 100); + + assertTrue(result.content().has("system_prompt")); + assertEquals("be helpful", result.content().get("system_prompt").asText()); + assertFalse(result.isTruncated()); + } + + @Test + public void parse_string_truncates() { + String longString = "this is a very long string that should be truncated"; + JsonFormatter.ParsedContent result = JsonFormatter.parse(longString, 10); + + assertTrue(result.isTruncated()); + assertEquals("this is a ...[truncated]", result.content().asText()); + } + + @Test + public void parse_map_truncatesNested() { + ImmutableMap map = ImmutableMap.of("key", "this is a long value"); + JsonFormatter.ParsedContent result = JsonFormatter.parse(map, 10); + + assertTrue(result.isTruncated()); + assertEquals("this is a ...[truncated]", result.content().get("key").asText()); + } + + @Test + public void parse_content_returnsSummary() { + Content content = Content.fromParts(Part.fromText("part 1"), Part.fromText("part 2")); + JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + + assertEquals("part 1 | part 2", result.content().get("text_summary").asText()); + assertEquals(2, result.parts().size()); + } + + @Test + public void parse_content_withFileData() { + FileData fileData = + FileData.builder().fileUri("gs://bucket/file.txt").mimeType("text/plain").build(); + Content content = Content.fromParts(Part.builder().fileData(fileData).build()); + JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + + assertEquals(1, result.parts().size()); + JsonNode partData = result.parts().get(0); + assertEquals("EXTERNAL_URI", partData.get("storage_mode").asText()); + assertEquals("gs://bucket/file.txt", partData.get("uri").asText()); + assertEquals("text/plain", partData.get("mime_type").asText()); + } + + @Test + public void parse_content_withFunctionCall() { + FunctionCall fc = FunctionCall.builder().name("myFunction").build(); + Content content = Content.fromParts(Part.builder().functionCall(fc).build()); + JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + + assertEquals(1, result.parts().size()); + JsonNode partData = result.parts().get(0); + assertEquals("application/json", partData.get("mime_type").asText()); + assertEquals("Function: myFunction", partData.get("text").asText()); + assertTrue(partData.get("part_attributes").asText().contains("myFunction")); + } + + @Test + public void parse_list_truncatesElements() { + List list = + Arrays.asList("short", "this is a very long string that should be truncated"); + JsonFormatter.ParsedContent result = JsonFormatter.parse(list, 10); + + assertTrue(result.isTruncated()); + JsonNode arrayNode = result.content(); + assertTrue(arrayNode.isArray()); + assertEquals(2, arrayNode.size()); + assertEquals("short", arrayNode.get(0).asText()); + assertEquals("this is a ...[truncated]", arrayNode.get(1).asText()); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/TraceManagerTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/TraceManagerTest.java new file mode 100644 index 000000000..b7ce7823f --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/TraceManagerTest.java @@ -0,0 +1,230 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Scope; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import io.reactivex.rxjava3.core.Flowable; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TraceManagerTest { + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + private InvocationContext mockContext; + private BaseAgent mockAgent; + private Map callbackData; + private TraceManager traceManager; + private Tracer tracer; + + @Before + public void setUp() { + tracer = openTelemetryRule.getOpenTelemetry().getTracer("test"); + callbackData = new ConcurrentHashMap<>(); + mockContext = mock(InvocationContext.class); + when(mockContext.callbackContextData()).thenReturn(callbackData); + when(mockContext.invocationId()).thenReturn("test-invocation-id"); + mockAgent = + new BaseAgent("test-agent", "desc", null, null, null) { + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + }; + when(mockContext.agent()).thenReturn(mockAgent); + traceManager = new TraceManager(tracer); + } + + @Test + public void pushSpan_createsValidSpanId() { + String spanId = traceManager.pushSpan("test-span"); + assertNotNull(spanId); + assertTrue(spanId.length() >= 16); + } + + @Test + public void pushSpan_maintainsParentChildRelationship() { + String parentId = traceManager.pushSpan("parent"); + String childId = traceManager.pushSpan("child"); + + TraceManager.SpanIds ids = traceManager.getCurrentSpanAndParent(); + assertEquals(childId, ids.spanId().orElse(null)); + assertEquals(parentId, ids.parentSpanId().orElse(null)); + } + + @Test + public void popSpan_removesFromStack() { + String parentId = traceManager.pushSpan("parent"); + traceManager.pushSpan("child"); + + Optional popped = traceManager.popSpan(); + assertTrue(popped.isPresent()); + assertFalse(popped.get().duration().isNegative()); + + String currentId = traceManager.getCurrentSpanId().orElse(null); + assertEquals(parentId, currentId); + + TraceManager.SpanIds ids = traceManager.getCurrentSpanAndParent(); + assertEquals(parentId, ids.spanId().orElse(null)); + assertFalse(ids.parentSpanId().isPresent()); + } + + @Test + public void ensureInvocationSpan_isIdempotent() { + traceManager.ensureInvocationSpan(mockContext); + String id1 = traceManager.getCurrentSpanId().orElse(null); + + traceManager.ensureInvocationSpan(mockContext); + String id2 = traceManager.getCurrentSpanId().orElse(null); + + assertEquals(id1, id2); + } + + @Test + public void ensureInvocationSpan_clearsStaleRecords() { + Span ambientSpan = tracer.spanBuilder("ambient").startSpan(); + try (Scope scope = ambientSpan.makeCurrent()) { + traceManager.ensureInvocationSpan(mockContext); + } finally { + ambientSpan.end(); + } + String id1 = traceManager.getCurrentSpanId().orElse(null); + // Create a new context with same callback data but different invocation ID + InvocationContext mockContext2 = mock(InvocationContext.class); + when(mockContext2.callbackContextData()).thenReturn(callbackData); + when(mockContext2.invocationId()).thenReturn("new-invocation-id"); + when(mockContext2.agent()).thenReturn(mockAgent); + Span ambientSpan2 = tracer.spanBuilder("ambient2").startSpan(); + try (Scope scope = ambientSpan2.makeCurrent()) { + traceManager.ensureInvocationSpan(mockContext2); + } finally { + ambientSpan2.end(); + } + String id2 = traceManager.getCurrentSpanId().orElse(null); + + assertNotEquals(id1, id2); + // Should only have 1 record now + TraceManager.SpanIds ids = traceManager.getCurrentSpanAndParent(); + assertFalse(ids.parentSpanId().isPresent()); + } + + @Test + public void attachCurrentSpan_usesAmbientSpan() { + Span ambientSpan = tracer.spanBuilder("ambient").startSpan(); + try (Scope scope = ambientSpan.makeCurrent()) { + String attachedId = traceManager.attachCurrentSpan(); + String expectedId = ambientSpan.getSpanContext().getSpanId(); + assertEquals(expectedId, attachedId); + } finally { + ambientSpan.end(); + } + } + + @Test + public void getTraceId_returnsCurrentTraceId() { + traceManager.pushSpan("test"); + String traceId = traceManager.getTraceId(mockContext); + assertNotNull(traceId); + if (traceId.equals("test-invocation-id")) { + assertEquals("test-invocation-id", traceId); + } else { + assertTrue(traceId.matches("[0-9a-f]{32}")); + } + } + + @Test + public void getTraceId_returnsInvocationId_whenRecordsIsEmpty() { + String traceId = traceManager.getTraceId(mockContext); + if (traceManager.hasAmbientSpan()) { + assertTrue(traceId.matches("[0-9a-f]{32}")); + } else { + assertEquals("test-invocation-id", traceId); + } + } + + @Test + public void getTraceId_returnsAmbientTraceId_whenRecordsIsEmpty_butAmbientIsPresent() { + Span ambientSpan = tracer.spanBuilder("ambient").startSpan(); + try (Scope scope = ambientSpan.makeCurrent()) { + String expectedTraceId = ambientSpan.getSpanContext().getTraceId(); + String traceId = traceManager.getTraceId(mockContext); + assertEquals(expectedTraceId, traceId); + } finally { + ambientSpan.end(); + } + } + + @Test + public void attachCurrentSpan_worksWithoutAmbientSpan() { + // Ensure no ambient span + String attachedId = traceManager.attachCurrentSpan(); + assertNotNull(attachedId); + assertEquals(16, attachedId.length()); + + // Verify it's in records + assertEquals(attachedId, traceManager.getCurrentSpanId().orElse(null)); + } + + @Test + public void getTraceId_fallsBackToInvocationId_whenRecordSpanIsInvalid() { + // attachCurrentSpan when no ambient context exists creates an invalid span record + traceManager.attachCurrentSpan(); + + String traceId = traceManager.getTraceId(mockContext); + if (traceManager.hasAmbientSpan()) { + assertTrue(traceId.matches("[0-9a-f]{32}")); + } else { + assertEquals("test-invocation-id", traceId); + } + } + + @Test + public void popSpan_returnsEmpty_whenRecordsIsEmpty() { + Optional popped = traceManager.popSpan(); + assertFalse(popped.isPresent()); + } + + @Test + public void clearStack_doesNothing_whenRecordsIsEmpty() { + traceManager.clearStack(); + assertTrue(traceManager.getCurrentSpanAndParent().spanId().isEmpty()); + } +}