diff --git a/tests/test_env_group.py b/tests/test_env_group.py index 9771b81f7..1022a7e08 100644 --- a/tests/test_env_group.py +++ b/tests/test_env_group.py @@ -234,6 +234,77 @@ def test_env_group_dataset_concatenation(self, mock_openai_client): assert tasks[1] == "math" assert tasks[2] == "code" + def test_env_group_dataset_interleaving_first_exhausted(self, mock_openai_client): + """Test that EnvGroup properly interleaves datasets with task labels.""" + env1 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict( + {"question": ["q1", "q2"], "answer": ["a1", "a2"]} + ), + rubric=Rubric(), + ) + + env2 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q3"], "answer": ["a3"]}), + rubric=Rubric(), + ) + + env_group = EnvGroup( + envs=[env1, env2], + env_names=["math", "code"], + env_mix_strategy="interleave", + ) + + # Check concatenated dataset + dataset = env_group.get_dataset() + assert len(dataset) == 2 + assert "task" in dataset.column_names + + # Check task labels + tasks = dataset["task"] + assert tasks[0] == "math" + assert tasks[1] == "code" + + def test_env_group_dataset_interleaving_all_exhausted(self, mock_openai_client): + """Test that EnvGroup properly interleaves datasets with task labels.""" + env1 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict( + {"question": ["q1", "q2"], "answer": ["a1", "a2"]} + ), + rubric=Rubric(), + ) + + env2 = SingleTurnEnv( + client=mock_openai_client, + model="test-model", + dataset=Dataset.from_dict({"question": ["q3"], "answer": ["a3"]}), + rubric=Rubric(), + ) + + env_group = EnvGroup( + envs=[env1, env2], + env_names=["math", "code"], + env_mix_strategy="interleave", + env_mix_kwargs=dict(stopping_strategy="all_exhausted"), + ) + + # Check concatenated dataset + dataset = env_group.get_dataset() + assert len(dataset) == 4 + assert "task" in dataset.column_names + + # Check task labels + tasks = dataset["task"] + assert tasks[0] == "math" + assert tasks[1] == "code" + assert tasks[2] == "math" + assert tasks[1] == "code" + def test_env_group_rubric_type(self, mock_openai_client): """Test that EnvGroup creates EnvGroupRubric.""" env1 = SingleTurnEnv( diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index 633fdeca6..d33dc8461 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -1,7 +1,7 @@ from collections import defaultdict -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING, Literal, Mapping -from datasets import Dataset, concatenate_datasets +from datasets import Dataset, concatenate_datasets, interleave_datasets from openai import AsyncOpenAI from verifiers import ( @@ -99,6 +99,8 @@ def __init__( self, envs: list[Environment], env_names: list[str] | None = None, + env_mix_strategy: Literal["interleave", "concatenate"] = "concatenate", + env_mix_kwargs: dict = {}, map_kwargs: dict = {}, **kwargs, ): @@ -140,8 +142,15 @@ def add_task(example): if env_eval_dataset is not None: env_eval_dataset = env_eval_dataset.map(add_task, **map_kwargs) eval_datasets.append(env_eval_dataset) - dataset = concatenate_datasets(datasets) if datasets else None - eval_dataset = concatenate_datasets(eval_datasets) if eval_datasets else None + mix_datasets = ( + interleave_datasets + if env_mix_strategy == "interleave" + else concatenate_datasets + ) + dataset = mix_datasets(datasets, **env_mix_kwargs) if datasets else None + eval_dataset = ( + mix_datasets(eval_datasets, **env_mix_kwargs) if eval_datasets else None + ) # wrap rubrics rubric = EnvGroupRubric(self.env_map)