diff --git a/src/bedrock_agentcore/memory/client.py b/src/bedrock_agentcore/memory/client.py index 2e4ca3e4..421dfcff 100644 --- a/src/bedrock_agentcore/memory/client.py +++ b/src/bedrock_agentcore/memory/client.py @@ -37,7 +37,7 @@ Role, StrategyType, ) -from .models.filters import EventMetadataFilter, MetadataValue +from .models.filters import EventMetadataFilter, IndexedKey, MemoryMetadataFilter, MetadataValue logger = logging.getLogger(__name__) @@ -160,8 +160,25 @@ def create_memory( event_expiry_days: int = 90, memory_execution_role_arn: Optional[str] = None, stream_delivery_resources: Optional[Dict[str, Any]] = None, + indexed_keys: Optional[List[IndexedKey]] = None, ) -> Dict[str, Any]: - """Create a memory with simplified configuration.""" + """Create a memory with simplified configuration. + + Args: + name: Name for the memory resource + strategies: Optional list of strategy configurations + description: Optional description + event_expiry_days: How long to retain events (default: 90 days) + memory_execution_role_arn: IAM role ARN for memory execution + stream_delivery_resources: Optional delivery configuration for streaming memory records + indexed_keys: Optional list of metadata keys to index for filtering. + Each entry should have 'key' (str) and 'type' ('STRING', 'STRINGLIST', or 'NUMBER'). + Once declared, indexed keys cannot be removed. + Example: [{"key": "priority", "type": "NUMBER"}, {"key": "agent_type", "type": "STRING"}] + + Returns: + Created memory object + """ if strategies is None: strategies = [] @@ -184,6 +201,9 @@ def create_memory( if stream_delivery_resources is not None: params["streamDeliveryResources"] = stream_delivery_resources + if indexed_keys is not None: + params["indexedKeys"] = indexed_keys + response = self.gmcp_client.create_memory(**params) memory = response["memory"] @@ -205,9 +225,21 @@ def create_or_get_memory( event_expiry_days: int = 90, memory_execution_role_arn: Optional[str] = None, stream_delivery_resources: Optional[Dict[str, Any]] = None, + indexed_keys: Optional[List[IndexedKey]] = None, ) -> Dict[str, Any]: """Create a memory resource or fetch the existing memory details if it already exists. + Args: + name: Name for the memory resource + strategies: Optional list of strategy configurations + description: Optional description + event_expiry_days: How long to retain events (default: 90 days) + memory_execution_role_arn: IAM role ARN for memory execution + stream_delivery_resources: Optional delivery configuration for streaming memory records + indexed_keys: Optional list of metadata keys to index for filtering. + Once declared, indexed keys cannot be removed; new keys can be added + via `update_memory(addIndexedKeys=...)`. + Returns: Memory object, either newly created or existing """ @@ -219,6 +251,7 @@ def create_or_get_memory( event_expiry_days=event_expiry_days, memory_execution_role_arn=memory_execution_role_arn, stream_delivery_resources=stream_delivery_resources, + indexed_keys=indexed_keys, ) return memory except ClientError as e: @@ -243,6 +276,7 @@ def create_memory_and_wait( stream_delivery_resources: Optional[Dict[str, Any]] = None, max_wait: int = 300, poll_interval: int = 10, + indexed_keys: Optional[List[IndexedKey]] = None, ) -> Dict[str, Any]: """Create a memory and wait for it to become ACTIVE. @@ -256,6 +290,10 @@ def create_memory_and_wait( event_expiry_days: How long to retain events (default: 90 days) memory_execution_role_arn: IAM role ARN for memory execution stream_delivery_resources: Optional delivery configuration for streaming memory records + indexed_keys: Optional list of metadata keys to index for filtering. + Each entry should have 'key' (str) and 'type' ('STRING', 'STRINGLIST', or 'NUMBER'). + Once declared, indexed keys cannot be removed; new keys can be added + via `update_memory(addIndexedKeys=...)`. max_wait: Maximum seconds to wait (default: 300) poll_interval: Seconds between status checks (default: 10) @@ -274,6 +312,7 @@ def create_memory_and_wait( event_expiry_days=event_expiry_days, memory_execution_role_arn=memory_execution_role_arn, stream_delivery_resources=stream_delivery_resources, + indexed_keys=indexed_keys, ) memory_id = memory.get("memoryId", memory.get("id")) # Handle both field names @@ -318,6 +357,7 @@ def retrieve_memories( actor_id: Optional[str] = None, top_k: int = 3, namespace_path: Optional[str] = None, + metadata_filters: Optional[List[MemoryMetadataFilter]] = None, ) -> List[Dict[str, Any]]: """Retrieve relevant memories using exact match or hierarchical path prefix. @@ -330,15 +370,29 @@ def retrieve_memories( actor_id: Optional actor ID (deprecated, use namespace) top_k: Number of results to return namespace_path: Hierarchical path prefix (e.g., "/org/team/") + metadata_filters: Optional list of metadata filter expressions to scope results. + Use MemoryMetadataFilter.build_expression() to construct filters. + The service accepts 1-5 filters. An empty list is treated as no filter. + Example: [MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.EQUALS_TO, + MemoryRecordRightExpression.build_string("high"), + )] Returns: List of memory records. Returns an empty list if the namespace arguments are invalid (both provided, neither provided, or contain wildcards) or if the service call fails. + + Raises: + ValueError: If `metadata_filters` exceeds the service maximum of 5. """ if query is None: raise TypeError("retrieve_memories() missing required argument: 'query'") + if metadata_filters is not None and len(metadata_filters) > 5: + raise ValueError(f"metadata_filters supports a maximum of 5 expressions; received {len(metadata_filters)}.") + try: ns_params = build_namespace_params(namespace, namespace_path) except ValueError as e: @@ -348,8 +402,13 @@ def retrieve_memories( ns_value = namespace or namespace_path try: + search_criteria = {"searchQuery": query, "topK": top_k} + if metadata_filters: + search_criteria["metadataFilters"] = metadata_filters + logger.debug("Applying %d metadata filter(s)", len(metadata_filters)) + response = self.gmdp_client.retrieve_memory_records( - memoryId=memory_id, searchCriteria={"searchQuery": query, "topK": top_k}, **ns_params + memoryId=memory_id, searchCriteria=search_criteria, **ns_params ) memories = response.get("memoryRecordSummaries", []) logger.info("Retrieved %d memories from namespace: %s", len(memories), ns_value) diff --git a/src/bedrock_agentcore/memory/models/__init__.py b/src/bedrock_agentcore/memory/models/__init__.py index 0213c5fc..752256d7 100644 --- a/src/bedrock_agentcore/memory/models/__init__.py +++ b/src/bedrock_agentcore/memory/models/__init__.py @@ -5,9 +5,15 @@ from .DictWrapper import DictWrapper from .filters import ( EventMetadataFilter, + IndexedKey, LeftExpression, + MemoryMetadataFilter, + MemoryRecordLeftExpression, + MemoryRecordOperatorType, + MemoryRecordRightExpression, MetadataKey, MetadataValue, + MetadataValueType, OperatorType, RightExpression, StringValue, @@ -101,4 +107,10 @@ def __init__(self, session_summary: Dict[str, Any]): "OperatorType", "RightExpression", "EventMetadataFilter", + "MemoryRecordOperatorType", + "MemoryRecordLeftExpression", + "MemoryRecordRightExpression", + "MemoryMetadataFilter", + "MetadataValueType", + "IndexedKey", ] diff --git a/src/bedrock_agentcore/memory/models/filters.py b/src/bedrock_agentcore/memory/models/filters.py index 9ab25f7d..88bd48c3 100644 --- a/src/bedrock_agentcore/memory/models/filters.py +++ b/src/bedrock_agentcore/memory/models/filters.py @@ -1,7 +1,14 @@ -"""Event metadata filter models for querying events based on metadata.""" +"""Metadata filter models for querying events and memory records.""" +from datetime import datetime from enum import Enum -from typing import Optional, TypedDict, Union +from typing import List, Optional, TypedDict, Union + +from typing_extensions import NotRequired + +# ============================================================================ +# Event Metadata Filters (existing) +# ============================================================================ class StringValue(TypedDict): @@ -124,3 +131,213 @@ def build_expression( if right_operand: filter["right"] = right_operand return filter + + +# ============================================================================ +# Memory Record Metadata Filters (LTM) +# ============================================================================ + + +class MemoryRecordOperatorType(Enum): + """Operator applied to memory record metadata filter expressions. + + Each operator is paired with a specific right-operand value type. Mismatches + are rejected by the service — pass the right operand built via the matching + `MemoryRecordRightExpression.build_*` factory. + + | Operator | Right operand | Builder | + |--------------------------|-----------------------------|--------------------------| + | `EQUALS_TO` | string | `build_string` | + | `EXISTS` | (none) | — | + | `NOT_EXISTS` | (none) | — | + | `BEFORE` | datetime | `build_datetime` | + | `AFTER` | datetime | `build_datetime` | + | `CONTAINS` | string list | `build_string_list` | + | `GREATER_THAN` | number | `build_number` | + | `GREATER_THAN_OR_EQUALS` | number | `build_number` | + | `LESS_THAN` | number | `build_number` | + | `LESS_THAN_OR_EQUALS` | number | `build_number` | + """ + + EQUALS_TO = "EQUALS_TO" + EXISTS = "EXISTS" + NOT_EXISTS = "NOT_EXISTS" + BEFORE = "BEFORE" + AFTER = "AFTER" + CONTAINS = "CONTAINS" + GREATER_THAN = "GREATER_THAN" + GREATER_THAN_OR_EQUALS = "GREATER_THAN_OR_EQUALS" + LESS_THAN = "LESS_THAN" + LESS_THAN_OR_EQUALS = "LESS_THAN_OR_EQUALS" + + +class MemoryRecordLeftExpression(TypedDict): + """Left operand of the memory record metadata filter expression.""" + + metadataKey: str + + @staticmethod + def build(key: str) -> "MemoryRecordLeftExpression": + """Build a MemoryRecordLeftExpression from a key name.""" + return {"metadataKey": key} + + +class MemoryRecordRightExpression(TypedDict): + """Right operand of the memory record metadata filter expression. + + Variants: + - {"metadataValue": {"stringValue": str}} + - {"metadataValue": {"numberValue": float}} + - {"metadataValue": {"dateTimeValue": datetime}} + - {"metadataValue": {"stringListValue": List[str]}} + """ + + metadataValue: dict + + @staticmethod + def build_string(value: str) -> "MemoryRecordRightExpression": + """Build a right expression with a string value.""" + return {"metadataValue": {"stringValue": value}} + + @staticmethod + def build_number(value: Union[int, float]) -> "MemoryRecordRightExpression": + """Build a right expression with a numeric value.""" + return {"metadataValue": {"numberValue": value}} + + @staticmethod + def build_datetime(value: datetime) -> "MemoryRecordRightExpression": + """Build a right expression with a datetime value.""" + return {"metadataValue": {"dateTimeValue": value}} + + @staticmethod + def build_string_list(value: List[str]) -> "MemoryRecordRightExpression": + """Build a right expression with a string list value.""" + return {"metadataValue": {"stringListValue": value}} + + +class MemoryMetadataFilter(TypedDict): + """Filter expression for querying memory records by metadata. + + Used with `retrieve_memories()` and `list_memory_records()` to scope + results by indexed metadata keys before semantic search runs. + + Args: + left: `MemoryRecordLeftExpression` specifying the metadata key. + operator: `MemoryRecordOperatorType` defining the comparison. + right: Optional `MemoryRecordRightExpression` with the value to compare against. + Not required for EXISTS/NOT_EXISTS operators. + + Example: + ```python + filter = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.EQUALS_TO, + MemoryRecordRightExpression.build_string("high"), + ) + ``` + """ + + left: MemoryRecordLeftExpression + # Stored as the operator's string value (e.g. "EQUALS_TO"), not the enum itself, + # since this dict is serialized directly to the AgentCore service. + operator: str + right: NotRequired[MemoryRecordRightExpression] + + @staticmethod + def build_expression( + left_operand: "MemoryRecordLeftExpression", + operator: MemoryRecordOperatorType, + right_operand: Optional["MemoryRecordRightExpression"] = None, + ) -> "MemoryMetadataFilter": + """Build a memory metadata filter expression. + + Args: + left_operand: The metadata key to filter on. + operator: The comparison operator. + right_operand: The value to compare against. Required for all operators + except EXISTS and NOT_EXISTS, which must NOT receive a + right operand. + + Raises: + ValueError: If `right_operand` is supplied with EXISTS or NOT_EXISTS, + or if `right_operand` is missing for any other operator. + + Example: + ```python + left_operand = MemoryRecordLeftExpression.build("priority") + operator = MemoryRecordOperatorType.GREATER_THAN + right_operand = MemoryRecordRightExpression.build_number(3.0) + + filter = MemoryMetadataFilter.build_expression(left_operand, operator, right_operand) + # Result: + # { + # "left": {"metadataKey": "priority"}, + # "operator": "GREATER_THAN", + # "right": {"metadataValue": {"numberValue": 3.0}} + # } + ``` + """ + is_existence_op = operator in ( + MemoryRecordOperatorType.EXISTS, + MemoryRecordOperatorType.NOT_EXISTS, + ) + if is_existence_op and right_operand is not None: + raise ValueError(f"{operator.value} does not accept a right operand; the service rejects this combination.") + if not is_existence_op and right_operand is None: + raise ValueError(f"{operator.value} requires a right operand.") + + filter = {"left": left_operand, "operator": operator.value} + + if right_operand is not None: + filter["right"] = right_operand + return filter + + +# ============================================================================ +# Indexed Key Types (Control Plane) +# ============================================================================ + + +class MetadataValueType(Enum): + """Supported data types for indexed metadata key values.""" + + STRING = "STRING" + STRINGLIST = "STRINGLIST" + NUMBER = "NUMBER" + + +class IndexedKey(TypedDict): + r"""A metadata key indexed for filtering on memory records. + + Args: + key: The metadata key name. 1-128 characters. May contain alphanumeric + characters, whitespace, and the symbols `. _ : / = + @ -`. Pattern: + `[a-zA-Z0-9\s._:/=+@-]*`. + type: The data type of the indexed key value. + + Note: + Indexed keys are append-only on the AgentCore service: once an + indexed key is declared on a memory it cannot be removed. New keys + can be added via `update_memory(addIndexedKeys=...)`. + + Example: + ```python + indexed_keys = [ + IndexedKey.build("priority", MetadataValueType.NUMBER), + IndexedKey.build("agent_type", MetadataValueType.STRING), + ] + ``` + """ + + key: str + type: str + + @staticmethod + def build(key: str, value_type: MetadataValueType) -> "IndexedKey": + """Build an IndexedKey configuration. + + Args: + key: The metadata key name. + value_type: The MetadataValueType for this key. + """ + return {"key": key, "type": value_type.value} diff --git a/src/bedrock_agentcore/memory/session.py b/src/bedrock_agentcore/memory/session.py index 6f8a81c6..ecd9af40 100644 --- a/src/bedrock_agentcore/memory/session.py +++ b/src/bedrock_agentcore/memory/session.py @@ -21,6 +21,7 @@ Event, EventMessage, EventMetadataFilter, + MemoryMetadataFilter, MemoryRecord, MetadataValue, SessionSummary, @@ -903,6 +904,7 @@ def search_long_term_memories( max_results: int = 20, namespace: Optional[str] = None, namespace_path: Optional[str] = None, + metadata_filters: Optional[List[MemoryMetadataFilter]] = None, ) -> List[MemoryRecord]: """Performs a semantic search against the long-term memory for this actor. @@ -921,7 +923,16 @@ def search_long_term_memories( namespace: Exact-match namespace (preserves pre-redesign behavior during the service grace period) namespace_path: Hierarchical path-prefix namespace + metadata_filters: Optional list of metadata filter expressions to scope results + by indexed metadata keys before semantic search runs. The service accepts + 1-5 filters. An empty list is treated as no filter. + + Raises: + ValueError: If `metadata_filters` exceeds the service maximum of 5. """ + if metadata_filters is not None and len(metadata_filters) > 5: + raise ValueError(f"metadata_filters supports a maximum of 5 expressions; received {len(metadata_filters)}.") + resolved_namespace = resolve_namespace_prefix_deprecation(namespace_prefix, namespace) ns_params = build_namespace_params(resolved_namespace, namespace_path) ns_value = resolved_namespace or namespace_path @@ -930,6 +941,9 @@ def search_long_term_memories( search_criteria = {"searchQuery": query, "topK": top_k} if strategy_id: search_criteria["memoryStrategyId"] = strategy_id + if metadata_filters: + search_criteria["metadataFilters"] = metadata_filters + logger.debug("Applying %d metadata filter(s)", len(metadata_filters)) params = { "memoryId": self._memory_id, @@ -1252,6 +1266,7 @@ def search_long_term_memories( max_results: int = 20, namespace: Optional[str] = None, namespace_path: Optional[str] = None, + metadata_filters: Optional[List[MemoryMetadataFilter]] = None, ) -> List[MemoryRecord]: """Delegates to manager.search_long_term_memories.""" return self._manager.search_long_term_memories( @@ -1262,6 +1277,7 @@ def search_long_term_memories( max_results, namespace=namespace, namespace_path=namespace_path, + metadata_filters=metadata_filters, ) def list_long_term_memory_records( diff --git a/tests/bedrock_agentcore/memory/models/test_filters.py b/tests/bedrock_agentcore/memory/models/test_filters.py new file mode 100644 index 00000000..4e6ceeed --- /dev/null +++ b/tests/bedrock_agentcore/memory/models/test_filters.py @@ -0,0 +1,308 @@ +"""Unit tests for memory record metadata filter models.""" + +from datetime import datetime + +import pytest + +from bedrock_agentcore.memory.models import ( + IndexedKey, + MemoryMetadataFilter, + MemoryRecordLeftExpression, + MemoryRecordOperatorType, + MemoryRecordRightExpression, + MetadataValueType, +) + + +class TestMemoryRecordLeftExpression: + """Test cases for MemoryRecordLeftExpression.""" + + def test_build(self): + """Test building a left expression from a key name.""" + result = MemoryRecordLeftExpression.build("priority") + assert result == {"metadataKey": "priority"} + + def test_build_various_keys(self): + """Test building left expressions with various key names.""" + assert MemoryRecordLeftExpression.build("agent_type") == {"metadataKey": "agent_type"} + assert MemoryRecordLeftExpression.build("created_at") == {"metadataKey": "created_at"} + assert MemoryRecordLeftExpression.build("tags") == {"metadataKey": "tags"} + + +class TestMemoryRecordRightExpression: + """Test cases for MemoryRecordRightExpression.""" + + def test_build_string(self): + """Test building a string value right expression.""" + result = MemoryRecordRightExpression.build_string("high") + assert result == {"metadataValue": {"stringValue": "high"}} + + def test_build_string_empty(self): + """Test building a string value with empty string.""" + result = MemoryRecordRightExpression.build_string("") + assert result == {"metadataValue": {"stringValue": ""}} + + def test_build_number_integer(self): + """Test building a numeric right expression with integer-like float.""" + result = MemoryRecordRightExpression.build_number(5.0) + assert result == {"metadataValue": {"numberValue": 5.0}} + + def test_build_number_float(self): + """Test building a numeric right expression with decimal float.""" + result = MemoryRecordRightExpression.build_number(3.14) + assert result == {"metadataValue": {"numberValue": 3.14}} + + def test_build_number_zero(self): + """Test building a numeric right expression with zero.""" + result = MemoryRecordRightExpression.build_number(0.0) + assert result == {"metadataValue": {"numberValue": 0.0}} + + def test_build_number_negative(self): + """Test building a numeric right expression with negative value.""" + result = MemoryRecordRightExpression.build_number(-1.5) + assert result == {"metadataValue": {"numberValue": -1.5}} + + def test_build_datetime(self): + """Test building a datetime right expression.""" + dt = datetime(2024, 6, 15, 10, 30, 0) + result = MemoryRecordRightExpression.build_datetime(dt) + assert result == {"metadataValue": {"dateTimeValue": dt}} + + def test_build_string_list(self): + """Test building a string list right expression.""" + result = MemoryRecordRightExpression.build_string_list(["tag1", "tag2", "tag3"]) + assert result == {"metadataValue": {"stringListValue": ["tag1", "tag2", "tag3"]}} + + def test_build_string_list_single_item(self): + """Test building a string list with a single item.""" + result = MemoryRecordRightExpression.build_string_list(["only_one"]) + assert result == {"metadataValue": {"stringListValue": ["only_one"]}} + + def test_build_string_list_empty(self): + """Test building a string list with empty list.""" + result = MemoryRecordRightExpression.build_string_list([]) + assert result == {"metadataValue": {"stringListValue": []}} + + +class TestMemoryRecordOperatorType: + """Test cases for MemoryRecordOperatorType enum.""" + + def test_all_operators_exist(self): + """Test that all expected operator types are defined.""" + assert MemoryRecordOperatorType.EQUALS_TO.value == "EQUALS_TO" + assert MemoryRecordOperatorType.EXISTS.value == "EXISTS" + assert MemoryRecordOperatorType.NOT_EXISTS.value == "NOT_EXISTS" + assert MemoryRecordOperatorType.BEFORE.value == "BEFORE" + assert MemoryRecordOperatorType.AFTER.value == "AFTER" + assert MemoryRecordOperatorType.CONTAINS.value == "CONTAINS" + assert MemoryRecordOperatorType.GREATER_THAN.value == "GREATER_THAN" + assert MemoryRecordOperatorType.GREATER_THAN_OR_EQUALS.value == "GREATER_THAN_OR_EQUALS" + assert MemoryRecordOperatorType.LESS_THAN.value == "LESS_THAN" + assert MemoryRecordOperatorType.LESS_THAN_OR_EQUALS.value == "LESS_THAN_OR_EQUALS" + + def test_operator_count(self): + """Test that exactly 10 operators are defined.""" + assert len(MemoryRecordOperatorType) == 10 + + +class TestMemoryMetadataFilter: + """Test cases for MemoryMetadataFilter.""" + + def test_build_expression_equals_string(self): + """Test building an EQUALS_TO filter with a string value.""" + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("agent_type"), + MemoryRecordOperatorType.EQUALS_TO, + MemoryRecordRightExpression.build_string("support"), + ) + + assert result == { + "left": {"metadataKey": "agent_type"}, + "operator": "EQUALS_TO", + "right": {"metadataValue": {"stringValue": "support"}}, + } + + def test_build_expression_greater_than_number(self): + """Test building a GREATER_THAN filter with a numeric value.""" + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.GREATER_THAN, + MemoryRecordRightExpression.build_number(3.0), + ) + + assert result == { + "left": {"metadataKey": "priority"}, + "operator": "GREATER_THAN", + "right": {"metadataValue": {"numberValue": 3.0}}, + } + + def test_build_expression_less_than_or_equals(self): + """Test building a LESS_THAN_OR_EQUALS filter.""" + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("score"), + MemoryRecordOperatorType.LESS_THAN_OR_EQUALS, + MemoryRecordRightExpression.build_number(100.0), + ) + + assert result == { + "left": {"metadataKey": "score"}, + "operator": "LESS_THAN_OR_EQUALS", + "right": {"metadataValue": {"numberValue": 100.0}}, + } + + def test_build_expression_before_datetime(self): + """Test building a BEFORE filter with a datetime value.""" + dt = datetime(2024, 1, 1, 0, 0, 0) + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("created_at"), + MemoryRecordOperatorType.BEFORE, + MemoryRecordRightExpression.build_datetime(dt), + ) + + assert result == { + "left": {"metadataKey": "created_at"}, + "operator": "BEFORE", + "right": {"metadataValue": {"dateTimeValue": dt}}, + } + + def test_build_expression_after_datetime(self): + """Test building an AFTER filter with a datetime value.""" + dt = datetime(2024, 12, 31, 23, 59, 59) + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("updated_at"), + MemoryRecordOperatorType.AFTER, + MemoryRecordRightExpression.build_datetime(dt), + ) + + assert result == { + "left": {"metadataKey": "updated_at"}, + "operator": "AFTER", + "right": {"metadataValue": {"dateTimeValue": dt}}, + } + + def test_build_expression_contains_string_list(self): + """Test building a CONTAINS filter with a string list.""" + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("tags"), + MemoryRecordOperatorType.CONTAINS, + MemoryRecordRightExpression.build_string_list(["urgent", "follow-up"]), + ) + + assert result == { + "left": {"metadataKey": "tags"}, + "operator": "CONTAINS", + "right": {"metadataValue": {"stringListValue": ["urgent", "follow-up"]}}, + } + + def test_build_expression_exists_no_right_operand(self): + """Test building an EXISTS filter without a right operand.""" + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("metadata_key"), + MemoryRecordOperatorType.EXISTS, + ) + + assert result == { + "left": {"metadataKey": "metadata_key"}, + "operator": "EXISTS", + } + assert "right" not in result + + def test_build_expression_not_exists_no_right_operand(self): + """Test building a NOT_EXISTS filter without a right operand.""" + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("optional_field"), + MemoryRecordOperatorType.NOT_EXISTS, + ) + + assert result == { + "left": {"metadataKey": "optional_field"}, + "operator": "NOT_EXISTS", + } + assert "right" not in result + + def test_build_expression_exists_with_none_right_operand(self): + """Test building EXISTS filter with explicit None right operand.""" + result = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("some_key"), + MemoryRecordOperatorType.EXISTS, + None, + ) + + assert result == { + "left": {"metadataKey": "some_key"}, + "operator": "EXISTS", + } + assert "right" not in result + + def test_build_expression_rejects_right_operand_with_exists(self): + """EXISTS rejects a right operand at build time.""" + with pytest.raises(ValueError, match="EXISTS does not accept a right operand"): + MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.EXISTS, + MemoryRecordRightExpression.build_string("high"), + ) + + def test_build_expression_rejects_right_operand_with_not_exists(self): + """NOT_EXISTS rejects a right operand at build time.""" + with pytest.raises(ValueError, match="NOT_EXISTS does not accept a right operand"): + MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.NOT_EXISTS, + MemoryRecordRightExpression.build_string("high"), + ) + + def test_build_expression_requires_right_operand_for_comparison_operators(self): + """Non-existence operators raise when right operand is missing.""" + with pytest.raises(ValueError, match="EQUALS_TO requires a right operand"): + MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.EQUALS_TO, + ) + + +class TestMetadataValueType: + """Test cases for MetadataValueType enum.""" + + def test_all_types_exist(self): + """Test that all expected value types are defined.""" + assert MetadataValueType.STRING.value == "STRING" + assert MetadataValueType.STRINGLIST.value == "STRINGLIST" + assert MetadataValueType.NUMBER.value == "NUMBER" + + def test_type_count(self): + """Test that exactly 3 value types are defined.""" + assert len(MetadataValueType) == 3 + + +class TestIndexedKey: + """Test cases for IndexedKey.""" + + def test_build_string_key(self): + """Test building an indexed key with STRING type.""" + result = IndexedKey.build("agent_type", MetadataValueType.STRING) + assert result == {"key": "agent_type", "type": "STRING"} + + def test_build_number_key(self): + """Test building an indexed key with NUMBER type.""" + result = IndexedKey.build("priority", MetadataValueType.NUMBER) + assert result == {"key": "priority", "type": "NUMBER"} + + def test_build_stringlist_key(self): + """Test building an indexed key with STRINGLIST type.""" + result = IndexedKey.build("tags", MetadataValueType.STRINGLIST) + assert result == {"key": "tags", "type": "STRINGLIST"} + + def test_build_multiple_keys(self): + """Test building a list of indexed keys for create_memory.""" + indexed_keys = [ + IndexedKey.build("priority", MetadataValueType.NUMBER), + IndexedKey.build("agent_type", MetadataValueType.STRING), + IndexedKey.build("categories", MetadataValueType.STRINGLIST), + ] + + assert indexed_keys == [ + {"key": "priority", "type": "NUMBER"}, + {"key": "agent_type", "type": "STRING"}, + {"key": "categories", "type": "STRINGLIST"}, + ] diff --git a/tests/bedrock_agentcore/memory/test_client.py b/tests/bedrock_agentcore/memory/test_client.py index 8c95d709..f7df4c4a 100644 --- a/tests/bedrock_agentcore/memory/test_client.py +++ b/tests/bedrock_agentcore/memory/test_client.py @@ -12,6 +12,12 @@ from bedrock_agentcore.memory import MemoryClient from bedrock_agentcore.memory.constants import StrategyType +from bedrock_agentcore.memory.models import ( + MemoryMetadataFilter, + MemoryRecordLeftExpression, + MemoryRecordOperatorType, + MemoryRecordRightExpression, +) def test_client_initialization(): @@ -3498,3 +3504,242 @@ def test_get_last_k_turns_explicit_max_results(): # Total events fetched should not exceed max_results total_fetched = sum(1 for _ in mock_gmdp.list_events.call_args_list) assert total_fetched <= 50 # Should stop after fetching 50 events worth of calls + + +# ============================================================================ +# LTM Metadata: indexed_keys and metadata_filters tests +# ============================================================================ + + +def test_create_memory_with_indexed_keys(): + """Test create_memory passes indexedKeys to gmcp_client when provided.""" + with patch("boto3.Session"): + client = MemoryClient() + + mock_gmcp = MagicMock() + client.gmcp_client = mock_gmcp + + mock_gmcp.create_memory.return_value = {"memory": {"memoryId": "mem-idx-1", "status": "CREATING"}} + + indexed_keys = [ + {"key": "priority", "type": "NUMBER"}, + {"key": "agent_type", "type": "STRING"}, + {"key": "tags", "type": "STRINGLIST"}, + ] + + with patch("uuid.uuid4", return_value=uuid.UUID("12345678-1234-5678-1234-567812345678")): + result = client.create_memory( + name="IndexedMemory", + strategies=[{StrategyType.SEMANTIC.value: {"name": "TestStrategy"}}], + indexed_keys=indexed_keys, + ) + + assert result["memoryId"] == "mem-idx-1" + assert mock_gmcp.create_memory.called + + args, kwargs = mock_gmcp.create_memory.call_args + assert kwargs["indexedKeys"] == indexed_keys + assert kwargs["name"] == "IndexedMemory" + + +def test_create_memory_without_indexed_keys(): + """Test create_memory does not include indexedKeys when not provided.""" + with patch("boto3.Session"): + client = MemoryClient() + + mock_gmcp = MagicMock() + client.gmcp_client = mock_gmcp + + mock_gmcp.create_memory.return_value = {"memory": {"memoryId": "mem-no-idx", "status": "CREATING"}} + + with patch("uuid.uuid4", return_value=uuid.UUID("12345678-1234-5678-1234-567812345678")): + result = client.create_memory( + name="NoIndexMemory", + strategies=[{StrategyType.SEMANTIC.value: {"name": "TestStrategy"}}], + ) + + assert result["memoryId"] == "mem-no-idx" + + args, kwargs = mock_gmcp.create_memory.call_args + assert "indexedKeys" not in kwargs + + +def test_create_memory_and_wait_with_indexed_keys(): + """Test create_memory_and_wait passes indexed_keys through to create_memory.""" + with patch("boto3.Session"): + client = MemoryClient() + + mock_gmcp = MagicMock() + client.gmcp_client = mock_gmcp + + mock_gmcp.create_memory.return_value = {"memory": {"memoryId": "mem-wait-idx", "status": "CREATING"}} + mock_gmcp.get_memory.return_value = { + "memory": {"memoryId": "mem-wait-idx", "status": "ACTIVE", "name": "WaitIndexed"} + } + + indexed_keys = [{"key": "category", "type": "STRING"}] + + with patch("time.time", return_value=0): + with patch("time.sleep"): + with patch("uuid.uuid4", return_value=uuid.UUID("12345678-1234-5678-1234-567812345678")): + result = client.create_memory_and_wait( + name="WaitIndexed", + strategies=[{StrategyType.SEMANTIC.value: {"name": "TestStrategy"}}], + indexed_keys=indexed_keys, + ) + + assert result["memoryId"] == "mem-wait-idx" + assert result["status"] == "ACTIVE" + + args, kwargs = mock_gmcp.create_memory.call_args + assert kwargs["indexedKeys"] == indexed_keys + + +def test_create_or_get_memory_with_indexed_keys(): + """Test create_or_get_memory passes indexed_keys through.""" + with patch("boto3.Session"): + client = MemoryClient() + + mock_gmcp = MagicMock() + client.gmcp_client = mock_gmcp + + mock_gmcp.create_memory.return_value = {"memory": {"memoryId": "mem-cog-idx", "status": "CREATING"}} + mock_gmcp.get_memory.return_value = { + "memory": {"memoryId": "mem-cog-idx", "status": "ACTIVE", "name": "COGIndexed"} + } + + indexed_keys = [{"key": "source", "type": "STRING"}] + + with patch("time.time", return_value=0): + with patch("time.sleep"): + with patch("uuid.uuid4", return_value=uuid.UUID("12345678-1234-5678-1234-567812345678")): + result = client.create_or_get_memory( + name="COGIndexed", + strategies=[{StrategyType.SEMANTIC.value: {"name": "TestStrategy"}}], + indexed_keys=indexed_keys, + ) + + assert result["memoryId"] == "mem-cog-idx" + + args, kwargs = mock_gmcp.create_memory.call_args + assert kwargs["indexedKeys"] == indexed_keys + + +def test_retrieve_memories_with_metadata_filters(): + """Test retrieve_memories includes metadataFilters in searchCriteria when provided.""" + with patch("boto3.Session"): + client = MemoryClient() + + mock_gmdp = MagicMock() + client.gmdp_client = mock_gmdp + + mock_gmdp.retrieve_memory_records.return_value = { + "memoryRecordSummaries": [{"content": {"text": "Filtered memory"}, "memoryRecordId": "rec-f1"}] + } + + metadata_filters = [ + { + "left": {"metadataKey": "priority"}, + "operator": "EQUALS_TO", + "right": {"metadataValue": {"stringValue": "high"}}, + }, + { + "left": {"metadataKey": "score"}, + "operator": "GREATER_THAN", + "right": {"metadataValue": {"numberValue": 0.8}}, + }, + ] + + memories = client.retrieve_memories( + memory_id="mem-123", + namespace="test/namespace/", + query="important items", + top_k=5, + metadata_filters=metadata_filters, + ) + + assert len(memories) == 1 + assert memories[0]["memoryRecordId"] == "rec-f1" + + args, kwargs = mock_gmdp.retrieve_memory_records.call_args + assert kwargs["memoryId"] == "mem-123" + assert kwargs["searchCriteria"]["searchQuery"] == "important items" + assert kwargs["searchCriteria"]["topK"] == 5 + assert kwargs["searchCriteria"]["metadataFilters"] == metadata_filters + + +def test_retrieve_memories_without_metadata_filters(): + """Test retrieve_memories does not include metadataFilters when not provided.""" + with patch("boto3.Session"): + client = MemoryClient() + + mock_gmdp = MagicMock() + client.gmdp_client = mock_gmdp + + mock_gmdp.retrieve_memory_records.return_value = { + "memoryRecordSummaries": [{"content": {"text": "Unfiltered memory"}, "memoryRecordId": "rec-u1"}] + } + + memories = client.retrieve_memories( + memory_id="mem-123", + namespace="test/namespace/", + query="general query", + top_k=3, + ) + + assert len(memories) == 1 + + args, kwargs = mock_gmdp.retrieve_memory_records.call_args + assert kwargs["searchCriteria"] == {"searchQuery": "general query", "topK": 3} + assert "metadataFilters" not in kwargs["searchCriteria"] + + +def test_retrieve_memories_with_builder_constructed_filter(): + """End-to-end: filter built via MemoryMetadataFilter.build_expression flows through to boto3.""" + with patch("boto3.Session"): + client = MemoryClient() + mock_gmdp = MagicMock() + client.gmdp_client = mock_gmdp + mock_gmdp.retrieve_memory_records.return_value = {"memoryRecordSummaries": []} + + built_filter = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.EQUALS_TO, + MemoryRecordRightExpression.build_string("high"), + ) + + client.retrieve_memories( + memory_id="mem-123", + namespace="test/namespace/", + query="x", + metadata_filters=[built_filter], + ) + + _, kwargs = mock_gmdp.retrieve_memory_records.call_args + assert kwargs["searchCriteria"]["metadataFilters"] == [ + { + "left": {"metadataKey": "priority"}, + "operator": "EQUALS_TO", + "right": {"metadataValue": {"stringValue": "high"}}, + } + ] + + +def test_retrieve_memories_rejects_more_than_five_filters(): + """retrieve_memories raises ValueError when given more than 5 filters.""" + with patch("boto3.Session"): + client = MemoryClient() + client.gmdp_client = MagicMock() + + too_many = [ + MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build(f"k{i}"), + MemoryRecordOperatorType.EXISTS, + ) + for i in range(6) + ] + + with pytest.raises(ValueError, match="maximum of 5"): + client.retrieve_memories( + memory_id="mem-123", namespace="test/namespace/", query="x", metadata_filters=too_many + ) diff --git a/tests/bedrock_agentcore/memory/test_session.py b/tests/bedrock_agentcore/memory/test_session.py index a03a4a98..786a15b3 100644 --- a/tests/bedrock_agentcore/memory/test_session.py +++ b/tests/bedrock_agentcore/memory/test_session.py @@ -1902,9 +1902,43 @@ def test_session_search_long_term_memories_delegation(self): assert result == mock_records mock_search.assert_called_once_with( - "test query", "test/namespace/", 3, None, 20, namespace=None, namespace_path=None + "test query", + "test/namespace/", + 3, + None, + 20, + namespace=None, + namespace_path=None, + metadata_filters=None, ) + def test_session_search_long_term_memories_forwards_metadata_filters(self): + """MemorySession.search_long_term_memories forwards metadata_filters to the manager.""" + from bedrock_agentcore.memory.models import ( + MemoryMetadataFilter, + MemoryRecordLeftExpression, + MemoryRecordOperatorType, + MemoryRecordRightExpression, + ) + + with patch("boto3.Session"): + manager = MemorySessionManager(memory_id="testMemory-1234567890", region_name="us-west-2") + session = MemorySession( + memory_id="testMemory-1234567890", actor_id="user-123", session_id="session-456", manager=manager + ) + + built_filter = MemoryMetadataFilter.build_expression( + MemoryRecordLeftExpression.build("priority"), + MemoryRecordOperatorType.EQUALS_TO, + MemoryRecordRightExpression.build_string("high"), + ) + + with patch.object(manager, "search_long_term_memories", return_value=[]) as mock_search: + session.search_long_term_memories(query="q", namespace="test/", metadata_filters=[built_filter]) + + _, kwargs = mock_search.call_args + assert kwargs["metadata_filters"] == [built_filter] + def test_session_list_long_term_memory_records_delegation(self): """Test MemorySession.list_long_term_memory_records delegates to manager.""" with patch("boto3.Session"):