Skip to content

Commit e59accd

Browse files
committed
Allow passing input/output serdes
Changes: - We allow input and output serdes on the decorator. This is particularly useful for chained invoke to ensure callees automatically have the correct serialization
1 parent 00b195d commit e59accd

File tree

4 files changed

+159
-20
lines changed

4 files changed

+159
-20
lines changed

ops/__tests__/test_parse_sdk_branch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,7 @@ def test():
7373

7474
for input_text, expected in test_cases:
7575
result = parse_sdk_branch(input_text)
76-
if result != expected:
77-
return False
78-
79-
return True
76+
assert result == expected # noqa: S101
8077

8178

8279
if __name__ == "__main__":

src/aws_durable_execution_sdk_python/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,24 @@
1818
# Core decorator - used in every durable function
1919
from aws_durable_execution_sdk_python.execution import durable_execution
2020

21+
# Serialization - for custom input/output serialization
22+
from aws_durable_execution_sdk_python.serdes import (
23+
ExtendedTypeSerDes,
24+
JsonSerDes,
25+
SerDes,
26+
)
27+
2128
# Essential context types - passed to user functions
2229
from aws_durable_execution_sdk_python.types import BatchResult, StepContext
2330

2431
__all__ = [
2532
"BatchResult",
2633
"DurableContext",
2734
"DurableExecutionsError",
35+
"ExtendedTypeSerDes",
2836
"InvocationError",
37+
"JsonSerDes",
38+
"SerDes",
2939
"StepContext",
3040
"ValidationError",
3141
"durable_execution",

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
OperationType,
2828
OperationUpdate,
2929
)
30+
from aws_durable_execution_sdk_python.serdes import (
31+
JsonSerDes,
32+
SerDes,
33+
deserialize,
34+
serialize,
35+
)
3036

3137
if TYPE_CHECKING:
3238
from collections.abc import Callable, MutableMapping
@@ -206,15 +212,29 @@ def durable_execution(
206212
func: Callable[[Any, DurableContext], Any] | None = None,
207213
*,
208214
boto3_client: boto3.client | None = None,
215+
input_serdes: SerDes | None = None,
216+
output_serdes: SerDes | None = None,
209217
) -> Callable[[Any, LambdaContext], Any]:
210218
# Decorator called with parameters
211219
if func is None:
212220
logger.debug("Decorator called with parameters")
213-
return functools.partial(durable_execution, boto3_client=boto3_client)
221+
return functools.partial(
222+
durable_execution,
223+
boto3_client=boto3_client,
224+
input_serdes=input_serdes,
225+
output_serdes=output_serdes,
226+
)
214227

215228
logger.debug("Starting durable execution handler...")
216229

217230
def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
231+
# Set default SerDes if not provided
232+
nonlocal input_serdes, output_serdes
233+
if input_serdes is None:
234+
input_serdes = JsonSerDes()
235+
if output_serdes is None:
236+
output_serdes = JsonSerDes()
237+
218238
invocation_input: DurableExecutionInvocationInput
219239
service_client: DurableServiceClient
220240

@@ -250,18 +270,14 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
250270
invocation_input.initial_execution_state.get_input_payload()
251271
)
252272

253-
# Python RIC LambdaMarshaller just uses standard json deserialization for event
254-
# https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46
255-
input_event: MutableMapping[str, Any] = {}
273+
input_event: MutableMapping[str, Any] = {} # type ignore
256274
if raw_input_payload and raw_input_payload.strip():
257-
try:
258-
input_event = json.loads(raw_input_payload)
259-
except json.JSONDecodeError:
260-
logger.exception(
261-
"Failed to parse input payload as JSON: payload: %r",
262-
raw_input_payload,
263-
)
264-
raise
275+
input_event = deserialize(
276+
serdes=input_serdes,
277+
data=raw_input_payload,
278+
operation_id="EXECUTION",
279+
durable_execution_arn=invocation_input.durable_execution_arn,
280+
)
265281

266282
execution_state: ExecutionState = ExecutionState(
267283
durable_execution_arn=invocation_input.durable_execution_arn,
@@ -310,7 +326,15 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
310326
"%s exiting user-space...",
311327
invocation_input.durable_execution_arn,
312328
)
313-
serialized_result = json.dumps(result)
329+
330+
# Serialize result using output_serdes if provided
331+
serialized_result = serialize(
332+
serdes=output_serdes,
333+
value=result,
334+
operation_id="EXECUTION",
335+
durable_execution_arn=invocation_input.durable_execution_arn,
336+
)
337+
314338
# large response handling here. Remember if checkpointing to complete, NOT to include
315339
# payload in response
316340
if (

tests/execution_test.py

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
import datetime
44
import json
55
import time
6+
from decimal import Decimal
67
from typing import Any
78
from unittest.mock import Mock, patch
89

910
import pytest
1011

12+
from aws_durable_execution_sdk_python import (
13+
DurableContext,
14+
ExtendedTypeSerDes,
15+
JsonSerDes,
16+
durable_execution,
17+
)
1118
from aws_durable_execution_sdk_python.config import StepConfig, StepSemantics
12-
from aws_durable_execution_sdk_python.context import DurableContext
1319
from aws_durable_execution_sdk_python.exceptions import (
1420
BotoClientError,
1521
CheckpointError,
@@ -24,7 +30,6 @@
2430
DurableExecutionInvocationOutput,
2531
InitialExecutionState,
2632
InvocationStatus,
27-
durable_execution,
2833
)
2934

3035
# LambdaContext no longer needed - using duck typing
@@ -946,7 +951,7 @@ def test_handler(event: Any, context: DurableContext) -> dict:
946951
lambda_context.invoked_function_arn = None
947952
lambda_context.tenant_id = None
948953

949-
with pytest.raises(json.JSONDecodeError):
954+
with pytest.raises(ExecutionError):
950955
test_handler(invocation_input, lambda_context)
951956

952957

@@ -2071,3 +2076,106 @@ def test_handler(event: Any, context: DurableContext) -> dict:
20712076
match="The payload is not the correct Durable Function input",
20722077
):
20732078
test_handler(non_dict_event, lambda_context)
2079+
2080+
2081+
# region SERDES
2082+
2083+
2084+
@pytest.mark.parametrize(
2085+
("input_serdes", "output_serdes", "input_data", "expected_output"),
2086+
[
2087+
# Both ExtendedTypeSerDes
2088+
(
2089+
ExtendedTypeSerDes(),
2090+
ExtendedTypeSerDes(),
2091+
{
2092+
"amount": Decimal("123.45"),
2093+
"timestamp": datetime.datetime(
2094+
2025, 11, 20, 18, 0, 0, tzinfo=datetime.UTC
2095+
),
2096+
},
2097+
{
2098+
"result_amount": Decimal("678.90"),
2099+
"processed_at": datetime.datetime(
2100+
2025, 11, 20, 19, 0, 0, tzinfo=datetime.UTC
2101+
),
2102+
},
2103+
),
2104+
# Both JsonSerDes
2105+
(
2106+
JsonSerDes(),
2107+
JsonSerDes(),
2108+
{"name": "test", "value": 42},
2109+
{"result": "success", "count": 100},
2110+
),
2111+
# Input ExtendedTypeSerDes, Output JsonSerDes
2112+
(
2113+
ExtendedTypeSerDes(),
2114+
JsonSerDes(),
2115+
{"amount": Decimal("123.45")},
2116+
{"result": "success"},
2117+
),
2118+
# Input JsonSerDes, Output ExtendedTypeSerDes
2119+
(
2120+
JsonSerDes(),
2121+
ExtendedTypeSerDes(),
2122+
{"name": "test"},
2123+
{
2124+
"timestamp": datetime.datetime(
2125+
2025, 11, 20, 19, 0, 0, tzinfo=datetime.UTC
2126+
)
2127+
},
2128+
),
2129+
],
2130+
)
2131+
def test_durable_execution_with_serdes(
2132+
input_serdes, output_serdes, input_data, expected_output
2133+
):
2134+
"""Test that input_serdes and output_serdes are invoked correctly."""
2135+
serialized_input = input_serdes.serialize(input_data, None)
2136+
2137+
mock_client = Mock(spec=DurableServiceClient)
2138+
mock_client.checkpoint.return_value = CheckpointOutput(
2139+
checkpoint_token="new-token", # noqa: S106
2140+
new_execution_state=CheckpointUpdatedExecutionState(
2141+
operations=[], next_marker=None
2142+
),
2143+
)
2144+
2145+
execution_op = Operation(
2146+
operation_id="EXECUTION",
2147+
operation_type=OperationType.EXECUTION,
2148+
status=OperationStatus.STARTED,
2149+
execution_details=ExecutionDetails(input_payload=serialized_input),
2150+
)
2151+
2152+
invocation_input = DurableExecutionInvocationInputWithClient(
2153+
durable_execution_arn="arn:aws:lambda:us-east-1:123456789012:function:test",
2154+
checkpoint_token="initial-token", # noqa: S106
2155+
initial_execution_state=InitialExecutionState(
2156+
operations=[execution_op],
2157+
next_marker="",
2158+
),
2159+
is_local_runner=False,
2160+
service_client=mock_client,
2161+
)
2162+
2163+
with (
2164+
patch.object(
2165+
input_serdes, "deserialize", wraps=input_serdes.deserialize
2166+
) as mock_input_deser,
2167+
patch.object(
2168+
output_serdes, "serialize", wraps=output_serdes.serialize
2169+
) as mock_output_ser,
2170+
):
2171+
2172+
@durable_execution(input_serdes=input_serdes, output_serdes=output_serdes)
2173+
def handler(event, context: DurableContext):
2174+
return expected_output
2175+
2176+
handler(invocation_input, Mock())
2177+
2178+
mock_input_deser.assert_called_once()
2179+
mock_output_ser.assert_called_once_with(
2180+
expected_output, mock_output_ser.call_args[0][1]
2181+
)

0 commit comments

Comments
 (0)