From 6519d3f92588c9be4cbb94654a7f586fafbc3e6e Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Mon, 16 Mar 2026 22:08:41 -0400 Subject: [PATCH 1/2] feat: add evaluate_url, lora_checkpoint, validation script, and CLI for GRPO training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add evaluate_url field to GRPOConfig for separate evaluate endpoint - Add lora_checkpoint field to resume GRPO from existing SFT LoRA adapter - Pass evaluate_url through rollout collector to WAALiveConfig - Load existing LoRA via PeftModel.from_pretrained() when lora_checkpoint set - Update verl_backend.py error message with actionable instructions - Add 5-phase validation script (connectivity → rollout → inference → train → multi-step) - Add CLI entry point (scripts/run_grpo.py) for running GRPO without writing Python Co-Authored-By: Claude Opus 4.6 (1M context) --- openadapt_ml/training/grpo/config.py | 6 + .../training/grpo/rollout_collector.py | 7 +- openadapt_ml/training/grpo/trainer.py | 29 +- openadapt_ml/training/grpo/verl_backend.py | 19 +- scripts/run_grpo.py | 83 +++++ scripts/validate_grpo_waa.py | 312 ++++++++++++++++++ 6 files changed, 444 insertions(+), 12 deletions(-) create mode 100644 scripts/run_grpo.py create mode 100644 scripts/validate_grpo_waa.py diff --git a/openadapt_ml/training/grpo/config.py b/openadapt_ml/training/grpo/config.py index 37ae4e5..f5d349b 100644 --- a/openadapt_ml/training/grpo/config.py +++ b/openadapt_ml/training/grpo/config.py @@ -30,8 +30,12 @@ class GRPOConfig: lora_alpha: LoRA alpha scaling factor. num_rollouts_per_step: Group size N for GRPO advantage computation. max_steps_per_episode: Maximum actions per rollout episode. + lora_checkpoint: Path to an existing LoRA adapter to resume from. + If set, loads the adapter via PeftModel.from_pretrained() instead + of creating a fresh LoRA. Useful for GRPO on top of an SFT LoRA. temperature: Sampling temperature for action generation during rollouts. server_url: URL of the WAA server for live environment interaction. + evaluate_url: URL of the evaluate server. If None, defaults to server_url. task_ids: List of WAA task IDs to train on. learning_rate: Optimizer learning rate for LoRA parameter updates. num_training_steps: Total number of GRPO training steps (outer loop). @@ -50,6 +54,7 @@ class GRPOConfig: # LoRA lora_r: int = 16 lora_alpha: int = 32 + lora_checkpoint: str | None = None # Path to existing LoRA adapter to resume from # GRPO-specific num_rollouts_per_step: int = 8 # Group size N @@ -58,6 +63,7 @@ class GRPOConfig: # Environment server_url: str = "http://localhost:5001" + evaluate_url: str | None = None # Separate evaluate endpoint; defaults to server_url task_ids: list[str] = field(default_factory=list) screen_size: tuple[int, int] = (1920, 1080) # (width, height) diff --git a/openadapt_ml/training/grpo/rollout_collector.py b/openadapt_ml/training/grpo/rollout_collector.py index aa0a5fc..05bda00 100644 --- a/openadapt_ml/training/grpo/rollout_collector.py +++ b/openadapt_ml/training/grpo/rollout_collector.py @@ -81,7 +81,12 @@ def __init__(self, config: GRPOConfig) -> None: ) self._config = config - self._adapter = WAALiveAdapter(WAALiveConfig(server_url=config.server_url)) + self._adapter = WAALiveAdapter( + WAALiveConfig( + server_url=config.server_url, + evaluate_url=config.evaluate_url, + ) + ) self._env = RLEnvironment(self._adapter) @property diff --git a/openadapt_ml/training/grpo/trainer.py b/openadapt_ml/training/grpo/trainer.py index f22b2e4..cee3fd0 100644 --- a/openadapt_ml/training/grpo/trainer.py +++ b/openadapt_ml/training/grpo/trainer.py @@ -210,10 +210,14 @@ def _format_action_as_text( def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]: """Load a VLM with LoRA using standard HuggingFace + PEFT. + If config.lora_checkpoint is set, loads an existing LoRA adapter via + PeftModel.from_pretrained() instead of creating a fresh one. This + enables GRPO fine-tuning on top of an SFT-trained LoRA. + Returns: (model, processor) tuple. Model has LoRA adapters attached. """ - from peft import LoraConfig, get_peft_model + from peft import LoraConfig, PeftModel, get_peft_model from transformers import AutoModelForVision2Seq, AutoProcessor processor = AutoProcessor.from_pretrained(config.model_name) @@ -233,13 +237,22 @@ def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]: model = AutoModelForVision2Seq.from_pretrained(config.model_name, **load_kwargs) - lora_config = LoraConfig( - r=config.lora_r, - lora_alpha=config.lora_alpha, - target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], - task_type="CAUSAL_LM", - ) - model = get_peft_model(model, lora_config) + if config.lora_checkpoint: + logger.info("Loading existing LoRA from %s", config.lora_checkpoint) + model = PeftModel.from_pretrained( + model, + config.lora_checkpoint, + is_trainable=True, + ) + else: + lora_config = LoraConfig( + r=config.lora_r, + lora_alpha=config.lora_alpha, + target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() return model, processor diff --git a/openadapt_ml/training/grpo/verl_backend.py b/openadapt_ml/training/grpo/verl_backend.py index 300e483..3904a5c 100644 --- a/openadapt_ml/training/grpo/verl_backend.py +++ b/openadapt_ml/training/grpo/verl_backend.py @@ -119,7 +119,20 @@ def train_with_verl(config: GRPOConfig) -> None: logger.info(" python -m vagen.train --config configs/train_waa_vagen.yaml") raise NotImplementedError( - "verl-agent training requires running via VAGEN's training script. " - "See docs/verl_agent_decision.md for setup instructions. " - "Use build_vagen_config() to generate a compatible config dict." + "verl-agent training runs out-of-process via VAGEN's training script, " + "not through this function. Use the E2E orchestration script:\n" + "\n" + " python openadapt-evals/scripts/train_verl_e2e.py \\\n" + " --server-url http://localhost:5000 \\\n" + " --task-ids \\\n" + " --model Qwen/Qwen2.5-VL-7B-Instruct\n" + "\n" + "Or build a VAGEN config from GRPOConfig:\n" + " config_dict = build_vagen_config(config)\n" + "\n" + "See also:\n" + " - openadapt-evals/scripts/train_verl_e2e.py (573-line E2E script)\n" + " - openadapt-evals/configs/train_waa_vagen.yaml (Hydra config)\n" + " - openadapt-evals/scripts/setup_gpu_training.sh (GPU VM setup)\n" + " - docs/verl_agent_decision.md (architecture rationale)" ) diff --git a/scripts/run_grpo.py b/scripts/run_grpo.py new file mode 100644 index 0000000..cb80f4a --- /dev/null +++ b/scripts/run_grpo.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +"""CLI entry point for standalone GRPO training. + +Usage: + python scripts/run_grpo.py \\ + --server-url http://localhost:5001 \\ + --task-ids abc-123 def-456 \\ + --model-name Qwen/Qwen2.5-VL-7B-Instruct \\ + --num-steps 100 \\ + --output-dir checkpoints/grpo_run1 + + # Resume from existing SFT LoRA: + python scripts/run_grpo.py \\ + --server-url http://localhost:5001 \\ + --task-ids abc-123 \\ + --lora-checkpoint checkpoints/sft/step_500 \\ + --num-steps 50 +""" + +from __future__ import annotations + +import argparse +import logging +import sys + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Run standalone GRPO training against a WAA VM", + ) + parser.add_argument("--server-url", default="http://localhost:5001") + parser.add_argument("--evaluate-url", default=None) + parser.add_argument("--task-ids", nargs="+", required=True) + parser.add_argument("--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct") + parser.add_argument("--lora-checkpoint", default=None) + parser.add_argument("--load-in-4bit", action="store_true", default=True) + parser.add_argument("--no-4bit", dest="load_in_4bit", action="store_false") + parser.add_argument("--lora-r", type=int, default=16) + parser.add_argument("--lora-alpha", type=int, default=32) + parser.add_argument("--num-rollouts", type=int, default=8) + parser.add_argument("--max-steps-per-episode", type=int, default=15) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--learning-rate", type=float, default=5e-6) + parser.add_argument("--num-steps", type=int, default=1000) + parser.add_argument("--save-every", type=int, default=50) + parser.add_argument("--output-dir", default="checkpoints/grpo") + args = parser.parse_args() + + from openadapt_ml.training.grpo.config import GRPOConfig + from openadapt_ml.training.grpo.trainer import GRPOTrainer + + config = GRPOConfig( + backend="standalone", + model_name=args.model_name, + load_in_4bit=args.load_in_4bit, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_checkpoint=args.lora_checkpoint, + num_rollouts_per_step=args.num_rollouts, + max_steps_per_episode=args.max_steps_per_episode, + temperature=args.temperature, + server_url=args.server_url, + evaluate_url=args.evaluate_url, + task_ids=args.task_ids, + learning_rate=args.learning_rate, + num_training_steps=args.num_steps, + save_every_steps=args.save_every, + output_dir=args.output_dir, + ) + + trainer = GRPOTrainer(config) + checkpoint = trainer.train() + print(f"Training complete. Final checkpoint: {checkpoint}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/validate_grpo_waa.py b/scripts/validate_grpo_waa.py new file mode 100644 index 0000000..7357f76 --- /dev/null +++ b/scripts/validate_grpo_waa.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +"""Phased validation of GRPO training against a WAA VM. + +Each phase builds on the previous one, with clear success criteria and +failure diagnostics. Run with --phase N to execute phases 1 through N. + +Phases: + 1. Connectivity: Verify WAA server is reachable (/screenshot, /evaluate) + 2. Single rollout: Reset environment, take one action, get reward + 3. Model inference: Load model, generate an action from a screenshot + 4. Single training step: Collect rollout group, compute loss, backward + 5. Multi-step training: Run 3 full GRPO steps, verify checkpoint saved + +Usage: + python scripts/validate_grpo_waa.py --server-url http://localhost:5001 --phase 3 + python scripts/validate_grpo_waa.py --server-url http://VM_IP:5000 --phase 5 --task-id + python scripts/validate_grpo_waa.py --mock --phase 4 # Use mock adapter (no VM) +""" + +from __future__ import annotations + +import argparse +import logging +import sys +import time + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) +logger = logging.getLogger("validate_grpo_waa") + + +def phase1_connectivity(server_url: str, evaluate_url: str | None) -> bool: + """Phase 1: Check WAA server connectivity.""" + import requests + + logger.info("=== Phase 1: Connectivity Check ===") + + # Check screenshot endpoint + try: + r = requests.get(f"{server_url}/screenshot", timeout=10) + if r.status_code == 200 and len(r.content) > 100: + logger.info(" /screenshot OK (%d bytes)", len(r.content)) + else: + logger.error(" /screenshot failed: status=%d, len=%d", r.status_code, len(r.content)) + return False + except Exception as e: + logger.error(" /screenshot unreachable: %s", e) + return False + + # Check evaluate endpoint + eval_base = evaluate_url or server_url + try: + r = requests.get(f"{eval_base}/probe", timeout=10) + logger.info(" /probe status=%d", r.status_code) + except Exception as e: + logger.warning(" /probe unreachable (non-fatal): %s", e) + + logger.info("Phase 1 PASSED") + return True + + +def phase2_single_rollout(server_url: str, evaluate_url: str | None, task_id: str, mock: bool) -> bool: + """Phase 2: Reset env, take one action, get reward.""" + logger.info("=== Phase 2: Single Rollout ===") + + if mock: + from openadapt_evals.adapters.waa.mock import WAAMockAdapter + adapter = WAAMockAdapter() + else: + from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig + adapter = WAALiveAdapter( + WAALiveConfig(server_url=server_url, evaluate_url=evaluate_url) + ) + + from openadapt_evals.adapters.base import BenchmarkAction + from openadapt_evals.adapters.rl_env import RLEnvironment + + env = RLEnvironment(adapter) + + # Reset + obs = env.reset(task_id=task_id) + if obs is None or obs.screenshot is None: + logger.error(" Reset returned no observation or screenshot") + return False + logger.info(" Reset OK, screenshot=%d bytes", len(obs.screenshot)) + + # Take one action + action = BenchmarkAction(type="click", x=500, y=400) + step = env.step(action) + logger.info(" Step OK: reward=%.2f, done=%s", step.reward, step.done) + + # Check screen size + logger.info(" Screen size: %s", env.screen_size) + + if hasattr(adapter, "close"): + adapter.close() + + logger.info("Phase 2 PASSED") + return True + + +def phase3_model_inference(server_url: str, evaluate_url: str | None, model_name: str, task_id: str, mock: bool) -> bool: + """Phase 3: Load model, generate action from screenshot.""" + logger.info("=== Phase 3: Model Inference ===") + + import io + + import torch + from PIL import Image + + from openadapt_ml.training.grpo.config import GRPOConfig + from openadapt_ml.training.grpo.trainer import ( + _build_agent_messages, + _load_model_and_processor, + _parse_vlm_output_to_action, + ) + + config = GRPOConfig( + model_name=model_name, + server_url=server_url, + evaluate_url=evaluate_url, + ) + + logger.info(" Loading model: %s", model_name) + t0 = time.time() + model, processor = _load_model_and_processor(config) + logger.info(" Model loaded in %.1fs", time.time() - t0) + + # Get a screenshot + if mock: + screenshot = Image.new("RGB", (1920, 1080), color=(50, 50, 80)) + else: + import requests + r = requests.get(f"{server_url}/screenshot", timeout=10) + screenshot = Image.open(io.BytesIO(r.content)) + + logger.info(" Screenshot: %s", screenshot.size) + + # Build prompt and generate + messages = _build_agent_messages("Click the Start button") + if hasattr(processor, "apply_chat_template"): + text_input = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + text_input = messages[-1]["content"] + + inputs = processor(text_input, images=[screenshot], return_tensors="pt") + inputs = {k: v.to(model.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True) + + decoded = processor.decode( + outputs[0][inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + ) + logger.info(" Model output: %s", decoded.strip()[:200]) + + action = _parse_vlm_output_to_action(decoded) + logger.info(" Parsed action: type=%s", action.type) + + logger.info("Phase 3 PASSED") + return True + + +def phase4_single_training_step( + server_url: str, + evaluate_url: str | None, + model_name: str, + task_id: str, + lora_checkpoint: str | None, + mock: bool, +) -> bool: + """Phase 4: Collect rollout group, compute loss, one gradient step.""" + logger.info("=== Phase 4: Single Training Step ===") + + import tempfile + + from openadapt_ml.training.grpo.config import GRPOConfig + from openadapt_ml.training.grpo.trainer import GRPOTrainer + + with tempfile.TemporaryDirectory() as tmpdir: + config = GRPOConfig( + model_name=model_name, + server_url=server_url, + evaluate_url=evaluate_url, + task_ids=[task_id], + lora_checkpoint=lora_checkpoint, + num_rollouts_per_step=2, # Small group for validation + max_steps_per_episode=3, # Short episodes + num_training_steps=1, + save_every_steps=1, + output_dir=tmpdir, + ) + + trainer = GRPOTrainer(config) + logger.info(" Config: rollouts=%d, max_steps=%d", config.num_rollouts_per_step, config.max_steps_per_episode) + + t0 = time.time() + checkpoint_path = trainer.train() + elapsed = time.time() - t0 + + logger.info(" Training step completed in %.1fs", elapsed) + logger.info(" Checkpoint: %s", checkpoint_path) + + # Verify checkpoint exists + from pathlib import Path + ckpt = Path(checkpoint_path) + if not ckpt.exists(): + logger.error(" Checkpoint directory missing!") + return False + + adapter_files = list(ckpt.glob("adapter_*")) + logger.info(" Checkpoint files: %s", [f.name for f in adapter_files]) + + logger.info("Phase 4 PASSED") + return True + + +def phase5_multi_step_training( + server_url: str, + evaluate_url: str | None, + model_name: str, + task_id: str, + lora_checkpoint: str | None, + mock: bool, +) -> bool: + """Phase 5: Run 3 GRPO steps, verify checkpoints.""" + logger.info("=== Phase 5: Multi-Step Training ===") + + import tempfile + from pathlib import Path + + from openadapt_ml.training.grpo.config import GRPOConfig + from openadapt_ml.training.grpo.trainer import GRPOTrainer + + with tempfile.TemporaryDirectory() as tmpdir: + config = GRPOConfig( + model_name=model_name, + server_url=server_url, + evaluate_url=evaluate_url, + task_ids=[task_id], + lora_checkpoint=lora_checkpoint, + num_rollouts_per_step=2, + max_steps_per_episode=3, + num_training_steps=3, + save_every_steps=1, + output_dir=tmpdir, + ) + + trainer = GRPOTrainer(config) + t0 = time.time() + checkpoint_path = trainer.train() + elapsed = time.time() - t0 + + logger.info(" 3 training steps completed in %.1fs", elapsed) + + # Verify all checkpoints + for step in [1, 2, 3]: + ckpt = Path(tmpdir) / f"step_{step}" + if ckpt.exists(): + logger.info(" step_%d checkpoint: OK", step) + else: + logger.warning(" step_%d checkpoint: MISSING", step) + + logger.info("Phase 5 PASSED") + return True + + +def main() -> int: + parser = argparse.ArgumentParser(description="Validate GRPO training against WAA") + parser.add_argument("--server-url", default="http://localhost:5001") + parser.add_argument("--evaluate-url", default=None) + parser.add_argument("--model-name", default="Qwen/Qwen2.5-VL-3B-Instruct") + parser.add_argument("--task-id", default="notepad_1") + parser.add_argument("--lora-checkpoint", default=None) + parser.add_argument("--phase", type=int, default=5, help="Run phases 1 through N") + parser.add_argument("--mock", action="store_true", help="Use mock adapter (no VM)") + args = parser.parse_args() + + phases = [ + (1, lambda: phase1_connectivity(args.server_url, args.evaluate_url)), + (2, lambda: phase2_single_rollout(args.server_url, args.evaluate_url, args.task_id, args.mock)), + (3, lambda: phase3_model_inference(args.server_url, args.evaluate_url, args.model_name, args.task_id, args.mock)), + (4, lambda: phase4_single_training_step(args.server_url, args.evaluate_url, args.model_name, args.task_id, args.lora_checkpoint, args.mock)), + (5, lambda: phase5_multi_step_training(args.server_url, args.evaluate_url, args.model_name, args.task_id, args.lora_checkpoint, args.mock)), + ] + + # Skip phase 1 when using mock (no server to connect to) + if args.mock: + phases = [(n, fn) for n, fn in phases if n != 1] + + for phase_num, phase_fn in phases: + if phase_num > args.phase: + break + try: + if not phase_fn(): + logger.error("Phase %d FAILED", phase_num) + return 1 + except Exception: + logger.exception("Phase %d raised an exception", phase_num) + return 1 + + logger.info("All phases through %d PASSED", args.phase) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From 53f80ee61b89d8e1d54df51a8500888b2743508a Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Tue, 17 Mar 2026 13:10:34 -0400 Subject: [PATCH 2/2] style: fix ruff formatting in config and validation script Co-Authored-By: Claude Opus 4.6 (1M context) --- openadapt_ml/training/grpo/config.py | 4 +- scripts/validate_grpo_waa.py | 70 ++++++++++++++++++++++++---- 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/openadapt_ml/training/grpo/config.py b/openadapt_ml/training/grpo/config.py index f5d349b..2298e5d 100644 --- a/openadapt_ml/training/grpo/config.py +++ b/openadapt_ml/training/grpo/config.py @@ -63,7 +63,9 @@ class GRPOConfig: # Environment server_url: str = "http://localhost:5001" - evaluate_url: str | None = None # Separate evaluate endpoint; defaults to server_url + evaluate_url: str | None = ( + None # Separate evaluate endpoint; defaults to server_url + ) task_ids: list[str] = field(default_factory=list) screen_size: tuple[int, int] = (1920, 1080) # (width, height) diff --git a/scripts/validate_grpo_waa.py b/scripts/validate_grpo_waa.py index 7357f76..770050e 100644 --- a/scripts/validate_grpo_waa.py +++ b/scripts/validate_grpo_waa.py @@ -43,7 +43,9 @@ def phase1_connectivity(server_url: str, evaluate_url: str | None) -> bool: if r.status_code == 200 and len(r.content) > 100: logger.info(" /screenshot OK (%d bytes)", len(r.content)) else: - logger.error(" /screenshot failed: status=%d, len=%d", r.status_code, len(r.content)) + logger.error( + " /screenshot failed: status=%d, len=%d", r.status_code, len(r.content) + ) return False except Exception as e: logger.error(" /screenshot unreachable: %s", e) @@ -61,15 +63,19 @@ def phase1_connectivity(server_url: str, evaluate_url: str | None) -> bool: return True -def phase2_single_rollout(server_url: str, evaluate_url: str | None, task_id: str, mock: bool) -> bool: +def phase2_single_rollout( + server_url: str, evaluate_url: str | None, task_id: str, mock: bool +) -> bool: """Phase 2: Reset env, take one action, get reward.""" logger.info("=== Phase 2: Single Rollout ===") if mock: from openadapt_evals.adapters.waa.mock import WAAMockAdapter + adapter = WAAMockAdapter() else: from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig + adapter = WAALiveAdapter( WAALiveConfig(server_url=server_url, evaluate_url=evaluate_url) ) @@ -101,7 +107,9 @@ def phase2_single_rollout(server_url: str, evaluate_url: str | None, task_id: st return True -def phase3_model_inference(server_url: str, evaluate_url: str | None, model_name: str, task_id: str, mock: bool) -> bool: +def phase3_model_inference( + server_url: str, evaluate_url: str | None, model_name: str, task_id: str, mock: bool +) -> bool: """Phase 3: Load model, generate action from screenshot.""" logger.info("=== Phase 3: Model Inference ===") @@ -133,6 +141,7 @@ def phase3_model_inference(server_url: str, evaluate_url: str | None, model_name screenshot = Image.new("RGB", (1920, 1080), color=(50, 50, 80)) else: import requests + r = requests.get(f"{server_url}/screenshot", timeout=10) screenshot = Image.open(io.BytesIO(r.content)) @@ -151,10 +160,12 @@ def phase3_model_inference(server_url: str, evaluate_url: str | None, model_name inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): - outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True) + outputs = model.generate( + **inputs, max_new_tokens=100, temperature=0.7, do_sample=True + ) decoded = processor.decode( - outputs[0][inputs["input_ids"].shape[1]:], + outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True, ) logger.info(" Model output: %s", decoded.strip()[:200]) @@ -197,7 +208,11 @@ def phase4_single_training_step( ) trainer = GRPOTrainer(config) - logger.info(" Config: rollouts=%d, max_steps=%d", config.num_rollouts_per_step, config.max_steps_per_episode) + logger.info( + " Config: rollouts=%d, max_steps=%d", + config.num_rollouts_per_step, + config.max_steps_per_episode, + ) t0 = time.time() checkpoint_path = trainer.train() @@ -208,6 +223,7 @@ def phase4_single_training_step( # Verify checkpoint exists from pathlib import Path + ckpt = Path(checkpoint_path) if not ckpt.exists(): logger.error(" Checkpoint directory missing!") @@ -283,10 +299,44 @@ def main() -> int: phases = [ (1, lambda: phase1_connectivity(args.server_url, args.evaluate_url)), - (2, lambda: phase2_single_rollout(args.server_url, args.evaluate_url, args.task_id, args.mock)), - (3, lambda: phase3_model_inference(args.server_url, args.evaluate_url, args.model_name, args.task_id, args.mock)), - (4, lambda: phase4_single_training_step(args.server_url, args.evaluate_url, args.model_name, args.task_id, args.lora_checkpoint, args.mock)), - (5, lambda: phase5_multi_step_training(args.server_url, args.evaluate_url, args.model_name, args.task_id, args.lora_checkpoint, args.mock)), + ( + 2, + lambda: phase2_single_rollout( + args.server_url, args.evaluate_url, args.task_id, args.mock + ), + ), + ( + 3, + lambda: phase3_model_inference( + args.server_url, + args.evaluate_url, + args.model_name, + args.task_id, + args.mock, + ), + ), + ( + 4, + lambda: phase4_single_training_step( + args.server_url, + args.evaluate_url, + args.model_name, + args.task_id, + args.lora_checkpoint, + args.mock, + ), + ), + ( + 5, + lambda: phase5_multi_step_training( + args.server_url, + args.evaluate_url, + args.model_name, + args.task_id, + args.lora_checkpoint, + args.mock, + ), + ), ] # Skip phase 1 when using mock (no server to connect to)