diff --git a/pyproject.toml b/pyproject.toml index 317b378cb..01eb4bfa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ grpc = ["grpcio>=1.48.2,<2"] opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] -openai-agents = ["openai-agents>=0.17.1", "mcp>=1.9.4, <2"] +openai-agents = ["openai-agents>=0.17.5", "mcp>=1.9.4, <2"] google-adk = ["google-adk>=1.27.0,<2"] langgraph = ["langgraph>=1.1.0"] langsmith = ["langsmith>=0.7.34,<0.9"] diff --git a/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py b/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py index 9e4d67644..675ccf9f7 100644 --- a/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py +++ b/temporalio/contrib/openai_agents/sandbox/_sandbox_client_provider.py @@ -3,10 +3,12 @@ from __future__ import annotations import io -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterator, Sequence +from contextlib import contextmanager from pathlib import Path from typing import Any +from agents.sandbox.errors import SandboxError from agents.sandbox.session.sandbox_client import BaseSandboxClient from agents.sandbox.session.sandbox_session import SandboxSession @@ -34,6 +36,22 @@ from temporalio.contrib.openai_agents.sandbox._temporal_activity_models import ( ExecResult as ExecResultModel, ) +from temporalio.exceptions import ApplicationError + + +@contextmanager +def _translate_sandbox_errors() -> Iterator[None]: + # Temporal retries every activity exception by default, so only a SandboxError + # the library has classified as terminal (retryable is False) is turned into a + # non-retryable ApplicationError. + try: + yield + except SandboxError as e: + if e.retryable is False: + raise ApplicationError( + str(e), type=str(e.error_code), non_retryable=True + ) from e + raise class SandboxClientProvider: @@ -99,133 +117,147 @@ def _get_activities(self) -> Sequence[Callable[..., Any]]: @activity.defn(name=f"{prefix}-sandbox_client_create") async def create_session(args: CreateSessionArgs) -> SessionResult: - session = await self._client.create( - snapshot=args.snapshot_spec, - manifest=args.manifest, - options=args.client_options, - ) - self._sessions[str(session.state.session_id)] = session - return SessionResult( - state=session.state, supports_pty=session.supports_pty() - ) + with _translate_sandbox_errors(): + session = await self._client.create( + snapshot=args.snapshot_spec, + manifest=args.manifest, + options=args.client_options, + ) + self._sessions[str(session.state.session_id)] = session + return SessionResult( + state=session.state, supports_pty=session.supports_pty() + ) @activity.defn(name=f"{prefix}-sandbox_client_resume") async def resume_session(args: ResumeSessionArgs) -> SessionResult: - session = await self._client.resume(args.state) - self._sessions[str(session.state.session_id)] = session - return SessionResult( - state=session.state, supports_pty=session.supports_pty() - ) + with _translate_sandbox_errors(): + session = await self._client.resume(args.state) + self._sessions[str(session.state.session_id)] = session + return SessionResult( + state=session.state, supports_pty=session.supports_pty() + ) @activity.defn(name=f"{prefix}-sandbox_client_delete") async def delete_session(args: StopArgs) -> None: - session = await self._session(args) - await self._client.delete(session) - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await self._client.delete(session) + return None # -- Session-level operations (I/O and lifecycle) -- @activity.defn(name=f"{prefix}-sandbox_session_exec") async def exec_(args: ExecArgs) -> ExecResultModel: - session = await self._session(args) - result = await session.exec( - *args.command, - timeout=args.timeout, - shell=args.shell, - user=args.user, - ) - return ExecResultModel( - stdout=result.stdout, - stderr=result.stderr, - exit_code=result.exit_code, - ) + with _translate_sandbox_errors(): + session = await self._session(args) + result = await session.exec( + *args.command, + timeout=args.timeout, + shell=args.shell, + user=args.user, + ) + return ExecResultModel( + stdout=result.stdout, + stderr=result.stderr, + exit_code=result.exit_code, + ) @activity.defn(name=f"{prefix}-sandbox_session_read") async def read(args: ReadArgs) -> ReadResult: - session = await self._session(args) - handle = await session.read(Path(args.path)) - return ReadResult(data=handle.read()) + with _translate_sandbox_errors(): + session = await self._session(args) + handle = await session.read(Path(args.path)) + return ReadResult(data=handle.read()) @activity.defn(name=f"{prefix}-sandbox_session_write") async def write(args: WriteArgs) -> None: - session = await self._session(args) - await session.write(Path(args.path), io.BytesIO(args.data)) - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.write(Path(args.path), io.BytesIO(args.data)) + return None @activity.defn(name=f"{prefix}-sandbox_session_running") async def running(args: RunningArgs) -> RunningResult: - session = await self._session(args) - return RunningResult(is_running=await session.running()) + with _translate_sandbox_errors(): + session = await self._session(args) + return RunningResult(is_running=await session.running()) @activity.defn(name=f"{prefix}-sandbox_session_persist_workspace") async def persist_workspace( args: PersistWorkspaceArgs, ) -> PersistWorkspaceResult: - session = await self._session(args) - stream = await session.persist_workspace() - return PersistWorkspaceResult(data=stream.read()) + with _translate_sandbox_errors(): + session = await self._session(args) + stream = await session.persist_workspace() + return PersistWorkspaceResult(data=stream.read()) @activity.defn(name=f"{prefix}-sandbox_session_hydrate_workspace") async def hydrate_workspace(args: HydrateWorkspaceArgs) -> None: - session = await self._session(args) - await session.hydrate_workspace(io.BytesIO(args.data)) - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.hydrate_workspace(io.BytesIO(args.data)) + return None @activity.defn(name=f"{prefix}-sandbox_session_pty_exec_start") async def pty_exec_start(args: PtyExecStartArgs) -> PtyExecUpdateResult: - session = await self._session(args) - update = await session.pty_exec_start( - *args.command, - timeout=args.timeout, - shell=args.shell, - user=args.user, - tty=args.tty, - yield_time_s=args.yield_time_s, - max_output_tokens=args.max_output_tokens, - ) - return PtyExecUpdateResult( - process_id=update.process_id, - output=update.output, - exit_code=update.exit_code, - original_token_count=update.original_token_count, - ) + with _translate_sandbox_errors(): + session = await self._session(args) + update = await session.pty_exec_start( + *args.command, + timeout=args.timeout, + shell=args.shell, + user=args.user, + tty=args.tty, + yield_time_s=args.yield_time_s, + max_output_tokens=args.max_output_tokens, + ) + return PtyExecUpdateResult( + process_id=update.process_id, + output=update.output, + exit_code=update.exit_code, + original_token_count=update.original_token_count, + ) @activity.defn(name=f"{prefix}-sandbox_session_pty_write_stdin") async def pty_write_stdin(args: PtyWriteStdinArgs) -> PtyExecUpdateResult: - session = await self._session(args) - update = await session.pty_write_stdin( - session_id=args.session_id, - chars=args.chars, - yield_time_s=args.yield_time_s, - max_output_tokens=args.max_output_tokens, - ) - return PtyExecUpdateResult( - process_id=update.process_id, - output=update.output, - exit_code=update.exit_code, - original_token_count=update.original_token_count, - ) + with _translate_sandbox_errors(): + session = await self._session(args) + update = await session.pty_write_stdin( + session_id=args.session_id, + chars=args.chars, + yield_time_s=args.yield_time_s, + max_output_tokens=args.max_output_tokens, + ) + return PtyExecUpdateResult( + process_id=update.process_id, + output=update.output, + exit_code=update.exit_code, + original_token_count=update.original_token_count, + ) @activity.defn(name=f"{prefix}-sandbox_session_start") async def start(args: StartArgs) -> None: - session = await self._session(args) - await session.start() - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.start() + return None @activity.defn(name=f"{prefix}-sandbox_session_stop") async def session_stop(args: StopArgs) -> None: - session = await self._session(args) - await session.stop() - return None + with _translate_sandbox_errors(): + session = await self._session(args) + await session.stop() + return None @activity.defn(name=f"{prefix}-sandbox_session_shutdown") async def session_shutdown(args: StopArgs) -> None: - key = str(args.state.session_id) - session = self._sessions.get(key) - if session is not None: - await session.shutdown() - del self._sessions[key] - return None + with _translate_sandbox_errors(): + key = str(args.state.session_id) + session = self._sessions.get(key) + if session is not None: + await session.shutdown() + del self._sessions[key] + return None return [ create_session, diff --git a/tests/contrib/openai_agents/test_openai_sandbox.py b/tests/contrib/openai_agents/test_openai_sandbox.py index 74ff80e85..66a9f5a7c 100644 --- a/tests/contrib/openai_agents/test_openai_sandbox.py +++ b/tests/contrib/openai_agents/test_openai_sandbox.py @@ -9,6 +9,11 @@ import pytest from agents import Agent, FunctionTool, RunConfig, Runner, Tool from agents.sandbox import Capability, Manifest, SandboxAgent, SandboxRunConfig +from agents.sandbox.errors import ( + ExecTransportError, + SandboxError, + WorkspaceArchiveReadError, +) from agents.sandbox.session.base_sandbox_session import BaseSandboxSession from agents.sandbox.session.sandbox_client import ( BaseSandboxClient, @@ -55,6 +60,7 @@ TestModelProvider, ) from temporalio.contrib.openai_agents.workflow import temporal_sandbox_client +from temporalio.exceptions import ApplicationError from temporalio.workflow import ActivityConfig from tests.helpers import new_worker @@ -569,6 +575,94 @@ async def test_multiple_providers_register_distinct_activities(): ) +# ── SandboxError retryable mapping tests ── + + +class _ExecRaisingSession(_MockSandboxSession): + """Mock session whose exec() raises a chosen SandboxError.""" + + def __init__(self, error: SandboxError) -> None: + super().__init__() + self._error = error + + async def _exec_internal( + self, + *command: str | Path, # type: ignore[reportUnusedParameter] + timeout: float | None = None, # type: ignore[reportUnusedParameter] + ) -> ExecResult: + raise self._error + + +async def _exec_with_error(error: SandboxError) -> None: + provider = SandboxClientProvider( + "mock", _MockSandboxClient(_ExecRaisingSession(error)) + ) + acts = _activity_map(provider) + state = ( + await acts["mock-sandbox_client_create"]( + CreateSessionArgs( + snapshot_spec=None, manifest=Manifest(), client_options=None + ) + ) + ).state + await acts["mock-sandbox_session_exec"]( + ExecArgs(state=state, command=["boom"], shell=True) + ) + + +async def test_exec_terminal_error_becomes_non_retryable_application_error(): + """retryable is False should map to a non-retryable ApplicationError.""" + with pytest.raises(ApplicationError) as exc_info: + await _exec_with_error(ExecTransportError(command=["boom"], retryable=False)) + assert exc_info.value.non_retryable is True + assert exc_info.value.type == "exec_transport_error" + + +async def test_exec_transient_error_propagates_unchanged(): + """retryable is True should let the original SandboxError propagate.""" + with pytest.raises(ExecTransportError): + await _exec_with_error(ExecTransportError(command=["boom"], retryable=True)) + + +async def test_exec_unclassified_error_propagates_unchanged(): + """retryable is None should let the original SandboxError propagate (not converted).""" + with pytest.raises(ExecTransportError): + await _exec_with_error(ExecTransportError(command=["boom"], retryable=None)) + + +class _RunningRaisingSession(_MockSandboxSession): + """Mock session whose running() raises a chosen SandboxError.""" + + def __init__(self, error: SandboxError) -> None: + super().__init__() + self._error = error + + async def running(self) -> bool: + raise self._error + + +async def test_running_terminal_error_becomes_non_retryable_application_error(): + """A terminal SandboxError from a non-exec activity also maps to a + non-retryable ApplicationError, with type set to its error_code.""" + error = WorkspaceArchiveReadError(path=Path("/workspace"), retryable=False) + provider = SandboxClientProvider( + "mock", _MockSandboxClient(_RunningRaisingSession(error)) + ) + acts = _activity_map(provider) + state = ( + await acts["mock-sandbox_client_create"]( + CreateSessionArgs( + snapshot_spec=None, manifest=Manifest(), client_options=None + ) + ) + ).state + + with pytest.raises(ApplicationError) as exc_info: + await acts["mock-sandbox_session_running"](RunningArgs(state=state)) + assert exc_info.value.non_retryable is True + assert exc_info.value.type == "workspace_archive_read_error" + + # ── End-to-end test: Runner + SandboxAgent through Temporal activities ──