11from __future__ import annotations
22
3+ import json
34from dataclasses import dataclass , field
4- from typing import TYPE_CHECKING , Protocol , TypeVar , cast
5+ from typing import TYPE_CHECKING , Any , Concatenate , ParamSpec , Protocol , TypeVar , cast
56
7+ from aws_durable_execution_sdk_python .execution import (
8+ InvocationStatus ,
9+ durable_handler ,
10+ )
611from aws_durable_execution_sdk_python .lambda_service import (
712 ErrorObject ,
813 OperationStatus ,
3136 import datetime
3237 from collections .abc import Callable , MutableMapping
3338
39+ from aws_durable_execution_sdk_python .context import DurableContext
3440 from aws_durable_execution_sdk_python .execution import InvocationStatus
3541
3642 from aws_durable_execution_sdk_python_testing .execution import Execution
@@ -49,6 +55,7 @@ class Operation:
4955
5056
5157T = TypeVar ("T" , bound = Operation )
58+ P = ParamSpec ("P" )
5259
5360
5461class OperationFactory (Protocol ):
@@ -90,7 +97,7 @@ def from_svc_operation(
9097@dataclass (frozen = True )
9198class ContextOperation (Operation ):
9299 child_operations : list [Operation ]
93- result : str | None = None
100+ result : Any = None
94101 error : ErrorObject | None = None
95102
96103 @staticmethod
@@ -119,9 +126,11 @@ def from_svc_operation(
119126 start_timestamp = operation .start_timestamp ,
120127 end_timestamp = operation .end_timestamp ,
121128 child_operations = child_operations ,
122- result = operation .context_details .result
123- if operation .context_details
124- else None ,
129+ result = (
130+ json .loads (operation .context_details .result )
131+ if operation .context_details and operation .context_details .result
132+ else None
133+ ),
125134 error = operation .context_details .error
126135 if operation .context_details
127136 else None ,
@@ -157,8 +166,7 @@ def get_execution(self, name: str) -> ExecutionOperation:
157166class StepOperation (ContextOperation ):
158167 attempt : int = 0
159168 next_attempt_timestamp : str | None = None
160- # TODO: deserialize?
161- result : str | None = None
169+ result : Any = None
162170 error : ErrorObject | None = None
163171
164172 @staticmethod
@@ -193,7 +201,11 @@ def from_svc_operation(
193201 if operation .step_details
194202 else None
195203 ),
196- result = operation .step_details .result if operation .step_details else None ,
204+ result = (
205+ json .loads (operation .step_details .result )
206+ if operation .step_details and operation .step_details .result
207+ else None
208+ ),
197209 error = operation .step_details .error if operation .step_details else None ,
198210 )
199211
@@ -230,7 +242,7 @@ def from_svc_operation(
230242@dataclass (frozen = True )
231243class CallbackOperation (ContextOperation ):
232244 callback_id : str | None = None
233- result : str | None = None
245+ result : Any = None
234246 error : ErrorObject | None = None
235247
236248 @staticmethod
@@ -264,9 +276,11 @@ def from_svc_operation(
264276 if operation .callback_details
265277 else None
266278 ),
267- result = operation .callback_details .result
268- if operation .callback_details
269- else None ,
279+ result = (
280+ json .loads (operation .callback_details .result )
281+ if operation .callback_details and operation .callback_details .result
282+ else None
283+ ),
270284 error = operation .callback_details .error
271285 if operation .callback_details
272286 else None ,
@@ -276,7 +290,7 @@ def from_svc_operation(
276290@dataclass (frozen = True )
277291class InvokeOperation (Operation ):
278292 durable_execution_arn : str | None = None
279- result : str | None = None
293+ result : Any = None
280294 error : ErrorObject | None = None
281295
282296 @staticmethod
@@ -301,9 +315,11 @@ def from_svc_operation(
301315 if operation .invoke_details
302316 else None
303317 ),
304- result = operation .invoke_details .result
305- if operation .invoke_details
306- else None ,
318+ result = (
319+ json .loads (operation .invoke_details .result )
320+ if operation .invoke_details and operation .invoke_details .result
321+ else None
322+ ),
307323 error = operation .invoke_details .error if operation .invoke_details else None ,
308324 )
309325
@@ -334,7 +350,7 @@ def create_operation(
334350class DurableFunctionTestResult :
335351 status : InvocationStatus
336352 operations : list [Operation ]
337- result : str | None = None
353+ result : Any = None
338354 error : ErrorObject | None = None
339355
340356 @classmethod
@@ -352,10 +368,14 @@ def create(cls, execution: Execution) -> DurableFunctionTestResult:
352368 msg : str = "Execution result must exist to create test result."
353369 raise DurableFunctionsTestError (msg )
354370
371+ deserialized_result = (
372+ json .loads (execution .result .result ) if execution .result .result else None
373+ )
374+
355375 return cls (
356376 status = execution .result .status ,
357377 operations = operations ,
358- result = execution . result . result ,
378+ result = deserialized_result ,
359379 error = execution .result .error ,
360380 )
361381
@@ -413,7 +433,7 @@ def close(self):
413433
414434 def run (
415435 self ,
416- input : str , # noqa: A002
436+ input : str | None = None , # noqa: A002
417437 timeout : int = 900 ,
418438 function_name : str = "test-function" ,
419439 execution_name : str = "execution-name" ,
@@ -451,4 +471,19 @@ def run(
451471 execution : Execution = self ._store .load (output .execution_arn )
452472 return DurableFunctionTestResult .create (execution = execution )
453473
454- # return execution
474+
475+ class DurableChildContextTestRunner (DurableFunctionTestRunner ):
476+ """Test a durable block, annotated with @durable_with_child_context, in isolation."""
477+
478+ def __init__ (
479+ self ,
480+ context_function : Callable [Concatenate [DurableContext , P ], Any ],
481+ * args ,
482+ ** kwargs ,
483+ ):
484+ # wrap the durable context around a durable handler as a convenience to run directly
485+ @durable_handler
486+ def handler (event : Any , context : DurableContext ): # noqa: ARG001
487+ return context_function (* args , ** kwargs )(context )
488+
489+ super ().__init__ (handler )
0 commit comments