Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
BatchResult,
LoggerInterface,
StepContext,
WaitForCallbackContext,
WaitForConditionCheckContext,
)
from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol
Expand Down Expand Up @@ -489,7 +490,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:

def wait_for_callback(
self,
submitter: Callable[[str], None],
submitter: Callable[[str, WaitForCallbackContext], None],
name: str | None = None,
config: WaitForCallbackConfig | None = None,
) -> Any:
Expand Down
15 changes: 11 additions & 4 deletions src/aws_durable_execution_sdk_python/operation/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CallbackOptions,
OperationUpdate,
)
from aws_durable_execution_sdk_python.types import WaitForCallbackContext

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -23,7 +24,11 @@
CheckpointedResult,
ExecutionState,
)
from aws_durable_execution_sdk_python.types import Callback, DurableContext
from aws_durable_execution_sdk_python.types import (
Callback,
DurableContext,
StepContext,
)


def create_callback_handler(
Expand Down Expand Up @@ -85,7 +90,7 @@ def create_callback_handler(

def wait_for_callback_handler(
context: DurableContext,
submitter: Callable[[str], None],
submitter: Callable[[str, WaitForCallbackContext], None],
name: str | None = None,
config: WaitForCallbackConfig | None = None,
) -> Any:
Expand All @@ -98,8 +103,10 @@ def wait_for_callback_handler(
name=f"{name_with_space}create callback id", config=config
)

def submitter_step(step_context): # noqa: ARG001
return submitter(callback.callback_id)
def submitter_step(step_context: StepContext):
return submitter(
callback.callback_id, WaitForCallbackContext(logger=step_context.logger)
)

step_config = (
StepConfig(
Expand Down
5 changes: 5 additions & 0 deletions src/aws_durable_execution_sdk_python/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class StepContext(OperationContext):
pass


@dataclass(frozen=True)
class WaitForCallbackContext(OperationContext):
"""Context provided to waitForCallback submitter functions."""


@dataclass(frozen=True)
class WaitForConditionCheckContext(OperationContext):
pass
Expand Down
31 changes: 25 additions & 6 deletions tests/operation/callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,18 @@ def test_wait_for_callback_handler_submitter_called_with_callback_id():
def capture_step_call(func, name, config=None):
# Execute the step callable to verify submitter is called correctly
step_context = Mock(spec=StepContext)
step_context.logger = Mock()
func(step_context)

mock_context.step.side_effect = capture_step_call

wait_for_callback_handler(mock_context, mock_submitter, "test")

mock_submitter.assert_called_once_with("callback_test_id")
# Verify submitter was called with callback_id and WaitForCallbackContext
assert mock_submitter.call_count == 1
call_args = mock_submitter.call_args[0]
assert call_args[0] == "callback_test_id"
assert hasattr(call_args[1], "logger")


def test_create_callback_handler_with_none_operation_in_result():
Expand Down Expand Up @@ -350,14 +355,19 @@ def test_wait_for_callback_handler_with_none_callback_id():

def execute_step(func, name, config=None):
step_context = Mock(spec=StepContext)
step_context.logger = Mock()
return func(step_context)

mock_context.step.side_effect = execute_step

result = wait_for_callback_handler(mock_context, mock_submitter, "test")

assert result == "result_with_none_id"
mock_submitter.assert_called_once_with(None)
# Verify submitter was called with None callback_id and WaitForCallbackContext
assert mock_submitter.call_count == 1
call_args = mock_submitter.call_args[0]
assert call_args[0] is None
assert hasattr(call_args[1], "logger")


def test_wait_for_callback_handler_with_empty_string_callback_id():
Expand All @@ -371,14 +381,19 @@ def test_wait_for_callback_handler_with_empty_string_callback_id():

def execute_step(func, name, config=None):
step_context = Mock(spec=StepContext)
step_context.logger = Mock()
return func(step_context)

mock_context.step.side_effect = execute_step

result = wait_for_callback_handler(mock_context, mock_submitter, "test")

assert result == "result_with_empty_id"
mock_submitter.assert_called_once_with("")
# Verify submitter was called with empty string callback_id and WaitForCallbackContext
assert mock_submitter.call_count == 1
call_args = mock_submitter.call_args[0]
assert call_args[0] == "" # noqa: PLC1901 - explicitly testing empty string, not just falsey
assert hasattr(call_args[1], "logger")


def test_wait_for_callback_handler_with_large_data():
Expand Down Expand Up @@ -585,12 +600,13 @@ def test_wait_for_callback_handler_submitter_exception_handling():
mock_callback.result.return_value = "exception_result"
mock_context.create_callback.return_value = mock_callback

def failing_submitter(callback_id):
def failing_submitter(callback_id, context):
msg = "Submitter failed"
raise ValueError(msg)

def step_side_effect(func, name, config=None):
step_context = Mock(spec=StepContext)
step_context.logger = Mock()
func(step_context)

mock_context.step.side_effect = step_side_effect
Expand Down Expand Up @@ -775,12 +791,14 @@ def test_callback_lifecycle_complete_flow():

assert callback_id == "lifecycle_cb123"

def mock_submitter(cb_id):
def mock_submitter(cb_id, context):
assert cb_id == "lifecycle_cb123"
assert hasattr(context, "logger")
return "submitted"

def execute_step(func, name, config=None):
step_context = Mock(spec=StepContext)
step_context.logger = Mock()
return func(step_context)

mock_context.step.side_effect = execute_step
Expand Down Expand Up @@ -889,7 +907,7 @@ def test_callback_with_complex_submitter():

submission_log = []

def complex_submitter(callback_id):
def complex_submitter(callback_id, context):
submission_log.append(f"received_id: {callback_id}")
if callback_id == "complex_cb789":
submission_log.append("api_call_success")
Expand All @@ -901,6 +919,7 @@ def complex_submitter(callback_id):

def execute_step(func, name, config):
step_context = Mock(spec=StepContext)
step_context.logger = Mock()
return func(step_context)

mock_context.step.side_effect = execute_step
Expand Down
Loading