Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ops/__tests__/test_parse_sdk_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def test():

for input_text, expected in test_cases:
result = parse_sdk_branch(input_text)
if result != expected:
return False

return True
# Assert is expected in test functions
assert result == expected, ( # noqa: S101
f"Expected '{expected}' but got '{result}' for input: {input_text[:50]}..."
)


if __name__ == "__main__":
success = test_parse_sdk_branch()
sys.exit(0 if success else 1)
test_parse_sdk_branch()
sys.exit(0)
6 changes: 4 additions & 2 deletions src/aws_durable_execution_sdk_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,12 @@ class InvokeConfig(Generic[P, R]):
from blocking execution indefinitely.

serdes_payload: Custom serialization/deserialization for the payload
sent to the invoked function. If None, uses default JSON serialization.
sent to the invoked function. Defaults to DEFAULT_JSON_SERDES when
not set.

serdes_result: Custom serialization/deserialization for the result
returned from the invoked function. If None, uses default JSON serialization.
returned from the invoked function. Defaults to DEFAULT_JSON_SERDES when
not set.

tenant_id: Optional tenant identifier for multi-tenant isolation.
If provided, the invocation will be scoped to this tenant.
Expand Down
10 changes: 7 additions & 3 deletions src/aws_durable_execution_sdk_python/operation/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
ChainedInvokeOptions,
OperationUpdate,
)
from aws_durable_execution_sdk_python.serdes import deserialize, serialize
from aws_durable_execution_sdk_python.serdes import (
DEFAULT_JSON_SERDES,
deserialize,
serialize,
)
from aws_durable_execution_sdk_python.suspend import suspend_with_optional_resume_delay

if TYPE_CHECKING:
Expand Down Expand Up @@ -53,7 +57,7 @@ def invoke_handler(
and checkpointed_result.operation.chained_invoke_details.result
):
return deserialize(
serdes=config.serdes_result,
serdes=config.serdes_result or DEFAULT_JSON_SERDES,
data=checkpointed_result.operation.chained_invoke_details.result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
Expand All @@ -78,7 +82,7 @@ def invoke_handler(
suspend_with_optional_resume_delay(msg, config.timeout_seconds)

serialized_payload: str = serialize(
serdes=config.serdes_payload,
serdes=config.serdes_payload or DEFAULT_JSON_SERDES,
value=payload,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
Expand Down
8 changes: 4 additions & 4 deletions src/aws_durable_execution_sdk_python/serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ def _to_json_serializable(self, obj: Any) -> Any:
return obj


_DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes()
_EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes()
DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes()
EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes()


def serialize(
Expand All @@ -463,7 +463,7 @@ def serialize(
FatalError: If serialization fails
"""
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES
active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES
try:
return active_serdes.serialize(value, serdes_context)
except Exception as e:
Expand Down Expand Up @@ -493,7 +493,7 @@ def deserialize(
FatalError: If deserialization fails
"""
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES
active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES
try:
return active_serdes.deserialize(data, serdes_context)
except Exception as e:
Expand Down
54 changes: 54 additions & 0 deletions tests/operation/invoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,57 @@ def test_invoke_handler_default_config_no_tenant_id():
chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"]
assert chained_invoke_options["FunctionName"] == "test_function"
assert "TenantId" not in chained_invoke_options


def test_invoke_handler_defaults_to_json_serdes():
"""Test invoke_handler uses DEFAULT_JSON_SERDES when config has no serdes."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"
mock_state.get_checkpoint_result.return_value = (
CheckpointedResult.create_not_found()
)

config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None)
payload = {"key": "value", "number": 42}

with pytest.raises(SuspendExecution):
invoke_handler(
function_name="test_function",
payload=payload,
state=mock_state,
operation_identifier=OperationIdentifier("invoke_json", None, None),
config=config,
)

# Verify JSON serialization was used (not extended types)
operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"]
assert operation_update.payload == json.dumps(payload)


def test_invoke_handler_result_defaults_to_json_serdes():
"""Test invoke_handler uses DEFAULT_JSON_SERDES for result deserialization."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"

result_data = {"key": "value", "number": 42}
operation = Operation(
operation_id="invoke_result_json",
operation_type=OperationType.CHAINED_INVOKE,
status=OperationStatus.SUCCEEDED,
chained_invoke_details=ChainedInvokeDetails(result=json.dumps(result_data)),
)
mock_result = CheckpointedResult.create_from_operation(operation)
mock_state.get_checkpoint_result.return_value = mock_result

config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None)

result = invoke_handler(
function_name="test_function",
payload={"input": "data"},
state=mock_state,
operation_identifier=OperationIdentifier("invoke_result_json", None, None),
config=config,
)

# Verify JSON deserialization was used (not extended types)
assert result == result_data