Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions python/restate/ext/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,34 @@
This module contains the optional OpenAI integration for Restate.
"""

from .runner_wrapper import Runner, DurableModelCalls, continue_on_terminal_errors, raise_terminal_errors
import typing

from .runner_wrapper import DurableRunner, continue_on_terminal_errors, raise_terminal_errors, LlmRetryOpts
from restate import ObjectContext, Context
from restate.server_context import current_context


def restate_object_context() -> ObjectContext:
"""Get the current Restate ObjectContext."""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return typing.cast(ObjectContext, ctx)


def restate_context() -> Context:
"""Get the current Restate Context."""
ctx = current_context()
if ctx is None:
raise RuntimeError("No Restate context found.")
return ctx


__all__ = [
"DurableModelCalls",
"DurableRunner",
"LlmRetryOpts",
"restate_object_context",
"restate_context",
"continue_on_terminal_errors",
"raise_terminal_errors",
"Runner",
]
204 changes: 107 additions & 97 deletions python/restate/ext/openai/runner_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,60 @@
This module contains the optional OpenAI integration for Restate.
"""

import asyncio
import dataclasses
import typing

from agents import (
Tool,
Usage,
Model,
RunContextWrapper,
AgentsException,
Runner as OpenAIRunner,
RunConfig,
TContext,
RunResult,
Agent,
ModelBehaviorError,
ModelSettings,
Runner,
)

from agents.models.multi_provider import MultiProvider
from agents.items import TResponseStreamEvent, TResponseOutputItem, ModelResponse
from agents.memory.session import SessionABC
from agents.items import TResponseInputItem
from typing import List, Any
from typing import AsyncIterator

from agents.tool import FunctionTool
from agents.tool_context import ToolContext
from datetime import timedelta
from typing import List, Any, AsyncIterator, Optional, cast
from pydantic import BaseModel

from restate.exceptions import SdkInternalBaseException
from restate.extensions import current_context

from restate import RunOptions, ObjectContext, TerminalError


@dataclasses.dataclass
class LlmRetryOpts:
max_attempts: Optional[int] = 10
"""Max number of attempts (including the initial), before giving up.

When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
max_duration: Optional[timedelta] = None
"""Max duration of retries, before giving up.

When giving up, the LLM call will throw a `TerminalError` wrapping the original error message."""
initial_retry_interval: Optional[timedelta] = timedelta(seconds=1)
"""Initial interval for the first retry attempt.
Retry interval will grow by a factor specified in `retry_interval_factor`.

If any of the other retry related fields is specified, the default for this field is 50 milliseconds, otherwise restate will fallback to the overall invocation retry policy."""
max_retry_interval: Optional[timedelta] = None
"""Max interval between retries.
Retry interval will grow by a factor specified in `retry_interval_factor`.

The default is 10 seconds."""
retry_interval_factor: Optional[float] = None
"""Exponentiation factor to use when computing the next retry delay.

If any of the other retry related fields is specified, the default for this field is `2`, meaning retry interval will double at each attempt, otherwise restate will fallback to the overall invocation retry policy."""


# The OpenAI ModelResponse class is a dataclass with Pydantic fields.
# The Restate SDK cannot serialize this. So we turn the ModelResponse int a Pydantic model.
class RestateModelResponse(BaseModel):
Expand All @@ -71,23 +91,23 @@ class DurableModelCalls(MultiProvider):
A Restate model provider that wraps the OpenAI SDK's default MultiProvider.
"""

def __init__(self, max_retries: int | None = 3):
def __init__(self, llm_retry_opts: LlmRetryOpts | None = None):
super().__init__()
self.max_retries = max_retries
self.llm_retry_opts = llm_retry_opts

def get_model(self, model_name: str | None) -> Model:
return RestateModelWrapper(super().get_model(model_name or None), self.max_retries)
return RestateModelWrapper(super().get_model(model_name or None), self.llm_retry_opts)


class RestateModelWrapper(Model):
"""
A wrapper around the OpenAI SDK's Model that persists LLM calls in the Restate journal.
"""

def __init__(self, model: Model, max_retries: int | None = 3):
def __init__(self, model: Model, llm_retry_opts: LlmRetryOpts | None = None):
self.model = model
self.model_name = "RestateModelWrapper"
self.max_retries = max_retries
self.llm_retry_opts = llm_retry_opts if llm_retry_opts is not None else LlmRetryOpts()

async def get_response(self, *args, **kwargs) -> ModelResponse:
async def call_llm() -> RestateModelResponse:
Expand All @@ -102,7 +122,17 @@ async def call_llm() -> RestateModelResponse:
ctx = current_context()
if ctx is None:
raise RuntimeError("No current Restate context found, make sure to run inside a Restate handler")
result = await ctx.run_typed("call LLM", call_llm, RunOptions(max_attempts=self.max_retries))
result = await ctx.run_typed(
"call LLM",
call_llm,
RunOptions(
max_attempts=self.llm_retry_opts.max_attempts,
max_duration=self.llm_retry_opts.max_duration,
initial_retry_interval=self.llm_retry_opts.initial_retry_interval,
max_retry_interval=self.llm_retry_opts.max_retry_interval,
retry_interval_factor=self.llm_retry_opts.retry_interval_factor,
),
)
# convert back to original ModelResponse
return ModelResponse(
output=result.output,
Expand All @@ -117,33 +147,41 @@ def stream_response(self, *args, **kwargs) -> AsyncIterator[TResponseStreamEvent
class RestateSession(SessionABC):
"""Restate session implementation following the Session protocol."""

def __init__(self):
self._items: List[TResponseInputItem] | None = None

def _ctx(self) -> ObjectContext:
return typing.cast(ObjectContext, current_context())
return cast(ObjectContext, current_context())

async def get_items(self, limit: int | None = None) -> List[TResponseInputItem]:
"""Retrieve conversation history for this session."""
current_items = await self._ctx().get("items", type_hint=List[TResponseInputItem]) or []
if self._items is None:
self._items = await self._ctx().get("items") or []
if limit is not None:
return current_items[-limit:]
return current_items
return self._items[-limit:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure about this?
this seems to take the last limit items (the negative sign) ?

return self._items.copy()

async def add_items(self, items: List[TResponseInputItem]) -> None:
"""Store new items for this session."""
# Your implementation here
current_items = await self.get_items() or []
self._ctx().set("items", current_items + items)
if self._items is None:
self._items = await self._ctx().get("items") or []
self._items.extend(items)

async def pop_item(self) -> TResponseInputItem | None:
"""Remove and return the most recent item from this session."""
current_items = await self.get_items() or []
if current_items:
item = current_items.pop()
self._ctx().set("items", current_items)
return item
if self._items is None:
self._items = await self._ctx().get("items") or []
if self._items:
return self._items.pop()
return None

def flush(self) -> None:
"""Flush the session items to the context."""
self._ctx().set("items", self._items)

async def clear_session(self) -> None:
"""Clear all items for this session."""
self._items = []
self._ctx().clear("items")


Expand Down Expand Up @@ -189,7 +227,7 @@ def continue_on_terminal_errors(context: RunContextWrapper[Any], error: Exceptio
raise error


class Runner:
class DurableRunner:
"""
A wrapper around Runner.run that automatically configures RunConfig for Restate contexts.

Expand All @@ -201,83 +239,55 @@ class Runner:
@staticmethod
async def run(
starting_agent: Agent[TContext],
disable_tool_autowrapping: bool = False,
*args: typing.Any,
run_config: RunConfig | None = None,
input: str | list[TResponseInputItem],
*,
use_restate_session: bool = False,
**kwargs,
) -> RunResult:
"""
Run an agent with automatic Restate configuration.

Args:
use_restate_session: If True, creates a RestateSession for conversation persistence.
Requires running within a Restate Virtual Object context.

Returns:
The result from Runner.run
"""

current_run_config = run_config or RunConfig()
new_run_config = dataclasses.replace(
current_run_config,
model_provider=DurableModelCalls(),
)
restate_agent = sequentialize_and_wrap_tools(starting_agent, disable_tool_autowrapping)
return await OpenAIRunner.run(restate_agent, *args, run_config=new_run_config, **kwargs)


def sequentialize_and_wrap_tools(
agent: Agent[TContext],
disable_tool_autowrapping: bool,
) -> Agent[TContext]:
"""
Wrap the tools of an agent to use the Restate error handling.

Returns:
A new agent with wrapped tools.
"""
# Set persisting model calls
llm_retry_opts = kwargs.pop("llm_retry_opts", None)
run_config = kwargs.pop("run_config", RunConfig())
run_config = dataclasses.replace(run_config, model_provider=DurableModelCalls(llm_retry_opts))

# Restate does not allow parallel tool calls, so we use a lock to ensure sequential execution.
# This lock only affects tools for this agent; handoff agents are wrapped recursively.
sequential_tools_lock = asyncio.Lock()
wrapped_tools: list[Tool] = []
for tool in agent.tools:
if isinstance(tool, FunctionTool):

def create_wrapper(captured_tool):
async def on_invoke_tool_wrapper(tool_context: ToolContext[Any], tool_input: Any) -> Any:
await sequential_tools_lock.acquire()

async def invoke():
result = await captured_tool.on_invoke_tool(tool_context, tool_input)
# Ensure Pydantic objects are serialized to dict for LLM compatibility
if hasattr(result, "model_dump"):
return result.model_dump()
elif hasattr(result, "dict"):
return result.dict()
return result

try:
if disable_tool_autowrapping:
return await invoke()

ctx = current_context()
if ctx is None:
raise RuntimeError(
"No current Restate context found, make sure to run inside a Restate handler"
)
return await ctx.run_typed(captured_tool.name, invoke)
finally:
sequential_tools_lock.release()

return on_invoke_tool_wrapper

wrapped_tools.append(dataclasses.replace(tool, on_invoke_tool=create_wrapper(tool)))
# Disable parallel tool calls
model_settings = run_config.model_settings
if model_settings is None:
model_settings = ModelSettings(parallel_tool_calls=False)
else:
wrapped_tools.append(tool)
model_settings = dataclasses.replace(
model_settings,
parallel_tool_calls=False,
)
run_config = dataclasses.replace(
run_config,
model_settings=model_settings,
)

# Use Restate session if requested, otherwise use provided session
session = kwargs.pop("session", None)
if use_restate_session:
if session is not None:
raise TerminalError("When use_restate_session is True, session config cannot be provided.")
session = RestateSession()

handoffs_with_wrapped_tools = []
for handoff in agent.handoffs:
# recursively wrap tools in handoff agents
handoffs_with_wrapped_tools.append(sequentialize_and_wrap_tools(handoff, disable_tool_autowrapping)) # type: ignore
try:
result = await Runner.run(
starting_agent=starting_agent, input=input, run_config=run_config, session=session, **kwargs
)
finally:
# Flush session items to Restate
if session is not None and isinstance(session, RestateSession):
session.flush()

return agent.clone(
tools=wrapped_tools,
handoffs=handoffs_with_wrapped_tools,
)
return result