From e3c44bfd7f52e1810fe6dc724cd2d59aaf9b86cb Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 10 Apr 2026 17:10:30 -0400 Subject: [PATCH 1/4] feat: add support for all models to agent config --- src/strands/experimental/agent_config.py | 147 ++++- src/strands/models/bedrock.py | 26 + src/strands/models/llamacpp.py | 20 + src/strands/models/mistral.py | 22 + src/strands/models/model.py | 23 + src/strands/models/ollama.py | 21 + src/strands/models/sagemaker.py | 21 + .../strands/experimental/test_agent_config.py | 549 ++++++++++++++++++ 8 files changed, 824 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index e6fb94118..26510e6c6 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -9,9 +9,33 @@ agent = config_to_agent("config.json") # Add tools that need code-based instantiation agent.tool_registry.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))]) + +The ``model`` field supports two formats: + +**String format (backward compatible — defaults to Bedrock):** + {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + +**Object format (supports all providers):** + { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 10000, + "client_args": {"api_key": "$ANTHROPIC_API_KEY"} + } + } + +Environment variable references (``$VAR`` or ``${VAR}``) in model config values are resolved +automatically before provider instantiation. + +Note: The following constructor parameters cannot be specified from JSON because they require +code-based instantiation: ``boto_session`` (Bedrock, SageMaker), ``client`` (OpenAI, Gemini), +``gemini_tools`` (Gemini). Use ``region_name`` / ``client_args`` as JSON-friendly alternatives. """ import json +import os +import re from pathlib import Path from typing import Any @@ -27,8 +51,25 @@ "properties": { "name": {"description": "Name of the agent", "type": ["string", "null"], "default": None}, "model": { - "description": "The model ID to use for this agent. If not specified, uses the default model.", - "type": ["string", "null"], + "description": ( + "The model to use for this agent. Can be a string (Bedrock model_id) " + "or an object with a 'provider' field for any supported provider." + ), + "oneOf": [ + {"type": "string"}, + {"type": "null"}, + { + "type": "object", + "properties": { + "provider": { + "description": "The model provider name", + "type": "string", + } + }, + "required": ["provider"], + "additionalProperties": True, + }, + ], "default": None, }, "prompt": { @@ -50,6 +91,87 @@ # Pre-compile validator for better performance _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) +# Pattern for matching environment variable references +_ENV_VAR_PATTERN = re.compile(r"^\$\{([^}]+)\}$|^\$([A-Za-z_][A-Za-z0-9_]*)$") + +# Provider name to model class name — resolved via strands.models lazy __getattr__ +PROVIDER_MAP: dict[str, str] = { + "bedrock": "BedrockModel", + "anthropic": "AnthropicModel", + "openai": "OpenAIModel", + "gemini": "GeminiModel", + "ollama": "OllamaModel", + "litellm": "LiteLLMModel", + "mistral": "MistralModel", + "llamaapi": "LlamaAPIModel", + "llamacpp": "LlamaCppModel", + "sagemaker": "SageMakerAIModel", + "writer": "WriterModel", + "openai_responses": "OpenAIResponsesModel", +} + + +def _resolve_env_vars(value: Any) -> Any: + """Recursively resolve environment variable references in config values. + + String values matching ``$VAR_NAME`` or ``${VAR_NAME}`` are replaced with the + corresponding environment variable value. Dicts and lists are traversed recursively. + + Args: + value: The value to resolve. Can be a string, dict, list, or any other type. + + Returns: + The resolved value with environment variable references replaced. + + Raises: + ValueError: If a referenced environment variable is not set. + """ + if isinstance(value, str): + match = _ENV_VAR_PATTERN.match(value) + if match: + var_name = match.group(1) or match.group(2) + env_value = os.environ.get(var_name) + if env_value is None: + raise ValueError(f"Environment variable '{var_name}' is not set") + return env_value + return value + if isinstance(value, dict): + return {k: _resolve_env_vars(v) for k, v in value.items()} + if isinstance(value, list): + return [_resolve_env_vars(item) for item in value] + return value + + +def _create_model_from_dict(model_config: dict[str, Any]) -> Any: + """Create a Model instance from a provider config dict. + + Routes the config to the appropriate model class based on the ``provider`` field, + then delegates to the class's ``from_dict`` method. All imports are lazy to avoid + requiring optional dependencies that are not installed. + + Args: + model_config: Dict containing at least a ``provider`` key and provider-specific params. + + Returns: + A configured Model instance for the specified provider. + + Raises: + ValueError: If the provider name is not recognized. + ImportError: If the provider's optional dependencies are not installed. + """ + config = model_config.copy() + provider = config.pop("provider") + + class_name = PROVIDER_MAP.get(provider) + if class_name is None: + supported = ", ".join(sorted(PROVIDER_MAP.keys())) + raise ValueError(f"Unknown model provider: '{provider}'. Supported providers: {supported}") + + from .. import models + + model_cls = getattr(models, class_name) + return model_cls.from_dict(config) + def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any: """Create an Agent from a configuration file or dictionary. @@ -83,6 +205,12 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A Create agent from dictionary: >>> config = {"model": "anthropic.claude-3-5-sonnet-20241022-v2:0", "tools": ["calculator"]} >>> agent = config_to_agent(config) + + Create agent with object model config: + >>> config = { + ... "model": {"provider": "openai", "model_id": "gpt-4o", "client_args": {"api_key": "$OPENAI_API_KEY"}} + ... } + >>> agent = config_to_agent(config) """ # Parse configuration if isinstance(config, str): @@ -114,11 +242,20 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A raise ValueError(f"Configuration validation error at {error_path}: {e.message}") from e # Prepare Agent constructor arguments - agent_kwargs = {} + agent_kwargs: dict[str, Any] = {} + + # Handle model field — string vs object format + model_value = config_dict.get("model") + if isinstance(model_value, dict): + # Object format: resolve env vars and create Model instance via factory + resolved_config = _resolve_env_vars(model_value) + agent_kwargs["model"] = _create_model_from_dict(resolved_config) + elif model_value is not None: + # String format (backward compat): pass directly as model_id to Agent + agent_kwargs["model"] = model_value - # Map configuration keys to Agent constructor parameters + # Map remaining configuration keys to Agent constructor parameters config_mapping = { - "model": "model", "prompt": "system_prompt", "tools": "tools", "name": "name", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bfb7b1ede..c3c246fc3 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -127,6 +127,32 @@ class BedrockConfig(TypedDict, total=False): temperature: float | None top_p: float | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "BedrockModel": + """Create a BedrockModel from a configuration dictionary. + + Handles extraction of ``region_name``, ``endpoint_url``, and conversion of + ``boto_client_config`` from a plain dict to ``botocore.config.Config``. + + Args: + config: Model configuration dictionary. + + Returns: + A configured BedrockModel instance. + """ + kwargs: dict[str, Any] = {} + + if "region_name" in config: + kwargs["region_name"] = config.pop("region_name") + if "endpoint_url" in config: + kwargs["endpoint_url"] = config.pop("endpoint_url") + if "boto_client_config" in config: + raw = config.pop("boto_client_config") + kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw + + kwargs.update(config) + return cls(**kwargs) + def __init__( self, *, diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c52509816..36da4ca03 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -131,6 +131,26 @@ class LlamaCppConfig(TypedDict, total=False): model_id: str params: dict[str, Any] | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "LlamaCppModel": + """Create a LlamaCppModel from a configuration dictionary. + + Handles extraction of ``base_url`` and ``timeout`` as separate constructor parameters. + + Args: + config: Model configuration dictionary. + + Returns: + A configured LlamaCppModel instance. + """ + kwargs: dict[str, Any] = {} + if "base_url" in config: + kwargs["base_url"] = config.pop("base_url") + if "timeout" in config: + kwargs["timeout"] = config.pop("timeout") + kwargs.update(config) + return cls(**kwargs) + def __init__( self, base_url: str = "http://localhost:8080", diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index f44a11d30..9bdb8ced9 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -53,6 +53,28 @@ class MistralConfig(TypedDict, total=False): top_p: float | None stream: bool | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "MistralModel": + """Create a MistralModel from a configuration dictionary. + + Handles extraction of ``api_key`` and ``client_args`` as separate constructor parameters. + + Args: + config: Model configuration dictionary. + + Returns: + A configured MistralModel instance. + """ + api_key = config.pop("api_key", None) + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if api_key is not None: + kwargs["api_key"] = api_key + if client_args is not None: + kwargs["client_args"] = client_args + kwargs.update(config) + return cls(**kwargs) + def __init__( self, api_key: str | None = None, diff --git a/src/strands/models/model.py b/src/strands/models/model.py index f084d24d5..0c3a5c7b8 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -1,5 +1,7 @@ """Abstract base class for Agent model providers.""" +from __future__ import annotations + import abc import logging from collections.abc import AsyncGenerator, AsyncIterable @@ -51,6 +53,27 @@ def stateful(self) -> bool: """ return False + @classmethod + def from_dict(cls, config: dict[str, Any]) -> Model: + """Create a Model instance from a configuration dictionary. + + The default implementation extracts ``client_args`` (if present) and passes + all remaining keys as keyword arguments to the constructor. Subclasses with + non-standard constructor signatures should override this method. + + Args: + config: Provider-specific configuration dictionary. + + Returns: + A configured Model instance. + """ + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if client_args is not None: + kwargs["client_args"] = client_args + kwargs.update(config) + return cls(**kwargs) + @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 97cb7948a..37b7090b1 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -56,6 +56,27 @@ class OllamaConfig(TypedDict, total=False): temperature: float | None top_p: float | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "OllamaModel": + """Create an OllamaModel from a configuration dictionary. + + Handles extraction of ``host`` as a positional argument and mapping of + ``client_args`` to the ``ollama_client_args`` constructor parameter. + + Args: + config: Model configuration dictionary. + + Returns: + A configured OllamaModel instance. + """ + host = config.pop("host", None) + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if client_args is not None: + kwargs["ollama_client_args"] = client_args + kwargs.update(config) + return cls(host, **kwargs) + def __init__( self, host: str | None, diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 775969290..424bac85f 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -133,6 +133,27 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): target_variant: str | None | None additional_args: dict[str, Any] | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "SageMakerAIModel": + """Create a SageMakerAIModel from a configuration dictionary. + + Handles extraction of ``endpoint_config``, ``payload_config``, and conversion of + ``boto_client_config`` from a plain dict to ``botocore.config.Config``. + + Args: + config: Model configuration dictionary. + + Returns: + A configured SageMakerAIModel instance. + """ + kwargs: dict[str, Any] = {} + kwargs["endpoint_config"] = config.pop("endpoint_config", {}) + kwargs["payload_config"] = config.pop("payload_config", {}) + if "boto_client_config" in config: + raw = config.pop("boto_client_config") + kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw + return cls(**kwargs) + def __init__( self, endpoint_config: SageMakerAIEndpointConfig, diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py index e6188079b..e60a24b94 100644 --- a/tests/strands/experimental/test_agent_config.py +++ b/tests/strands/experimental/test_agent_config.py @@ -3,10 +3,21 @@ import json import os import tempfile +from typing import Any +from unittest.mock import MagicMock, patch import pytest from strands.experimental import config_to_agent +from strands.experimental.agent_config import ( + PROVIDER_MAP, + _create_model_from_dict, + _resolve_env_vars, +) + +# ============================================================================= +# Backward compatibility tests (existing) +# ============================================================================= def test_config_to_agent_with_dict(): @@ -170,3 +181,541 @@ def test_config_to_agent_with_tool(): config = {"model": "test-model", "tools": ["tests.fixtures.say_tool:say"]} agent = config_to_agent(config) assert "say" in agent.tool_names + + +# ============================================================================= +# Environment variable resolution tests +# ============================================================================= + + +class TestResolveEnvVars: + """Tests for the _resolve_env_vars utility function.""" + + def test_resolve_dollar_prefix(self): + """Test resolving $VAR_NAME format.""" + with patch.dict(os.environ, {"MY_API_KEY": "secret123"}): + assert _resolve_env_vars("$MY_API_KEY") == "secret123" + + def test_resolve_braced_format(self): + """Test resolving ${VAR_NAME} format.""" + with patch.dict(os.environ, {"MY_API_KEY": "secret456"}): + assert _resolve_env_vars("${MY_API_KEY}") == "secret456" + + def test_resolve_nested_dict(self): + """Test recursive resolution in nested dicts.""" + with patch.dict(os.environ, {"KEY1": "val1", "KEY2": "val2"}): + data = {"outer": {"inner": "$KEY1"}, "flat": "${KEY2}"} + result = _resolve_env_vars(data) + assert result == {"outer": {"inner": "val1"}, "flat": "val2"} + + def test_resolve_list(self): + """Test recursive resolution in lists.""" + with patch.dict(os.environ, {"KEY1": "val1", "KEY2": "val2"}): + data = ["$KEY1", "${KEY2}", "literal"] + result = _resolve_env_vars(data) + assert result == ["val1", "val2", "literal"] + + def test_missing_env_var_raises(self): + """Test that missing env vars raise ValueError.""" + with patch.dict(os.environ, {}, clear=True): + # Ensure the var is not set + os.environ.pop("NONEXISTENT_VAR", None) + with pytest.raises(ValueError, match="Environment variable 'NONEXISTENT_VAR' is not set"): + _resolve_env_vars("$NONEXISTENT_VAR") + + def test_missing_braced_env_var_raises(self): + """Test that missing braced env vars raise ValueError.""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("NONEXISTENT_VAR", None) + with pytest.raises(ValueError, match="Environment variable 'NONEXISTENT_VAR' is not set"): + _resolve_env_vars("${NONEXISTENT_VAR}") + + def test_non_env_string_unchanged(self): + """Test that regular strings are returned unchanged.""" + assert _resolve_env_vars("just-a-string") == "just-a-string" + + def test_non_string_values_unchanged(self): + """Test that non-string values pass through unchanged.""" + assert _resolve_env_vars(42) == 42 + assert _resolve_env_vars(True) is True + assert _resolve_env_vars(3.14) == 3.14 + assert _resolve_env_vars(None) is None + + def test_deeply_nested_resolution(self): + """Test env var resolution in deeply nested structures.""" + with patch.dict(os.environ, {"DEEP_VAL": "found"}): + data = {"a": {"b": {"c": [{"d": "$DEEP_VAL"}]}}} + result = _resolve_env_vars(data) + assert result == {"a": {"b": {"c": [{"d": "found"}]}}} + + +# ============================================================================= +# Schema validation tests — dual-format model field +# ============================================================================= + + +class TestSchemaValidation: + """Tests for the updated AGENT_CONFIG_SCHEMA that supports both string and object model formats.""" + + def test_string_model_valid(self): + """Test that string model format still passes validation.""" + config = {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "us.anthropic.claude-sonnet-4-20250514-v1:0" + + def test_object_model_valid(self): + """Test that object model format passes schema validation.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 10000, + } + } + agent = config_to_agent(config) + assert agent.model is mock_model + + def test_object_model_missing_provider_raises(self): + """Test that object model without provider raises validation error.""" + config = {"model": {"model_id": "some-model"}} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + def test_object_model_allows_additional_properties(self): + """Test that object model format allows provider-specific properties.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + "client_args": {"api_key": "test"}, + "custom_field": "allowed", + } + } + # Should not raise + config_to_agent(config) + + def test_null_model_still_valid(self): + """Test that null model is still accepted for default behavior.""" + config = {"model": None} + agent = config_to_agent(config) + # Should use default model + assert agent is not None + + def test_model_wrong_type_raises(self): + """Test that model field with invalid type raises validation error.""" + config = {"model": 12345} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + def test_object_model_from_file(self): + """Test object model format loaded from a JSON file.""" + mock_model = MagicMock() + config_data = { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + } + } + temp_path = "" + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + json.dump(config_data, f) + f.flush() + temp_path = f.name + + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + agent = config_to_agent(temp_path) + assert agent.model is mock_model + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +# ============================================================================= +# Provider factory tests — all 12 providers +# ============================================================================= + + +class TestProviderMap: + """Test that all 12 providers are registered in PROVIDER_MAP.""" + + EXPECTED_PROVIDERS = [ + "bedrock", + "anthropic", + "openai", + "gemini", + "ollama", + "litellm", + "mistral", + "llamaapi", + "llamacpp", + "sagemaker", + "writer", + "openai_responses", + ] + + def test_all_providers_registered(self): + """Test that all 12 providers are in PROVIDER_MAP.""" + for provider in self.EXPECTED_PROVIDERS: + assert provider in PROVIDER_MAP, f"Provider '{provider}' not found in PROVIDER_MAP" + + def test_no_extra_providers(self): + """Test that only the expected 12 providers are registered.""" + assert set(PROVIDER_MAP.keys()) == set(self.EXPECTED_PROVIDERS) + + +class TestCreateModelFromConfig: + """Tests for _create_model_from_dict dispatching to cls.from_dict.""" + + def test_unknown_provider_raises(self): + """Test that an unknown provider name raises ValueError.""" + with pytest.raises(ValueError, match="Unknown model provider: 'nonexistent'"): + _create_model_from_dict({"provider": "nonexistent", "model_id": "x"}) + + def _patch_model_class(self, class_name): + """Patch a model class on the strands.models module and return the mock.""" + mock_cls = MagicMock() + mock_cls.from_dict.return_value = MagicMock() + return patch(f"strands.models.{class_name}", mock_cls, create=True), mock_cls + + def test_dispatches_to_from_dict(self): + """Test that _create_model_from_dict calls cls.from_dict on the resolved model class.""" + mock_model = MagicMock() + mock_cls = MagicMock() + mock_cls.from_dict.return_value = mock_model + + with patch("strands.models.AnthropicModel", mock_cls, create=True): + result = _create_model_from_dict( + { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 8192, + "client_args": {"api_key": "test-key"}, + } + ) + mock_cls.from_dict.assert_called_once() + call_config = mock_cls.from_dict.call_args[0][0] + assert call_config["model_id"] == "claude-sonnet-4-20250514" + assert call_config["max_tokens"] == 8192 + assert call_config["client_args"] == {"api_key": "test-key"} + assert "provider" not in call_config + assert result is mock_model + + def test_does_not_mutate_input(self): + """Test that _create_model_from_dict does not mutate the input dict.""" + original = {"provider": "anthropic", "model_id": "test"} + original_copy = original.copy() + + mock_cls = MagicMock() + mock_cls.from_dict.return_value = MagicMock() + with patch("strands.models.AnthropicModel", mock_cls, create=True): + _create_model_from_dict(original) + + assert original == original_copy + + @pytest.mark.parametrize( + "provider,class_name", + list(PROVIDER_MAP.items()), + ) + def test_all_providers_dispatch(self, provider, class_name): + """Test that each registered provider dispatches to the correct class.""" + patcher, mock_cls = self._patch_model_class(class_name) + with patcher: + _create_model_from_dict({"provider": provider, "model_id": "test"}) + mock_cls.from_dict.assert_called_once() + + +# ============================================================================= +# Model from_dict tests — provider-specific parameter handling +# ============================================================================= + + +class TestModelFromConfig: + """Tests for from_dict on model classes with non-standard constructors. + + Patches __init__ on each model class to capture the arguments passed by from_dict + without actually initializing the model (which would require real provider dependencies). + """ + + def test_bedrock_from_dict_boto_client_config_conversion(self): + """Test that BedrockModel.from_dict converts boto_client_config dict to BotocoreConfig.""" + from botocore.config import Config as BotocoreConfig + + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-west-2", + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-west-2" + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + assert call_kwargs["model_id"] == "test-model" + + def test_bedrock_from_dict_without_boto_client_config(self): + """Test BedrockModel.from_dict without boto_client_config.""" + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-east-1", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-east-1" + assert "boto_client_config" not in call_kwargs + + def test_bedrock_from_dict_endpoint_url(self): + """Test BedrockModel.from_dict with endpoint_url.""" + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "endpoint_url": "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_url"] == "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com" + + def test_ollama_from_dict_host_and_client_args_mapping(self): + """Test that OllamaModel.from_dict routes host and maps client_args to ollama_client_args.""" + from strands.models.ollama import OllamaModel + + with patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict( + { + "model_id": "llama3", + "host": "http://localhost:11434", + "client_args": {"timeout": 30}, + } + ) + call_args = mock_init.call_args + assert call_args[0][0] == "http://localhost:11434" # host is positional + assert call_args[1]["ollama_client_args"] == {"timeout": 30} + assert call_args[1]["model_id"] == "llama3" + + def test_ollama_from_dict_default_host(self): + """Test OllamaModel.from_dict with no host specified defaults to None.""" + from strands.models.ollama import OllamaModel + + with patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict({"model_id": "llama3"}) + call_args = mock_init.call_args + assert call_args[0][0] is None # host defaults to None + + def test_mistral_from_dict_api_key_extraction(self): + """Test that MistralModel.from_dict extracts api_key separately.""" + from strands.models.mistral import MistralModel + + with patch.object(MistralModel, "__init__", return_value=None) as mock_init: + MistralModel.from_dict( + { + "model_id": "mistral-large-latest", + "api_key": "test-key", + "client_args": {"timeout": 60}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["api_key"] == "test-key" + assert call_kwargs["client_args"] == {"timeout": 60} + assert call_kwargs["model_id"] == "mistral-large-latest" + + def test_llamacpp_from_dict_base_url_and_timeout(self): + """Test that LlamaCppModel.from_dict extracts base_url and timeout.""" + from strands.models.llamacpp import LlamaCppModel + + with patch.object(LlamaCppModel, "__init__", return_value=None) as mock_init: + LlamaCppModel.from_dict( + { + "model_id": "default", + "base_url": "http://myhost:8080", + "timeout": 30.0, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["base_url"] == "http://myhost:8080" + assert call_kwargs["timeout"] == 30.0 + assert call_kwargs["model_id"] == "default" + + def test_sagemaker_from_dict_dict_params(self): + """Test that SageMakerAIModel.from_dict receives endpoint_config and payload_config as dicts.""" + from strands.models.sagemaker import SageMakerAIModel + + with patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep", "region_name": "us-west-2"}, + "payload_config": {"max_tokens": 1024, "stream": True}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_config"] == {"endpoint_name": "my-ep", "region_name": "us-west-2"} + assert call_kwargs["payload_config"] == {"max_tokens": 1024, "stream": True} + + def test_sagemaker_from_dict_boto_client_config_conversion(self): + """Test that SageMakerAIModel.from_dict converts boto_client_config dict to BotocoreConfig.""" + from botocore.config import Config as BotocoreConfig + + from strands.models.sagemaker import SageMakerAIModel + + with patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep"}, + "payload_config": {"max_tokens": 1024}, + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + + def test_default_from_dict_client_args_pattern(self): + """Test the default from_dict (inherited) handles client_args + remaining kwargs.""" + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + # BedrockModel overrides from_dict, so use AnthropicModel which inherits the default + from strands.models.anthropic import AnthropicModel + + with patch.object(AnthropicModel, "__init__", return_value=None) as mock_init: + AnthropicModel.from_dict( + { + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 4096, + "client_args": {"api_key": "test"}, + "params": {"temperature": 0.5}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["client_args"] == {"api_key": "test"} + assert call_kwargs["model_id"] == "claude-sonnet-4-20250514" + assert call_kwargs["max_tokens"] == 4096 + assert call_kwargs["params"] == {"temperature": 0.5} + + def test_default_from_dict_without_client_args(self): + """Test the default from_dict works without client_args.""" + from strands.models.anthropic import AnthropicModel + + with patch.object(AnthropicModel, "__init__", return_value=None) as mock_init: + AnthropicModel.from_dict({"model_id": "test-model", "max_tokens": 1024}) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["model_id"] == "test-model" + assert call_kwargs["max_tokens"] == 1024 + assert "client_args" not in call_kwargs + + +# ============================================================================= +# Error handling tests +# ============================================================================= + + +class TestErrorHandling: + """Tests for error handling in model creation.""" + + def test_missing_optional_dependency(self): + """Test clear error when provider dependency is not installed.""" + mock_cls = MagicMock() + mock_cls.from_dict.side_effect = ImportError("No module named 'anthropic'") + + with patch("strands.models.AnthropicModel", mock_cls, create=True): + with pytest.raises(ImportError, match="anthropic"): + _create_model_from_dict( + { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + } + ) + + def test_unknown_provider_error_message(self): + """Test that unknown provider gives helpful error message.""" + with pytest.raises(ValueError, match="Unknown model provider: 'my_custom_provider'"): + _create_model_from_dict({"provider": "my_custom_provider"}) + + +# ============================================================================= +# Integration: config_to_agent with object model +# ============================================================================= + + +class TestConfigToAgentObjectModel: + """Tests for config_to_agent using the object model format end-to-end.""" + + def test_object_model_creates_agent(self): + """Test that object model config creates an agent with the correct model.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + }, + "prompt": "You are helpful", + } + agent = config_to_agent(config) + assert agent.model is mock_model + assert agent.system_prompt == "You are helpful" + + def test_object_model_env_var_resolution(self): + """Test that env vars are resolved in object model config before provider creation.""" + mock_model = MagicMock() + with patch.dict(os.environ, {"TEST_API_KEY": "resolved-key"}): + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ) as mock_create: + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + "client_args": {"api_key": "$TEST_API_KEY"}, + } + } + config_to_agent(config) + # Verify the env var was resolved before passing to the factory + call_args = mock_create.call_args[0][0] + assert call_args["client_args"]["api_key"] == "resolved-key" + + def test_string_model_backward_compat(self): + """Test that string model still works as Bedrock model_id.""" + config = {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + agent = config_to_agent(config) + # String model is passed directly to Agent, which interprets it as Bedrock model_id + assert agent.model.config["model_id"] == "us.anthropic.claude-sonnet-4-20250514-v1:0" + + def test_object_model_with_kwargs_override(self): + """Test that kwargs can still override when using object model.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": {"provider": "openai", "model_id": "gpt-4o"}, + "prompt": "Original prompt", + } + agent = config_to_agent(config, system_prompt="Override prompt") + assert agent.system_prompt == "Override prompt" From 0773285b8b0a0345a137503add2ec08fa62258ff Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Mon, 13 Apr 2026 13:09:47 -0400 Subject: [PATCH 2/4] fix: address review feedback for multi-provider agent config - Fix **kwargs type annotation (dict[str, Any] -> Any) in config_to_agent - Add defensive copy in all from_dict methods to avoid mutating caller's dict - Raise ValueError on unsupported config keys in SageMaker from_dict - Improve _create_model_from_dict return type to Model - Document env var pattern full-string-only matching --- src/strands/experimental/agent_config.py | 14 ++++++++++---- src/strands/models/bedrock.py | 4 +++- src/strands/models/llamacpp.py | 4 +++- src/strands/models/mistral.py | 4 +++- src/strands/models/model.py | 4 +++- src/strands/models/ollama.py | 4 +++- src/strands/models/sagemaker.py | 7 ++++++- 7 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index 26510e6c6..256aa764b 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -33,15 +33,20 @@ ``gemini_tools`` (Gemini). Use ``region_name`` / ``client_args`` as JSON-friendly alternatives. """ +from __future__ import annotations + import json import os import re from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import jsonschema from jsonschema import ValidationError +if TYPE_CHECKING: + from ..models.model import Model + # JSON Schema for agent configuration AGENT_CONFIG_SCHEMA = { "$schema": "http://json-schema.org/draft-07/schema#", @@ -91,7 +96,8 @@ # Pre-compile validator for better performance _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) -# Pattern for matching environment variable references +# Only full-string env var references are resolved (no inline interpolation). +# "prefix-$VAR" is NOT resolved; construct values programmatically instead. _ENV_VAR_PATTERN = re.compile(r"^\$\{([^}]+)\}$|^\$([A-Za-z_][A-Za-z0-9_]*)$") # Provider name to model class name — resolved via strands.models lazy __getattr__ @@ -142,7 +148,7 @@ def _resolve_env_vars(value: Any) -> Any: return value -def _create_model_from_dict(model_config: dict[str, Any]) -> Any: +def _create_model_from_dict(model_config: dict[str, Any]) -> "Model": """Create a Model instance from a provider config dict. Routes the config to the appropriate model class based on the ``provider`` field, @@ -173,7 +179,7 @@ def _create_model_from_dict(model_config: dict[str, Any]) -> Any: return model_cls.from_dict(config) -def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any: +def config_to_agent(config: str | dict[str, Any], **kwargs: Any) -> Any: """Create an Agent from a configuration file or dictionary. This function supports tools that can be loaded declaratively (file paths, module names, diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c3c246fc3..bc906ba1a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -135,11 +135,13 @@ def from_dict(cls, config: dict[str, Any]) -> "BedrockModel": ``boto_client_config`` from a plain dict to ``botocore.config.Config``. Args: - config: Model configuration dictionary. + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. Returns: A configured BedrockModel instance. """ + config = config.copy() kwargs: dict[str, Any] = {} if "region_name" in config: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 36da4ca03..823a41f36 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -138,11 +138,13 @@ def from_dict(cls, config: dict[str, Any]) -> "LlamaCppModel": Handles extraction of ``base_url`` and ``timeout`` as separate constructor parameters. Args: - config: Model configuration dictionary. + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. Returns: A configured LlamaCppModel instance. """ + config = config.copy() kwargs: dict[str, Any] = {} if "base_url" in config: kwargs["base_url"] = config.pop("base_url") diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 9bdb8ced9..2203ea29a 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -60,11 +60,13 @@ def from_dict(cls, config: dict[str, Any]) -> "MistralModel": Handles extraction of ``api_key`` and ``client_args`` as separate constructor parameters. Args: - config: Model configuration dictionary. + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. Returns: A configured MistralModel instance. """ + config = config.copy() api_key = config.pop("api_key", None) client_args = config.pop("client_args", None) kwargs: dict[str, Any] = {} diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 0c3a5c7b8..cfd95cb3c 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -62,11 +62,13 @@ def from_dict(cls, config: dict[str, Any]) -> Model: non-standard constructor signatures should override this method. Args: - config: Provider-specific configuration dictionary. + config: Provider-specific configuration dictionary. A copy is made internally; + the caller's dict is not modified. Returns: A configured Model instance. """ + config = config.copy() client_args = config.pop("client_args", None) kwargs: dict[str, Any] = {} if client_args is not None: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 37b7090b1..41217dd47 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -64,11 +64,13 @@ def from_dict(cls, config: dict[str, Any]) -> "OllamaModel": ``client_args`` to the ``ollama_client_args`` constructor parameter. Args: - config: Model configuration dictionary. + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. Returns: A configured OllamaModel instance. """ + config = config.copy() host = config.pop("host", None) client_args = config.pop("client_args", None) kwargs: dict[str, Any] = {} diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 424bac85f..9d3178650 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -141,17 +141,22 @@ def from_dict(cls, config: dict[str, Any]) -> "SageMakerAIModel": ``boto_client_config`` from a plain dict to ``botocore.config.Config``. Args: - config: Model configuration dictionary. + config: Model configuration dictionary. A copy is made internally; + the caller's dict is not modified. Returns: A configured SageMakerAIModel instance. """ + config = config.copy() kwargs: dict[str, Any] = {} kwargs["endpoint_config"] = config.pop("endpoint_config", {}) kwargs["payload_config"] = config.pop("payload_config", {}) if "boto_client_config" in config: raw = config.pop("boto_client_config") kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw + if config: + unexpected = ", ".join(sorted(config.keys())) + raise ValueError(f"Unsupported SageMaker config keys: {unexpected}") return cls(**kwargs) def __init__( From 407de2de304ff044db64b5ad4000990eadc06dbd Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Wed, 15 Apr 2026 19:58:18 +0000 Subject: [PATCH 3/4] refactor: address PR review feedback - Remove env var resolution per maintainer request (Unshure) - Move from_dict tests to respective model test files per project test mirroring convention (test_bedrock, test_ollama, test_mistral, test_llamacpp, test_sagemaker, test_model) - Add missing test for SageMaker unexpected config keys rejection - Remove dead code (unused BedrockModel import/patch block) - Fix test mocking to avoid triggering lazy imports of optional deps - Fix mypy type annotation on _create_model_from_dict return type --- src/strands/experimental/agent_config.py | 53 +-- .../strands/experimental/test_agent_config.py | 335 +++--------------- tests/strands/models/test_bedrock.py | 44 +++ tests/strands/models/test_llamacpp.py | 21 ++ tests/strands/models/test_mistral.py | 19 + tests/strands/models/test_model.py | 43 +++ tests/strands/models/test_ollama.py | 26 ++ tests/strands/models/test_sagemaker.py | 41 +++ 8 files changed, 242 insertions(+), 340 deletions(-) diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index 256aa764b..6a4e727ec 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -21,13 +21,10 @@ "provider": "anthropic", "model_id": "claude-sonnet-4-20250514", "max_tokens": 10000, - "client_args": {"api_key": "$ANTHROPIC_API_KEY"} + "client_args": {"api_key": "..."} } } -Environment variable references (``$VAR`` or ``${VAR}``) in model config values are resolved -automatically before provider instantiation. - Note: The following constructor parameters cannot be specified from JSON because they require code-based instantiation: ``boto_session`` (Bedrock, SageMaker), ``client`` (OpenAI, Gemini), ``gemini_tools`` (Gemini). Use ``region_name`` / ``client_args`` as JSON-friendly alternatives. @@ -36,8 +33,6 @@ from __future__ import annotations import json -import os -import re from pathlib import Path from typing import TYPE_CHECKING, Any @@ -96,10 +91,6 @@ # Pre-compile validator for better performance _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) -# Only full-string env var references are resolved (no inline interpolation). -# "prefix-$VAR" is NOT resolved; construct values programmatically instead. -_ENV_VAR_PATTERN = re.compile(r"^\$\{([^}]+)\}$|^\$([A-Za-z_][A-Za-z0-9_]*)$") - # Provider name to model class name — resolved via strands.models lazy __getattr__ PROVIDER_MAP: dict[str, str] = { "bedrock": "BedrockModel", @@ -117,38 +108,7 @@ } -def _resolve_env_vars(value: Any) -> Any: - """Recursively resolve environment variable references in config values. - - String values matching ``$VAR_NAME`` or ``${VAR_NAME}`` are replaced with the - corresponding environment variable value. Dicts and lists are traversed recursively. - - Args: - value: The value to resolve. Can be a string, dict, list, or any other type. - - Returns: - The resolved value with environment variable references replaced. - - Raises: - ValueError: If a referenced environment variable is not set. - """ - if isinstance(value, str): - match = _ENV_VAR_PATTERN.match(value) - if match: - var_name = match.group(1) or match.group(2) - env_value = os.environ.get(var_name) - if env_value is None: - raise ValueError(f"Environment variable '{var_name}' is not set") - return env_value - return value - if isinstance(value, dict): - return {k: _resolve_env_vars(v) for k, v in value.items()} - if isinstance(value, list): - return [_resolve_env_vars(item) for item in value] - return value - - -def _create_model_from_dict(model_config: dict[str, Any]) -> "Model": +def _create_model_from_dict(model_config: dict[str, Any]) -> Model: """Create a Model instance from a provider config dict. Routes the config to the appropriate model class based on the ``provider`` field, @@ -175,7 +135,7 @@ def _create_model_from_dict(model_config: dict[str, Any]) -> "Model": from .. import models - model_cls = getattr(models, class_name) + model_cls: type[Model] = getattr(models, class_name) return model_cls.from_dict(config) @@ -214,7 +174,7 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: Any) -> Any: Create agent with object model config: >>> config = { - ... "model": {"provider": "openai", "model_id": "gpt-4o", "client_args": {"api_key": "$OPENAI_API_KEY"}} + ... "model": {"provider": "openai", "model_id": "gpt-4o", "client_args": {"api_key": "..."}} ... } >>> agent = config_to_agent(config) """ @@ -253,9 +213,8 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: Any) -> Any: # Handle model field — string vs object format model_value = config_dict.get("model") if isinstance(model_value, dict): - # Object format: resolve env vars and create Model instance via factory - resolved_config = _resolve_env_vars(model_value) - agent_kwargs["model"] = _create_model_from_dict(resolved_config) + # Object format: create Model instance via factory + agent_kwargs["model"] = _create_model_from_dict(model_value) elif model_value is not None: # String format (backward compat): pass directly as model_id to Agent agent_kwargs["model"] = model_value diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py index e60a24b94..2ccfbcb1a 100644 --- a/tests/strands/experimental/test_agent_config.py +++ b/tests/strands/experimental/test_agent_config.py @@ -3,7 +3,6 @@ import json import os import tempfile -from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -12,7 +11,6 @@ from strands.experimental.agent_config import ( PROVIDER_MAP, _create_model_from_dict, - _resolve_env_vars, ) # ============================================================================= @@ -183,72 +181,6 @@ def test_config_to_agent_with_tool(): assert "say" in agent.tool_names -# ============================================================================= -# Environment variable resolution tests -# ============================================================================= - - -class TestResolveEnvVars: - """Tests for the _resolve_env_vars utility function.""" - - def test_resolve_dollar_prefix(self): - """Test resolving $VAR_NAME format.""" - with patch.dict(os.environ, {"MY_API_KEY": "secret123"}): - assert _resolve_env_vars("$MY_API_KEY") == "secret123" - - def test_resolve_braced_format(self): - """Test resolving ${VAR_NAME} format.""" - with patch.dict(os.environ, {"MY_API_KEY": "secret456"}): - assert _resolve_env_vars("${MY_API_KEY}") == "secret456" - - def test_resolve_nested_dict(self): - """Test recursive resolution in nested dicts.""" - with patch.dict(os.environ, {"KEY1": "val1", "KEY2": "val2"}): - data = {"outer": {"inner": "$KEY1"}, "flat": "${KEY2}"} - result = _resolve_env_vars(data) - assert result == {"outer": {"inner": "val1"}, "flat": "val2"} - - def test_resolve_list(self): - """Test recursive resolution in lists.""" - with patch.dict(os.environ, {"KEY1": "val1", "KEY2": "val2"}): - data = ["$KEY1", "${KEY2}", "literal"] - result = _resolve_env_vars(data) - assert result == ["val1", "val2", "literal"] - - def test_missing_env_var_raises(self): - """Test that missing env vars raise ValueError.""" - with patch.dict(os.environ, {}, clear=True): - # Ensure the var is not set - os.environ.pop("NONEXISTENT_VAR", None) - with pytest.raises(ValueError, match="Environment variable 'NONEXISTENT_VAR' is not set"): - _resolve_env_vars("$NONEXISTENT_VAR") - - def test_missing_braced_env_var_raises(self): - """Test that missing braced env vars raise ValueError.""" - with patch.dict(os.environ, {}, clear=True): - os.environ.pop("NONEXISTENT_VAR", None) - with pytest.raises(ValueError, match="Environment variable 'NONEXISTENT_VAR' is not set"): - _resolve_env_vars("${NONEXISTENT_VAR}") - - def test_non_env_string_unchanged(self): - """Test that regular strings are returned unchanged.""" - assert _resolve_env_vars("just-a-string") == "just-a-string" - - def test_non_string_values_unchanged(self): - """Test that non-string values pass through unchanged.""" - assert _resolve_env_vars(42) == 42 - assert _resolve_env_vars(True) is True - assert _resolve_env_vars(3.14) == 3.14 - assert _resolve_env_vars(None) is None - - def test_deeply_nested_resolution(self): - """Test env var resolution in deeply nested structures.""" - with patch.dict(os.environ, {"DEEP_VAL": "found"}): - data = {"a": {"b": {"c": [{"d": "$DEEP_VAL"}]}}} - result = _resolve_env_vars(data) - assert result == {"a": {"b": {"c": [{"d": "found"}]}}} - - # ============================================================================= # Schema validation tests — dual-format model field # ============================================================================= @@ -385,19 +317,31 @@ def test_unknown_provider_raises(self): with pytest.raises(ValueError, match="Unknown model provider: 'nonexistent'"): _create_model_from_dict({"provider": "nonexistent", "model_id": "x"}) - def _patch_model_class(self, class_name): - """Patch a model class on the strands.models module and return the mock.""" + def _set_mock_on_models(self, class_name): + """Inject a mock class directly into strands.models.__dict__ to avoid triggering lazy imports.""" + import strands.models as models_pkg + mock_cls = MagicMock() mock_cls.from_dict.return_value = MagicMock() - return patch(f"strands.models.{class_name}", mock_cls, create=True), mock_cls + original = models_pkg.__dict__.get(class_name) + models_pkg.__dict__[class_name] = mock_cls + return mock_cls, original + + def _restore_models(self, class_name, original): + """Restore original state of strands.models after test.""" + import strands.models as models_pkg + + if original is None: + models_pkg.__dict__.pop(class_name, None) + else: + models_pkg.__dict__[class_name] = original def test_dispatches_to_from_dict(self): """Test that _create_model_from_dict calls cls.from_dict on the resolved model class.""" + mock_cls, original = self._set_mock_on_models("AnthropicModel") mock_model = MagicMock() - mock_cls = MagicMock() mock_cls.from_dict.return_value = mock_model - - with patch("strands.models.AnthropicModel", mock_cls, create=True): + try: result = _create_model_from_dict( { "provider": "anthropic", @@ -413,18 +357,19 @@ def test_dispatches_to_from_dict(self): assert call_config["client_args"] == {"api_key": "test-key"} assert "provider" not in call_config assert result is mock_model + finally: + self._restore_models("AnthropicModel", original) def test_does_not_mutate_input(self): """Test that _create_model_from_dict does not mutate the input dict.""" - original = {"provider": "anthropic", "model_id": "test"} - original_copy = original.copy() - - mock_cls = MagicMock() - mock_cls.from_dict.return_value = MagicMock() - with patch("strands.models.AnthropicModel", mock_cls, create=True): - _create_model_from_dict(original) - - assert original == original_copy + mock_cls, original = self._set_mock_on_models("AnthropicModel") + try: + original_input = {"provider": "anthropic", "model_id": "test"} + original_copy = original_input.copy() + _create_model_from_dict(original_input) + assert original_input == original_copy + finally: + self._restore_models("AnthropicModel", original) @pytest.mark.parametrize( "provider,class_name", @@ -432,197 +377,12 @@ def test_does_not_mutate_input(self): ) def test_all_providers_dispatch(self, provider, class_name): """Test that each registered provider dispatches to the correct class.""" - patcher, mock_cls = self._patch_model_class(class_name) - with patcher: + mock_cls, original = self._set_mock_on_models(class_name) + try: _create_model_from_dict({"provider": provider, "model_id": "test"}) mock_cls.from_dict.assert_called_once() - - -# ============================================================================= -# Model from_dict tests — provider-specific parameter handling -# ============================================================================= - - -class TestModelFromConfig: - """Tests for from_dict on model classes with non-standard constructors. - - Patches __init__ on each model class to capture the arguments passed by from_dict - without actually initializing the model (which would require real provider dependencies). - """ - - def test_bedrock_from_dict_boto_client_config_conversion(self): - """Test that BedrockModel.from_dict converts boto_client_config dict to BotocoreConfig.""" - from botocore.config import Config as BotocoreConfig - - from strands.models.bedrock import BedrockModel - - with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: - BedrockModel.from_dict( - { - "model_id": "test-model", - "region_name": "us-west-2", - "boto_client_config": {"read_timeout": 300}, - } - ) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["region_name"] == "us-west-2" - assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) - assert call_kwargs["model_id"] == "test-model" - - def test_bedrock_from_dict_without_boto_client_config(self): - """Test BedrockModel.from_dict without boto_client_config.""" - from strands.models.bedrock import BedrockModel - - with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: - BedrockModel.from_dict( - { - "model_id": "test-model", - "region_name": "us-east-1", - } - ) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["region_name"] == "us-east-1" - assert "boto_client_config" not in call_kwargs - - def test_bedrock_from_dict_endpoint_url(self): - """Test BedrockModel.from_dict with endpoint_url.""" - from strands.models.bedrock import BedrockModel - - with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: - BedrockModel.from_dict( - { - "model_id": "test-model", - "endpoint_url": "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com", - } - ) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["endpoint_url"] == "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com" - - def test_ollama_from_dict_host_and_client_args_mapping(self): - """Test that OllamaModel.from_dict routes host and maps client_args to ollama_client_args.""" - from strands.models.ollama import OllamaModel - - with patch.object(OllamaModel, "__init__", return_value=None) as mock_init: - OllamaModel.from_dict( - { - "model_id": "llama3", - "host": "http://localhost:11434", - "client_args": {"timeout": 30}, - } - ) - call_args = mock_init.call_args - assert call_args[0][0] == "http://localhost:11434" # host is positional - assert call_args[1]["ollama_client_args"] == {"timeout": 30} - assert call_args[1]["model_id"] == "llama3" - - def test_ollama_from_dict_default_host(self): - """Test OllamaModel.from_dict with no host specified defaults to None.""" - from strands.models.ollama import OllamaModel - - with patch.object(OllamaModel, "__init__", return_value=None) as mock_init: - OllamaModel.from_dict({"model_id": "llama3"}) - call_args = mock_init.call_args - assert call_args[0][0] is None # host defaults to None - - def test_mistral_from_dict_api_key_extraction(self): - """Test that MistralModel.from_dict extracts api_key separately.""" - from strands.models.mistral import MistralModel - - with patch.object(MistralModel, "__init__", return_value=None) as mock_init: - MistralModel.from_dict( - { - "model_id": "mistral-large-latest", - "api_key": "test-key", - "client_args": {"timeout": 60}, - } - ) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["api_key"] == "test-key" - assert call_kwargs["client_args"] == {"timeout": 60} - assert call_kwargs["model_id"] == "mistral-large-latest" - - def test_llamacpp_from_dict_base_url_and_timeout(self): - """Test that LlamaCppModel.from_dict extracts base_url and timeout.""" - from strands.models.llamacpp import LlamaCppModel - - with patch.object(LlamaCppModel, "__init__", return_value=None) as mock_init: - LlamaCppModel.from_dict( - { - "model_id": "default", - "base_url": "http://myhost:8080", - "timeout": 30.0, - } - ) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["base_url"] == "http://myhost:8080" - assert call_kwargs["timeout"] == 30.0 - assert call_kwargs["model_id"] == "default" - - def test_sagemaker_from_dict_dict_params(self): - """Test that SageMakerAIModel.from_dict receives endpoint_config and payload_config as dicts.""" - from strands.models.sagemaker import SageMakerAIModel - - with patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: - SageMakerAIModel.from_dict( - { - "endpoint_config": {"endpoint_name": "my-ep", "region_name": "us-west-2"}, - "payload_config": {"max_tokens": 1024, "stream": True}, - } - ) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["endpoint_config"] == {"endpoint_name": "my-ep", "region_name": "us-west-2"} - assert call_kwargs["payload_config"] == {"max_tokens": 1024, "stream": True} - - def test_sagemaker_from_dict_boto_client_config_conversion(self): - """Test that SageMakerAIModel.from_dict converts boto_client_config dict to BotocoreConfig.""" - from botocore.config import Config as BotocoreConfig - - from strands.models.sagemaker import SageMakerAIModel - - with patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: - SageMakerAIModel.from_dict( - { - "endpoint_config": {"endpoint_name": "my-ep"}, - "payload_config": {"max_tokens": 1024}, - "boto_client_config": {"read_timeout": 300}, - } - ) - call_kwargs = mock_init.call_args[1] - assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) - - def test_default_from_dict_client_args_pattern(self): - """Test the default from_dict (inherited) handles client_args + remaining kwargs.""" - from strands.models.bedrock import BedrockModel - - with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: - # BedrockModel overrides from_dict, so use AnthropicModel which inherits the default - from strands.models.anthropic import AnthropicModel - - with patch.object(AnthropicModel, "__init__", return_value=None) as mock_init: - AnthropicModel.from_dict( - { - "model_id": "claude-sonnet-4-20250514", - "max_tokens": 4096, - "client_args": {"api_key": "test"}, - "params": {"temperature": 0.5}, - } - ) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["client_args"] == {"api_key": "test"} - assert call_kwargs["model_id"] == "claude-sonnet-4-20250514" - assert call_kwargs["max_tokens"] == 4096 - assert call_kwargs["params"] == {"temperature": 0.5} - - def test_default_from_dict_without_client_args(self): - """Test the default from_dict works without client_args.""" - from strands.models.anthropic import AnthropicModel - - with patch.object(AnthropicModel, "__init__", return_value=None) as mock_init: - AnthropicModel.from_dict({"model_id": "test-model", "max_tokens": 1024}) - call_kwargs = mock_init.call_args[1] - assert call_kwargs["model_id"] == "test-model" - assert call_kwargs["max_tokens"] == 1024 - assert "client_args" not in call_kwargs + finally: + self._restore_models(class_name, original) # ============================================================================= @@ -635,10 +395,14 @@ class TestErrorHandling: def test_missing_optional_dependency(self): """Test clear error when provider dependency is not installed.""" + import strands.models as models_pkg + mock_cls = MagicMock() mock_cls.from_dict.side_effect = ImportError("No module named 'anthropic'") - with patch("strands.models.AnthropicModel", mock_cls, create=True): + original = models_pkg.__dict__.get("AnthropicModel") + models_pkg.__dict__["AnthropicModel"] = mock_cls + try: with pytest.raises(ImportError, match="anthropic"): _create_model_from_dict( { @@ -646,6 +410,11 @@ def test_missing_optional_dependency(self): "model_id": "claude-sonnet-4-20250514", } ) + finally: + if original is None: + models_pkg.__dict__.pop("AnthropicModel", None) + else: + models_pkg.__dict__["AnthropicModel"] = original def test_unknown_provider_error_message(self): """Test that unknown provider gives helpful error message.""" @@ -679,26 +448,6 @@ def test_object_model_creates_agent(self): assert agent.model is mock_model assert agent.system_prompt == "You are helpful" - def test_object_model_env_var_resolution(self): - """Test that env vars are resolved in object model config before provider creation.""" - mock_model = MagicMock() - with patch.dict(os.environ, {"TEST_API_KEY": "resolved-key"}): - with patch( - "strands.experimental.agent_config._create_model_from_dict", - return_value=mock_model, - ) as mock_create: - config = { - "model": { - "provider": "openai", - "model_id": "gpt-4o", - "client_args": {"api_key": "$TEST_API_KEY"}, - } - } - config_to_agent(config) - # Verify the env var was resolved before passing to the factory - call_args = mock_create.call_args[0][0] - assert call_args["client_args"]["api_key"] == "resolved-key" - def test_string_model_backward_compat(self): """Test that string model still works as Bedrock model_id.""" config = {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index cd7016488..8b6a6ca18 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2964,3 +2964,47 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model, assert citation["location"] == {"web": {"url": "https://example.com", "domain": "example.com"}} assert "title" not in citation assert "sourceContent" not in citation + + +class TestBedrockFromDict: + """Tests for BedrockModel.from_dict classmethod.""" + + def test_from_dict_boto_client_config_conversion(self): + """Test that from_dict converts boto_client_config dict to BotocoreConfig.""" + with unittest.mock.patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-west-2", + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-west-2" + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + assert call_kwargs["model_id"] == "test-model" + + def test_from_dict_without_boto_client_config(self): + """Test from_dict without boto_client_config.""" + with unittest.mock.patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-east-1", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-east-1" + assert "boto_client_config" not in call_kwargs + + def test_from_dict_endpoint_url(self): + """Test from_dict with endpoint_url.""" + with unittest.mock.patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "endpoint_url": "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_url"] == "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com" diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 3e023dfce..e7a50e46c 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -706,3 +706,24 @@ def test_format_request_filters_location_source_document(caplog) -> None: assert len(user_content) == 1 assert user_content[0]["type"] == "text" assert "Location sources are not supported by llama.cpp" in caplog.text + + +class TestLlamaCppFromDict: + """Tests for LlamaCppModel.from_dict classmethod.""" + + def test_from_dict_base_url_and_timeout(self): + """Test that from_dict extracts base_url and timeout.""" + from strands.models.llamacpp import LlamaCppModel + + with patch.object(LlamaCppModel, "__init__", return_value=None) as mock_init: + LlamaCppModel.from_dict( + { + "model_id": "default", + "base_url": "http://myhost:8080", + "timeout": 30.0, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["base_url"] == "http://myhost:8080" + assert call_kwargs["timeout"] == 30.0 + assert call_kwargs["model_id"] == "default" diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 57189748e..6bd8d657f 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -679,3 +679,22 @@ def test_format_request_filters_location_source_document(model, caplog): user_content = formatted_messages[0]["content"] assert user_content == "analyze this document" assert "Location sources are not supported by Mistral" in caplog.text + + +class TestMistralFromDict: + """Tests for MistralModel.from_dict classmethod.""" + + def test_from_dict_api_key_extraction(self): + """Test that from_dict extracts api_key separately.""" + with unittest.mock.patch.object(MistralModel, "__init__", return_value=None) as mock_init: + MistralModel.from_dict( + { + "model_id": "mistral-large-latest", + "api_key": "test-key", + "client_args": {"timeout": 60}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["api_key"] == "test-key" + assert call_kwargs["client_args"] == {"timeout": 60} + assert call_kwargs["model_id"] == "mistral-large-latest" diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 458e98645..37383ce44 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -213,3 +213,46 @@ def test_model_plugin_preserves_messages_when_not_stateful(model_plugin): model_plugin._on_after_invocation(event) assert len(agent.messages) == 1 + + +class TestModelFromDict: + """Tests for the default Model.from_dict classmethod.""" + + def test_from_dict_with_client_args(self): + """Test that from_dict extracts client_args and passes remaining kwargs.""" + from unittest.mock import patch + + from strands.models.bedrock import BedrockModel + + # Use BedrockModel to test from_dict invocation since it is always available; + # the base Model.from_dict is tested indirectly via the bedrock override path + # not executing (we only verify kwarg routing here). + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + # Invoke the *base* Model.from_dict by calling it on a subclass that + # does NOT override from_dict. BedrockModel overrides it, so we call + # the base implementation directly for this test. + SAModel.from_dict.__func__( + BedrockModel, + { + "model_id": "test-model", + "client_args": {"api_key": "test"}, + "max_tokens": 4096, + }, + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["client_args"] == {"api_key": "test"} + assert call_kwargs["model_id"] == "test-model" + assert call_kwargs["max_tokens"] == 4096 + + def test_from_dict_without_client_args(self): + """Test that from_dict works without client_args.""" + from unittest.mock import patch + + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + SAModel.from_dict.__func__(BedrockModel, {"model_id": "test-model", "max_tokens": 1024}) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["model_id"] == "test-model" + assert call_kwargs["max_tokens"] == 1024 + assert "client_args" not in call_kwargs diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 0d4fbb9e0..844f369d7 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -625,3 +625,29 @@ def test_format_request_filters_location_source_document(model, caplog): user_message = formatted_messages[0] assert user_message["content"] == "analyze this document" assert "Location sources are not supported by Ollama" in caplog.text + + +class TestOllamaFromDict: + """Tests for OllamaModel.from_dict classmethod.""" + + def test_from_dict_host_and_client_args_mapping(self): + """Test that from_dict routes host and maps client_args to ollama_client_args.""" + with unittest.mock.patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict( + { + "model_id": "llama3", + "host": "http://localhost:11434", + "client_args": {"timeout": 30}, + } + ) + call_args = mock_init.call_args + assert call_args[0][0] == "http://localhost:11434" + assert call_args[1]["ollama_client_args"] == {"timeout": 30} + assert call_args[1]["model_id"] == "llama3" + + def test_from_dict_default_host(self): + """Test from_dict with no host specified defaults to None.""" + with unittest.mock.patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict({"model_id": "llama3"}) + call_args = mock_init.call_args + assert call_args[0][0] is None diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index 5d6d6869a..1ce11f5bd 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -665,3 +665,44 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +class TestSageMakerFromDict: + """Tests for SageMakerAIModel.from_dict classmethod.""" + + def test_from_dict_dict_params(self): + """Test that from_dict receives endpoint_config and payload_config as dicts.""" + with unittest.mock.patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep", "region_name": "us-west-2"}, + "payload_config": {"max_tokens": 1024, "stream": True}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_config"] == {"endpoint_name": "my-ep", "region_name": "us-west-2"} + assert call_kwargs["payload_config"] == {"max_tokens": 1024, "stream": True} + + def test_from_dict_boto_client_config_conversion(self): + """Test that from_dict converts boto_client_config dict to BotocoreConfig.""" + with unittest.mock.patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep"}, + "payload_config": {"max_tokens": 1024}, + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + + def test_from_dict_rejects_unexpected_keys(self): + """Test that from_dict raises ValueError on unexpected config keys.""" + with pytest.raises(ValueError, match="Unsupported SageMaker config keys"): + SageMakerAIModel.from_dict( + { + "endpoint_config": {}, + "payload_config": {}, + "model_id": "unexpected", + } + ) From b8a964c3a08c8fc080c54da639ac328ed5b2af9e Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Wed, 15 Apr 2026 20:33:08 +0000 Subject: [PATCH 4/4] refactor: remove low-value and duplicate tests - Remove TestProviderMap class (test_all_providers_dispatch already covers all provider dispatch correctness) - Remove test_string_model_backward_compat (duplicate of test_string_model_valid) - Remove test_unknown_provider_error_message (duplicate of test_unknown_provider_raises) --- .../strands/experimental/test_agent_config.py | 42 +------------------ 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py index 2ccfbcb1a..cd2b0fad3 100644 --- a/tests/strands/experimental/test_agent_config.py +++ b/tests/strands/experimental/test_agent_config.py @@ -277,38 +277,10 @@ def test_object_model_from_file(self): # ============================================================================= -# Provider factory tests — all 12 providers +# Provider factory tests # ============================================================================= -class TestProviderMap: - """Test that all 12 providers are registered in PROVIDER_MAP.""" - - EXPECTED_PROVIDERS = [ - "bedrock", - "anthropic", - "openai", - "gemini", - "ollama", - "litellm", - "mistral", - "llamaapi", - "llamacpp", - "sagemaker", - "writer", - "openai_responses", - ] - - def test_all_providers_registered(self): - """Test that all 12 providers are in PROVIDER_MAP.""" - for provider in self.EXPECTED_PROVIDERS: - assert provider in PROVIDER_MAP, f"Provider '{provider}' not found in PROVIDER_MAP" - - def test_no_extra_providers(self): - """Test that only the expected 12 providers are registered.""" - assert set(PROVIDER_MAP.keys()) == set(self.EXPECTED_PROVIDERS) - - class TestCreateModelFromConfig: """Tests for _create_model_from_dict dispatching to cls.from_dict.""" @@ -416,11 +388,6 @@ def test_missing_optional_dependency(self): else: models_pkg.__dict__["AnthropicModel"] = original - def test_unknown_provider_error_message(self): - """Test that unknown provider gives helpful error message.""" - with pytest.raises(ValueError, match="Unknown model provider: 'my_custom_provider'"): - _create_model_from_dict({"provider": "my_custom_provider"}) - # ============================================================================= # Integration: config_to_agent with object model @@ -448,13 +415,6 @@ def test_object_model_creates_agent(self): assert agent.model is mock_model assert agent.system_prompt == "You are helpful" - def test_string_model_backward_compat(self): - """Test that string model still works as Bedrock model_id.""" - config = {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} - agent = config_to_agent(config) - # String model is passed directly to Agent, which interprets it as Bedrock model_id - assert agent.model.config["model_id"] == "us.anthropic.claude-sonnet-4-20250514-v1:0" - def test_object_model_with_kwargs_override(self): """Test that kwargs can still override when using object model.""" mock_model = MagicMock()