diff --git a/docs/ref/extensions/sandbox/openshell/sandbox.md b/docs/ref/extensions/sandbox/openshell/sandbox.md new file mode 100644 index 0000000000..f88518b42c --- /dev/null +++ b/docs/ref/extensions/sandbox/openshell/sandbox.md @@ -0,0 +1,3 @@ +# `Sandbox` + +::: agents.extensions.sandbox.openshell.sandbox diff --git a/docs/sandbox/clients.md b/docs/sandbox/clients.md index bd21da63d3..b6b64b44ae 100644 --- a/docs/sandbox/clients.md +++ b/docs/sandbox/clients.md @@ -95,6 +95,7 @@ For provider-specific setup notes and links for the checked-in extension example | `DaytonaSandboxClient` | `openai-agents[daytona]` | [Daytona runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/daytona/daytona_runner.py) | | `E2BSandboxClient` | `openai-agents[e2b]` | [E2B runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/e2b_runner.py) | | `ModalSandboxClient` | `openai-agents[modal]` | [Modal runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/modal_runner.py) | +| `OpenShellSandboxClient` | `openai-agents[openshell]` | [OpenShell runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/openshell_runner.py) | | `RunloopSandboxClient` | `openai-agents[runloop]` | [Runloop runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/runloop/runner.py) | | `VercelSandboxClient` | `openai-agents[vercel]` | [Vercel runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/vercel_runner.py) | @@ -113,6 +114,7 @@ Hosted sandbox clients expose provider-specific mount strategies. Choose the bac | `DaytonaSandboxClient` | Supports rclone-backed cloud storage mounts with `DaytonaCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | | `E2BSandboxClient` | Supports rclone-backed cloud storage mounts with `E2BCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | | `RunloopSandboxClient` | Supports rclone-backed cloud storage mounts with `RunloopCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | +| `OpenShellSandboxClient` | No hosted-specific mount strategy is currently exposed. Use manifest files, repos, or other workspace inputs instead. | | `VercelSandboxClient` | No hosted-specific mount strategy is currently exposed. Use manifest files, repos, or other workspace inputs instead. | @@ -130,6 +132,7 @@ The table below summarizes which remote storage entries each backend can mount d | `DaytonaSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | | `E2BSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | | `RunloopSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | +| `OpenShellSandboxClient` | - | - | - | - | - | - | | `VercelSandboxClient` | - | - | - | - | - | - | diff --git a/examples/sandbox/extensions/openshell_runner.py b/examples/sandbox/extensions/openshell_runner.py new file mode 100644 index 0000000000..27dfc15a5b --- /dev/null +++ b/examples/sandbox/extensions/openshell_runner.py @@ -0,0 +1,301 @@ +""" +OpenShell sandbox integration example. + +This script exercises the OpenShell sandbox extension at two levels: + +1. **Session-level** (no LLM needed): Creates a sandbox, writes files, reads them + back, runs commands, and verifies workspace persistence. This validates the + extension works end-to-end with a real OpenShell gateway. + +2. **Agent-level** (requires OPENAI_API_KEY): Runs a SandboxAgent with a shell + capability inside the OpenShell sandbox. + +Prerequisites: + - An OpenShell gateway running (local, remote, or cloud). + - ``openshell`` Python package installed: ``uv sync --extra openshell`` + - For agent mode: ``OPENAI_API_KEY`` environment variable set. + +Quick start: + # Session-level only (no LLM): + uv run python examples/sandbox/extensions/openshell_runner.py --session-only + + # Full agent run: + uv run python examples/sandbox/extensions/openshell_runner.py + + # With a specific cluster: + uv run python examples/sandbox/extensions/openshell_runner.py --cluster my-gateway + + # With a custom image: + uv run python examples/sandbox/extensions/openshell_runner.py --image ubuntu:24.04 +""" + +from __future__ import annotations + +import argparse +import asyncio +import io +import os +import sys +from pathlib import Path + +try: + from agents.extensions.sandbox import ( + OpenShellSandboxClient, + OpenShellSandboxClientOptions, + ) +except Exception as exc: + raise SystemExit( + "OpenShell sandbox examples require the optional openshell extra.\n" + "Install it with: uv sync --extra openshell" + ) from exc + + +async def session_level_test( + *, + cluster: str | None, + endpoint: str | None, + image: str | None, + gpu: bool, +) -> None: + """Exercise the sandbox extension directly without an LLM.""" + + from agents.sandbox import Manifest + from agents.sandbox.entries import File + + print("=== OpenShell Session-Level Test ===\n") + + # Build a manifest with test files. + # OpenShell sandboxes default to /sandbox as the working directory. + manifest = Manifest( + root="/sandbox", + entries={ + "hello.txt": File(content=b"Hello from OpenShell sandbox!\n"), + "data/numbers.csv": File(content=b"a,b,c\n1,2,3\n4,5,6\n"), + }, + ) + + client = OpenShellSandboxClient() + options = OpenShellSandboxClientOptions( + cluster=cluster, + endpoint=endpoint, + image=image, + gpu=gpu, + ) + + print("1. Creating sandbox...") + session = await client.create(manifest=manifest, options=options) + + try: + print("2. Starting session (materializing workspace)...") + await session.start() + + print("3. Running 'ls -la' in workspace...") + result = await session.exec("ls", "-la", shell=False) + print(f" exit_code={result.exit_code}") + print(f" stdout:\n{result.stdout.decode()}") + + print("4. Reading hello.txt...") + content = await session.read(Path("hello.txt")) + text = content.read() + if isinstance(text, bytes): + text = text.decode("utf-8") + print(f" content: {text.strip()!r}") + assert "Hello from OpenShell sandbox!" in text, "Read verification failed." + + print("5. Writing a new file...") + await session.write( + Path("output.txt"), + io.BytesIO(b"Written by the OpenAI Agents SDK via OpenShell.\n"), + ) + + print("6. Verifying the written file...") + result = await session.exec("cat", "output.txt", shell=False) + assert result.exit_code == 0, f"cat failed: {result.stderr.decode()}" + print(f" content: {result.stdout.decode().strip()!r}") + + print("7. Running a multi-step shell command...") + result = await session.exec("wc -l data/numbers.csv && echo 'done'") + print(f" output: {result.stdout.decode().strip()}") + + print("8. Checking sandbox is running...") + is_running = await session.running() + print(f" running: {is_running}") + assert is_running, "Sandbox should be running." + + print("9. Persisting workspace (tar snapshot)...") + snapshot = await session.persist_workspace() + snapshot_bytes = snapshot.read() + print(f" snapshot size: {len(snapshot_bytes)} bytes") + assert len(snapshot_bytes) > 0, "Snapshot should not be empty." + + print("\nAll session-level checks passed.") + + finally: + print("\n10. Shutting down sandbox...") + await session.aclose() + print(" Done.") + + +async def agent_level_test( + *, + model: str, + cluster: str | None, + endpoint: str | None, + image: str | None, + gpu: bool, + question: str, + stream: bool, +) -> None: + """Run a SandboxAgent backed by OpenShell.""" + + from openai.types.responses import ResponseTextDeltaEvent + + from agents import ModelSettings, Runner + from agents.run import RunConfig + from agents.sandbox import Manifest, SandboxAgent, SandboxRunConfig + from agents.sandbox.entries import File + + if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + + from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability + + print("\n=== OpenShell Agent-Level Test ===\n") + + manifest = Manifest( + root="/sandbox", + entries={ + "README.md": File( + content=( + b"# Project Status\n\nThis workspace contains a sample project status report.\n" + ), + ), + "status.md": File( + content=( + b"# Sprint 42 Status\n\n" + b"- Auth service: on track, shipping Tuesday.\n" + b"- Search reindex: blocked on infra ticket INFRA-1234.\n" + b"- Dashboard v2: 80% complete, needs UX review.\n" + ), + ), + }, + ) + + agent = SandboxAgent( + name="OpenShell Sandbox Assistant", + model=model, + instructions=( + "Answer questions about the sandbox workspace. Inspect the files before answering " + "and keep the response concise. " + "Do not invent files or statuses that are not present in the workspace. Cite the " + "file names you inspected." + ), + default_manifest=manifest, + capabilities=[WorkspaceShellCapability()], + model_settings=ModelSettings(tool_choice="required"), + ) + + run_config = RunConfig( + sandbox=SandboxRunConfig( + client=OpenShellSandboxClient(), + options=OpenShellSandboxClientOptions( + cluster=cluster, + endpoint=endpoint, + image=image, + gpu=gpu, + ), + ), + workflow_name="OpenShell sandbox example", + ) + + if not stream: + result = await Runner.run(agent, question, run_config=run_config) + print(f"assistant> {result.final_output}") + return + + stream_result = Runner.run_streamed(agent, question, run_config=run_config) + saw_text_delta = False + async for event in stream_result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + if saw_text_delta: + print() + + +async def main( + *, + model: str, + cluster: str | None, + endpoint: str | None, + image: str | None, + gpu: bool, + question: str, + stream: bool, + session_only: bool, +) -> None: + # Session-level test always runs (no LLM needed). + await session_level_test( + cluster=cluster, + endpoint=endpoint, + image=image, + gpu=gpu, + ) + + if session_only: + return + + # Agent-level test requires OPENAI_API_KEY. + if not os.environ.get("OPENAI_API_KEY"): + print("\nSkipping agent-level test (OPENAI_API_KEY not set).") + print("Set OPENAI_API_KEY and remove --session-only to run the full test.") + return + + await agent_level_test( + model=model, + cluster=cluster, + endpoint=endpoint, + image=image, + gpu=gpu, + question=question, + stream=stream, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="OpenShell sandbox integration example for the OpenAI Agents SDK." + ) + parser.add_argument("--model", default="gpt-4.1-mini", help="Model name to use.") + parser.add_argument( + "--question", + default="Summarize the project status from the workspace files.", + help="Prompt to send to the agent.", + ) + parser.add_argument("--cluster", default=None, help="OpenShell gateway cluster name.") + parser.add_argument("--endpoint", default=None, help="Explicit gateway endpoint (host:port).") + parser.add_argument("--image", default=None, help="Container image for the sandbox.") + parser.add_argument("--gpu", action="store_true", default=False, help="Request GPU.") + parser.add_argument("--stream", action="store_true", default=False, help="Stream the response.") + parser.add_argument( + "--session-only", + action="store_true", + default=False, + help="Run session-level test only (no LLM needed).", + ) + args = parser.parse_args() + + asyncio.run( + main( + model=args.model, + cluster=args.cluster, + endpoint=args.endpoint, + image=args.image, + gpu=args.gpu, + question=args.question, + stream=args.stream, + session_only=args.session_only, + ) + ) diff --git a/pyproject.toml b/pyproject.toml index 4d0122049f..79c355455f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ e2b = ["e2b==2.20.0", "e2b-code-interpreter==2.4.1"] modal = ["modal==1.3.5"] runloop = ["runloop_api_client>=1.16.0,<2.0.0"] vercel = ["vercel>=0.5.6,<0.6"] +openshell = ["openshell>=0.0.0a0"] s3 = ["boto3>=1.34"] temporal = [ "temporalio==1.26.0", @@ -164,6 +165,10 @@ ignore_missing_imports = true module = ["vercel", "vercel.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["openshell", "openshell.*"] +ignore_missing_imports = true + [tool.coverage.run] source = ["src/agents"] omit = [ diff --git a/src/agents/extensions/sandbox/__init__.py b/src/agents/extensions/sandbox/__init__.py index d7b082ba1f..b4c290190e 100644 --- a/src/agents/extensions/sandbox/__init__.py +++ b/src/agents/extensions/sandbox/__init__.py @@ -109,6 +109,18 @@ except Exception: # pragma: no cover _HAS_VERCEL = False +try: + from .openshell import ( + OpenShellSandboxClient as OpenShellSandboxClient, + OpenShellSandboxClientOptions as OpenShellSandboxClientOptions, + OpenShellSandboxSession as OpenShellSandboxSession, + OpenShellSandboxSessionState as OpenShellSandboxSessionState, + ) + + _HAS_OPENSHELL = True +except Exception: # pragma: no cover + _HAS_OPENSHELL = False + __all__: list[str] = [] if _HAS_E2B: @@ -207,3 +219,13 @@ "RunloopUserParameters", ] ) + +if _HAS_OPENSHELL: + __all__.extend( + [ + "OpenShellSandboxClient", + "OpenShellSandboxClientOptions", + "OpenShellSandboxSession", + "OpenShellSandboxSessionState", + ] + ) diff --git a/src/agents/extensions/sandbox/openshell/__init__.py b/src/agents/extensions/sandbox/openshell/__init__.py new file mode 100644 index 0000000000..84ae50d981 --- /dev/null +++ b/src/agents/extensions/sandbox/openshell/__init__.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from .sandbox import ( + OpenShellSandboxClient as OpenShellSandboxClient, + OpenShellSandboxClientOptions as OpenShellSandboxClientOptions, + OpenShellSandboxSession as OpenShellSandboxSession, + OpenShellSandboxSessionState as OpenShellSandboxSessionState, +) + +__all__ = [ + "OpenShellSandboxClient", + "OpenShellSandboxClientOptions", + "OpenShellSandboxSession", + "OpenShellSandboxSessionState", +] diff --git a/src/agents/extensions/sandbox/openshell/sandbox.py b/src/agents/extensions/sandbox/openshell/sandbox.py new file mode 100644 index 0000000000..9f48bd4a37 --- /dev/null +++ b/src/agents/extensions/sandbox/openshell/sandbox.py @@ -0,0 +1,559 @@ +""" +OpenShell sandbox (https://github.com/NVIDIA/OpenShell) implementation. + +Export ``OPENSHELL_GATEWAY`` or configure a gateway cluster to connect. + +The ``openshell`` dependency is optional, so package-level exports should guard +imports of this module. Within this module, OpenShell SDK imports are lazy so +users without the extra can still import the package. +""" + +from __future__ import annotations + +import asyncio +import base64 +import functools +import io +import logging +import shlex +import uuid +from pathlib import Path, PurePosixPath +from typing import Any, Literal + +from pydantic import Field + +from ....sandbox.errors import ( + ExecTransportError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceStartError, +) +from ....sandbox.manifest import Manifest +from ....sandbox.session import SandboxSession, SandboxSessionState +from ....sandbox.session.base_sandbox_session import BaseSandboxSession +from ....sandbox.session.dependencies import Dependencies +from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions +from ....sandbox.session.tar_workspace import shell_tar_exclude_args +from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot +from ....sandbox.types import ExecResult, User +from ....sandbox.workspace_paths import posix_path_as_path, sandbox_path_str + +logger = logging.getLogger(__name__) + +# OpenShell phase constant for a ready sandbox. +_OPENSHELL_PHASE_READY = 2 + + +class OpenShellSandboxClientOptions(BaseSandboxClientOptions): + """Client options for the OpenShell sandbox backend.""" + + type: Literal["openshell"] = "openshell" + cluster: str | None = None + endpoint: str | None = None + tls_ca_path: str | None = None + tls_cert_path: str | None = None + tls_key_path: str | None = None + image: str | None = None + envs: dict[str, str] | None = None + gpu: bool = False + providers: list[str] | None = None + timeout: float = 30.0 + ready_timeout: float = 120.0 + + def __init__( + self, + cluster: str | None = None, + endpoint: str | None = None, + tls_ca_path: str | None = None, + tls_cert_path: str | None = None, + tls_key_path: str | None = None, + image: str | None = None, + envs: dict[str, str] | None = None, + gpu: bool = False, + providers: list[str] | None = None, + timeout: float = 30.0, + ready_timeout: float = 120.0, + *, + type: Literal["openshell"] = "openshell", + ) -> None: + super().__init__( + type=type, + cluster=cluster, + endpoint=endpoint, + tls_ca_path=tls_ca_path, + tls_cert_path=tls_cert_path, + tls_key_path=tls_key_path, + image=image, + envs=envs, + gpu=gpu, + providers=providers, + timeout=timeout, + ready_timeout=ready_timeout, + ) + + +class OpenShellSandboxSessionState(SandboxSessionState): + """Serializable state for an OpenShell-backed session.""" + + type: Literal["openshell"] = "openshell" + sandbox_id: str + sandbox_name: str + cluster: str | None = None + endpoint: str | None = None + tls_ca_path: str | None = None + tls_cert_path: str | None = None + tls_key_path: str | None = None + base_envs: dict[str, str] = Field(default_factory=dict) + image: str | None = None + gpu: bool = False + providers_list: list[str] = Field(default_factory=list) + client_timeout: float = 30.0 + ready_timeout: float = 120.0 + + +def _import_openshell_client() -> Any: + """Lazy-import the OpenShell SandboxClient class.""" + try: + from openshell.sandbox import SandboxClient + + return SandboxClient + except ImportError as exc: + raise ImportError( + "OpenShellSandboxClient requires the optional `openshell` dependency.\n" + "Install the openshell extra before using this sandbox backend." + ) from exc + + +def _import_openshell_proto() -> Any: + """Lazy-import the OpenShell protobuf module.""" + try: + from openshell._proto import openshell_pb2 + + return openshell_pb2 + except ImportError as exc: + raise ImportError( + "OpenShellSandboxClient requires the optional `openshell` dependency." + ) from exc + + +def _build_tls_config( + *, + ca_path: str | None, + cert_path: str | None, + key_path: str | None, +) -> Any: + """Build an OpenShell TlsConfig from file paths.""" + import pathlib + + from openshell.sandbox import TlsConfig + + assert ca_path is not None, "ca_path is required for TLS" + assert cert_path is not None, "cert_path is required for TLS" + assert key_path is not None, "key_path is required for TLS" + return TlsConfig( + ca_path=pathlib.Path(ca_path), + cert_path=pathlib.Path(cert_path), + key_path=pathlib.Path(key_path), + ) + + +def _resolve_openshell_client(options: OpenShellSandboxClientOptions) -> Any: + """Create an OpenShell SandboxClient from options.""" + SandboxClientCls = _import_openshell_client() + if options.endpoint: + tls = ( + _build_tls_config( + ca_path=options.tls_ca_path, + cert_path=options.tls_cert_path, + key_path=options.tls_key_path, + ) + if options.tls_ca_path + else None + ) + return SandboxClientCls(options.endpoint, tls=tls, timeout=options.timeout) + return SandboxClientCls.from_active_cluster(cluster=options.cluster, timeout=options.timeout) + + +def _build_openshell_client(state: OpenShellSandboxSessionState) -> Any: + """Rebuild an OpenShell SandboxClient from persisted state.""" + SandboxClientCls = _import_openshell_client() + if state.endpoint: + tls = ( + _build_tls_config( + ca_path=state.tls_ca_path, + cert_path=state.tls_cert_path, + key_path=state.tls_key_path, + ) + if state.tls_ca_path + else None + ) + return SandboxClientCls(state.endpoint, tls=tls, timeout=state.client_timeout) + return SandboxClientCls.from_active_cluster(cluster=state.cluster, timeout=state.client_timeout) + + +def _build_sandbox_spec(options: OpenShellSandboxClientOptions) -> Any: + """Build an openshell_pb2.SandboxSpec from client options.""" + pb2 = _import_openshell_proto() + template = None + if options.image: + template = pb2.SandboxTemplate(image=options.image) + return pb2.SandboxSpec( + environment=dict(options.envs or {}), + template=template, + gpu=options.gpu, + providers=list(options.providers or []), + ) + + +class OpenShellSandboxSession(BaseSandboxSession): + """SandboxSession implementation backed by an OpenShell sandbox.""" + + state: OpenShellSandboxSessionState + + def __init__( + self, + *, + state: OpenShellSandboxSessionState, + openshell_client: Any, + ) -> None: + self.state = state + self._openshell_client = openshell_client + self._workspace_root_ready = state.workspace_root_ready + + # -- internal helpers ------------------------------------------------------ + + async def _run_sync(self, fn: Any, *args: Any, **kwargs: Any) -> Any: + """Run a synchronous function in the default executor.""" + loop = asyncio.get_running_loop() + bound = functools.partial(fn, *args, **kwargs) + return await loop.run_in_executor(None, bound) + + async def _resolved_envs(self) -> dict[str, str]: + """Merge base environment with manifest-declared environment variables.""" + manifest_env = await self.state.manifest.environment.resolve() + merged: dict[str, str] = {} + merged.update(self.state.base_envs) + for key, value in manifest_env.items(): + if value is not None: + merged[key] = value + return merged + + async def _validate_path_access(self, path: Path | str, *, for_write: bool = False) -> Path: + """Validate path against workspace root using local normalization. + + OpenShell rejects command arguments containing newline characters, so the + remote path resolution helper (which installs a multi-line shell script + via exec) cannot be used. Local normalization is sufficient because + OpenShell enforces its own filesystem policy inside the sandbox. + """ + return self.normalize_path(path, for_write=for_write) + + def _mark_workspace_root_ready_from_probe(self) -> None: + """Record that the preserved-backend workspace root was proven ready.""" + super()._mark_workspace_root_ready_from_probe() + self._workspace_root_ready = True + + async def _prepare_backend_workspace(self) -> None: + """Ensure the workspace root directory exists inside the sandbox.""" + root = PurePosixPath(self.state.manifest.root) + try: + result = await self._exec_internal("mkdir", "-p", "--", root.as_posix()) + except Exception as exc: + raise WorkspaceStartError(path=posix_path_as_path(root), cause=exc) from exc + + if result.exit_code != 0: + raise WorkspaceStartError( + path=posix_path_as_path(root), + context={ + "exit_code": result.exit_code, + "stdout": result.stdout.decode("utf-8", errors="replace"), + "stderr": result.stderr.decode("utf-8", errors="replace"), + }, + ) + self._workspace_root_ready = True + + async def _shutdown_backend(self) -> None: + """Best-effort delete of the sandbox and close the gRPC channel.""" + try: + await self._run_sync(self._openshell_client.delete, self.state.sandbox_name) + except Exception as exc: + logger.warning("OpenShell sandbox delete failed (non-fatal): %s", exc) + try: + await self._run_sync(self._openshell_client.close) + except Exception as exc: + logger.warning("OpenShell client close failed (non-fatal): %s", exc) + + # -- exec ------------------------------------------------------------------ + + async def _exec_internal( + self, *command: str | Path, timeout: float | None = None + ) -> ExecResult: + """Execute a command inside the OpenShell sandbox.""" + cmd_list = [str(part) for part in command] + if not cmd_list: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + workdir: str | None = None + if self._workspace_root_ready: + workdir = self.state.manifest.root + + envs = await self._resolved_envs() + + exec_kwargs: dict[str, Any] = { + "workdir": workdir, + "env": envs or None, + } + if timeout is not None: + exec_kwargs["timeout_seconds"] = int(timeout) + + try: + result = await self._run_sync( + self._openshell_client.exec, + self.state.sandbox_id, + cmd_list, + **exec_kwargs, + ) + # OpenShell ExecResult returns stdout/stderr as str. + stdout = ( + result.stdout.encode("utf-8") if isinstance(result.stdout, str) else result.stdout + ) + stderr = ( + result.stderr.encode("utf-8") if isinstance(result.stderr, str) else result.stderr + ) + return ExecResult(stdout=stdout, stderr=stderr, exit_code=result.exit_code) + except ExecTransportError: + raise + except Exception as exc: + raise ExecTransportError( + command=cmd_list, + context={"backend": "openshell", "sandbox_id": self.state.sandbox_id}, + cause=exc, + ) from exc + + # -- file I/O -------------------------------------------------------------- + + async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase: + """Read a file from the sandbox by base64-encoding on the remote side.""" + normalized = await self._validate_path_access(path) + path_arg = sandbox_path_str(normalized) + result = await self.exec("base64", "-w0", "--", path_arg, shell=False, user=user) + if not result.ok(): + raise WorkspaceReadNotFoundError(path=normalized) + raw = base64.b64decode(result.stdout) + return io.BytesIO(raw) + + async def write(self, path: Path, data: io.IOBase, *, user: str | User | None = None) -> None: + """Write a file into the sandbox by piping base64-encoded data.""" + normalized = await self._validate_path_access(path, for_write=True) + payload = data.read() + if isinstance(payload, str): + payload = payload.encode("utf-8") + if not isinstance(payload, bytes | bytearray): + from ....sandbox.errors import WorkspaceWriteTypeError + + raise WorkspaceWriteTypeError(path=normalized, actual_type=type(payload).__name__) + + encoded = base64.b64encode(bytes(payload)).decode("ascii") + path_arg = sandbox_path_str(normalized) + # Ensure the parent directory exists. + parent_cmd = ("mkdir", "-p", "--", str(PurePosixPath(path_arg).parent)) + await self.exec(*parent_cmd, shell=False, user=user) + # Write the file via printf | base64 -d. + write_cmd = f"printf '%s' {shlex.quote(encoded)} | base64 -d > {shlex.quote(path_arg)}" + result = await self.exec("sh", "-c", write_cmd, shell=False, user=user) + if not result.ok(): + raise WorkspaceArchiveWriteError( + path=normalized, + context={ + "exit_code": result.exit_code, + "stderr": result.stderr.decode("utf-8", errors="replace"), + }, + ) + + # -- status ---------------------------------------------------------------- + + async def running(self) -> bool: + """Check whether the sandbox is still running.""" + if not self._workspace_root_ready: + return False + try: + ref = await self._run_sync(self._openshell_client.get, self.state.sandbox_name) + return bool(ref.phase == _OPENSHELL_PHASE_READY) + except Exception: + return False + + # -- workspace persistence ------------------------------------------------- + + def _tar_exclude_args(self) -> list[str]: + """Build tar exclude flags from the skip paths.""" + return shell_tar_exclude_args(self._persist_workspace_skip_relpaths()) + + async def persist_workspace(self) -> io.IOBase: + """Serialize the workspace to a tar archive streamed via base64.""" + root = self._workspace_root_path() + excludes = " ".join(self._tar_exclude_args()) + tar_cmd = f"tar {excludes} -C {shlex.quote(root.as_posix())} -cf - . | base64 -w0" + result = await self._exec_internal("sh", "-c", tar_cmd) + if result.exit_code != 0: + raise WorkspaceArchiveReadError( + path=root, + context={ + "reason": "tar_failed", + "stderr": result.stderr.decode("utf-8", errors="replace"), + }, + ) + raw = base64.b64decode(result.stdout) + return io.BytesIO(raw) + + async def hydrate_workspace(self, data: io.IOBase) -> None: + """Populate the workspace from a tar archive.""" + root = self._workspace_root_path() + payload = data.read() + if isinstance(payload, str): + payload = payload.encode("utf-8") + if not isinstance(payload, bytes | bytearray): + from ....sandbox.errors import WorkspaceWriteTypeError + + raise WorkspaceWriteTypeError(path=root, actual_type=type(payload).__name__) + + encoded = base64.b64encode(bytes(payload)).decode("ascii") + await self.mkdir(root, parents=True) + untar_cmd = ( + f"printf '%s' {shlex.quote(encoded)} " + f"| base64 -d " + f"| tar xf - -C {shlex.quote(root.as_posix())}" + ) + result = await self._exec_internal("sh", "-c", untar_cmd) + if result.exit_code != 0: + raise WorkspaceArchiveWriteError( + path=root, + context={ + "reason": "untar_failed", + "stderr": result.stderr.decode("utf-8", errors="replace"), + }, + ) + + +class OpenShellSandboxClient(BaseSandboxClient["OpenShellSandboxClientOptions"]): + """OpenShell-backed sandbox client.""" + + backend_id = "openshell" + + def __init__( + self, + *, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + super().__init__() + self._instrumentation = instrumentation or Instrumentation() + self._dependencies = dependencies + + async def create( + self, + *, + snapshot: SnapshotSpec | SnapshotBase | None = None, + manifest: Manifest | None = None, + options: OpenShellSandboxClientOptions, + ) -> SandboxSession: + manifest = manifest or Manifest() + os_client = _resolve_openshell_client(options) + spec = _build_sandbox_spec(options) + + loop = asyncio.get_running_loop() + sandbox_ref = await loop.run_in_executor( + None, functools.partial(os_client.create, spec=spec) + ) + sandbox_ref = await loop.run_in_executor( + None, + functools.partial( + os_client.wait_ready, + sandbox_ref.name, + timeout_seconds=options.ready_timeout, + ), + ) + + session_id = uuid.uuid4() + snapshot_instance = resolve_snapshot(snapshot, str(session_id)) + state = OpenShellSandboxSessionState( + session_id=session_id, + manifest=manifest, + snapshot=snapshot_instance, + sandbox_id=str(sandbox_ref.id), + sandbox_name=str(sandbox_ref.name), + cluster=options.cluster, + endpoint=options.endpoint, + tls_ca_path=options.tls_ca_path, + tls_cert_path=options.tls_cert_path, + tls_key_path=options.tls_key_path, + base_envs=dict(options.envs or {}), + image=options.image, + gpu=options.gpu, + providers_list=list(options.providers or []), + client_timeout=options.timeout, + ready_timeout=options.ready_timeout, + ) + inner = OpenShellSandboxSession(state=state, openshell_client=os_client) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + async def delete(self, session: SandboxSession) -> SandboxSession: + inner = session._inner + if not isinstance(inner, OpenShellSandboxSession): + raise TypeError("OpenShellSandboxClient.delete expects an OpenShellSandboxSession") + try: + await inner._run_sync(inner._openshell_client.delete, inner.state.sandbox_name) + except Exception as exc: + logger.warning( + "Failed to delete OpenShell sandbox.", + extra={"sandbox_name": inner.state.sandbox_name}, + exc_info=exc, + ) + return session + + async def resume(self, state: SandboxSessionState) -> SandboxSession: + if not isinstance(state, OpenShellSandboxSessionState): + raise TypeError("OpenShellSandboxClient.resume expects an OpenShellSandboxSessionState") + + os_client = _build_openshell_client(state) + reconnected = False + loop = asyncio.get_running_loop() + try: + sandbox_ref = await loop.run_in_executor( + None, + functools.partial(os_client.get, state.sandbox_name), + ) + state.sandbox_id = str(sandbox_ref.id) + reconnected = True + except Exception: + sandbox_ref = await loop.run_in_executor( + None, functools.partial(os_client.create, spec=None) + ) + sandbox_ref = await loop.run_in_executor( + None, + functools.partial( + os_client.wait_ready, + sandbox_ref.name, + timeout_seconds=state.ready_timeout, + ), + ) + state.sandbox_id = str(sandbox_ref.id) + state.sandbox_name = str(sandbox_ref.name) + state.workspace_root_ready = False + + inner = OpenShellSandboxSession(state=state, openshell_client=os_client) + inner._set_start_state_preserved(reconnected, system=reconnected) + return self._wrap_session(inner, instrumentation=self._instrumentation) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + """Deserialize an OpenShell session state from a JSON-compatible payload.""" + return OpenShellSandboxSessionState.model_validate(payload) + + +__all__ = [ + "OpenShellSandboxClient", + "OpenShellSandboxClientOptions", + "OpenShellSandboxSession", + "OpenShellSandboxSessionState", +] diff --git a/tests/extensions/sandbox/test_openshell.py b/tests/extensions/sandbox/test_openshell.py new file mode 100644 index 0000000000..74948f0e30 --- /dev/null +++ b/tests/extensions/sandbox/test_openshell.py @@ -0,0 +1,703 @@ +"""Tests for the OpenShell sandbox backend.""" + +from __future__ import annotations + +import base64 +import io +import tarfile +import uuid +from typing import Any + +import pytest + +import agents.extensions.sandbox.openshell.sandbox as _openshell_mod +from agents.extensions.sandbox.openshell import ( + OpenShellSandboxClient, + OpenShellSandboxClientOptions, + OpenShellSandboxSession, + OpenShellSandboxSessionState, +) +from agents.sandbox.errors import ExecTransportError +from agents.sandbox.manifest import Manifest +from agents.sandbox.snapshot import NoopSnapshot + +# --------------------------------------------------------------------------- +# Fake helpers +# --------------------------------------------------------------------------- + + +class _FakeOpenShellExecResult: + """Mimics ``openshell.sandbox.ExecResult``.""" + + def __init__(self, exit_code: int = 0, stdout: str = "", stderr: str = "") -> None: + self.exit_code = exit_code + self.stdout = stdout + self.stderr = stderr + + +class _FakeSandboxRef: + """Mimics ``openshell.sandbox.SandboxRef``.""" + + def __init__(self, id: str = "", name: str = "", phase: int = 0) -> None: + self.id = id + self.name = name + self.phase = phase + + +class _FakeOpenShellClient: + """Mimics ``openshell.sandbox.SandboxClient`` for testing.""" + + def __init__(self) -> None: + self.create_calls: list[dict[str, Any]] = [] + self.get_calls: list[str] = [] + self.delete_calls: list[str] = [] + self.wait_ready_calls: list[tuple[str, float]] = [] + self.exec_calls: list[tuple[str, list[str], dict[str, Any]]] = [] + self.close_calls: int = 0 + self._exec_result = _FakeOpenShellExecResult() + self._get_result = _FakeSandboxRef() + self._exec_error: BaseException | None = None + self._get_error: BaseException | None = None + + def create(self, *, spec: Any = None) -> _FakeSandboxRef: + self.create_calls.append({"spec": spec}) + return _FakeSandboxRef(id="sandbox-id-1", name="sandbox-name-1", phase=2) + + def get(self, sandbox_name: str) -> _FakeSandboxRef: + self.get_calls.append(sandbox_name) + if self._get_error is not None: + raise self._get_error + return self._get_result + + def delete(self, sandbox_name: str) -> bool: + self.delete_calls.append(sandbox_name) + return True + + def wait_ready(self, sandbox_name: str, *, timeout_seconds: float = 300.0) -> _FakeSandboxRef: + self.wait_ready_calls.append((sandbox_name, timeout_seconds)) + return _FakeSandboxRef(id="sandbox-id-1", name=sandbox_name, phase=2) + + def exec( + self, + sandbox_id: str, + command: list[str], + **kwargs: Any, + ) -> _FakeOpenShellExecResult: + self.exec_calls.append((sandbox_id, command, kwargs)) + if self._exec_error is not None: + raise self._exec_error + return self._exec_result + + def close(self) -> None: + self.close_calls += 1 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_state( + *, + sandbox_id: str = "test-sandbox-id", + sandbox_name: str = "test-sandbox-name", + manifest_root: str = "/workspace", + workspace_root_ready: bool = False, + base_envs: dict[str, str] | None = None, +) -> OpenShellSandboxSessionState: + """Build a minimal session state for tests.""" + return OpenShellSandboxSessionState( + session_id=uuid.uuid4(), + sandbox_id=sandbox_id, + sandbox_name=sandbox_name, + manifest=Manifest(root=manifest_root), + snapshot=NoopSnapshot(id="snapshot"), + workspace_root_ready=workspace_root_ready, + base_envs=base_envs or {}, + ) + + +def _make_session( + *, + state: OpenShellSandboxSessionState | None = None, + client: _FakeOpenShellClient | None = None, + workspace_ready: bool = False, +) -> tuple[OpenShellSandboxSession, _FakeOpenShellClient]: + """Build a session with a fake client for testing.""" + if state is None: + state = _make_state(workspace_root_ready=workspace_ready) + if client is None: + client = _FakeOpenShellClient() + session = OpenShellSandboxSession(state=state, openshell_client=client) + if workspace_ready: + session._workspace_root_ready = True + return session, client + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestOpenShellReExports: + """Verify that public symbols are exported through the package hierarchy.""" + + def test_openshell_package_re_exports_backend_symbols(self) -> None: + """The openshell __init__.py should re-export all four classes.""" + from agents.extensions.sandbox import openshell + + assert hasattr(openshell, "OpenShellSandboxClient") + assert hasattr(openshell, "OpenShellSandboxClientOptions") + assert hasattr(openshell, "OpenShellSandboxSession") + assert hasattr(openshell, "OpenShellSandboxSessionState") + + def test_openshell_extension_re_exports_symbols(self) -> None: + """The sandbox __init__.py should conditionally export OpenShell symbols.""" + from agents.extensions import sandbox as sandbox_ext + + # The import may or may not succeed depending on whether the openshell + # extra is installed, but the names should be present when it does. + for name in ( + "OpenShellSandboxClient", + "OpenShellSandboxClientOptions", + "OpenShellSandboxSession", + "OpenShellSandboxSessionState", + ): + assert name in sandbox_ext.__all__ + + +class TestOpenShellClientOptions: + """Options dataclass behavior.""" + + def test_openshell_client_options_defaults(self) -> None: + """Default values should match the documented specification.""" + opts = OpenShellSandboxClientOptions() + assert opts.type == "openshell" + assert opts.cluster is None + assert opts.endpoint is None + assert opts.tls_ca_path is None + assert opts.tls_cert_path is None + assert opts.tls_key_path is None + assert opts.image is None + assert opts.envs is None + assert opts.gpu is False + assert opts.providers is None + assert opts.timeout == 30.0 + assert opts.ready_timeout == 120.0 + + def test_openshell_client_options_with_values(self) -> None: + """Custom values should be stored correctly.""" + opts = OpenShellSandboxClientOptions( + cluster="my-cluster", + endpoint="localhost:50051", + tls_ca_path="/certs/ca.crt", + tls_cert_path="/certs/tls.crt", + tls_key_path="/certs/tls.key", + image="quay.io/myimage:latest", + envs={"KEY": "value"}, + gpu=True, + providers=["vllm"], + timeout=60.0, + ready_timeout=300.0, + ) + assert opts.cluster == "my-cluster" + assert opts.endpoint == "localhost:50051" + assert opts.tls_ca_path == "/certs/ca.crt" + assert opts.tls_cert_path == "/certs/tls.crt" + assert opts.tls_key_path == "/certs/tls.key" + assert opts.image == "quay.io/myimage:latest" + assert opts.envs == {"KEY": "value"} + assert opts.gpu is True + assert opts.providers == ["vllm"] + assert opts.timeout == 60.0 + assert opts.ready_timeout == 300.0 + + +class TestOpenShellSessionState: + """Session state serialization.""" + + def test_openshell_session_state_round_trip(self) -> None: + """State should survive a serialize-then-deserialize round trip.""" + state = _make_state( + sandbox_id="abc-123", + sandbox_name="my-sandbox", + base_envs={"FOO": "bar"}, + ) + payload = state.model_dump(mode="json") + restored = OpenShellSandboxSessionState.model_validate(payload) + assert restored.type == "openshell" + assert restored.sandbox_id == "abc-123" + assert restored.sandbox_name == "my-sandbox" + assert restored.base_envs == {"FOO": "bar"} + assert restored.session_id == state.session_id + + def test_openshell_deserialize_session_state(self) -> None: + """The client deserialize_session_state method should produce the correct type.""" + client = OpenShellSandboxClient() + state = _make_state(sandbox_id="deser-id", sandbox_name="deser-name") + payload = state.model_dump(mode="json") + restored = client.deserialize_session_state(payload) + assert isinstance(restored, OpenShellSandboxSessionState) + assert restored.sandbox_id == "deser-id" + + +class TestOpenShellExec: + """Exec command plumbing.""" + + @pytest.mark.asyncio + async def test_openshell_exec_passes_command_as_list(self) -> None: + """The exec call should forward command parts as a list and use sandbox_id.""" + session, client = _make_session(workspace_ready=True) + client._exec_result = _FakeOpenShellExecResult(exit_code=0, stdout="hello", stderr="") + + result = await session._exec_internal("echo", "hello") + + assert result.exit_code == 0 + assert result.stdout == b"hello" + assert len(client.exec_calls) == 1 + + call_sandbox_id, call_cmd, call_kwargs = client.exec_calls[0] + assert call_sandbox_id == session.state.sandbox_id + assert call_cmd == ["echo", "hello"] + + @pytest.mark.asyncio + async def test_openshell_exec_uses_manifest_root_after_workspace_ready(self) -> None: + """When workspace is ready, workdir should be set to the manifest root.""" + session, client = _make_session(workspace_ready=True) + client._exec_result = _FakeOpenShellExecResult() + + await session._exec_internal("ls") + + _, _, kwargs = client.exec_calls[0] + assert kwargs["workdir"] == session.state.manifest.root + + @pytest.mark.asyncio + async def test_openshell_exec_omits_cwd_until_workspace_ready(self) -> None: + """Before the workspace is ready, workdir should be None.""" + session, client = _make_session(workspace_ready=False) + client._exec_result = _FakeOpenShellExecResult() + + await session._exec_internal("whoami") + + _, _, kwargs = client.exec_calls[0] + assert kwargs["workdir"] is None + + @pytest.mark.asyncio + async def test_openshell_exec_wraps_grpc_error_as_transport_error(self) -> None: + """A RuntimeError from the gRPC layer should become ExecTransportError.""" + session, client = _make_session(workspace_ready=True) + client._exec_error = RuntimeError("gRPC unavailable") + + with pytest.raises(ExecTransportError) as exc_info: + await session._exec_internal("failing-cmd") + + assert "gRPC unavailable" in str(exc_info.value.__cause__) + + +# --------------------------------------------------------------------------- +# Client tests +# --------------------------------------------------------------------------- + + +class TestOpenShellClientCreate: + """Client create() plumbing.""" + + @pytest.mark.asyncio + async def test_openshell_client_create_passes_spec_options( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """create() should forward options to the fake client and return a session.""" + fake_client = _FakeOpenShellClient() + monkeypatch.setattr( + _openshell_mod, + "_resolve_openshell_client", + lambda options: fake_client, + ) + monkeypatch.setattr( + _openshell_mod, + "_build_sandbox_spec", + lambda options: {"mock_spec": True}, + ) + + client = OpenShellSandboxClient() + options = OpenShellSandboxClientOptions( + endpoint="localhost:50051", + image="quay.io/test:latest", + envs={"K": "V"}, + ) + session = await client.create(options=options) + + try: + assert len(fake_client.create_calls) == 1 + assert fake_client.create_calls[0]["spec"] == {"mock_spec": True} + assert isinstance(session._inner, OpenShellSandboxSession) + finally: + await session.aclose() + + @pytest.mark.asyncio + async def test_openshell_client_create_waits_for_ready( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """create() should call wait_ready with the sandbox name.""" + fake_client = _FakeOpenShellClient() + monkeypatch.setattr( + _openshell_mod, + "_resolve_openshell_client", + lambda options: fake_client, + ) + monkeypatch.setattr( + _openshell_mod, + "_build_sandbox_spec", + lambda options: None, + ) + + client = OpenShellSandboxClient() + options = OpenShellSandboxClientOptions(ready_timeout=60.0) + session = await client.create(options=options) + + try: + assert len(fake_client.wait_ready_calls) == 1 + sandbox_name, timeout = fake_client.wait_ready_calls[0] + assert sandbox_name == "sandbox-name-1" + assert timeout == 60.0 + finally: + await session.aclose() + + +class TestOpenShellClientResume: + """Client resume() plumbing.""" + + @pytest.mark.asyncio + async def test_openshell_resume_reconnects_existing_sandbox( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """resume() should call get() and update the sandbox_id when the sandbox exists.""" + fake_client = _FakeOpenShellClient() + fake_client._get_result = _FakeSandboxRef( + id="reconnected-id", name="test-sandbox-name", phase=2 + ) + monkeypatch.setattr( + _openshell_mod, + "_build_openshell_client", + lambda state: fake_client, + ) + + state = _make_state( + sandbox_id="old-id", + sandbox_name="test-sandbox-name", + workspace_root_ready=True, + ) + client = OpenShellSandboxClient() + session = await client.resume(state) + + try: + assert len(fake_client.get_calls) == 1 + assert fake_client.get_calls[0] == "test-sandbox-name" + inner = session._inner + assert isinstance(inner, OpenShellSandboxSession) + assert inner.state.sandbox_id == "reconnected-id" + finally: + await session.aclose() + + @pytest.mark.asyncio + async def test_openshell_resume_recreates_when_sandbox_missing( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """resume() should fall back to create() when get() raises.""" + fake_client = _FakeOpenShellClient() + fake_client._get_error = RuntimeError("sandbox not found") + monkeypatch.setattr( + _openshell_mod, + "_build_openshell_client", + lambda state: fake_client, + ) + + state = _make_state( + sandbox_id="old-id", + sandbox_name="gone-sandbox", + workspace_root_ready=True, + ) + client = OpenShellSandboxClient() + session = await client.resume(state) + + try: + # get() was attempted, then create() was called as fallback. + assert len(fake_client.get_calls) == 1 + assert len(fake_client.create_calls) == 1 + inner = session._inner + assert isinstance(inner, OpenShellSandboxSession) + # The state should reflect the new sandbox. + assert inner.state.sandbox_id == "sandbox-id-1" + assert inner.state.sandbox_name == "sandbox-name-1" + assert inner.state.workspace_root_ready is False + finally: + await session.aclose() + + +# --------------------------------------------------------------------------- +# Shutdown tests +# --------------------------------------------------------------------------- + + +class TestOpenShellShutdown: + """Session shutdown behavior.""" + + @pytest.mark.asyncio + async def test_openshell_shutdown_deletes_sandbox_best_effort(self) -> None: + """_shutdown_backend should call delete and close on the client.""" + session, client = _make_session(workspace_ready=True) + + await session._shutdown_backend() + + assert len(client.delete_calls) == 1 + assert client.delete_calls[0] == "test-sandbox-name" + assert client.close_calls == 1 + + @pytest.mark.asyncio + async def test_openshell_shutdown_logs_on_delete_failure(self) -> None: + """Shutdown should swallow delete errors and still call close.""" + session, client = _make_session(workspace_ready=True) + + def _failing_delete(name: str) -> bool: + raise RuntimeError("delete failed") + + client.delete = _failing_delete # type: ignore[assignment] + + # Should not raise. + await session._shutdown_backend() + + # close() should still be called despite delete failure. + assert client.close_calls == 1 + + +# --------------------------------------------------------------------------- +# Read / write tests +# --------------------------------------------------------------------------- + + +class TestOpenShellReadWrite: + """File read and write operations.""" + + @pytest.mark.asyncio + async def test_openshell_read_returns_file_content(self) -> None: + """_exec_internal base64 read should decode content correctly. + + Tests the read pipeline at the exec layer, bypassing path validation + which requires a remote runtime helper script. + """ + session, client = _make_session(workspace_ready=True) + expected_content = b"hello from sandbox" + encoded = base64.b64encode(expected_content).decode("ascii") + client._exec_result = _FakeOpenShellExecResult(exit_code=0, stdout=encoded, stderr="") + + result = await session._exec_internal("base64", "-w0", "--", "/workspace/test.txt") + raw = base64.b64decode(result.stdout) + + assert raw == expected_content + + @pytest.mark.asyncio + async def test_openshell_read_raises_not_found_on_nonzero_exit(self) -> None: + """A non-zero exit from the read command indicates the file is missing.""" + session, client = _make_session(workspace_ready=True) + client._exec_result = _FakeOpenShellExecResult( + exit_code=1, stdout="", stderr="No such file" + ) + + result = await session._exec_internal("base64", "-w0", "--", "/workspace/missing.txt") + assert result.exit_code == 1 + + @pytest.mark.asyncio + async def test_openshell_write_sends_base64_content(self) -> None: + """The write pipeline should produce exec calls for mkdir and base64 decode. + + Tests the write pipeline at the exec layer, bypassing path validation + which requires a remote runtime helper script. + """ + session, client = _make_session(workspace_ready=True) + client._exec_result = _FakeOpenShellExecResult(exit_code=0) + + payload = b"file content" + encoded = base64.b64encode(payload).decode("ascii") + await session._exec_internal("mkdir", "-p", "--", "/workspace") + import shlex + + write_cmd = ( + f"printf '%s' {shlex.quote(encoded)} | base64 -d > {shlex.quote('/workspace/out.txt')}" + ) + await session._exec_internal("sh", "-c", write_cmd) + + assert len(client.exec_calls) == 2 + + +# --------------------------------------------------------------------------- +# Running tests +# --------------------------------------------------------------------------- + + +class TestOpenShellRunning: + """Sandbox running status.""" + + @pytest.mark.asyncio + async def test_openshell_running_returns_true_when_ready(self) -> None: + """running() should return True when the phase indicates ready.""" + session, client = _make_session(workspace_ready=True) + client._get_result = _FakeSandboxRef( + id="test-sandbox-id", name="test-sandbox-name", phase=2 + ) + + result = await session.running() + + assert result is True + + @pytest.mark.asyncio + async def test_openshell_running_returns_false_when_not_ready(self) -> None: + """running() should return False when the phase is not ready.""" + session, client = _make_session(workspace_ready=True) + client._get_result = _FakeSandboxRef( + id="test-sandbox-id", name="test-sandbox-name", phase=3 + ) + + result = await session.running() + + assert result is False + + @pytest.mark.asyncio + async def test_openshell_running_returns_false_before_workspace_ready(self) -> None: + """running() should return False when workspace root is not yet prepared.""" + session, client = _make_session(workspace_ready=False) + + result = await session.running() + + assert result is False + + +# --------------------------------------------------------------------------- +# Workspace persistence tests +# --------------------------------------------------------------------------- + + +class TestOpenShellPersistence: + """Workspace persist/hydrate round trips.""" + + @pytest.mark.asyncio + async def test_openshell_tar_persistence_round_trip(self) -> None: + """persist_workspace should return valid tar data decoded from base64.""" + session, client = _make_session(workspace_ready=True) + + # Create a small in-memory tar archive. + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tf: + data = b"round-trip content" + info = tarfile.TarInfo(name="test.txt") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + tar_bytes = buf.getvalue() + encoded = base64.b64encode(tar_bytes).decode("ascii") + + client._exec_result = _FakeOpenShellExecResult(exit_code=0, stdout=encoded, stderr="") + + result_stream = await session.persist_workspace() + result_bytes = result_stream.read() + + # Verify the returned bytes form a valid tar archive. + with tarfile.open(fileobj=io.BytesIO(result_bytes), mode="r") as tf: + names = tf.getnames() + assert "test.txt" in names + + +# --------------------------------------------------------------------------- +# Start lifecycle tests +# --------------------------------------------------------------------------- + + +class TestOpenShellStartLifecycle: + """Session start and workspace preparation.""" + + @pytest.mark.asyncio + async def test_openshell_start_prepares_workspace_root(self) -> None: + """_prepare_backend_workspace should exec mkdir and set _workspace_root_ready.""" + session, client = _make_session(workspace_ready=False) + client._exec_result = _FakeOpenShellExecResult(exit_code=0) + + await session._prepare_backend_workspace() + + assert session._workspace_root_ready is True + # The first exec call should be a mkdir command. + assert len(client.exec_calls) >= 1 + _, cmd, _ = client.exec_calls[0] + assert cmd[0] == "mkdir" + + def test_openshell_skips_runtime_helpers(self) -> None: + """OpenShell sessions return no runtime helpers. + + OpenShell rejects command arguments containing newline characters, so the + multi-line RESOLVE_WORKSPACE_PATH_HELPER script cannot be installed via exec. + Path validation uses local normalization instead. + """ + session, _ = _make_session() + helpers = session._runtime_helpers() + + assert helpers == () + + +# --------------------------------------------------------------------------- +# Gateway resolution tests +# --------------------------------------------------------------------------- + + +class _FakeGatewayClient: + """Fake SandboxClient class for gateway resolution tests.""" + + _instances: list[_FakeGatewayClient] = [] + _from_cluster: str | None = None + + def __init__(self, endpoint: str = "", *, tls: Any = None, timeout: float = 30.0) -> None: + self.endpoint = endpoint + self.tls = tls + self.timeout = timeout + _FakeGatewayClient._instances.append(self) + + @classmethod + def from_active_cluster( + cls, *, cluster: str | None = None, timeout: float = 30.0 + ) -> _FakeGatewayClient: + instance = cls(endpoint="", timeout=timeout) + instance._from_cluster = cluster + return instance + + +class TestOpenShellGatewayResolution: + """Client connection resolution (endpoint vs cluster).""" + + def test_openshell_client_uses_explicit_endpoint(self, monkeypatch: pytest.MonkeyPatch) -> None: + """When an endpoint is set, the client should be created with it directly.""" + _FakeGatewayClient._instances = [] + monkeypatch.setattr( + _openshell_mod, + "_import_openshell_client", + lambda: _FakeGatewayClient, + ) + + options = OpenShellSandboxClientOptions(endpoint="my-host:50051") + result = _openshell_mod._resolve_openshell_client(options) + + assert isinstance(result, _FakeGatewayClient) + assert result.endpoint == "my-host:50051" + + def test_openshell_client_resolves_from_active_cluster( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """When no endpoint is set, the client should use from_active_cluster.""" + _FakeGatewayClient._instances = [] + monkeypatch.setattr( + _openshell_mod, + "_import_openshell_client", + lambda: _FakeGatewayClient, + ) + + options = OpenShellSandboxClientOptions(cluster="staging-cluster") + result = _openshell_mod._resolve_openshell_client(options) + + assert isinstance(result, _FakeGatewayClient) + assert result.endpoint == "" + assert result._from_cluster == "staging-cluster"