From 930bd968c0ef1f54237ca2ddd63065e3523061ec Mon Sep 17 00:00:00 2001 From: bigximik Date: Sat, 21 Mar 2026 17:04:46 +0000 Subject: [PATCH 1/4] normalize grpo_new_logprobs by num_docs instead of num_micro_batches Add generic denominator_batch_field to LossDef so any metric can be normalized by a pre-computed per-micro-batch scalar from the batch context, bypassing TP/SP/PP splitting entirely. For grpo_new_logprobs: compute num_docs = (labels_per_document > 0).sum() in language_model.py before any parallel splitting, giving a true per-document average regardless of variable document lengths. Only sequence_data_rank==0 contributes to num_docs. The runner all_reduces the denominator across the data group (which includes SDP ranks); if every SDP rank reported its own num_docs, a single document processed by SDP=2 would be counted twice, halving the metric. Also clamp num_labels_in_seq to avoid 0/0=nan for padding segments or fully-masked documents (loss_mask=0 there so the numerator is 0 too). Tests verify: - num_docs counts only unmasked documents - padding segments (pad_to_size) are excluded - with SDP=2, only rank 0 contributes num_docs so the all_reduce SUM across SDP ranks gives the correct denominator --- fast_llm/data/document/language_model.py | 15 +++ fast_llm/engine/base_model/config.py | 5 +- fast_llm/engine/schedule/runner.py | 41 ++++-- fast_llm/layers/language_model/loss/config.py | 1 + fast_llm/layers/language_model/loss/grpo.py | 6 +- tests/data/test_preprocessing.py | 121 ++++++++++++++++++ 6 files changed, 176 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 00040e576..da50468cb 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -39,6 +39,9 @@ class LanguageModelTargetInput(ModelInput): class LanguageModelInput(TokenModelInput): targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list) image_patches: PatchModelInput | None = None + # Number of documents with at least one response token in this micro-batch. + # Computed before any TP/SP/PP splitting; used to normalize per-document metrics. + num_docs: int | None = None def set_children_attributes(self) -> None: if self.image_patches is not None: @@ -58,6 +61,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.num_docs: self.num_docs, } if self.image_patches is not None: out.update(self.image_patches.to_kwargs()) @@ -161,6 +165,17 @@ def _get_model_input( length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0) label_count_cumsum = mask_cumsum[length_cumsum] labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] + # Track documents with at least one response token for per-document metric + # normalization. Only counted on sequence_data_rank==0 so that when the runner + # all_reduces the denominator across the data group (which includes SDP ranks), + # each document is counted exactly once even though all SDP ranks see the same + # documents (but process different token slices of them). + if model_input.num_docs is None: + model_input.num_docs = ( + int((labels_per_document > 0).sum().item()) + if config.distributed.sequence_data_rank == 0 + else 0 + ) # Expand to one entry per token: find each token's document index via the sorted # length cumsum, then look up that document's label count. # TODO: Document index already computed in `LengthModelInputPreprocessor`. diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 0526b9dc2..959368476 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -109,6 +109,9 @@ class LossDef: name: str formatted_name: str # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. - # TODO: Allow variable count? Would need a reduction across PP devices. count: int = 1 dtype: DataType = DataType.float32 + # If set, normalize this metric by summing values from context.batch[i][denominator_batch_field] + # across micro-batches and DP ranks, instead of using count * data_parallel * num_inputs. + # The field must be a scalar (int or float) pre-computed before any TP/SP/PP splitting. + denominator_batch_field: str | None = None diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 24b8b3d63..88f0bbbd5 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -287,20 +287,39 @@ def run_step( def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: reduced_losses = {} for name, losses in context.losses.items(): + loss_def = self._loss_definitions[name] if losses or self._distributed.pipeline_group: if losses: - loss_count = ( - self._loss_definitions[name].count - * self._distributed_config.data_parallel - * context.schedule.config.num_inputs - ) - reduced_loss = torch.stack(losses).sum() / loss_count - if self._distributed.data_group: - all_reduce(reduced_loss, group=self._distributed.data_group) + denom_field = loss_def.denominator_batch_field + if denom_field is not None: + # Normalize by a per-micro-batch scalar from the batch data (e.g. num_docs), + # computed before any TP/SP/PP splitting. Sum numerator and denominator + # independently across DP ranks so the result is a true global average. + numerator = torch.stack(losses).sum() + denominator = torch.tensor( + sum( + batch_kwargs[denom_field] or 0 + for batch_kwargs in context.batch.values() + if denom_field in batch_kwargs + ), + dtype=numerator.dtype, + device=numerator.device, + ) + if self._distributed.data_group: + all_reduce(numerator, group=self._distributed.data_group) + all_reduce(denominator, group=self._distributed.data_group) + reduced_loss = numerator / denominator.clamp(min=1) + else: + loss_count = ( + loss_def.count + * self._distributed_config.data_parallel + * context.schedule.config.num_inputs + ) + reduced_loss = torch.stack(losses).sum() / loss_count + if self._distributed.data_group: + all_reduce(reduced_loss, group=self._distributed.data_group) else: - reduced_loss = torch.zeros( - [1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device - ) + reduced_loss = torch.zeros([1], dtype=loss_def.dtype.torch, device=self._distributed.device) if self._distributed.pipeline_group: all_reduce(reduced_loss, group=self._distributed.pipeline_group) else: diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 99d4bce9a..883b2b095 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -27,6 +27,7 @@ class LanguageModelLossKwargs(BlockKwargs): advantages = "advantages" old_log_probabilities = "old_log_probabilities" label_counts = "num_labels_in_seq" + num_docs = "num_docs" @config_class(registry=True) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 2136e7918..3217f112c 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -47,8 +47,12 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: LossDef( self._logprob_metric_name, formatted_name=self._logprob_metric_name, - count=1, # This is an additive metric over the sequence. + count=1, dtype=DataType.float32, + # Normalize by the number of documents with response tokens in each micro-batch, + # giving a true per-document average regardless of variable document lengths. + # num_docs is computed before any TP/SP/PP splitting in language_model.py. + denominator_batch_field=LanguageModelLossKwargs.num_docs, ) ] diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index d0e56e3f0..1733c423b 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -4,6 +4,8 @@ from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.data.document.range import RangeDocument +from fast_llm.data.document.token_data import TokenDataDocument +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -53,3 +55,122 @@ def test_preprocessing(tokens, loss_masking_spans): Assert.eq(len(model_input.targets), 1) Assert.all_equal(model_input.targets[0].tokens, torch.cat(label_tokens)[1:]) + + +def _make_grpo_document(tokens, loss_masking_spans=None): + """Helper: create a LanguageModelDocument with GRPO fields (advantages, old_log_probabilities).""" + t = torch.tensor(tokens, dtype=torch.int64) + n = len(t) + return LanguageModelDocument( + tokens=t, + loss_masking_spans=None if loss_masking_spans is None else RangeDocument(ranges=loss_masking_spans), + advantages=TokenDataDocument(data=torch.zeros(n)), + old_log_probabilities=TokenDataDocument(data=torch.zeros(n)), + ) + + +@pytest.mark.parametrize( + ("token_lists", "loss_masking_spans_list", "expected_num_docs"), + ( + # Single doc, no masking — all tokens are response tokens except first (cross-doc mask) + ([[1, 2, 3, 4, 5]], [None], 1), + # Single doc fully masked by loss_masking_spans — no response tokens, num_docs = 0 + ([[1, 2, 3, 4, 5]], [[(0, 5)]], 0), + # Two docs, both with response tokens + ([[1, 2, 3], [4, 5, 6]], [None, None], 2), + # Two docs, one fully masked — only 1 contributes + ([[1, 2, 3], [4, 5, 6]], [[(0, 3)], None], 1), + # Two docs, both fully masked + ([[1, 2, 3], [4, 5, 6]], [[(0, 3)], [(0, 3)]], 0), + # Padding: a short doc packed into a larger micro_batch_size leaves a padding segment + ([[1, 2, 3]], [None], 1), # with pad_to_size below + ), +) +def test_num_docs_computation(token_lists, loss_masking_spans_list, expected_num_docs): + """num_docs counts only documents that have at least one non-masked response token.""" + documents = [_make_grpo_document(tokens, spans) for tokens, spans in zip(token_lists, loss_masking_spans_list)] + config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True, return_label_counts=True) + (model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + Assert.eq(model_input.num_docs, expected_num_docs) + + +def test_num_docs_excludes_padding(): + """Padding appended by pad_to_size is a 0-label segment and must not count toward num_docs.""" + documents = [_make_grpo_document([1, 2, 3, 4])] + config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True, return_label_counts=True) + # pad_to_size > total tokens forces a padding segment to be added + (model_input,) = LanguageModelBatch.from_documents(documents, pad_to_size=10).get_model_inputs(config) + # Only the real document counts; the padding segment (all -100) does not + Assert.eq(model_input.num_docs, 1) + + +def test_num_docs_none_without_label_counts(): + """num_docs is None when return_label_counts is False (GRPO preprocessing not requested).""" + documents = [_make_grpo_document([1, 2, 3, 4])] + config = LanguageModelBatchPreprocessingConfig() + (model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + assert model_input.num_docs is None + + +def _make_sdp_config(sdp_rank: int, sdp_size: int = 2) -> LanguageModelBatchPreprocessingConfig: + """Config simulating a given sequence-data-parallel rank.""" + return LanguageModelBatchPreprocessingConfig( + use_grpo_data=True, + return_label_counts=True, + distributed=DistributedConfig(world_size=sdp_size, rank=sdp_rank, sequence_data_parallel=sdp_size), + ) + + +def test_num_docs_sdp_only_counted_on_rank0(): + """With SDP=2, num_docs is counted only on sequence_data_rank=0. + + The runner all_reduces the denominator across the data group (which includes all SDP + ranks). If both SDP ranks reported num_docs=1 for the same document, the all_reduce + SUM would produce denominator=2 and halve the metric. Only rank 0 must contribute + to avoid this double-counting. + """ + # 9 tokens → total_input_length = 8 (divisible by SDP=2) + documents = [_make_grpo_document([1, 2, 3, 4, 5, 6, 7, 8, 9])] + batch = LanguageModelBatch.from_documents(documents) + + (model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0)) + (model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1)) + + # Rank 0 counts the document; rank 1 must not. + Assert.eq(model_input_rank0.num_docs, 1) + Assert.eq(model_input_rank1.num_docs, 0) + + # After all_reduce SUM across SDP ranks the denominator equals the true doc count. + Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 1) + + +def test_num_docs_sdp_fully_masked_excluded_on_rank0(): + """A fully-masked document is excluded from num_docs even on SDP rank 0.""" + # Doc 0 fully masked; doc 1 has response tokens. 9 tokens total → 8 input tokens (div by 2). + documents = [ + _make_grpo_document([1, 2, 3, 4], loss_masking_spans=[(0, 4)]), + _make_grpo_document([5, 6, 7, 8, 9]), + ] + batch = LanguageModelBatch.from_documents(documents) + + (model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0)) + (model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1)) + + # Only the unmasked document counts; rank 1 always contributes 0. + Assert.eq(model_input_rank0.num_docs, 1) + Assert.eq(model_input_rank1.num_docs, 0) + Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 1) + + +def test_num_docs_sdp_two_docs_counted_once(): + """Two documents on SDP=2 are counted once in total (not once per SDP rank).""" + # 9 tokens total → 8 input tokens, divisible by SDP=2. + documents = [_make_grpo_document([1, 2, 3, 4]), _make_grpo_document([5, 6, 7, 8, 9])] + batch = LanguageModelBatch.from_documents(documents) + + (model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0)) + (model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1)) + + Assert.eq(model_input_rank0.num_docs, 2) + Assert.eq(model_input_rank1.num_docs, 0) + Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 2) From baff181873e07d9d72583f9c4915faf35043d9a8 Mon Sep 17 00:00:00 2001 From: bigximik Date: Sat, 21 Mar 2026 19:06:22 +0000 Subject: [PATCH 2/4] fix num_docs double-counting with micro_batch_splits > 1 When micro_batch_splits > 1, _get_model_input is called once per split on the same rank. Documents that span a split boundary appear in both splits' cropped_lengths, so both would count them without a guard. The runner sums num_docs across all splits in context.batch, so boundary documents would be counted multiple times. Fix: after the loop in get_model_inputs, set num_docs=None on all splits except the first. The first split already holds the correct count (guarded by sequence_data_rank==0 for SDP); subsequent splits get None which the runner treats as 0 via `batch_kwargs[field] or 0`. With micro_batch_splits=1 (the default) model_inputs[1:] is empty so there is no behaviour change. --- fast_llm/data/document/language_model.py | 6 ++++ tests/data/test_preprocessing.py | 41 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index da50468cb..601276df0 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -125,6 +125,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis model_inputs.append(model_input) + # num_docs is counted only on the first split (micro_sequence_index==0) to avoid + # double-counting documents that span a split boundary when micro_batch_splits > 1. + # The first split already has the correct count (or 0 for SDP rank > 0); clear the rest. + for model_input in model_inputs[1:]: + model_input.num_docs = None + return model_inputs def _get_model_input( diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index 1733c423b..6802095c0 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -174,3 +174,44 @@ def test_num_docs_sdp_two_docs_counted_once(): Assert.eq(model_input_rank0.num_docs, 2) Assert.eq(model_input_rank1.num_docs, 0) Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 2) + + +def test_num_docs_micro_batch_splits_only_first_split_counts(): + """With micro_batch_splits=2, num_docs is non-None only on the first split. + + A document that spans the split boundary would be visible in both splits' + cropped_lengths, so both would count it without this guard. The runner sums + num_docs across all splits in context.batch; only the first split must contribute + to avoid double-counting. + """ + # 9 tokens → 8 input tokens, divisible by micro_batch_splits=2 (4 each) + documents = [_make_grpo_document([1, 2, 3, 4, 5, 6, 7, 8, 9])] + config = LanguageModelBatchPreprocessingConfig( + use_grpo_data=True, + return_label_counts=True, + micro_batch_splits=2, + ) + split0, split1 = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + + # Only the first split carries the count; the second must be None (runner treats as 0). + Assert.eq(split0.num_docs, 1) + assert split1.num_docs is None + + # Simulated runner sum: 1 + (None→0) = 1, the correct denominator. + Assert.eq((split0.num_docs or 0) + (split1.num_docs or 0), 1) + + +def test_num_docs_micro_batch_splits_two_docs(): + """With micro_batch_splits=2 and two documents, only the first split counts both docs.""" + # 9 tokens total → 8 input tokens, divisible by 2 + documents = [_make_grpo_document([1, 2, 3, 4]), _make_grpo_document([5, 6, 7, 8, 9])] + config = LanguageModelBatchPreprocessingConfig( + use_grpo_data=True, + return_label_counts=True, + micro_batch_splits=2, + ) + split0, split1 = LanguageModelBatch.from_documents(documents).get_model_inputs(config) + + Assert.eq(split0.num_docs, 2) + assert split1.num_docs is None + Assert.eq((split0.num_docs or 0) + (split1.num_docs or 0), 2) From 03ea8057b015a9a3f174536372a25e3be03ca8d2 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 23 Mar 2026 13:25:41 +0000 Subject: [PATCH 3/4] fix num_docs computation on meta tensors during trainer setup get_input_meta() is called during setup/shape inference with meta tensors that have no real data. Calling .item() on a meta tensor raises: RuntimeError: Tensor.item() cannot be called on meta tensors Guard the computation with is_meta check, returning 0 for meta tensors (num_docs is not used during setup, only during actual training steps). --- fast_llm/data/document/language_model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 601276df0..6a20dceba 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -177,11 +177,15 @@ def _get_model_input( # each document is counted exactly once even though all SDP ranks see the same # documents (but process different token slices of them). if model_input.num_docs is None: - model_input.num_docs = ( - int((labels_per_document > 0).sum().item()) - if config.distributed.sequence_data_rank == 0 - else 0 - ) + # Skip .item() on meta tensors (called during setup/shape inference). + if labels_per_document.is_meta: + model_input.num_docs = 0 + else: + model_input.num_docs = ( + int((labels_per_document > 0).sum().item()) + if config.distributed.sequence_data_rank == 0 + else 0 + ) # Expand to one entry per token: find each token's document index via the sorted # length cumsum, then look up that document's label count. # TODO: Document index already computed in `LengthModelInputPreprocessor`. From a656e9c56c44b1f8b83c003ef3f894f01757836d Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 25 Mar 2026 15:02:44 +0000 Subject: [PATCH 4/4] add data pipeline diagnostic logging for bottleneck investigation Logs per-document read events (gap_ms, latency_ms, token count) to per-rank JSONL files under experiment_directory/data_pipeline_log/ when log_data_pipeline is enabled in StreamingDatasetConfig. Also logs per-micro-batch BATCH_START/BATCH_END events with duration and sample count from the schedule runner when log_data_pipeline is enabled in ScheduleConfig. --- fast_llm/data/dataset/config.py | 6 +++ fast_llm/data/dataset/streaming.py | 84 +++++++++++++++++++++--------- fast_llm/engine/schedule/config.py | 5 ++ fast_llm/engine/schedule/runner.py | 50 ++++++++++++++++++ 4 files changed, 121 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index bfe1509d6..4d8e98da7 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -291,6 +291,12 @@ class StreamingDatasetConfig[DocumentType: LanguageModelDocument](RedisConfig, S _abstract = False + log_data_pipeline: bool = Field( + default=False, + desc="Write per-read timing to data_pipeline_log/rank_{rank}.jsonl for pipeline diagnostics.", + hint=FieldHint.optional, + ) + def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) -> SampledDataset[DocumentType]: from fast_llm.data.dataset.streaming import RedisStreamingDataset diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 8835612ec..bc09bf864 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -12,6 +12,7 @@ from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.document.range import RangeDocument from fast_llm.data.document.token_data import TokenDataDocument +from fast_llm.engine.config_utils.run import get_run from fast_llm.utils import Assert @@ -139,28 +140,63 @@ def iterate(self, config: SamplingConfig, num_samples: int, seed: int) -> typing else: raise + # Set up pipeline diagnostics log file if enabled + log_file = None + if self._config.log_data_pipeline: + run = get_run() + if run is not None and run.experiment_directory is not None: + log_dir = run.experiment_directory / "data_pipeline_log" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = open(log_dir / f"rank_{config.rank}.jsonl", "a") + + last_read_time = time.time() start_time = time.time() - while True: - # XREADGROUP reads from the consumer group - # COUNT: max number of messages to fetch at once - # BLOCK: wait for new messages (milliseconds) - messages = client.xreadgroup( - groupname=REDIS_GROUP_NAME, - consumername=f"fast_llm_consumer_{config.rank}", - # ">" reads only new messages that have not been delivered to any consumer - streams={REDIS_DATA_STREAM: ">"}, - count=1, - block=1000, - # No explicit ACK: messages are processed immediately; on rank failure the job restarts, - # so message loss is acceptable and simplifies coordination - noack=True, - ) - if messages: - for stream_key, messages_ in messages: - assert stream_key == REDIS_DATA_STREAM.encode() - for message_id, message in messages_: - yield RedisStreamingDocumentData.from_message(message).to_document() - start_time = time.time() - - elif (t := time.time() - start_time) > self._config.timeout: - raise TimeoutError(f"No document received after {t} seconds") + try: + while True: + # XREADGROUP reads from the consumer group + # COUNT: max number of messages to fetch at once + # BLOCK: wait for new messages (milliseconds) + messages = client.xreadgroup( + groupname=REDIS_GROUP_NAME, + consumername=f"fast_llm_consumer_{config.rank}", + # ">" reads only new messages that have not been delivered to any consumer + streams={REDIS_DATA_STREAM: ">"}, + count=1, + block=1000, + # No explicit ACK: messages are processed immediately; on rank failure the job restarts, + # so message loss is acceptable and simplifies coordination + noack=True, + ) + if messages: + now = time.time() + for stream_key, messages_ in messages: + assert stream_key == REDIS_DATA_STREAM.encode() + for message_id, message in messages_: + doc = RedisStreamingDocumentData.from_message(message) + if log_file is not None: + write_ms = int(message_id.split(b"-")[0]) / 1000.0 + gap_ms = (now - last_read_time) * 1000.0 + latency_ms = (now - write_ms) * 1000.0 + log_file.write( + json.dumps( + { + "event": "READ", + "rank": config.rank, + "t": now, + "gap_ms": round(gap_ms, 2), + "latency_ms": round(latency_ms, 2), + "tokens": doc.num_tokens, + } + ) + + "\n" + ) + log_file.flush() + last_read_time = now + yield doc.to_document() + start_time = time.time() + + elif (t := time.time() - start_time) > self._config.timeout: + raise TimeoutError(f"No document received after {t} seconds") + finally: + if log_file is not None: + log_file.close() diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 48714db40..a849d05de 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -42,6 +42,11 @@ class ScheduleConfig(Config): hint=FieldHint.optional, valid=check_field(Assert.gt, 0), ) + log_data_pipeline: bool = Field( + default=False, + desc="Write per-micro-batch timing to data_pipeline_log/rank_{rank}.jsonl for pipeline diagnostics.", + hint=FieldHint.optional, + ) # Enable cpu throttling to avoid lag spikes, see https://arxiv.org/pdf/2211.05953.pdf, appendix D.2. throttle_cpu: bool = Field( default=True, diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 88f0bbbd5..dd41dad86 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -1,6 +1,7 @@ import collections import contextlib import dataclasses +import json import logging import time import typing @@ -447,14 +448,63 @@ def _backward(self, context: BatchContext, step: Step) -> torch.Tensor: self._record_compute(context, step) return input_grad + def _get_pipeline_log_file(self): + if not self._config.log_data_pipeline: + return None + if not hasattr(self, "_pipeline_log_file"): + run = get_run() + if run is not None and run.experiment_directory is not None: + log_dir = run.experiment_directory / "data_pipeline_log" + log_dir.mkdir(parents=True, exist_ok=True) + rank = self._distributed_config.data_rank + self._pipeline_log_file = open(log_dir / f"rank_{rank}.jsonl", "a") + else: + self._pipeline_log_file = None + return self._pipeline_log_file + def _get_forward_input(self, context: BatchContext, step: Step) -> torch.Tensor: if step.index not in context.batch: start_time = time.perf_counter() + rank = self._distributed_config.data_rank + log_file = self._get_pipeline_log_file() + + if log_file is not None: + log_file.write( + json.dumps( + { + "event": "BATCH_START", + "rank": rank, + "micro": step.index, + "t": time.time(), + } + ) + + "\n" + ) + log_file.flush() + samples_loaded = 0 while step.index not in context.batch: next(context.data_iterator) + samples_loaded += 1 data_time = (time.perf_counter() - start_time) * 1000 + + if log_file is not None: + log_file.write( + json.dumps( + { + "event": "BATCH_END", + "rank": rank, + "micro": step.index, + "t": time.time(), + "duration_ms": round(data_time, 2), + "samples": samples_loaded, + } + ) + + "\n" + ) + log_file.flush() + if data_time > self._config.data_batch_warn_time_ms: logger.warning(f"Data loading took {data_time:,.2f} ms") return context.inputs.pop(step.global_index).detach().requires_grad_(step.stage != 0)