diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index e6fb94118..6a4e727ec 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -9,15 +9,39 @@ agent = config_to_agent("config.json") # Add tools that need code-based instantiation agent.tool_registry.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))]) + +The ``model`` field supports two formats: + +**String format (backward compatible — defaults to Bedrock):** + {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + +**Object format (supports all providers):** + { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 10000, + "client_args": {"api_key": "..."} + } + } + +Note: The following constructor parameters cannot be specified from JSON because they require +code-based instantiation: ``boto_session`` (Bedrock, SageMaker), ``client`` (OpenAI, Gemini), +``gemini_tools`` (Gemini). Use ``region_name`` / ``client_args`` as JSON-friendly alternatives. """ +from __future__ import annotations + import json from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import jsonschema from jsonschema import ValidationError +if TYPE_CHECKING: + from ..models.model import Model + # JSON Schema for agent configuration AGENT_CONFIG_SCHEMA = { "$schema": "http://json-schema.org/draft-07/schema#", @@ -27,8 +51,25 @@ "properties": { "name": {"description": "Name of the agent", "type": ["string", "null"], "default": None}, "model": { - "description": "The model ID to use for this agent. If not specified, uses the default model.", - "type": ["string", "null"], + "description": ( + "The model to use for this agent. Can be a string (Bedrock model_id) " + "or an object with a 'provider' field for any supported provider." + ), + "oneOf": [ + {"type": "string"}, + {"type": "null"}, + { + "type": "object", + "properties": { + "provider": { + "description": "The model provider name", + "type": "string", + } + }, + "required": ["provider"], + "additionalProperties": True, + }, + ], "default": None, }, "prompt": { @@ -50,8 +91,55 @@ # Pre-compile validator for better performance _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) +# Provider name to model class name — resolved via strands.models lazy __getattr__ +PROVIDER_MAP: dict[str, str] = { + "bedrock": "BedrockModel", + "anthropic": "AnthropicModel", + "openai": "OpenAIModel", + "gemini": "GeminiModel", + "ollama": "OllamaModel", + "litellm": "LiteLLMModel", + "mistral": "MistralModel", + "llamaapi": "LlamaAPIModel", + "llamacpp": "LlamaCppModel", + "sagemaker": "SageMakerAIModel", + "writer": "WriterModel", + "openai_responses": "OpenAIResponsesModel", +} + -def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any: +def _create_model_from_dict(model_config: dict[str, Any]) -> Model: + """Create a Model instance from a provider config dict. + + Routes the config to the appropriate model class based on the ``provider`` field, + then delegates to the class's ``from_dict`` method. All imports are lazy to avoid + requiring optional dependencies that are not installed. + + Args: + model_config: Dict containing at least a ``provider`` key and provider-specific params. + + Returns: + A configured Model instance for the specified provider. + + Raises: + ValueError: If the provider name is not recognized. + ImportError: If the provider's optional dependencies are not installed. + """ + config = model_config.copy() + provider = config.pop("provider") + + class_name = PROVIDER_MAP.get(provider) + if class_name is None: + supported = ", ".join(sorted(PROVIDER_MAP.keys())) + raise ValueError(f"Unknown model provider: '{provider}'. Supported providers: {supported}") + + from .. import models + + model_cls: type[Model] = getattr(models, class_name) + return model_cls.from_dict(config) + + +def config_to_agent(config: str | dict[str, Any], **kwargs: Any) -> Any: """Create an Agent from a configuration file or dictionary. This function supports tools that can be loaded declaratively (file paths, module names, @@ -83,6 +171,12 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A Create agent from dictionary: >>> config = {"model": "anthropic.claude-3-5-sonnet-20241022-v2:0", "tools": ["calculator"]} >>> agent = config_to_agent(config) + + Create agent with object model config: + >>> config = { + ... "model": {"provider": "openai", "model_id": "gpt-4o", "client_args": {"api_key": "..."}} + ... } + >>> agent = config_to_agent(config) """ # Parse configuration if isinstance(config, str): @@ -114,11 +208,19 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A raise ValueError(f"Configuration validation error at {error_path}: {e.message}") from e # Prepare Agent constructor arguments - agent_kwargs = {} - - # Map configuration keys to Agent constructor parameters + agent_kwargs: dict[str, Any] = {} + + # Handle model field — string vs object format + model_value = config_dict.get("model") + if isinstance(model_value, dict): + # Object format: create Model instance via factory + agent_kwargs["model"] = _create_model_from_dict(model_value) + elif model_value is not None: + # String format (backward compat): pass directly as model_id to Agent + agent_kwargs["model"] = model_value + + # Map remaining configuration keys to Agent constructor parameters config_mapping = { - "model": "model", "prompt": "system_prompt", "tools": "tools", "name": "name", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bfb7b1ede..bc906ba1a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -127,6 +127,34 @@ class BedrockConfig(TypedDict, total=False): temperature: float | None top_p: float | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "BedrockModel": + """Create a BedrockModel from a configuration dictionary. + + Handles extraction of ``region_name``, ``endpoint_url``, and conversion of + ``boto_client_config`` from a plain dict to ``botocore.config.Config``. + + Args: + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. + + Returns: + A configured BedrockModel instance. + """ + config = config.copy() + kwargs: dict[str, Any] = {} + + if "region_name" in config: + kwargs["region_name"] = config.pop("region_name") + if "endpoint_url" in config: + kwargs["endpoint_url"] = config.pop("endpoint_url") + if "boto_client_config" in config: + raw = config.pop("boto_client_config") + kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw + + kwargs.update(config) + return cls(**kwargs) + def __init__( self, *, diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c52509816..823a41f36 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -131,6 +131,28 @@ class LlamaCppConfig(TypedDict, total=False): model_id: str params: dict[str, Any] | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "LlamaCppModel": + """Create a LlamaCppModel from a configuration dictionary. + + Handles extraction of ``base_url`` and ``timeout`` as separate constructor parameters. + + Args: + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. + + Returns: + A configured LlamaCppModel instance. + """ + config = config.copy() + kwargs: dict[str, Any] = {} + if "base_url" in config: + kwargs["base_url"] = config.pop("base_url") + if "timeout" in config: + kwargs["timeout"] = config.pop("timeout") + kwargs.update(config) + return cls(**kwargs) + def __init__( self, base_url: str = "http://localhost:8080", diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index f44a11d30..2203ea29a 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -53,6 +53,30 @@ class MistralConfig(TypedDict, total=False): top_p: float | None stream: bool | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "MistralModel": + """Create a MistralModel from a configuration dictionary. + + Handles extraction of ``api_key`` and ``client_args`` as separate constructor parameters. + + Args: + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. + + Returns: + A configured MistralModel instance. + """ + config = config.copy() + api_key = config.pop("api_key", None) + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if api_key is not None: + kwargs["api_key"] = api_key + if client_args is not None: + kwargs["client_args"] = client_args + kwargs.update(config) + return cls(**kwargs) + def __init__( self, api_key: str | None = None, diff --git a/src/strands/models/model.py b/src/strands/models/model.py index f084d24d5..cfd95cb3c 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -1,5 +1,7 @@ """Abstract base class for Agent model providers.""" +from __future__ import annotations + import abc import logging from collections.abc import AsyncGenerator, AsyncIterable @@ -51,6 +53,29 @@ def stateful(self) -> bool: """ return False + @classmethod + def from_dict(cls, config: dict[str, Any]) -> Model: + """Create a Model instance from a configuration dictionary. + + The default implementation extracts ``client_args`` (if present) and passes + all remaining keys as keyword arguments to the constructor. Subclasses with + non-standard constructor signatures should override this method. + + Args: + config: Provider-specific configuration dictionary. A copy is made internally; + the caller's dict is not modified. + + Returns: + A configured Model instance. + """ + config = config.copy() + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if client_args is not None: + kwargs["client_args"] = client_args + kwargs.update(config) + return cls(**kwargs) + @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 97cb7948a..41217dd47 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -56,6 +56,29 @@ class OllamaConfig(TypedDict, total=False): temperature: float | None top_p: float | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "OllamaModel": + """Create an OllamaModel from a configuration dictionary. + + Handles extraction of ``host`` as a positional argument and mapping of + ``client_args`` to the ``ollama_client_args`` constructor parameter. + + Args: + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. + + Returns: + A configured OllamaModel instance. + """ + config = config.copy() + host = config.pop("host", None) + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if client_args is not None: + kwargs["ollama_client_args"] = client_args + kwargs.update(config) + return cls(host, **kwargs) + def __init__( self, host: str | None, diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 775969290..9d3178650 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -133,6 +133,32 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): target_variant: str | None | None additional_args: dict[str, Any] | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "SageMakerAIModel": + """Create a SageMakerAIModel from a configuration dictionary. + + Handles extraction of ``endpoint_config``, ``payload_config``, and conversion of + ``boto_client_config`` from a plain dict to ``botocore.config.Config``. + + Args: + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. + + Returns: + A configured SageMakerAIModel instance. + """ + config = config.copy() + kwargs: dict[str, Any] = {} + kwargs["endpoint_config"] = config.pop("endpoint_config", {}) + kwargs["payload_config"] = config.pop("payload_config", {}) + if "boto_client_config" in config: + raw = config.pop("boto_client_config") + kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw + if config: + unexpected = ", ".join(sorted(config.keys())) + raise ValueError(f"Unsupported SageMaker config keys: {unexpected}") + return cls(**kwargs) + def __init__( self, endpoint_config: SageMakerAIEndpointConfig, diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py index e6188079b..cd2b0fad3 100644 --- a/tests/strands/experimental/test_agent_config.py +++ b/tests/strands/experimental/test_agent_config.py @@ -3,10 +3,19 @@ import json import os import tempfile +from unittest.mock import MagicMock, patch import pytest from strands.experimental import config_to_agent +from strands.experimental.agent_config import ( + PROVIDER_MAP, + _create_model_from_dict, +) + +# ============================================================================= +# Backward compatibility tests (existing) +# ============================================================================= def test_config_to_agent_with_dict(): @@ -170,3 +179,252 @@ def test_config_to_agent_with_tool(): config = {"model": "test-model", "tools": ["tests.fixtures.say_tool:say"]} agent = config_to_agent(config) assert "say" in agent.tool_names + + +# ============================================================================= +# Schema validation tests — dual-format model field +# ============================================================================= + + +class TestSchemaValidation: + """Tests for the updated AGENT_CONFIG_SCHEMA that supports both string and object model formats.""" + + def test_string_model_valid(self): + """Test that string model format still passes validation.""" + config = {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "us.anthropic.claude-sonnet-4-20250514-v1:0" + + def test_object_model_valid(self): + """Test that object model format passes schema validation.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 10000, + } + } + agent = config_to_agent(config) + assert agent.model is mock_model + + def test_object_model_missing_provider_raises(self): + """Test that object model without provider raises validation error.""" + config = {"model": {"model_id": "some-model"}} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + def test_object_model_allows_additional_properties(self): + """Test that object model format allows provider-specific properties.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + "client_args": {"api_key": "test"}, + "custom_field": "allowed", + } + } + # Should not raise + config_to_agent(config) + + def test_null_model_still_valid(self): + """Test that null model is still accepted for default behavior.""" + config = {"model": None} + agent = config_to_agent(config) + # Should use default model + assert agent is not None + + def test_model_wrong_type_raises(self): + """Test that model field with invalid type raises validation error.""" + config = {"model": 12345} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + def test_object_model_from_file(self): + """Test object model format loaded from a JSON file.""" + mock_model = MagicMock() + config_data = { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + } + } + temp_path = "" + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + json.dump(config_data, f) + f.flush() + temp_path = f.name + + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + agent = config_to_agent(temp_path) + assert agent.model is mock_model + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +# ============================================================================= +# Provider factory tests +# ============================================================================= + + +class TestCreateModelFromConfig: + """Tests for _create_model_from_dict dispatching to cls.from_dict.""" + + def test_unknown_provider_raises(self): + """Test that an unknown provider name raises ValueError.""" + with pytest.raises(ValueError, match="Unknown model provider: 'nonexistent'"): + _create_model_from_dict({"provider": "nonexistent", "model_id": "x"}) + + def _set_mock_on_models(self, class_name): + """Inject a mock class directly into strands.models.__dict__ to avoid triggering lazy imports.""" + import strands.models as models_pkg + + mock_cls = MagicMock() + mock_cls.from_dict.return_value = MagicMock() + original = models_pkg.__dict__.get(class_name) + models_pkg.__dict__[class_name] = mock_cls + return mock_cls, original + + def _restore_models(self, class_name, original): + """Restore original state of strands.models after test.""" + import strands.models as models_pkg + + if original is None: + models_pkg.__dict__.pop(class_name, None) + else: + models_pkg.__dict__[class_name] = original + + def test_dispatches_to_from_dict(self): + """Test that _create_model_from_dict calls cls.from_dict on the resolved model class.""" + mock_cls, original = self._set_mock_on_models("AnthropicModel") + mock_model = MagicMock() + mock_cls.from_dict.return_value = mock_model + try: + result = _create_model_from_dict( + { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 8192, + "client_args": {"api_key": "test-key"}, + } + ) + mock_cls.from_dict.assert_called_once() + call_config = mock_cls.from_dict.call_args[0][0] + assert call_config["model_id"] == "claude-sonnet-4-20250514" + assert call_config["max_tokens"] == 8192 + assert call_config["client_args"] == {"api_key": "test-key"} + assert "provider" not in call_config + assert result is mock_model + finally: + self._restore_models("AnthropicModel", original) + + def test_does_not_mutate_input(self): + """Test that _create_model_from_dict does not mutate the input dict.""" + mock_cls, original = self._set_mock_on_models("AnthropicModel") + try: + original_input = {"provider": "anthropic", "model_id": "test"} + original_copy = original_input.copy() + _create_model_from_dict(original_input) + assert original_input == original_copy + finally: + self._restore_models("AnthropicModel", original) + + @pytest.mark.parametrize( + "provider,class_name", + list(PROVIDER_MAP.items()), + ) + def test_all_providers_dispatch(self, provider, class_name): + """Test that each registered provider dispatches to the correct class.""" + mock_cls, original = self._set_mock_on_models(class_name) + try: + _create_model_from_dict({"provider": provider, "model_id": "test"}) + mock_cls.from_dict.assert_called_once() + finally: + self._restore_models(class_name, original) + + +# ============================================================================= +# Error handling tests +# ============================================================================= + + +class TestErrorHandling: + """Tests for error handling in model creation.""" + + def test_missing_optional_dependency(self): + """Test clear error when provider dependency is not installed.""" + import strands.models as models_pkg + + mock_cls = MagicMock() + mock_cls.from_dict.side_effect = ImportError("No module named 'anthropic'") + + original = models_pkg.__dict__.get("AnthropicModel") + models_pkg.__dict__["AnthropicModel"] = mock_cls + try: + with pytest.raises(ImportError, match="anthropic"): + _create_model_from_dict( + { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + } + ) + finally: + if original is None: + models_pkg.__dict__.pop("AnthropicModel", None) + else: + models_pkg.__dict__["AnthropicModel"] = original + + +# ============================================================================= +# Integration: config_to_agent with object model +# ============================================================================= + + +class TestConfigToAgentObjectModel: + """Tests for config_to_agent using the object model format end-to-end.""" + + def test_object_model_creates_agent(self): + """Test that object model config creates an agent with the correct model.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + }, + "prompt": "You are helpful", + } + agent = config_to_agent(config) + assert agent.model is mock_model + assert agent.system_prompt == "You are helpful" + + def test_object_model_with_kwargs_override(self): + """Test that kwargs can still override when using object model.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": {"provider": "openai", "model_id": "gpt-4o"}, + "prompt": "Original prompt", + } + agent = config_to_agent(config, system_prompt="Override prompt") + assert agent.system_prompt == "Override prompt" diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index cd7016488..8b6a6ca18 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2964,3 +2964,47 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model, assert citation["location"] == {"web": {"url": "https://example.com", "domain": "example.com"}} assert "title" not in citation assert "sourceContent" not in citation + + +class TestBedrockFromDict: + """Tests for BedrockModel.from_dict classmethod.""" + + def test_from_dict_boto_client_config_conversion(self): + """Test that from_dict converts boto_client_config dict to BotocoreConfig.""" + with unittest.mock.patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-west-2", + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-west-2" + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + assert call_kwargs["model_id"] == "test-model" + + def test_from_dict_without_boto_client_config(self): + """Test from_dict without boto_client_config.""" + with unittest.mock.patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-east-1", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-east-1" + assert "boto_client_config" not in call_kwargs + + def test_from_dict_endpoint_url(self): + """Test from_dict with endpoint_url.""" + with unittest.mock.patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "endpoint_url": "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_url"] == "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com" diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 3e023dfce..e7a50e46c 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -706,3 +706,24 @@ def test_format_request_filters_location_source_document(caplog) -> None: assert len(user_content) == 1 assert user_content[0]["type"] == "text" assert "Location sources are not supported by llama.cpp" in caplog.text + + +class TestLlamaCppFromDict: + """Tests for LlamaCppModel.from_dict classmethod.""" + + def test_from_dict_base_url_and_timeout(self): + """Test that from_dict extracts base_url and timeout.""" + from strands.models.llamacpp import LlamaCppModel + + with patch.object(LlamaCppModel, "__init__", return_value=None) as mock_init: + LlamaCppModel.from_dict( + { + "model_id": "default", + "base_url": "http://myhost:8080", + "timeout": 30.0, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["base_url"] == "http://myhost:8080" + assert call_kwargs["timeout"] == 30.0 + assert call_kwargs["model_id"] == "default" diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 57189748e..6bd8d657f 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -679,3 +679,22 @@ def test_format_request_filters_location_source_document(model, caplog): user_content = formatted_messages[0]["content"] assert user_content == "analyze this document" assert "Location sources are not supported by Mistral" in caplog.text + + +class TestMistralFromDict: + """Tests for MistralModel.from_dict classmethod.""" + + def test_from_dict_api_key_extraction(self): + """Test that from_dict extracts api_key separately.""" + with unittest.mock.patch.object(MistralModel, "__init__", return_value=None) as mock_init: + MistralModel.from_dict( + { + "model_id": "mistral-large-latest", + "api_key": "test-key", + "client_args": {"timeout": 60}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["api_key"] == "test-key" + assert call_kwargs["client_args"] == {"timeout": 60} + assert call_kwargs["model_id"] == "mistral-large-latest" diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 458e98645..37383ce44 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -213,3 +213,46 @@ def test_model_plugin_preserves_messages_when_not_stateful(model_plugin): model_plugin._on_after_invocation(event) assert len(agent.messages) == 1 + + +class TestModelFromDict: + """Tests for the default Model.from_dict classmethod.""" + + def test_from_dict_with_client_args(self): + """Test that from_dict extracts client_args and passes remaining kwargs.""" + from unittest.mock import patch + + from strands.models.bedrock import BedrockModel + + # Use BedrockModel to test from_dict invocation since it is always available; + # the base Model.from_dict is tested indirectly via the bedrock override path + # not executing (we only verify kwarg routing here). + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + # Invoke the *base* Model.from_dict by calling it on a subclass that + # does NOT override from_dict. BedrockModel overrides it, so we call + # the base implementation directly for this test. + SAModel.from_dict.__func__( + BedrockModel, + { + "model_id": "test-model", + "client_args": {"api_key": "test"}, + "max_tokens": 4096, + }, + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["client_args"] == {"api_key": "test"} + assert call_kwargs["model_id"] == "test-model" + assert call_kwargs["max_tokens"] == 4096 + + def test_from_dict_without_client_args(self): + """Test that from_dict works without client_args.""" + from unittest.mock import patch + + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + SAModel.from_dict.__func__(BedrockModel, {"model_id": "test-model", "max_tokens": 1024}) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["model_id"] == "test-model" + assert call_kwargs["max_tokens"] == 1024 + assert "client_args" not in call_kwargs diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 0d4fbb9e0..844f369d7 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -625,3 +625,29 @@ def test_format_request_filters_location_source_document(model, caplog): user_message = formatted_messages[0] assert user_message["content"] == "analyze this document" assert "Location sources are not supported by Ollama" in caplog.text + + +class TestOllamaFromDict: + """Tests for OllamaModel.from_dict classmethod.""" + + def test_from_dict_host_and_client_args_mapping(self): + """Test that from_dict routes host and maps client_args to ollama_client_args.""" + with unittest.mock.patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict( + { + "model_id": "llama3", + "host": "http://localhost:11434", + "client_args": {"timeout": 30}, + } + ) + call_args = mock_init.call_args + assert call_args[0][0] == "http://localhost:11434" + assert call_args[1]["ollama_client_args"] == {"timeout": 30} + assert call_args[1]["model_id"] == "llama3" + + def test_from_dict_default_host(self): + """Test from_dict with no host specified defaults to None.""" + with unittest.mock.patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict({"model_id": "llama3"}) + call_args = mock_init.call_args + assert call_args[0][0] is None diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index 5d6d6869a..1ce11f5bd 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -665,3 +665,44 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +class TestSageMakerFromDict: + """Tests for SageMakerAIModel.from_dict classmethod.""" + + def test_from_dict_dict_params(self): + """Test that from_dict receives endpoint_config and payload_config as dicts.""" + with unittest.mock.patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep", "region_name": "us-west-2"}, + "payload_config": {"max_tokens": 1024, "stream": True}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_config"] == {"endpoint_name": "my-ep", "region_name": "us-west-2"} + assert call_kwargs["payload_config"] == {"max_tokens": 1024, "stream": True} + + def test_from_dict_boto_client_config_conversion(self): + """Test that from_dict converts boto_client_config dict to BotocoreConfig.""" + with unittest.mock.patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep"}, + "payload_config": {"max_tokens": 1024}, + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + + def test_from_dict_rejects_unexpected_keys(self): + """Test that from_dict raises ValueError on unexpected config keys.""" + with pytest.raises(ValueError, match="Unsupported SageMaker config keys"): + SageMakerAIModel.from_dict( + { + "endpoint_config": {}, + "payload_config": {}, + "model_id": "unexpected", + } + )