Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions openadapt_ml/training/grpo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ class GRPOConfig:
lora_alpha: LoRA alpha scaling factor.
num_rollouts_per_step: Group size N for GRPO advantage computation.
max_steps_per_episode: Maximum actions per rollout episode.
lora_checkpoint: Path to an existing LoRA adapter to resume from.
If set, loads the adapter via PeftModel.from_pretrained() instead
of creating a fresh LoRA. Useful for GRPO on top of an SFT LoRA.
temperature: Sampling temperature for action generation during rollouts.
server_url: URL of the WAA server for live environment interaction.
evaluate_url: URL of the evaluate server. If None, defaults to server_url.
task_ids: List of WAA task IDs to train on.
learning_rate: Optimizer learning rate for LoRA parameter updates.
num_training_steps: Total number of GRPO training steps (outer loop).
Expand All @@ -50,6 +54,7 @@ class GRPOConfig:
# LoRA
lora_r: int = 16
lora_alpha: int = 32
lora_checkpoint: str | None = None # Path to existing LoRA adapter to resume from

# GRPO-specific
num_rollouts_per_step: int = 8 # Group size N
Expand All @@ -58,6 +63,9 @@ class GRPOConfig:

# Environment
server_url: str = "http://localhost:5001"
evaluate_url: str | None = (
None # Separate evaluate endpoint; defaults to server_url
)
task_ids: list[str] = field(default_factory=list)
screen_size: tuple[int, int] = (1920, 1080) # (width, height)

Expand Down
7 changes: 6 additions & 1 deletion openadapt_ml/training/grpo/rollout_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def __init__(self, config: GRPOConfig) -> None:
)

self._config = config
self._adapter = WAALiveAdapter(WAALiveConfig(server_url=config.server_url))
self._adapter = WAALiveAdapter(
WAALiveConfig(
server_url=config.server_url,
evaluate_url=config.evaluate_url,
)
)
self._env = RLEnvironment(self._adapter)

@property
Expand Down
29 changes: 21 additions & 8 deletions openadapt_ml/training/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,14 @@ def _format_action_as_text(
def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]:
"""Load a VLM with LoRA using standard HuggingFace + PEFT.

If config.lora_checkpoint is set, loads an existing LoRA adapter via
PeftModel.from_pretrained() instead of creating a fresh one. This
enables GRPO fine-tuning on top of an SFT-trained LoRA.

Returns:
(model, processor) tuple. Model has LoRA adapters attached.
"""
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForVision2Seq, AutoProcessor

processor = AutoProcessor.from_pretrained(config.model_name)
Expand All @@ -233,13 +237,22 @@ def _load_model_and_processor(config: GRPOConfig) -> tuple[Any, Any]:

model = AutoModelForVision2Seq.from_pretrained(config.model_name, **load_kwargs)

lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
if config.lora_checkpoint:
logger.info("Loading existing LoRA from %s", config.lora_checkpoint)
model = PeftModel.from_pretrained(
model,
config.lora_checkpoint,
is_trainable=True,
)
else:
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

return model, processor
Expand Down
19 changes: 16 additions & 3 deletions openadapt_ml/training/grpo/verl_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,20 @@ def train_with_verl(config: GRPOConfig) -> None:
logger.info(" python -m vagen.train --config configs/train_waa_vagen.yaml")

raise NotImplementedError(
"verl-agent training requires running via VAGEN's training script. "
"See docs/verl_agent_decision.md for setup instructions. "
"Use build_vagen_config() to generate a compatible config dict."
"verl-agent training runs out-of-process via VAGEN's training script, "
"not through this function. Use the E2E orchestration script:\n"
"\n"
" python openadapt-evals/scripts/train_verl_e2e.py \\\n"
" --server-url http://localhost:5000 \\\n"
" --task-ids <TASK_ID> \\\n"
" --model Qwen/Qwen2.5-VL-7B-Instruct\n"
"\n"
"Or build a VAGEN config from GRPOConfig:\n"
" config_dict = build_vagen_config(config)\n"
"\n"
"See also:\n"
" - openadapt-evals/scripts/train_verl_e2e.py (573-line E2E script)\n"
" - openadapt-evals/configs/train_waa_vagen.yaml (Hydra config)\n"
" - openadapt-evals/scripts/setup_gpu_training.sh (GPU VM setup)\n"
" - docs/verl_agent_decision.md (architecture rationale)"
)
83 changes: 83 additions & 0 deletions scripts/run_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3
"""CLI entry point for standalone GRPO training.

Usage:
python scripts/run_grpo.py \\
--server-url http://localhost:5001 \\
--task-ids abc-123 def-456 \\
--model-name Qwen/Qwen2.5-VL-7B-Instruct \\
--num-steps 100 \\
--output-dir checkpoints/grpo_run1

# Resume from existing SFT LoRA:
python scripts/run_grpo.py \\
--server-url http://localhost:5001 \\
--task-ids abc-123 \\
--lora-checkpoint checkpoints/sft/step_500 \\
--num-steps 50
"""

from __future__ import annotations

import argparse
import logging
import sys

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)


def main() -> int:
parser = argparse.ArgumentParser(
description="Run standalone GRPO training against a WAA VM",
)
parser.add_argument("--server-url", default="http://localhost:5001")
parser.add_argument("--evaluate-url", default=None)
parser.add_argument("--task-ids", nargs="+", required=True)
parser.add_argument("--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct")
parser.add_argument("--lora-checkpoint", default=None)
parser.add_argument("--load-in-4bit", action="store_true", default=True)
parser.add_argument("--no-4bit", dest="load_in_4bit", action="store_false")
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--num-rollouts", type=int, default=8)
parser.add_argument("--max-steps-per-episode", type=int, default=15)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--num-steps", type=int, default=1000)
parser.add_argument("--save-every", type=int, default=50)
parser.add_argument("--output-dir", default="checkpoints/grpo")
args = parser.parse_args()

from openadapt_ml.training.grpo.config import GRPOConfig
from openadapt_ml.training.grpo.trainer import GRPOTrainer

config = GRPOConfig(
backend="standalone",
model_name=args.model_name,
load_in_4bit=args.load_in_4bit,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_checkpoint=args.lora_checkpoint,
num_rollouts_per_step=args.num_rollouts,
max_steps_per_episode=args.max_steps_per_episode,
temperature=args.temperature,
server_url=args.server_url,
evaluate_url=args.evaluate_url,
task_ids=args.task_ids,
learning_rate=args.learning_rate,
num_training_steps=args.num_steps,
save_every_steps=args.save_every,
output_dir=args.output_dir,
)

trainer = GRPOTrainer(config)
checkpoint = trainer.train()
print(f"Training complete. Final checkpoint: {checkpoint}")
return 0


if __name__ == "__main__":
sys.exit(main())
Loading
Loading