From cb252f056e09367f6cad0ce572796442540cc2d7 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Mon, 15 Dec 2025 19:54:07 +0100 Subject: [PATCH 1/4] Add explicit tool execution synchronization --- python/restate/ext/adk/plugin.py | 58 +++++++++++++++++++++++++------- python/restate/ext/turnstile.py | 31 +++++++++++++++++ 2 files changed, 76 insertions(+), 13 deletions(-) create mode 100644 python/restate/ext/turnstile.py diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index 0911190..d9d9c82 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -35,17 +35,25 @@ from restate.extensions import current_context +from restate.ext.turnstile import Turnstile + + +def _create_turnstile(s: LlmResponse) -> Turnstile: + ids = _get_function_call_ids(s) + turnstile = Turnstile(ids) + return turnstile + class RestatePlugin(BasePlugin): """A plugin to integrate Restate with the ADK framework.""" _models: dict[str, BaseLlm] - _locks: dict[str, asyncio.Lock] + _turnstiles: dict[str, Turnstile | None] def __init__(self, *, max_model_call_retries: int = 10): super().__init__(name="restate_plugin") self._models = {} - self._locks = {} + self._turnstiles = {} self._max_model_call_retries = max_model_call_retries async def before_agent_callback( @@ -62,7 +70,7 @@ async def before_agent_callback( ) model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model) self._models[callback_context.invocation_id] = model - self._locks[callback_context.invocation_id] = asyncio.Lock() + self._turnstiles[callback_context.invocation_id] = None id = callback_context.invocation_id event = ctx.request().attempt_finished_event @@ -73,7 +81,7 @@ async def release_task(): await event.wait() finally: self._models.pop(id, None) - self._locks.pop(id, None) + self._turnstiles.pop(id, None) _ = asyncio.create_task(release_task()) return None @@ -82,7 +90,7 @@ async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: self._models.pop(callback_context.invocation_id, None) - self._locks.pop(callback_context.invocation_id, None) + self._turnstiles.pop(callback_context.invocation_id, None) return None async def after_run_callback(self, *, invocation_context: InvocationContext) -> None: @@ -100,6 +108,8 @@ async def before_model_callback( "No Restate context found, the restate plugin must be used from within a restate handler." ) response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request) + turnstile = _create_turnstile(response) + self._turnstiles[callback_context.invocation_id] = turnstile return response async def before_tool_callback( @@ -109,11 +119,17 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - lock = self._locks[tool_context.invocation_id] + turnstile = self._turnstiles[tool_context.invocation_id] + assert turnstile is not None, "Turnstile not found for tool invocation." + + id = tool_context.function_call_id + assert id is not None, "Function call ID is required for tool invocation." + + await turnstile.wait_for(id) + ctx = current_context() - await lock.acquire() tool_context.session.state["restate_context"] = ctx - # TODO: if we want we can also automatically wrap tools with ctx.run_typed here + return None async def after_tool_callback( @@ -125,8 +141,11 @@ async def after_tool_callback( result: dict, ) -> Optional[dict]: tool_context.session.state.pop("restate_context", None) - lock = self._locks[tool_context.invocation_id] - lock.release() + turnstile = self._turnstiles[tool_context.invocation_id] + assert turnstile is not None, "Turnstile not found for tool invocation." + id = tool_context.function_call_id + assert id is not None, "Function call ID is required for tool invocation." + turnstile.allow_next_after(id) return None async def on_tool_error_callback( @@ -138,13 +157,26 @@ async def on_tool_error_callback( error: Exception, ) -> Optional[dict]: tool_context.session.state.pop("restate_context", None) - lock = self._locks[tool_context.invocation_id] - lock.release() + turnstile = self._turnstiles[tool_context.invocation_id] + assert turnstile is not None, "Turnstile not found for tool invocation." + id = tool_context.function_call_id + assert id is not None, "Function call ID is required for tool invocation." + turnstile.allow_next_after(id) return None async def close(self): self._models.clear() - self._locks.clear() + self._turnstiles.clear() + + +def _get_function_call_ids(s: LlmResponse) -> list[str]: + ids = [] + if s.content and s.content.parts: + for part in s.content.parts: + if part.function_call: + if part.function_call.id: + ids.append(part.function_call.id) + return ids def _generate_client_function_call_id(s: LlmResponse) -> None: diff --git a/python/restate/ext/turnstile.py b/python/restate/ext/turnstile.py new file mode 100644 index 0000000..a2f6104 --- /dev/null +++ b/python/restate/ext/turnstile.py @@ -0,0 +1,31 @@ +from asyncio import Event + + +class Turnstile: + """A turnstile to manage ordered access based on IDs.""" + + def __init__(self, ids: list[str]): + # ordered mapping of id to next id in the sequence + # for example: + # {'id1': 'id2', 'id2': 'id3'} <-- id3 is the last. + # {} <-- no ids, no turns. + # {} <-- single id, no next turn. + # + self.turns = dict(zip(ids, ids[1:])) + # mapping of id to event that signals when that id's turn is allowed + self.events = {id: Event() for id in ids} + if ids: + # make sure that the first id can proceed immediately + event = self.events[ids[0]] + event.set() + + async def wait_for(self, id: str) -> None: + event = self.events[id] + await event.wait() + + def allow_next_after(self, id: str) -> None: + next_id = self.turns.get(id) + if next_id is None: + return + next_event = self.events[next_id] + next_event.set() From c7b5fe9863c85b2c3e3d6efdb28731eede707b5e Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 16 Dec 2025 17:27:43 +0100 Subject: [PATCH 2/4] Integrate turnstile into OpenAI ext --- python/restate/ext/openai/models.py | 79 +++++++++++++ python/restate/ext/openai/runner_wrapper.py | 121 ++++---------------- python/restate/ext/openai/session.py | 61 ++++++++++ python/restate/ext/openai/utils.py | 110 ++++++++++++++++++ 4 files changed, 273 insertions(+), 98 deletions(-) create mode 100644 python/restate/ext/openai/models.py create mode 100644 python/restate/ext/openai/session.py create mode 100644 python/restate/ext/openai/utils.py diff --git a/python/restate/ext/openai/models.py b/python/restate/ext/openai/models.py new file mode 100644 index 0000000..c48cf88 --- /dev/null +++ b/python/restate/ext/openai/models.py @@ -0,0 +1,79 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +This module contains the optional OpenAI integration for Restate. +""" + +import dataclasses + +from agents import ( + Usage, +) +from agents.items import TResponseOutputItem +from agents.items import TResponseInputItem +from datetime import timedelta +from typing import Optional +from pydantic import BaseModel + +from restate.ext.turnstile import Turnstile + + +class State: + __slots__ = ("turnstile",) + + def __init__(self) -> None: + self.turnstile = Turnstile([]) + + +@dataclasses.dataclass +class LlmRetryOpts: + max_attempts: Optional[int] = 10 + """Max number of attempts (including the initial), before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + max_duration: Optional[timedelta] = None + """Max duration of retries, before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + initial_retry_interval: Optional[timedelta] = timedelta(seconds=1) + """Initial interval for the first retry attempt. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy.""" + max_retry_interval: Optional[timedelta] = None + """Max interval between retries. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + The default is 10 seconds.""" + retry_interval_factor: Optional[float] = None + """Exponentiation factor to use when computing the next retry delay. + + If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy.""" + + +# The OpenAI ModelResponse class is a dataclass with Pydantic fields. +# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model. +class RestateModelResponse(BaseModel): + output: list[TResponseOutputItem] + """A list of outputs (messages, tool calls, etc) generated by the model""" + + usage: Usage + """The usage information for the response.""" + + response_id: str | None + """An ID for the response which can be used to refer to the response in subsequent calls to the + model. Not supported by all model providers. + If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can + be passed to `Runner.run`. + """ + + def to_input_items(self) -> list[TResponseInputItem]: + return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 572b819..7bc2d44 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -15,7 +15,6 @@ import dataclasses from agents import ( - Usage, Model, RunContextWrapper, AgentsException, @@ -28,62 +27,18 @@ Runner, ) from agents.models.multi_provider import MultiProvider -from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse -from agents.memory.session import SessionABC +from agents.items import TResponseStreamEvent, ModelResponse from agents.items import TResponseInputItem -from datetime import timedelta -from typing import List, Any, AsyncIterator, Optional, cast -from pydantic import BaseModel +from typing import Any, AsyncIterator from restate.exceptions import SdkInternalBaseException +from restate.ext.turnstile import Turnstile from restate.extensions import current_context -from restate import RunOptions, ObjectContext, TerminalError +from restate import RunOptions, TerminalError - -@dataclasses.dataclass -class LlmRetryOpts: - max_attempts: Optional[int] = 10 - """Max number of attempts (including the initial), before giving up. - - When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" - max_duration: Optional[timedelta] = None - """Max duration of retries, before giving up. - - When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" - initial_retry_interval: Optional[timedelta] = timedelta(seconds=1) - """Initial interval for the first retry attempt. - Retry interval will grow by a factor specified in `retry_interval_factor`. - - If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy.""" - max_retry_interval: Optional[timedelta] = None - """Max interval between retries. - Retry interval will grow by a factor specified in `retry_interval_factor`. - - The default is 10 seconds.""" - retry_interval_factor: Optional[float] = None - """Exponentiation factor to use when computing the next retry delay. - - If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy.""" - - -# The OpenAI ModelResponse class is a dataclass with Pydantic fields. -# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model. -class RestateModelResponse(BaseModel): - output: list[TResponseOutputItem] - """A list of outputs (messages, tool calls, etc) generated by the model""" - - usage: Usage - """The usage information for the response.""" - - response_id: str | None - """An ID for the response which can be used to refer to the response in subsequent calls to the - model. Not supported by all model providers. - If using OpenAI models via the Responses API, this is the `response_id` parameter, and it can - be passed to `Runner.run`. - """ - - def to_input_items(self) -> list[TResponseInputItem]: - return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore +from .utils import get_function_call_ids, wrap_agent_tools +from .models import LlmRetryOpts, RestateModelResponse, State +from .session import RestateSession class DurableModelCalls(MultiProvider): @@ -91,12 +46,14 @@ class DurableModelCalls(MultiProvider): A Restate model provider that wraps the OpenAI SDK's default MultiProvider. """ - def __init__(self, llm_retry_opts: LlmRetryOpts | None = None): + def __init__(self, state: State, llm_retry_opts: LlmRetryOpts | None = None): super().__init__() self.llm_retry_opts = llm_retry_opts + self.state = state def get_model(self, model_name: str | None) -> Model: - return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts) + model = super().get_model(model_name or None) + return RestateModelWrapper(model, self.state, self.llm_retry_opts) class RestateModelWrapper(Model): @@ -104,8 +61,9 @@ class RestateModelWrapper(Model): A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal. """ - def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None): + def __init__(self, model: Model, state: State, llm_retry_opts: LlmRetryOpts | None = None): self.model = model + self.state = state self.model_name = "RestateModelWrapper" self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts() @@ -133,6 +91,10 @@ async def call_llm() -> RestateModelResponse: retry_interval_factor=self.llm_retry_opts.retry_interval_factor, ), ) + # collect function call IDs, to + ids = get_function_call_ids(result.output) + self.state.turnstile = Turnstile(ids) + # convert back to original ModelResponse return ModelResponse( output=result.output, @@ -144,47 +106,6 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.") -class RestateSession(SessionABC): - """Restate session implementation following the Session protocol.""" - - def __init__(self): - self._items: List[TResponseInputItem] | None = None - - def _ctx(self) -> ObjectContext: - return cast(ObjectContext, current_context()) - - async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]: - """Retrieve conversation history for this session.""" - if self._items is None: - self._items = await self._ctx().get("items") or [] - if limit is not None: - return self._items[-limit:] - return self._items.copy() - - async def add_items(self, items: List[TResponseInputItem]) -> None: - """Store new items for this session.""" - if self._items is None: - self._items = await self._ctx().get("items") or [] - self._items.extend(items) - - async def pop_item(self) -> TResponseInputItem | None: - """Remove and return the most recent item from this session.""" - if self._items is None: - self._items = await self._ctx().get("items") or [] - if self._items: - return self._items.pop() - return None - - def flush(self) -> None: - """Flush the session items to the context.""" - self._ctx().set("items", self._items) - - async def clear_session(self) -> None: - """Clear all items for this session.""" - self._items = [] - self._ctx().clear("items") - - class AgentsTerminalException(AgentsException, TerminalError): """Exception that is both an AgentsException and a restate.TerminalError.""" @@ -255,10 +176,13 @@ async def run( The result from Runner.run """ + # execution state + state = State() + # Set persisting model calls llm_retry_opts = kwargs.pop("llm_retry_opts", None) run_config = kwargs.pop("run_config", RunConfig()) - run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts)) + run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(state, llm_retry_opts)) # Disable parallel tool calls model_settings = run_config.model_settings @@ -281,9 +205,10 @@ async def run( raise TerminalError("When use_restate_session is True, session config cannot be provided.") session = RestateSession() + agent = wrap_agent_tools(starting_agent, state) try: result = await Runner.run( - starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs + starting_agent=agent, input=input, run_config=run_config, session=session, **kwargs ) finally: # Flush session items to Restate diff --git a/python/restate/ext/openai/session.py b/python/restate/ext/openai/session.py new file mode 100644 index 0000000..f4aefd7 --- /dev/null +++ b/python/restate/ext/openai/session.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +This module contains the optional OpenAI integration for Restate. +""" + +from agents.memory.session import SessionABC +from agents.items import TResponseInputItem +from typing import List, cast + +from restate.extensions import current_context +from restate import ObjectContext + + +class RestateSession(SessionABC): + """Restate session implementation following the Session protocol.""" + + def __init__(self): + self._items: List[TResponseInputItem] | None = None + + def _ctx(self) -> ObjectContext: + return cast(ObjectContext, current_context()) + + async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]: + """Retrieve conversation history for this session.""" + if self._items is None: + self._items = await self._ctx().get("items") or [] + if limit is not None: + return self._items[-limit:] + return self._items.copy() + + async def add_items(self, items: List[TResponseInputItem]) -> None: + """Store new items for this session.""" + if self._items is None: + self._items = await self._ctx().get("items") or [] + self._items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from this session.""" + if self._items is None: + self._items = await self._ctx().get("items") or [] + if self._items: + return self._items.pop() + return None + + def flush(self) -> None: + """Flush the session items to the context.""" + self._ctx().set("items", self._items) + + async def clear_session(self) -> None: + """Clear all items for this session.""" + self._items = [] + self._ctx().clear("items") diff --git a/python/restate/ext/openai/utils.py b/python/restate/ext/openai/utils.py new file mode 100644 index 0000000..0967414 --- /dev/null +++ b/python/restate/ext/openai/utils.py @@ -0,0 +1,110 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# + +import dataclasses + +from agents import ( + Handoff, + TContext, + Agent, +) + +from agents.tool import FunctionTool, Tool +from agents.tool_context import ToolContext + +from typing import List, Any + +from agents.items import TResponseOutputItem +from restate.extensions import current_context + +from .models import State + + +def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]: + """Extract function call IDs from the model response.""" + # TODO: support function calls in other response types + return [item.call_id for item in response if item.type == "function_call"] + + +def _create_wrapper(state, captured_tool): + async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any: + async def invoke(): + result = await captured_tool.on_invoke_tool(tool_context, tool_input) + # Ensure Pydantic objects are serialized to dict for LLM compatibility + if hasattr(result, "model_dump"): + return result.model_dump() + elif hasattr(result, "dict"): + return result.dict() + return result + + turnstile = state.turnstile + call_id = tool_context.tool_call_id + try: + await turnstile.wait_for(call_id) + ctx = current_context() + if ctx is None: + raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler") + return await ctx.run_typed(captured_tool.name, invoke) + finally: + turnstile.allow_next_after(call_id) + + return on_invoke_tool_wrapper + + +def wrap_agent_tools( + agent: Agent[TContext], + state: State, +) -> Agent[TContext]: + """ + Wrap the tools of an agent to use the Restate error handling. + + Returns: + A new agent with wrapped tools. + """ + wrapped_tools: list[Tool] = [] + for tool in agent.tools: + if isinstance(tool, FunctionTool): + wrapped = _create_wrapper(state, tool) + wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=wrapped)) + else: + wrapped_tools.append(tool) + + wrapped_handoffs: list[Agent[Any] | Handoff[Any]] = [] + for handoff in agent.handoffs: + if isinstance(handoff, Agent): + wrapped_handoff = wrap_agent_tools(handoff, state) + wrapped_handoffs.append(wrapped_handoff) + elif isinstance(handoff, Handoff): + wrapped_handoffs.append(wrap_agent_handoff_tools(handoff, state)) + else: + raise TypeError(f"Unsupported handoff type: {type(handoff)}") + + return agent.clone(tools=wrapped_tools, handoffs=wrapped_handoffs) + + +def wrap_agent_handoff_tools( + handoff: Handoff[TContext], + state: State, +) -> Handoff[TContext]: + """ + Wrap the tools of a handoff to use the Restate error handling. + + Returns: + A new handoff with wrapped tools. + """ + + original_on_invoke_handoff = handoff.on_invoke_handoff + + async def wrapped(*args, **kwargs) -> Any: + agent = await original_on_invoke_handoff(*args, **kwargs) + return wrap_agent_tools(agent, state) + + return dataclasses.replace(handoff, on_invoke_handoff=wrapped) From ef2fc0746470e93e83c0933a538d14f291af9625 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 16 Dec 2025 18:34:38 +0100 Subject: [PATCH 3/4] Do not autowrap tools --- python/restate/ext/openai/__init__.py | 4 +++- python/restate/ext/openai/runner_wrapper.py | 17 +---------------- python/restate/ext/openai/utils.py | 19 ++----------------- 3 files changed, 6 insertions(+), 34 deletions(-) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index 969d42f..dad5783 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -14,7 +14,9 @@ import typing -from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts +from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors +from .models import LlmRetryOpts + from restate import ObjectContext, Context from restate.server_context import current_context diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 7bc2d44..06f07d4 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH # # This file is part of the Restate SDK for Python, # which is released under the MIT license. @@ -23,7 +23,6 @@ RunResult, Agent, ModelBehaviorError, - ModelSettings, Runner, ) from agents.models.multi_provider import MultiProvider @@ -184,20 +183,6 @@ async def run( run_config = kwargs.pop("run_config", RunConfig()) run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(state, llm_retry_opts)) - # Disable parallel tool calls - model_settings = run_config.model_settings - if model_settings is None: - model_settings = ModelSettings(parallel_tool_calls=False) - else: - model_settings = dataclasses.replace( - model_settings, - parallel_tool_calls=False, - ) - run_config = dataclasses.replace( - run_config, - model_settings=model_settings, - ) - # Use Restate session if requested, otherwise use provided session session = kwargs.pop("session", None) if use_restate_session: diff --git a/python/restate/ext/openai/utils.py b/python/restate/ext/openai/utils.py index 0967414..6fee0ce 100644 --- a/python/restate/ext/openai/utils.py +++ b/python/restate/ext/openai/utils.py @@ -9,6 +9,7 @@ # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # +from typing import List, Any import dataclasses from agents import ( @@ -19,11 +20,7 @@ from agents.tool import FunctionTool, Tool from agents.tool_context import ToolContext - -from typing import List, Any - from agents.items import TResponseOutputItem -from restate.extensions import current_context from .models import State @@ -36,23 +33,11 @@ def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]: def _create_wrapper(state, captured_tool): async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any: - async def invoke(): - result = await captured_tool.on_invoke_tool(tool_context, tool_input) - # Ensure Pydantic objects are serialized to dict for LLM compatibility - if hasattr(result, "model_dump"): - return result.model_dump() - elif hasattr(result, "dict"): - return result.dict() - return result - turnstile = state.turnstile call_id = tool_context.tool_call_id try: await turnstile.wait_for(call_id) - ctx = current_context() - if ctx is None: - raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler") - return await ctx.run_typed(captured_tool.name, invoke) + return await captured_tool.on_invoke_tool(tool_context, tool_input) finally: turnstile.allow_next_after(call_id) From 06ecb23a85357229a0bcc930fe9e42c879ca1ef3 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 16 Dec 2025 18:39:34 +0100 Subject: [PATCH 4/4] rename utils to functions --- python/restate/ext/openai/__init__.py | 6 +++--- python/restate/ext/openai/{utils.py => functions.py} | 0 python/restate/ext/openai/runner_wrapper.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename python/restate/ext/openai/{utils.py => functions.py} (100%) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index dad5783..a13ac9f 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -14,12 +14,12 @@ import typing -from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors -from .models import LlmRetryOpts - from restate import ObjectContext, Context from restate.server_context import current_context +from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors +from .models import LlmRetryOpts + def restate_object_context() -> ObjectContext: """Get the current Restate ObjectContext.""" diff --git a/python/restate/ext/openai/utils.py b/python/restate/ext/openai/functions.py similarity index 100% rename from python/restate/ext/openai/utils.py rename to python/restate/ext/openai/functions.py diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 06f07d4..6235125 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -35,7 +35,7 @@ from restate.extensions import current_context from restate import RunOptions, TerminalError -from .utils import get_function_call_ids, wrap_agent_tools +from .functions import get_function_call_ids, wrap_agent_tools from .models import LlmRetryOpts, RestateModelResponse, State from .session import RestateSession