diff --git a/contributing/samples/token_usage/agent.py b/contributing/samples/token_usage/agent.py index a73f9e7638..35b9775706 100755 --- a/contributing/samples/token_usage/agent.py +++ b/contributing/samples/token_usage/agent.py @@ -26,26 +26,26 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: - """Roll a die and return the rolled result. + """Roll a die and return the rolled result. - Args: - sides: The integer number of sides the die has. + Args: + sides: The integer number of sides the die has. - Returns: - An integer of the result of rolling the die. - """ - result = random.randint(1, sides) - if 'rolls' not in tool_context.state: - tool_context.state['rolls'] = [] + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if "rolls" not in tool_context.state: + tool_context.state["rolls"] = [] - tool_context.state['rolls'] = tool_context.state['rolls'] + [result] - return result + tool_context.state["rolls"] = tool_context.state["rolls"] + [result] + return result roll_agent_with_openai = LlmAgent( - model=LiteLlm(model='openai/gpt-4o'), - description='Handles rolling dice of different sizes.', - name='roll_agent_with_openai', + model=LiteLlm(model="openai/gpt-4o"), + description="Handles rolling dice of different sizes.", + name="roll_agent_with_openai", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -54,9 +54,9 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) roll_agent_with_claude = LlmAgent( - model=Claude(model='claude-3-7-sonnet@20250219'), - description='Handles rolling dice of different sizes.', - name='roll_agent_with_claude', + model=Claude(model="claude-3-7-sonnet@20250219"), + description="Handles rolling dice of different sizes.", + name="roll_agent_with_claude", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -65,9 +65,9 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) roll_agent_with_litellm_claude = LlmAgent( - model=LiteLlm(model='vertex_ai/claude-3-7-sonnet'), - description='Handles rolling dice of different sizes.', - name='roll_agent_with_litellm_claude', + model=LiteLlm(model="vertex_ai/claude-3-7-sonnet"), + description="Handles rolling dice of different sizes.", + name="roll_agent_with_litellm_claude", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -76,9 +76,9 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) roll_agent_with_gemini = LlmAgent( - model='gemini-2.0-flash', - description='Handles rolling dice of different sizes.', - name='roll_agent_with_gemini', + model="gemini-2.0-flash", + description="Handles rolling dice of different sizes.", + name="roll_agent_with_gemini", instruction=""" You are responsible for rolling dice based on the user's request. When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. @@ -87,7 +87,7 @@ def roll_die(sides: int, tool_context: ToolContext) -> int: ) root_agent = SequentialAgent( - name='code_pipeline_agent', + name="code_pipeline_agent", sub_agents=[ roll_agent_with_openai, roll_agent_with_claude, diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 6c20b1b9a5..730689efa5 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -29,9 +29,11 @@ from typing import TYPE_CHECKING from typing import Union -from anthropic import AnthropicVertex +from anthropic import AsyncAnthropic +from anthropic import AsyncAnthropicVertex from anthropic import NOT_GIVEN from anthropic import types as anthropic_types +from anthropic.resources.messages import AsyncMessages from google.genai import types from pydantic import BaseModel from typing_extensions import override @@ -244,8 +246,8 @@ def function_declaration_to_tool_param( ) -class Claude(BaseLlm): - """Integration with Claude models served from Vertex AI. +class AnthropicClaude(BaseLlm): + """Integration with Claude models served from Anthropic. Attributes: model: The name of the Claude model. @@ -284,7 +286,7 @@ async def generate_content_async( else NOT_GIVEN ) # TODO(b/421255973): Enable streaming for anthropic models. - message = self._anthropic_client.messages.create( + message = await self._anthropic_client.create( model=llm_request.model, system=llm_request.config.system_instruction, messages=messages, @@ -295,7 +297,21 @@ async def generate_content_async( yield message_to_generate_content_response(message) @cached_property - def _anthropic_client(self) -> AnthropicVertex: + def _anthropic_client(self) -> AsyncMessages: + return AsyncAnthropic().messages + + +class Claude(AnthropicClaude): + """Integration with Claude models served from Vertex AI. + + Attributes: + model: The name of the Claude model. + max_tokens: The maximum number of tokens to generate. + """ + + @cached_property + @override + def _anthropic_client(self) -> AsyncMessages: if ( "GOOGLE_CLOUD_PROJECT" not in os.environ or "GOOGLE_CLOUD_LOCATION" not in os.environ @@ -305,7 +321,7 @@ def _anthropic_client(self) -> AnthropicVertex: " Anthropic on Vertex." ) - return AnthropicVertex( + return AsyncAnthropicVertex( project_id=os.environ["GOOGLE_CLOUD_PROJECT"], region=os.environ["GOOGLE_CLOUD_LOCATION"], - ) + ).messages diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index a81fbc7252..fe4c00718f 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -295,7 +295,9 @@ async def test_function_declaration_to_tool_param( async def test_generate_content_async( claude_llm, llm_request, generate_content_response, generate_llm_response ): - with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: + with mock.patch.object( + claude_llm, "_anthropic_client" + ) as mock_messages_client: with mock.patch.object( anthropic_llm, "message_to_generate_content_response", @@ -306,7 +308,7 @@ async def mock_coro(): return generate_content_response # Assign the coroutine to the mocked method - mock_client.messages.create.return_value = mock_coro() + mock_messages_client.create.return_value = mock_coro() responses = [ resp @@ -324,7 +326,9 @@ async def test_generate_content_async_with_max_tokens( llm_request, generate_content_response, generate_llm_response ): claude_llm = Claude(model="claude-3-5-sonnet-v2@20241022", max_tokens=4096) - with mock.patch.object(claude_llm, "_anthropic_client") as mock_client: + with mock.patch.object( + claude_llm, "_anthropic_client" + ) as mock_messages_client: with mock.patch.object( anthropic_llm, "message_to_generate_content_response", @@ -335,7 +339,7 @@ async def mock_coro(): return generate_content_response # Assign the coroutine to the mocked method - mock_client.messages.create.return_value = mock_coro() + mock_messages_client.create.return_value = mock_coro() _ = [ resp @@ -343,6 +347,6 @@ async def mock_coro(): llm_request, stream=False ) ] - mock_client.messages.create.assert_called_once() - _, kwargs = mock_client.messages.create.call_args + mock_messages_client.create.assert_called_once() + _, kwargs = mock_messages_client.create.call_args assert kwargs["max_tokens"] == 4096