Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 60 additions & 105 deletions scripts/gen_payload_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,87 +20,26 @@ def name_for(desc: Descriptor) -> str:
return desc.full_name.replace(".", "_")


# ---------------------------------------------------------------------------
# Emitters for the "multi-unit" case: accumulate coroutines into `coros` list
# and let the caller do a single asyncio.gather(*coros) at the end.
# ---------------------------------------------------------------------------


def emit_loop(
field_name: str,
iter_expr: str,
child_method: str,
) -> str:
# Emit a coros.extend() over a collection with optional skip guard
# Emit a for-loop with direct await, with optional skip guard
inner = (
f"for v in {iter_expr}:\n"
f" await self._visit_{child_method}(fs, v)"
)
if field_name == "headers":
return (
" if not self.skip_headers:\n"
f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})"
)
return f" if not self.skip_headers:\n {inner}"
elif field_name == "search_attributes":
return (
" if not self.skip_search_attributes:\n"
f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})"
)
return f" if not self.skip_search_attributes:\n {inner}"
else:
return f" coros.extend(self._visit_{child_method}(fs, v) for v in {iter_expr})"
return f" {inner}"


def emit_singular(
field_name: str, access_expr: str, child_method: str, presence_word: str | None
) -> str:
# Emit a coros.append() with optional HasField check and skip guard
if presence_word:
if field_name == "headers":
return (
" if not self.skip_headers:\n"
f' {presence_word} o.HasField("{field_name}"):\n'
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
)
else:
return (
f' {presence_word} o.HasField("{field_name}"):\n'
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
)
else:
if field_name == "headers":
return (
" if not self.skip_headers:\n"
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
)
else:
return (
f" coros.append(self._visit_{child_method}(fs, {access_expr}))"
)


# ---------------------------------------------------------------------------
# Emitters for the "single-unit" case: emit a direct await (no list needed).
# ---------------------------------------------------------------------------


def emit_loop_direct(
field_name: str,
iter_expr: str,
child_method: str,
) -> str:
# Emit a direct await asyncio.gather(*[...]) with optional skip guard
if field_name == "headers":
return (
" if not self.skip_headers:\n"
f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])"
)
elif field_name == "search_attributes":
return (
" if not self.skip_search_attributes:\n"
f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])"
)
else:
return f" await asyncio.gather(*[self._visit_{child_method}(fs, v) for v in {iter_expr}])"


def emit_singular_direct(
field_name: str, access_expr: str, child_method: str, presence_word: str | None
) -> str:
# Emit a direct await self._visit_...() with optional HasField check and skip guard
if presence_word:
Expand Down Expand Up @@ -144,7 +83,6 @@ def generate(self, roots: list[Descriptor]) -> str:
# This file is generated by gen_payload_visitor.py. Changes should be made there.
import abc
import asyncio
from collections.abc import Coroutine
from typing import Any, MutableSequence

from temporalio.api.common.v1.message_pb2 import Payload
Expand All @@ -167,19 +105,53 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:


class _BoundedVisitorFunctions(VisitorFunctions):
\"\"\"Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.\"\"\"
\"\"\"Wraps VisitorFunctions to cap concurrent payload visits via a semaphore.

After the full traversal, call drain() to await all in-flight tasks.
\"\"\"

def __init__(self, inner: VisitorFunctions, sem: asyncio.Semaphore) -> None:
self._inner = inner
self._sem = sem
self._tasks: list[asyncio.Task[None]] = []

async def visit_payload(self, payload: Payload) -> None:
async with self._sem:
await self._inner.visit_payload(payload)
await self._sem.acquire()

async def _run() -> None:
try:
await self._inner.visit_payload(payload)
finally:
self._sem.release()

self._tasks.append(asyncio.create_task(_run()))

async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
async with self._sem:
await self._inner.visit_payloads(payloads)
await self._sem.acquire()

async def _run() -> None:
try:
await self._inner.visit_payloads(payloads)
finally:
self._sem.release()

self._tasks.append(asyncio.create_task(_run()))

async def drain(self) -> None:
\"\"\"Wait for all in-flight background tasks to complete.

On cancellation or error, cancels all remaining tasks and awaits
them so their finally blocks run before this coroutine returns.
\"\"\"
if not self._tasks:
return
try:
await asyncio.gather(*self._tasks)
except BaseException:
for task in self._tasks:
task.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
raise


class PayloadVisitor:
Expand All @@ -200,10 +172,8 @@ def __init__(
skip_search_attributes: If True, search attributes are not visited.
skip_headers: If True, headers are not visited.
concurrency_limit: Maximum number of payload visits that may run
concurrently during a single call to visit(). Defaults to 1.
The semaphore is applied to each visit_payload / visit_payloads
call, so it limits I/O-level concurrency without risking
deadlock in the recursive traversal.
concurrently during a single call to visit(). Defaults to 1
(sequential).
\"\"\"
if concurrency_limit < 1:
raise ValueError("concurrency_limit must be positive")
Expand All @@ -215,13 +185,19 @@ async def visit(
self, fs: VisitorFunctions, root: Any
) -> None:
\"\"\"Visits the given root message with the given function.\"\"\"
fs = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit))
method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
method = getattr(self, method_name, None)
if method is not None:
await method(fs, root)
else:
if method is None:
raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
if self._concurrency_limit == 1:
await method(fs, root)
return

bounded = _BoundedVisitorFunctions(fs, asyncio.Semaphore(self._concurrency_limit))
try:
await method(bounded, root)
finally:
await bounded.drain()

"""

Expand Down Expand Up @@ -388,46 +364,25 @@ def walk(self, desc: Descriptor) -> bool:
lines.append(" if self.skip_search_attributes:")
lines.append(" return")

# Use coros accumulation only when there are multiple independent units;
# a single unit is emitted with a direct await (no list overhead).
use_coros = len(emit_items) > 1
if use_coros:
lines.append(" coros: list[Coroutine[Any, Any, None]] = []")

for item in emit_items:
if item[0] == "loop":
_, field_name, iter_expr, child_method = item
lines.append(
emit_loop(field_name, iter_expr, child_method)
if use_coros
else emit_loop_direct(field_name, iter_expr, child_method)
)
lines.append(emit_loop(field_name, iter_expr, child_method))
elif item[0] == "singular":
_, field_name, access_expr, child_method, presence_word = item
lines.append(
emit_singular(
field_name, access_expr, child_method, presence_word
)
if use_coros
else emit_singular_direct(
field_name, access_expr, child_method, presence_word
)
)
else: # oneof_group
for field_name, access_expr, child_method, presence_word in item[1]:
lines.append(
emit_singular(
field_name, access_expr, child_method, presence_word
)
if use_coros
else emit_singular_direct(
field_name, access_expr, child_method, presence_word
)
)

if use_coros:
lines.append(" await asyncio.gather(*coros)")

self.methods.append("\n".join(lines) + "\n")
return has_payload

Expand Down
Loading
Loading