From 8bf6cb8e394506673698bcb370f6b224457ddecd Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 16 Dec 2025 19:45:18 +0100 Subject: [PATCH 1/6] Shuffle imports --- python/restate/ext/openai/__init__.py | 3 +- python/restate/ext/openai/functions.py | 34 ++++++++++++++- python/restate/ext/openai/models.py | 16 +++++++ python/restate/ext/openai/runner_wrapper.py | 48 +-------------------- 4 files changed, 52 insertions(+), 49 deletions(-) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index a13ac9f..42b561f 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -17,8 +17,9 @@ from restate import ObjectContext, Context from restate.server_context import current_context -from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors +from .runner_wrapper import DurableRunner from .models import LlmRetryOpts +from .functions import continue_on_terminal_errors, raise_terminal_errors def restate_object_context() -> ObjectContext: diff --git a/python/restate/ext/openai/functions.py b/python/restate/ext/openai/functions.py index 6fee0ce..6f6867f 100644 --- a/python/restate/ext/openai/functions.py +++ b/python/restate/ext/openai/functions.py @@ -16,13 +16,45 @@ Handoff, TContext, Agent, + RunContextWrapper, + ModelBehaviorError, ) from agents.tool import FunctionTool, Tool from agents.tool_context import ToolContext from agents.items import TResponseOutputItem -from .models import State +from restate import TerminalError + +from .models import State, AgentsTerminalException + + +def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: + """A custom function to provide a user-friendly error message.""" + # Raise terminal errors and cancellations + if isinstance(error, TerminalError): + # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError + # so we create a new exception that inherits from both + raise AgentsTerminalException(error.message) + + if isinstance(error, ModelBehaviorError): + return f"An error occurred while calling the tool: {str(error)}" + + raise error + + +def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: + """A custom function to provide a user-friendly error message.""" + # Raise terminal errors and cancellations + if isinstance(error, TerminalError): + # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError + # so we create a new exception that inherits from both + return f"An error occurred while running the tool: {str(error)}" + + if isinstance(error, ModelBehaviorError): + return f"An error occurred while calling the tool: {str(error)}" + + raise error def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]: diff --git a/python/restate/ext/openai/models.py b/python/restate/ext/openai/models.py index c48cf88..0ce3d12 100644 --- a/python/restate/ext/openai/models.py +++ b/python/restate/ext/openai/models.py @@ -16,6 +16,7 @@ from agents import ( Usage, + AgentsException, ) from agents.items import TResponseOutputItem from agents.items import TResponseInputItem @@ -24,6 +25,7 @@ from pydantic import BaseModel from restate.ext.turnstile import Turnstile +from restate import TerminalError, SdkInternalBaseException class State: @@ -77,3 +79,17 @@ class RestateModelResponse(BaseModel): def to_input_items(self) -> list[TResponseInputItem]: return [it.model_dump(exclude_unset=True) for it in self.output] # type: ignore + + +class AgentsTerminalException(AgentsException, TerminalError): + """Exception that is both an AgentsException and a restate.TerminalError.""" + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +class AgentsSuspension(AgentsException, SdkInternalBaseException): + """Exception that is both an AgentsException and a restate SdkInternalBaseException.""" + + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 6235125..d93f6db 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -16,21 +16,17 @@ from agents import ( Model, - RunContextWrapper, - AgentsException, RunConfig, TContext, RunResult, Agent, - ModelBehaviorError, Runner, ) from agents.models.multi_provider import MultiProvider from agents.items import TResponseStreamEvent, ModelResponse from agents.items import TResponseInputItem -from typing import Any, AsyncIterator +from typing import AsyncIterator -from restate.exceptions import SdkInternalBaseException from restate.ext.turnstile import Turnstile from restate.extensions import current_context from restate import RunOptions, TerminalError @@ -105,48 +101,6 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent raise TerminalError("Streaming is not supported in Restate. Use `get_response` instead.") -class AgentsTerminalException(AgentsException, TerminalError): - """Exception that is both an AgentsException and a restate.TerminalError.""" - - def __init__(self, *args: object) -> None: - super().__init__(*args) - - -class AgentsSuspension(AgentsException, SdkInternalBaseException): - """Exception that is both an AgentsException and a restate SdkInternalBaseException.""" - - def __init__(self, *args: object) -> None: - super().__init__(*args) - - -def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: - """A custom function to provide a user-friendly error message.""" - # Raise terminal errors and cancellations - if isinstance(error, TerminalError): - # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError - # so we create a new exception that inherits from both - raise AgentsTerminalException(error.message) - - if isinstance(error, ModelBehaviorError): - return f"An error occurred while calling the tool: {str(error)}" - - raise error - - -def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: - """A custom function to provide a user-friendly error message.""" - # Raise terminal errors and cancellations - if isinstance(error, TerminalError): - # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError - # so we create a new exception that inherits from both - return f"An error occurred while running the tool: {str(error)}" - - if isinstance(error, ModelBehaviorError): - return f"An error occurred while calling the tool: {str(error)}" - - raise error - - class DurableRunner: """ A wrapper around Runner.run that automatically configures RunConfig for Restate contexts. From c87e987f2a608e4f0589b5a7ad1e8814d677a79c Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 16 Dec 2025 20:12:42 +0100 Subject: [PATCH 2/6] Preconfigure a function tool --- python/restate/ext/openai/__init__.py | 3 ++- python/restate/ext/openai/functions.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index 42b561f..61be8d6 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -19,7 +19,7 @@ from .runner_wrapper import DurableRunner from .models import LlmRetryOpts -from .functions import continue_on_terminal_errors, raise_terminal_errors +from .functions import continue_on_terminal_errors, raise_terminal_errors, function_tool def restate_object_context() -> ObjectContext: @@ -45,4 +45,5 @@ def restate_context() -> Context: "restate_context", "continue_on_terminal_errors", "raise_terminal_errors", + "function_tool", ] diff --git a/python/restate/ext/openai/functions.py b/python/restate/ext/openai/functions.py index 6f6867f..2b193f8 100644 --- a/python/restate/ext/openai/functions.py +++ b/python/restate/ext/openai/functions.py @@ -20,7 +20,7 @@ ModelBehaviorError, ) -from agents.tool import FunctionTool, Tool +from agents.tool import FunctionTool, Tool, ToolFunction, function_tool as oai_function_tool from agents.tool_context import ToolContext from agents.items import TResponseOutputItem @@ -29,6 +29,17 @@ from .models import State, AgentsTerminalException +def function_tool(func: ToolFunction, *args, **kwargs) -> FunctionTool: + failure_error_function = kwargs.pop("failure_error_function", raise_terminal_errors) + + return oai_function_tool( + func, + *args, + failure_error_function=failure_error_function, + **kwargs, + ) + + def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: """A custom function to provide a user-friendly error message.""" # Raise terminal errors and cancellations From eaee96deaa4f50b6616f3755c3131f5426f8c039 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 16 Dec 2025 21:27:07 +0100 Subject: [PATCH 3/6] use our own @function_tool that sets a differnt default failure function --- python/restate/ext/openai/functions.py | 91 ++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 11 deletions(-) diff --git a/python/restate/ext/openai/functions.py b/python/restate/ext/openai/functions.py index 2b193f8..16ee5ea 100644 --- a/python/restate/ext/openai/functions.py +++ b/python/restate/ext/openai/functions.py @@ -9,7 +9,7 @@ # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # -from typing import List, Any +from typing import List, Any, overload, Callable, Union, TypeVar, Awaitable import dataclasses from agents import ( @@ -18,9 +18,18 @@ Agent, RunContextWrapper, ModelBehaviorError, + AgentBase, ) -from agents.tool import FunctionTool, Tool, ToolFunction, function_tool as oai_function_tool +from agents.function_schema import DocstringStyle + +from agents.tool import ( + FunctionTool, + Tool, + ToolFunction, + ToolErrorFunction, + function_tool as oai_function_tool, +) from agents.tool_context import ToolContext from agents.items import TResponseOutputItem @@ -28,16 +37,9 @@ from .models import State, AgentsTerminalException +T = TypeVar("T") -def function_tool(func: ToolFunction, *args, **kwargs) -> FunctionTool: - failure_error_function = kwargs.pop("failure_error_function", raise_terminal_errors) - - return oai_function_tool( - func, - *args, - failure_error_function=failure_error_function, - **kwargs, - ) +MaybeAwaitable = Union[Awaitable[T], T] def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: @@ -68,6 +70,73 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio raise error +@overload +def function_tool( + func: ToolFunction[...], + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, +) -> FunctionTool: + """Overload for usage as @function_tool (no parentheses).""" + ... + + +@overload +def function_tool( + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, +) -> Callable[[ToolFunction[...]], FunctionTool]: + """Overload for usage as @function_tool(...).""" + ... + + +def function_tool( + func: ToolFunction[...] | None = None, + *, + name_override: str | None = None, + description_override: str | None = None, + docstring_style: DocstringStyle | None = None, + use_docstring_info: bool = True, + failure_error_function: ToolErrorFunction | None = raise_terminal_errors, + strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, +) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: + # If func is actually a callable, we were used as @function_tool with no parentheses + if callable(func): + return oai_function_tool( + func=func, + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + failure_error_function=failure_error_function or raise_terminal_errors, + strict_mode=strict_mode, + is_enabled=is_enabled, + ) + + # Otherwise, we were used as @function_tool(...), so return a decorator + return oai_function_tool( + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + failure_error_function=failure_error_function or raise_terminal_errors, + strict_mode=strict_mode, + is_enabled=is_enabled, + ) + + def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]: """Extract function call IDs from the model response.""" # TODO: support function calls in other response types From f20757fc2475e6b2b3c026396be7fd93a0709797 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 16 Dec 2025 22:09:14 +0100 Subject: [PATCH 4/6] Make sure to cancel any subsequent tools on failure --- python/restate/ext/openai/functions.py | 16 +++++++++++++--- python/restate/ext/turnstile.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/restate/ext/openai/functions.py b/python/restate/ext/openai/functions.py index 16ee5ea..3db918f 100644 --- a/python/restate/ext/openai/functions.py +++ b/python/restate/ext/openai/functions.py @@ -147,11 +147,21 @@ def _create_wrapper(state, captured_tool): async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any: turnstile = state.turnstile call_id = tool_context.tool_call_id + # wait for our turn + await turnstile.wait_for(call_id) try: - await turnstile.wait_for(call_id) - return await captured_tool.on_invoke_tool(tool_context, tool_input) - finally: + # invoke the original tool + res = await captured_tool.on_invoke_tool(tool_context, tool_input) + # allow the next tool to proceed turnstile.allow_next_after(call_id) + return res + except BaseException as ex: + # if there was an error, it will be propagated up, towards the handler + # but we need to make sure that all subsequent tools will not execute + # as they might interact with the restate context. + turnstile.cancel_all_after(call_id) + # re-raise the exception + raise ex from None return on_invoke_tool_wrapper diff --git a/python/restate/ext/turnstile.py b/python/restate/ext/turnstile.py index a2f6104..a091061 100644 --- a/python/restate/ext/turnstile.py +++ b/python/restate/ext/turnstile.py @@ -1,4 +1,5 @@ from asyncio import Event +from restate.exceptions import SdkInternalException class Turnstile: @@ -14,6 +15,7 @@ def __init__(self, ids: list[str]): self.turns = dict(zip(ids, ids[1:])) # mapping of id to event that signals when that id's turn is allowed self.events = {id: Event() for id in ids} + self.canceled = False if ids: # make sure that the first id can proceed immediately event = self.events[ids[0]] @@ -22,6 +24,16 @@ def __init__(self, ids: list[str]): async def wait_for(self, id: str) -> None: event = self.events[id] await event.wait() + if self.canceled: + raise SdkInternalException() from None + + def cancel_all_after(self, id: str) -> None: + self.canceled = True + next_id = self.turns.get(id) + while next_id is not None: + next_event = self.events[next_id] + next_event.set() + next_id = self.turns.get(next_id) def allow_next_after(self, id: str) -> None: next_id = self.turns.get(id) From 49bd011145c348a4013434c036b22375a652419c Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 17 Dec 2025 13:27:19 +0100 Subject: [PATCH 5/6] Cancel all pending tools on attempt finish --- python/restate/ext/adk/plugin.py | 6 ++++-- python/restate/ext/turnstile.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index d9d9c82..65bdbc2 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -81,7 +81,9 @@ async def release_task(): await event.wait() finally: self._models.pop(id, None) - self._turnstiles.pop(id, None) + maybe_turnstile = self._turnstiles.pop(id, None) + if maybe_turnstile is not None: + maybe_turnstile.cancel_all() _ = asyncio.create_task(release_task()) return None @@ -161,7 +163,7 @@ async def on_tool_error_callback( assert turnstile is not None, "Turnstile not found for tool invocation." id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." - turnstile.allow_next_after(id) + turnstile.cancel_all_after(id) return None async def close(self): diff --git a/python/restate/ext/turnstile.py b/python/restate/ext/turnstile.py index a091061..08981c6 100644 --- a/python/restate/ext/turnstile.py +++ b/python/restate/ext/turnstile.py @@ -41,3 +41,8 @@ def allow_next_after(self, id: str) -> None: return next_event = self.events[next_id] next_event.set() + + def cancel_all(self) -> None: + self.canceled = True + for event in self.events.values(): + event.set() From 4bd159fc0740adc3bd927b306c083b3e7582c3f2 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 17 Dec 2025 14:11:20 +0100 Subject: [PATCH 6/6] rename function_tool to durable_function_tool and propagate cancellations --- python/restate/ext/openai/__init__.py | 6 +-- python/restate/ext/openai/functions.py | 52 +++++++++++++------------- python/restate/ext/openai/models.py | 9 +---- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index 61be8d6..94e6d0b 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -19,7 +19,7 @@ from .runner_wrapper import DurableRunner from .models import LlmRetryOpts -from .functions import continue_on_terminal_errors, raise_terminal_errors, function_tool +from .functions import propagate_cancellation, raise_terminal_errors, durable_function_tool def restate_object_context() -> ObjectContext: @@ -43,7 +43,7 @@ def restate_context() -> Context: "LlmRetryOpts", "restate_object_context", "restate_context", - "continue_on_terminal_errors", "raise_terminal_errors", - "function_tool", + "propagate_cancellation", + "durable_function_tool", ] diff --git a/python/restate/ext/openai/functions.py b/python/restate/ext/openai/functions.py index 3db918f..621f5c6 100644 --- a/python/restate/ext/openai/functions.py +++ b/python/restate/ext/openai/functions.py @@ -29,6 +29,7 @@ ToolFunction, ToolErrorFunction, function_tool as oai_function_tool, + default_tool_error_function, ) from agents.tool_context import ToolContext from agents.items import TResponseOutputItem @@ -56,22 +57,21 @@ def raise_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> raise error -def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exception) -> str: - """A custom function to provide a user-friendly error message.""" - # Raise terminal errors and cancellations - if isinstance(error, TerminalError): - # For the agent SDK it needs to be an AgentsException, for restate it needs to be a TerminalError - # so we create a new exception that inherits from both - return f"An error occurred while running the tool: {str(error)}" +def propagate_cancellation(failure_error_function: ToolErrorFunction | None = None) -> ToolErrorFunction: + _fn = failure_error_function if failure_error_function is not None else default_tool_error_function - if isinstance(error, ModelBehaviorError): - return f"An error occurred while calling the tool: {str(error)}" + def inner(context: RunContextWrapper[Any], error: Exception): + """Raise cancellations as exceptions.""" + if isinstance(error, TerminalError): + if error.status_code == 409: + raise error from None + return _fn(context, error) - raise error + return inner @overload -def function_tool( +def durable_function_tool( func: ToolFunction[...], *, name_override: str | None = None, @@ -87,7 +87,7 @@ def function_tool( @overload -def function_tool( +def durable_function_tool( *, name_override: str | None = None, description_override: str | None = None, @@ -101,7 +101,7 @@ def function_tool( ... -def function_tool( +def durable_function_tool( func: ToolFunction[...] | None = None, *, name_override: str | None = None, @@ -112,7 +112,8 @@ def function_tool( strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: - # If func is actually a callable, we were used as @function_tool with no parentheses + failure_fn = propagate_cancellation(failure_error_function) + if callable(func): return oai_function_tool( func=func, @@ -120,21 +121,20 @@ def function_tool( description_override=description_override, docstring_style=docstring_style, use_docstring_info=use_docstring_info, - failure_error_function=failure_error_function or raise_terminal_errors, + failure_error_function=failure_fn, + strict_mode=strict_mode, + is_enabled=is_enabled, + ) + else: + return oai_function_tool( + name_override=name_override, + description_override=description_override, + docstring_style=docstring_style, + use_docstring_info=use_docstring_info, + failure_error_function=failure_fn, strict_mode=strict_mode, is_enabled=is_enabled, ) - - # Otherwise, we were used as @function_tool(...), so return a decorator - return oai_function_tool( - name_override=name_override, - description_override=description_override, - docstring_style=docstring_style, - use_docstring_info=use_docstring_info, - failure_error_function=failure_error_function or raise_terminal_errors, - strict_mode=strict_mode, - is_enabled=is_enabled, - ) def get_function_call_ids(response: list[TResponseOutputItem]) -> List[str]: diff --git a/python/restate/ext/openai/models.py b/python/restate/ext/openai/models.py index 0ce3d12..650c03a 100644 --- a/python/restate/ext/openai/models.py +++ b/python/restate/ext/openai/models.py @@ -25,7 +25,7 @@ from pydantic import BaseModel from restate.ext.turnstile import Turnstile -from restate import TerminalError, SdkInternalBaseException +from restate import TerminalError class State: @@ -86,10 +86,3 @@ class AgentsTerminalException(AgentsException, TerminalError): def __init__(self, *args: object) -> None: super().__init__(*args) - - -class AgentsSuspension(AgentsException, SdkInternalBaseException): - """Exception that is both an AgentsException and a restate SdkInternalBaseException.""" - - def __init__(self, *args: object) -> None: - super().__init__(*args)