From fb2a1aab2cd0cf1bb906e1f8d3047b8e8bdfe348 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 25 Nov 2025 02:29:49 +0000 Subject: [PATCH 01/18] adds API for converting documents into pretraining data --- src/instructlab/training/config.py | 13 ++++ src/instructlab/training/data_process.py | 85 ++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index e603b55e..ec195ad1 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -63,9 +63,22 @@ class DataProcessArgs(BaseModel): description="this is the number of CPU procs we use for data processing parallelization", ) + # Pretraining mode flag + is_pretraining: bool = Field( + default=False, + description="Enable pretraining mode: tokenizes raw documents without chat templates or chunking", + ) + # disable the protected namespace for the model_config field model_config = ConfigDict(protected_namespaces=()) + @model_validator(mode="after") + def validate_pretraining_params(self): + """Validate pretraining parameter combinations""" + if self.is_pretraining and self.chat_tmpl_path is not None: + raise ValueError("chat_tmpl_path not compatible with is_pretraining=True") + return self + # public API class TorchrunArgs(BaseModel): diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 2b93a48f..690e4665 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -1133,6 +1133,91 @@ def process_messages_into_input_ids( save_dataset(final_dataset, data_output_path, num_cpu_procs) +def process_documents_for_pretraining( + data_path: str, + data_output_path: str, + model_path: str, + num_cpu_procs: int, +) -> None: + """ + Process raw documents for pretraining by tokenizing without chunking. + + Outputs one JSONL record per document with only input_ids (no labels). + Blocking/chunking happens later during training. + + Pattern: Each document → [BOS][tokens][EOS] + + Args: + data_path: Path to input JSONL with {"documents": "text"} format + data_output_path: Directory for processed data output + model_path: Path to model/tokenizer + num_cpu_procs: Number of parallel processes + """ + ensure_can_write_to_directory(data_output_path) + + # Load and validate dataset + try: + data = load_dataset("json", data_files=data_path, split="train") + except Exception as e: + raise ValueError( + "Malformed or missing data, please ensure your dataset is correctly formatted" + ) from e + + if data.num_rows == 0: + raise ValueError("The provided dataset is empty") + + + if 'document' not in data.column_names: + raise ValueError( + f"Pretraining data must have 'document' field. Found: {data.column_names}" + ) + + logger.info("Loading tokenizer from %s", model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + if tokenizer.eos_token_id is None: + raise ValueError("Tokenizer must have an EOS token defined for pretraining") + + logger.info("Tokenizing %d documents for pretraining...", data.num_rows) + + # Tokenize each document: encode() adds BOS, then append EOS + def tokenize_document(sample): + input_ids = tokenizer.encode(sample['document'], add_special_tokens=True) + input_ids.append(tokenizer.eos_token_id) + return { + "input_ids": input_ids, + "len": len(input_ids), + } + + tokenized_data = data.map( + tokenize_document, + num_proc=num_cpu_procs, + desc="Tokenizing documents", + remove_columns=data.column_names, + ) + + # Calculate statistics + total_tokens = sum(tokenized_data['len']) + avg_tokens = total_tokens / len(tokenized_data) + logger.info(f"Processed {len(tokenized_data):,} documents") + logger.info(f"Total tokens: {total_tokens:,}") + logger.info(f"Average tokens per document: {avg_tokens:.1f}") + + # Save to JSONL (one record per document) + os.makedirs(data_output_path, exist_ok=True) + output_file = Path(data_output_path) / "data.jsonl" + + tokenized_data.to_json( + output_file, + num_proc=num_cpu_procs, + lines=True, + orient="records" + ) + + logger.info(f"Saved tokenized documents to {output_file}") + logger.info("Note: Blocking into fixed-size chunks will happen during training") + + def ensure_can_write_to_directory(output_dir: str) -> None: """ Ensure that we can write to the output directory. From 6fbdcfdef121f3ab9bfa7939feb87de60b661910 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:58:55 +0000 Subject: [PATCH 02/18] removes extraneous validation + account for eos already being present --- src/instructlab/training/config.py | 7 ------- src/instructlab/training/data_process.py | 24 ++++++++++++++---------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index ec195ad1..c9056e70 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -72,13 +72,6 @@ class DataProcessArgs(BaseModel): # disable the protected namespace for the model_config field model_config = ConfigDict(protected_namespaces=()) - @model_validator(mode="after") - def validate_pretraining_params(self): - """Validate pretraining parameter combinations""" - if self.is_pretraining and self.chat_tmpl_path is not None: - raise ValueError("chat_tmpl_path not compatible with is_pretraining=True") - return self - # public API class TorchrunArgs(BaseModel): diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 690e4665..a17f9ee1 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -1138,6 +1138,7 @@ def process_documents_for_pretraining( data_output_path: str, model_path: str, num_cpu_procs: int, + document_column_name: str = "document", ) -> None: """ Process raw documents for pretraining by tokenizing without chunking. @@ -1152,6 +1153,7 @@ def process_documents_for_pretraining( data_output_path: Directory for processed data output model_path: Path to model/tokenizer num_cpu_procs: Number of parallel processes + document_column_name: Name of the column containing the documents """ ensure_can_write_to_directory(data_output_path) @@ -1166,10 +1168,9 @@ def process_documents_for_pretraining( if data.num_rows == 0: raise ValueError("The provided dataset is empty") - - if 'document' not in data.column_names: + if document_column_name not in data.column_names: raise ValueError( - f"Pretraining data must have 'document' field. Found: {data.column_names}" + f"Pretraining data must have '{document_column_name}' field. Found: {data.column_names}" ) logger.info("Loading tokenizer from %s", model_path) @@ -1182,8 +1183,14 @@ def process_documents_for_pretraining( # Tokenize each document: encode() adds BOS, then append EOS def tokenize_document(sample): - input_ids = tokenizer.encode(sample['document'], add_special_tokens=True) - input_ids.append(tokenizer.eos_token_id) + input_ids = tokenizer.encode( + sample[document_column_name], add_special_tokens=True + ) + + # ensures eos token is present without double-adding it. + if input_ids[-1] != tokenizer.eos_token_id: + input_ids.append(tokenizer.eos_token_id) + return { "input_ids": input_ids, "len": len(input_ids), @@ -1197,7 +1204,7 @@ def tokenize_document(sample): ) # Calculate statistics - total_tokens = sum(tokenized_data['len']) + total_tokens = sum(tokenized_data["len"]) avg_tokens = total_tokens / len(tokenized_data) logger.info(f"Processed {len(tokenized_data):,} documents") logger.info(f"Total tokens: {total_tokens:,}") @@ -1208,10 +1215,7 @@ def tokenize_document(sample): output_file = Path(data_output_path) / "data.jsonl" tokenized_data.to_json( - output_file, - num_proc=num_cpu_procs, - lines=True, - orient="records" + output_file, num_proc=num_cpu_procs, lines=True, orient="records" ) logger.info(f"Saved tokenized documents to {output_file}") From fd6f5f33465ea13759c22b218dbe0ca7ec2b5a5f Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Wed, 26 Nov 2025 04:30:07 +0000 Subject: [PATCH 03/18] exposes the pretraining column name in training config --- src/instructlab/training/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index c9056e70..db0b4b45 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -68,6 +68,10 @@ class DataProcessArgs(BaseModel): default=False, description="Enable pretraining mode: tokenizes raw documents without chat templates or chunking", ) + pretraining_column_name: str = Field( + default="document", + description="the name of the column containing the text to pretrain on", + ) # disable the protected namespace for the model_config field model_config = ConfigDict(protected_namespaces=()) From 09a0b9045cf845aecc96bad64633723a2dfb0419 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Wed, 26 Nov 2025 06:12:12 +0000 Subject: [PATCH 04/18] explicitly do not request a dict --- src/instructlab/training/data_process.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index a17f9ee1..9ad2cf44 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -412,7 +412,8 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs): logger.info("Tokenizing the dataset with %s tokenizer...", args.model_path) data_with_input_ids = data.map( lambda x: { - "input_ids": tokenizer.apply_chat_template(x["messages"], tokenize=True), + # newer versions of transformers have `return_dict=True` by default + "input_ids": tokenizer.apply_chat_template(x["messages"], tokenize=True, return_dict=False), "unmask": bool(x["unmask"]) if "unmask" in x else False, }, num_proc=NUM_PROC, @@ -687,7 +688,8 @@ def unmask_messages( if regions: message_regions_map[idx] = regions - input_ids = tokenizer.apply_chat_template(msgs_with_unmasking) + # newer versions of transformers have `return_dict=True` by default + input_ids = tokenizer.apply_chat_template(msgs_with_unmasking, return_dict=False) # Get token IDs for all unmask tokens unmask_begin_token_id = tokenizer.encode( From 967a3c418e71097519e8104be7b737f0ead7e811 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Wed, 26 Nov 2025 07:03:14 +0000 Subject: [PATCH 05/18] adds the ability for training to consume pretraining forrmat data --- src/instructlab/training/config.py | 18 ++++ src/instructlab/training/main_ds.py | 50 ++++++--- src/instructlab/training/sampler.py | 112 +++++++++++++++++++- tests/unit/test_pretraining_mode.py | 156 ++++++++++++++++++++++++++++ 4 files changed, 322 insertions(+), 14 deletions(-) create mode 100644 tests/unit/test_pretraining_mode.py diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index db0b4b45..734aa976 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -77,6 +77,16 @@ class DataProcessArgs(BaseModel): model_config = ConfigDict(protected_namespaces=()) +class PretrainingConfig(BaseModel): + """ + Configuration for pretraining mode. + """ + + block_size: int = Field( + description="Size of each block in tokens for pretraining datasets." + ) + + # public API class TorchrunArgs(BaseModel): """ @@ -276,6 +286,14 @@ class TrainingArgs(BaseModel): # "last_epoch". This works alongside the '--checkpoint_at_epoch' flag. keep_last_checkpoint_only: Optional[bool] = False + pretraining_config: Optional[PretrainingConfig] = Field( + default=None, + description=( + "Pretraining configuration. When provided, enables block-based sampling " + "for raw document pretraining datasets." + ), + ) + # TODO(osilkin): # we are only exposing this here because `run_training` today is implicitly coupled # with `process_data`. Since we don't have a specific field for data processing arguments, diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index d08afc91..9be7a19d 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -49,6 +49,7 @@ ModelTypes, TorchrunArgs, TrainingArgs, + PretrainingConfig, ) # pylint: disable=no-name-in-module @@ -364,6 +365,7 @@ def main(args): batch_size = args.effective_batch_size pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + train_loader = get_data_loader( data_path=args.data_path, batch_size=batch_size, @@ -374,6 +376,7 @@ def main(args): num_workers=8, # I don't like this but am setting it for consistency flash_enabled=flash_enabled, pad_token_id=pad_token_id, + pretraining_config=getattr(args, "pretraining_config", None), ) if args.local_rank == 0: @@ -469,18 +472,26 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ) if train_args.process_data: - # TODO(osilkin): - # Decouple the data processing logic from training. - # Now that we've decided that repos will be less tethered to the - # design choices of the `ilab` CLI, we can make this change. - dp.process_data( - data_output_path=train_args.data_output_dir, - model_path=train_args.model_path, - data_path=train_args.data_path, - max_seq_len=train_args.max_seq_len, - chat_tmpl_path=train_args.chat_tmpl_path, - num_cpu_procs=train_args.data_process_num_cpu_procs, - ) + if train_args.pretraining_config is not None: + dp.process_documents_for_pretraining( + data_path=train_args.data_path, + data_output_path=train_args.data_output_dir, + model_path=train_args.model_path, + num_cpu_procs=train_args.data_process_num_cpu_procs, + ) + else: + # TODO(osilkin): + # Decouple the data processing logic from training. + # Now that we've decided that repos will be less tethered to the + # design choices of the `ilab` CLI, we can make this change. + dp.process_data( + data_output_path=train_args.data_output_dir, + model_path=train_args.model_path, + data_path=train_args.data_path, + max_seq_len=train_args.max_seq_len, + chat_tmpl_path=train_args.chat_tmpl_path, + num_cpu_procs=train_args.data_process_num_cpu_procs, + ) if not os.path.exists(train_args.ckpt_output_dir): os.makedirs(train_args.ckpt_output_dir, exist_ok=True) @@ -537,6 +548,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ] ) + if train_args.pretraining_config is not None: + command.append( + f"--block-size={train_args.pretraining_config.block_size}" + ) + if train_args.chat_tmpl_path is not None: command.append(f"--chat-tmpl-path={train_args.chat_tmpl_path}") @@ -784,6 +800,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: help="Which modules we should target for injecting LoRA layers. Defaults to selecting all projection layers when no values are provided.", ) parser.add_argument("--max_batch_len", type=int, default=60000) + parser.add_argument( + "--block-size", + type=int, + default=None, + help="When provided, enables pretraining mode with the given token block size.", + ) parser.add_argument( "--cpu_offload_optimizer", action="store_true", @@ -856,6 +878,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: help="Epsilon for numerical stability in AdamW optimizer.", ) args = parser.parse_args() + if args.block_size is not None: + args.pretraining_config = PretrainingConfig(block_size=args.block_size) + else: + args.pretraining_config = None set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/sampler.py b/src/instructlab/training/sampler.py index be2b3f4b..73aefce5 100644 --- a/src/instructlab/training/sampler.py +++ b/src/instructlab/training/sampler.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +import logging from typing import Optional # Third Party -from datasets import load_dataset +from datasets import Dataset as HFDataset, load_dataset from torch.utils.data import DataLoader, Dataset, Sampler import numpy as np import torch @@ -16,6 +17,9 @@ batch_lengths_to_minibatches_padded, ) from instructlab.training.type_definitions import CollatedItem +from instructlab.training.config import PretrainingConfig + +logger = logging.getLogger(__name__) class EpochSampler(Sampler): @@ -291,6 +295,97 @@ def __call__(self, batch: list[dict]): return all_minibatches +class PretrainingBlockDataset(Dataset): + """Dataset that concatenates documents and exposes fixed-size blocks.""" + + def __init__(self, dataset: HFDataset, block_size: int, pad_token_id: int): + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}") + if "input_ids" not in dataset.column_names: + raise ValueError("Pretraining data must provide an 'input_ids' column.") + if pad_token_id < 0: + raise ValueError("pad_token_id must be a non-negative integer.") + + self.block_size = block_size + self.pad_token_id = pad_token_id + + all_input_ids: list[int] = [] + for sample in dataset: + ids = sample["input_ids"] + if isinstance(ids, torch.Tensor): + ids = ids.tolist() + all_input_ids.extend(ids) + + total_tokens = len(all_input_ids) + if total_tokens == 0: + raise ValueError("Pretraining dataset is empty after concatenation.") + + num_blocks, remainder = divmod(total_tokens, block_size) + if remainder: + num_blocks += 1 + + self.all_input_ids = all_input_ids + self.num_blocks = num_blocks + self.last_block_len = remainder if remainder else block_size + self.total_tokens = total_tokens + + logger.info( + "Pretraining dataset: %s tokens → %s block(s) (block_size=%s, remainder=%s)", + f"{total_tokens:,}", + f"{self.num_blocks:,}", + block_size, + remainder, + ) + + def __len__(self) -> int: + return self.num_blocks + + def __getitem__(self, index: int): + if index < 0 or index >= self.num_blocks: + raise IndexError(f"Index {index} out of range for {self.num_blocks} blocks.") + + start = index * self.block_size + end = start + self.block_size + is_last_block = index == self.num_blocks - 1 + is_partial = is_last_block and self.last_block_len != self.block_size + + if is_partial: + actual_tokens = self.all_input_ids[start:] + actual_len = len(actual_tokens) + pad_len = self.block_size - actual_len + + input_ids = actual_tokens + [self.pad_token_id] * pad_len + labels = actual_tokens + [-100] * pad_len + loss_tokens = max(actual_len - 1, 0) + else: + input_ids = self.all_input_ids[start:end] + labels = list(input_ids) + loss_tokens = self.block_size - 1 + + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long), + "len": self.block_size, + "num_loss_counted_tokens": loss_tokens, + } + + @classmethod + def from_jsonl_file( + cls, + data_path: str, + block_size: int, + pad_token_id: int, + ) -> "PretrainingBlockDataset": + dataset = load_dataset("json", data_files=data_path, split="train") + return cls(dataset, block_size, pad_token_id) + + def get_lengths(self) -> np.ndarray: + lengths = np.full(self.num_blocks, self.block_size, dtype=np.int64) + if self.num_blocks and self.last_block_len != self.block_size: + lengths[-1] = self.last_block_len + return lengths + + class TokenDataset(Dataset): """Dataset for loading tokenized data from JSONL files. @@ -346,6 +441,7 @@ def get_data_loader( num_workers: int = 0, flash_enabled: bool = True, pad_token_id: int = 0, + pretraining_config: Optional[PretrainingConfig] = None, ): """Create a data loader with epoch-based sampling and batch packing. @@ -360,11 +456,23 @@ def get_data_loader( num_workers: Number of data loading workers flash_enabled: Whether flash attention is enabled (affects collation strategy) pad_token_id: Token ID to use for padding (only used when flash_enabled=False) + pretraining_config: When provided, enables block-based pretraining dataset loading Returns: DataLoader configured with appropriate collator based on flash_enabled """ - dataset = TokenDataset(data_path) + if pretraining_config is not None: + dataset = PretrainingBlockDataset.from_jsonl_file( + data_path, pretraining_config.block_size, pad_token_id + ) + logger.info( + "Using pretraining dataset with block_size=%s and %s block(s)", + pretraining_config.block_size, + f"{len(dataset):,}", + ) + else: + dataset = TokenDataset(data_path) + sampler = EpochSampler(len(dataset), seed=seed) # Create unified collator with appropriate mode diff --git a/tests/unit/test_pretraining_mode.py b/tests/unit/test_pretraining_mode.py new file mode 100644 index 00000000..50d3427b --- /dev/null +++ b/tests/unit/test_pretraining_mode.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import json +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +# Third Party +import torch +from datasets import Dataset as HFDataset + +# First Party +from instructlab.training.config import PretrainingConfig +from instructlab.training.data_process import process_documents_for_pretraining +from instructlab.training.sampler import ( + PretrainingBlockDataset, + get_data_loader, +) + + +class TestPretrainingBlockDataset(unittest.TestCase): + """Tests for the PretrainingBlockDataset behavior.""" + + def test_blocks_are_padded_and_loss_counts_tracked(self): + dataset = HFDataset.from_dict({"input_ids": [[1, 2, 3], [4, 5, 6, 7]]}) + block_ds = PretrainingBlockDataset(dataset, block_size=4, pad_token_id=0) + + self.assertEqual(len(block_ds), 2) + + first = block_ds[0] + self.assertTrue( + torch.equal(first["input_ids"], torch.tensor([1, 2, 3, 4], dtype=torch.long)) + ) + self.assertTrue( + torch.equal(first["labels"], torch.tensor([1, 2, 3, 4], dtype=torch.long)) + ) + self.assertEqual(first["num_loss_counted_tokens"], 3) + self.assertEqual(first["len"], 4) + + second = block_ds[1] + self.assertTrue( + torch.equal(second["input_ids"], torch.tensor([5, 6, 7, 0], dtype=torch.long)) + ) + self.assertTrue( + torch.equal( + second["labels"], torch.tensor([5, 6, 7, -100], dtype=torch.long) + ) + ) + self.assertEqual(second["num_loss_counted_tokens"], 2) + self.assertEqual(second["len"], 4) + + lengths = block_ds.get_lengths() + self.assertEqual(lengths.tolist(), [4, 3]) + + +class TestPretrainingDataLoader(unittest.TestCase): + """Tests for the pretraining-specific data loader integration.""" + + def test_pretraining_loader_returns_packed_batches(self): + cfg = PretrainingConfig(block_size=4) + + with tempfile.TemporaryDirectory() as tmpdir: + data_path = Path(tmpdir) / "data.jsonl" + records = [ + {"input_ids": [1, 2, 3, 4]}, + {"input_ids": [5, 6, 7, 8]}, + ] + with data_path.open("w", encoding="utf-8") as fh: + for record in records: + fh.write(json.dumps(record) + "\n") + + loader = get_data_loader( + data_path=str(data_path), + batch_size=2, + max_tokens_per_gpu=8, + seed=42, + rank=0, + world_size=1, + pretraining_config=cfg, + pad_token_id=0, + ) + + self.assertIsInstance(loader.dataset, PretrainingBlockDataset) + self.assertEqual(len(loader.dataset), 2) + + step = next(iter(loader)) + self.assertIsInstance(step, list) + self.assertEqual(len(step), 1) + + microbatch = step[0] + self.assertIn("input_ids", microbatch) + self.assertTrue(torch.is_tensor(microbatch["input_ids"])) + self.assertEqual(microbatch["input_ids"].shape, (1, 8)) + self.assertEqual(microbatch["num_samples"], 2) + self.assertEqual(microbatch["num_loss_counted_tokens"], 6) + self.assertEqual(microbatch["batch_num_loss_counted_tokens"], 6) + + +class TestPretrainingDataProcessing(unittest.TestCase): + """Tests for the pretraining data processing pipeline.""" + + def test_process_documents_for_pretraining_outputs_expected_fields(self): + class StubTokenizer: + eos_token_id = 999 + + def encode(self, text, add_special_tokens=True): + base = [ord(ch) % 50 + 1 for ch in text] + return base if add_special_tokens else base[1:] + + documents = [ + {"document": "alpha"}, + {"document": "beta"}, + ] + + with tempfile.TemporaryDirectory() as tmpdir: + data_path = Path(tmpdir) / "raw.jsonl" + with data_path.open("w", encoding="utf-8") as fh: + for record in documents: + fh.write(json.dumps(record) + "\n") + + output_dir = Path(tmpdir) / "processed" + + with patch( + "instructlab.training.data_process.AutoTokenizer.from_pretrained", + return_value=StubTokenizer(), + ) as mock_auto: + process_documents_for_pretraining( + data_path=str(data_path), + data_output_path=str(output_dir), + model_path="stub-model", + num_cpu_procs=1, + ) + + mock_auto.assert_called_once_with("stub-model") + + output_file = output_dir / "data.jsonl" + self.assertTrue(output_file.exists()) + + with output_file.open("r", encoding="utf-8") as fh: + rows = [json.loads(line) for line in fh if line.strip()] + + self.assertEqual(len(rows), len(documents)) + for row in rows: + self.assertIn("input_ids", row) + self.assertIn("len", row) + self.assertIsInstance(row["input_ids"], list) + self.assertIsInstance(row["len"], int) + self.assertEqual(len(row["input_ids"]), row["len"]) + self.assertEqual(row["input_ids"][-1], StubTokenizer.eos_token_id) + + +if __name__ == "__main__": # pragma: no cover + unittest.main() From dd8e3635a37e8a468bc3ddb77d758d5eaf2b715d Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:04:41 +0000 Subject: [PATCH 06/18] linting + minor fixes --- src/instructlab/training/__init__.py | 2 ++ src/instructlab/training/config.py | 4 ++++ src/instructlab/training/data_process.py | 4 +++- src/instructlab/training/main_ds.py | 22 +++++++++++++++++++--- src/instructlab/training/sampler.py | 11 +++++++---- tests/unit/test_pretraining_mode.py | 14 +++++++++----- 6 files changed, 44 insertions(+), 13 deletions(-) diff --git a/src/instructlab/training/__init__.py b/src/instructlab/training/__init__.py index 78ba2bfd..136d1384 100644 --- a/src/instructlab/training/__init__.py +++ b/src/instructlab/training/__init__.py @@ -10,6 +10,7 @@ "FSDPOptions", "ShardingStrategies", "DistributedBackend", + "PretrainingConfig", ) # First Party @@ -23,6 +24,7 @@ DistributedBackend, FSDPOptions, LoraOptions, + PretrainingConfig, QuantizeDataType, ShardingStrategies, TorchrunArgs, diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 734aa976..3f2dd810 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -85,6 +85,10 @@ class PretrainingConfig(BaseModel): block_size: int = Field( description="Size of each block in tokens for pretraining datasets." ) + document_column_name: str = Field( + default="document", + description="Name of the column containing raw documents for pretraining.", + ) # public API diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 9ad2cf44..2b546fd5 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -413,7 +413,9 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs): data_with_input_ids = data.map( lambda x: { # newer versions of transformers have `return_dict=True` by default - "input_ids": tokenizer.apply_chat_template(x["messages"], tokenize=True, return_dict=False), + "input_ids": tokenizer.apply_chat_template( + x["messages"], tokenize=True, return_dict=False + ), "unmask": bool(x["unmask"]) if "unmask" in x else False, }, num_proc=NUM_PROC, diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 9be7a19d..4d44af99 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -47,9 +47,9 @@ from instructlab.training.config import ( DistributedBackend, ModelTypes, + PretrainingConfig, TorchrunArgs, TrainingArgs, - PretrainingConfig, ) # pylint: disable=no-name-in-module @@ -478,6 +478,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: data_output_path=train_args.data_output_dir, model_path=train_args.model_path, num_cpu_procs=train_args.data_process_num_cpu_procs, + document_column_name=train_args.pretraining_config.document_column_name, ) else: # TODO(osilkin): @@ -549,8 +550,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ) if train_args.pretraining_config is not None: + command.append(f"--block-size={train_args.pretraining_config.block_size}") command.append( - f"--block-size={train_args.pretraining_config.block_size}" + f"--document-column-name={train_args.pretraining_config.document_column_name}" ) if train_args.chat_tmpl_path is not None: @@ -806,6 +808,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default=None, help="When provided, enables pretraining mode with the given token block size.", ) + parser.add_argument( + "--document-column-name", + type=str, + default=None, + help="Column name containing raw documents for continual pretraining data.", + ) parser.add_argument( "--cpu_offload_optimizer", action="store_true", @@ -878,8 +886,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: help="Epsilon for numerical stability in AdamW optimizer.", ) args = parser.parse_args() + if args.document_column_name is not None and args.block_size is None: + parser.error("--document-column-name requires --block-size to be specified.") + if args.block_size is not None: - args.pretraining_config = PretrainingConfig(block_size=args.block_size) + pretraining_kwargs = {} + if args.document_column_name is not None: + pretraining_kwargs["document_column_name"] = args.document_column_name + args.pretraining_config = PretrainingConfig( + block_size=args.block_size, **pretraining_kwargs + ) else: args.pretraining_config = None set_random_seed(args.seed) diff --git a/src/instructlab/training/sampler.py b/src/instructlab/training/sampler.py index 73aefce5..35ba41bc 100644 --- a/src/instructlab/training/sampler.py +++ b/src/instructlab/training/sampler.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -import logging from typing import Optional +import logging # Third Party -from datasets import Dataset as HFDataset, load_dataset +from datasets import Dataset as HFDataset +from datasets import load_dataset from torch.utils.data import DataLoader, Dataset, Sampler import numpy as np import torch @@ -13,11 +14,11 @@ # First Party from instructlab.training.batch_packer import batch_lengths_to_minibatches_lpt +from instructlab.training.config import PretrainingConfig from instructlab.training.padded_batch_packer import ( batch_lengths_to_minibatches_padded, ) from instructlab.training.type_definitions import CollatedItem -from instructlab.training.config import PretrainingConfig logger = logging.getLogger(__name__) @@ -342,7 +343,9 @@ def __len__(self) -> int: def __getitem__(self, index: int): if index < 0 or index >= self.num_blocks: - raise IndexError(f"Index {index} out of range for {self.num_blocks} blocks.") + raise IndexError( + f"Index {index} out of range for {self.num_blocks} blocks." + ) start = index * self.block_size end = start + self.block_size diff --git a/tests/unit/test_pretraining_mode.py b/tests/unit/test_pretraining_mode.py index 50d3427b..92056eac 100644 --- a/tests/unit/test_pretraining_mode.py +++ b/tests/unit/test_pretraining_mode.py @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from pathlib import Path +from unittest.mock import patch import json import os import tempfile import unittest -from pathlib import Path -from unittest.mock import patch # Third Party -import torch from datasets import Dataset as HFDataset +import torch # First Party from instructlab.training.config import PretrainingConfig @@ -32,7 +32,9 @@ def test_blocks_are_padded_and_loss_counts_tracked(self): first = block_ds[0] self.assertTrue( - torch.equal(first["input_ids"], torch.tensor([1, 2, 3, 4], dtype=torch.long)) + torch.equal( + first["input_ids"], torch.tensor([1, 2, 3, 4], dtype=torch.long) + ) ) self.assertTrue( torch.equal(first["labels"], torch.tensor([1, 2, 3, 4], dtype=torch.long)) @@ -42,7 +44,9 @@ def test_blocks_are_padded_and_loss_counts_tracked(self): second = block_ds[1] self.assertTrue( - torch.equal(second["input_ids"], torch.tensor([5, 6, 7, 0], dtype=torch.long)) + torch.equal( + second["input_ids"], torch.tensor([5, 6, 7, 0], dtype=torch.long) + ) ) self.assertTrue( torch.equal( From 333a8ef4cace8efa21984d59678f19ebd6d5e983 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:08:31 +0000 Subject: [PATCH 07/18] add docs --- README.md | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bd04999b..c2c45c9a 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ The library now supports reasoning traces through the `reasoning_content` field - [Using the library](#using-the-library) - [Data format](#data-format) - [Reasoning content support](#reasoning-content-support-1) +- [Continual pretraining mode](#continual-pretraining-mode) - [Documentation](#documentation) - [Learning about the training arguments](#learning-about-training-arguments) - [`TrainingArgs`](#trainingargs) @@ -122,6 +123,46 @@ The library now supports an optional `reasoning_content` field in addition to th } ``` +## Continual pretraining mode + +In addition to instruction tuning, the library can run document-style continual pretraining on raw text corpora. +Enable this by supplying a block size when invoking `main_ds.py`: + +```bash +torchrun main_ds.py \ + --model_name_or_path mistralai/Mistral-7B-v0.1 \ + --data_path /data/documents.jsonl \ + --ckpt_output_dir ./checkpoints \ + --effective_batch_size 128 \ + --max_batch_len 60000 \ + --block-size 8192 \ + --document-column-name text # optional, defaults to "document" +``` + +- `--block-size` (required) toggles continual pretraining and controls how many tokens are packed into each block. +- `--document-column-name` (optional) specifies which JSONL field contains the raw document text. + +The same options are available programmatically via `TrainingArgs.pretraining_config`: + +```python +from instructlab.training import TrainingArgs, PretrainingConfig + +train_args = TrainingArgs( + model_name_or_path="mistralai/Mistral-7B-v0.1", + data_path="documents.jsonl", + ckpt_output_dir="./checkpoints", + max_seq_len=4096, + max_batch_len=40000, + effective_batch_size=128, + pretraining_config=PretrainingConfig( + block_size=2048, + document_column_name="text", # optional + ), +) +``` + +When a pretraining config is provided, `process_documents_for_pretraining()` is invoked under the hood to tokenize raw documents before training. + **Standard message structure:** ```json @@ -139,7 +180,7 @@ The library now supports an optional `reasoning_content` field in addition to th } ``` -#### Important Notes +### Important Notes 1. **Automatic reasoning content processing**: If `reasoning_content` exists in a message, it will always be processed and unmasked as long as the message role is targeted for unmasking. This ensures that reasoning traces are properly included in the training data. From 1cd4aa43994a3a12225e56440ea8ac10b27323f4 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Wed, 26 Nov 2025 17:18:40 +0000 Subject: [PATCH 08/18] updates mock tokenizer --- tests/unit/test_data_process.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/unit/test_data_process.py b/tests/unit/test_data_process.py index d0d36f6c..c69cc9f8 100644 --- a/tests/unit/test_data_process.py +++ b/tests/unit/test_data_process.py @@ -67,7 +67,9 @@ def _mock_apply_chat_template( messages: t.List[Message], tokenize: bool = True, add_special_tokens: bool = True, - ) -> t.Union[str, t.List[int]]: + return_dict: bool = False, + **kwargs, + ) -> t.Union[str, t.List[int], t.Dict[str, t.Any]]: """Mock implementation of apply_chat_template.""" template_tokens = [] @@ -91,10 +93,14 @@ def _mock_apply_chat_template( ] template_tokens.extend(reasoning_tokens) - if tokenize: - return template_tokens - else: - return " ".join([f"token_{t}" for t in template_tokens]) + result = ( + template_tokens + if tokenize + else " ".join([f"token_{t}" for t in template_tokens]) + ) + if return_dict: + return {"input_ids": result} + return result def test_single_turn_assistant_only_content(self): """Test basic single-turn conversation with assistant content only.""" @@ -555,7 +561,9 @@ def _mock_apply_chat_template( messages: t.List[Message], tokenize: bool = True, add_special_tokens: bool = True, - ) -> t.Union[str, t.List[int]]: + return_dict: bool = False, + **kwargs, + ) -> t.Union[str, t.List[int], t.Dict[str, t.Any]]: """Mock implementation of apply_chat_template.""" template_str = "" for msg in messages: @@ -566,10 +574,14 @@ def _mock_apply_chat_template( template_str += msg["reasoning_content"] template_str += "\n" - if tokenize: - return [hash(template_str) % 1000 for _ in range(len(template_str.split()))] - else: - return template_str + result = ( + [hash(template_str) % 1000 for _ in range(len(template_str.split()))] + if tokenize + else template_str + ) + if return_dict: + return {"input_ids": result} + return result def test_wrap_masked_messages_with_reasoning_content(self): """Test that wrap_masked_messages correctly wraps both content and reasoning_content.""" From fba8ce6b43739115475e6c1bed9de869383ee9cc Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 19:21:08 +0000 Subject: [PATCH 09/18] tests --- src/instructlab/training/config.py | 2 +- tests/unit/test_pretraining_data_process.py | 443 ++++++++++++++++ tests/unit/test_pretraining_sampler.py | 540 ++++++++++++++++++++ 3 files changed, 984 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_pretraining_data_process.py create mode 100644 tests/unit/test_pretraining_sampler.py diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 3f2dd810..17d3a69d 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -291,7 +291,7 @@ class TrainingArgs(BaseModel): keep_last_checkpoint_only: Optional[bool] = False pretraining_config: Optional[PretrainingConfig] = Field( - default=None, + default="document", description=( "Pretraining configuration. When provided, enables block-based sampling " "for raw document pretraining datasets." diff --git a/tests/unit/test_pretraining_data_process.py b/tests/unit/test_pretraining_data_process.py new file mode 100644 index 00000000..7981c8ce --- /dev/null +++ b/tests/unit/test_pretraining_data_process.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for pretraining data processing functionality.""" + +# Standard +from unittest.mock import MagicMock, patch, mock_open +import json +import os +import tempfile + +# Third Party +import pytest +from transformers import AutoTokenizer + +# First Party +from instructlab.training.data_process import process_documents_for_pretraining + + +class TestProcessDocumentsForPretraining: + """Test suite for process_documents_for_pretraining function.""" + + @pytest.fixture + def mock_tokenizer(self): + """Mock AutoTokenizer with BOS/EOS behavior.""" + mock_tok = MagicMock() + mock_tok.bos_token_id = 1 + mock_tok.eos_token_id = 2 + + # Mock encode to add BOS automatically and generate predictable tokens + def mock_encode(text, add_special_tokens=True): + # Simple hash-based encoding for predictability + tokens = [hash(text) % 1000 + 100] + if add_special_tokens: + return [mock_tok.bos_token_id] + tokens + return tokens + + mock_tok.encode = mock_encode + return mock_tok + + @pytest.fixture + def temp_pretraining_jsonl(self, tmp_path): + """Create temp JSONL with 'documents' field.""" + data_file = tmp_path / "pretraining.jsonl" + samples = [ + {"documents": "This is document one."}, + {"documents": "This is document two with more text."}, + {"documents": "Short doc."} + ] + + with open(data_file, 'w') as f: + for sample in samples: + json.dump(sample, f) + f.write('\n') + + return str(data_file) + + @pytest.fixture + def temp_output_dir(self, tmp_path): + """Create temporary output directory.""" + output_dir = tmp_path / "output" + output_dir.mkdir() + return str(output_dir) + + @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') + @patch('instructlab.training.data_process.load_dataset') + def test_basic_tokenization_with_bos_eos( + self, + mock_load_dataset, + mock_from_pretrained, + mock_tokenizer, + temp_pretraining_jsonl, + temp_output_dir + ): + """Verify basic tokenization adds BOS and EOS tokens correctly.""" + # Setup mocks + mock_from_pretrained.return_value = mock_tokenizer + + # Create mock dataset + mock_ds = MagicMock() + mock_ds.num_rows = 1 + mock_ds.column_names = ['documents'] + + # Mock single document + mock_ds.__iter__ = lambda self: iter([{"documents": "Test document"}]) + mock_ds.map = MagicMock() + + # Make map return a dataset with tokenized data + def map_side_effect(func, **kwargs): + result = func({"documents": "Test document"}) + mapped_ds = MagicMock() + mapped_ds.__getitem__ = lambda self, key: [result[key]] + mapped_ds.to_json = MagicMock() + return mapped_ds + + mock_ds.map.side_effect = map_side_effect + mock_load_dataset.return_value = mock_ds + + # Run function + process_documents_for_pretraining( + data_path=temp_pretraining_jsonl, + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=1, + ) + + # Verify tokenizer was loaded + mock_from_pretrained.assert_called_once_with("test-model") + + # Verify dataset map was called + assert mock_ds.map.called + + @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') + @patch('instructlab.training.data_process.load_dataset') + def test_multiple_documents_separate_records( + self, + mock_load_dataset, + mock_from_pretrained, + mock_tokenizer, + temp_output_dir + ): + """Ensure each document gets its own JSONL record.""" + # Setup + mock_from_pretrained.return_value = mock_tokenizer + + # Create mock dataset with 3 documents + mock_ds = MagicMock() + mock_ds.num_rows = 3 + mock_ds.column_names = ['documents'] + + docs = [ + {"documents": "Doc 1"}, + {"documents": "Doc 2"}, + {"documents": "Doc 3"} + ] + + # Mock map to process all documents + def map_side_effect(func, **kwargs): + results = [func(doc) for doc in docs] + mapped_ds = MagicMock() + mapped_ds.__len__ = lambda self: len(results) + mapped_ds.__getitem__ = lambda self, key: [r[key] for r in results] + mapped_ds.to_json = MagicMock() + return mapped_ds + + mock_ds.map.side_effect = map_side_effect + mock_load_dataset.return_value = mock_ds + + # Run + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=1, + ) + + # Verify map was called (which processes each document) + assert mock_ds.map.called + + @patch('instructlab.training.data_process.load_dataset') + def test_empty_dataset_raises_error(self, mock_load_dataset, temp_output_dir): + """Validate error handling for empty input.""" + # Create empty dataset + mock_ds = MagicMock() + mock_ds.num_rows = 0 + mock_load_dataset.return_value = mock_ds + + # Should raise ValueError + with pytest.raises(ValueError, match="empty"): + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=1, + ) + + @patch('instructlab.training.data_process.load_dataset') + def test_missing_documents_field_raises_error(self, mock_load_dataset, temp_output_dir): + """Validate schema enforcement.""" + # Create dataset with wrong field name + mock_ds = MagicMock() + mock_ds.num_rows = 1 + mock_ds.column_names = ['text'] # Wrong field name + mock_load_dataset.return_value = mock_ds + + # Should raise ValueError + with pytest.raises(ValueError, match="must have 'documents' field"): + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=1, + ) + + @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') + @patch('instructlab.training.data_process.load_dataset') + def test_tokenizer_without_eos_raises_error( + self, + mock_load_dataset, + mock_from_pretrained, + temp_output_dir + ): + """Validate tokenizer requirements.""" + # Create valid dataset + mock_ds = MagicMock() + mock_ds.num_rows = 1 + mock_ds.column_names = ['documents'] + mock_load_dataset.return_value = mock_ds + + # Create tokenizer without EOS token + mock_tok = MagicMock(spec=AutoTokenizer) + mock_tok.eos_token_id = None # No EOS token + mock_from_pretrained.return_value = mock_tok + + # Should raise ValueError + with pytest.raises(ValueError, match="must have an EOS token"): + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=1, + ) + + @patch('instructlab.training.data_process.logger') + @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') + @patch('instructlab.training.data_process.load_dataset') + def test_statistics_logging( + self, + mock_load_dataset, + mock_from_pretrained, + mock_logger, + mock_tokenizer, + temp_output_dir + ): + """Verify statistics are calculated correctly.""" + # Setup + mock_from_pretrained.return_value = mock_tokenizer + + # Create dataset with known token counts + mock_ds = MagicMock() + mock_ds.num_rows = 2 + mock_ds.column_names = ['documents'] + + # Mock map to return known lengths + def map_side_effect(func, **kwargs): + # Simulate 2 documents with 5 and 10 tokens each + mapped_ds = MagicMock() + mapped_ds.__getitem__ = lambda self, key: [5, 10] if key == 'len' else None + mapped_ds.__len__ = lambda self: 2 + mapped_ds.to_json = MagicMock() + return mapped_ds + + mock_ds.map.side_effect = map_side_effect + mock_load_dataset.return_value = mock_ds + + # Run + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=1, + ) + + # Verify logging was called (check info was called multiple times) + assert mock_logger.info.call_count >= 3 + + @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') + @patch('instructlab.training.data_process.load_dataset') + def test_parallel_processing( + self, + mock_load_dataset, + mock_from_pretrained, + mock_tokenizer, + temp_output_dir + ): + """Ensure num_cpu_procs parameter works.""" + # Setup + mock_from_pretrained.return_value = mock_tokenizer + + mock_ds = MagicMock() + mock_ds.num_rows = 1 + mock_ds.column_names = ['documents'] + mock_ds.map = MagicMock() + + def map_side_effect(func, **kwargs): + mapped_ds = MagicMock() + mapped_ds.to_json = MagicMock() + return mapped_ds + + mock_ds.map.side_effect = map_side_effect + mock_load_dataset.return_value = mock_ds + + # Run with specific num_cpu_procs + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=4, + ) + + # Verify map was called with num_proc=4 + call_args = mock_ds.map.call_args + assert call_args[1]['num_proc'] == 4 + + def test_output_directory_creation(self, tmp_path, mock_tokenizer): + """Verify directory is created if it doesn't exist.""" + # Use non-existent output path + output_dir = tmp_path / "nonexistent" / "nested" / "dir" + + with patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') as mock_from_pretrained: + with patch('instructlab.training.data_process.load_dataset') as mock_load_dataset: + mock_from_pretrained.return_value = mock_tokenizer + + mock_ds = MagicMock() + mock_ds.num_rows = 1 + mock_ds.column_names = ['documents'] + + def map_side_effect(func, **kwargs): + mapped_ds = MagicMock() + mapped_ds.to_json = MagicMock() + return mapped_ds + + mock_ds.map.side_effect = map_side_effect + mock_load_dataset.return_value = mock_ds + + # Run + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=str(output_dir), + model_path="test-model", + num_cpu_procs=1, + ) + + # Verify directory was created + assert output_dir.exists() + + @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') + @patch('instructlab.training.data_process.load_dataset') + def test_output_jsonl_format( + self, + mock_load_dataset, + mock_from_pretrained, + mock_tokenizer, + temp_output_dir + ): + """Validate JSONL output format.""" + # Setup + mock_from_pretrained.return_value = mock_tokenizer + + mock_ds = MagicMock() + mock_ds.num_rows = 1 + mock_ds.column_names = ['documents'] + + # Track what gets written + output_file_path = None + + def map_side_effect(func, **kwargs): + result = func({"documents": "Test"}) + mapped_ds = MagicMock() + + def to_json_side_effect(path, **kw): + nonlocal output_file_path + output_file_path = path + # Write actual JSON to verify format + with open(path, 'w') as f: + json.dump(result, f) + f.write('\n') + + mapped_ds.to_json = to_json_side_effect + return mapped_ds + + mock_ds.map.side_effect = map_side_effect + mock_load_dataset.return_value = mock_ds + + # Run + process_documents_for_pretraining( + data_path="dummy.jsonl", + data_output_path=temp_output_dir, + model_path="test-model", + num_cpu_procs=1, + ) + + # Verify file was created + assert output_file_path is not None + assert os.path.exists(output_file_path) + + # Verify format + with open(output_file_path, 'r') as f: + line = f.readline() + data = json.loads(line) + + # Should have input_ids and len fields + assert 'input_ids' in data + assert 'len' in data + + # Should NOT have labels field + assert 'labels' not in data + + # input_ids should be a list starting with BOS + assert isinstance(data['input_ids'], list) + assert data['input_ids'][0] == 1 # BOS token + assert data['input_ids'][-1] == 2 # EOS token + + @pytest.mark.slow + def test_integration_with_real_tokenizer(self, temp_output_dir): + """Integration test with actual GPT2 tokenizer.""" + # Create real input file + input_file = os.path.join(temp_output_dir, "input.jsonl") + with open(input_file, 'w') as f: + json.dump({"documents": "This is a test document for GPT2 tokenization."}, f) + f.write('\n') + + # Run with real tokenizer + process_documents_for_pretraining( + data_path=input_file, + data_output_path=temp_output_dir, + model_path="gpt2", + num_cpu_procs=1, + ) + + # Verify output + output_file = os.path.join(temp_output_dir, "data.jsonl") + assert os.path.exists(output_file) + + with open(output_file, 'r') as f: + line = f.readline() + data = json.loads(line) + + # Verify structure + assert 'input_ids' in data + assert 'len' in data + assert len(data['input_ids']) == data['len'] + + # Load tokenizer to verify tokens + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + # Verify BOS/EOS are present (GPT2 uses same token 50256 for both) + # encode() with add_special_tokens=True adds BOS + # We manually append EOS + assert data['input_ids'][0] == tokenizer.bos_token_id or data['input_ids'][0] == tokenizer.eos_token_id + assert data['input_ids'][-1] == tokenizer.eos_token_id + + # Verify token count is reasonable + assert data['len'] > 5 # Should have more than just BOS/EOS diff --git a/tests/unit/test_pretraining_sampler.py b/tests/unit/test_pretraining_sampler.py new file mode 100644 index 00000000..13cb3382 --- /dev/null +++ b/tests/unit/test_pretraining_sampler.py @@ -0,0 +1,540 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for pretraining sampler functionality.""" + +# Standard +from unittest.mock import MagicMock, patch +import json + +# Third Party +import pytest +import torch + +# First Party +from instructlab.training.sampler import PretrainingBlockDataset, get_data_loader + + +class TestPretrainingBlockDataset: + """Test suite for PretrainingBlockDataset class.""" + + @pytest.fixture + def sample_pretraining_data(self): + """Sample tokenized data (14 total tokens).""" + return [ + {"input_ids": [1, 2, 3, 4, 5], "len": 5}, + {"input_ids": [6, 7, 8, 9, 10, 11], "len": 6}, + {"input_ids": [12, 13, 14], "len": 3}, + ] + + @pytest.fixture + def mock_hf_dataset(self, sample_pretraining_data): + """Mock HuggingFace dataset.""" + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: len(sample_pretraining_data) + mock_ds.__iter__ = lambda self: iter(sample_pretraining_data) + return mock_ds + + def test_dataset_initialization(self, mock_hf_dataset): + """Test basic initialization of PretrainingBlockDataset.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + # Verify basic attributes + assert dataset.block_size == 5 + assert dataset.pad_token_id == 0 + assert dataset.num_blocks == 3 # 14 tokens / 5 = 2 complete + 1 partial + assert dataset.last_block_len == 4 # 14 % 5 = 4 + assert len(dataset.all_input_ids) == 14 # All tokens concatenated + + def test_concatenation_of_documents(self, mock_hf_dataset): + """Verify documents are concatenated in the correct order.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + # Check concatenation order + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + assert dataset.all_input_ids == expected + + def test_num_blocks_calculation_with_partial(self, mock_hf_dataset): + """Test num_blocks calculation with partial block.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + # 14 tokens / 5 = 2 complete + 1 partial + assert dataset.num_blocks == 3 + assert dataset.last_block_len == 4 + + def test_num_blocks_calculation_exact_multiple(self, sample_pretraining_data): + """Test num_blocks calculation when tokens exactly divide by block_size.""" + # Add one more token to make it 15 (exact multiple of 5) + data = sample_pretraining_data + [{"input_ids": [15], "len": 1}] + + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: len(data) + mock_ds.__iter__ = lambda self: iter(data) + + dataset = PretrainingBlockDataset( + dataset=mock_ds, + block_size=5, + pad_token_id=0 + ) + + # 15 tokens / 5 = 3 complete blocks + assert dataset.num_blocks == 3 + assert dataset.last_block_len == 5 # Last block is complete + + def test_getitem_complete_block(self, mock_hf_dataset): + """Test __getitem__ for a complete block.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + # Get first block (indices 0-4) + block = dataset[0] + + assert block["input_ids"].shape == (5,) + assert block["labels"].shape == (5,) + assert block["len"] == 5 + assert block["num_loss_counted_tokens"] == 4 # block_size - 1 (causal shift) + + # Check actual token values + assert torch.equal(block["input_ids"], torch.tensor([1, 2, 3, 4, 5], dtype=torch.long)) + assert torch.equal(block["labels"], torch.tensor([1, 2, 3, 4, 5], dtype=torch.long)) + + def test_getitem_partial_block_with_padding(self, mock_hf_dataset): + """Test __getitem__ for partial last block with padding.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + # Get last block (index 2) - should have 4 real tokens + 1 padding + block = dataset[2] + + assert block["input_ids"].shape == (5,) + assert block["labels"].shape == (5,) + assert block["len"] == 5 + + # 4 real tokens - 1 for causal shift = 3 + assert block["num_loss_counted_tokens"] == 3 + + # Last token should be pad_token_id (0) + assert block["input_ids"][-1].item() == 0 + + # Last label should be masked (-100) + assert block["labels"][-1].item() == -100 + + # First 4 tokens should be real data [11, 12, 13, 14] from position 10-13 + assert block["input_ids"][0].item() == 11 + assert block["input_ids"][1].item() == 12 + assert block["input_ids"][2].item() == 13 + assert block["input_ids"][3].item() == 14 + + def test_labels_are_copy_not_reference(self, mock_hf_dataset): + """Test that labels are a copy, not a reference to input_ids.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + block = dataset[0] + + # Tensors should not be the same object + assert block["input_ids"] is not block["labels"] + + # But values should be equal for complete blocks + assert torch.equal(block["input_ids"], block["labels"]) + + # Modify labels to verify they're independent + original_labels = block["labels"].clone() + block["labels"][0] = 999 + + # input_ids should remain unchanged + assert block["input_ids"][0].item() != 999 + assert block["input_ids"][0].item() == 1 + + def test_num_loss_counted_tokens_complete_block(self): + """Test num_loss_counted_tokens for complete blocks with various block sizes.""" + for block_size in [5, 10, 20]: + # Create data with at least 2 complete blocks + num_tokens = block_size * 2 + data = [{"input_ids": list(range(num_tokens)), "len": num_tokens}] + + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: len(data) + mock_ds.__iter__ = lambda self: iter(data) + + dataset = PretrainingBlockDataset( + dataset=mock_ds, + block_size=block_size, + pad_token_id=0 + ) + + # Check first complete block + block = dataset[0] + assert block["num_loss_counted_tokens"] == block_size - 1 + + def test_num_loss_counted_tokens_partial_block(self, mock_hf_dataset): + """Test num_loss_counted_tokens for partial blocks.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + # Last block has 4 real tokens + block = dataset[2] + + # Should be actual_length - 1 = 4 - 1 = 3 + assert block["num_loss_counted_tokens"] == 3 + + def test_index_out_of_range(self, mock_hf_dataset): + """Test that accessing beyond num_blocks raises IndexError.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + # Try to access block beyond num_blocks (which is 3) + with pytest.raises(IndexError) as exc_info: + _ = dataset[3] + + assert "out of range" in str(exc_info.value).lower() + + def test_missing_input_ids_field_raises_error(self): + """Test that missing input_ids field raises ValueError.""" + # Create dataset without input_ids field + mock_ds = MagicMock() + mock_ds.column_names = ["len"] # Missing input_ids + + with pytest.raises(ValueError) as exc_info: + _ = PretrainingBlockDataset( + dataset=mock_ds, + block_size=5, + pad_token_id=0 + ) + + assert "input_ids" in str(exc_info.value) + + def test_tensor_dtype_correct(self, mock_hf_dataset): + """Test that all tensors use torch.long dtype.""" + dataset = PretrainingBlockDataset( + dataset=mock_hf_dataset, + block_size=5, + pad_token_id=0 + ) + + block = dataset[0] + + assert block["input_ids"].dtype == torch.long + assert block["labels"].dtype == torch.long + + +class TestGetDataLoaderPretraining: + """Test suite for get_data_loader with pretraining mode.""" + + @pytest.fixture + def temp_pretraining_file(self, tmp_path): + """Create temp JSONL with pretraining data.""" + data_file = tmp_path / "pretraining_data.jsonl" + samples = [ + {"input_ids": list(range(100, 150)), "len": 50}, + {"input_ids": list(range(200, 280)), "len": 80}, + {"input_ids": list(range(300, 370)), "len": 70}, + ] + + with open(data_file, 'w') as f: + for sample in samples: + json.dump(sample, f) + f.write('\n') + + return str(data_file) + + @patch('instructlab.training.sampler.load_dataset') + def test_pretraining_mode_creates_block_dataset( + self, + mock_load_dataset, + temp_pretraining_file + ): + """Test that is_pretraining=True creates PretrainingBlockDataset.""" + # Create mock dataset + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: 3 + mock_ds.__iter__ = lambda self: iter([ + {"input_ids": [1, 2, 3], "len": 3}, + {"input_ids": [4, 5, 6], "len": 3}, + {"input_ids": [7, 8, 9], "len": 3}, + ]) + mock_load_dataset.return_value = mock_ds + + # Call with pretraining mode + loader = get_data_loader( + data_path=temp_pretraining_file, + batch_size=2, + max_tokens_per_gpu=100, + seed=42, + rank=0, + world_size=1, + is_pretraining=True, + block_size=128 + ) + + # Verify load_dataset was called + mock_load_dataset.assert_called_once() + + # Verify dataset is PretrainingBlockDataset + assert isinstance(loader.dataset, PretrainingBlockDataset) + + def test_instruction_tuning_mode_creates_token_dataset(self, temp_pretraining_file): + """Test that is_pretraining=False uses TokenDataset.""" + # Create a valid instruction tuning JSONL file + from pathlib import Path + inst_file = Path(temp_pretraining_file).parent / "inst_data.jsonl" + samples = [ + {"input_ids": [1, 2, 3], "labels": [1, 2, 3], "len": 3}, + {"input_ids": [4, 5, 6], "labels": [4, 5, 6], "len": 3}, + ] + with open(inst_file, 'w') as f: + for sample in samples: + json.dump(sample, f) + f.write('\n') + + # Call with instruction tuning mode (default) + loader = get_data_loader( + data_path=str(inst_file), + batch_size=2, + max_tokens_per_gpu=100, + seed=42, + rank=0, + world_size=1, + is_pretraining=False + ) + + # Verify dataset is TokenDataset (not PretrainingBlockDataset) + from instructlab.training.sampler import TokenDataset + assert isinstance(loader.dataset, TokenDataset) + assert not isinstance(loader.dataset, PretrainingBlockDataset) + + @patch('instructlab.training.sampler.load_dataset') + def test_pretraining_block_size_parameter( + self, + mock_load_dataset, + temp_pretraining_file + ): + """Test that block_size parameter is correctly passed.""" + # Create mock dataset + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: 1 + mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(100)), "len": 100}]) + mock_load_dataset.return_value = mock_ds + + # Call with specific block_size + block_size = 256 + loader = get_data_loader( + data_path=temp_pretraining_file, + batch_size=2, + max_tokens_per_gpu=1000, + seed=42, + rank=0, + world_size=1, + is_pretraining=True, + block_size=block_size + ) + + # Verify dataset has correct block_size + assert loader.dataset.block_size == block_size + + @patch('instructlab.training.sampler.load_dataset') + def test_pretraining_pad_token_id_used( + self, + mock_load_dataset, + temp_pretraining_file + ): + """Test that pad_token_id is correctly passed to PretrainingBlockDataset.""" + # Create mock dataset + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: 1 + mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(10)), "len": 10}]) + mock_load_dataset.return_value = mock_ds + + # Call with specific pad_token_id + pad_token_id = 99 + loader = get_data_loader( + data_path=temp_pretraining_file, + batch_size=2, + max_tokens_per_gpu=100, + seed=42, + rank=0, + world_size=1, + is_pretraining=True, + block_size=7, # Will create partial block + pad_token_id=pad_token_id + ) + + # Verify dataset has correct pad_token_id + assert loader.dataset.pad_token_id == pad_token_id + + @patch('instructlab.training.sampler.load_dataset') + def test_data_loader_returns_correct_structure( + self, + mock_load_dataset, + temp_pretraining_file + ): + """Test that get_data_loader returns a properly configured DataLoader.""" + # Create mock dataset + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: 2 + mock_ds.__iter__ = lambda self: iter([ + {"input_ids": list(range(50)), "len": 50}, + {"input_ids": list(range(50, 100)), "len": 50}, + ]) + mock_load_dataset.return_value = mock_ds + + # Call get_data_loader + loader = get_data_loader( + data_path=temp_pretraining_file, + batch_size=2, + max_tokens_per_gpu=100, + seed=42, + rank=0, + world_size=1, + is_pretraining=True, + block_size=25 + ) + + # Verify it's a DataLoader + from torch.utils.data import DataLoader + assert isinstance(loader, DataLoader) + + # Verify batch_size + assert loader.batch_size == 2 + + @patch('instructlab.training.sampler.load_dataset') + def test_epoch_sampler_created( + self, + mock_load_dataset, + temp_pretraining_file + ): + """Test that EpochSampler is created with correct parameters.""" + # Create mock dataset with known length + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: 1 + mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(100)), "len": 100}]) + mock_load_dataset.return_value = mock_ds + + seed = 123 + block_size = 25 + + loader = get_data_loader( + data_path=temp_pretraining_file, + batch_size=2, + max_tokens_per_gpu=100, + seed=seed, + rank=0, + world_size=1, + is_pretraining=True, + block_size=block_size + ) + + # Verify sampler is EpochSampler + from instructlab.training.sampler import EpochSampler + assert isinstance(loader.sampler, EpochSampler) + + # Verify seed is set correctly + assert loader.sampler.seed == seed + + @patch('instructlab.training.sampler.load_dataset') + def test_collator_configuration( + self, + mock_load_dataset, + temp_pretraining_file + ): + """Test that MaxTokensPerRankCollator is configured correctly.""" + # Create mock dataset + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: 1 + mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(50)), "len": 50}]) + mock_load_dataset.return_value = mock_ds + + flash_enabled = False + pad_token_id = 42 + max_tokens = 200 + + loader = get_data_loader( + data_path=temp_pretraining_file, + batch_size=2, + max_tokens_per_gpu=max_tokens, + seed=42, + rank=0, + world_size=1, + is_pretraining=True, + block_size=25, + flash_enabled=flash_enabled, + pad_token_id=pad_token_id + ) + + # Verify collate_fn is MaxTokensPerRankCollator + from instructlab.training.sampler import MaxTokensPerRankCollator + assert isinstance(loader.collate_fn, MaxTokensPerRankCollator) + + # Verify collator configuration + assert loader.collate_fn.max_tokens_per_rank == max_tokens + assert loader.collate_fn.flash_enabled == flash_enabled + assert loader.collate_fn.pad_token_id == pad_token_id + + @patch('instructlab.training.sampler.load_dataset') + def test_num_workers_parameter( + self, + mock_load_dataset, + temp_pretraining_file + ): + """Test that num_workers parameter is correctly applied.""" + # Create mock dataset + mock_ds = MagicMock() + mock_ds.column_names = ["input_ids", "len"] + mock_ds.__len__ = lambda self: 1 + mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(50)), "len": 50}]) + mock_load_dataset.return_value = mock_ds + + num_workers = 4 + + loader = get_data_loader( + data_path=temp_pretraining_file, + batch_size=2, + max_tokens_per_gpu=100, + seed=42, + rank=0, + world_size=1, + is_pretraining=True, + block_size=25, + num_workers=num_workers + ) + + # Verify num_workers is set + assert loader.num_workers == num_workers + + # When num_workers > 0, persistent_workers should be True + assert loader.persistent_workers == True From f57251b979a1225c344aaf1e505c589a7fc68f7c Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:21:40 +0000 Subject: [PATCH 10/18] linting --- src/instructlab/training/main_ds.py | 5 +- tests/unit/test_pretraining_data_process.py | 148 ++++++++-------- tests/unit/test_pretraining_sampler.py | 184 +++++++++----------- 3 files changed, 155 insertions(+), 182 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4d44af99..decda926 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -6,7 +6,6 @@ import logging import os import subprocess -import sys import time import warnings @@ -458,7 +457,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # Enable package logging propagation before setting up loggers propagate_package_logs(True) setup_root_logger(train_args.log_level) - setup_metric_logger("async", None, train_args.ckpt_output_dir) + setup_metric_logger( + train_args.logger_type, train_args.run_name, train_args.ckpt_output_dir + ) logger = logging.getLogger("instructlab.training") logger.info("Starting training setup...") diff --git a/tests/unit/test_pretraining_data_process.py b/tests/unit/test_pretraining_data_process.py index 7981c8ce..ab1537e5 100644 --- a/tests/unit/test_pretraining_data_process.py +++ b/tests/unit/test_pretraining_data_process.py @@ -3,14 +3,14 @@ """Unit tests for pretraining data processing functionality.""" # Standard -from unittest.mock import MagicMock, patch, mock_open +from unittest.mock import MagicMock, mock_open, patch import json import os import tempfile # Third Party -import pytest from transformers import AutoTokenizer +import pytest # First Party from instructlab.training.data_process import process_documents_for_pretraining @@ -44,13 +44,13 @@ def temp_pretraining_jsonl(self, tmp_path): samples = [ {"documents": "This is document one."}, {"documents": "This is document two with more text."}, - {"documents": "Short doc."} + {"documents": "Short doc."}, ] - with open(data_file, 'w') as f: + with open(data_file, "w") as f: for sample in samples: json.dump(sample, f) - f.write('\n') + f.write("\n") return str(data_file) @@ -61,15 +61,15 @@ def temp_output_dir(self, tmp_path): output_dir.mkdir() return str(output_dir) - @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') - @patch('instructlab.training.data_process.load_dataset') + @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") + @patch("instructlab.training.data_process.load_dataset") def test_basic_tokenization_with_bos_eos( self, mock_load_dataset, mock_from_pretrained, mock_tokenizer, temp_pretraining_jsonl, - temp_output_dir + temp_output_dir, ): """Verify basic tokenization adds BOS and EOS tokens correctly.""" # Setup mocks @@ -78,7 +78,7 @@ def test_basic_tokenization_with_bos_eos( # Create mock dataset mock_ds = MagicMock() mock_ds.num_rows = 1 - mock_ds.column_names = ['documents'] + mock_ds.column_names = ["documents"] # Mock single document mock_ds.__iter__ = lambda self: iter([{"documents": "Test document"}]) @@ -109,14 +109,10 @@ def map_side_effect(func, **kwargs): # Verify dataset map was called assert mock_ds.map.called - @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') - @patch('instructlab.training.data_process.load_dataset') + @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") + @patch("instructlab.training.data_process.load_dataset") def test_multiple_documents_separate_records( - self, - mock_load_dataset, - mock_from_pretrained, - mock_tokenizer, - temp_output_dir + self, mock_load_dataset, mock_from_pretrained, mock_tokenizer, temp_output_dir ): """Ensure each document gets its own JSONL record.""" # Setup @@ -125,13 +121,9 @@ def test_multiple_documents_separate_records( # Create mock dataset with 3 documents mock_ds = MagicMock() mock_ds.num_rows = 3 - mock_ds.column_names = ['documents'] + mock_ds.column_names = ["documents"] - docs = [ - {"documents": "Doc 1"}, - {"documents": "Doc 2"}, - {"documents": "Doc 3"} - ] + docs = [{"documents": "Doc 1"}, {"documents": "Doc 2"}, {"documents": "Doc 3"}] # Mock map to process all documents def map_side_effect(func, **kwargs): @@ -156,7 +148,7 @@ def map_side_effect(func, **kwargs): # Verify map was called (which processes each document) assert mock_ds.map.called - @patch('instructlab.training.data_process.load_dataset') + @patch("instructlab.training.data_process.load_dataset") def test_empty_dataset_raises_error(self, mock_load_dataset, temp_output_dir): """Validate error handling for empty input.""" # Create empty dataset @@ -173,13 +165,15 @@ def test_empty_dataset_raises_error(self, mock_load_dataset, temp_output_dir): num_cpu_procs=1, ) - @patch('instructlab.training.data_process.load_dataset') - def test_missing_documents_field_raises_error(self, mock_load_dataset, temp_output_dir): + @patch("instructlab.training.data_process.load_dataset") + def test_missing_documents_field_raises_error( + self, mock_load_dataset, temp_output_dir + ): """Validate schema enforcement.""" # Create dataset with wrong field name mock_ds = MagicMock() mock_ds.num_rows = 1 - mock_ds.column_names = ['text'] # Wrong field name + mock_ds.column_names = ["text"] # Wrong field name mock_load_dataset.return_value = mock_ds # Should raise ValueError @@ -191,19 +185,16 @@ def test_missing_documents_field_raises_error(self, mock_load_dataset, temp_outp num_cpu_procs=1, ) - @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') - @patch('instructlab.training.data_process.load_dataset') + @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") + @patch("instructlab.training.data_process.load_dataset") def test_tokenizer_without_eos_raises_error( - self, - mock_load_dataset, - mock_from_pretrained, - temp_output_dir + self, mock_load_dataset, mock_from_pretrained, temp_output_dir ): """Validate tokenizer requirements.""" # Create valid dataset mock_ds = MagicMock() mock_ds.num_rows = 1 - mock_ds.column_names = ['documents'] + mock_ds.column_names = ["documents"] mock_load_dataset.return_value = mock_ds # Create tokenizer without EOS token @@ -220,16 +211,16 @@ def test_tokenizer_without_eos_raises_error( num_cpu_procs=1, ) - @patch('instructlab.training.data_process.logger') - @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') - @patch('instructlab.training.data_process.load_dataset') + @patch("instructlab.training.data_process.logger") + @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") + @patch("instructlab.training.data_process.load_dataset") def test_statistics_logging( self, mock_load_dataset, mock_from_pretrained, mock_logger, mock_tokenizer, - temp_output_dir + temp_output_dir, ): """Verify statistics are calculated correctly.""" # Setup @@ -238,13 +229,13 @@ def test_statistics_logging( # Create dataset with known token counts mock_ds = MagicMock() mock_ds.num_rows = 2 - mock_ds.column_names = ['documents'] + mock_ds.column_names = ["documents"] # Mock map to return known lengths def map_side_effect(func, **kwargs): # Simulate 2 documents with 5 and 10 tokens each mapped_ds = MagicMock() - mapped_ds.__getitem__ = lambda self, key: [5, 10] if key == 'len' else None + mapped_ds.__getitem__ = lambda self, key: [5, 10] if key == "len" else None mapped_ds.__len__ = lambda self: 2 mapped_ds.to_json = MagicMock() return mapped_ds @@ -263,14 +254,10 @@ def map_side_effect(func, **kwargs): # Verify logging was called (check info was called multiple times) assert mock_logger.info.call_count >= 3 - @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') - @patch('instructlab.training.data_process.load_dataset') + @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") + @patch("instructlab.training.data_process.load_dataset") def test_parallel_processing( - self, - mock_load_dataset, - mock_from_pretrained, - mock_tokenizer, - temp_output_dir + self, mock_load_dataset, mock_from_pretrained, mock_tokenizer, temp_output_dir ): """Ensure num_cpu_procs parameter works.""" # Setup @@ -278,7 +265,7 @@ def test_parallel_processing( mock_ds = MagicMock() mock_ds.num_rows = 1 - mock_ds.column_names = ['documents'] + mock_ds.column_names = ["documents"] mock_ds.map = MagicMock() def map_side_effect(func, **kwargs): @@ -299,20 +286,24 @@ def map_side_effect(func, **kwargs): # Verify map was called with num_proc=4 call_args = mock_ds.map.call_args - assert call_args[1]['num_proc'] == 4 + assert call_args[1]["num_proc"] == 4 def test_output_directory_creation(self, tmp_path, mock_tokenizer): """Verify directory is created if it doesn't exist.""" # Use non-existent output path output_dir = tmp_path / "nonexistent" / "nested" / "dir" - with patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') as mock_from_pretrained: - with patch('instructlab.training.data_process.load_dataset') as mock_load_dataset: + with patch( + "instructlab.training.data_process.AutoTokenizer.from_pretrained" + ) as mock_from_pretrained: + with patch( + "instructlab.training.data_process.load_dataset" + ) as mock_load_dataset: mock_from_pretrained.return_value = mock_tokenizer mock_ds = MagicMock() mock_ds.num_rows = 1 - mock_ds.column_names = ['documents'] + mock_ds.column_names = ["documents"] def map_side_effect(func, **kwargs): mapped_ds = MagicMock() @@ -333,14 +324,10 @@ def map_side_effect(func, **kwargs): # Verify directory was created assert output_dir.exists() - @patch('instructlab.training.data_process.AutoTokenizer.from_pretrained') - @patch('instructlab.training.data_process.load_dataset') + @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") + @patch("instructlab.training.data_process.load_dataset") def test_output_jsonl_format( - self, - mock_load_dataset, - mock_from_pretrained, - mock_tokenizer, - temp_output_dir + self, mock_load_dataset, mock_from_pretrained, mock_tokenizer, temp_output_dir ): """Validate JSONL output format.""" # Setup @@ -348,7 +335,7 @@ def test_output_jsonl_format( mock_ds = MagicMock() mock_ds.num_rows = 1 - mock_ds.column_names = ['documents'] + mock_ds.column_names = ["documents"] # Track what gets written output_file_path = None @@ -361,9 +348,9 @@ def to_json_side_effect(path, **kw): nonlocal output_file_path output_file_path = path # Write actual JSON to verify format - with open(path, 'w') as f: + with open(path, "w") as f: json.dump(result, f) - f.write('\n') + f.write("\n") mapped_ds.to_json = to_json_side_effect return mapped_ds @@ -384,30 +371,32 @@ def to_json_side_effect(path, **kw): assert os.path.exists(output_file_path) # Verify format - with open(output_file_path, 'r') as f: + with open(output_file_path, "r") as f: line = f.readline() data = json.loads(line) # Should have input_ids and len fields - assert 'input_ids' in data - assert 'len' in data + assert "input_ids" in data + assert "len" in data # Should NOT have labels field - assert 'labels' not in data + assert "labels" not in data # input_ids should be a list starting with BOS - assert isinstance(data['input_ids'], list) - assert data['input_ids'][0] == 1 # BOS token - assert data['input_ids'][-1] == 2 # EOS token + assert isinstance(data["input_ids"], list) + assert data["input_ids"][0] == 1 # BOS token + assert data["input_ids"][-1] == 2 # EOS token @pytest.mark.slow def test_integration_with_real_tokenizer(self, temp_output_dir): """Integration test with actual GPT2 tokenizer.""" # Create real input file input_file = os.path.join(temp_output_dir, "input.jsonl") - with open(input_file, 'w') as f: - json.dump({"documents": "This is a test document for GPT2 tokenization."}, f) - f.write('\n') + with open(input_file, "w") as f: + json.dump( + {"documents": "This is a test document for GPT2 tokenization."}, f + ) + f.write("\n") # Run with real tokenizer process_documents_for_pretraining( @@ -421,14 +410,14 @@ def test_integration_with_real_tokenizer(self, temp_output_dir): output_file = os.path.join(temp_output_dir, "data.jsonl") assert os.path.exists(output_file) - with open(output_file, 'r') as f: + with open(output_file, "r") as f: line = f.readline() data = json.loads(line) # Verify structure - assert 'input_ids' in data - assert 'len' in data - assert len(data['input_ids']) == data['len'] + assert "input_ids" in data + assert "len" in data + assert len(data["input_ids"]) == data["len"] # Load tokenizer to verify tokens tokenizer = AutoTokenizer.from_pretrained("gpt2") @@ -436,8 +425,11 @@ def test_integration_with_real_tokenizer(self, temp_output_dir): # Verify BOS/EOS are present (GPT2 uses same token 50256 for both) # encode() with add_special_tokens=True adds BOS # We manually append EOS - assert data['input_ids'][0] == tokenizer.bos_token_id or data['input_ids'][0] == tokenizer.eos_token_id - assert data['input_ids'][-1] == tokenizer.eos_token_id + assert ( + data["input_ids"][0] == tokenizer.bos_token_id + or data["input_ids"][0] == tokenizer.eos_token_id + ) + assert data["input_ids"][-1] == tokenizer.eos_token_id # Verify token count is reasonable - assert data['len'] > 5 # Should have more than just BOS/EOS + assert data["len"] > 5 # Should have more than just BOS/EOS diff --git a/tests/unit/test_pretraining_sampler.py b/tests/unit/test_pretraining_sampler.py index 13cb3382..62a4d203 100644 --- a/tests/unit/test_pretraining_sampler.py +++ b/tests/unit/test_pretraining_sampler.py @@ -38,9 +38,7 @@ def mock_hf_dataset(self, sample_pretraining_data): def test_dataset_initialization(self, mock_hf_dataset): """Test basic initialization of PretrainingBlockDataset.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) # Verify basic attributes @@ -53,9 +51,7 @@ def test_dataset_initialization(self, mock_hf_dataset): def test_concatenation_of_documents(self, mock_hf_dataset): """Verify documents are concatenated in the correct order.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) # Check concatenation order @@ -65,9 +61,7 @@ def test_concatenation_of_documents(self, mock_hf_dataset): def test_num_blocks_calculation_with_partial(self, mock_hf_dataset): """Test num_blocks calculation with partial block.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) # 14 tokens / 5 = 2 complete + 1 partial @@ -84,11 +78,7 @@ def test_num_blocks_calculation_exact_multiple(self, sample_pretraining_data): mock_ds.__len__ = lambda self: len(data) mock_ds.__iter__ = lambda self: iter(data) - dataset = PretrainingBlockDataset( - dataset=mock_ds, - block_size=5, - pad_token_id=0 - ) + dataset = PretrainingBlockDataset(dataset=mock_ds, block_size=5, pad_token_id=0) # 15 tokens / 5 = 3 complete blocks assert dataset.num_blocks == 3 @@ -97,9 +87,7 @@ def test_num_blocks_calculation_exact_multiple(self, sample_pretraining_data): def test_getitem_complete_block(self, mock_hf_dataset): """Test __getitem__ for a complete block.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) # Get first block (indices 0-4) @@ -111,15 +99,17 @@ def test_getitem_complete_block(self, mock_hf_dataset): assert block["num_loss_counted_tokens"] == 4 # block_size - 1 (causal shift) # Check actual token values - assert torch.equal(block["input_ids"], torch.tensor([1, 2, 3, 4, 5], dtype=torch.long)) - assert torch.equal(block["labels"], torch.tensor([1, 2, 3, 4, 5], dtype=torch.long)) + assert torch.equal( + block["input_ids"], torch.tensor([1, 2, 3, 4, 5], dtype=torch.long) + ) + assert torch.equal( + block["labels"], torch.tensor([1, 2, 3, 4, 5], dtype=torch.long) + ) def test_getitem_partial_block_with_padding(self, mock_hf_dataset): """Test __getitem__ for partial last block with padding.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) # Get last block (index 2) - should have 4 real tokens + 1 padding @@ -147,9 +137,7 @@ def test_getitem_partial_block_with_padding(self, mock_hf_dataset): def test_labels_are_copy_not_reference(self, mock_hf_dataset): """Test that labels are a copy, not a reference to input_ids.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) block = dataset[0] @@ -181,9 +169,7 @@ def test_num_loss_counted_tokens_complete_block(self): mock_ds.__iter__ = lambda self: iter(data) dataset = PretrainingBlockDataset( - dataset=mock_ds, - block_size=block_size, - pad_token_id=0 + dataset=mock_ds, block_size=block_size, pad_token_id=0 ) # Check first complete block @@ -193,9 +179,7 @@ def test_num_loss_counted_tokens_complete_block(self): def test_num_loss_counted_tokens_partial_block(self, mock_hf_dataset): """Test num_loss_counted_tokens for partial blocks.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) # Last block has 4 real tokens @@ -207,9 +191,7 @@ def test_num_loss_counted_tokens_partial_block(self, mock_hf_dataset): def test_index_out_of_range(self, mock_hf_dataset): """Test that accessing beyond num_blocks raises IndexError.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) # Try to access block beyond num_blocks (which is 3) @@ -225,20 +207,14 @@ def test_missing_input_ids_field_raises_error(self): mock_ds.column_names = ["len"] # Missing input_ids with pytest.raises(ValueError) as exc_info: - _ = PretrainingBlockDataset( - dataset=mock_ds, - block_size=5, - pad_token_id=0 - ) + _ = PretrainingBlockDataset(dataset=mock_ds, block_size=5, pad_token_id=0) assert "input_ids" in str(exc_info.value) def test_tensor_dtype_correct(self, mock_hf_dataset): """Test that all tensors use torch.long dtype.""" dataset = PretrainingBlockDataset( - dataset=mock_hf_dataset, - block_size=5, - pad_token_id=0 + dataset=mock_hf_dataset, block_size=5, pad_token_id=0 ) block = dataset[0] @@ -260,29 +236,29 @@ def temp_pretraining_file(self, tmp_path): {"input_ids": list(range(300, 370)), "len": 70}, ] - with open(data_file, 'w') as f: + with open(data_file, "w") as f: for sample in samples: json.dump(sample, f) - f.write('\n') + f.write("\n") return str(data_file) - @patch('instructlab.training.sampler.load_dataset') + @patch("instructlab.training.sampler.load_dataset") def test_pretraining_mode_creates_block_dataset( - self, - mock_load_dataset, - temp_pretraining_file + self, mock_load_dataset, temp_pretraining_file ): """Test that is_pretraining=True creates PretrainingBlockDataset.""" # Create mock dataset mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 3 - mock_ds.__iter__ = lambda self: iter([ - {"input_ids": [1, 2, 3], "len": 3}, - {"input_ids": [4, 5, 6], "len": 3}, - {"input_ids": [7, 8, 9], "len": 3}, - ]) + mock_ds.__iter__ = lambda self: iter( + [ + {"input_ids": [1, 2, 3], "len": 3}, + {"input_ids": [4, 5, 6], "len": 3}, + {"input_ids": [7, 8, 9], "len": 3}, + ] + ) mock_load_dataset.return_value = mock_ds # Call with pretraining mode @@ -294,7 +270,7 @@ def test_pretraining_mode_creates_block_dataset( rank=0, world_size=1, is_pretraining=True, - block_size=128 + block_size=128, ) # Verify load_dataset was called @@ -306,16 +282,18 @@ def test_pretraining_mode_creates_block_dataset( def test_instruction_tuning_mode_creates_token_dataset(self, temp_pretraining_file): """Test that is_pretraining=False uses TokenDataset.""" # Create a valid instruction tuning JSONL file + # Standard from pathlib import Path + inst_file = Path(temp_pretraining_file).parent / "inst_data.jsonl" samples = [ {"input_ids": [1, 2, 3], "labels": [1, 2, 3], "len": 3}, {"input_ids": [4, 5, 6], "labels": [4, 5, 6], "len": 3}, ] - with open(inst_file, 'w') as f: + with open(inst_file, "w") as f: for sample in samples: json.dump(sample, f) - f.write('\n') + f.write("\n") # Call with instruction tuning mode (default) loader = get_data_loader( @@ -325,26 +303,28 @@ def test_instruction_tuning_mode_creates_token_dataset(self, temp_pretraining_fi seed=42, rank=0, world_size=1, - is_pretraining=False + is_pretraining=False, ) # Verify dataset is TokenDataset (not PretrainingBlockDataset) + # First Party from instructlab.training.sampler import TokenDataset + assert isinstance(loader.dataset, TokenDataset) assert not isinstance(loader.dataset, PretrainingBlockDataset) - @patch('instructlab.training.sampler.load_dataset') + @patch("instructlab.training.sampler.load_dataset") def test_pretraining_block_size_parameter( - self, - mock_load_dataset, - temp_pretraining_file + self, mock_load_dataset, temp_pretraining_file ): """Test that block_size parameter is correctly passed.""" # Create mock dataset mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 1 - mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(100)), "len": 100}]) + mock_ds.__iter__ = lambda self: iter( + [{"input_ids": list(range(100)), "len": 100}] + ) mock_load_dataset.return_value = mock_ds # Call with specific block_size @@ -357,24 +337,24 @@ def test_pretraining_block_size_parameter( rank=0, world_size=1, is_pretraining=True, - block_size=block_size + block_size=block_size, ) # Verify dataset has correct block_size assert loader.dataset.block_size == block_size - @patch('instructlab.training.sampler.load_dataset') + @patch("instructlab.training.sampler.load_dataset") def test_pretraining_pad_token_id_used( - self, - mock_load_dataset, - temp_pretraining_file + self, mock_load_dataset, temp_pretraining_file ): """Test that pad_token_id is correctly passed to PretrainingBlockDataset.""" # Create mock dataset mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 1 - mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(10)), "len": 10}]) + mock_ds.__iter__ = lambda self: iter( + [{"input_ids": list(range(10)), "len": 10}] + ) mock_load_dataset.return_value = mock_ds # Call with specific pad_token_id @@ -388,27 +368,27 @@ def test_pretraining_pad_token_id_used( world_size=1, is_pretraining=True, block_size=7, # Will create partial block - pad_token_id=pad_token_id + pad_token_id=pad_token_id, ) # Verify dataset has correct pad_token_id assert loader.dataset.pad_token_id == pad_token_id - @patch('instructlab.training.sampler.load_dataset') + @patch("instructlab.training.sampler.load_dataset") def test_data_loader_returns_correct_structure( - self, - mock_load_dataset, - temp_pretraining_file + self, mock_load_dataset, temp_pretraining_file ): """Test that get_data_loader returns a properly configured DataLoader.""" # Create mock dataset mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 2 - mock_ds.__iter__ = lambda self: iter([ - {"input_ids": list(range(50)), "len": 50}, - {"input_ids": list(range(50, 100)), "len": 50}, - ]) + mock_ds.__iter__ = lambda self: iter( + [ + {"input_ids": list(range(50)), "len": 50}, + {"input_ids": list(range(50, 100)), "len": 50}, + ] + ) mock_load_dataset.return_value = mock_ds # Call get_data_loader @@ -420,28 +400,28 @@ def test_data_loader_returns_correct_structure( rank=0, world_size=1, is_pretraining=True, - block_size=25 + block_size=25, ) # Verify it's a DataLoader + # Third Party from torch.utils.data import DataLoader + assert isinstance(loader, DataLoader) # Verify batch_size assert loader.batch_size == 2 - @patch('instructlab.training.sampler.load_dataset') - def test_epoch_sampler_created( - self, - mock_load_dataset, - temp_pretraining_file - ): + @patch("instructlab.training.sampler.load_dataset") + def test_epoch_sampler_created(self, mock_load_dataset, temp_pretraining_file): """Test that EpochSampler is created with correct parameters.""" # Create mock dataset with known length mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 1 - mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(100)), "len": 100}]) + mock_ds.__iter__ = lambda self: iter( + [{"input_ids": list(range(100)), "len": 100}] + ) mock_load_dataset.return_value = mock_ds seed = 123 @@ -455,28 +435,28 @@ def test_epoch_sampler_created( rank=0, world_size=1, is_pretraining=True, - block_size=block_size + block_size=block_size, ) # Verify sampler is EpochSampler + # First Party from instructlab.training.sampler import EpochSampler + assert isinstance(loader.sampler, EpochSampler) # Verify seed is set correctly assert loader.sampler.seed == seed - @patch('instructlab.training.sampler.load_dataset') - def test_collator_configuration( - self, - mock_load_dataset, - temp_pretraining_file - ): + @patch("instructlab.training.sampler.load_dataset") + def test_collator_configuration(self, mock_load_dataset, temp_pretraining_file): """Test that MaxTokensPerRankCollator is configured correctly.""" # Create mock dataset mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 1 - mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(50)), "len": 50}]) + mock_ds.__iter__ = lambda self: iter( + [{"input_ids": list(range(50)), "len": 50}] + ) mock_load_dataset.return_value = mock_ds flash_enabled = False @@ -493,11 +473,13 @@ def test_collator_configuration( is_pretraining=True, block_size=25, flash_enabled=flash_enabled, - pad_token_id=pad_token_id + pad_token_id=pad_token_id, ) # Verify collate_fn is MaxTokensPerRankCollator + # First Party from instructlab.training.sampler import MaxTokensPerRankCollator + assert isinstance(loader.collate_fn, MaxTokensPerRankCollator) # Verify collator configuration @@ -505,18 +487,16 @@ def test_collator_configuration( assert loader.collate_fn.flash_enabled == flash_enabled assert loader.collate_fn.pad_token_id == pad_token_id - @patch('instructlab.training.sampler.load_dataset') - def test_num_workers_parameter( - self, - mock_load_dataset, - temp_pretraining_file - ): + @patch("instructlab.training.sampler.load_dataset") + def test_num_workers_parameter(self, mock_load_dataset, temp_pretraining_file): """Test that num_workers parameter is correctly applied.""" # Create mock dataset mock_ds = MagicMock() mock_ds.column_names = ["input_ids", "len"] mock_ds.__len__ = lambda self: 1 - mock_ds.__iter__ = lambda self: iter([{"input_ids": list(range(50)), "len": 50}]) + mock_ds.__iter__ = lambda self: iter( + [{"input_ids": list(range(50)), "len": 50}] + ) mock_load_dataset.return_value = mock_ds num_workers = 4 @@ -530,7 +510,7 @@ def test_num_workers_parameter( world_size=1, is_pretraining=True, block_size=25, - num_workers=num_workers + num_workers=num_workers, ) # Verify num_workers is set From a354284785321149a78441033b4506725cfe9833 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:47:39 +0000 Subject: [PATCH 11/18] double-check failure --- src/instructlab/training/main_ds.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index decda926..ef8a921b 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -666,15 +666,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: return # wait for the process to exit so we can properly read the exit code - process.wait(timeout=60) - process_code = process.poll() - failure = process_code != 0 - - if not failure: + return_code = process.wait(timeout=60) # wait for 1 min or error + if return_code == 0: logger.info("Operation completed successfully! 🎉") else: logger.error( - f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}" + f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {return_code}" ) process.terminate() From 699ca7713d24e0aee925bdf5e7b6c7d2dfbb5f68 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:53:29 +0000 Subject: [PATCH 12/18] fix tests --- tests/unit/test_pretraining_data_process.py | 35 +++++++++++++++------ tests/unit/test_pretraining_sampler.py | 26 +++++++-------- 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/tests/unit/test_pretraining_data_process.py b/tests/unit/test_pretraining_data_process.py index ab1537e5..4e34691e 100644 --- a/tests/unit/test_pretraining_data_process.py +++ b/tests/unit/test_pretraining_data_process.py @@ -89,6 +89,7 @@ def map_side_effect(func, **kwargs): result = func({"documents": "Test document"}) mapped_ds = MagicMock() mapped_ds.__getitem__ = lambda self, key: [result[key]] + mapped_ds.__len__ = lambda self: 1 mapped_ds.to_json = MagicMock() return mapped_ds @@ -101,6 +102,7 @@ def map_side_effect(func, **kwargs): data_output_path=temp_output_dir, model_path="test-model", num_cpu_procs=1, + document_column_name="documents", ) # Verify tokenizer was loaded @@ -143,6 +145,7 @@ def map_side_effect(func, **kwargs): data_output_path=temp_output_dir, model_path="test-model", num_cpu_procs=1, + document_column_name="documents", ) # Verify map was called (which processes each document) @@ -177,12 +180,13 @@ def test_missing_documents_field_raises_error( mock_load_dataset.return_value = mock_ds # Should raise ValueError - with pytest.raises(ValueError, match="must have 'documents' field"): + with pytest.raises(ValueError, match="must have.*field"): process_documents_for_pretraining( data_path="dummy.jsonl", data_output_path=temp_output_dir, model_path="test-model", num_cpu_procs=1, + document_column_name="documents", ) @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") @@ -209,6 +213,7 @@ def test_tokenizer_without_eos_raises_error( data_output_path=temp_output_dir, model_path="test-model", num_cpu_procs=1, + document_column_name="documents", ) @patch("instructlab.training.data_process.logger") @@ -249,6 +254,7 @@ def map_side_effect(func, **kwargs): data_output_path=temp_output_dir, model_path="test-model", num_cpu_procs=1, + document_column_name="documents", ) # Verify logging was called (check info was called multiple times) @@ -270,6 +276,8 @@ def test_parallel_processing( def map_side_effect(func, **kwargs): mapped_ds = MagicMock() + mapped_ds.__len__ = lambda self: 1 + mapped_ds.__getitem__ = lambda self, key: [10] if key == "len" else None mapped_ds.to_json = MagicMock() return mapped_ds @@ -282,6 +290,7 @@ def map_side_effect(func, **kwargs): data_output_path=temp_output_dir, model_path="test-model", num_cpu_procs=4, + document_column_name="documents", ) # Verify map was called with num_proc=4 @@ -307,6 +316,10 @@ def test_output_directory_creation(self, tmp_path, mock_tokenizer): def map_side_effect(func, **kwargs): mapped_ds = MagicMock() + mapped_ds.__len__ = lambda self: 1 + mapped_ds.__getitem__ = ( + lambda self, key: [10] if key == "len" else None + ) mapped_ds.to_json = MagicMock() return mapped_ds @@ -319,6 +332,7 @@ def map_side_effect(func, **kwargs): data_output_path=str(output_dir), model_path="test-model", num_cpu_procs=1, + document_column_name="documents", ) # Verify directory was created @@ -343,6 +357,8 @@ def test_output_jsonl_format( def map_side_effect(func, **kwargs): result = func({"documents": "Test"}) mapped_ds = MagicMock() + mapped_ds.__len__ = lambda self: 1 + mapped_ds.__getitem__ = lambda self, key: [result[key]] def to_json_side_effect(path, **kw): nonlocal output_file_path @@ -364,6 +380,7 @@ def to_json_side_effect(path, **kw): data_output_path=temp_output_dir, model_path="test-model", num_cpu_procs=1, + document_column_name="documents", ) # Verify file was created @@ -404,6 +421,7 @@ def test_integration_with_real_tokenizer(self, temp_output_dir): data_output_path=temp_output_dir, model_path="gpt2", num_cpu_procs=1, + document_column_name="documents", ) # Verify output @@ -422,14 +440,11 @@ def test_integration_with_real_tokenizer(self, temp_output_dir): # Load tokenizer to verify tokens tokenizer = AutoTokenizer.from_pretrained("gpt2") - # Verify BOS/EOS are present (GPT2 uses same token 50256 for both) - # encode() with add_special_tokens=True adds BOS - # We manually append EOS - assert ( - data["input_ids"][0] == tokenizer.bos_token_id - or data["input_ids"][0] == tokenizer.eos_token_id - ) + # Verify EOS is present at the end + # Note: GPT2's encode() with add_special_tokens=True doesn't add BOS + # (GPT2 uses the same token for BOS and EOS) + # The implementation manually appends EOS if not present assert data["input_ids"][-1] == tokenizer.eos_token_id - # Verify token count is reasonable - assert data["len"] > 5 # Should have more than just BOS/EOS + # Verify token count is reasonable (should have content tokens + EOS) + assert data["len"] > 5 diff --git a/tests/unit/test_pretraining_sampler.py b/tests/unit/test_pretraining_sampler.py index 62a4d203..76b4b16d 100644 --- a/tests/unit/test_pretraining_sampler.py +++ b/tests/unit/test_pretraining_sampler.py @@ -11,6 +11,7 @@ import torch # First Party +from instructlab.training.config import PretrainingConfig from instructlab.training.sampler import PretrainingBlockDataset, get_data_loader @@ -269,8 +270,7 @@ def test_pretraining_mode_creates_block_dataset( seed=42, rank=0, world_size=1, - is_pretraining=True, - block_size=128, + pretraining_config=PretrainingConfig(block_size=128), ) # Verify load_dataset was called @@ -303,7 +303,7 @@ def test_instruction_tuning_mode_creates_token_dataset(self, temp_pretraining_fi seed=42, rank=0, world_size=1, - is_pretraining=False, + pretraining_config=None, ) # Verify dataset is TokenDataset (not PretrainingBlockDataset) @@ -336,8 +336,7 @@ def test_pretraining_block_size_parameter( seed=42, rank=0, world_size=1, - is_pretraining=True, - block_size=block_size, + pretraining_config=PretrainingConfig(block_size=block_size), ) # Verify dataset has correct block_size @@ -366,8 +365,9 @@ def test_pretraining_pad_token_id_used( seed=42, rank=0, world_size=1, - is_pretraining=True, - block_size=7, # Will create partial block + pretraining_config=PretrainingConfig( + block_size=7 + ), # Will create partial block pad_token_id=pad_token_id, ) @@ -399,8 +399,7 @@ def test_data_loader_returns_correct_structure( seed=42, rank=0, world_size=1, - is_pretraining=True, - block_size=25, + pretraining_config=PretrainingConfig(block_size=25), ) # Verify it's a DataLoader @@ -434,8 +433,7 @@ def test_epoch_sampler_created(self, mock_load_dataset, temp_pretraining_file): seed=seed, rank=0, world_size=1, - is_pretraining=True, - block_size=block_size, + pretraining_config=PretrainingConfig(block_size=block_size), ) # Verify sampler is EpochSampler @@ -470,8 +468,7 @@ def test_collator_configuration(self, mock_load_dataset, temp_pretraining_file): seed=42, rank=0, world_size=1, - is_pretraining=True, - block_size=25, + pretraining_config=PretrainingConfig(block_size=25), flash_enabled=flash_enabled, pad_token_id=pad_token_id, ) @@ -508,8 +505,7 @@ def test_num_workers_parameter(self, mock_load_dataset, temp_pretraining_file): seed=42, rank=0, world_size=1, - is_pretraining=True, - block_size=25, + pretraining_config=PretrainingConfig(block_size=25), num_workers=num_workers, ) From 9c3905bdfeb46c22950bb97062efc1333b34d111 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:01:02 +0000 Subject: [PATCH 13/18] revert debu change --- src/instructlab/training/main_ds.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index ef8a921b..87775bbb 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -457,9 +457,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # Enable package logging propagation before setting up loggers propagate_package_logs(True) setup_root_logger(train_args.log_level) - setup_metric_logger( - train_args.logger_type, train_args.run_name, train_args.ckpt_output_dir - ) + setup_metric_logger("async", None, train_args.ckpt_output_dir) logger = logging.getLogger("instructlab.training") logger.info("Starting training setup...") From fa0b7d1475cd50c2e0067da82be5622b12b7bf75 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:24:30 +0000 Subject: [PATCH 14/18] fix incorrect accese --- src/instructlab/training/config.py | 2 +- src/instructlab/training/main_ds.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 17d3a69d..3f2dd810 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -291,7 +291,7 @@ class TrainingArgs(BaseModel): keep_last_checkpoint_only: Optional[bool] = False pretraining_config: Optional[PretrainingConfig] = Field( - default="document", + default=None, description=( "Pretraining configuration. When provided, enables block-based sampling " "for raw document pretraining datasets." diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 87775bbb..0cf6b33e 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -571,8 +571,8 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.mock_data: command.append("--mock_data") - if train_args.mock_len: - command.append(f"--mock_len={train_args.mock_len}") + if train_args.mock_data_len: + command.append(f"--mock_len={train_args.mock_data_len}") if train_args.disable_flash_attn: command.append("--disable_flash_attn") From 5f8f2a535f0a0cd4d7842cd8daf3e33a5ee73a94 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:25:36 +0000 Subject: [PATCH 15/18] adds pydantic to mypy for tox --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 0794c417..8ce92518 100644 --- a/tox.ini +++ b/tox.ini @@ -104,6 +104,7 @@ deps = types-tqdm types-PyYAML pytest + pydantic commands = mypy {posargs:src} From d7be2a37e29503d7065e2d4a7520311825c85b1d Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:28:06 +0000 Subject: [PATCH 16/18] more linting --- src/instructlab/training/accelerator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/instructlab/training/accelerator.py b/src/instructlab/training/accelerator.py index 4baa7c0e..49fa52a8 100644 --- a/src/instructlab/training/accelerator.py +++ b/src/instructlab/training/accelerator.py @@ -63,6 +63,9 @@ def __init__( self.lr_scheduler = None if self.distributed_framework == DistributedBackend.DEEPSPEED: # Standard + cpu_offload_optimizer_ratio = ( + self.deepspeed_cpu_offload_optimizer_ratio or 0.0 + ) accel_args = { "deepspeed_plugin": self.get_ds_plugin( world_size=torch.distributed.get_world_size(), @@ -70,7 +73,7 @@ def __init__( grad_accum=grad_accum, opts=DeepSpeedOptions( cpu_offload_optimizer=deepspeed_cpu_offload_optimizer, - cpu_offload_optimizer_ratio=self.deepspeed_cpu_offload_optimizer_ratio, + cpu_offload_optimizer_ratio=cpu_offload_optimizer_ratio, cpu_offload_optimizer_pin_memory=self.deepspeed_cpu_offload_optimizer_pin_memory, save_samples=save_samples, ), From 02754556454ebf182ec7956dfa23e2a15d370b6f Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:35:12 +0000 Subject: [PATCH 17/18] revert changes --- src/instructlab/training/main_ds.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 0cf6b33e..ff5cea7d 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -664,12 +664,15 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: return # wait for the process to exit so we can properly read the exit code - return_code = process.wait(timeout=60) # wait for 1 min or error - if return_code == 0: + process.wait(timeout=60) + process_code = process.poll() + failure = process_code != 0 + + if not failure: logger.info("Operation completed successfully! 🎉") else: logger.error( - f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {return_code}" + f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}" ) process.terminate() From 77d39f63d19978a83b1cf2c84f3f2d20b1ec8903 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:48:48 +0000 Subject: [PATCH 18/18] fix tests --- src/instructlab/training/data_process.py | 20 +++++- tests/unit/test_pretraining_data_process.py | 80 +++++++++++++++++---- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 2b546fd5..b0eaa268 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -1153,7 +1153,7 @@ def process_documents_for_pretraining( Pattern: Each document → [BOS][tokens][EOS] Args: - data_path: Path to input JSONL with {"documents": "text"} format + data_path: Path to input JSONL with {"document": "text"} format data_output_path: Directory for processed data output model_path: Path to model/tokenizer num_cpu_procs: Number of parallel processes @@ -1200,11 +1200,25 @@ def tokenize_document(sample): "len": len(input_ids), } - tokenized_data = data.map( + # Filter out empty documents before tokenization + def filter_empty_documents(batch): + return [bool(doc) for doc in batch[document_column_name]] + + filtered_data = data.filter( + filter_empty_documents, + batched=True, + num_proc=num_cpu_procs, + desc="Filtering empty documents", + ) + + dropped_count = data.num_rows - filtered_data.num_rows + if dropped_count > 0: + logger.info(f"Dropped {dropped_count:,} empty documents") + tokenized_data = filtered_data.map( tokenize_document, num_proc=num_cpu_procs, desc="Tokenizing documents", - remove_columns=data.column_names, + remove_columns=filtered_data.column_names, ) # Calculate statistics diff --git a/tests/unit/test_pretraining_data_process.py b/tests/unit/test_pretraining_data_process.py index 4e34691e..e3651e00 100644 --- a/tests/unit/test_pretraining_data_process.py +++ b/tests/unit/test_pretraining_data_process.py @@ -82,7 +82,14 @@ def test_basic_tokenization_with_bos_eos( # Mock single document mock_ds.__iter__ = lambda self: iter([{"documents": "Test document"}]) - mock_ds.map = MagicMock() + + # Create filtered dataset mock + filtered_ds = MagicMock() + filtered_ds.num_rows = 1 + filtered_ds.column_names = ["documents"] + + # Mock filter to return the filtered dataset + mock_ds.filter = MagicMock(return_value=filtered_ds) # Make map return a dataset with tokenized data def map_side_effect(func, **kwargs): @@ -93,7 +100,7 @@ def map_side_effect(func, **kwargs): mapped_ds.to_json = MagicMock() return mapped_ds - mock_ds.map.side_effect = map_side_effect + filtered_ds.map = MagicMock(side_effect=map_side_effect) mock_load_dataset.return_value = mock_ds # Run function @@ -108,8 +115,8 @@ def map_side_effect(func, **kwargs): # Verify tokenizer was loaded mock_from_pretrained.assert_called_once_with("test-model") - # Verify dataset map was called - assert mock_ds.map.called + # Verify dataset filter and map were called + assert mock_ds.filter.called @patch("instructlab.training.data_process.AutoTokenizer.from_pretrained") @patch("instructlab.training.data_process.load_dataset") @@ -127,6 +134,14 @@ def test_multiple_documents_separate_records( docs = [{"documents": "Doc 1"}, {"documents": "Doc 2"}, {"documents": "Doc 3"}] + # Create filtered dataset mock + filtered_ds = MagicMock() + filtered_ds.num_rows = 3 + filtered_ds.column_names = ["documents"] + + # Mock filter to return the filtered dataset + mock_ds.filter = MagicMock(return_value=filtered_ds) + # Mock map to process all documents def map_side_effect(func, **kwargs): results = [func(doc) for doc in docs] @@ -136,7 +151,7 @@ def map_side_effect(func, **kwargs): mapped_ds.to_json = MagicMock() return mapped_ds - mock_ds.map.side_effect = map_side_effect + filtered_ds.map = MagicMock(side_effect=map_side_effect) mock_load_dataset.return_value = mock_ds # Run @@ -148,8 +163,8 @@ def map_side_effect(func, **kwargs): document_column_name="documents", ) - # Verify map was called (which processes each document) - assert mock_ds.map.called + # Verify filter and map were called (which processes each document) + assert mock_ds.filter.called @patch("instructlab.training.data_process.load_dataset") def test_empty_dataset_raises_error(self, mock_load_dataset, temp_output_dir): @@ -236,6 +251,14 @@ def test_statistics_logging( mock_ds.num_rows = 2 mock_ds.column_names = ["documents"] + # Create filtered dataset mock + filtered_ds = MagicMock() + filtered_ds.num_rows = 2 + filtered_ds.column_names = ["documents"] + + # Mock filter to return the filtered dataset + mock_ds.filter = MagicMock(return_value=filtered_ds) + # Mock map to return known lengths def map_side_effect(func, **kwargs): # Simulate 2 documents with 5 and 10 tokens each @@ -245,7 +268,7 @@ def map_side_effect(func, **kwargs): mapped_ds.to_json = MagicMock() return mapped_ds - mock_ds.map.side_effect = map_side_effect + filtered_ds.map = MagicMock(side_effect=map_side_effect) mock_load_dataset.return_value = mock_ds # Run @@ -272,7 +295,14 @@ def test_parallel_processing( mock_ds = MagicMock() mock_ds.num_rows = 1 mock_ds.column_names = ["documents"] - mock_ds.map = MagicMock() + + # Create filtered dataset mock + filtered_ds = MagicMock() + filtered_ds.num_rows = 1 + filtered_ds.column_names = ["documents"] + + # Mock filter to return the filtered dataset + mock_ds.filter = MagicMock(return_value=filtered_ds) def map_side_effect(func, **kwargs): mapped_ds = MagicMock() @@ -281,7 +311,7 @@ def map_side_effect(func, **kwargs): mapped_ds.to_json = MagicMock() return mapped_ds - mock_ds.map.side_effect = map_side_effect + filtered_ds.map = MagicMock(side_effect=map_side_effect) mock_load_dataset.return_value = mock_ds # Run with specific num_cpu_procs @@ -293,9 +323,13 @@ def map_side_effect(func, **kwargs): document_column_name="documents", ) - # Verify map was called with num_proc=4 - call_args = mock_ds.map.call_args - assert call_args[1]["num_proc"] == 4 + # Verify filter was called with num_proc=4 + filter_call_args = mock_ds.filter.call_args + assert filter_call_args[1]["num_proc"] == 4 + + # Verify map was also called with num_proc=4 + map_call_args = filtered_ds.map.call_args + assert map_call_args[1]["num_proc"] == 4 def test_output_directory_creation(self, tmp_path, mock_tokenizer): """Verify directory is created if it doesn't exist.""" @@ -314,6 +348,14 @@ def test_output_directory_creation(self, tmp_path, mock_tokenizer): mock_ds.num_rows = 1 mock_ds.column_names = ["documents"] + # Create filtered dataset mock + filtered_ds = MagicMock() + filtered_ds.num_rows = 1 + filtered_ds.column_names = ["documents"] + + # Mock filter to return the filtered dataset + mock_ds.filter = MagicMock(return_value=filtered_ds) + def map_side_effect(func, **kwargs): mapped_ds = MagicMock() mapped_ds.__len__ = lambda self: 1 @@ -323,7 +365,7 @@ def map_side_effect(func, **kwargs): mapped_ds.to_json = MagicMock() return mapped_ds - mock_ds.map.side_effect = map_side_effect + filtered_ds.map = MagicMock(side_effect=map_side_effect) mock_load_dataset.return_value = mock_ds # Run @@ -351,6 +393,14 @@ def test_output_jsonl_format( mock_ds.num_rows = 1 mock_ds.column_names = ["documents"] + # Create filtered dataset mock + filtered_ds = MagicMock() + filtered_ds.num_rows = 1 + filtered_ds.column_names = ["documents"] + + # Mock filter to return the filtered dataset + mock_ds.filter = MagicMock(return_value=filtered_ds) + # Track what gets written output_file_path = None @@ -371,7 +421,7 @@ def to_json_side_effect(path, **kw): mapped_ds.to_json = to_json_side_effect return mapped_ds - mock_ds.map.side_effect = map_side_effect + filtered_ds.map = MagicMock(side_effect=map_side_effect) mock_load_dataset.return_value = mock_ds # Run