From 4a35f2b470aaac6f5ae0bb3b2532e77583a5c4e0 Mon Sep 17 00:00:00 2001
From: Kyle Avery <9327972+kyleavery@users.noreply.github.com>
Date: Mon, 3 Nov 2025 16:22:17 -0600
Subject: [PATCH 1/6] add exclude_think param to MultiTurnEnv
---
tests/test_multiturn_env.py | 95 +++++++++++++++++++++++++++++++++
verifiers/envs/multiturn_env.py | 47 +++++++++++++++-
2 files changed, 140 insertions(+), 2 deletions(-)
diff --git a/tests/test_multiturn_env.py b/tests/test_multiturn_env.py
index 46834ee8c..67cdbe045 100644
--- a/tests/test_multiturn_env.py
+++ b/tests/test_multiturn_env.py
@@ -459,3 +459,98 @@ async def test_responses_stored_in_state(self, mock_multiturn_env):
for response in state["responses"]:
assert hasattr(response, "choices")
assert len(response.choices) > 0
+
+ @pytest.mark.asyncio
+ async def test_strip_think_string_content_preserves_tail_and_tools(
+ self, mock_openai_client, sample_chat_dataset
+ ):
+ """Ensure only text up to is removed; tool_calls and tool messages remain."""
+ from tests.conftest import SimpleMultiTurnEnv
+
+ env = SimpleMultiTurnEnv(
+ client=mock_openai_client,
+ model="test-model",
+ dataset=sample_chat_dataset,
+ parser=Parser(),
+ rubric=Rubric(),
+ exclude_think=True,
+ )
+
+ prompt = [{"role": "user", "content": "What is 2+2?"}]
+ state = await env.init_state(
+ prompt=prompt,
+ completion=[],
+ answer="",
+ task="default",
+ info={},
+ example_id=0,
+ )
+
+ assistant_msg = {
+ "role": "assistant",
+ "content": "\nprivate reasoning\n\nCall tool A",
+ "tool_calls": [
+ {
+ "id": "id1",
+ "type": "function",
+ "function": {"name": "toolA", "arguments": "{}"},
+ }
+ ],
+ }
+ tool_msg = {"role": "tool", "content": "resultA", "tool_call_id": "id1"}
+ state["completion"].extend([assistant_msg, tool_msg])
+
+ ctx = await env.get_context_messages(state)
+ assert isinstance(ctx, list)
+
+ assert ctx[0] == prompt[0]
+ assert ctx[1]["role"] == "assistant"
+ assert ctx[1]["content"] == "Call tool A"
+ assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"]
+ assert ctx[2] == tool_msg
+
+ @pytest.mark.asyncio
+ async def test_no_think_content_is_passthrough(
+ self, mock_openai_client, sample_chat_dataset
+ ):
+ """If no present, assistant content remains unchanged."""
+ from tests.conftest import SimpleMultiTurnEnv
+
+ env = SimpleMultiTurnEnv(
+ client=mock_openai_client,
+ model="test-model",
+ dataset=sample_chat_dataset,
+ parser=Parser(),
+ rubric=Rubric(),
+ exclude_think=True,
+ )
+
+ prompt = [{"role": "user", "content": "Q"}]
+ state = await env.init_state(
+ prompt=prompt,
+ completion=[],
+ answer="",
+ task="default",
+ info={},
+ example_id=0,
+ )
+
+ assistant_msg = {
+ "role": "assistant",
+ "content": "No CoT here, proceed to tool",
+ "tool_calls": [
+ {
+ "id": "id3",
+ "type": "function",
+ "function": {"name": "toolC", "arguments": "{}"},
+ }
+ ],
+ }
+ tool_msg = {"role": "tool", "content": "resultC", "tool_call_id": "id3"}
+ state["completion"].extend([assistant_msg, tool_msg])
+
+ ctx = await env.get_context_messages(state)
+ assert isinstance(ctx, list)
+ assert ctx[1]["content"] == assistant_msg["content"]
+ assert ctx[1].get("tool_calls") == assistant_msg["tool_calls"]
+ assert ctx[2] == tool_msg
diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py
index 6f78b3b66..76a3a3aa7 100644
--- a/verifiers/envs/multiturn_env.py
+++ b/verifiers/envs/multiturn_env.py
@@ -20,9 +20,15 @@
class MultiTurnEnv(Environment):
- def __init__(self, max_turns: int = -1, **kwargs):
+ def __init__(
+ self,
+ max_turns: int = -1,
+ exclude_think: bool = False,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.max_turns = max_turns
+ self.exclude_think = exclude_think
async def prompt_too_long(self, state: State) -> bool:
return state.get("prompt_too_long", False)
@@ -49,8 +55,45 @@ async def env_response(
"""
pass
+ @staticmethod
+ def _process_assistant_message(msg: ChatMessage) -> ChatMessage:
+ import re
+
+ def _strip_prefix_up_to_close(text: str) -> str:
+ return re.sub(r"(?s)^.*", "", text).lstrip()
+
+ new_msg: ChatMessage = {"role": msg.get("role", "assistant")}
+ if "tool_calls" in msg:
+ new_msg["tool_calls"] = msg["tool_calls"]
+
+ content = msg.get("content")
+ if content is None:
+ new_msg["content"] = ""
+ return new_msg
+
+ if "" in content:
+ new_msg["content"] = _strip_prefix_up_to_close(content)
+ else:
+ new_msg["content"] = content
+
+ return new_msg
+
async def get_context_messages(self, state: State) -> Messages:
- return state["prompt"] + state["completion"]
+ if not self.exclude_think:
+ return state["prompt"] + state["completion"]
+
+ prompt_msgs = state["prompt"]
+ completion_msgs = state["completion"]
+
+ processed_completion: list[ChatMessage] = []
+ for m in completion_msgs:
+ role = m.get("role")
+ if role == "assistant":
+ processed_completion.append(self._process_assistant_message(m))
+ else:
+ processed_completion.append(m)
+
+ return prompt_msgs + processed_completion
async def rollout(
self,
From 1b12878640b59a18fae4faef6c43b92ff89a8c49 Mon Sep 17 00:00:00 2001
From: Kyle Avery <9327972+kyleavery@users.noreply.github.com>
Date: Tue, 4 Nov 2025 18:26:01 -0600
Subject: [PATCH 2/6] implement stepwise advantages
---
verifiers/envs/multiturn_env.py | 68 +++--
verifiers/rl/README.md | 3 +
verifiers/rl/trainer/generator.py | 452 +++++++++++++++++++++++++-----
verifiers/rl/trainer/trainer.py | 3 +
4 files changed, 438 insertions(+), 88 deletions(-)
diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py
index 76a3a3aa7..507bae201 100644
--- a/verifiers/envs/multiturn_env.py
+++ b/verifiers/envs/multiturn_env.py
@@ -58,13 +58,13 @@ async def env_response(
@staticmethod
def _process_assistant_message(msg: ChatMessage) -> ChatMessage:
import re
+ from copy import deepcopy
def _strip_prefix_up_to_close(text: str) -> str:
return re.sub(r"(?s)^.*", "", text).lstrip()
- new_msg: ChatMessage = {"role": msg.get("role", "assistant")}
- if "tool_calls" in msg:
- new_msg["tool_calls"] = msg["tool_calls"]
+ new_msg: ChatMessage = deepcopy(msg)
+ new_msg["role"] = msg.get("role", "assistant")
content = msg.get("content")
if content is None:
@@ -102,7 +102,7 @@ async def rollout(
prompt: Messages,
completion: Messages | None = None,
answer: str = "",
- state: State = {},
+ state: State | None = None,
task: str = "default",
info: Info | None = None,
example_id: int = 0,
@@ -118,6 +118,9 @@ async def rollout(
state = state or await self.init_state(
prompt, completion, answer, task, info, example_id
)
+ track_step_scores: bool = bool(kwargs.get("track_step_scores", False))
+ if track_step_scores and "step_scores" not in state:
+ state["step_scores"] = []
start_time = time.time()
state = await maybe_await(self.setup_state, state, **kwargs)
if self.message_type == "chat":
@@ -128,10 +131,13 @@ async def rollout(
assert isinstance(state["completion"], str)
state["responses_start_idx"] = []
while not is_completed:
+ # 1. Build current context and check early termination
context_messages = await self.get_context_messages(state)
if await maybe_await(self.is_completed, context_messages, state, **kwargs):
is_completed = True
break
+
+ # 2. Model response for this turn
response = await self.get_model_response(
client,
model,
@@ -146,6 +152,8 @@ async def rollout(
state["prompt_too_long"] = True
break
state["responses"].append(response)
+
+ # 2a. Append assistant message to completion
response_text: str = ""
if self.message_type == "chat":
assert isinstance(context_messages, list)
@@ -161,31 +169,55 @@ async def rollout(
and response.choices[0].message
and response.choices[0].message.tool_calls
):
- response_message["tool_calls"] = response.choices[ # type: ignore
- 0
- ].message.tool_calls
+ response_message["tool_calls"] = response.choices[0].message.tool_calls # type: ignore
state["completion"].append(response_message)
else:
assert isinstance(response, Completion)
- state["responses_start_idx"].append(len(completion))
+ # track where this assistant response starts in the running text
+ state["responses_start_idx"].append(len(state["completion"]))
if response.choices and response.choices[0]:
response_text = response.choices[0].text or ""
state["completion"] += response_text
+
+ # 3) Environment feedback for THIS turn
+ # Use latest context that includes the assistant message
context_messages = await self.get_context_messages(state)
+ env_msgs, state = await maybe_await(
+ self.env_response, context_messages, state, **kwargs
+ )
+ if self.message_type == "chat":
+ assert isinstance(env_msgs, list)
+ state["completion"] += env_msgs
+ else:
+ assert isinstance(env_msgs, str)
+ state["completion"] += env_msgs
+
+ # 4) Now compute per-turn score after env feedback is appended
+ if track_step_scores:
+ try:
+ rs = await self.rubric.score_rollout(
+ prompt=state["prompt"],
+ completion=state["completion"],
+ answer=state.get("answer", ""),
+ state=state,
+ task=state.get("task", "default"),
+ info=state.get("info", {}),
+ example_id=state.get("example_id", example_id),
+ **kwargs,
+ )
+ state.setdefault("step_scores", []).append(float(rs.reward))
+ except Exception as e:
+ logger.error(f"Error computing step score: {e}")
+ # state.setdefault("step_scores", []).append(0.0)
+ raise RuntimeError(f"Step score computation failed: {e}")
+
+ # 5) Prepare for next turn
state["turn"] += 1
+ context_messages = await self.get_context_messages(state)
if await maybe_await(self.is_completed, context_messages, state, **kwargs):
is_completed = True
end_time = time.time()
state["timing"]["generation_ms"] = (end_time - start_time) * 1000
state["timing"]["total_ms"] = (end_time - start_time) * 1000
- else:
- env_msgs, state = await maybe_await(
- self.env_response, context_messages, state, **kwargs
- )
- if self.message_type == "chat":
- assert isinstance(env_msgs, list)
- state["completion"] += env_msgs
- else:
- assert isinstance(env_msgs, str)
- state["completion"] += env_msgs
+ break
return state["completion"], state
diff --git a/verifiers/rl/README.md b/verifiers/rl/README.md
index 890e2a8f0..1fdea5e5f 100644
--- a/verifiers/rl/README.md
+++ b/verifiers/rl/README.md
@@ -97,6 +97,9 @@ We have removed a number of features from the previous `GRPOTrainer`, in favor o
- `rollouts_per_example`: rollouts per example/prompt (default is `16`)
- `max_seq_len`: the maximum sequence length for the training (default is `2048`)
- `max_steps`: the maximum number of steps for the training (default is `500`)
+ - `use_stepwise_advantage`: if `True`, each assistant turn becomes its own training sample and advantages are computed from a discounted per-step return (default `False`)
+ - `stepwise_aggregation`: aggregation for stepwise returns: `"sum"` (default) or `"max"`.
+ - `stepwise_gamma`: discount factor `gamma` for stepwise returns (default `0.4`). With `stepwise_aggregation="sum"`, `R_t=\sum_{i=t}^{T}{\gamma^{i-t} r_i}`; with `stepwise_aggregation="max"`, `R_t=\max_{i=t..T}{\gamma^{i-t} r_i}`.
- Sampling configuration arguments:
- `max_tokens`: the maximum number of tokens per request (default is `None`)
- `temperature`: the temperature for the sampling (default is `0.7`)
diff --git a/verifiers/rl/trainer/generator.py b/verifiers/rl/trainer/generator.py
index abde0a407..c6988ef78 100644
--- a/verifiers/rl/trainer/generator.py
+++ b/verifiers/rl/trainer/generator.py
@@ -1,4 +1,5 @@
import asyncio
+import json
import logging
import queue
import threading
@@ -13,6 +14,12 @@
from transformers import PreTrainedTokenizerBase
from verifiers import Environment
+from verifiers.utils.processing_utils import (
+ parse_chat_completion_logprobs,
+ parse_chat_completion_tokens,
+ parse_completion_logprobs,
+ parse_completion_tokens,
+)
class Microbatch(BaseModel):
@@ -66,6 +73,9 @@ def __init__(
mask_truncated_completions: bool,
zero_truncated_completions: bool,
max_concurrent: int,
+ use_stepwise_advantage: bool,
+ stepwise_gamma: float,
+ stepwise_aggregation: str,
):
self.env = env
self.client_base_url = client_base_url
@@ -87,6 +97,9 @@ def __init__(
self.mask_truncated_completions = mask_truncated_completions
self.zero_truncated_completions = zero_truncated_completions
self.max_concurrent = max_concurrent
+ self.use_stepwise_advantage = use_stepwise_advantage
+ self.stepwise_gamma = float(stepwise_gamma)
+ self.stepwise_aggregation = stepwise_aggregation
# queues for communication
self.request_queue = queue.Queue()
@@ -225,41 +238,273 @@ async def generate_batch(self, batch_id: int) -> Batch:
sampling_args=self.sampling_args,
score_rollouts=True,
max_concurrent=self.max_concurrent,
+ track_step_scores=self.use_stepwise_advantage,
)
self.is_generating = False
wall_clock_s = time.time() - start_time
- processed_results = self.env.process_env_results_vllm(
- prompts=env_results.prompt,
- completions=env_results.completion,
- states=env_results.state,
- rewards=env_results.reward,
- processing_class=self.processing_class,
- max_seq_len=self.max_seq_len,
- mask_env_responses=self.mask_env_responses,
- mask_truncated_completions=self.mask_truncated_completions,
- zero_truncated_completions=self.zero_truncated_completions,
- )
+ if not self.use_stepwise_advantage:
+ processed_results = self.env.process_env_results_vllm(
+ prompts=env_results.prompt,
+ completions=env_results.completion,
+ states=env_results.state,
+ rewards=env_results.reward,
+ processing_class=self.processing_class,
+ max_seq_len=self.max_seq_len,
+ mask_env_responses=self.mask_env_responses,
+ mask_truncated_completions=self.mask_truncated_completions,
+ zero_truncated_completions=self.zero_truncated_completions,
+ )
- rewards_dict = {"reward": processed_results.rewards}
- for k in env_results.metrics:
- rewards_dict[k] = env_results.metrics[k]
-
- rewards: list[float] = processed_results.rewards
- advantages: list[float] = [0.0] * len(rewards)
- prompts_in_batch = len(batch_ds)
- for prompt_idx in range(prompts_in_batch):
- group_indices = [
- prompt_idx + k * prompts_in_batch
- for k in range(self.rollouts_per_example)
- if (prompt_idx + k * prompts_in_batch) < len(rewards)
- ]
- if not group_indices:
- continue
- group = [rewards[i] for i in group_indices]
- gmean = sum(group) / float(len(group))
- for idx, r in zip(group_indices, group):
- advantages[idx] = r - gmean
+ rewards_dict = {"reward": processed_results.rewards}
+ for k in env_results.metrics:
+ rewards_dict[k] = env_results.metrics[k]
+
+ rewards: list[float] = processed_results.rewards
+ advantages: list[float] = [0.0] * len(rewards)
+ prompts_in_batch = len(batch_ds)
+ for prompt_idx in range(prompts_in_batch):
+ group_indices = [
+ prompt_idx + k * prompts_in_batch
+ for k in range(self.rollouts_per_example)
+ if (prompt_idx + k * prompts_in_batch) < len(rewards)
+ ]
+ if not group_indices:
+ continue
+ group = [rewards[i] for i in group_indices]
+ gmean = sum(group) / float(len(group))
+ for idx, r in zip(group_indices, group):
+ advantages[idx] = r - gmean
+ else:
+ # Expand each rollout into per-step training samples and compute
+ # discounted MC returns per step.
+ # Reference: https://arxiv.org/abs/2507.11948
+ msg_type = getattr(self.env, "message_type", "chat")
+ all_prompt_ids: list[list[int]] = []
+ all_prompt_masks: list[list[int]] = []
+ all_completion_ids: list[list[int]] = []
+ all_completion_masks: list[list[int]] = []
+ all_completion_logprobs: list[list[float]] = []
+ all_returns: list[float] = []
+ item_meta: list[tuple[int, int]] = [] # (prompt_idx, step_idx)
+
+ # iterate in the same order as env_results arrays
+ for i, (prompt, completion, state) in enumerate(
+ zip(env_results.prompt, env_results.completion, env_results.state)
+ ):
+ # determine prompt index within this batch
+ prompt_idx = i % len(batch_ds)
+ # build per-step tokenization
+ step_items: list[tuple[list[int], list[int], list[int], list[int], list[float]]]
+ if msg_type == "chat":
+ assert isinstance(prompt, list) and isinstance(completion, list)
+ step_items = []
+ # Build context for tokenization similar to process_chat_format_vllm
+ responses = state["responses"]
+ responses_idx = 0
+ zipped_steps = []
+ for turn in completion:
+ if turn.get("role") == "assistant":
+ zipped_steps.append((turn, responses[responses_idx]))
+ responses_idx += 1
+ else:
+ zipped_steps.append((turn, None))
+ assert responses_idx == len(responses)
+
+ # utility to deserialize tool_calls for templates that expect JSON args
+ def _deserialize_tool_calls(message: dict) -> dict:
+ def _deserialize(tc) -> dict:
+ tc = dict(tc)
+ if (
+ "function" in tc
+ and isinstance(tc["function"], dict)
+ and "arguments" in tc["function"]
+ ):
+ args = tc["function"]["arguments"]
+ if isinstance(args, str):
+ try:
+ args = json.loads(args)
+ except Exception:
+ pass
+ tc["function"] = {**tc["function"], "arguments": args}
+ return tc
+
+ return {
+ **message,
+ "tool_calls": [
+ _deserialize(tc)
+ for tc in (message.get("tool_calls", []) or [])
+ ],
+ }
+
+ def _maybe_strip_think(msg: dict) -> dict:
+ if getattr(self.env, "exclude_think", False) and hasattr(self.env, "_process_assistant_message"):
+ if msg.get("role") == "assistant":
+ return self.env._process_assistant_message(msg)
+ return msg
+
+ messages_consumed: list[dict] = [_maybe_strip_think(dict(m)) for m in prompt]
+ si = 0
+ j = 0
+ while j < len(zipped_steps):
+ message, response = zipped_steps[j]
+ message = _deserialize_tool_calls(message)
+ if message.get("role") == "assistant":
+ assert response is not None
+ prompt_text = self.processing_class.apply_chat_template(
+ conversation=messages_consumed,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ prompt_ids = self.processing_class.encode(prompt_text)
+ prompt_mask = [0] * len(prompt_ids)
+ completion_ids = parse_chat_completion_tokens(response)
+ completion_mask = [1] * len(completion_ids)
+ completion_logprobs = parse_chat_completion_logprobs(response)
+ step_items.append(
+ (
+ prompt_ids,
+ prompt_mask,
+ completion_ids,
+ completion_mask,
+ completion_logprobs,
+ )
+ )
+ messages_consumed.append(_maybe_strip_think(message))
+ si += 1
+ j += 1
+ else:
+ messages_consumed.append(_maybe_strip_think(message))
+ j += 1
+ else:
+ assert isinstance(prompt, str) and isinstance(completion, str)
+ responses = state.get("responses", [])
+ starts = state.get("responses_start_idx", [])
+ assert len(responses) == len(starts)
+ step_items = []
+ for ridx in range(len(responses)):
+ start_i = starts[ridx]
+ context_prefix = prompt + completion[:start_i]
+ prompt_ids = self.processing_class.encode(context_prefix)
+ prompt_mask = [0] * len(prompt_ids)
+ resp = responses[ridx]
+ completion_ids = parse_completion_tokens(resp)
+ completion_mask = [1] * len(completion_ids)
+ completion_logprobs = parse_completion_logprobs(resp)
+ step_items.append(
+ (
+ prompt_ids,
+ prompt_mask,
+ completion_ids,
+ completion_mask,
+ completion_logprobs,
+ )
+ )
+
+ # compute immediate rewards per step and MC returns
+ if msg_type == "chat":
+ assert isinstance(prompt, list) and isinstance(completion, list)
+ # Prefer precomputed per-turn scores from rollout if available
+ step_rewards: list[float] = []
+ pre_scores = state.get("step_scores", None)
+ if isinstance(pre_scores, list) and pre_scores:
+ step_rewards = [float(x) for x in pre_scores]
+ else:
+ raise RuntimeError("Per-turn scores missing in state for stepwise advantage computation")
+
+ returns: list[float] = self._compute_stepwise_returns(step_rewards)
+
+ else:
+ assert isinstance(prompt, str) and isinstance(completion, str)
+ responses = state.get("responses", [])
+ starts = state.get("responses_start_idx", [])
+ step_rewards = []
+ assert len(responses) == len(starts)
+ pre_scores = state.get("step_scores", None)
+ if isinstance(pre_scores, list) and pre_scores:
+ step_rewards = [float(x) for x in pre_scores]
+ step_rewards = step_rewards[: len(starts)]
+ else:
+ for ridx, start in enumerate(starts):
+ # Include env feedback after this assistant response
+ end_i = starts[ridx + 1] if (ridx + 1) < len(starts) else len(completion)
+ partial_text = completion[:end_i]
+ rs = await self.env.rubric.score_rollout(
+ prompt=prompt,
+ completion=partial_text,
+ answer=state.get("answer", ""),
+ state=state,
+ task=state.get("task", "default"),
+ info=state.get("info", {}),
+ example_id=state.get("example_id", i),
+ )
+ step_rewards.append(float(rs.reward))
+
+ returns: list[float] = self._compute_stepwise_returns(step_rewards)
+
+ for step_idx, item in enumerate(step_items):
+ p_ids, p_mask, c_ids, c_mask, c_logps = item
+ completion_truncated = False
+ if self.max_seq_len > 0:
+ max_c_possible = min(len(c_ids), self.max_seq_len)
+ keep_p = self.max_seq_len - max_c_possible
+ if keep_p < len(p_ids):
+ if keep_p <= 0:
+ p_ids, p_mask = [], []
+ else:
+ p_ids = p_ids[-keep_p:]
+ p_mask = p_mask[-keep_p:]
+
+ max_c_len = self.max_seq_len - len(p_ids)
+ if len(c_ids) > max_c_len:
+ completion_truncated = True
+ c_ids = c_ids[:max_c_len] if max_c_len > 0 else []
+ c_mask = c_mask[:max_c_len] if max_c_len > 0 else []
+ c_logps = c_logps[:max_c_len] if max_c_len > 0 else []
+
+ effective_c_mask = c_mask
+ if completion_truncated and self.mask_truncated_completions:
+ effective_c_mask = [0] * len(c_ids)
+
+ ret = float(returns[step_idx])
+ if completion_truncated and self.zero_truncated_completions:
+ ret = 0.0
+
+ all_prompt_ids.append(p_ids)
+ all_prompt_masks.append(p_mask)
+ all_completion_ids.append(c_ids)
+ all_completion_masks.append(effective_c_mask)
+ all_completion_logprobs.append(c_logps)
+ all_returns.append(ret)
+ item_meta.append((prompt_idx, step_idx))
+
+
+ class _Proc:
+ def __init__(self):
+ self.prompt_ids = all_prompt_ids
+ self.prompt_mask = all_prompt_masks
+ self.completion_ids = all_completion_ids
+ self.completion_mask = all_completion_masks
+ self.completion_logprobs = all_completion_logprobs
+ self.rewards = all_returns
+
+ # mimic ProcessedOutputs for downstream use
+ processed_results = _Proc()
+ rewards_dict = {"reward": all_returns}
+
+ rewards = all_returns
+ # Compute stepwise group baseline across all m×n samples for each prompt
+ advantages: list[float] = [0.0] * len(rewards)
+ prompts_in_batch = len(batch_ds)
+ for prompt_idx in range(prompts_in_batch):
+ group = [j for j, (p, _s) in enumerate(item_meta) if p == prompt_idx]
+ if not group:
+ continue
+ group_vals = [rewards[j] for j in group]
+ gmean = float(np.mean(group_vals))
+ gstd = float(np.std(group_vals)) + 1e-8 # prevent div by zero
+ for j in group:
+ advantages[j] = (rewards[j] - gmean) / gstd
metrics_dict = {}
if rewards:
@@ -267,6 +512,10 @@ async def generate_batch(self, batch_id: int) -> Batch:
metrics_dict["reward"] = float(rewards_arr.mean())
metrics_dict["reward/std"] = float(rewards_arr.std())
+ if self.use_stepwise_advantage:
+ metrics_dict["stepwise/turns_per_rollout"] = float(np.mean([len(s.get("step_scores", [])) for s in env_results.state]))
+ metrics_dict["stepwise/rollout_length"] = float(np.mean([len(s.get("responses", [])) for s in env_results.state]))
+
if advantages:
adv_arr = np.asarray(advantages, dtype=np.float32)
metrics_dict["advantage/absmean"] = float(np.abs(adv_arr).mean())
@@ -315,48 +564,85 @@ async def generate_batch(self, batch_id: int) -> Batch:
# build per-process microbatches
N = len(processed_results.rewards)
- per_proc = N // self.num_processes
microbatches: list[list[Microbatch]] = []
items_per_process: list[int] = []
- for proc in range(self.num_processes):
- ps = proc * per_proc
- pe = ps + per_proc
- proc_mbs: list[Microbatch] = []
- proc_item_total = 0
- for s in range(ps, pe, self.micro_batch_size):
- e = min(s + self.micro_batch_size, pe)
- ids_chunk = [
- processed_results.prompt_ids[i]
- + processed_results.completion_ids[i]
- for i in range(s, e)
- ]
- mask_chunk = [
- processed_results.prompt_mask[i]
- + processed_results.completion_mask[i]
- for i in range(s, e)
- ]
- slogp_chunk = [
- [0.0] * len(processed_results.prompt_mask[i])
- + processed_results.completion_logprobs[i]
- for i in range(s, e)
- ]
- lengths = [len(mask) for mask in mask_chunk]
- adv_chunk = [
- [advantages[i]] * lengths[idx]
- for idx, i in enumerate(list(range(s, e)))
- ]
- mb_items = sum(sum(mask) for mask in mask_chunk)
- microbatch = Microbatch(
- input_ids=ids_chunk,
- loss_mask=mask_chunk,
- sampling_logprobs=slogp_chunk,
- advantages=adv_chunk,
- items=mb_items,
- )
- proc_item_total += mb_items
- proc_mbs.append(microbatch)
- microbatches.append(proc_mbs)
- items_per_process.append(proc_item_total)
+ if not self.use_stepwise_advantage:
+ per_proc = N // self.num_processes
+ for proc in range(self.num_processes):
+ ps = proc * per_proc
+ pe = ps + per_proc
+ proc_mbs: list[Microbatch] = []
+ proc_item_total = 0
+ for s in range(ps, pe, self.micro_batch_size):
+ e = min(s + self.micro_batch_size, pe)
+ ids_chunk = [
+ processed_results.prompt_ids[i]
+ + processed_results.completion_ids[i]
+ for i in range(s, e)
+ ]
+ mask_chunk = [
+ processed_results.prompt_mask[i]
+ + processed_results.completion_mask[i]
+ for i in range(s, e)
+ ]
+ slogp_chunk = [
+ [0.0] * len(processed_results.prompt_mask[i])
+ + processed_results.completion_logprobs[i]
+ for i in range(s, e)
+ ]
+ lengths = [len(mask) for mask in mask_chunk]
+ adv_chunk = [
+ [advantages[i]] * lengths[idx]
+ for idx, i in enumerate(list(range(s, e)))
+ ]
+ mb_items = sum(sum(mask) for mask in mask_chunk)
+ microbatch = Microbatch(
+ input_ids=ids_chunk,
+ loss_mask=mask_chunk,
+ sampling_logprobs=slogp_chunk,
+ advantages=adv_chunk,
+ items=mb_items,
+ )
+ proc_item_total += mb_items
+ proc_mbs.append(microbatch)
+ microbatches.append(proc_mbs)
+ items_per_process.append(proc_item_total)
+ else:
+ for proc in range(self.num_processes):
+ indices = list(range(proc, N, self.num_processes))
+ proc_mbs: list[Microbatch] = []
+ proc_item_total = 0
+ for start in range(0, len(indices), self.micro_batch_size):
+ idxs = indices[start : start + self.micro_batch_size]
+ ids_chunk = [
+ processed_results.prompt_ids[i]
+ + processed_results.completion_ids[i]
+ for i in idxs
+ ]
+ mask_chunk = [
+ processed_results.prompt_mask[i]
+ + processed_results.completion_mask[i]
+ for i in idxs
+ ]
+ slogp_chunk = [
+ [0.0] * len(processed_results.prompt_mask[i])
+ + processed_results.completion_logprobs[i]
+ for i in idxs
+ ]
+ lengths = [len(mask) for mask in mask_chunk]
+ adv_chunk = [[advantages[i]] * lengths[k] for k, i in enumerate(idxs)]
+ mb_items = sum(sum(mask) for mask in mask_chunk)
+ microbatch = Microbatch(
+ input_ids=ids_chunk,
+ loss_mask=mask_chunk,
+ sampling_logprobs=slogp_chunk,
+ advantages=adv_chunk,
+ items=mb_items,
+ )
+ proc_item_total += mb_items
+ proc_mbs.append(microbatch)
+ microbatches.append(proc_mbs)
+ items_per_process.append(proc_item_total)
global_item_count = sum(items_per_process)
@@ -371,3 +657,29 @@ async def generate_batch(self, batch_id: int) -> Batch:
prompts=env_results.prompt,
metrics_dict=metrics_dict,
)
+
+ def _compute_stepwise_returns(self, step_rewards: list[float]) -> list[float]:
+ if not step_rewards:
+ return []
+
+ g = float(self.stepwise_gamma)
+ if self.stepwise_aggregation == "sum":
+ # R_t=\sum_{i=t}^{T}{\gamma^{i-t} r_i}
+ G = 0.0
+ out = [0.0] * len(step_rewards)
+ for t in range(len(step_rewards) - 1, -1, -1):
+ G = float(step_rewards[t]) + g * G
+ out[t] = G
+ return out
+
+ elif self.stepwise_aggregation == "max":
+ # R_t=\max_{i=t..T}{\gamma^{i-t} r_i}
+ out = [0.0] * len(step_rewards)
+ R_next: float | None = None
+ for t in range(len(step_rewards) - 1, -1, -1):
+ r = float(step_rewards[t])
+ cand_future = (g * R_next) if (R_next is not None) else None
+ R_t = r if cand_future is None else max(r, cand_future)
+ out[t] = R_t
+ R_next = R_t
+ return out
diff --git a/verifiers/rl/trainer/trainer.py b/verifiers/rl/trainer/trainer.py
index bbeeef6a2..054685a33 100644
--- a/verifiers/rl/trainer/trainer.py
+++ b/verifiers/rl/trainer/trainer.py
@@ -112,6 +112,9 @@ def __init__(
mask_truncated_completions=args.mask_truncated_completions,
zero_truncated_completions=args.zero_truncated_completions,
max_concurrent=args.max_concurrent,
+ use_stepwise_advantage=args.use_stepwise_advantage,
+ stepwise_gamma=args.stepwise_gamma,
+ stepwise_aggregation=args.stepwise_aggregation,
)
self.generator.start()
self.generator.submit_batch(0)
From f98f5d88b28eb471542e746ebf54b04efc8132a2 Mon Sep 17 00:00:00 2001
From: Kyle Avery <9327972+kyleavery@users.noreply.github.com>
Date: Wed, 5 Nov 2025 07:42:15 -0600
Subject: [PATCH 3/6] stepwise tests
---
tests/test_stepwise_advantages.py | 477 ++++++++++++++++++++++++++++++
1 file changed, 477 insertions(+)
create mode 100644 tests/test_stepwise_advantages.py
diff --git a/tests/test_stepwise_advantages.py b/tests/test_stepwise_advantages.py
new file mode 100644
index 000000000..3025416b3
--- /dev/null
+++ b/tests/test_stepwise_advantages.py
@@ -0,0 +1,477 @@
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import pytest
+from datasets import Dataset
+
+from verifiers.types import GenerateMetadata, GenerateOutputs
+from verifiers.rl.trainer.generator import Generator
+from verifiers.envs.multiturn_env import MultiTurnEnv
+from verifiers.rubrics.math_rubric import MathRubric
+
+
+class DummyTokenizer:
+ def encode(self, text: str) -> list[int]:
+ # Deterministic tokenization proportional to text length
+ return list(range(len(text)))
+
+ def apply_chat_template(
+ self,
+ conversation: list[dict],
+ tokenize: bool = False,
+ add_generation_prompt: bool = True,
+ ) -> str:
+ # Simplified chat template: concatenate contents only
+ return "".join(m.get("content", "") for m in conversation)
+
+
+class StepwiseChatEnv(MultiTurnEnv):
+ def __init__(
+ self,
+ dataset: Dataset,
+ prepared: list[tuple[list[dict], list[dict], dict[str, Any]]],
+ ):
+ rb = MathRubric()
+ super().__init__(
+ exclude_think=True,
+ dataset=dataset,
+ message_type="chat",
+ rubric=rb,
+ parser=rb.parser,
+ )
+ self._prepared = prepared
+
+ async def env_response(self, messages, state, **kwargs):
+ return [], state
+
+ async def generate(self, *args, **kwargs) -> GenerateOutputs: # type: ignore[override]
+ prompts: list[Any] = []
+ completions: list[Any] = []
+ states: list[dict[str, Any]] = []
+ answers: list[str] = []
+ tasks: list[str] = []
+ infos: list[dict[str, Any]] = []
+ example_ids: list[int] = []
+
+ for i, (p, c, s) in enumerate(self._prepared):
+ s = dict(s)
+ s.setdefault("timing", {})
+ s["timing"].setdefault("generation_ms", 0.0)
+ s["timing"].setdefault("scoring_ms", 0.0)
+ s["timing"].setdefault("total_ms", 0.0)
+
+ if "step_scores" not in s:
+ step_scores: list[float] = []
+ j = 0
+ while j < len(c):
+ msg = c[j]
+ if msg.get("role") == "assistant":
+ k = j + 1
+ while k < len(c) and c[k].get("role") != "assistant":
+ k += 1
+ partial = c[:k]
+ rs = await self.rubric.score_rollout(
+ prompt=p,
+ completion=list(partial),
+ answer=s.get("answer", ""),
+ state=s,
+ task=s.get("task", "default"),
+ info=s.get("info", {}),
+ example_id=i,
+ )
+ step_scores.append(float(rs.reward))
+ j = k
+ else:
+ j += 1
+ s["step_scores"] = step_scores
+
+ prompts.append(p)
+ completions.append(c)
+ states.append(s)
+ answers.append(s.get("answer", ""))
+ tasks.append(s.get("task", "default"))
+ infos.append(s.get("info", {}))
+ example_ids.append(i)
+
+ meta = GenerateMetadata(
+ env_id="stub-chat",
+ env_args={},
+ model="test-model",
+ base_url="http://localhost/v1",
+ num_examples=len(prompts),
+ rollouts_per_example=1,
+ sampling_args={},
+ date="1970-01-01",
+ time_ms=0.0,
+ avg_reward=0.0,
+ avg_metrics={},
+ state_columns=[],
+ path_to_save=Path("/tmp/stub"),
+ )
+ return GenerateOutputs(
+ prompt=prompts,
+ completion=completions,
+ answer=answers,
+ state=states,
+ task=tasks,
+ info=infos,
+ example_id=example_ids,
+ reward=[0.0] * len(prompts),
+ metrics={},
+ metadata=meta,
+ )
+
+
+def chat_prompt(text: str) -> list[dict]:
+ return [{"role": "user", "content": text}]
+
+
+def assistant_msg(boxed: str, think: str = "chain-of-thought") -> dict:
+ return {
+ "role": "assistant",
+ "content": f"{think} Visible: \\boxed{{{boxed}}}",
+ }
+
+
+def user_env_msg(text: str = "ack") -> dict:
+ return {"role": "user", "content": text}
+
+
+def make_chat_rollout(
+ prompt: list[dict],
+ answer: str,
+ boxed_sequence: list[str],
+ token_lengths: list[int] | None = None,
+ step_scores: list[float] | None = None,
+) -> tuple[list[dict], list[dict], dict[str, Any]]:
+ token_lengths = token_lengths or [3] * len(boxed_sequence)
+ completion: list[dict] = []
+ responses = []
+ for b, tlen in zip(boxed_sequence, token_lengths):
+ completion.append(assistant_msg(b))
+ completion.append(user_env_msg("ok"))
+ responses.append({"tokens_len": int(tlen)})
+
+ state = {
+ "prompt": prompt,
+ "completion": completion,
+ "responses": responses,
+ "turn": len(boxed_sequence),
+ "timing": {"total_ms": 0.0},
+ "task": "default",
+ "info": {},
+ "answer": answer,
+ }
+ if step_scores is not None:
+ state["step_scores"] = [float(x) for x in step_scores]
+ return prompt, completion, state
+
+
+def compute_discounted_returns(
+ rewards: list[float], gamma: float, aggregation: str
+) -> list[float]:
+ if not rewards:
+ return []
+ g = float(gamma)
+ if aggregation == "sum":
+ out = [0.0] * len(rewards)
+ G = 0.0
+ for t in range(len(rewards) - 1, -1, -1):
+ G = float(rewards[t]) + g * G
+ out[t] = G
+ return out
+ else:
+ out = [0.0] * len(rewards)
+ R_next: float | None = None
+ for t in range(len(rewards) - 1, -1, -1):
+ r = float(rewards[t])
+ cand = (g * R_next) if (R_next is not None) else None
+ R_t = r if cand is None else max(r, cand)
+ out[t] = R_t
+ R_next = R_t
+ return out
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "aggregation,gamma",
+ [
+ ("sum", 1.0),
+ ("max", 0.5),
+ ],
+)
+async def test_stepwise_chat_precomputed_returns_and_metrics_with_advantage_checks(
+ monkeypatch: pytest.MonkeyPatch, aggregation: str, gamma: float
+):
+ import verifiers.rl.trainer.generator as gen_mod
+
+ monkeypatch.setattr(
+ gen_mod,
+ "parse_chat_completion_tokens",
+ lambda resp: list(range(resp.get("tokens_len", 2))),
+ )
+ monkeypatch.setattr(
+ gen_mod,
+ "parse_chat_completion_logprobs",
+ lambda resp: [-1.0] * resp.get("tokens_len", 2),
+ )
+
+ # Two prompts, two rollouts each (total 4)
+ pA = chat_prompt("Compute 2+2 and provide a boxed answer.")
+ pB = chat_prompt("Compute 1+1 and provide a boxed answer.")
+
+ # Boxed sequences and precomputed step_scores (immediate rewards)
+ # A1: wrong then right -> [0,1]
+ A1 = make_chat_rollout(pA, answer="4", boxed_sequence=["3", "4"], step_scores=[0.0, 1.0])
+ # B1: both right -> [1,1]
+ B1 = make_chat_rollout(pB, answer="2", boxed_sequence=["2", "2"], step_scores=[1.0, 1.0])
+ # A2: right then wrong -> [1,0]
+ A2 = make_chat_rollout(pA, answer="4", boxed_sequence=["4", "3"], step_scores=[1.0, 0.0])
+ # B2: wrong then right -> [0,1]
+ B2 = make_chat_rollout(pB, answer="2", boxed_sequence=["3", "2"], step_scores=[0.0, 1.0])
+
+ prepared = [A1, B1, A2, B2]
+
+ ds = Dataset.from_dict({"prompt": [pA, pB]})
+ env = StepwiseChatEnv(dataset=ds, prepared=prepared)
+ tokenizer = DummyTokenizer()
+
+ g = Generator(
+ env=env,
+ client_base_url="http://localhost/v1",
+ client_api_key="test",
+ client_limit=1,
+ client_timeout=1.0,
+ model_name="m",
+ sampling_args={},
+ rollouts_per_example=2,
+ batch_size=4,
+ micro_batch_size=2,
+ num_processes=1,
+ generation_timeout=5.0,
+ processing_class=tokenizer,
+ mask_env_responses=False,
+ max_seq_len=4096,
+ max_prompt_len=4096,
+ mask_truncated_completions=False,
+ zero_truncated_completions=False,
+ max_concurrent=1,
+ use_stepwise_advantage=True,
+ stepwise_gamma=gamma,
+ stepwise_aggregation=aggregation,
+ )
+
+ monkeypatch.setattr(env, "a_generate", env.generate)
+ g.client = object()
+
+ result = await g.generate_batch(batch_id=0)
+
+ all_step_rewards = [s["step_scores"] for _p, _c, s in prepared]
+ expected = [
+ r for rewards in all_step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation)
+ ]
+
+ assert np.allclose(result.rewards_dict["reward"], expected)
+
+ assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected))
+ assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0
+ assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0
+
+ # Advantage sanity checks
+ # - per-token advantages are constant within each sample
+ # - lengths of advantages match input_ids and loss_mask
+ # - advantages are z-scored within each prompt group (mean≈0, std≈1)
+ # - advantage/absmean > 0 for non-degenerate case
+ assert result.metrics_dict.get("advantage/absmean", 0.0) > 0.0
+
+ sample_adv_scalars: list[float] = []
+ microbatches = result.microbatches[0]
+ for mb in microbatches:
+ for row_ids, row_mask, row_adv in zip(mb.input_ids, mb.loss_mask, mb.advantages):
+ assert len(row_ids) == len(row_mask) == len(row_adv)
+ uniq_vals = set(row_adv)
+ assert len(uniq_vals) == 1
+ sample_adv_scalars.append(row_adv[0])
+
+ # Reconstruct prompt groups: with two steps per rollout and two prompts,
+ # grouping used by Generator is i % prompts_in_batch; here i == j // steps_per_rollout
+ steps_per_rollout = 2
+ prompts_in_batch = len(ds)
+ N = len(sample_adv_scalars)
+ assert N == 4 * steps_per_rollout
+
+ group0 = [sample_adv_scalars[j] for j in range(N) if ((j // steps_per_rollout) % prompts_in_batch) == 0]
+ group1 = [sample_adv_scalars[j] for j in range(N) if ((j // steps_per_rollout) % prompts_in_batch) == 1]
+
+ assert pytest.approx(float(np.mean(group0)), abs=1e-6) == 0.0
+ assert pytest.approx(float(np.mean(group1)), abs=1e-6) == 0.0
+ assert pytest.approx(float(np.std(group0)), rel=1e-5) == 1.0
+ assert pytest.approx(float(np.std(group1)), rel=1e-5) == 1.0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "aggregation,gamma",
+ [
+ ("sum", 1.0),
+ ("max", 0.5),
+ ],
+)
+async def test_stepwise_chat_returns_precomputed_by_math_rubric(
+ monkeypatch: pytest.MonkeyPatch, aggregation: str, gamma: float
+):
+ import verifiers.rl.trainer.generator as gen_mod
+
+ monkeypatch.setattr(
+ gen_mod,
+ "parse_chat_completion_tokens",
+ lambda resp: list(range(resp.get("tokens_len", 2))),
+ )
+ monkeypatch.setattr(
+ gen_mod,
+ "parse_chat_completion_logprobs",
+ lambda resp: [-1.0] * resp.get("tokens_len", 2),
+ )
+
+ pA = chat_prompt("Compute 2+2 and provide a boxed answer.")
+ pB = chat_prompt("Compute 1+1 and provide a boxed answer.")
+
+ A1 = make_chat_rollout(pA, answer="4", boxed_sequence=["3", "4"])
+ B1 = make_chat_rollout(pB, answer="2", boxed_sequence=["2", "2"])
+ A2 = make_chat_rollout(pA, answer="4", boxed_sequence=["4", "3"])
+ B2 = make_chat_rollout(pB, answer="2", boxed_sequence=["3", "2"])
+
+ prepared = [A1, B1, A2, B2]
+ ds = Dataset.from_dict({"prompt": [pA, pB]})
+ env = StepwiseChatEnv(dataset=ds, prepared=prepared)
+ tokenizer = DummyTokenizer()
+
+ g = Generator(
+ env=env,
+ client_base_url="http://localhost/v1",
+ client_api_key="test",
+ client_limit=1,
+ client_timeout=1.0,
+ model_name="m",
+ sampling_args={},
+ rollouts_per_example=2,
+ batch_size=4,
+ micro_batch_size=2,
+ num_processes=1,
+ generation_timeout=5.0,
+ processing_class=tokenizer,
+ mask_env_responses=False,
+ max_seq_len=4096,
+ max_prompt_len=4096,
+ mask_truncated_completions=False,
+ zero_truncated_completions=False,
+ max_concurrent=1,
+ use_stepwise_advantage=True,
+ stepwise_gamma=gamma,
+ stepwise_aggregation=aggregation,
+ )
+
+ monkeypatch.setattr(env, "a_generate", env.generate)
+ g.client = object()
+
+ result = await g.generate_batch(batch_id=0)
+
+ # Expected immediate rewards derived from MathRubric correctness on each step
+ # A1: [0,1], B1: [1,1], A2: [1,0], B2: [0,1]
+ step_rewards = [[0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 1.0]]
+ expected = [
+ r for rewards in step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation)
+ ]
+
+ assert np.allclose(result.rewards_dict["reward"], expected)
+ assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected))
+ assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0
+ assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0
+
+
+@pytest.mark.asyncio
+async def test_stepwise_chat_truncation_zero_reward(monkeypatch: pytest.MonkeyPatch):
+ import verifiers.rl.trainer.generator as gen_mod
+
+ monkeypatch.setattr(
+ gen_mod,
+ "parse_chat_completion_tokens",
+ lambda resp: list(range(resp.get("tokens_len", 2))),
+ )
+ monkeypatch.setattr(
+ gen_mod,
+ "parse_chat_completion_logprobs",
+ lambda resp: [-1.0] * resp.get("tokens_len", 2),
+ )
+
+ pA = chat_prompt("Compute 2+2 and provide a boxed answer.")
+ pB = chat_prompt("Compute 1+1 and provide a boxed answer.")
+
+ # All steps 'correct' → immediate rewards [1,1] per rollout
+ # Force truncation on first step by assigning large token length
+ A1 = make_chat_rollout(
+ pA,
+ answer="4",
+ boxed_sequence=["4", "4"],
+ token_lengths=[50, 2],
+ step_scores=[1.0, 1.0],
+ )
+ B1 = make_chat_rollout(
+ pB, answer="2", boxed_sequence=["2", "2"], token_lengths=[2, 2], step_scores=[1.0, 1.0]
+ )
+ A2 = make_chat_rollout(
+ pA, answer="4", boxed_sequence=["4", "4"], token_lengths=[2, 2], step_scores=[1.0, 1.0]
+ )
+ B2 = make_chat_rollout(
+ pB, answer="2", boxed_sequence=["2", "2"], token_lengths=[2, 2], step_scores=[1.0, 1.0]
+ )
+
+ prepared = [A1, B1, A2, B2]
+ ds = Dataset.from_dict({"prompt": [pA, pB]})
+ env = StepwiseChatEnv(dataset=ds, prepared=prepared)
+ tokenizer = DummyTokenizer()
+
+ g = Generator(
+ env=env,
+ client_base_url="http://localhost/v1",
+ client_api_key="test",
+ client_limit=1,
+ client_timeout=1.0,
+ model_name="m",
+ sampling_args={},
+ rollouts_per_example=2,
+ batch_size=4,
+ micro_batch_size=2,
+ num_processes=1,
+ generation_timeout=5.0,
+ processing_class=tokenizer,
+ mask_env_responses=False,
+ max_seq_len=10,
+ max_prompt_len=4096,
+ mask_truncated_completions=True,
+ zero_truncated_completions=True,
+ max_concurrent=1,
+ use_stepwise_advantage=True,
+ stepwise_gamma=1.0,
+ stepwise_aggregation="sum",
+ )
+
+ monkeypatch.setattr(env, "a_generate", env.generate)
+ g.client = object()
+
+ result = await g.generate_batch(batch_id=0)
+
+ # Without truncation, each rollout [1,1] with gamma=1 → returns [2,1]
+ expected_naive = [2.0, 1.0] * 4
+ rewards = list(result.rewards_dict["reward"])
+
+ # Only the very first step (A1 step 1) should be truncated → set to 0.0
+ expected = expected_naive[:]
+ expected[0] = 0.0
+
+ assert np.allclose(rewards, expected)
+
+ # masking should be nonzero due to truncation
+ assert "tokens/masked_fraction" in result.metrics_dict
+ assert result.metrics_dict["tokens/masked_fraction"] > 0.0
From 8f5f6d0aac996de54619f5b31f9ceb57a238a9c2 Mon Sep 17 00:00:00 2001
From: Kyle Avery <9327972+kyleavery@users.noreply.github.com>
Date: Wed, 5 Nov 2025 13:51:49 -0600
Subject: [PATCH 4/6] fix stepwise tests
---
tests/test_stepwise_advantages.py | 32 +++++++++++++++++--------------
1 file changed, 18 insertions(+), 14 deletions(-)
diff --git a/tests/test_stepwise_advantages.py b/tests/test_stepwise_advantages.py
index 3025416b3..7f3155a70 100644
--- a/tests/test_stepwise_advantages.py
+++ b/tests/test_stepwise_advantages.py
@@ -268,13 +268,19 @@ async def test_stepwise_chat_precomputed_returns_and_metrics_with_advantage_chec
result = await g.generate_batch(batch_id=0)
all_step_rewards = [s["step_scores"] for _p, _c, s in prepared]
- expected = [
+ # For logging: one reward per rollout (return at first assistant turn)
+ expected_log_rewards = [
+ (compute_discounted_returns(rewards, gamma, aggregation)[0] if rewards else 0.0)
+ for rewards in all_step_rewards
+ ]
+ # For metrics: mean over per-step returns remains unchanged
+ expected_all_returns = [
r for rewards in all_step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation)
]
- assert np.allclose(result.rewards_dict["reward"], expected)
+ assert np.allclose(result.rewards_dict["reward"], expected_log_rewards)
- assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected))
+ assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected_all_returns))
assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0
assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0
@@ -380,12 +386,15 @@ async def test_stepwise_chat_returns_precomputed_by_math_rubric(
# Expected immediate rewards derived from MathRubric correctness on each step
# A1: [0,1], B1: [1,1], A2: [1,0], B2: [0,1]
step_rewards = [[0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 1.0]]
- expected = [
+ expected_log = [
+ compute_discounted_returns(rewards, gamma, aggregation)[0] for rewards in step_rewards
+ ]
+ expected_all = [
r for rewards in step_rewards for r in compute_discounted_returns(rewards, gamma, aggregation)
]
- assert np.allclose(result.rewards_dict["reward"], expected)
- assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected))
+ assert np.allclose(result.rewards_dict["reward"], expected_log)
+ assert pytest.approx(result.metrics_dict["reward"], rel=1e-6) == float(np.mean(expected_all))
assert pytest.approx(result.metrics_dict["stepwise/rollout_length"], rel=1e-6) == 2.0
assert pytest.approx(result.metrics_dict["stepwise/turns_per_rollout"], rel=1e-6) == 2.0
@@ -462,15 +471,10 @@ async def test_stepwise_chat_truncation_zero_reward(monkeypatch: pytest.MonkeyPa
result = await g.generate_batch(batch_id=0)
- # Without truncation, each rollout [1,1] with gamma=1 → returns [2,1]
- expected_naive = [2.0, 1.0] * 4
+ # With step scores [1,1] and gamma=1, rollout-level reward is 2.0 for each rollout
rewards = list(result.rewards_dict["reward"])
-
- # Only the very first step (A1 step 1) should be truncated → set to 0.0
- expected = expected_naive[:]
- expected[0] = 0.0
-
- assert np.allclose(rewards, expected)
+ expected_log = [2.0, 2.0, 2.0, 2.0]
+ assert np.allclose(rewards, expected_log)
# masking should be nonzero due to truncation
assert "tokens/masked_fraction" in result.metrics_dict
From f2c597fcb7b3d96f5a7220a40079ab2553953d1c Mon Sep 17 00:00:00 2001
From: Kyle Avery <9327972+kyleavery@users.noreply.github.com>
Date: Wed, 5 Nov 2025 14:50:00 -0600
Subject: [PATCH 5/6] update config
---
verifiers/rl/trainer/config.py | 37 +++++++++++++++++++++++++++++++++-
1 file changed, 36 insertions(+), 1 deletion(-)
diff --git a/verifiers/rl/trainer/config.py b/verifiers/rl/trainer/config.py
index 6e35c3ba4..f28c8503a 100644
--- a/verifiers/rl/trainer/config.py
+++ b/verifiers/rl/trainer/config.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
-from typing import List, Optional, Union
+from typing import List, Literal, Optional, Union
from peft import LoraConfig
from transformers import TrainingArguments
@@ -127,6 +127,33 @@ class RLConfig(TrainingArguments):
default=False,
metadata={"help": "Whether to give zero reward to truncated completions."},
)
+
+ # stepwise returns / advantage configuration
+ use_stepwise_advantage: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "If True, treat each assistant turn as its own training sample and use a discounted MC return per step."
+ )
+ },
+ )
+ stepwise_aggregation: Literal["sum", "max"] = field(
+ default="sum",
+ metadata={
+ "help": (
+ "How to compute discounted per-step return R_t from future rewards."
+ )
+ },
+ )
+ stepwise_gamma: float = field(
+ default=0.4,
+ metadata={
+ "help": (
+ "Discount factor gamma for stepwise MC return. Must be in [0, 1]."
+ )
+ },
+ )
+
# sampling_args for generation
max_tokens: Optional[int] = field(
default=None,
@@ -333,3 +360,11 @@ def __post_init__(self):
assert self.rollouts_per_example > 1, (
"2 or more rollouts per example are required."
)
+
+ assert 0.0 <= self.stepwise_gamma <= 1.0, (
+ "stepwise_gamma must be between 0.0 and 1.0."
+ )
+
+ assert self.stepwise_aggregation in ["sum", "max"], (
+ "stepwise_aggregation must be either 'sum' or 'max'."
+ )
From f97729b861b08bc8ae6c0ec5825de91a28ec9db2 Mon Sep 17 00:00:00 2001
From: Kyle Avery <9327972+kyleavery@users.noreply.github.com>
Date: Wed, 5 Nov 2025 14:50:26 -0600
Subject: [PATCH 6/6] fix logging
---
verifiers/rl/trainer/generator.py | 11 ++++++++++-
1 file changed, 10 insertions(+), 1 deletion(-)
diff --git a/verifiers/rl/trainer/generator.py b/verifiers/rl/trainer/generator.py
index c6988ef78..9f445b6a0 100644
--- a/verifiers/rl/trainer/generator.py
+++ b/verifiers/rl/trainer/generator.py
@@ -287,6 +287,8 @@ async def generate_batch(self, batch_id: int) -> Batch:
all_completion_logprobs: list[list[float]] = []
all_returns: list[float] = []
item_meta: list[tuple[int, int]] = [] # (prompt_idx, step_idx)
+ # rollout-level rewards for logging only (one per rollout)
+ rollout_rewards_for_logging: list[float] = []
# iterate in the same order as env_results arrays
for i, (prompt, completion, state) in enumerate(
@@ -413,6 +415,8 @@ def _maybe_strip_think(msg: dict) -> dict:
raise RuntimeError("Per-turn scores missing in state for stepwise advantage computation")
returns: list[float] = self._compute_stepwise_returns(step_rewards)
+ # For logging: rollout-level reward is the return at first assistant turn
+ rollout_rewards_for_logging.append(float(returns[0]) if returns else 0.0)
else:
assert isinstance(prompt, str) and isinstance(completion, str)
@@ -441,6 +445,8 @@ def _maybe_strip_think(msg: dict) -> dict:
step_rewards.append(float(rs.reward))
returns: list[float] = self._compute_stepwise_returns(step_rewards)
+ # For logging: rollout-level reward is the return at first assistant turn
+ rollout_rewards_for_logging.append(float(returns[0]) if returns else 0.0)
for step_idx, item in enumerate(step_items):
p_ids, p_mask, c_ids, c_mask, c_logps = item
@@ -490,7 +496,10 @@ def __init__(self):
# mimic ProcessedOutputs for downstream use
processed_results = _Proc()
- rewards_dict = {"reward": all_returns}
+ # For logging, include one reward per rollout to align with prompts/completions
+ rewards_dict = {"reward": rollout_rewards_for_logging}
+ for k in env_results.metrics:
+ rewards_dict[k] = env_results.metrics[k]
rewards = all_returns
# Compute stepwise group baseline across all m×n samples for each prompt