From 4a35f2b470aaac6f5ae0bb3b2532e77583a5c4e0 Mon Sep 17 00:00:00 2001 From: Kyle Avery <9327972+kyleavery@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:22:17 -0600 Subject: [PATCH 1/6] add exclude_think param to MultiTurnEnv --- tests/test_multiturn_env.py | 95 +++++++++++++++++++++++++++++++++ verifiers/envs/multiturn_env.py | 47 +++++++++++++++- 2 files changed, 140 insertions(+), 2 deletions(-) diff --git a/tests/test_multiturn_env.py b/tests/test_multiturn_env.py index 46834ee8c..67cdbe045 100644 --- a/tests/test_multiturn_env.py +++ b/tests/test_multiturn_env.py @@ -459,3 +459,98 @@ async def test_responses_stored_in_state(self, mock_multiturn_env): for response in state["responses"]: assert hasattr(response, "choices") assert len(response.choices) > 0 + + @pytest.mark.asyncio + async def test_strip_think_string_content_preserves_tail_and_tools( + self, mock_openai_client, sample_chat_dataset + ): + """Ensure only text up to is removed; tool_calls and tool messages remain.""" + from tests.conftest import SimpleMultiTurnEnv + + env = SimpleMultiTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=sample_chat_dataset, + parser=Parser(), + rubric=Rubric(), + exclude_think=True, + ) + + prompt = [{"role": "user", "content": "What is 2+2?"}] + state = await env.init_state( + prompt=prompt, + completion=[], + answer="", + task="default", + info={}, + example_id=0, + ) + + assistant_msg = { + "role": "assistant", + "content": "\nprivate reasoning\n\nCall tool A", + "tool_calls": [ + { + "id": "id1", + "type": "function", + "function": {"name": "toolA", "arguments": "{}"}, + } + ], + } + tool_msg = {"role": "tool", "content": "resultA", "tool_call_id": "id1"} + state["completion"].extend([assistant_msg, tool_msg]) + + ctx = await env.get_context_messages(state) + assert isinstance(ctx, list) + + assert ctx[0] == prompt[0] + assert ctx[1]["role"] == "assistant" + assert ctx[1]["content"] == "Call tool A" + assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"] + assert ctx[2] == tool_msg + + @pytest.mark.asyncio + async def test_no_think_content_is_passthrough( + self, mock_openai_client, sample_chat_dataset + ): + """If no present, assistant content remains unchanged.""" + from tests.conftest import SimpleMultiTurnEnv + + env = SimpleMultiTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=sample_chat_dataset, + parser=Parser(), + rubric=Rubric(), + exclude_think=True, + ) + + prompt = [{"role": "user", "content": "Q"}] + state = await env.init_state( + prompt=prompt, + completion=[], + answer="", + task="default", + info={}, + example_id=0, + ) + + assistant_msg = { + "role": "assistant", + "content": "No CoT here, proceed to tool", + "tool_calls": [ + { + "id": "id3", + "type": "function", + "function": {"name": "toolC", "arguments": "{}"}, + } + ], + } + tool_msg = {"role": "tool", "content": "resultC", "tool_call_id": "id3"} + state["completion"].extend([assistant_msg, tool_msg]) + + ctx = await env.get_context_messages(state) + assert isinstance(ctx, list) + assert ctx[1]["content"] == assistant_msg["content"] + assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"] + assert ctx[2] == tool_msg diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 6f78b3b66..76a3a3aa7 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -20,9 +20,15 @@ class MultiTurnEnv(Environment): - def __init__(self, max_turns: int = -1, **kwargs): + def __init__( + self, + max_turns: int = -1, + exclude_think: bool = False, + **kwargs, + ): super().__init__(**kwargs) self.max_turns = max_turns + self.exclude_think = exclude_think async def prompt_too_long(self, state: State) -> bool: return state.get("prompt_too_long", False) @@ -49,8 +55,45 @@ async def env_response( """ pass + @staticmethod + def _process_assistant_message(msg: ChatMessage) -> ChatMessage: + import re + + def _strip_prefix_up_to_close(text: str) -> str: + return re.sub(r"(?s)^.*", "", text).lstrip() + + new_msg: ChatMessage = {"role": msg.get("role", "assistant")} + if "tool_calls" in msg: + new_msg["tool_calls"] = msg["tool_calls"] + + content = msg.get("content") + if content is None: + new_msg["content"] = "" + return new_msg + + if "" in content: + new_msg["content"] = _strip_prefix_up_to_close(content) + else: + new_msg["content"] = content + + return new_msg + async def get_context_messages(self, state: State) -> Messages: - return state["prompt"] + state["completion"] + if not self.exclude_think: + return state["prompt"] + state["completion"] + + prompt_msgs = state["prompt"] + completion_msgs = state["completion"] + + processed_completion: list[ChatMessage] = [] + for m in completion_msgs: + role = m.get("role") + if role == "assistant": + processed_completion.append(self._process_assistant_message(m)) + else: + processed_completion.append(m) + + return prompt_msgs + processed_completion async def rollout( self, From 1b12878640b59a18fae4faef6c43b92ff89a8c49 Mon Sep 17 00:00:00 2001 From: Kyle Avery <9327972+kyleavery@users.noreply.github.com> Date: Tue, 4 Nov 2025 18:26:01 -0600 Subject: [PATCH 2/6] implement stepwise advantages --- verifiers/envs/multiturn_env.py | 68 +++-- verifiers/rl/README.md | 3 + verifiers/rl/trainer/generator.py | 452 +++++++++++++++++++++++++----- verifiers/rl/trainer/trainer.py | 3 + 4 files changed, 438 insertions(+), 88 deletions(-) diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 76a3a3aa7..507bae201 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -58,13 +58,13 @@ async def env_response( @staticmethod def _process_assistant_message(msg: ChatMessage) -> ChatMessage: import re + from copy import deepcopy def _strip_prefix_up_to_close(text: str) -> str: return re.sub(r"(?s)^.*", "", text).lstrip() - new_msg: ChatMessage = {"role": msg.get("role", "assistant")} - if "tool_calls" in msg: - new_msg["tool_calls"] = msg["tool_calls"] + new_msg: ChatMessage = deepcopy(msg) + new_msg["role"] = msg.get("role", "assistant") content = msg.get("content") if content is None: @@ -102,7 +102,7 @@ async def rollout( prompt: Messages, completion: Messages | None = None, answer: str = "", - state: State = {}, + state: State | None = None, task: str = "default", info: Info | None = None, example_id: int = 0, @@ -118,6 +118,9 @@ async def rollout( state = state or await self.init_state( prompt, completion, answer, task, info, example_id ) + track_step_scores: bool = bool(kwargs.get("track_step_scores", False)) + if track_step_scores and "step_scores" not in state: + state["step_scores"] = [] start_time = time.time() state = await maybe_await(self.setup_state, state, **kwargs) if self.message_type == "chat": @@ -128,10 +131,13 @@ async def rollout( assert isinstance(state["completion"], str) state["responses_start_idx"] = [] while not is_completed: + # 1. Build current context and check early termination context_messages = await self.get_context_messages(state) if await maybe_await(self.is_completed, context_messages, state, **kwargs): is_completed = True break + + # 2. Model response for this turn response = await self.get_model_response( client, model, @@ -146,6 +152,8 @@ async def rollout( state["prompt_too_long"] = True break state["responses"].append(response) + + # 2a. Append assistant message to completion response_text: str = "" if self.message_type == "chat": assert isinstance(context_messages, list) @@ -161,31 +169,55 @@ async def rollout( and response.choices[0].message and response.choices[0].message.tool_calls ): - response_message["tool_calls"] = response.choices[ # type: ignore - 0 - ].message.tool_calls + response_message["tool_calls"] = response.choices[0].message.tool_calls # type: ignore state["completion"].append(response_message) else: assert isinstance(response, Completion) - state["responses_start_idx"].append(len(completion)) + # track where this assistant response starts in the running text + state["responses_start_idx"].append(len(state["completion"])) if response.choices and response.choices[0]: response_text = response.choices[0].text or "" state["completion"] += response_text + + # 3) Environment feedback for THIS turn + # Use latest context that includes the assistant message context_messages = await self.get_context_messages(state) + env_msgs, state = await maybe_await( + self.env_response, context_messages, state, **kwargs + ) + if self.message_type == "chat": + assert isinstance(env_msgs, list) + state["completion"] += env_msgs + else: + assert isinstance(env_msgs, str) + state["completion"] += env_msgs + + # 4) Now compute per-turn score after env feedback is appended + if track_step_scores: + try: + rs = await self.rubric.score_rollout( + prompt=state["prompt"], + completion=state["completion"], + answer=state.get("answer", ""), + state=state, + task=state.get("task", "default"), + info=state.get("info", {}), + example_id=state.get("example_id", example_id), + **kwargs, + ) + state.setdefault("step_scores", []).append(float(rs.reward)) + except Exception as e: + logger.error(f"Error computing step score: {e}") + # state.setdefault("step_scores", []).append(0.0) + raise RuntimeError(f"Step score computation failed: {e}") + + # 5) Prepare for next turn state["turn"] += 1 + context_messages = await self.get_context_messages(state) if await maybe_await(self.is_completed, context_messages, state, **kwargs): is_completed = True end_time = time.time() state["timing"]["generation_ms"] = (end_time - start_time) * 1000 state["timing"]["total_ms"] = (end_time - start_time) * 1000 - else: - env_msgs, state = await maybe_await( - self.env_response, context_messages, state, **kwargs - ) - if self.message_type == "chat": - assert isinstance(env_msgs, list) - state["completion"] += env_msgs - else: - assert isinstance(env_msgs, str) - state["completion"] += env_msgs + break return state["completion"], state diff --git a/verifiers/rl/README.md b/verifiers/rl/README.md index 890e2a8f0..1fdea5e5f 100644 --- a/verifiers/rl/README.md +++ b/verifiers/rl/README.md @@ -97,6 +97,9 @@ We have removed a number of features from the previous `GRPOTrainer`, in favor o - `rollouts_per_example`: rollouts per example/prompt (default is `16`) - `max_seq_len`: the maximum sequence length for the training (default is `2048`) - `max_steps`: the maximum number of steps for the training (default is `500`) + - `use_stepwise_advantage`: if `True`, each assistant turn becomes its own training sample and advantages are computed from a discounted per-step return (default `False`) + - `stepwise_aggregation`: aggregation for stepwise returns: `"sum"` (default) or `"max"`. + - `stepwise_gamma`: discount factor `gamma` for stepwise returns (default `0.4`). With `stepwise_aggregation="sum"`, `R_t=\sum_{i=t}^{T}{\gamma^{i-t} r_i}`; with `stepwise_aggregation="max"`, `R_t=\max_{i=t..T}{\gamma^{i-t} r_i}`. - Sampling configuration arguments: - `max_tokens`: the maximum number of tokens per request (default is `None`) - `temperature`: the temperature for the sampling (default is `0.7`) diff --git a/verifiers/rl/trainer/generator.py b/verifiers/rl/trainer/generator.py index abde0a407..c6988ef78 100644 --- a/verifiers/rl/trainer/generator.py +++ b/verifiers/rl/trainer/generator.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import queue import threading @@ -13,6 +14,12 @@ from transformers import PreTrainedTokenizerBase from verifiers import Environment +from verifiers.utils.processing_utils import ( + parse_chat_completion_logprobs, + parse_chat_completion_tokens, + parse_completion_logprobs, + parse_completion_tokens, +) class Microbatch(BaseModel): @@ -66,6 +73,9 @@ def __init__( mask_truncated_completions: bool, zero_truncated_completions: bool, max_concurrent: int, + use_stepwise_advantage: bool, + stepwise_gamma: float, + stepwise_aggregation: str, ): self.env = env self.client_base_url = client_base_url @@ -87,6 +97,9 @@ def __init__( self.mask_truncated_completions = mask_truncated_completions self.zero_truncated_completions = zero_truncated_completions self.max_concurrent = max_concurrent + self.use_stepwise_advantage = use_stepwise_advantage + self.stepwise_gamma = float(stepwise_gamma) + self.stepwise_aggregation = stepwise_aggregation # queues for communication self.request_queue = queue.Queue() @@ -225,41 +238,273 @@ async def generate_batch(self, batch_id: int) -> Batch: sampling_args=self.sampling_args, score_rollouts=True, max_concurrent=self.max_concurrent, + track_step_scores=self.use_stepwise_advantage, ) self.is_generating = False wall_clock_s = time.time() - start_time - processed_results = self.env.process_env_results_vllm( - prompts=env_results.prompt, - completions=env_results.completion, - states=env_results.state, - rewards=env_results.reward, - processing_class=self.processing_class, - max_seq_len=self.max_seq_len, - mask_env_responses=self.mask_env_responses, - mask_truncated_completions=self.mask_truncated_completions, - zero_truncated_completions=self.zero_truncated_completions, - ) + if not self.use_stepwise_advantage: + processed_results = self.env.process_env_results_vllm( + prompts=env_results.prompt, + completions=env_results.completion, + states=env_results.state, + rewards=env_results.reward, + processing_class=self.processing_class, + max_seq_len=self.max_seq_len, + mask_env_responses=self.mask_env_responses, + mask_truncated_completions=self.mask_truncated_completions, + zero_truncated_completions=self.zero_truncated_completions, + ) - rewards_dict = {"reward": processed_results.rewards} - for k in env_results.metrics: - rewards_dict[k] = env_results.metrics[k] - - rewards: list[float] = processed_results.rewards - advantages: list[float] = [0.0] * len(rewards) - prompts_in_batch = len(batch_ds) - for prompt_idx in range(prompts_in_batch): - group_indices = [ - prompt_idx + k * prompts_in_batch - for k in range(self.rollouts_per_example) - if (prompt_idx + k * prompts_in_batch) < len(rewards) - ] - if not group_indices: - continue - group = [rewards[i] for i in group_indices] - gmean = sum(group) / float(len(group)) - for idx, r in zip(group_indices, group): - advantages[idx] = r - gmean + rewards_dict = {"reward": processed_results.rewards} + for k in env_results.metrics: + rewards_dict[k] = env_results.metrics[k] + + rewards: list[float] = processed_results.rewards + advantages: list[float] = [0.0] * len(rewards) + prompts_in_batch = len(batch_ds) + for prompt_idx in range(prompts_in_batch): + group_indices = [ + prompt_idx + k * prompts_in_batch + for k in range(self.rollouts_per_example) + if (prompt_idx + k * prompts_in_batch) < len(rewards) + ] + if not group_indices: + continue + group = [rewards[i] for i in group_indices] + gmean = sum(group) / float(len(group)) + for idx, r in zip(group_indices, group): + advantages[idx] = r - gmean + else: + # Expand each rollout into per-step training samples and compute + # discounted MC returns per step. + # Reference: https://arxiv.org/abs/2507.11948 + msg_type = getattr(self.env, "message_type", "chat") + all_prompt_ids: list[list[int]] = [] + all_prompt_masks: list[list[int]] = [] + all_completion_ids: list[list[int]] = [] + all_completion_masks: list[list[int]] = [] + all_completion_logprobs: list[list[float]] = [] + all_returns: list[float] = [] + item_meta: list[tuple[int, int]] = [] # (prompt_idx, step_idx) + + # iterate in the same order as env_results arrays + for i, (prompt, completion, state) in enumerate( + zip(env_results.prompt, env_results.completion, env_results.state) + ): + # determine prompt index within this batch + prompt_idx = i % len(batch_ds) + # build per-step tokenization + step_items: list[tuple[list[int], list[int], list[int], list[int], list[float]]] + if msg_type == "chat": + assert isinstance(prompt, list) and isinstance(completion, list) + step_items = [] + # Build context for tokenization similar to process_chat_format_vllm + responses = state["responses"] + responses_idx = 0 + zipped_steps = [] + for turn in completion: + if turn.get("role") == "assistant": + zipped_steps.append((turn, responses[responses_idx])) + responses_idx += 1 + else: + zipped_steps.append((turn, None)) + assert responses_idx == len(responses) + + # utility to deserialize tool_calls for templates that expect JSON args + def _deserialize_tool_calls(message: dict) -> dict: + def _deserialize(tc) -> dict: + tc = dict(tc) + if ( + "function" in tc + and isinstance(tc["function"], dict) + and "arguments" in tc["function"] + ): + args = tc["function"]["arguments"] + if isinstance(args, str): + try: + args = json.loads(args) + except Exception: + pass + tc["function"] = {**tc["function"], "arguments": args} + return tc + + return { + **message, + "tool_calls": [ + _deserialize(tc) + for tc in (message.get("tool_calls", []) or []) + ], + } + + def _maybe_strip_think(msg: dict) -> dict: + if getattr(self.env, "exclude_think", False) and hasattr(self.env, "_process_assistant_message"): + if msg.get("role") == "assistant": + return self.env._process_assistant_message(msg) + return msg + + messages_consumed: list[dict] = [_maybe_strip_think(dict(m)) for m in prompt] + si = 0 + j = 0 + while j < len(zipped_steps): + message, response = zipped_steps[j] + message = _deserialize_tool_calls(message) + if message.get("role") == "assistant": + assert response is not None + prompt_text = self.processing_class.apply_chat_template( + conversation=messages_consumed, + tokenize=False, + add_generation_prompt=True, + ) + prompt_ids = self.processing_class.encode(prompt_text) + prompt_mask = [0] * len(prompt_ids) + completion_ids = parse_chat_completion_tokens(response) + completion_mask = [1] * len(completion_ids) + completion_logprobs = parse_chat_completion_logprobs(response) + step_items.append( + ( + prompt_ids, + prompt_mask, + completion_ids, + completion_mask, + completion_logprobs, + ) + ) + messages_consumed.append(_maybe_strip_think(message)) + si += 1 + j += 1 + else: + messages_consumed.append(_maybe_strip_think(message)) + j += 1 + else: + assert isinstance(prompt, str) and isinstance(completion, str) + responses = state.get("responses", []) + starts = state.get("responses_start_idx", []) + assert len(responses) == len(starts) + step_items = [] + for ridx in range(len(responses)): + start_i = starts[ridx] + context_prefix = prompt + completion[:start_i] + prompt_ids = self.processing_class.encode(context_prefix) + prompt_mask = [0] * len(prompt_ids) + resp = responses[ridx] + completion_ids = parse_completion_tokens(resp) + completion_mask = [1] * len(completion_ids) + completion_logprobs = parse_completion_logprobs(resp) + step_items.append( + ( + prompt_ids, + prompt_mask, + completion_ids, + completion_mask, + completion_logprobs, + ) + ) + + # compute immediate rewards per step and MC returns + if msg_type == "chat": + assert isinstance(prompt, list) and isinstance(completion, list) + # Prefer precomputed per-turn scores from rollout if available + step_rewards: list[float] = [] + pre_scores = state.get("step_scores", None) + if isinstance(pre_scores, list) and pre_scores: + step_rewards = [float(x) for x in pre_scores] + else: + raise RuntimeError("Per-turn scores missing in state for stepwise advantage computation") + + returns: list[float] = self._compute_stepwise_returns(step_rewards) + + else: + assert isinstance(prompt, str) and isinstance(completion, str) + responses = state.get("responses", []) + starts = state.get("responses_start_idx", []) + step_rewards = [] + assert len(responses) == len(starts) + pre_scores = state.get("step_scores", None) + if isinstance(pre_scores, list) and pre_scores: + step_rewards = [float(x) for x in pre_scores] + step_rewards = step_rewards[: len(starts)] + else: + for ridx, start in enumerate(starts): + # Include env feedback after this assistant response + end_i = starts[ridx + 1] if (ridx + 1) < len(starts) else len(completion) + partial_text = completion[:end_i] + rs = await self.env.rubric.score_rollout( + prompt=prompt, + completion=partial_text, + answer=state.get("answer", ""), + state=state, + task=state.get("task", "default"), + info=state.get("info", {}), + example_id=state.get("example_id", i), + ) + step_rewards.append(float(rs.reward)) + + returns: list[float] = self._compute_stepwise_returns(step_rewards) + + for step_idx, item in enumerate(step_items): + p_ids, p_mask, c_ids, c_mask, c_logps = item + completion_truncated = False + if self.max_seq_len > 0: + max_c_possible = min(len(c_ids), self.max_seq_len) + keep_p = self.max_seq_len - max_c_possible + if keep_p < len(p_ids): + if keep_p <= 0: + p_ids, p_mask = [], [] + else: + p_ids = p_ids[-keep_p:] + p_mask = p_mask[-keep_p:] + + max_c_len = self.max_seq_len - len(p_ids) + if len(c_ids) > max_c_len: + completion_truncated = True + c_ids = c_ids[:max_c_len] if max_c_len > 0 else [] + c_mask = c_mask[:max_c_len] if max_c_len > 0 else [] + c_logps = c_logps[:max_c_len] if max_c_len > 0 else [] + + effective_c_mask = c_mask + if completion_truncated and self.mask_truncated_completions: + effective_c_mask = [0] * len(c_ids) + + ret = float(returns[step_idx]) + if completion_truncated and self.zero_truncated_completions: + ret = 0.0 + + all_prompt_ids.append(p_ids) + all_prompt_masks.append(p_mask) + all_completion_ids.append(c_ids) + all_completion_masks.append(effective_c_mask) + all_completion_logprobs.append(c_logps) + all_returns.append(ret) + item_meta.append((prompt_idx, step_idx)) + + + class _Proc: + def __init__(self): + self.prompt_ids = all_prompt_ids + self.prompt_mask = all_prompt_masks + self.completion_ids = all_completion_ids + self.completion_mask = all_completion_masks + self.completion_logprobs = all_completion_logprobs + self.rewards = all_returns + + # mimic ProcessedOutputs for downstream use + processed_results = _Proc() + rewards_dict = {"reward": all_returns} + + rewards = all_returns + # Compute stepwise group baseline across all m×n samples for each prompt + advantages: list[float] = [0.0] * len(rewards) + prompts_in_batch = len(batch_ds) + for prompt_idx in range(prompts_in_batch): + group = [j for j, (p, _s) in enumerate(item_meta) if p == prompt_idx] + if not group: + continue + group_vals = [rewards[j] for j in group] + gmean = float(np.mean(group_vals)) + gstd = float(np.std(group_vals)) + 1e-8 # prevent div by zero + for j in group: + advantages[j] = (rewards[j] - gmean) / gstd metrics_dict = {} if rewards: @@ -267,6 +512,10 @@ async def generate_batch(self, batch_id: int) -> Batch: metrics_dict["reward"] = float(rewards_arr.mean()) metrics_dict["reward/std"] = float(rewards_arr.std()) + if self.use_stepwise_advantage: + metrics_dict["stepwise/turns_per_rollout"] = float(np.mean([len(s.get("step_scores", [])) for s in env_results.state])) + metrics_dict["stepwise/rollout_length"] = float(np.mean([len(s.get("responses", [])) for s in env_results.state])) + if advantages: adv_arr = np.asarray(advantages, dtype=np.float32) metrics_dict["advantage/absmean"] = float(np.abs(adv_arr).mean()) @@ -315,48 +564,85 @@ async def generate_batch(self, batch_id: int) -> Batch: # build per-process microbatches N = len(processed_results.rewards) - per_proc = N // self.num_processes microbatches: list[list[Microbatch]] = [] items_per_process: list[int] = [] - for proc in range(self.num_processes): - ps = proc * per_proc - pe = ps + per_proc - proc_mbs: list[Microbatch] = [] - proc_item_total = 0 - for s in range(ps, pe, self.micro_batch_size): - e = min(s + self.micro_batch_size, pe) - ids_chunk = [ - processed_results.prompt_ids[i] - + processed_results.completion_ids[i] - for i in range(s, e) - ] - mask_chunk = [ - processed_results.prompt_mask[i] - + processed_results.completion_mask[i] - for i in range(s, e) - ] - slogp_chunk = [ - [0.0] * len(processed_results.prompt_mask[i]) - + processed_results.completion_logprobs[i] - for i in range(s, e) - ] - lengths = [len(mask) for mask in mask_chunk] - adv_chunk = [ - [advantages[i]] * lengths[idx] - for idx, i in enumerate(list(range(s, e))) - ] - mb_items = sum(sum(mask) for mask in mask_chunk) - microbatch = Microbatch( - input_ids=ids_chunk, - loss_mask=mask_chunk, - sampling_logprobs=slogp_chunk, - advantages=adv_chunk, - items=mb_items, - ) - proc_item_total += mb_items - proc_mbs.append(microbatch) - microbatches.append(proc_mbs) - items_per_process.append(proc_item_total) + if not self.use_stepwise_advantage: + per_proc = N // self.num_processes + for proc in range(self.num_processes): + ps = proc * per_proc + pe = ps + per_proc + proc_mbs: list[Microbatch] = [] + proc_item_total = 0 + for s in range(ps, pe, self.micro_batch_size): + e = min(s + self.micro_batch_size, pe) + ids_chunk = [ + processed_results.prompt_ids[i] + + processed_results.completion_ids[i] + for i in range(s, e) + ] + mask_chunk = [ + processed_results.prompt_mask[i] + + processed_results.completion_mask[i] + for i in range(s, e) + ] + slogp_chunk = [ + [0.0] * len(processed_results.prompt_mask[i]) + + processed_results.completion_logprobs[i] + for i in range(s, e) + ] + lengths = [len(mask) for mask in mask_chunk] + adv_chunk = [ + [advantages[i]] * lengths[idx] + for idx, i in enumerate(list(range(s, e))) + ] + mb_items = sum(sum(mask) for mask in mask_chunk) + microbatch = Microbatch( + input_ids=ids_chunk, + loss_mask=mask_chunk, + sampling_logprobs=slogp_chunk, + advantages=adv_chunk, + items=mb_items, + ) + proc_item_total += mb_items + proc_mbs.append(microbatch) + microbatches.append(proc_mbs) + items_per_process.append(proc_item_total) + else: + for proc in range(self.num_processes): + indices = list(range(proc, N, self.num_processes)) + proc_mbs: list[Microbatch] = [] + proc_item_total = 0 + for start in range(0, len(indices), self.micro_batch_size): + idxs = indices[start : start + self.micro_batch_size] + ids_chunk = [ + processed_results.prompt_ids[i] + + processed_results.completion_ids[i] + for i in idxs + ] + mask_chunk = [ + processed_results.prompt_mask[i] + + processed_results.completion_mask[i] + for i in idxs + ] + slogp_chunk = [ + [0.0] * len(processed_results.prompt_mask[i]) + + processed_results.completion_logprobs[i] + for i in idxs + ] + lengths = [len(mask) for mask in mask_chunk] + adv_chunk = [[advantages[i]] * lengths[k] for k, i in enumerate(idxs)] + mb_items = sum(sum(mask) for mask in mask_chunk) + microbatch = Microbatch( + input_ids=ids_chunk, + loss_mask=mask_chunk, + sampling_logprobs=slogp_chunk, + advantages=adv_chunk, + items=mb_items, + ) + proc_item_total += mb_items + proc_mbs.append(microbatch) + microbatches.append(proc_mbs) + items_per_process.append(proc_item_total) global_item_count = sum(items_per_process) @@ -371,3 +657,29 @@ async def generate_batch(self, batch_id: int) -> Batch: prompts=env_results.prompt, metrics_dict=metrics_dict, ) + + def _compute_stepwise_returns(self, step_rewards: list[float]) -> list[float]: + if not step_rewards: + return [] + + g = float(self.stepwise_gamma) + if self.stepwise_aggregation == "sum": + # R_t=\sum_{i=t}^{T}{\gamma^{i-t} r_i} + G = 0.0 + out = [0.0] * len(step_rewards) + for t in range(len(step_rewards) - 1, -1, -1): + G = float(step_rewards[t]) + g * G + out[t] = G + return out + + elif self.stepwise_aggregation == "max": + # R_t=\max_{i=t..T}{\gamma^{i-t} r_i} + out = [0.0] * len(step_rewards) + R_next: float | None = None + for t in range(len(step_rewards) - 1, -1, -1): + r = float(step_rewards[t]) + cand_future = (g * R_next) if (R_next is not None) else None + R_t = r if cand_future is None else max(r, cand_future) + out[t] = R_t + R_next = R_t + return out diff --git a/verifiers/rl/trainer/trainer.py b/verifiers/rl/trainer/trainer.py index bbeeef6a2..054685a33 100644 --- a/verifiers/rl/trainer/trainer.py +++ b/verifiers/rl/trainer/trainer.py @@ -112,6 +112,9 @@ def __init__( mask_truncated_completions=args.mask_truncated_completions, zero_truncated_completions=args.zero_truncated_completions, max_concurrent=args.max_concurrent, + use_stepwise_advantage=args.use_stepwise_advantage, + stepwise_gamma=args.stepwise_gamma, + stepwise_aggregation=args.stepwise_aggregation, ) self.generator.start() self.generator.submit_batch(0) From f98f5d88b28eb471542e746ebf54b04efc8132a2 Mon Sep 17 00:00:00 2001 From: Kyle Avery <9327972+kyleavery@users.noreply.github.com> Date: Wed, 5 Nov 2025 07:42:15 -0600 Subject: [PATCH 3/6] stepwise tests --- tests/test_stepwise_advantages.py | 477 ++++++++++++++++++++++++++++++ 1 file changed, 477 insertions(+) create mode 100644 tests/test_stepwise_advantages.py diff --git a/tests/test_stepwise_advantages.py b/tests/test_stepwise_advantages.py new file mode 100644 index 000000000..3025416b3 --- /dev/null +++ b/tests/test_stepwise_advantages.py @@ -0,0 +1,477 @@ +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +from datasets import Dataset + +from verifiers.types import GenerateMetadata, GenerateOutputs +from verifiers.rl.trainer.generator import Generator +from verifiers.envs.multiturn_env import MultiTurnEnv +from verifiers.rubrics.math_rubric import MathRubric + + +class DummyTokenizer: + def encode(self, text: str) -> list[int]: + # Deterministic tokenization proportional to text length + return list(range(len(text))) + + def apply_chat_template( + self, + conversation: list[dict], + tokenize: bool = False, + add_generation_prompt: bool = True, + ) -> str: + # Simplified chat template: concatenate contents only + return "".join(m.get("content", "") for m in conversation) + + +class StepwiseChatEnv(MultiTurnEnv): + def __init__( + self, + dataset: Dataset, + prepared: list[tuple[list[dict], list[dict], dict[str, Any]]], + ): + rb = MathRubric() + super().__init__( + exclude_think=True, + dataset=dataset, + message_type="chat", + rubric=rb, + parser=rb.parser, + ) + self._prepared = prepared + + async def env_response(self, messages, state, **kwargs): + return [], state + + async def generate(self, *args, **kwargs) -> GenerateOutputs: # type: ignore[override] + prompts: list[Any] = [] + completions: list[Any] = [] + states: list[dict[str, Any]] = [] + answers: list[str] = [] + tasks: list[str] = [] + infos: list[dict[str, Any]] = [] + example_ids: list[int] = [] + + for i, (p, c, s) in enumerate(self._prepared): + s = dict(s) + s.setdefault("timing", {}) + s["timing"].setdefault("generation_ms", 0.0) + s["timing"].setdefault("scoring_ms", 0.0) + s["timing"].setdefault("total_ms", 0.0) + + if "step_scores" not in s: + step_scores: list[float] = [] + j = 0 + while j < len(c): + msg = c[j] + if msg.get("role") == "assistant": + k = j + 1 + while k < len(c) and c[k].get("role") != "assistant": + k += 1 + partial = c[:k] + rs = await self.rubric.score_rollout( + prompt=p, + completion=list(partial), + answer=s.get("answer", ""), + state=s, + task=s.get("task", "default"), + info=s.get("info", {}), + example_id=i, + ) + step_scores.append(float(rs.reward)) + j = k + else: + j += 1 + s["step_scores"] = step_scores + + prompts.append(p) + completions.append(c) + states.append(s) + answers.append(s.get("answer", "")) + tasks.append(s.get("task", "default")) + infos.append(s.get("info", {})) + example_ids.append(i) + + meta = GenerateMetadata( + env_id="stub-chat", + env_args={}, + model="test-model", + base_url="http://localhost/v1", + num_examples=len(prompts), + rollouts_per_example=1, + sampling_args={}, + date="1970-01-01", + time_ms=0.0, + avg_reward=0.0, + avg_metrics={}, + state_columns=[], + path_to_save=Path("/tmp/stub"), + ) + return GenerateOutputs( + prompt=prompts, + completion=completions, + answer=answers, + state=states, + task=tasks, + info=infos, + example_id=example_ids, + reward=[0.0] * len(prompts), + metrics={}, + metadata=meta, + ) + + +def chat_prompt(text: str) -> list[dict]: + return [{"role": "user", "content": text}] + + +def assistant_msg(boxed: str, think: str = "chain-of-thought") -> dict: + return { + "role": "assistant", + "content": f"{think} Visible: \\boxed{{{boxed}}}", + } + + +def user_env_msg(text: str = "ack") -> dict: + return {"role": "user", "content": text} + + +def make_chat_rollout( + prompt: list[dict], + answer: str, + boxed_sequence: list[str], + token_lengths: list[int] | None = None, + step_scores: list[float] | None = None, +) -> tuple[list[dict], list[dict], dict[str, Any]]: + token_lengths = token_lengths or [3] * len(boxed_sequence) + completion: list[dict] = [] + responses = [] + for b, tlen in zip(boxed_sequence, token_lengths): + completion.append(assistant_msg(b)) + completion.append(user_env_msg("ok")) + responses.append({"tokens_len": int(tlen)}) + + state = { + "prompt": prompt, + "completion": completion, + "responses": responses, + "turn": len(boxed_sequence), + "timing": {"total_ms": 0.0}, + "task": "default", + "info": {}, + "answer": answer, + } + if step_scores is not None: + state["step_scores"] = [float(x) for x in step_scores] + return prompt, completion, state + + +def compute_discounted_returns( + rewards: list[float], gamma: float, aggregation: str +) -> list[float]: + if not rewards: + return [] + g = float(gamma) + if aggregation == "sum": + out = [0.0] * len(rewards) + G = 0.0 + for t in range(len(rewards) - 1, -1, -1): + G = float(rewards[t]) + g * G + out[t] = G + return out + else: + out = [0.0] * len(rewards) + R_next: float | None = None + for t in range(len(rewards) - 1, -1, -1): + r = float(rewards[t]) + cand = (g * R_next) if (R_next is not None) else None + R_t = r if cand is None else max(r, cand) + out[t] = R_t + R_next = R_t + return out + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aggregation,gamma", + [ + ("sum", 1.0), + ("max", 0.5), + ], +) +async def test_stepwise_chat_precomputed_returns_and_metrics_with_advantage_checks( + monkeypatch: pytest.MonkeyPatch, aggregation: str, gamma: float +): + import verifiers.rl.trainer.generator as gen_mod + + monkeypatch.setattr( + gen_mod, + "parse_chat_completion_tokens", + lambda resp: list(range(resp.get("tokens_len", 2))), + ) + monkeypatch.setattr( + gen_mod, + "parse_chat_completion_logprobs", + lambda resp: [-1.0] * resp.get("tokens_len", 2), + ) + + # Two prompts, two rollouts each (total 4) + pA = chat_prompt("Compute 2+2 and provide a boxed answer.") + pB = chat_prompt("Compute 1+1 and provide a boxed answer.") + + # Boxed sequences and precomputed step_scores (immediate rewards) + # A1: wrong then right -> [0,1] + A1 = make_chat_rollout(pA, answer="4", boxed_sequence=["3", "4"], step_scores=[0.0, 1.0]) + # B1: both right -> [1,1] + B1 = make_chat_rollout(pB, answer="2", boxed_sequence=["2", "2"], step_scores=[1.0, 1.0]) + # A2: right then wrong -> [1,0] + A2 = make_chat_rollout(pA, answer="4", boxed_sequence=["4", "3"], step_scores=[1.0, 0.0]) + # B2: wrong then right -> [0,1] + B2 = make_chat_rollout(pB, answer="2", boxed_sequence=["3", "2"], step_scores=[0.0, 1.0]) + + prepared = [A1, B1, A2, B2] + + ds = Dataset.from_dict({"prompt": [pA, pB]}) + env = StepwiseChatEnv(dataset=ds, prepared=prepared) + tokenizer = DummyTokenizer() + + g = Generator( + env=env, + client_base_url="http://localhost/v1", + client_api_key="test", + client_limit=1, + client_timeout=1.0, + model_name="m", + sampling_args={}, + rollouts_per_example=2, + batch_size=4, + micro_batch_size=2, + num_processes=1, + generation_timeout=5.0, + processing_class=tokenizer, + mask_env_responses=False, + max_seq_len=4096, + max_prompt_len=4096, + mask_truncated_completions=False, + zero_truncated_completions=False, + max_concurrent=1, + use_stepwise_advantage=True, + stepwise_gamma=gamma, + stepwise_aggregation=aggregation, + ) + + monkeypatch.setattr(env, "a_generate", env.generate) + g.client = object() + + result = await g.generate_batch(batch_id=0) + + all_step_rewards = [s["step_scores"] for _p, _c, s in prepared] + expected = [ + r for rewards in all_step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation) + ] + + assert np.allclose(result.rewards_dict["reward"], expected) + + assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected)) + assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0 + assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0 + + # Advantage sanity checks + # - per-token advantages are constant within each sample + # - lengths of advantages match input_ids and loss_mask + # - advantages are z-scored within each prompt group (mean≈0, std≈1) + # - advantage/absmean > 0 for non-degenerate case + assert result.metrics_dict.get("advantage/absmean", 0.0) > 0.0 + + sample_adv_scalars: list[float] = [] + microbatches = result.microbatches[0] + for mb in microbatches: + for row_ids, row_mask, row_adv in zip(mb.input_ids, mb.loss_mask, mb.advantages): + assert len(row_ids) == len(row_mask) == len(row_adv) + uniq_vals = set(row_adv) + assert len(uniq_vals) == 1 + sample_adv_scalars.append(row_adv[0]) + + # Reconstruct prompt groups: with two steps per rollout and two prompts, + # grouping used by Generator is i % prompts_in_batch; here i == j // steps_per_rollout + steps_per_rollout = 2 + prompts_in_batch = len(ds) + N = len(sample_adv_scalars) + assert N == 4 * steps_per_rollout + + group0 = [sample_adv_scalars[j] for j in range(N) if ((j // steps_per_rollout) % prompts_in_batch) == 0] + group1 = [sample_adv_scalars[j] for j in range(N) if ((j // steps_per_rollout) % prompts_in_batch) == 1] + + assert pytest.approx(float(np.mean(group0)), abs=1e-6) == 0.0 + assert pytest.approx(float(np.mean(group1)), abs=1e-6) == 0.0 + assert pytest.approx(float(np.std(group0)), rel=1e-5) == 1.0 + assert pytest.approx(float(np.std(group1)), rel=1e-5) == 1.0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aggregation,gamma", + [ + ("sum", 1.0), + ("max", 0.5), + ], +) +async def test_stepwise_chat_returns_precomputed_by_math_rubric( + monkeypatch: pytest.MonkeyPatch, aggregation: str, gamma: float +): + import verifiers.rl.trainer.generator as gen_mod + + monkeypatch.setattr( + gen_mod, + "parse_chat_completion_tokens", + lambda resp: list(range(resp.get("tokens_len", 2))), + ) + monkeypatch.setattr( + gen_mod, + "parse_chat_completion_logprobs", + lambda resp: [-1.0] * resp.get("tokens_len", 2), + ) + + pA = chat_prompt("Compute 2+2 and provide a boxed answer.") + pB = chat_prompt("Compute 1+1 and provide a boxed answer.") + + A1 = make_chat_rollout(pA, answer="4", boxed_sequence=["3", "4"]) + B1 = make_chat_rollout(pB, answer="2", boxed_sequence=["2", "2"]) + A2 = make_chat_rollout(pA, answer="4", boxed_sequence=["4", "3"]) + B2 = make_chat_rollout(pB, answer="2", boxed_sequence=["3", "2"]) + + prepared = [A1, B1, A2, B2] + ds = Dataset.from_dict({"prompt": [pA, pB]}) + env = StepwiseChatEnv(dataset=ds, prepared=prepared) + tokenizer = DummyTokenizer() + + g = Generator( + env=env, + client_base_url="http://localhost/v1", + client_api_key="test", + client_limit=1, + client_timeout=1.0, + model_name="m", + sampling_args={}, + rollouts_per_example=2, + batch_size=4, + micro_batch_size=2, + num_processes=1, + generation_timeout=5.0, + processing_class=tokenizer, + mask_env_responses=False, + max_seq_len=4096, + max_prompt_len=4096, + mask_truncated_completions=False, + zero_truncated_completions=False, + max_concurrent=1, + use_stepwise_advantage=True, + stepwise_gamma=gamma, + stepwise_aggregation=aggregation, + ) + + monkeypatch.setattr(env, "a_generate", env.generate) + g.client = object() + + result = await g.generate_batch(batch_id=0) + + # Expected immediate rewards derived from MathRubric correctness on each step + # A1: [0,1], B1: [1,1], A2: [1,0], B2: [0,1] + step_rewards = [[0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 1.0]] + expected = [ + r for rewards in step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation) + ] + + assert np.allclose(result.rewards_dict["reward"], expected) + assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected)) + assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0 + assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0 + + +@pytest.mark.asyncio +async def test_stepwise_chat_truncation_zero_reward(monkeypatch: pytest.MonkeyPatch): + import verifiers.rl.trainer.generator as gen_mod + + monkeypatch.setattr( + gen_mod, + "parse_chat_completion_tokens", + lambda resp: list(range(resp.get("tokens_len", 2))), + ) + monkeypatch.setattr( + gen_mod, + "parse_chat_completion_logprobs", + lambda resp: [-1.0] * resp.get("tokens_len", 2), + ) + + pA = chat_prompt("Compute 2+2 and provide a boxed answer.") + pB = chat_prompt("Compute 1+1 and provide a boxed answer.") + + # All steps 'correct' → immediate rewards [1,1] per rollout + # Force truncation on first step by assigning large token length + A1 = make_chat_rollout( + pA, + answer="4", + boxed_sequence=["4", "4"], + token_lengths=[50, 2], + step_scores=[1.0, 1.0], + ) + B1 = make_chat_rollout( + pB, answer="2", boxed_sequence=["2", "2"], token_lengths=[2, 2], step_scores=[1.0, 1.0] + ) + A2 = make_chat_rollout( + pA, answer="4", boxed_sequence=["4", "4"], token_lengths=[2, 2], step_scores=[1.0, 1.0] + ) + B2 = make_chat_rollout( + pB, answer="2", boxed_sequence=["2", "2"], token_lengths=[2, 2], step_scores=[1.0, 1.0] + ) + + prepared = [A1, B1, A2, B2] + ds = Dataset.from_dict({"prompt": [pA, pB]}) + env = StepwiseChatEnv(dataset=ds, prepared=prepared) + tokenizer = DummyTokenizer() + + g = Generator( + env=env, + client_base_url="http://localhost/v1", + client_api_key="test", + client_limit=1, + client_timeout=1.0, + model_name="m", + sampling_args={}, + rollouts_per_example=2, + batch_size=4, + micro_batch_size=2, + num_processes=1, + generation_timeout=5.0, + processing_class=tokenizer, + mask_env_responses=False, + max_seq_len=10, + max_prompt_len=4096, + mask_truncated_completions=True, + zero_truncated_completions=True, + max_concurrent=1, + use_stepwise_advantage=True, + stepwise_gamma=1.0, + stepwise_aggregation="sum", + ) + + monkeypatch.setattr(env, "a_generate", env.generate) + g.client = object() + + result = await g.generate_batch(batch_id=0) + + # Without truncation, each rollout [1,1] with gamma=1 → returns [2,1] + expected_naive = [2.0, 1.0] * 4 + rewards = list(result.rewards_dict["reward"]) + + # Only the very first step (A1 step 1) should be truncated → set to 0.0 + expected = expected_naive[:] + expected[0] = 0.0 + + assert np.allclose(rewards, expected) + + # masking should be nonzero due to truncation + assert "tokens/masked_fraction" in result.metrics_dict + assert result.metrics_dict["tokens/masked_fraction"] > 0.0 From 8f5f6d0aac996de54619f5b31f9ceb57a238a9c2 Mon Sep 17 00:00:00 2001 From: Kyle Avery <9327972+kyleavery@users.noreply.github.com> Date: Wed, 5 Nov 2025 13:51:49 -0600 Subject: [PATCH 4/6] fix stepwise tests --- tests/test_stepwise_advantages.py | 32 +++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/test_stepwise_advantages.py b/tests/test_stepwise_advantages.py index 3025416b3..7f3155a70 100644 --- a/tests/test_stepwise_advantages.py +++ b/tests/test_stepwise_advantages.py @@ -268,13 +268,19 @@ async def test_stepwise_chat_precomputed_returns_and_metrics_with_advantage_chec result = await g.generate_batch(batch_id=0) all_step_rewards = [s["step_scores"] for _p, _c, s in prepared] - expected = [ + # For logging: one reward per rollout (return at first assistant turn) + expected_log_rewards = [ + (compute_discounted_returns(rewards, gamma, aggregation)[0] if rewards else 0.0) + for rewards in all_step_rewards + ] + # For metrics: mean over per-step returns remains unchanged + expected_all_returns = [ r for rewards in all_step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation) ] - assert np.allclose(result.rewards_dict["reward"], expected) + assert np.allclose(result.rewards_dict["reward"], expected_log_rewards) - assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected)) + assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected_all_returns)) assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0 assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0 @@ -380,12 +386,15 @@ async def test_stepwise_chat_returns_precomputed_by_math_rubric( # Expected immediate rewards derived from MathRubric correctness on each step # A1: [0,1], B1: [1,1], A2: [1,0], B2: [0,1] step_rewards = [[0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 1.0]] - expected = [ + expected_log = [ + compute_discounted_returns(rewards, gamma, aggregation)[0] for rewards in step_rewards + ] + expected_all = [ r for rewards in step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation) ] - assert np.allclose(result.rewards_dict["reward"], expected) - assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected)) + assert np.allclose(result.rewards_dict["reward"], expected_log) + assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected_all)) assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0 assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0 @@ -462,15 +471,10 @@ async def test_stepwise_chat_truncation_zero_reward(monkeypatch: pytest.MonkeyPa result = await g.generate_batch(batch_id=0) - # Without truncation, each rollout [1,1] with gamma=1 → returns [2,1] - expected_naive = [2.0, 1.0] * 4 + # With step scores [1,1] and gamma=1, rollout-level reward is 2.0 for each rollout rewards = list(result.rewards_dict["reward"]) - - # Only the very first step (A1 step 1) should be truncated → set to 0.0 - expected = expected_naive[:] - expected[0] = 0.0 - - assert np.allclose(rewards, expected) + expected_log = [2.0, 2.0, 2.0, 2.0] + assert np.allclose(rewards, expected_log) # masking should be nonzero due to truncation assert "tokens/masked_fraction" in result.metrics_dict From f2c597fcb7b3d96f5a7220a40079ab2553953d1c Mon Sep 17 00:00:00 2001 From: Kyle Avery <9327972+kyleavery@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:50:00 -0600 Subject: [PATCH 5/6] update config --- verifiers/rl/trainer/config.py | 37 +++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/verifiers/rl/trainer/config.py b/verifiers/rl/trainer/config.py index 6e35c3ba4..f28c8503a 100644 --- a/verifiers/rl/trainer/config.py +++ b/verifiers/rl/trainer/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union from peft import LoraConfig from transformers import TrainingArguments @@ -127,6 +127,33 @@ class RLConfig(TrainingArguments): default=False, metadata={"help": "Whether to give zero reward to truncated completions."}, ) + + # stepwise returns / advantage configuration + use_stepwise_advantage: bool = field( + default=False, + metadata={ + "help": ( + "If True, treat each assistant turn as its own training sample and use a discounted MC return per step." + ) + }, + ) + stepwise_aggregation: Literal["sum", "max"] = field( + default="sum", + metadata={ + "help": ( + "How to compute discounted per-step return R_t from future rewards." + ) + }, + ) + stepwise_gamma: float = field( + default=0.4, + metadata={ + "help": ( + "Discount factor gamma for stepwise MC return. Must be in [0, 1]." + ) + }, + ) + # sampling_args for generation max_tokens: Optional[int] = field( default=None, @@ -333,3 +360,11 @@ def __post_init__(self): assert self.rollouts_per_example > 1, ( "2 or more rollouts per example are required." ) + + assert 0.0 <= self.stepwise_gamma <= 1.0, ( + "stepwise_gamma must be between 0.0 and 1.0." + ) + + assert self.stepwise_aggregation in ["sum", "max"], ( + "stepwise_aggregation must be either 'sum' or 'max'." + ) From f97729b861b08bc8ae6c0ec5825de91a28ec9db2 Mon Sep 17 00:00:00 2001 From: Kyle Avery <9327972+kyleavery@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:50:26 -0600 Subject: [PATCH 6/6] fix logging --- verifiers/rl/trainer/generator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/verifiers/rl/trainer/generator.py b/verifiers/rl/trainer/generator.py index c6988ef78..9f445b6a0 100644 --- a/verifiers/rl/trainer/generator.py +++ b/verifiers/rl/trainer/generator.py @@ -287,6 +287,8 @@ async def generate_batch(self, batch_id: int) -> Batch: all_completion_logprobs: list[list[float]] = [] all_returns: list[float] = [] item_meta: list[tuple[int, int]] = [] # (prompt_idx, step_idx) + # rollout-level rewards for logging only (one per rollout) + rollout_rewards_for_logging: list[float] = [] # iterate in the same order as env_results arrays for i, (prompt, completion, state) in enumerate( @@ -413,6 +415,8 @@ def _maybe_strip_think(msg: dict) -> dict: raise RuntimeError("Per-turn scores missing in state for stepwise advantage computation") returns: list[float] = self._compute_stepwise_returns(step_rewards) + # For logging: rollout-level reward is the return at first assistant turn + rollout_rewards_for_logging.append(float(returns[0]) if returns else 0.0) else: assert isinstance(prompt, str) and isinstance(completion, str) @@ -441,6 +445,8 @@ def _maybe_strip_think(msg: dict) -> dict: step_rewards.append(float(rs.reward)) returns: list[float] = self._compute_stepwise_returns(step_rewards) + # For logging: rollout-level reward is the return at first assistant turn + rollout_rewards_for_logging.append(float(returns[0]) if returns else 0.0) for step_idx, item in enumerate(step_items): p_ids, p_mask, c_ids, c_mask, c_logps = item @@ -490,7 +496,10 @@ def __init__(self): # mimic ProcessedOutputs for downstream use processed_results = _Proc() - rewards_dict = {"reward": all_returns} + # For logging, include one reward per rollout to align with prompts/completions + rewards_dict = {"reward": rollout_rewards_for_logging} + for k in env_results.metrics: + rewards_dict[k] = env_results.metrics[k] rewards = all_returns # Compute stepwise group baseline across all m×n samples for each prompt