Skip to content

Commit ecee01a

Browse files
committed
Minimize task creation for concurrent payload visiting
1 parent c0a8a01 commit ecee01a

File tree

3 files changed

+252
-311
lines changed

3 files changed

+252
-311
lines changed

scripts/gen_payload_visitor.py

Lines changed: 60 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -20,87 +20,26 @@ def name_for(desc: Descriptor) -> str:
2020
return desc.full_name.replace(".", "_")
2121

2222

23-
# ---------------------------------------------------------------------------
24-
# Emitters for the "multi-unit" case: accumulate coroutines into `coros` list
25-
# and let the caller do a single asyncio.gather(*coros) at the end.
26-
# ---------------------------------------------------------------------------
27-
28-
2923
def emit_loop(
3024
field_name: str,
3125
iter_expr: str,
3226
child_method: str,
3327
) -> str:
34-
# Emit a coros.extend() over a collection with optional skip guard
28+
# Emit a for-loop with direct await, with optional skip guard
29+
inner = (
30+
f"for v in {iter_expr}:\n"
31+
f" await self._visit_{child_method}(fs, v)"
32+
)
3533
if field_name == "headers":
36-
return (
37-
" if not self.skip_headers:\n"
38-
f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})"
39-
)
34+
return f" if not self.skip_headers:\n {inner}"
4035
elif field_name == "search_attributes":
41-
return (
42-
" if not self.skip_search_attributes:\n"
43-
f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})"
44-
)
36+
return f" if not self.skip_search_attributes:\n {inner}"
4537
else:
46-
return f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})"
38+
return f" {inner}"
4739

4840

4941
def emit_singular(
5042
field_name: str, access_expr: str, child_method: str, presence_word: str | None
51-
) -> str:
52-
# Emit a coros.append() with optional HasField check and skip guard
53-
if presence_word:
54-
if field_name == "headers":
55-
return (
56-
" if not self.skip_headers:\n"
57-
f' {presence_word} o.HasField("{field_name}"):\n'
58-
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
59-
)
60-
else:
61-
return (
62-
f' {presence_word} o.HasField("{field_name}"):\n'
63-
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
64-
)
65-
else:
66-
if field_name == "headers":
67-
return (
68-
" if not self.skip_headers:\n"
69-
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
70-
)
71-
else:
72-
return (
73-
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
74-
)
75-
76-
77-
# ---------------------------------------------------------------------------
78-
# Emitters for the "single-unit" case: emit a direct await (no list needed).
79-
# ---------------------------------------------------------------------------
80-
81-
82-
def emit_loop_direct(
83-
field_name: str,
84-
iter_expr: str,
85-
child_method: str,
86-
) -> str:
87-
# Emit a direct await asyncio.gather(*[...]) with optional skip guard
88-
if field_name == "headers":
89-
return (
90-
" if not self.skip_headers:\n"
91-
f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])"
92-
)
93-
elif field_name == "search_attributes":
94-
return (
95-
" if not self.skip_search_attributes:\n"
96-
f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])"
97-
)
98-
else:
99-
return f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])"
100-
101-
102-
def emit_singular_direct(
103-
field_name: str, access_expr: str, child_method: str, presence_word: str | None
10443
) -> str:
10544
# Emit a direct await self._visit_...() with optional HasField check and skip guard
10645
if presence_word:
@@ -144,7 +83,6 @@ def generate(self, roots: list[Descriptor]) -> str:
14483
# This file is generated by gen_payload_visitor.py. Changes should be made there.
14584
import abc
14685
import asyncio
147-
from collections.abc import Coroutine
14886
from typing import Any, MutableSequence
14987
15088
from temporalio.api.common.v1.message_pb2 import Payload
@@ -167,19 +105,53 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
167105
168106
169107
class _BoundedVisitorFunctions(VisitorFunctions):
170-
\"\"\"Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.\"\"\"
108+
\"\"\"Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.
109+
110+
After the full traversal, call drain() to await all in-flight tasks.
111+
\"\"\"
171112
172113
def __init__(self, inner: VisitorFunctions, sem: asyncio.Semaphore) -> None:
173114
self._inner = inner
174115
self._sem = sem
116+
self._tasks: list[asyncio.Task[None]] = []
175117
176118
async def visit_payload(self, payload: Payload) -> None:
177-
async with self._sem:
178-
await self._inner.visit_payload(payload)
119+
await self._sem.acquire()
120+
121+
async def _run() -> None:
122+
try:
123+
await self._inner.visit_payload(payload)
124+
finally:
125+
self._sem.release()
126+
127+
self._tasks.append(asyncio.create_task(_run()))
179128
180129
async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
181-
async with self._sem:
182-
await self._inner.visit_payloads(payloads)
130+
await self._sem.acquire()
131+
132+
async def _run() -> None:
133+
try:
134+
await self._inner.visit_payloads(payloads)
135+
finally:
136+
self._sem.release()
137+
138+
self._tasks.append(asyncio.create_task(_run()))
139+
140+
async def drain(self) -> None:
141+
\"\"\"Wait for all in-flight background tasks to complete.
142+
143+
On cancellation or error, cancels all remaining tasks and awaits
144+
them so their finally blocks run before this coroutine returns.
145+
\"\"\"
146+
if not self._tasks:
147+
return
148+
try:
149+
await asyncio.gather(*self._tasks)
150+
except BaseException:
151+
for task in self._tasks:
152+
task.cancel()
153+
await asyncio.gather(*self._tasks, return_exceptions=True)
154+
raise
183155
184156
185157
class PayloadVisitor:
@@ -200,10 +172,8 @@ def __init__(
200172
skip_search_attributes: If True, search attributes are not visited.
201173
skip_headers: If True, headers are not visited.
202174
concurrency_limit: Maximum number of payload visits that may run
203-
concurrently during a single call to visit(). Defaults to 1.
204-
The semaphore is applied to each visit_payload / visit_payloads
205-
call, so it limits I/O-level concurrency without risking
206-
deadlock in the recursive traversal.
175+
concurrently during a single call to visit(). Defaults to 1
176+
(sequential).
207177
\"\"\"
208178
if concurrency_limit < 1:
209179
raise ValueError("concurrency_limit must be positive")
@@ -215,13 +185,19 @@ async def visit(
215185
self, fs: VisitorFunctions, root: Any
216186
) -> None:
217187
\"\"\"Visits the given root message with the given function.\"\"\"
218-
fs = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit))
219188
method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
220189
method = getattr(self, method_name, None)
221-
if method is not None:
222-
await method(fs, root)
223-
else:
190+
if method is None:
224191
raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
192+
if self._concurrency_limit == 1:
193+
await method(fs, root)
194+
return
195+
196+
bounded = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit))
197+
try:
198+
await method(bounded, root)
199+
finally:
200+
await bounded.drain()
225201
226202
"""
227203

@@ -388,46 +364,25 @@ def walk(self, desc: Descriptor) -> bool:
388364
lines.append(" if self.skip_search_attributes:")
389365
lines.append(" return")
390366

391-
# Use coros accumulation only when there are multiple independent units;
392-
# a single unit is emitted with a direct await (no list overhead).
393-
use_coros = len(emit_items) > 1
394-
if use_coros:
395-
lines.append(" coros: list[Coroutine[Any, Any, None]] = []")
396-
397367
for item in emit_items:
398368
if item[0] == "loop":
399369
_, field_name, iter_expr, child_method = item
400-
lines.append(
401-
emit_loop(field_name, iter_expr, child_method)
402-
if use_coros
403-
else emit_loop_direct(field_name, iter_expr, child_method)
404-
)
370+
lines.append(emit_loop(field_name, iter_expr, child_method))
405371
elif item[0] == "singular":
406372
_, field_name, access_expr, child_method, presence_word = item
407373
lines.append(
408374
emit_singular(
409375
field_name, access_expr, child_method, presence_word
410376
)
411-
if use_coros
412-
else emit_singular_direct(
413-
field_name, access_expr, child_method, presence_word
414-
)
415377
)
416378
else: # oneof_group
417379
for field_name, access_expr, child_method, presence_word in item[1]:
418380
lines.append(
419381
emit_singular(
420382
field_name, access_expr, child_method, presence_word
421383
)
422-
if use_coros
423-
else emit_singular_direct(
424-
field_name, access_expr, child_method, presence_word
425-
)
426384
)
427385

428-
if use_coros:
429-
lines.append(" await asyncio.gather(*coros)")
430-
431386
self.methods.append("\n".join(lines) + "\n")
432387
return has_payload
433388

0 commit comments

Comments
 (0)