Skip to content

Commit 8444ca6

Browse files
committed
wip: first pass over using sandbox and snapshots
1 parent 99ad2ee commit 8444ca6

File tree

4 files changed

+196
-16
lines changed

4 files changed

+196
-16
lines changed

codegen-examples/examples/swebench_agent_run/entry_point.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from codegen.extensions.swebench.utils import SweBenchExample
2-
from codegen.extensions.swebench.harness import run_agent_on_entry
31
import modal
2+
from codegen.extensions.swebench.harness import run_agent_on_entry
3+
from codegen.extensions.swebench.utils import SweBenchExample
44

55
image = (
66
modal.Image.debian_slim(python_version="3.13")

codegen-examples/examples/swebench_agent_run/run_eval.py

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
import asyncio
22
import json
33
import traceback
4-
from pathlib import Path
54
import uuid
6-
import modal
7-
import click
5+
from collections import defaultdict
6+
from dataclasses import asdict
87
from datetime import datetime
9-
from codegen.extensions.swebench.utils import SWEBenchDataset, get_swe_bench_example, get_swe_bench_examples
8+
from pathlib import Path
9+
10+
import click
11+
import modal
1012
from codegen.extensions.swebench.report import generate_report
13+
from codegen.extensions.swebench.utils import (
14+
SWEBenchDataset,
15+
SweBenchExample,
16+
get_swe_bench_example,
17+
get_swe_bench_examples,
18+
)
19+
20+
from .sandbox import SandboxManager
1121

1222
PREDS_DNAME = Path(__file__).parent / "predictions"
1323
LOG_DIR = Path(__file__).parent / "logs"
@@ -61,11 +71,26 @@ async def process_batch(examples, batch_size=10):
6171
print("Traceback:")
6272
print("".join(error_info["traceback"]))
6373

64-
results.append({"instance_id": example.instance_id, "status": "error", "error_info": error_info})
74+
results.append(
75+
{
76+
"instance_id": example.instance_id,
77+
"status": "error",
78+
"error_info": error_info,
79+
}
80+
)
6581
else:
6682
if result is None:
6783
print(f"Warning: Null result for {example.instance_id}")
68-
results.append({"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}})
84+
results.append(
85+
{
86+
"instance_id": example.instance_id,
87+
"status": "error",
88+
"error_info": {
89+
"error_type": "NullResult",
90+
"error_message": "Process returned None",
91+
},
92+
}
93+
)
6994
else:
7095
results.append(result)
7196

@@ -81,14 +106,24 @@ async def process_batch(examples, batch_size=10):
81106
{
82107
"instance_id": example.instance_id,
83108
"status": "error",
84-
"error_info": {"error_type": type(e).__name__, "error_message": str(e), "traceback": traceback.format_exc(), "batch_failure": True},
109+
"error_info": {
110+
"error_type": type(e).__name__,
111+
"error_message": str(e),
112+
"traceback": traceback.format_exc(),
113+
"batch_failure": True,
114+
},
85115
}
86116
)
87117

88118
return results
89119

90120

91-
async def run_eval(use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None):
121+
async def run_eval(
122+
use_existing_preds: str | None,
123+
dataset: str,
124+
length: int,
125+
instance_id: str | None = None,
126+
):
92127
run_id = use_existing_preds or str(uuid.uuid4())
93128
predictions_dir = PREDS_DNAME / f"results_{run_id}"
94129
dataset = SWEBenchDataset(dataset)
@@ -157,13 +192,58 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in
157192
raise
158193

159194

195+
SANDBOX_SEMAPHORES = defaultdict(asyncio.Semaphore)
196+
197+
198+
async def run_example(sandbox_manager: SandboxManager, example: SweBenchExample):
199+
async with SANDBOX_SEMAPHORES[(example.repo, example.base_commit)]:
200+
async with sandbox_manager.get_sandbox(example) as sandbox:
201+
result = await sandbox.exec(
202+
"python3",
203+
"-c",
204+
f"from codegen.extensions.swebench.harness import run_agent_from_serialized_entry; run_agent_from_serialized_entry({json.dumps(asdict(example))})",
205+
)
206+
exit_code = await result.wait()
207+
if exit_code != 0:
208+
raise Exception(f"Sandbox exited with non-zero exit code {exit_code}")
209+
return result
210+
211+
212+
async def run_on_sandbox(use_existing_preds, dataset, length, instance_id):
213+
dataset = SWEBenchDataset(dataset)
214+
if instance_id:
215+
examples = [get_swe_bench_example(instance_id, dataset=dataset)]
216+
else:
217+
examples = get_swe_bench_examples(dataset=dataset, length=length)
218+
219+
sandbox_manager = SandboxManager()
220+
# TODO: remote execution should push results to the database. See: codegeon-on-oss/outputs/sql_output.py
221+
return await asyncio.gather(*(run_example(sandbox_manager, example) for example in examples))
222+
223+
160224
@click.command()
161-
@click.option("--use-existing-preds", help="The run ID of the existing predictions to use.", type=str, default=None)
162-
@click.option("--dataset", help="The dataset to use.", type=click.Choice([dataset.value for dataset in SWEBenchDataset]), default=SWEBenchDataset.LITE.value)
225+
@click.option(
226+
"--use-existing-preds",
227+
help="The run ID of the existing predictions to use.",
228+
type=str,
229+
default=None,
230+
)
231+
@click.option(
232+
"--dataset",
233+
help="The dataset to use.",
234+
type=click.Choice([dataset.value for dataset in SWEBenchDataset]),
235+
default=SWEBenchDataset.LITE.value,
236+
)
163237
@click.option("--length", help="The number of examples to process.", type=int, default=10)
164-
@click.option("--instance-id", help="The instance ID of the example to process.", type=str, default=None)
238+
@click.option(
239+
"--instance-id",
240+
help="The instance ID of the example to process.",
241+
type=str,
242+
default=None,
243+
)
165244
def run_eval_command(use_existing_preds, dataset, length, instance_id):
166-
asyncio.run(run_eval(use_existing_preds, dataset, length, instance_id))
245+
# asyncio.run(run_eval(use_existing_preds, dataset, length, instance_id))
246+
asyncio.run(run_on_sandbox(use_existing_preds, dataset, length, instance_id))
167247

168248

169249
if __name__ == "__main__":
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import io
2+
import json
3+
from collections import defaultdict
4+
from contextlib import asynccontextmanager
5+
6+
import modal
7+
from codegen.extensions.swebench.utils import SweBenchExample
8+
9+
BASE_IMAGE: modal.Image = modal.Image.debian_slim(python_version="3.13").apt_install("git").pip_install("fastapi[standard]")
10+
11+
SNAPSHOT_META_VOLUME = modal.Volume.from_name("swebench-agent-snapshot-volume", create_if_missing=True)
12+
SNAPSHOT_META_FILE_PATH: str = "/root/snapshot_meta.json"
13+
14+
15+
try:
16+
# To ensure secrets are consistent across runs, we look up existing secret
17+
secret = modal.Secret.from_name("swebench-agent-run-secrets")
18+
except modal.exception.NotFoundError:
19+
secret = modal.Secret.from_dotenv()
20+
21+
app = modal.App.lookup(name="swebench-agent-run", create_if_missing=True)
22+
23+
24+
class SandboxManager:
25+
keep_alive: bool
26+
27+
def __init__(self, keep_alive: bool = False):
28+
self.keep_alive = keep_alive
29+
30+
async def read_snapshot_meta(self) -> dict[str, dict[str, str]]:
31+
bytes_io = io.BytesIO()
32+
try:
33+
await SNAPSHOT_META_VOLUME.read_file_into_fileobj(SNAPSHOT_META_FILE_PATH, bytes_io)
34+
snapshot_meta = json.loads(bytes_io.getvalue().decode("utf-8"))
35+
except FileNotFoundError:
36+
snapshot_meta = {}
37+
return defaultdict(lambda: defaultdict(lambda: None), snapshot_meta)
38+
39+
async def update_snapshot_meta(self, example: SweBenchExample, snapshot_uid: str):
40+
snapshot_meta = await self.read_snapshot_meta()
41+
snapshot_meta[example.repo][example.base_commit] = snapshot_uid
42+
async with SNAPSHOT_META_VOLUME.batch_upload() as upload:
43+
await upload.put_file(
44+
io.BytesIO(json.dumps(snapshot_meta).encode("utf-8")),
45+
SNAPSHOT_META_FILE_PATH,
46+
)
47+
await SNAPSHOT_META_VOLUME.commit()
48+
49+
async def create_sandbox(self, example: SweBenchExample) -> modal.Sandbox:
50+
snapshot_meta = await self.read_snapshot_meta()
51+
existing_snapshot_uid = snapshot_meta[example.repo][example.base_commit]
52+
if existing_snapshot_uid:
53+
return await modal.Sandbox._experimental_from_snapshot(existing_snapshot_uid)
54+
55+
# TODO: test if this get local version works / add ability to install specific version
56+
with modal.enable_output():
57+
return await modal.Sandbox.create(
58+
app=app,
59+
image=BASE_IMAGE.add_local_python_source("codegen"),
60+
secrets=[secret],
61+
tags={"repo": example.repo, "commit": example.base_commit},
62+
)
63+
64+
@asynccontextmanager
65+
async def get_sandbox(self, example: SweBenchExample):
66+
async for sandbox in modal.Sandbox.list(
67+
app_id=app.app_id,
68+
tags={"repo": example.repo, "commit": example.base_commit},
69+
):
70+
break
71+
else:
72+
sandbox = await self.create_sandbox(example)
73+
74+
try:
75+
await sandbox.wait()
76+
yield sandbox
77+
finally:
78+
if not self.keep_alive:
79+
# Killing sandbox, so take a snapshot and save it
80+
await sandbox.exec(
81+
"bash",
82+
"-c",
83+
f"cd /root/tmp/{example.repo}; git stash", # cheeky little stash
84+
)
85+
snapshot = await sandbox._experimental_snapshot() # commit any codegen updates
86+
await self.update_snapshot_meta(example, snapshot.object_id)
87+
88+
# Codebase.from_repo doesn't use git to fetch/checkout the repo.
89+
# We could replace this with our own git commands to control the file state
90+
await sandbox.terminate()

src/codegen/extensions/swebench/harness.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def show_problems(dataset):
4848
print(f"{inst}: {problem}")
4949

5050

51-
def run_agent_on_entry(entry: SweBenchExample):
51+
def run_agent_from_serialized_entry(serialized_entry: str):
52+
entry = json.loads(serialized_entry)
53+
return run_agent_on_entry(SweBenchExample(**entry))
54+
55+
56+
def run_agent_on_entry(entry: SweBenchExample, tmp_dir="/root/tmp"):
5257
"""Process one `entry` from SWE Bench using the LLM `models` at the
5358
given `temperature`. Set `model_name_or_path` in the result json.
5459
"""
@@ -63,7 +68,12 @@ def run_agent_on_entry(entry: SweBenchExample):
6368

6469
gold_files = files_in_patch(entry.patch)
6570

66-
codebase = Codebase.from_repo(repo_full_name=entry.repo, commit=base_commit, language="python") # check out the repo
71+
codebase = Codebase.from_repo(
72+
repo_full_name=entry.repo,
73+
commit=base_commit,
74+
language="python",
75+
tmp_dir=tmp_dir,
76+
) # check out the repo
6777

6878
agent = CodeAgent(codebase=codebase)
6979

0 commit comments

Comments
 (0)