@@ -20,87 +20,26 @@ def name_for(desc: Descriptor) -> str:
2020 return desc .full_name .replace ("." , "_" )
2121
2222
23- # ---------------------------------------------------------------------------
24- # Emitters for the "multi-unit" case: accumulate coroutines into `coros` list
25- # and let the caller do a single asyncio.gather(*coros) at the end.
26- # ---------------------------------------------------------------------------
27-
28-
2923def emit_loop (
3024 field_name : str ,
3125 iter_expr : str ,
3226 child_method : str ,
3327) -> str :
34- # Emit a coros.extend() over a collection with optional skip guard
28+ # Emit a for-loop with direct await, with optional skip guard
29+ inner = (
30+ f"for v in { iter_expr } :\n "
31+ f" await self._visit_{ child_method } (fs, v)"
32+ )
3533 if field_name == "headers" :
36- return (
37- " if not self.skip_headers:\n "
38- f" coros.extend(self._visit_{ child_method } (fs, v) for v in { iter_expr } )"
39- )
34+ return f" if not self.skip_headers:\n { inner } "
4035 elif field_name == "search_attributes" :
41- return (
42- " if not self.skip_search_attributes:\n "
43- f" coros.extend(self._visit_{ child_method } (fs, v) for v in { iter_expr } )"
44- )
36+ return f" if not self.skip_search_attributes:\n { inner } "
4537 else :
46- return f" coros.extend(self._visit_ { child_method } (fs, v) for v in { iter_expr } ) "
38+ return f" { inner } "
4739
4840
4941def emit_singular (
5042 field_name : str , access_expr : str , child_method : str , presence_word : str | None
51- ) -> str :
52- # Emit a coros.append() with optional HasField check and skip guard
53- if presence_word :
54- if field_name == "headers" :
55- return (
56- " if not self.skip_headers:\n "
57- f' { presence_word } o.HasField("{ field_name } "):\n '
58- f" coros.append(self._visit_{ child_method } (fs, { access_expr } ))"
59- )
60- else :
61- return (
62- f' { presence_word } o.HasField("{ field_name } "):\n '
63- f" coros.append(self._visit_{ child_method } (fs, { access_expr } ))"
64- )
65- else :
66- if field_name == "headers" :
67- return (
68- " if not self.skip_headers:\n "
69- f" coros.append(self._visit_{ child_method } (fs, { access_expr } ))"
70- )
71- else :
72- return (
73- f" coros.append(self._visit_{ child_method } (fs, { access_expr } ))"
74- )
75-
76-
77- # ---------------------------------------------------------------------------
78- # Emitters for the "single-unit" case: emit a direct await (no list needed).
79- # ---------------------------------------------------------------------------
80-
81-
82- def emit_loop_direct (
83- field_name : str ,
84- iter_expr : str ,
85- child_method : str ,
86- ) -> str :
87- # Emit a direct await asyncio.gather(*[...]) with optional skip guard
88- if field_name == "headers" :
89- return (
90- " if not self.skip_headers:\n "
91- f" await asyncio.gather(*[self._visit_{ child_method } (fs, v) for v in { iter_expr } ])"
92- )
93- elif field_name == "search_attributes" :
94- return (
95- " if not self.skip_search_attributes:\n "
96- f" await asyncio.gather(*[self._visit_{ child_method } (fs, v) for v in { iter_expr } ])"
97- )
98- else :
99- return f" await asyncio.gather(*[self._visit_{ child_method } (fs, v) for v in { iter_expr } ])"
100-
101-
102- def emit_singular_direct (
103- field_name : str , access_expr : str , child_method : str , presence_word : str | None
10443) -> str :
10544 # Emit a direct await self._visit_...() with optional HasField check and skip guard
10645 if presence_word :
@@ -144,7 +83,6 @@ def generate(self, roots: list[Descriptor]) -> str:
14483# This file is generated by gen_payload_visitor.py. Changes should be made there.
14584import abc
14685import asyncio
147- from collections.abc import Coroutine
14886from typing import Any, MutableSequence
14987
15088from temporalio.api.common.v1.message_pb2 import Payload
@@ -167,19 +105,53 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
167105
168106
169107class _BoundedVisitorFunctions(VisitorFunctions):
170- \" \" \" Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.\" \" \"
108+ \" \" \" Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.
109+
110+ After the full traversal, call drain() to await all in-flight tasks.
111+ \" \" \"
171112
172113 def __init__(self, inner: VisitorFunctions, sem: asyncio.Semaphore) -> None:
173114 self._inner = inner
174115 self._sem = sem
116+ self._tasks: list[asyncio.Task[None]] = []
175117
176118 async def visit_payload(self, payload: Payload) -> None:
177- async with self._sem:
178- await self._inner.visit_payload(payload)
119+ await self._sem.acquire()
120+
121+ async def _run() -> None:
122+ try:
123+ await self._inner.visit_payload(payload)
124+ finally:
125+ self._sem.release()
126+
127+ self._tasks.append(asyncio.create_task(_run()))
179128
180129 async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
181- async with self._sem:
182- await self._inner.visit_payloads(payloads)
130+ await self._sem.acquire()
131+
132+ async def _run() -> None:
133+ try:
134+ await self._inner.visit_payloads(payloads)
135+ finally:
136+ self._sem.release()
137+
138+ self._tasks.append(asyncio.create_task(_run()))
139+
140+ async def drain(self) -> None:
141+ \" \" \" Wait for all in-flight background tasks to complete.
142+
143+ On cancellation or error, cancels all remaining tasks and awaits
144+ them so their finally blocks run before this coroutine returns.
145+ \" \" \"
146+ if not self._tasks:
147+ return
148+ try:
149+ await asyncio.gather(*self._tasks)
150+ except BaseException:
151+ for task in self._tasks:
152+ task.cancel()
153+ await asyncio.gather(*self._tasks, return_exceptions=True)
154+ raise
183155
184156
185157class PayloadVisitor:
@@ -200,10 +172,8 @@ def __init__(
200172 skip_search_attributes: If True, search attributes are not visited.
201173 skip_headers: If True, headers are not visited.
202174 concurrency_limit: Maximum number of payload visits that may run
203- concurrently during a single call to visit(). Defaults to 1.
204- The semaphore is applied to each visit_payload / visit_payloads
205- call, so it limits I/O-level concurrency without risking
206- deadlock in the recursive traversal.
175+ concurrently during a single call to visit(). Defaults to 1
176+ (sequential).
207177 \" \" \"
208178 if concurrency_limit < 1:
209179 raise ValueError("concurrency_limit must be positive")
@@ -215,13 +185,19 @@ async def visit(
215185 self, fs: VisitorFunctions, root: Any
216186 ) -> None:
217187 \" \" \" Visits the given root message with the given function.\" \" \"
218- fs = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit))
219188 method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
220189 method = getattr(self, method_name, None)
221- if method is not None:
222- await method(fs, root)
223- else:
190+ if method is None:
224191 raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
192+ if self._concurrency_limit == 1:
193+ await method(fs, root)
194+ return
195+
196+ bounded = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit))
197+ try:
198+ await method(bounded, root)
199+ finally:
200+ await bounded.drain()
225201
226202"""
227203
@@ -388,46 +364,25 @@ def walk(self, desc: Descriptor) -> bool:
388364 lines .append (" if self.skip_search_attributes:" )
389365 lines .append (" return" )
390366
391- # Use coros accumulation only when there are multiple independent units;
392- # a single unit is emitted with a direct await (no list overhead).
393- use_coros = len (emit_items ) > 1
394- if use_coros :
395- lines .append (" coros: list[Coroutine[Any, Any, None]] = []" )
396-
397367 for item in emit_items :
398368 if item [0 ] == "loop" :
399369 _ , field_name , iter_expr , child_method = item
400- lines .append (
401- emit_loop (field_name , iter_expr , child_method )
402- if use_coros
403- else emit_loop_direct (field_name , iter_expr , child_method )
404- )
370+ lines .append (emit_loop (field_name , iter_expr , child_method ))
405371 elif item [0 ] == "singular" :
406372 _ , field_name , access_expr , child_method , presence_word = item
407373 lines .append (
408374 emit_singular (
409375 field_name , access_expr , child_method , presence_word
410376 )
411- if use_coros
412- else emit_singular_direct (
413- field_name , access_expr , child_method , presence_word
414- )
415377 )
416378 else : # oneof_group
417379 for field_name , access_expr , child_method , presence_word in item [1 ]:
418380 lines .append (
419381 emit_singular (
420382 field_name , access_expr , child_method , presence_word
421383 )
422- if use_coros
423- else emit_singular_direct (
424- field_name , access_expr , child_method , presence_word
425- )
426384 )
427385
428- if use_coros :
429- lines .append (" await asyncio.gather(*coros)" )
430-
431386 self .methods .append ("\n " .join (lines ) + "\n " )
432387 return has_payload
433388
0 commit comments