Skip to content

Commit af29675

Browse files
sjarmakclaude
andcommitted
feat: precision improvements — tighter prompt, pruning pass, parallel execution
- Add precision guidelines to SYSTEM_PROMPT_SUFFIX (exclude test files, docs, tangential code; aim for 1-5 files on simple bugs) - Add prune_oracle_cli() that runs a haiku pruning pass to filter irrelevant files from agent output - Add --prune flag to validate_on_contextbench.py - Add --parallel N flag with ThreadPoolExecutor for concurrent tasks - Refactor main loop into process_one_task() worker function Phase 1 baseline: composite=0.6426, recall=0.90, precision=0.30 Target: improve precision to push composite above 0.65 threshold. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ce7da6c commit af29675

File tree

2 files changed

+204
-34
lines changed

2 files changed

+204
-34
lines changed

scripts/context_retrieval_agent.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,16 @@ def _extract_clone_urls(dockerfile_content: str) -> List[Dict[str, str]]:
980980
}
981981

982982
SYSTEM_PROMPT_SUFFIX = """
983+
## Precision Guidelines
984+
- Include ONLY source files that would need to be **read or modified** to address the task.
985+
- Do NOT include test files, documentation, or configuration files unless the task \
986+
explicitly asks about testing, docs, or configuration.
987+
- Do NOT include files that merely import or reference the relevant code — only files \
988+
that contain the logic central to the task.
989+
- When in doubt, ask: "Would a developer need to open this file to fix/understand the issue?" \
990+
If no, exclude it.
991+
- Aim for 1-5 files for simple bugs, 3-10 for moderate tasks, 10+ only for large refactors.
992+
983993
## Output
984994
When you have identified all relevant files, output a JSON object with:
985995
```json
@@ -1264,6 +1274,119 @@ def _cli_error_metadata(model: str, backend: str, start_time: float) -> Dict[str
12641274
}
12651275

12661276

1277+
PRUNE_PROMPT = """\
1278+
You are a precision filter for code context retrieval. Given a task description \
1279+
and a list of predicted files, remove files that are NOT directly relevant.
1280+
1281+
## Rules
1282+
- Keep ONLY files that a developer would need to **read or modify** to address the task.
1283+
- Remove test files unless the task is specifically about testing.
1284+
- Remove documentation files unless the task is about docs.
1285+
- Remove files that merely import or reference the relevant code.
1286+
- When unsure, keep the file (recall > precision).
1287+
1288+
## Task
1289+
{task_description}
1290+
1291+
## Predicted Files
1292+
{file_list}
1293+
1294+
## Output
1295+
Return a JSON object with the filtered file list:
1296+
```json
1297+
{{
1298+
"files": [
1299+
{{"repo": "repo-name", "path": "relative/path/to/file"}}
1300+
],
1301+
"pruned_count": <number of files removed>,
1302+
"text": "Brief explanation of what was removed and why."
1303+
}}
1304+
```
1305+
"""
1306+
1307+
1308+
def prune_oracle_cli(
1309+
oracle: Dict[str, Any],
1310+
ctx: Dict[str, Any],
1311+
prune_model: str = "claude-haiku-4-5-20251001",
1312+
verbose: bool = False,
1313+
) -> Dict[str, Any]:
1314+
"""Run a pruning pass on the oracle output using a cheap model via CLI.
1315+
1316+
Asks a fast model to remove irrelevant files from the agent's output.
1317+
Returns the pruned oracle (or original if pruning fails).
1318+
"""
1319+
files = oracle.get("files", [])
1320+
if len(files) <= 3:
1321+
# Too few to prune — skip
1322+
if verbose:
1323+
log.info(" Prune: skipping (%d files, <= 3)", len(files))
1324+
return oracle
1325+
1326+
task_desc = ctx.get("seed_prompt", "") or ctx.get("instruction", "")
1327+
file_list = "\n".join(
1328+
f"- {f.get('repo', '?')}: {f.get('path', '?')}" for f in files
1329+
)
1330+
1331+
prompt = PRUNE_PROMPT.format(
1332+
task_description=task_desc[:3000],
1333+
file_list=file_list,
1334+
)
1335+
1336+
cmd = [
1337+
"claude", "-p", prompt,
1338+
"--output-format", "json",
1339+
"--model", prune_model,
1340+
"--dangerously-skip-permissions",
1341+
]
1342+
env = {k: v for k, v in os.environ.items() if k != "CLAUDECODE"}
1343+
1344+
if verbose:
1345+
log.info(" Prune: %d files -> calling %s", len(files), prune_model)
1346+
1347+
try:
1348+
result = subprocess.run(
1349+
cmd, capture_output=True, text=True, env=env, timeout=60,
1350+
)
1351+
except subprocess.TimeoutExpired:
1352+
log.warning(" Prune: timeout, keeping original")
1353+
return oracle
1354+
1355+
if result.returncode != 0:
1356+
log.warning(" Prune: CLI failed (rc=%d), keeping original", result.returncode)
1357+
return oracle
1358+
1359+
try:
1360+
cli_output = json.loads(result.stdout)
1361+
except (json.JSONDecodeError, ValueError):
1362+
log.warning(" Prune: failed to parse CLI output, keeping original")
1363+
return oracle
1364+
1365+
result_text = cli_output.get("result", "")
1366+
pruned = _extract_json_from_text(result_text)
1367+
if pruned is None or "files" not in pruned:
1368+
log.warning(" Prune: no valid JSON in output, keeping original")
1369+
return oracle
1370+
1371+
pruned_files = pruned.get("files", [])
1372+
prune_cost = cli_output.get("total_cost_usd", 0.0)
1373+
1374+
if verbose:
1375+
log.info(" Prune: %d -> %d files ($%.4f)",
1376+
len(files), len(pruned_files), prune_cost)
1377+
1378+
# Merge: keep pruned files but preserve original symbols/chain/text
1379+
result_oracle = dict(oracle)
1380+
result_oracle["files"] = pruned_files
1381+
result_oracle["_prune_metadata"] = {
1382+
"original_count": len(files),
1383+
"pruned_count": len(files) - len(pruned_files),
1384+
"prune_model": prune_model,
1385+
"prune_cost_usd": prune_cost,
1386+
}
1387+
return result_oracle
1388+
1389+
12671390
def build_user_message(
12681391
ctx: Dict[str, Any], repo_paths: Dict[str, Path]
12691392
) -> str:

scripts/validate_on_contextbench.py

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"""
4040

4141
import argparse
42+
import concurrent.futures
4243
import json
4344
import logging
4445
import os
@@ -732,6 +733,14 @@ def main() -> int:
732733
"--max-tasks", type=int, default=0,
733734
help="Process at most N tasks (0 = all)",
734735
)
736+
parser.add_argument(
737+
"--parallel", type=int, default=1,
738+
help="Number of tasks to run in parallel (default: 1)",
739+
)
740+
parser.add_argument(
741+
"--prune", action="store_true",
742+
help="Run a pruning pass with haiku to remove irrelevant files",
743+
)
735744
args = parser.parse_args()
736745
use_cli = not args.use_sdk
737746

@@ -844,46 +853,33 @@ def main() -> int:
844853
from context_retrieval_agent import SourcegraphClient
845854
sg = SourcegraphClient()
846855

847-
total_cost = 0.0
848-
trajectories = []
849-
evaluated_tasks = []
850-
851-
for i, task in enumerate(tasks):
852-
if args.max_tasks > 0 and i >= args.max_tasks:
853-
log.info("Max tasks limit reached (%d)", args.max_tasks)
854-
break
855-
if args.max_cost > 0 and total_cost >= args.max_cost:
856-
log.warning("Cost limit reached ($%.2f)", total_cost)
857-
break
858-
859-
instance_id = task.get("instance_id", f"task_{i}")
856+
# -- Per-task worker function (can run in parallel) --
857+
def process_one_task(task_tuple):
858+
idx, task = task_tuple
859+
instance_id = task.get("instance_id", f"task_{idx}")
860860
repo_url = task.get("repo_url", "")
861861
if not repo_url:
862-
# Construct from repo field (org/repo -> full URL)
863862
repo_slug = task.get("repo", "")
864863
if repo_slug and "/" in repo_slug:
865864
repo_url = f"https://github.com/{repo_slug}"
866865
commit = task.get("base_commit", task.get("commit", "HEAD"))
867866

868-
log.info("[%d/%d] %s", i + 1, len(tasks), instance_id)
869-
870867
if not repo_url:
871-
# Try to reconstruct from instance_id
872868
parts = instance_id.rsplit("-", 1)
873869
org_repo = parts[0].replace("__", "/") if parts else ""
874870
repo_url = f"https://github.com/{org_repo}" if org_repo else ""
875871

876872
if not repo_url:
877-
log.warning(" No repo URL, skipping")
878-
continue
873+
log.warning("[%d] No repo URL, skipping %s", idx + 1, instance_id)
874+
return None
875+
876+
log.info("[%d/%d] %s", idx + 1, len(tasks), instance_id)
879877

880-
# Clone repo
881878
repo_path = clone_for_contextbench(repo_url, commit)
882879
if not repo_path:
883-
log.warning(" Clone failed, skipping")
884-
continue
880+
log.warning("[%d] Clone failed, skipping %s", idx + 1, instance_id)
881+
return None
885882

886-
# Run agent
887883
try:
888884
result = run_retrieval_agent_on_cb_task(
889885
task, repo_path, client,
@@ -892,24 +888,75 @@ def main() -> int:
892888
use_cli=use_cli,
893889
)
894890
except Exception as e:
895-
log.error(" Agent failed: %s", e)
896-
continue
891+
log.error("[%d] Agent failed for %s: %s", idx + 1, instance_id, e)
892+
return None
893+
894+
# Optional pruning pass
895+
if args.prune:
896+
from context_retrieval_agent import prune_oracle_cli
897+
ctx_for_prune = {
898+
"seed_prompt": task.get("problem_statement", ""),
899+
"instruction": task.get("problem_statement", ""),
900+
}
901+
result["oracle"] = prune_oracle_cli(
902+
result["oracle"], ctx_for_prune, verbose=args.verbose,
903+
)
904+
prune_meta = result["oracle"].get("_prune_metadata", {})
905+
result["metadata"]["prune_cost_usd"] = prune_meta.get("prune_cost_usd", 0)
906+
result["metadata"]["cost_usd"] = (
907+
result["metadata"].get("cost_usd", 0) + prune_meta.get("prune_cost_usd", 0)
908+
)
897909

898-
total_cost += result["metadata"].get("cost_usd", 0)
910+
n_files = len(result["oracle"].get("files", []))
911+
log.info("[%d] %s -> %d files, $%.4f",
912+
idx + 1, instance_id, n_files, result["metadata"]["cost_usd"])
899913

900-
# Convert to trajectory
901914
traj = convert_to_trajectory(
902915
instance_id, result["oracle"],
903916
model_patch=task.get("patch", ""),
904917
)
905-
trajectories.append(traj)
906-
evaluated_tasks.append(task)
918+
return {"task": task, "traj": traj, "result": result}
907919

908-
n_files = len(result["oracle"].get("files", []))
909-
log.info(
910-
" -> %d files, $%.4f",
911-
n_files, result["metadata"]["cost_usd"],
912-
)
920+
# -- Apply limits --
921+
run_tasks = tasks
922+
if args.max_tasks > 0:
923+
run_tasks = tasks[:args.max_tasks]
924+
925+
# -- Execute tasks (parallel or sequential) --
926+
total_cost = 0.0
927+
trajectories = []
928+
evaluated_tasks = []
929+
930+
task_tuples = list(enumerate(run_tasks))
931+
n_parallel = max(1, args.parallel)
932+
933+
if n_parallel > 1 and len(task_tuples) > 1:
934+
log.info("Running %d tasks with %d workers", len(task_tuples), n_parallel)
935+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_parallel) as executor:
936+
futures = {executor.submit(process_one_task, t): t for t in task_tuples}
937+
for future in concurrent.futures.as_completed(futures):
938+
outcome = future.result()
939+
if outcome is None:
940+
continue
941+
total_cost += outcome["result"]["metadata"].get("cost_usd", 0)
942+
if args.max_cost > 0 and total_cost >= args.max_cost:
943+
log.warning("Cost limit reached ($%.2f), cancelling remaining", total_cost)
944+
for f in futures:
945+
f.cancel()
946+
break
947+
trajectories.append(outcome["traj"])
948+
evaluated_tasks.append(outcome["task"])
949+
else:
950+
for tt in task_tuples:
951+
if args.max_cost > 0 and total_cost >= args.max_cost:
952+
log.warning("Cost limit reached ($%.2f)", total_cost)
953+
break
954+
outcome = process_one_task(tt)
955+
if outcome is None:
956+
continue
957+
total_cost += outcome["result"]["metadata"].get("cost_usd", 0)
958+
trajectories.append(outcome["traj"])
959+
evaluated_tasks.append(outcome["task"])
913960

914961
if not trajectories:
915962
log.error("No tasks completed")

0 commit comments

Comments
 (0)