Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/test_env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,77 @@ def test_env_group_dataset_concatenation(self, mock_openai_client):
assert tasks[1] == "math"
assert tasks[2] == "code"

def test_env_group_dataset_interleaving_first_exhausted(self, mock_openai_client):
"""Test that EnvGroup properly interleaves datasets with task labels."""
env1 = SingleTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=Dataset.from_dict(
{"question": ["q1", "q2"], "answer": ["a1", "a2"]}
),
rubric=Rubric(),
)

env2 = SingleTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=Dataset.from_dict({"question": ["q3"], "answer": ["a3"]}),
rubric=Rubric(),
)

env_group = EnvGroup(
envs=[env1, env2],
env_names=["math", "code"],
env_mix_strategy="interleave",
)

# Check concatenated dataset
dataset = env_group.get_dataset()
assert len(dataset) == 2
assert "task" in dataset.column_names

# Check task labels
tasks = dataset["task"]
assert tasks[0] == "math"
assert tasks[1] == "code"

def test_env_group_dataset_interleaving_all_exhausted(self, mock_openai_client):
"""Test that EnvGroup properly interleaves datasets with task labels."""
env1 = SingleTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=Dataset.from_dict(
{"question": ["q1", "q2"], "answer": ["a1", "a2"]}
),
rubric=Rubric(),
)

env2 = SingleTurnEnv(
client=mock_openai_client,
model="test-model",
dataset=Dataset.from_dict({"question": ["q3"], "answer": ["a3"]}),
rubric=Rubric(),
)

env_group = EnvGroup(
envs=[env1, env2],
env_names=["math", "code"],
env_mix_strategy="interleave",
env_mix_kwargs=dict(stopping_strategy="all_exhausted"),
)

# Check concatenated dataset
dataset = env_group.get_dataset()
assert len(dataset) == 4
assert "task" in dataset.column_names

# Check task labels
tasks = dataset["task"]
assert tasks[0] == "math"
assert tasks[1] == "code"
assert tasks[2] == "math"
assert tasks[1] == "code"

def test_env_group_rubric_type(self, mock_openai_client):
"""Test that EnvGroup creates EnvGroupRubric."""
env1 = SingleTurnEnv(
Expand Down
17 changes: 13 additions & 4 deletions verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Mapping
from typing import TYPE_CHECKING, Literal, Mapping

from datasets import Dataset, concatenate_datasets
from datasets import Dataset, concatenate_datasets, interleave_datasets
from openai import AsyncOpenAI

from verifiers import (
Expand Down Expand Up @@ -99,6 +99,8 @@ def __init__(
self,
envs: list[Environment],
env_names: list[str] | None = None,
env_mix_strategy: Literal["interleave", "concatenate"] = "concatenate",
env_mix_kwargs: dict = {},
map_kwargs: dict = {},
**kwargs,
):
Expand Down Expand Up @@ -140,8 +142,15 @@ def add_task(example):
if env_eval_dataset is not None:
env_eval_dataset = env_eval_dataset.map(add_task, **map_kwargs)
eval_datasets.append(env_eval_dataset)
dataset = concatenate_datasets(datasets) if datasets else None
eval_dataset = concatenate_datasets(eval_datasets) if eval_datasets else None
mix_datasets = (
interleave_datasets
if env_mix_strategy == "interleave"
else concatenate_datasets
)
dataset = mix_datasets(datasets, **env_mix_kwargs) if datasets else None
eval_dataset = (
mix_datasets(eval_datasets, **env_mix_kwargs) if eval_datasets else None
)
# wrap rubrics
rubric = EnvGroupRubric(self.env_map)

Expand Down
Loading