Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/aws_durable_execution_sdk_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,24 @@
# Core decorator - used in every durable function
from aws_durable_execution_sdk_python.execution import durable_execution

# Serialization - for custom input/output serialization
from aws_durable_execution_sdk_python.serdes import (
ExtendedTypeSerDes,
JsonSerDes,
SerDes,
)

# Essential context types - passed to user functions
from aws_durable_execution_sdk_python.types import BatchResult, StepContext

__all__ = [
"BatchResult",
"DurableContext",
"DurableExecutionsError",
"ExtendedTypeSerDes",
"InvocationError",
"JsonSerDes",
"SerDes",
"StepContext",
"ValidationError",
"durable_execution",
Expand Down
43 changes: 31 additions & 12 deletions src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
OperationType,
OperationUpdate,
)
from aws_durable_execution_sdk_python.serdes import (
JsonSerDes,
SerDes,
deserialize,
serialize,
)

if TYPE_CHECKING:
from collections.abc import Callable, MutableMapping
Expand Down Expand Up @@ -206,13 +212,22 @@ def durable_execution(
func: Callable[[Any, DurableContext], Any] | None = None,
*,
boto3_client: boto3.client | None = None,
input_serdes: SerDes | None = None,
output_serdes: SerDes | None = None,
) -> Callable[[Any, LambdaContext], Any]:
# Decorator called with parameters
if func is None:
logger.debug("Decorator called with parameters")
return functools.partial(durable_execution, boto3_client=boto3_client)
return functools.partial(
durable_execution,
boto3_client=boto3_client,
input_serdes=input_serdes,
output_serdes=output_serdes,
)

logger.debug("Starting durable execution handler...")
input_serdes = input_serdes or JsonSerDes()
output_serdes = output_serdes or JsonSerDes()

def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
invocation_input: DurableExecutionInvocationInput
Expand Down Expand Up @@ -250,18 +265,14 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
invocation_input.initial_execution_state.get_input_payload()
)

# Python RIC LambdaMarshaller just uses standard json deserialization for event
# https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46
input_event: MutableMapping[str, Any] = {}
if raw_input_payload and raw_input_payload.strip():
try:
input_event = json.loads(raw_input_payload)
except json.JSONDecodeError:
logger.exception(
"Failed to parse input payload as JSON: payload: %r",
raw_input_payload,
)
raise
input_event = deserialize(
serdes=input_serdes,
data=raw_input_payload,
operation_id="EXECUTION",
durable_execution_arn=invocation_input.durable_execution_arn,
)

execution_state: ExecutionState = ExecutionState(
durable_execution_arn=invocation_input.durable_execution_arn,
Expand Down Expand Up @@ -310,7 +321,15 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
"%s exiting user-space...",
invocation_input.durable_execution_arn,
)
serialized_result = json.dumps(result)

# Serialize result using output_serdes if provided
serialized_result = serialize(
serdes=output_serdes,
value=result,
operation_id="EXECUTION",
durable_execution_arn=invocation_input.durable_execution_arn,
)

# large response handling here. Remember if checkpointing to complete, NOT to include
# payload in response
if (
Expand Down
114 changes: 111 additions & 3 deletions tests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import datetime
import json
import time
from decimal import Decimal
from typing import Any
from unittest.mock import Mock, patch

import pytest

from aws_durable_execution_sdk_python import (
DurableContext,
ExtendedTypeSerDes,
JsonSerDes,
durable_execution,
)
from aws_durable_execution_sdk_python.config import StepConfig, StepSemantics
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.exceptions import (
BotoClientError,
CheckpointError,
Expand All @@ -24,7 +30,6 @@
DurableExecutionInvocationOutput,
InitialExecutionState,
InvocationStatus,
durable_execution,
)

# LambdaContext no longer needed - using duck typing
Expand Down Expand Up @@ -946,7 +951,7 @@ def test_handler(event: Any, context: DurableContext) -> dict:
lambda_context.invoked_function_arn = None
lambda_context.tenant_id = None

with pytest.raises(json.JSONDecodeError):
with pytest.raises(ExecutionError):
test_handler(invocation_input, lambda_context)


Expand Down Expand Up @@ -2071,3 +2076,106 @@ def test_handler(event: Any, context: DurableContext) -> dict:
match="The payload is not the correct Durable Function input",
):
test_handler(non_dict_event, lambda_context)


# region SERDES


@pytest.mark.parametrize(
("input_serdes", "output_serdes", "input_data", "expected_output"),
[
# Both ExtendedTypeSerDes
(
ExtendedTypeSerDes(),
ExtendedTypeSerDes(),
{
"amount": Decimal("123.45"),
"timestamp": datetime.datetime(
2025, 11, 20, 18, 0, 0, tzinfo=datetime.UTC
),
},
{
"result_amount": Decimal("678.90"),
"processed_at": datetime.datetime(
2025, 11, 20, 19, 0, 0, tzinfo=datetime.UTC
),
},
),
# Both JsonSerDes
(
JsonSerDes(),
JsonSerDes(),
{"name": "test", "value": 42},
{"result": "success", "count": 100},
),
# Input ExtendedTypeSerDes, Output JsonSerDes
(
ExtendedTypeSerDes(),
JsonSerDes(),
{"amount": Decimal("123.45")},
{"result": "success"},
),
# Input JsonSerDes, Output ExtendedTypeSerDes
(
JsonSerDes(),
ExtendedTypeSerDes(),
{"name": "test"},
{
"timestamp": datetime.datetime(
2025, 11, 20, 19, 0, 0, tzinfo=datetime.UTC
)
},
),
],
)
def test_durable_execution_with_serdes(
input_serdes, output_serdes, input_data, expected_output
):
"""Test that input_serdes and output_serdes are invoked correctly."""
serialized_input = input_serdes.serialize(input_data, None)

mock_client = Mock(spec=DurableServiceClient)
mock_client.checkpoint.return_value = CheckpointOutput(
checkpoint_token="new-token", # noqa: S106
new_execution_state=CheckpointUpdatedExecutionState(
operations=[], next_marker=None
),
)

execution_op = Operation(
operation_id="EXECUTION",
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=ExecutionDetails(input_payload=serialized_input),
)

invocation_input = DurableExecutionInvocationInputWithClient(
durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
checkpoint_token="initial-token", # noqa: S106
initial_execution_state=InitialExecutionState(
operations=[execution_op],
next_marker="",
),
is_local_runner=False,
service_client=mock_client,
)

with (
patch.object(
input_serdes, "deserialize", wraps=input_serdes.deserialize
) as mock_input_deser,
patch.object(
output_serdes, "serialize", wraps=output_serdes.serialize
) as mock_output_ser,
):

@durable_execution(input_serdes=input_serdes, output_serdes=output_serdes)
def handler(event, context: DurableContext):
return expected_output

handler(invocation_input, Mock())

mock_input_deser.assert_called_once()
mock_output_ser.assert_called_once_with(
expected_output, mock_output_ser.call_args[0][1]
)