diff --git a/src/aws_durable_execution_sdk_python/__init__.py b/src/aws_durable_execution_sdk_python/__init__.py index 8c3e46c..6deb42c 100644 --- a/src/aws_durable_execution_sdk_python/__init__.py +++ b/src/aws_durable_execution_sdk_python/__init__.py @@ -18,6 +18,13 @@ # Core decorator - used in every durable function from aws_durable_execution_sdk_python.execution import durable_execution +# Serialization - for custom input/output serialization +from aws_durable_execution_sdk_python.serdes import ( + ExtendedTypeSerDes, + JsonSerDes, + SerDes, +) + # Essential context types - passed to user functions from aws_durable_execution_sdk_python.types import BatchResult, StepContext @@ -25,7 +32,10 @@ "BatchResult", "DurableContext", "DurableExecutionsError", + "ExtendedTypeSerDes", "InvocationError", + "JsonSerDes", + "SerDes", "StepContext", "ValidationError", "durable_execution", diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index a31e227..b1464e7 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -27,6 +27,12 @@ OperationType, OperationUpdate, ) +from aws_durable_execution_sdk_python.serdes import ( + JsonSerDes, + SerDes, + deserialize, + serialize, +) if TYPE_CHECKING: from collections.abc import Callable, MutableMapping @@ -206,13 +212,22 @@ def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, boto3_client: boto3.client | None = None, + input_serdes: SerDes | None = None, + output_serdes: SerDes | None = None, ) -> Callable[[Any, LambdaContext], Any]: # Decorator called with parameters if func is None: logger.debug("Decorator called with parameters") - return functools.partial(durable_execution, boto3_client=boto3_client) + return functools.partial( + durable_execution, + boto3_client=boto3_client, + input_serdes=input_serdes, + output_serdes=output_serdes, + ) logger.debug("Starting durable execution handler...") + input_serdes = input_serdes or JsonSerDes() + output_serdes = output_serdes or JsonSerDes() def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input: DurableExecutionInvocationInput @@ -250,18 +265,14 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input.initial_execution_state.get_input_payload() ) - # Python RIC LambdaMarshaller just uses standard json deserialization for event - # https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46 input_event: MutableMapping[str, Any] = {} if raw_input_payload and raw_input_payload.strip(): - try: - input_event = json.loads(raw_input_payload) - except json.JSONDecodeError: - logger.exception( - "Failed to parse input payload as JSON: payload: %r", - raw_input_payload, - ) - raise + input_event = deserialize( + serdes=input_serdes, + data=raw_input_payload, + operation_id="EXECUTION", + durable_execution_arn=invocation_input.durable_execution_arn, + ) execution_state: ExecutionState = ExecutionState( durable_execution_arn=invocation_input.durable_execution_arn, @@ -310,7 +321,15 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: "%s exiting user-space...", invocation_input.durable_execution_arn, ) - serialized_result = json.dumps(result) + + # Serialize result using output_serdes if provided + serialized_result = serialize( + serdes=output_serdes, + value=result, + operation_id="EXECUTION", + durable_execution_arn=invocation_input.durable_execution_arn, + ) + # large response handling here. Remember if checkpointing to complete, NOT to include # payload in response if ( diff --git a/tests/execution_test.py b/tests/execution_test.py index 4d11298..0a0f215 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -3,13 +3,19 @@ import datetime import json import time +from decimal import Decimal from typing import Any from unittest.mock import Mock, patch import pytest +from aws_durable_execution_sdk_python import ( + DurableContext, + ExtendedTypeSerDes, + JsonSerDes, + durable_execution, +) from aws_durable_execution_sdk_python.config import StepConfig, StepSemantics -from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.exceptions import ( BotoClientError, CheckpointError, @@ -24,7 +30,6 @@ DurableExecutionInvocationOutput, InitialExecutionState, InvocationStatus, - durable_execution, ) # LambdaContext no longer needed - using duck typing @@ -946,7 +951,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: lambda_context.invoked_function_arn = None lambda_context.tenant_id = None - with pytest.raises(json.JSONDecodeError): + with pytest.raises(ExecutionError): test_handler(invocation_input, lambda_context) @@ -2071,3 +2076,106 @@ def test_handler(event: Any, context: DurableContext) -> dict: match="The payload is not the correct Durable Function input", ): test_handler(non_dict_event, lambda_context) + + +# region SERDES + + +@pytest.mark.parametrize( + ("input_serdes", "output_serdes", "input_data", "expected_output"), + [ + # Both ExtendedTypeSerDes + ( + ExtendedTypeSerDes(), + ExtendedTypeSerDes(), + { + "amount": Decimal("123.45"), + "timestamp": datetime.datetime( + 2025, 11, 20, 18, 0, 0, tzinfo=datetime.UTC + ), + }, + { + "result_amount": Decimal("678.90"), + "processed_at": datetime.datetime( + 2025, 11, 20, 19, 0, 0, tzinfo=datetime.UTC + ), + }, + ), + # Both JsonSerDes + ( + JsonSerDes(), + JsonSerDes(), + {"name": "test", "value": 42}, + {"result": "success", "count": 100}, + ), + # Input ExtendedTypeSerDes, Output JsonSerDes + ( + ExtendedTypeSerDes(), + JsonSerDes(), + {"amount": Decimal("123.45")}, + {"result": "success"}, + ), + # Input JsonSerDes, Output ExtendedTypeSerDes + ( + JsonSerDes(), + ExtendedTypeSerDes(), + {"name": "test"}, + { + "timestamp": datetime.datetime( + 2025, 11, 20, 19, 0, 0, tzinfo=datetime.UTC + ) + }, + ), + ], +) +def test_durable_execution_with_serdes( + input_serdes, output_serdes, input_data, expected_output +): + """Test that input_serdes and output_serdes are invoked correctly.""" + serialized_input = input_serdes.serialize(input_data, None) + + mock_client = Mock(spec=DurableServiceClient) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new-token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[], next_marker=None + ), + ) + + execution_op = Operation( + operation_id="EXECUTION", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload=serialized_input), + ) + + invocation_input = DurableExecutionInvocationInputWithClient( + durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test", + checkpoint_token="initial-token", # noqa: S106 + initial_execution_state=InitialExecutionState( + operations=[execution_op], + next_marker="", + ), + is_local_runner=False, + service_client=mock_client, + ) + + with ( + patch.object( + input_serdes, "deserialize", wraps=input_serdes.deserialize + ) as mock_input_deser, + patch.object( + output_serdes, "serialize", wraps=output_serdes.serialize + ) as mock_output_ser, + ): + + @durable_execution(input_serdes=input_serdes, output_serdes=output_serdes) + def handler(event, context: DurableContext): + return expected_output + + handler(invocation_input, Mock()) + + mock_input_deser.assert_called_once() + mock_output_ser.assert_called_once_with( + expected_output, mock_output_ser.call_args[0][1] + )