diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py index 5b6f02396..928be03e5 100644 --- a/scripts/gen_payload_visitor.py +++ b/scripts/gen_payload_visitor.py @@ -20,87 +20,26 @@ def name_for(desc: Descriptor) -> str: return desc.full_name.replace(".", "_") -# --------------------------------------------------------------------------- -# Emitters for the "multi-unit" case: accumulate coroutines into `coros` list -# and let the caller do a single asyncio.gather(*coros) at the end. -# --------------------------------------------------------------------------- - - def emit_loop( field_name: str, iter_expr: str, child_method: str, ) -> str: - # Emit a coros.extend() over a collection with optional skip guard + # Emit a for-loop with direct await, with optional skip guard + inner = ( + f"for v in {iter_expr}:\n" + f" await self._visit_{child_method}(fs, v)" + ) if field_name == "headers": - return ( - " if not self.skip_headers:\n" - f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})" - ) + return f" if not self.skip_headers:\n {inner}" elif field_name == "search_attributes": - return ( - " if not self.skip_search_attributes:\n" - f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})" - ) + return f" if not self.skip_search_attributes:\n {inner}" else: - return f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})" + return f" {inner}" def emit_singular( field_name: str, access_expr: str, child_method: str, presence_word: str | None -) -> str: - # Emit a coros.append() with optional HasField check and skip guard - if presence_word: - if field_name == "headers": - return ( - " if not self.skip_headers:\n" - f' {presence_word} o.HasField("{field_name}"):\n' - f" coros.append(self._visit_{child_method}(fs, {access_expr}))" - ) - else: - return ( - f' {presence_word} o.HasField("{field_name}"):\n' - f" coros.append(self._visit_{child_method}(fs, {access_expr}))" - ) - else: - if field_name == "headers": - return ( - " if not self.skip_headers:\n" - f" coros.append(self._visit_{child_method}(fs, {access_expr}))" - ) - else: - return ( - f" coros.append(self._visit_{child_method}(fs, {access_expr}))" - ) - - -# --------------------------------------------------------------------------- -# Emitters for the "single-unit" case: emit a direct await (no list needed). -# --------------------------------------------------------------------------- - - -def emit_loop_direct( - field_name: str, - iter_expr: str, - child_method: str, -) -> str: - # Emit a direct await asyncio.gather(*[...]) with optional skip guard - if field_name == "headers": - return ( - " if not self.skip_headers:\n" - f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])" - ) - elif field_name == "search_attributes": - return ( - " if not self.skip_search_attributes:\n" - f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])" - ) - else: - return f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])" - - -def emit_singular_direct( - field_name: str, access_expr: str, child_method: str, presence_word: str | None ) -> str: # Emit a direct await self._visit_...() with optional HasField check and skip guard if presence_word: @@ -144,7 +83,6 @@ def generate(self, roots: list[Descriptor]) -> str: # This file is generated by gen_payload_visitor.py. Changes should be made there. import abc import asyncio -from collections.abc import Coroutine from typing import Any, MutableSequence from temporalio.api.common.v1.message_pb2 import Payload @@ -167,19 +105,53 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: class _BoundedVisitorFunctions(VisitorFunctions): - \"\"\"Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.\"\"\" + \"\"\"Wraps VisitorFunctions to cap concurrent payload visits via a semaphore. + + After the full traversal, call drain() to await all in-flight tasks. + \"\"\" def __init__(self, inner: VisitorFunctions, sem: asyncio.Semaphore) -> None: self._inner = inner self._sem = sem + self._tasks: list[asyncio.Task[None]] = [] async def visit_payload(self, payload: Payload) -> None: - async with self._sem: - await self._inner.visit_payload(payload) + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payload(payload) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - async with self._sem: - await self._inner.visit_payloads(payloads) + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payloads(payloads) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) + + async def drain(self) -> None: + \"\"\"Wait for all in-flight background tasks to complete. + + On cancellation or error, cancels all remaining tasks and awaits + them so their finally blocks run before this coroutine returns. + \"\"\" + if not self._tasks: + return + try: + await asyncio.gather(*self._tasks) + except BaseException: + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) + raise class PayloadVisitor: @@ -200,10 +172,8 @@ def __init__( skip_search_attributes: If True, search attributes are not visited. skip_headers: If True, headers are not visited. concurrency_limit: Maximum number of payload visits that may run - concurrently during a single call to visit(). Defaults to 1. - The semaphore is applied to each visit_payload / visit_payloads - call, so it limits I/O-level concurrency without risking - deadlock in the recursive traversal. + concurrently during a single call to visit(). Defaults to 1 + (sequential). \"\"\" if concurrency_limit < 1: raise ValueError("concurrency_limit must be positive") @@ -215,13 +185,19 @@ async def visit( self, fs: VisitorFunctions, root: Any ) -> None: \"\"\"Visits the given root message with the given function.\"\"\" - fs = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit)) method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) - if method is not None: - await method(fs, root) - else: + if method is None: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + if self._concurrency_limit == 1: + await method(fs, root) + return + + bounded = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit)) + try: + await method(bounded, root) + finally: + await bounded.drain() """ @@ -388,30 +364,16 @@ def walk(self, desc: Descriptor) -> bool: lines.append(" if self.skip_search_attributes:") lines.append(" return") - # Use coros accumulation only when there are multiple independent units; - # a single unit is emitted with a direct await (no list overhead). - use_coros = len(emit_items) > 1 - if use_coros: - lines.append(" coros: list[Coroutine[Any, Any, None]] = []") - for item in emit_items: if item[0] == "loop": _, field_name, iter_expr, child_method = item - lines.append( - emit_loop(field_name, iter_expr, child_method) - if use_coros - else emit_loop_direct(field_name, iter_expr, child_method) - ) + lines.append(emit_loop(field_name, iter_expr, child_method)) elif item[0] == "singular": _, field_name, access_expr, child_method, presence_word = item lines.append( emit_singular( field_name, access_expr, child_method, presence_word ) - if use_coros - else emit_singular_direct( - field_name, access_expr, child_method, presence_word - ) ) else: # oneof_group for field_name, access_expr, child_method, presence_word in item[1]: @@ -419,15 +381,8 @@ def walk(self, desc: Descriptor) -> bool: emit_singular( field_name, access_expr, child_method, presence_word ) - if use_coros - else emit_singular_direct( - field_name, access_expr, child_method, presence_word - ) ) - if use_coros: - lines.append(" await asyncio.gather(*coros)") - self.methods.append("\n".join(lines) + "\n") return has_payload diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 6f596bc15..0f030ac01 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,7 +1,6 @@ # This file is generated by gen_payload_visitor.py. Changes should be made there. import abc import asyncio -from collections.abc import Coroutine from typing import Any, MutableSequence from temporalio.api.common.v1.message_pb2 import Payload @@ -24,19 +23,53 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: class _BoundedVisitorFunctions(VisitorFunctions): - """Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.""" + """Wraps VisitorFunctions to cap concurrent payload visits via a semaphore. + + After the full traversal, call drain() to await all in-flight tasks. + """ def __init__(self, inner: VisitorFunctions, sem: asyncio.Semaphore) -> None: self._inner = inner self._sem = sem + self._tasks: list[asyncio.Task[None]] = [] async def visit_payload(self, payload: Payload) -> None: - async with self._sem: - await self._inner.visit_payload(payload) + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payload(payload) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - async with self._sem: - await self._inner.visit_payloads(payloads) + await self._sem.acquire() + + async def _run() -> None: + try: + await self._inner.visit_payloads(payloads) + finally: + self._sem.release() + + self._tasks.append(asyncio.create_task(_run())) + + async def drain(self) -> None: + """Wait for all in-flight background tasks to complete. + + On cancellation or error, cancels all remaining tasks and awaits + them so their finally blocks run before this coroutine returns. + """ + if not self._tasks: + return + try: + await asyncio.gather(*self._tasks) + except BaseException: + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) + raise class PayloadVisitor: @@ -57,10 +90,8 @@ def __init__( skip_search_attributes: If True, search attributes are not visited. skip_headers: If True, headers are not visited. concurrency_limit: Maximum number of payload visits that may run - concurrently during a single call to visit(). Defaults to 1. - The semaphore is applied to each visit_payload / visit_payloads - call, so it limits I/O-level concurrency without risking - deadlock in the recursive traversal. + concurrently during a single call to visit(). Defaults to 1 + (sequential). """ if concurrency_limit < 1: raise ValueError("concurrency_limit must be positive") @@ -70,13 +101,21 @@ def __init__( async def visit(self, fs: VisitorFunctions, root: Any) -> None: """Visits the given root message with the given function.""" - fs = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit)) method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) - if method is not None: - await method(fs, root) - else: + if method is None: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + if self._concurrency_limit == 1: + await method(fs, root) + return + + bounded = _BoundedVisitorFunctions( + fs, asyncio.Semaphore(self._concurrency_limit) + ) + try: + await method(bounded, root) + finally: + await bounded.drain() async def _visit_temporal_api_common_v1_Payload(self, fs, o): await fs.visit_payload(o) @@ -108,104 +147,66 @@ async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): ) async def _visit_temporal_api_failure_v1_Failure(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] if o.HasField("encoded_attributes"): - coros.append( - self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) - ) + await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) if o.HasField("cause"): - coros.append(self._visit_temporal_api_failure_v1_Failure(fs, o.cause)) + await self._visit_temporal_api_failure_v1_Failure(fs, o.cause) if o.HasField("application_failure_info"): - coros.append( - self._visit_temporal_api_failure_v1_ApplicationFailureInfo( - fs, o.application_failure_info - ) + await self._visit_temporal_api_failure_v1_ApplicationFailureInfo( + fs, o.application_failure_info ) elif o.HasField("timeout_failure_info"): - coros.append( - self._visit_temporal_api_failure_v1_TimeoutFailureInfo( - fs, o.timeout_failure_info - ) + await self._visit_temporal_api_failure_v1_TimeoutFailureInfo( + fs, o.timeout_failure_info ) elif o.HasField("canceled_failure_info"): - coros.append( - self._visit_temporal_api_failure_v1_CanceledFailureInfo( - fs, o.canceled_failure_info - ) + await self._visit_temporal_api_failure_v1_CanceledFailureInfo( + fs, o.canceled_failure_info ) elif o.HasField("reset_workflow_failure_info"): - coros.append( - self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( - fs, o.reset_workflow_failure_info - ) + await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + fs, o.reset_workflow_failure_info ) - await asyncio.gather(*coros) async def _visit_temporal_api_common_v1_Memo(self, fs, o): - await asyncio.gather( - *[ - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.fields.values() - ] - ) + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_temporal_api_common_v1_SearchAttributes(self, fs, o): if self.skip_search_attributes: return - await asyncio.gather( - *[ - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.indexed_fields.values() - ] - ) + for v in o.indexed_fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] - coros.append(self._visit_payload_container(fs, o.arguments)) + await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("continued_failure"): - coros.append( - self._visit_temporal_api_failure_v1_Failure(fs, o.continued_failure) - ) + await self._visit_temporal_api_failure_v1_Failure(fs, o.continued_failure) if o.HasField("last_completion_result"): - coros.append( - self._visit_temporal_api_common_v1_Payloads( - fs, o.last_completion_result - ) + await self._visit_temporal_api_common_v1_Payloads( + fs, o.last_completion_result ) if o.HasField("memo"): - coros.append(self._visit_temporal_api_common_v1_Memo(fs, o.memo)) + await self._visit_temporal_api_common_v1_Memo(fs, o.memo) if o.HasField("search_attributes"): - coros.append( - self._visit_temporal_api_common_v1_SearchAttributes( - fs, o.search_attributes - ) + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes ) - await asyncio.gather(*coros) async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] - coros.append(self._visit_payload_container(fs, o.arguments)) + await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) - await asyncio.gather(*coros) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_SignalWorkflow(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] - coros.append(self._visit_payload_container(fs, o.input)) + await self._visit_payload_container(fs, o.input) if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) - await asyncio.gather(*coros) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_activity_result_Success(self, fs, o): if o.HasField("result"): @@ -284,14 +285,10 @@ async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflo await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] - coros.append(self._visit_payload_container(fs, o.input)) + await self._visit_payload_container(fs, o.input) if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) - await asyncio.gather(*coros) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( self, fs, o @@ -358,30 +355,20 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): ) async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): - await asyncio.gather( - *[ - self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) - for v in o.jobs - ] - ) + for v in o.jobs: + await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) async def _visit_temporal_api_sdk_v1_UserMetadata(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] if o.HasField("summary"): - coros.append(self._visit_temporal_api_common_v1_Payload(fs, o.summary)) + await self._visit_temporal_api_common_v1_Payload(fs, o.summary) if o.HasField("details"): - coros.append(self._visit_temporal_api_common_v1_Payload(fs, o.details)) - await asyncio.gather(*coros) + await self._visit_temporal_api_common_v1_Payload(fs, o.details) async def _visit_coresdk_workflow_commands_ScheduleActivity(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) - coros.append(self._visit_payload_container(fs, o.arguments)) - await asyncio.gather(*coros) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): if o.HasField("response"): @@ -404,64 +391,42 @@ async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( self, fs, o ): - coros: list[Coroutine[Any, Any, None]] = [] - coros.append(self._visit_payload_container(fs, o.arguments)) - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) for v in o.memo.values() - ) + await self._visit_payload_container(fs, o.arguments) + for v in o.memo.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("search_attributes"): - coros.append( - self._visit_temporal_api_common_v1_SearchAttributes( - fs, o.search_attributes - ) + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes ) - await asyncio.gather(*coros) async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] - coros.append(self._visit_payload_container(fs, o.input)) + await self._visit_payload_container(fs, o.input) if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) for v in o.memo.values() - ) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + for v in o.memo.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("search_attributes"): - coros.append( - self._visit_temporal_api_common_v1_SearchAttributes( - fs, o.search_attributes - ) + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes ) - await asyncio.gather(*coros) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( self, fs, o ): - coros: list[Coroutine[Any, Any, None]] = [] - coros.append(self._visit_payload_container(fs, o.args)) + await self._visit_payload_container(fs, o.args) if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) - await asyncio.gather(*coros) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] if not self.skip_headers: - coros.extend( - self._visit_temporal_api_common_v1_Payload(fs, v) - for v in o.headers.values() - ) - coros.append(self._visit_payload_container(fs, o.arguments)) - await asyncio.gather(*coros) + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( self, fs, o @@ -486,92 +451,60 @@ async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): await self._visit_temporal_api_common_v1_Payload(fs, o.input) async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): - coros: list[Coroutine[Any, Any, None]] = [] if o.HasField("user_metadata"): - coros.append( - self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) - ) + await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): - coros.append( - self._visit_coresdk_workflow_commands_ScheduleActivity( - fs, o.schedule_activity - ) + await self._visit_coresdk_workflow_commands_ScheduleActivity( + fs, o.schedule_activity ) elif o.HasField("respond_to_query"): - coros.append( - self._visit_coresdk_workflow_commands_QueryResult( - fs, o.respond_to_query - ) + await self._visit_coresdk_workflow_commands_QueryResult( + fs, o.respond_to_query ) elif o.HasField("complete_workflow_execution"): - coros.append( - self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( - fs, o.complete_workflow_execution - ) + await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( + fs, o.complete_workflow_execution ) elif o.HasField("fail_workflow_execution"): - coros.append( - self._visit_coresdk_workflow_commands_FailWorkflowExecution( - fs, o.fail_workflow_execution - ) + await self._visit_coresdk_workflow_commands_FailWorkflowExecution( + fs, o.fail_workflow_execution ) elif o.HasField("continue_as_new_workflow_execution"): - coros.append( - self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - fs, o.continue_as_new_workflow_execution - ) + await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( + fs, o.continue_as_new_workflow_execution ) elif o.HasField("start_child_workflow_execution"): - coros.append( - self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( - fs, o.start_child_workflow_execution - ) + await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( + fs, o.start_child_workflow_execution ) elif o.HasField("signal_external_workflow_execution"): - coros.append( - self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - fs, o.signal_external_workflow_execution - ) + await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + fs, o.signal_external_workflow_execution ) elif o.HasField("schedule_local_activity"): - coros.append( - self._visit_coresdk_workflow_commands_ScheduleLocalActivity( - fs, o.schedule_local_activity - ) + await self._visit_coresdk_workflow_commands_ScheduleLocalActivity( + fs, o.schedule_local_activity ) elif o.HasField("upsert_workflow_search_attributes"): - coros.append( - self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - fs, o.upsert_workflow_search_attributes - ) + await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( + fs, o.upsert_workflow_search_attributes ) elif o.HasField("modify_workflow_properties"): - coros.append( - self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( - fs, o.modify_workflow_properties - ) + await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( + fs, o.modify_workflow_properties ) elif o.HasField("update_response"): - coros.append( - self._visit_coresdk_workflow_commands_UpdateResponse( - fs, o.update_response - ) + await self._visit_coresdk_workflow_commands_UpdateResponse( + fs, o.update_response ) elif o.HasField("schedule_nexus_operation"): - coros.append( - self._visit_coresdk_workflow_commands_ScheduleNexusOperation( - fs, o.schedule_nexus_operation - ) + await self._visit_coresdk_workflow_commands_ScheduleNexusOperation( + fs, o.schedule_nexus_operation ) - await asyncio.gather(*coros) async def _visit_coresdk_workflow_completion_Success(self, fs, o): - await asyncio.gather( - *[ - self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) - for v in o.commands - ] - ) + for v in o.commands: + await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) async def _visit_coresdk_workflow_completion_Failure(self, fs, o): if o.HasField("failure"): diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 15860f58c..876387393 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -3,6 +3,7 @@ import time from collections.abc import MutableSequence +import pytest from google.protobuf.duration_pb2 import Duration import temporalio.bridge.worker @@ -273,6 +274,58 @@ async def _visit(self, count: int) -> None: assert visitor_concurrent.max_concurrent == 5 +async def test_cancel_drains_background_tasks(): + """Cancelling visit() cancels in-flight tasks and awaits their cleanup.""" + tasks_started = 0 + tasks_cleaned_up = 0 + background_running = asyncio.Event() + + class SlowVisitor(VisitorFunctions): + async def visit_payload(self, payload: Payload) -> None: + pass + + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + nonlocal tasks_started, tasks_cleaned_up + tasks_started += 1 + background_running.set() + try: + await asyncio.sleep(10) + finally: + tasks_cleaned_up += 1 + + completion = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_activity=ScheduleActivity( + seq=i, + activity_id=str(i), + activity_type="", + task_queue="", + arguments=[Payload(data=f"arg_{i}".encode())], + priority=Priority(), + ) + ) + for i in range(5) + ] + ), + ) + + task = asyncio.create_task( + PayloadVisitor(concurrency_limit=5).visit(SlowVisitor(), completion) + ) + await background_running.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + # All started tasks ran their finally blocks before drain() returned. + assert tasks_started > 0 + assert tasks_cleaned_up == tasks_started + + async def test_bridge_encoding(): comp = WorkflowActivationCompletion( run_id="1",