diff --git a/ops/__tests__/test_parse_sdk_branch.py b/ops/__tests__/test_parse_sdk_branch.py index cdae4bc..d458651 100755 --- a/ops/__tests__/test_parse_sdk_branch.py +++ b/ops/__tests__/test_parse_sdk_branch.py @@ -73,12 +73,12 @@ def test(): for input_text, expected in test_cases: result = parse_sdk_branch(input_text) - if result != expected: - return False - - return True + # Assert is expected in test functions + assert result == expected, ( # noqa: S101 + f"Expected '{expected}' but got '{result}' for input: {input_text[:50]}..." + ) if __name__ == "__main__": - success = test_parse_sdk_branch() - sys.exit(0 if success else 1) + test_parse_sdk_branch() + sys.exit(0) diff --git a/src/aws_durable_execution_sdk_python/config.py b/src/aws_durable_execution_sdk_python/config.py index e9945b6..42a5d42 100644 --- a/src/aws_durable_execution_sdk_python/config.py +++ b/src/aws_durable_execution_sdk_python/config.py @@ -392,10 +392,12 @@ class InvokeConfig(Generic[P, R]): from blocking execution indefinitely. serdes_payload: Custom serialization/deserialization for the payload - sent to the invoked function. If None, uses default JSON serialization. + sent to the invoked function. Defaults to DEFAULT_JSON_SERDES when + not set. serdes_result: Custom serialization/deserialization for the result - returned from the invoked function. If None, uses default JSON serialization. + returned from the invoked function. Defaults to DEFAULT_JSON_SERDES when + not set. tenant_id: Optional tenant identifier for multi-tenant isolation. If provided, the invocation will be scoped to this tenant. diff --git a/src/aws_durable_execution_sdk_python/operation/invoke.py b/src/aws_durable_execution_sdk_python/operation/invoke.py index 924f2e4..4b1eb99 100644 --- a/src/aws_durable_execution_sdk_python/operation/invoke.py +++ b/src/aws_durable_execution_sdk_python/operation/invoke.py @@ -11,7 +11,11 @@ ChainedInvokeOptions, OperationUpdate, ) -from aws_durable_execution_sdk_python.serdes import deserialize, serialize +from aws_durable_execution_sdk_python.serdes import ( + DEFAULT_JSON_SERDES, + deserialize, + serialize, +) from aws_durable_execution_sdk_python.suspend import suspend_with_optional_resume_delay if TYPE_CHECKING: @@ -53,7 +57,7 @@ def invoke_handler( and checkpointed_result.operation.chained_invoke_details.result ): return deserialize( - serdes=config.serdes_result, + serdes=config.serdes_result or DEFAULT_JSON_SERDES, data=checkpointed_result.operation.chained_invoke_details.result, operation_id=operation_identifier.operation_id, durable_execution_arn=state.durable_execution_arn, @@ -78,7 +82,7 @@ def invoke_handler( suspend_with_optional_resume_delay(msg, config.timeout_seconds) serialized_payload: str = serialize( - serdes=config.serdes_payload, + serdes=config.serdes_payload or DEFAULT_JSON_SERDES, value=payload, operation_id=operation_identifier.operation_id, durable_execution_arn=state.durable_execution_arn, diff --git a/src/aws_durable_execution_sdk_python/serdes.py b/src/aws_durable_execution_sdk_python/serdes.py index d589629..b3b704a 100644 --- a/src/aws_durable_execution_sdk_python/serdes.py +++ b/src/aws_durable_execution_sdk_python/serdes.py @@ -441,8 +441,8 @@ def _to_json_serializable(self, obj: Any) -> Any: return obj -_DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes() -_EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes() +DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes() +EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes() def serialize( @@ -463,7 +463,7 @@ def serialize( FatalError: If serialization fails """ serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn) - active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES + active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES try: return active_serdes.serialize(value, serdes_context) except Exception as e: @@ -493,7 +493,7 @@ def deserialize( FatalError: If deserialization fails """ serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn) - active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES + active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES try: return active_serdes.deserialize(data, serdes_context) except Exception as e: diff --git a/tests/operation/invoke_test.py b/tests/operation/invoke_test.py index 1625b4d..ac8a86b 100644 --- a/tests/operation/invoke_test.py +++ b/tests/operation/invoke_test.py @@ -612,3 +612,57 @@ def test_invoke_handler_default_config_no_tenant_id(): chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"] assert chained_invoke_options["FunctionName"] == "test_function" assert "TenantId" not in chained_invoke_options + + +def test_invoke_handler_defaults_to_json_serdes(): + """Test invoke_handler uses DEFAULT_JSON_SERDES when config has no serdes.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + mock_state.get_checkpoint_result.return_value = ( + CheckpointedResult.create_not_found() + ) + + config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None) + payload = {"key": "value", "number": 42} + + with pytest.raises(SuspendExecution): + invoke_handler( + function_name="test_function", + payload=payload, + state=mock_state, + operation_identifier=OperationIdentifier("invoke_json", None, None), + config=config, + ) + + # Verify JSON serialization was used (not extended types) + operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"] + assert operation_update.payload == json.dumps(payload) + + +def test_invoke_handler_result_defaults_to_json_serdes(): + """Test invoke_handler uses DEFAULT_JSON_SERDES for result deserialization.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + result_data = {"key": "value", "number": 42} + operation = Operation( + operation_id="invoke_result_json", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + chained_invoke_details=ChainedInvokeDetails(result=json.dumps(result_data)), + ) + mock_result = CheckpointedResult.create_from_operation(operation) + mock_state.get_checkpoint_result.return_value = mock_result + + config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None) + + result = invoke_handler( + function_name="test_function", + payload={"input": "data"}, + state=mock_state, + operation_identifier=OperationIdentifier("invoke_result_json", None, None), + config=config, + ) + + # Verify JSON deserialization was used (not extended types) + assert result == result_data