Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
grpc = ["grpcio>=1.48.2,<2"]
opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"]
pydantic = ["pydantic>=2.0.0,<3"]
openai-agents = ["openai-agents>=0.17.1", "mcp>=1.9.4, <2"]
openai-agents = ["openai-agents>=0.17.5", "mcp>=1.9.4, <2"]
google-adk = ["google-adk>=1.27.0,<2"]
langgraph = ["langgraph>=1.1.0"]
langsmith = ["langsmith>=0.7.34,<0.9"]
Expand Down
202 changes: 117 additions & 85 deletions temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from __future__ import annotations

import io
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from pathlib import Path
from typing import Any

from agents.sandbox.errors import SandboxError
from agents.sandbox.session.sandbox_client import BaseSandboxClient
from agents.sandbox.session.sandbox_session import SandboxSession

Expand Down Expand Up @@ -34,6 +36,22 @@
from temporalio.contrib.openai_agents.sandbox._temporal_activity_models import (
ExecResult as ExecResultModel,
)
from temporalio.exceptions import ApplicationError


@contextmanager
def _translate_sandbox_errors() -> Iterator[None]:
# Temporal retries every activity exception by default, so only a SandboxError
# the library has classified as terminal (retryable is False) is turned into a
# non-retryable ApplicationError.
try:
yield
except SandboxError as e:
if e.retryable is False:
raise ApplicationError(
str(e), type=str(e.error_code), non_retryable=True
) from e
raise


class SandboxClientProvider:
Expand Down Expand Up @@ -99,133 +117,147 @@ def _get_activities(self) -> Sequence[Callable[..., Any]]:

@activity.defn(name=f"{prefix}-sandbox_client_create")
async def create_session(args: CreateSessionArgs) -> SessionResult:
session = await self._client.create(
snapshot=args.snapshot_spec,
manifest=args.manifest,
options=args.client_options,
)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)
with _translate_sandbox_errors():
session = await self._client.create(
snapshot=args.snapshot_spec,
manifest=args.manifest,
options=args.client_options,
)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)

@activity.defn(name=f"{prefix}-sandbox_client_resume")
async def resume_session(args: ResumeSessionArgs) -> SessionResult:
session = await self._client.resume(args.state)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)
with _translate_sandbox_errors():
session = await self._client.resume(args.state)
self._sessions[str(session.state.session_id)] = session
return SessionResult(
state=session.state, supports_pty=session.supports_pty()
)

@activity.defn(name=f"{prefix}-sandbox_client_delete")
async def delete_session(args: StopArgs) -> None:
session = await self._session(args)
await self._client.delete(session)
return None
with _translate_sandbox_errors():
session = await self._session(args)
await self._client.delete(session)
return None

# -- Session-level operations (I/O and lifecycle) --

@activity.defn(name=f"{prefix}-sandbox_session_exec")
async def exec_(args: ExecArgs) -> ExecResultModel:
session = await self._session(args)
result = await session.exec(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
)
return ExecResultModel(
stdout=result.stdout,
stderr=result.stderr,
exit_code=result.exit_code,
)
with _translate_sandbox_errors():
session = await self._session(args)
result = await session.exec(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
)
return ExecResultModel(
stdout=result.stdout,
stderr=result.stderr,
exit_code=result.exit_code,
)

@activity.defn(name=f"{prefix}-sandbox_session_read")
async def read(args: ReadArgs) -> ReadResult:
session = await self._session(args)
handle = await session.read(Path(args.path))
return ReadResult(data=handle.read())
with _translate_sandbox_errors():
session = await self._session(args)
handle = await session.read(Path(args.path))
return ReadResult(data=handle.read())

@activity.defn(name=f"{prefix}-sandbox_session_write")
async def write(args: WriteArgs) -> None:
session = await self._session(args)
await session.write(Path(args.path), io.BytesIO(args.data))
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.write(Path(args.path), io.BytesIO(args.data))
return None

@activity.defn(name=f"{prefix}-sandbox_session_running")
async def running(args: RunningArgs) -> RunningResult:
session = await self._session(args)
return RunningResult(is_running=await session.running())
with _translate_sandbox_errors():
session = await self._session(args)
return RunningResult(is_running=await session.running())

@activity.defn(name=f"{prefix}-sandbox_session_persist_workspace")
async def persist_workspace(
args: PersistWorkspaceArgs,
) -> PersistWorkspaceResult:
session = await self._session(args)
stream = await session.persist_workspace()
return PersistWorkspaceResult(data=stream.read())
with _translate_sandbox_errors():
session = await self._session(args)
stream = await session.persist_workspace()
return PersistWorkspaceResult(data=stream.read())

@activity.defn(name=f"{prefix}-sandbox_session_hydrate_workspace")
async def hydrate_workspace(args: HydrateWorkspaceArgs) -> None:
session = await self._session(args)
await session.hydrate_workspace(io.BytesIO(args.data))
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.hydrate_workspace(io.BytesIO(args.data))
return None

@activity.defn(name=f"{prefix}-sandbox_session_pty_exec_start")
async def pty_exec_start(args: PtyExecStartArgs) -> PtyExecUpdateResult:
session = await self._session(args)
update = await session.pty_exec_start(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
tty=args.tty,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)
with _translate_sandbox_errors():
session = await self._session(args)
update = await session.pty_exec_start(
*args.command,
timeout=args.timeout,
shell=args.shell,
user=args.user,
tty=args.tty,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)

@activity.defn(name=f"{prefix}-sandbox_session_pty_write_stdin")
async def pty_write_stdin(args: PtyWriteStdinArgs) -> PtyExecUpdateResult:
session = await self._session(args)
update = await session.pty_write_stdin(
session_id=args.session_id,
chars=args.chars,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)
with _translate_sandbox_errors():
session = await self._session(args)
update = await session.pty_write_stdin(
session_id=args.session_id,
chars=args.chars,
yield_time_s=args.yield_time_s,
max_output_tokens=args.max_output_tokens,
)
return PtyExecUpdateResult(
process_id=update.process_id,
output=update.output,
exit_code=update.exit_code,
original_token_count=update.original_token_count,
)

@activity.defn(name=f"{prefix}-sandbox_session_start")
async def start(args: StartArgs) -> None:
session = await self._session(args)
await session.start()
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.start()
return None

@activity.defn(name=f"{prefix}-sandbox_session_stop")
async def session_stop(args: StopArgs) -> None:
session = await self._session(args)
await session.stop()
return None
with _translate_sandbox_errors():
session = await self._session(args)
await session.stop()
return None

@activity.defn(name=f"{prefix}-sandbox_session_shutdown")
async def session_shutdown(args: StopArgs) -> None:
key = str(args.state.session_id)
session = self._sessions.get(key)
if session is not None:
await session.shutdown()
del self._sessions[key]
return None
with _translate_sandbox_errors():
key = str(args.state.session_id)
session = self._sessions.get(key)
if session is not None:
await session.shutdown()
del self._sessions[key]
Comment on lines +255 to +259
return None

return [
create_session,
Expand Down
94 changes: 94 additions & 0 deletions tests/contrib/openai_agents/test_openai_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import pytest
from agents import Agent, FunctionTool, RunConfig, Runner, Tool
from agents.sandbox import Capability, Manifest, SandboxAgent, SandboxRunConfig
from agents.sandbox.errors import (
ExecTransportError,
SandboxError,
WorkspaceArchiveReadError,
)
from agents.sandbox.session.base_sandbox_session import BaseSandboxSession
from agents.sandbox.session.sandbox_client import (
BaseSandboxClient,
Expand Down Expand Up @@ -55,6 +60,7 @@
TestModelProvider,
)
from temporalio.contrib.openai_agents.workflow import temporal_sandbox_client
from temporalio.exceptions import ApplicationError
from temporalio.workflow import ActivityConfig
from tests.helpers import new_worker

Expand Down Expand Up @@ -569,6 +575,94 @@ async def test_multiple_providers_register_distinct_activities():
)


# ── SandboxError retryable mapping tests ──


class _ExecRaisingSession(_MockSandboxSession):
"""Mock session whose exec() raises a chosen SandboxError."""

def __init__(self, error: SandboxError) -> None:
super().__init__()
self._error = error

async def _exec_internal(
self,
*command: str | Path, # type: ignore[reportUnusedParameter]
timeout: float | None = None, # type: ignore[reportUnusedParameter]
) -> ExecResult:
raise self._error


async def _exec_with_error(error: SandboxError) -> None:
provider = SandboxClientProvider(
"mock", _MockSandboxClient(_ExecRaisingSession(error))
)
acts = _activity_map(provider)
state = (
await acts["mock-sandbox_client_create"](
CreateSessionArgs(
snapshot_spec=None, manifest=Manifest(), client_options=None
)
)
).state
await acts["mock-sandbox_session_exec"](
ExecArgs(state=state, command=["boom"], shell=True)
)


async def test_exec_terminal_error_becomes_non_retryable_application_error():
"""retryable is False should map to a non-retryable ApplicationError."""
with pytest.raises(ApplicationError) as exc_info:
await _exec_with_error(ExecTransportError(command=["boom"], retryable=False))
assert exc_info.value.non_retryable is True
assert exc_info.value.type == "exec_transport_error"


async def test_exec_transient_error_propagates_unchanged():
"""retryable is True should let the original SandboxError propagate."""
with pytest.raises(ExecTransportError):
await _exec_with_error(ExecTransportError(command=["boom"], retryable=True))


async def test_exec_unclassified_error_propagates_unchanged():
"""retryable is None should let the original SandboxError propagate (not converted)."""
with pytest.raises(ExecTransportError):
await _exec_with_error(ExecTransportError(command=["boom"], retryable=None))


class _RunningRaisingSession(_MockSandboxSession):
"""Mock session whose running() raises a chosen SandboxError."""

def __init__(self, error: SandboxError) -> None:
super().__init__()
self._error = error

async def running(self) -> bool:
raise self._error


async def test_running_terminal_error_becomes_non_retryable_application_error():
"""A terminal SandboxError from a non-exec activity also maps to a
non-retryable ApplicationError, with type set to its error_code."""
error = WorkspaceArchiveReadError(path=Path("/workspace"), retryable=False)
provider = SandboxClientProvider(
"mock", _MockSandboxClient(_RunningRaisingSession(error))
)
acts = _activity_map(provider)
state = (
await acts["mock-sandbox_client_create"](
CreateSessionArgs(
snapshot_spec=None, manifest=Manifest(), client_options=None
)
)
).state

with pytest.raises(ApplicationError) as exc_info:
await acts["mock-sandbox_session_running"](RunningArgs(state=state))
assert exc_info.value.non_retryable is True
assert exc_info.value.type == "workspace_archive_read_error"


# ── End-to-end test: Runner + SandboxAgent through Temporal activities ──


Expand Down
Loading