Skip to content

Commit be6a248

Browse files
author
Rares Polenciuc
committed
feat: implement callback token generation and processing
- Add CallbackToken generation in callback processor with observer integration - Implement SendCallbackSuccess, SendCallbackFailure, and SendCallbackHeartbeat - Add callback operation lookup and completion methods to execution - Ensure unique token generation across executions
1 parent a395662 commit be6a248

File tree

11 files changed

+758
-33
lines changed

11 files changed

+758
-33
lines changed

examples/test/test_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Tests for step example."""
22

33
import pytest
4-
from aws_durable_execution_sdk_python.execution import InvocationStatus
54

5+
from aws_durable_execution_sdk_python.execution import InvocationStatus
66
from src import step
77
from test.conftest import deserialize_operation_payload
88

src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def _create_callback_details(
9898
) -> CallbackDetails | None:
9999
"""Create CallbackDetails from OperationUpdate."""
100100
return (
101-
CallbackDetails(
102-
callback_id="placeholder", result=update.payload, error=update.error
103-
)
101+
CallbackDetails(callback_id="", result=update.payload, error=update.error)
104102
if update.operation_type == OperationType.CALLBACK
105103
else None
106104
)

src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from aws_durable_execution_sdk_python_testing.exceptions import (
1818
InvalidParameterValueException,
1919
)
20+
from aws_durable_execution_sdk_python_testing.token import CallbackToken
2021

2122

2223
if TYPE_CHECKING:
@@ -36,13 +37,39 @@ def process(
3637
"""Process CALLBACK operation update with scheduler integration for activities."""
3738
match update.action:
3839
case OperationAction.START:
39-
# TODO: create CallbackToken (see token module). Add Observer/Notifier for on_callback_created possibly,
40-
# but token might well have enough so don't need to maintain token list on execution itself
41-
return self._translate_update_to_operation(
40+
callback_token = CallbackToken(
41+
execution_arn=execution_arn,
42+
operation_id=update.operation_id,
43+
)
44+
callback_id = callback_token.to_str()
45+
46+
notifier.notify_callback_created(
47+
execution_arn=execution_arn,
48+
operation_id=update.operation_id,
49+
callback_id=callback_id,
50+
)
51+
52+
operation = self._translate_update_to_operation(
4253
update=update,
4354
current_operation=current_op,
4455
status=OperationStatus.STARTED,
4556
)
57+
58+
# Replace callback_details with actual callback_id
59+
if operation.callback_details:
60+
from dataclasses import replace
61+
from aws_durable_execution_sdk_python.lambda_service import (
62+
CallbackDetails,
63+
)
64+
65+
new_callback_details = replace(
66+
operation.callback_details, callback_id=callback_id
67+
)
68+
operation = replace(
69+
operation, callback_details=new_callback_details
70+
)
71+
72+
return operation
4673
case _:
4774
msg: str = "Invalid action for CALLBACK operation."
4875

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
from aws_durable_execution_sdk_python_testing.model import (
3131
StartDurableExecutionInput,
3232
)
33-
from aws_durable_execution_sdk_python_testing.token import CheckpointToken
33+
from aws_durable_execution_sdk_python_testing.token import (
34+
CheckpointToken,
35+
CallbackToken,
36+
)
3437

3538

3639
class CloseStatus(Enum):
@@ -305,6 +308,18 @@ def find_operation(self, operation_id: str) -> tuple[int, Operation]:
305308
msg: str = f"Attempting to update state of an Operation [{operation_id}] that doesn't exist"
306309
raise IllegalStateException(msg)
307310

311+
def find_callback_operation(self, callback_id: str) -> tuple[int, Operation]:
312+
"""Find callback operation by callback_id, return index and operation."""
313+
for i, operation in enumerate(self.operations):
314+
if (
315+
operation.operation_type == OperationType.CALLBACK
316+
and operation.callback_details
317+
and operation.callback_details.callback_id == callback_id
318+
):
319+
return i, operation
320+
msg: str = f"Callback operation with callback_id [{callback_id}] not found"
321+
raise IllegalStateException(msg)
322+
308323
def complete_wait(self, operation_id: str) -> Operation:
309324
"""Complete WAIT operation when timer fires."""
310325
index, operation = self.find_operation(operation_id)
@@ -362,3 +377,56 @@ def complete_retry(self, operation_id: str) -> Operation:
362377
# Assign
363378
self.operations[index] = updated_operation
364379
return updated_operation
380+
381+
def complete_callback_success(
382+
self, callback_id: str, result: bytes | None = None
383+
) -> Operation:
384+
"""Complete CALLBACK operation with success."""
385+
index, operation = self.find_callback_operation(callback_id)
386+
387+
if operation.status != OperationStatus.STARTED:
388+
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
389+
raise IllegalStateException(msg)
390+
391+
with self._state_lock:
392+
self._token_sequence += 1
393+
updated_callback_details = None
394+
if operation.callback_details:
395+
updated_callback_details = replace(
396+
operation.callback_details,
397+
result=result.decode() if result else None,
398+
)
399+
400+
self.operations[index] = replace(
401+
operation,
402+
status=OperationStatus.SUCCEEDED,
403+
end_timestamp=datetime.now(UTC),
404+
callback_details=updated_callback_details,
405+
)
406+
return self.operations[index]
407+
408+
def complete_callback_failure(
409+
self, callback_id: str, error: ErrorObject
410+
) -> Operation:
411+
"""Complete CALLBACK operation with failure."""
412+
index, operation = self.find_callback_operation(callback_id)
413+
414+
if operation.status != OperationStatus.STARTED:
415+
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
416+
raise IllegalStateException(msg)
417+
418+
with self._state_lock:
419+
self._token_sequence += 1
420+
updated_callback_details = None
421+
if operation.callback_details:
422+
updated_callback_details = replace(
423+
operation.callback_details, error=error
424+
)
425+
426+
self.operations[index] = replace(
427+
operation,
428+
status=OperationStatus.FAILED,
429+
end_timestamp=datetime.now(UTC),
430+
callback_details=updated_callback_details,
431+
)
432+
return self.operations[index]

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
Execution as ExecutionSummary,
4747
)
4848
from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver
49+
from aws_durable_execution_sdk_python_testing.token import CallbackToken
4950

5051

5152
if TYPE_CHECKING:
@@ -509,7 +510,7 @@ def checkpoint_execution(
509510
def send_callback_success(
510511
self,
511512
callback_id: str,
512-
result: bytes | None = None, # noqa: ARG002
513+
result: bytes | None = None,
513514
) -> SendDurableExecutionCallbackSuccessResponse:
514515
"""Send callback success response.
515516
@@ -528,16 +529,23 @@ def send_callback_success(
528529
msg: str = "callback_id is required"
529530
raise InvalidParameterValueException(msg)
530531

531-
# TODO: Implement actual callback success logic
532-
# This would involve finding the callback operation and completing it
533-
logger.info("Callback success sent for callback_id: %s", callback_id)
532+
try:
533+
callback_token = CallbackToken.from_str(callback_id)
534+
execution = self.get_execution(callback_token.execution_arn)
535+
execution.complete_callback_success(callback_id, result)
536+
self._store.update(execution)
537+
self._invoke_execution(callback_token.execution_arn)
538+
logger.info("Callback success completed for callback_id: %s", callback_id)
539+
except Exception as e:
540+
msg = f"Failed to process callback success: {e}"
541+
raise ResourceNotFoundException(msg) from e
534542

535543
return SendDurableExecutionCallbackSuccessResponse()
536544

537545
def send_callback_failure(
538546
self,
539547
callback_id: str,
540-
error: ErrorObject | None = None, # noqa: ARG002
548+
error: ErrorObject | None = None,
541549
) -> SendDurableExecutionCallbackFailureResponse:
542550
"""Send callback failure response.
543551
@@ -556,9 +564,18 @@ def send_callback_failure(
556564
msg: str = "callback_id is required"
557565
raise InvalidParameterValueException(msg)
558566

559-
# TODO: Implement actual callback failure logic
560-
# This would involve finding the callback operation and failing it
561-
logger.info("Callback failure sent for callback_id: %s", callback_id)
567+
callback_error = error or ErrorObject.from_message("Callback failed")
568+
569+
try:
570+
callback_token = CallbackToken.from_str(callback_id)
571+
execution = self.get_execution(callback_token.execution_arn)
572+
execution.complete_callback_failure(callback_id, callback_error)
573+
self._store.update(execution)
574+
self._invoke_execution(callback_token.execution_arn)
575+
logger.info("Callback failure completed for callback_id: %s", callback_id)
576+
except Exception as e:
577+
msg = f"Failed to process callback failure: {e}"
578+
raise ResourceNotFoundException(msg) from e
562579

563580
return SendDurableExecutionCallbackFailureResponse()
564581

@@ -581,9 +598,20 @@ def send_callback_heartbeat(
581598
msg: str = "callback_id is required"
582599
raise InvalidParameterValueException(msg)
583600

584-
# TODO: Implement actual callback heartbeat logic
585-
# This would involve updating the callback timeout
586-
logger.info("Callback heartbeat sent for callback_id: %s", callback_id)
601+
try:
602+
callback_token = CallbackToken.from_str(callback_id)
603+
execution = self.get_execution(callback_token.execution_arn)
604+
605+
# Find callback operation to verify it exists and is active
606+
_, operation = execution.find_callback_operation(callback_id)
607+
if operation.status != OperationStatus.STARTED:
608+
msg = f"Callback {callback_id} is not active"
609+
raise ResourceNotFoundException(msg)
610+
611+
logger.info("Callback heartbeat processed for callback_id: %s", callback_id)
612+
except Exception as e:
613+
msg = f"Failed to process callback heartbeat: {e}"
614+
raise ResourceNotFoundException(msg) from e
587615

588616
return SendDurableExecutionCallbackHeartbeatResponse()
589617

@@ -907,4 +935,15 @@ def retry_handler() -> None:
907935
retry_handler, delay=delay, completion_event=completion_event
908936
)
909937

938+
def on_callback_created(
939+
self, execution_arn: str, operation_id: str, callback_id: str
940+
) -> None:
941+
"""Handle callback creation. Observer method triggered by notifier."""
942+
logger.debug(
943+
"[%s] Callback created for operation %s with callback_id: %s",
944+
execution_arn,
945+
operation_id,
946+
callback_id,
947+
)
948+
910949
# endregion ExecutionObserver

src/aws_durable_execution_sdk_python_testing/observer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def on_step_retry_scheduled(
4444
) -> None:
4545
"""Called when step retry scheduled."""
4646

47+
@abstractmethod
48+
def on_callback_created(
49+
self, execution_arn: str, operation_id: str, callback_id: str
50+
) -> None:
51+
"""Called when callback is created."""
52+
4753

4854
class ExecutionNotifier:
4955
"""Notifies observers about execution events. Thread-safe."""
@@ -111,4 +117,15 @@ def notify_step_retry_scheduled(
111117
delay=delay,
112118
)
113119

120+
def notify_callback_created(
121+
self, execution_arn: str, operation_id: str, callback_id: str
122+
) -> None:
123+
"""Notify observers about callback creation."""
124+
self._notify_observers(
125+
ExecutionObserver.on_callback_created,
126+
execution_arn=execution_arn,
127+
operation_id=operation_id,
128+
callback_id=callback_id,
129+
)
130+
114131
# endregion event emitters

tests/checkpoint/processors/base_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_create_callback_details():
249249
result = processor.create_callback_details(update)
250250

251251
assert isinstance(result, CallbackDetails)
252-
assert result.callback_id == "placeholder"
252+
assert result.callback_id == ""
253253
assert result.result == "test-payload"
254254
assert result.error == error
255255

tests/execution_edge_cases_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Tests for edge cases in execution functionality."""
2+
3+
import pytest
4+
from unittest.mock import Mock, patch
5+
from aws_durable_execution_sdk_python_testing.execution import Execution, CloseStatus
6+
from aws_durable_execution_sdk_python_testing.exceptions import (
7+
InvalidParameterValueException,
8+
)
9+
10+
11+
def test_status_unexpected_close_status():
12+
"""Test status property with unexpected close status."""
13+
execution = Execution("test-arn", Mock(), [])
14+
execution.is_complete = True
15+
execution.close_status = "UNKNOWN_STATUS" # Invalid close status
16+
17+
with pytest.raises(InvalidParameterValueException, match="Unexpected close status"):
18+
_ = execution.status
19+
20+
21+
def test_start_no_invocation_id():
22+
"""Test start method when invocation_id is None."""
23+
start_input = Mock()
24+
start_input.invocation_id = None
25+
26+
execution = Execution("test-arn", start_input, [])
27+
28+
with pytest.raises(
29+
InvalidParameterValueException, match="invocation_id is required"
30+
):
31+
execution.start()
32+
33+
34+
def test_from_dict_with_none_result():
35+
"""Test from_dict with None result."""
36+
data = {
37+
"DurableExecutionArn": "test-arn",
38+
"StartInput": {"function_name": "test"},
39+
"Operations": [],
40+
"Updates": [],
41+
"UsedTokens": [],
42+
"TokenSequence": 0,
43+
"IsComplete": False,
44+
"Result": None, # None result
45+
"ConsecutiveFailedInvocationAttempts": 0,
46+
"CloseStatus": None,
47+
}
48+
49+
with patch(
50+
"aws_durable_execution_sdk_python_testing.model.StartDurableExecutionInput.from_dict"
51+
) as mock_from_dict:
52+
mock_from_dict.return_value = Mock()
53+
execution = Execution.from_dict(data)
54+
assert execution.result is None
55+
56+
57+
def test_from_dict_with_none_close_status():
58+
"""Test from_dict with None close status."""
59+
data = {
60+
"DurableExecutionArn": "test-arn",
61+
"StartInput": {"function_name": "test"},
62+
"Operations": [],
63+
"Updates": [],
64+
"UsedTokens": [],
65+
"TokenSequence": 0,
66+
"IsComplete": False,
67+
"Result": None,
68+
"ConsecutiveFailedInvocationAttempts": 0,
69+
"CloseStatus": None, # None close status
70+
}
71+
72+
with patch(
73+
"aws_durable_execution_sdk_python_testing.model.StartDurableExecutionInput.from_dict"
74+
) as mock_from_dict:
75+
mock_from_dict.return_value = Mock()
76+
execution = Execution.from_dict(data)
77+
assert execution.close_status is None

0 commit comments

Comments
 (0)