diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index bfe1509d6..4d8e98da7 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -291,6 +291,12 @@ class StreamingDatasetConfig[DocumentType: LanguageModelDocument](RedisConfig, S _abstract = False + log_data_pipeline: bool = Field( + default=False, + desc="Write per-read timing to data_pipeline_log/rank_{rank}.jsonl for pipeline diagnostics.", + hint=FieldHint.optional, + ) + def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) -> SampledDataset[DocumentType]: from fast_llm.data.dataset.streaming import RedisStreamingDataset diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 8835612ec..bc09bf864 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -12,6 +12,7 @@ from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.document.range import RangeDocument from fast_llm.data.document.token_data import TokenDataDocument +from fast_llm.engine.config_utils.run import get_run from fast_llm.utils import Assert @@ -139,28 +140,63 @@ def iterate(self, config: SamplingConfig, num_samples: int, seed: int) -> typing else: raise + # Set up pipeline diagnostics log file if enabled + log_file = None + if self._config.log_data_pipeline: + run = get_run() + if run is not None and run.experiment_directory is not None: + log_dir = run.experiment_directory / "data_pipeline_log" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = open(log_dir / f"rank_{config.rank}.jsonl", "a") + + last_read_time = time.time() start_time = time.time() - while True: - # XREADGROUP reads from the consumer group - # COUNT: max number of messages to fetch at once - # BLOCK: wait for new messages (milliseconds) - messages = client.xreadgroup( - groupname=REDIS_GROUP_NAME, - consumername=f"fast_llm_consumer_{config.rank}", - # ">" reads only new messages that have not been delivered to any consumer - streams={REDIS_DATA_STREAM: ">"}, - count=1, - block=1000, - # No explicit ACK: messages are processed immediately; on rank failure the job restarts, - # so message loss is acceptable and simplifies coordination - noack=True, - ) - if messages: - for stream_key, messages_ in messages: - assert stream_key == REDIS_DATA_STREAM.encode() - for message_id, message in messages_: - yield RedisStreamingDocumentData.from_message(message).to_document() - start_time = time.time() - - elif (t := time.time() - start_time) > self._config.timeout: - raise TimeoutError(f"No document received after {t} seconds") + try: + while True: + # XREADGROUP reads from the consumer group + # COUNT: max number of messages to fetch at once + # BLOCK: wait for new messages (milliseconds) + messages = client.xreadgroup( + groupname=REDIS_GROUP_NAME, + consumername=f"fast_llm_consumer_{config.rank}", + # ">" reads only new messages that have not been delivered to any consumer + streams={REDIS_DATA_STREAM: ">"}, + count=1, + block=1000, + # No explicit ACK: messages are processed immediately; on rank failure the job restarts, + # so message loss is acceptable and simplifies coordination + noack=True, + ) + if messages: + now = time.time() + for stream_key, messages_ in messages: + assert stream_key == REDIS_DATA_STREAM.encode() + for message_id, message in messages_: + doc = RedisStreamingDocumentData.from_message(message) + if log_file is not None: + write_ms = int(message_id.split(b"-")[0]) / 1000.0 + gap_ms = (now - last_read_time) * 1000.0 + latency_ms = (now - write_ms) * 1000.0 + log_file.write( + json.dumps( + { + "event": "READ", + "rank": config.rank, + "t": now, + "gap_ms": round(gap_ms, 2), + "latency_ms": round(latency_ms, 2), + "tokens": doc.num_tokens, + } + ) + + "\n" + ) + log_file.flush() + last_read_time = now + yield doc.to_document() + start_time = time.time() + + elif (t := time.time() - start_time) > self._config.timeout: + raise TimeoutError(f"No document received after {t} seconds") + finally: + if log_file is not None: + log_file.close() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 00040e576..6a20dceba 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -39,6 +39,9 @@ class LanguageModelTargetInput(ModelInput): class LanguageModelInput(TokenModelInput): targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list) image_patches: PatchModelInput | None = None + # Number of documents with at least one response token in this micro-batch. + # Computed before any TP/SP/PP splitting; used to normalize per-document metrics. + num_docs: int | None = None def set_children_attributes(self) -> None: if self.image_patches is not None: @@ -58,6 +61,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.num_docs: self.num_docs, } if self.image_patches is not None: out.update(self.image_patches.to_kwargs()) @@ -121,6 +125,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis model_inputs.append(model_input) + # num_docs is counted only on the first split (micro_sequence_index==0) to avoid + # double-counting documents that span a split boundary when micro_batch_splits > 1. + # The first split already has the correct count (or 0 for SDP rank > 0); clear the rest. + for model_input in model_inputs[1:]: + model_input.num_docs = None + return model_inputs def _get_model_input( @@ -161,6 +171,21 @@ def _get_model_input( length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0) label_count_cumsum = mask_cumsum[length_cumsum] labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] + # Track documents with at least one response token for per-document metric + # normalization. Only counted on sequence_data_rank==0 so that when the runner + # all_reduces the denominator across the data group (which includes SDP ranks), + # each document is counted exactly once even though all SDP ranks see the same + # documents (but process different token slices of them). + if model_input.num_docs is None: + # Skip .item() on meta tensors (called during setup/shape inference). + if labels_per_document.is_meta: + model_input.num_docs = 0 + else: + model_input.num_docs = ( + int((labels_per_document > 0).sum().item()) + if config.distributed.sequence_data_rank == 0 + else 0 + ) # Expand to one entry per token: find each token's document index via the sorted # length cumsum, then look up that document's label count. # TODO: Document index already computed in `LengthModelInputPreprocessor`. diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 0526b9dc2..959368476 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -109,6 +109,9 @@ class LossDef: name: str formatted_name: str # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. - # TODO: Allow variable count? Would need a reduction across PP devices. count: int = 1 dtype: DataType = DataType.float32 + # If set, normalize this metric by summing values from context.batch[i][denominator_batch_field] + # across micro-batches and DP ranks, instead of using count * data_parallel * num_inputs. + # The field must be a scalar (int or float) pre-computed before any TP/SP/PP splitting. + denominator_batch_field: str | None = None diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 48714db40..a849d05de 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -42,6 +42,11 @@ class ScheduleConfig(Config): hint=FieldHint.optional, valid=check_field(Assert.gt, 0), ) + log_data_pipeline: bool = Field( + default=False, + desc="Write per-micro-batch timing to data_pipeline_log/rank_{rank}.jsonl for pipeline diagnostics.", + hint=FieldHint.optional, + ) # Enable cpu throttling to avoid lag spikes, see https://arxiv.org/pdf/2211.05953.pdf, appendix D.2. throttle_cpu: bool = Field( default=True, diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 24b8b3d63..dd41dad86 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -1,6 +1,7 @@ import collections import contextlib import dataclasses +import json import logging import time import typing @@ -287,20 +288,39 @@ def run_step( def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: reduced_losses = {} for name, losses in context.losses.items(): + loss_def = self._loss_definitions[name] if losses or self._distributed.pipeline_group: if losses: - loss_count = ( - self._loss_definitions[name].count - * self._distributed_config.data_parallel - * context.schedule.config.num_inputs - ) - reduced_loss = torch.stack(losses).sum() / loss_count - if self._distributed.data_group: - all_reduce(reduced_loss, group=self._distributed.data_group) + denom_field = loss_def.denominator_batch_field + if denom_field is not None: + # Normalize by a per-micro-batch scalar from the batch data (e.g. num_docs), + # computed before any TP/SP/PP splitting. Sum numerator and denominator + # independently across DP ranks so the result is a true global average. + numerator = torch.stack(losses).sum() + denominator = torch.tensor( + sum( + batch_kwargs[denom_field] or 0 + for batch_kwargs in context.batch.values() + if denom_field in batch_kwargs + ), + dtype=numerator.dtype, + device=numerator.device, + ) + if self._distributed.data_group: + all_reduce(numerator, group=self._distributed.data_group) + all_reduce(denominator, group=self._distributed.data_group) + reduced_loss = numerator / denominator.clamp(min=1) + else: + loss_count = ( + loss_def.count + * self._distributed_config.data_parallel + * context.schedule.config.num_inputs + ) + reduced_loss = torch.stack(losses).sum() / loss_count + if self._distributed.data_group: + all_reduce(reduced_loss, group=self._distributed.data_group) else: - reduced_loss = torch.zeros( - [1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device - ) + reduced_loss = torch.zeros([1], dtype=loss_def.dtype.torch, device=self._distributed.device) if self._distributed.pipeline_group: all_reduce(reduced_loss, group=self._distributed.pipeline_group) else: @@ -428,14 +448,63 @@ def _backward(self, context: BatchContext, step: Step) -> torch.Tensor: self._record_compute(context, step) return input_grad + def _get_pipeline_log_file(self): + if not self._config.log_data_pipeline: + return None + if not hasattr(self, "_pipeline_log_file"): + run = get_run() + if run is not None and run.experiment_directory is not None: + log_dir = run.experiment_directory / "data_pipeline_log" + log_dir.mkdir(parents=True, exist_ok=True) + rank = self._distributed_config.data_rank + self._pipeline_log_file = open(log_dir / f"rank_{rank}.jsonl", "a") + else: + self._pipeline_log_file = None + return self._pipeline_log_file + def _get_forward_input(self, context: BatchContext, step: Step) -> torch.Tensor: if step.index not in context.batch: start_time = time.perf_counter() + rank = self._distributed_config.data_rank + log_file = self._get_pipeline_log_file() + + if log_file is not None: + log_file.write( + json.dumps( + { + "event": "BATCH_START", + "rank": rank, + "micro": step.index, + "t": time.time(), + } + ) + + "\n" + ) + log_file.flush() + samples_loaded = 0 while step.index not in context.batch: next(context.data_iterator) + samples_loaded += 1 data_time = (time.perf_counter() - start_time) * 1000 + + if log_file is not None: + log_file.write( + json.dumps( + { + "event": "BATCH_END", + "rank": rank, + "micro": step.index, + "t": time.time(), + "duration_ms": round(data_time, 2), + "samples": samples_loaded, + } + ) + + "\n" + ) + log_file.flush() + if data_time > self._config.data_batch_warn_time_ms: logger.warning(f"Data loading took {data_time:,.2f} ms") return context.inputs.pop(step.global_index).detach().requires_grad_(step.stage != 0) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 99d4bce9a..883b2b095 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -27,6 +27,7 @@ class LanguageModelLossKwargs(BlockKwargs): advantages = "advantages" old_log_probabilities = "old_log_probabilities" label_counts = "num_labels_in_seq" + num_docs = "num_docs" @config_class(registry=True) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 2136e7918..3217f112c 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -47,8 +47,12 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: LossDef( self._logprob_metric_name, formatted_name=self._logprob_metric_name, - count=1, # This is an additive metric over the sequence. + count=1, dtype=DataType.float32, + # Normalize by the number of documents with response tokens in each micro-batch, + # giving a true per-document average regardless of variable document lengths. + # num_docs is computed before any TP/SP/PP splitting in language_model.py. + denominator_batch_field=LanguageModelLossKwargs.num_docs, ) ] diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index d0e56e3f0..6802095c0 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -4,6 +4,8 @@ from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.data.document.range import RangeDocument +from fast_llm.data.document.token_data import TokenDataDocument +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -53,3 +55,163 @@ def test_preprocessing(tokens, loss_masking_spans): Assert.eq(len(model_input.targets), 1) Assert.all_equal(model_input.targets[0].tokens, torch.cat(label_tokens)[1:]) + + +def _make_grpo_document(tokens, loss_masking_spans=None): + """Helper: create a LanguageModelDocument with GRPO fields (advantages, old_log_probabilities).""" + t = torch.tensor(tokens, dtype=torch.int64) + n = len(t) + return LanguageModelDocument( + tokens=t, + loss_masking_spans=None if loss_masking_spans is None else RangeDocument(ranges=loss_masking_spans), + advantages=TokenDataDocument(data=torch.zeros(n)), + old_log_probabilities=TokenDataDocument(data=torch.zeros(n)), + ) + + +@pytest.mark.parametrize( + ("token_lists", "loss_masking_spans_list", "expected_num_docs"), + ( + # Single doc, no masking — all tokens are response tokens except first (cross-doc mask) + ([[1, 2, 3, 4, 5]], [None], 1), + # Single doc fully masked by loss_masking_spans — no response tokens, num_docs = 0 + ([[1, 2, 3, 4, 5]], [[(0, 5)]], 0), + # Two docs, both with response tokens + ([[1, 2, 3], [4, 5, 6]], [None, None], 2), + # Two docs, one fully masked — only 1 contributes + ([[1, 2, 3], [4, 5, 6]], [[(0, 3)], None], 1), + # Two docs, both fully masked + ([[1, 2, 3], [4, 5, 6]], [[(0, 3)], [(0, 3)]], 0), + # Padding: a short doc packed into a larger micro_batch_size leaves a padding segment + ([[1, 2, 3]], [None], 1), # with pad_to_size below + ), +) +def test_num_docs_computation(token_lists, loss_masking_spans_list, expected_num_docs): + """num_docs counts only documents that have at least one non-masked response token.""" + documents = [_make_grpo_document(tokens, spans) for tokens, spans in zip(token_lists, loss_masking_spans_list)] + config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True, return_label_counts=True) + (model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + Assert.eq(model_input.num_docs, expected_num_docs) + + +def test_num_docs_excludes_padding(): + """Padding appended by pad_to_size is a 0-label segment and must not count toward num_docs.""" + documents = [_make_grpo_document([1, 2, 3, 4])] + config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True, return_label_counts=True) + # pad_to_size > total tokens forces a padding segment to be added + (model_input,) = LanguageModelBatch.from_documents(documents, pad_to_size=10).get_model_inputs(config) + # Only the real document counts; the padding segment (all -100) does not + Assert.eq(model_input.num_docs, 1) + + +def test_num_docs_none_without_label_counts(): + """num_docs is None when return_label_counts is False (GRPO preprocessing not requested).""" + documents = [_make_grpo_document([1, 2, 3, 4])] + config = LanguageModelBatchPreprocessingConfig() + (model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + assert model_input.num_docs is None + + +def _make_sdp_config(sdp_rank: int, sdp_size: int = 2) -> LanguageModelBatchPreprocessingConfig: + """Config simulating a given sequence-data-parallel rank.""" + return LanguageModelBatchPreprocessingConfig( + use_grpo_data=True, + return_label_counts=True, + distributed=DistributedConfig(world_size=sdp_size, rank=sdp_rank, sequence_data_parallel=sdp_size), + ) + + +def test_num_docs_sdp_only_counted_on_rank0(): + """With SDP=2, num_docs is counted only on sequence_data_rank=0. + + The runner all_reduces the denominator across the data group (which includes all SDP + ranks). If both SDP ranks reported num_docs=1 for the same document, the all_reduce + SUM would produce denominator=2 and halve the metric. Only rank 0 must contribute + to avoid this double-counting. + """ + # 9 tokens → total_input_length = 8 (divisible by SDP=2) + documents = [_make_grpo_document([1, 2, 3, 4, 5, 6, 7, 8, 9])] + batch = LanguageModelBatch.from_documents(documents) + + (model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0)) + (model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1)) + + # Rank 0 counts the document; rank 1 must not. + Assert.eq(model_input_rank0.num_docs, 1) + Assert.eq(model_input_rank1.num_docs, 0) + + # After all_reduce SUM across SDP ranks the denominator equals the true doc count. + Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 1) + + +def test_num_docs_sdp_fully_masked_excluded_on_rank0(): + """A fully-masked document is excluded from num_docs even on SDP rank 0.""" + # Doc 0 fully masked; doc 1 has response tokens. 9 tokens total → 8 input tokens (div by 2). + documents = [ + _make_grpo_document([1, 2, 3, 4], loss_masking_spans=[(0, 4)]), + _make_grpo_document([5, 6, 7, 8, 9]), + ] + batch = LanguageModelBatch.from_documents(documents) + + (model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0)) + (model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1)) + + # Only the unmasked document counts; rank 1 always contributes 0. + Assert.eq(model_input_rank0.num_docs, 1) + Assert.eq(model_input_rank1.num_docs, 0) + Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 1) + + +def test_num_docs_sdp_two_docs_counted_once(): + """Two documents on SDP=2 are counted once in total (not once per SDP rank).""" + # 9 tokens total → 8 input tokens, divisible by SDP=2. + documents = [_make_grpo_document([1, 2, 3, 4]), _make_grpo_document([5, 6, 7, 8, 9])] + batch = LanguageModelBatch.from_documents(documents) + + (model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0)) + (model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1)) + + Assert.eq(model_input_rank0.num_docs, 2) + Assert.eq(model_input_rank1.num_docs, 0) + Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 2) + + +def test_num_docs_micro_batch_splits_only_first_split_counts(): + """With micro_batch_splits=2, num_docs is non-None only on the first split. + + A document that spans the split boundary would be visible in both splits' + cropped_lengths, so both would count it without this guard. The runner sums + num_docs across all splits in context.batch; only the first split must contribute + to avoid double-counting. + """ + # 9 tokens → 8 input tokens, divisible by micro_batch_splits=2 (4 each) + documents = [_make_grpo_document([1, 2, 3, 4, 5, 6, 7, 8, 9])] + config = LanguageModelBatchPreprocessingConfig( + use_grpo_data=True, + return_label_counts=True, + micro_batch_splits=2, + ) + split0, split1 = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + + # Only the first split carries the count; the second must be None (runner treats as 0). + Assert.eq(split0.num_docs, 1) + assert split1.num_docs is None + + # Simulated runner sum: 1 + (None→0) = 1, the correct denominator. + Assert.eq((split0.num_docs or 0) + (split1.num_docs or 0), 1) + + +def test_num_docs_micro_batch_splits_two_docs(): + """With micro_batch_splits=2 and two documents, only the first split counts both docs.""" + # 9 tokens total → 8 input tokens, divisible by 2 + documents = [_make_grpo_document([1, 2, 3, 4]), _make_grpo_document([5, 6, 7, 8, 9])] + config = LanguageModelBatchPreprocessingConfig( + use_grpo_data=True, + return_label_counts=True, + micro_batch_splits=2, + ) + split0, split1 = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + + Assert.eq(split0.num_docs, 2) + assert split1.num_docs is None + Assert.eq((split0.num_docs or 0) + (split1.num_docs or 0), 2)