From a95e3acf4292e2759f691fc7c90765a4ebf98ab9 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Mon, 15 Dec 2025 14:13:45 +0100 Subject: [PATCH 1/5] Update OpenAI integration: flush session at end of runner, allow setting model retry opts, set parallel_tool_calls to false --- python/restate/ext/openai/__init__.py | 36 +++- python/restate/ext/openai/runner_wrapper.py | 196 ++++++++++---------- 2 files changed, 133 insertions(+), 99 deletions(-) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index e2c431c..e724bcb 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -12,11 +12,43 @@ This module contains the optional OpenAI integration for Restate. """ -from .runner_wrapper import Runner, DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors +import typing + +from .runner_wrapper import ( + DurableRunner, + DurableModelCalls, + continue_on_terminal_errors, + raise_terminal_errors, + RestateSession, + LlmRetryOpts +) +from restate import ObjectContext, Context +from restate.server_context import current_context + + +def restate_object_context() -> ObjectContext: + """Get the current Restate ObjectContext.""" + ctx = current_context() + if ctx is None: + raise RuntimeError("No Restate context found.") + return typing.cast(ObjectContext, ctx) + + +def restate_context() -> Context: + """Get the current Restate Context.""" + ctx = current_context() + if ctx is None: + raise RuntimeError("No Restate context found.") + return ctx + __all__ = [ "DurableModelCalls", "continue_on_terminal_errors", "raise_terminal_errors", - "Runner", + "RestateSession", + "DurableRunner", + "LlmRetryOpts", + "restate_object_context", + "restate_context", ] diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index c9e9f5a..1a59adf 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -12,39 +12,61 @@ This module contains the optional OpenAI integration for Restate. """ -import asyncio import dataclasses -import typing from agents import ( - Tool, Usage, Model, RunContextWrapper, AgentsException, - Runner as OpenAIRunner, RunConfig, TContext, RunResult, Agent, ModelBehaviorError, + ModelSettings, ) - from agents.models.multi_provider import MultiProvider from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse from agents.memory.session import SessionABC from agents.items import TResponseInputItem -from typing import List, Any -from typing import AsyncIterator - -from agents.tool import FunctionTool -from agents.tool_context import ToolContext +from agents.run import ( + AgentRunner, + DEFAULT_AGENT_RUNNER, +) +from datetime import timedelta +from typing import List, Any, AsyncIterator, Optional, cast from pydantic import BaseModel + from restate.exceptions import SdkInternalBaseException from restate.extensions import current_context - from restate import RunOptions, ObjectContext, TerminalError +@dataclasses.dataclass +class LlmRetryOpts: + max_attempts: Optional[int] = 10 + """Max number of attempts (including the initial), before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + max_duration: Optional[timedelta] = None + """Max duration of retries, before giving up. + + When giving up, the LLM call will throw a `TerminalError` wrapping the original error message.""" + initial_retry_interval: Optional[timedelta] = timedelta(seconds=1) + """Initial interval for the first retry attempt. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy.""" + max_retry_interval: Optional[timedelta] = None + """Max interval between retries. + Retry interval will grow by a factor specified in `retry_interval_factor`. + + The default is 10 seconds.""" + retry_interval_factor: Optional[float] = None + """Exponentiation factor to use when computing the next retry delay. + + If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy.""" + # The OpenAI ModelResponse class is a dataclass with Pydantic fields. # The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model. @@ -71,12 +93,12 @@ class DurableModelCalls(MultiProvider): A Restate model provider that wraps the OpenAI SDK's default MultiProvider. """ - def __init__(self, max_retries: int | None = 3): + def __init__(self, llm_retry_opts: LlmRetryOpts): super().__init__() - self.max_retries = max_retries + self.llm_retry_opts = llm_retry_opts def get_model(self, model_name: str | None) -> Model: - return RestateModelWrapper(super().get_model(model_name or None), self.max_retries) + return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts) class RestateModelWrapper(Model): @@ -84,10 +106,10 @@ class RestateModelWrapper(Model): A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal. """ - def __init__(self, model: Model, max_retries: int | None = 3): + def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts): self.model = model self.model_name = "RestateModelWrapper" - self.max_retries = max_retries + self.llm_retry_opts = llm_retry_opts async def get_response(self, *args, **kwargs) -> ModelResponse: async def call_llm() -> RestateModelResponse: @@ -102,7 +124,18 @@ async def call_llm() -> RestateModelResponse: ctx = current_context() if ctx is None: raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler") - result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries)) + print("Calling LLM with retry options:", self.llm_retry_opts) + result = await ctx.run_typed( + "call LLM", + call_llm, + RunOptions( + max_attempts=self.llm_retry_opts.max_attempts, + max_duration=self.llm_retry_opts.max_duration, + initial_retry_interval=self.llm_retry_opts.initial_retry_interval, + max_retry_interval=self.llm_retry_opts.max_retry_interval, + retry_interval_factor=self.llm_retry_opts.retry_interval_factor, + ), + ) # convert back to original ModelResponse return ModelResponse( output=result.output, @@ -117,33 +150,43 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent class RestateSession(SessionABC): """Restate session implementation following the Session protocol.""" + def __init__(self): + self._items: List[TResponseInputItem] | None = None + def _ctx(self) -> ObjectContext: - return typing.cast(ObjectContext, current_context()) + return cast(ObjectContext, current_context()) + + async def _load_items_if_needed(self) -> None: + """Load items from context if not already loaded.""" + if self._items is None: + self._items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or [] async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]: """Retrieve conversation history for this session.""" - current_items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or [] + await self._load_items_if_needed() if limit is not None: - return current_items[-limit:] - return current_items + return self._items[-limit:] + return self._items.copy() async def add_items(self, items: List[TResponseInputItem]) -> None: """Store new items for this session.""" - # Your implementation here - current_items = await self.get_items() or [] - self._ctx().set("items", current_items + items) + await self._load_items_if_needed() + self._items.extend(items) async def pop_item(self) -> TResponseInputItem | None: """Remove and return the most recent item from this session.""" - current_items = await self.get_items() or [] - if current_items: - item = current_items.pop() - self._ctx().set("items", current_items) - return item + await self._load_items_if_needed() + if self._items: + return self._items.pop() return None + async def flush(self) -> None: + """Flush the session items to the context.""" + self._ctx().set("items", self._items) + async def clear_session(self) -> None: """Clear all items for this session.""" + self._items = [] self._ctx().clear("items") @@ -189,7 +232,7 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio raise error -class Runner: +class DurableRunner: """ A wrapper around Runner.run that automatically configures RunConfig for Restate contexts. @@ -201,9 +244,7 @@ class Runner: @staticmethod async def run( starting_agent: Agent[TContext], - disable_tool_autowrapping: bool = False, - *args: typing.Any, - run_config: RunConfig | None = None, + input: str | list[TResponseInputItem], **kwargs, ) -> RunResult: """ @@ -213,71 +254,32 @@ async def run( The result from Runner.run """ - current_run_config = run_config or RunConfig() - new_run_config = dataclasses.replace( - current_run_config, - model_provider=DurableModelCalls(), - ) - restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping) - return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs) - + # Set persisting model calls + llm_retry_opts = kwargs.get("llm_retry_opts", LlmRetryOpts()) + run_config = kwargs.pop("run_config", RunConfig()) + run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts)) -def sequentialize_and_wrap_tools( - agent: Agent[TContext], - disable_tool_autowrapping: bool, -) -> Agent[TContext]: - """ - Wrap the tools of an agent to use the Restate error handling. - - Returns: - A new agent with wrapped tools. - """ - - # Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution. - # This lock only affects tools for this agent; handoff agents are wrapped recursively. - sequential_tools_lock = asyncio.Lock() - wrapped_tools: list[Tool] = [] - for tool in agent.tools: - if isinstance(tool, FunctionTool): - - def create_wrapper(captured_tool): - async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any: - await sequential_tools_lock.acquire() - - async def invoke(): - result = await captured_tool.on_invoke_tool(tool_context, tool_input) - # Ensure Pydantic objects are serialized to dict for LLM compatibility - if hasattr(result, "model_dump"): - return result.model_dump() - elif hasattr(result, "dict"): - return result.dict() - return result - - try: - if disable_tool_autowrapping: - return await invoke() - - ctx = current_context() - if ctx is None: - raise RuntimeError( - "No current Restate context found, make sure to run inside a Restate handler" - ) - return await ctx.run_typed(captured_tool.name, invoke) - finally: - sequential_tools_lock.release() - - return on_invoke_tool_wrapper - - wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=create_wrapper(tool))) + # Disable parallel tool calls + model_settings = run_config.model_settings + if model_settings is None: + model_settings = ModelSettings(parallel_tool_calls=False) else: - wrapped_tools.append(tool) + model_settings = dataclasses.replace( + model_settings, + parallel_tool_calls=False, + ) + run_config = dataclasses.replace( + run_config, + model_settings=model_settings, + ) - handoffs_with_wrapped_tools = [] - for handoff in agent.handoffs: - # recursively wrap tools in handoff agents - handoffs_with_wrapped_tools.append(sequentialize_and_wrap_tools(handoff, disable_tool_autowrapping)) # type: ignore + runner = DEFAULT_AGENT_RUNNER or AgentRunner() + try: + result = await runner.run(starting_agent=starting_agent, input=input, run_config=run_config, **kwargs) + finally: + # Flush session items to Restate + session = kwargs.get("session", None) + if session is not None and isinstance(session, RestateSession): + await session.flush() - return agent.clone( - tools=wrapped_tools, - handoffs=handoffs_with_wrapped_tools, - ) + return result From c610405d9da9a850ebbaef619e2cabf584ec7d62 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Mon, 15 Dec 2025 14:14:13 +0100 Subject: [PATCH 2/5] Update OpenAI integration: flush session at end of runner, allow setting model retry opts, set parallel_tool_calls to false --- python/restate/ext/openai/__init__.py | 8 +++----- python/restate/ext/openai/runner_wrapper.py | 7 +++---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index e724bcb..8b62a2d 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -16,7 +16,6 @@ from .runner_wrapper import ( DurableRunner, - DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors, RestateSession, @@ -43,12 +42,11 @@ def restate_context() -> Context: __all__ = [ - "DurableModelCalls", - "continue_on_terminal_errors", - "raise_terminal_errors", - "RestateSession", "DurableRunner", + "RestateSession", "LlmRetryOpts", "restate_object_context", "restate_context", + "continue_on_terminal_errors", + "raise_terminal_errors", ] diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 1a59adf..b27d52f 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -93,7 +93,7 @@ class DurableModelCalls(MultiProvider): A Restate model provider that wraps the OpenAI SDK's default MultiProvider. """ - def __init__(self, llm_retry_opts: LlmRetryOpts): + def __init__(self, llm_retry_opts: LlmRetryOpts | None = None): super().__init__() self.llm_retry_opts = llm_retry_opts @@ -106,7 +106,7 @@ class RestateModelWrapper(Model): A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal. """ - def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts): + def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = LlmRetryOpts()): self.model = model self.model_name = "RestateModelWrapper" self.llm_retry_opts = llm_retry_opts @@ -124,7 +124,6 @@ async def call_llm() -> RestateModelResponse: ctx = current_context() if ctx is None: raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler") - print("Calling LLM with retry options:", self.llm_retry_opts) result = await ctx.run_typed( "call LLM", call_llm, @@ -255,7 +254,7 @@ async def run( """ # Set persisting model calls - llm_retry_opts = kwargs.get("llm_retry_opts", LlmRetryOpts()) + llm_retry_opts = kwargs.get("llm_retry_opts", None) run_config = kwargs.pop("run_config", RunConfig()) run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts)) From 23722419a7dd6870603d64d55223302d12a302d1 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Mon, 15 Dec 2025 14:25:04 +0100 Subject: [PATCH 3/5] flush sync --- python/restate/ext/openai/runner_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index b27d52f..88db48e 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -179,7 +179,7 @@ async def pop_item(self) -> TResponseInputItem | None: return self._items.pop() return None - async def flush(self) -> None: + def flush(self) -> None: """Flush the session items to the context.""" self._ctx().set("items", self._items) @@ -279,6 +279,6 @@ async def run( # Flush session items to Restate session = kwargs.get("session", None) if session is not None and isinstance(session, RestateSession): - await session.flush() + session.flush() return result From bc9a5b6e50ceb7587dfe4c83c220ac2912896559 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Tue, 16 Dec 2025 10:06:58 +0100 Subject: [PATCH 4/5] Address review comments --- python/restate/ext/openai/__init__.py | 9 +---- python/restate/ext/openai/runner_wrapper.py | 41 ++++++++++++--------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/python/restate/ext/openai/__init__.py b/python/restate/ext/openai/__init__.py index 8b62a2d..969d42f 100644 --- a/python/restate/ext/openai/__init__.py +++ b/python/restate/ext/openai/__init__.py @@ -14,13 +14,7 @@ import typing -from .runner_wrapper import ( - DurableRunner, - continue_on_terminal_errors, - raise_terminal_errors, - RestateSession, - LlmRetryOpts -) +from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts from restate import ObjectContext, Context from restate.server_context import current_context @@ -43,7 +37,6 @@ def restate_context() -> Context: __all__ = [ "DurableRunner", - "RestateSession", "LlmRetryOpts", "restate_object_context", "restate_context", diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 88db48e..4f7fa63 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -25,15 +25,12 @@ Agent, ModelBehaviorError, ModelSettings, + Runner, ) from agents.models.multi_provider import MultiProvider from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse from agents.memory.session import SessionABC from agents.items import TResponseInputItem -from agents.run import ( - AgentRunner, - DEFAULT_AGENT_RUNNER, -) from datetime import timedelta from typing import List, Any, AsyncIterator, Optional, cast from pydantic import BaseModel @@ -42,6 +39,7 @@ from restate.extensions import current_context from restate import RunOptions, ObjectContext, TerminalError + @dataclasses.dataclass class LlmRetryOpts: max_attempts: Optional[int] = 10 @@ -106,10 +104,10 @@ class RestateModelWrapper(Model): A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal. """ - def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = LlmRetryOpts()): + def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None): self.model = model self.model_name = "RestateModelWrapper" - self.llm_retry_opts = llm_retry_opts + self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts() async def get_response(self, *args, **kwargs) -> ModelResponse: async def call_llm() -> RestateModelResponse: @@ -155,26 +153,24 @@ def __init__(self): def _ctx(self) -> ObjectContext: return cast(ObjectContext, current_context()) - async def _load_items_if_needed(self) -> None: - """Load items from context if not already loaded.""" - if self._items is None: - self._items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or [] - async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]: """Retrieve conversation history for this session.""" - await self._load_items_if_needed() + if self._items is None: + self._items = await self._ctx().get("items") or [] if limit is not None: return self._items[-limit:] return self._items.copy() async def add_items(self, items: List[TResponseInputItem]) -> None: """Store new items for this session.""" - await self._load_items_if_needed() + if self._items is None: + self._items = await self._ctx().get("items") or [] self._items.extend(items) async def pop_item(self) -> TResponseInputItem | None: """Remove and return the most recent item from this session.""" - await self._load_items_if_needed() + if self._items is None: + self._items = await self._ctx().get("items") or [] if self._items: return self._items.pop() return None @@ -244,6 +240,8 @@ class DurableRunner: async def run( starting_agent: Agent[TContext], input: str | list[TResponseInputItem], + *, + use_restate_session: bool = False, **kwargs, ) -> RunResult: """ @@ -254,7 +252,7 @@ async def run( """ # Set persisting model calls - llm_retry_opts = kwargs.get("llm_retry_opts", None) + llm_retry_opts = kwargs.pop("llm_retry_opts", None) run_config = kwargs.pop("run_config", RunConfig()) run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts)) @@ -272,12 +270,19 @@ async def run( model_settings=model_settings, ) - runner = DEFAULT_AGENT_RUNNER or AgentRunner() + # Use Restate session if requested, otherwise use provided session + session = kwargs.pop("session", None) + if use_restate_session: + if session is not None: + raise TerminalError("When use_restate_session is True, session config cannot be provided.") + session = RestateSession() + try: - result = await runner.run(starting_agent=starting_agent, input=input, run_config=run_config, **kwargs) + result = await Runner.run( + starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs + ) finally: # Flush session items to Restate - session = kwargs.get("session", None) if session is not None and isinstance(session, RestateSession): session.flush() From 73bfa10f4b6b309f47bb749e23f993f8370627b7 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Tue, 16 Dec 2025 12:28:33 +0100 Subject: [PATCH 5/5] Add docstring --- python/restate/ext/openai/runner_wrapper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/restate/ext/openai/runner_wrapper.py b/python/restate/ext/openai/runner_wrapper.py index 4f7fa63..572b819 100644 --- a/python/restate/ext/openai/runner_wrapper.py +++ b/python/restate/ext/openai/runner_wrapper.py @@ -247,6 +247,10 @@ async def run( """ Run an agent with automatic Restate configuration. + Args: + use_restate_session: If True, creates a RestateSession for conversation persistence. + Requires running within a Restate Virtual Object context. + Returns: The result from Runner.run """