From db6fac2dcf598070e38aeb4269dbfe59224080fd Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Wed, 5 Nov 2025 18:53:32 +0530 Subject: [PATCH 1/2] Support env mixing with interleave_datasets --- tests/test_env_group.py | 71 +++++++++++++++++++++++++++++++++++++ verifiers/envs/env_group.py | 18 +++++++--- 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/tests/test_env_group.py b/tests/test_env_group.py index 9771b81f7..1022a7e08 100644 --- a/tests/test_env_group.py +++ b/tests/test_env_group.py @@ -234,6 +234,77 @@ def test_env_group_dataset_concatenation(self, mock_openai_client): assert tasks[1] == "math" assert tasks[2] == "code" + def test_env_group_dataset_interleaving_first_exhausted(self, mock_openai_client): + """Test that EnvGroup properly interleaves datasets with task labels.""" + env1 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict( + {"question": ["q1", "q2"], "answer": ["a1", "a2"]} + ), + rubric=Rubric(), + ) + + env2 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q3"], "answer": ["a3"]}), + rubric=Rubric(), + ) + + env_group = EnvGroup( + envs=[env1, env2], + env_names=["math", "code"], + env_mix_strategy="interleave", + ) + + # Check concatenated dataset + dataset = env_group.get_dataset() + assert len(dataset) == 2 + assert "task" in dataset.column_names + + # Check task labels + tasks = dataset["task"] + assert tasks[0] == "math" + assert tasks[1] == "code" + + def test_env_group_dataset_interleaving_all_exhausted(self, mock_openai_client): + """Test that EnvGroup properly interleaves datasets with task labels.""" + env1 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict( + {"question": ["q1", "q2"], "answer": ["a1", "a2"]} + ), + rubric=Rubric(), + ) + + env2 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q3"], "answer": ["a3"]}), + rubric=Rubric(), + ) + + env_group = EnvGroup( + envs=[env1, env2], + env_names=["math", "code"], + env_mix_strategy="interleave", + env_mix_kwargs=dict(stopping_strategy="all_exhausted"), + ) + + # Check concatenated dataset + dataset = env_group.get_dataset() + assert len(dataset) == 4 + assert "task" in dataset.column_names + + # Check task labels + tasks = dataset["task"] + assert tasks[0] == "math" + assert tasks[1] == "code" + assert tasks[2] == "math" + assert tasks[1] == "code" + def test_env_group_rubric_type(self, mock_openai_client): """Test that EnvGroup creates EnvGroupRubric.""" env1 = SingleTurnEnv( diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index 633fdeca6..7036f46d7 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -1,7 +1,7 @@ from collections import defaultdict -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING, Literal, Mapping -from datasets import Dataset, concatenate_datasets +from datasets import Dataset, concatenate_datasets, interleave_datasets from openai import AsyncOpenAI from verifiers import ( @@ -99,6 +99,9 @@ def __init__( self, envs: list[Environment], env_names: list[str] | None = None, + probabilities: list[float] | None = None, + env_mix_strategy: Literal["interleave", "concatenate"] = "concatenate", + env_mix_kwargs: dict = {}, map_kwargs: dict = {}, **kwargs, ): @@ -140,8 +143,15 @@ def add_task(example): if env_eval_dataset is not None: env_eval_dataset = env_eval_dataset.map(add_task, **map_kwargs) eval_datasets.append(env_eval_dataset) - dataset = concatenate_datasets(datasets) if datasets else None - eval_dataset = concatenate_datasets(eval_datasets) if eval_datasets else None + mix_datasets = ( + interleave_datasets + if env_mix_strategy == "interleave" + else concatenate_datasets + ) + dataset = mix_datasets(datasets, **env_mix_kwargs) if datasets else None + eval_dataset = ( + mix_datasets(eval_datasets, **env_mix_kwargs) if eval_datasets else None + ) # wrap rubrics rubric = EnvGroupRubric(self.env_map) From a2261ef6f64924a9f0a4c3a2621f114fa388074e Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Sat, 8 Nov 2025 13:54:44 +0000 Subject: [PATCH 2/2] Remove probabilities --- verifiers/envs/env_group.py | 1 - 1 file changed, 1 deletion(-) diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index 7036f46d7..d33dc8461 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -99,7 +99,6 @@ def __init__( self, envs: list[Environment], env_names: list[str] | None = None, - probabilities: list[float] | None = None, env_mix_strategy: Literal["interleave", "concatenate"] = "concatenate", env_mix_kwargs: dict = {}, map_kwargs: dict = {},