Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ class StreamingDatasetConfig[DocumentType: LanguageModelDocument](RedisConfig, S

_abstract = False

log_data_pipeline: bool = Field(
default=False,
desc="Write per-read timing to data_pipeline_log/rank_{rank}.jsonl for pipeline diagnostics.",
hint=FieldHint.optional,
)

def build_and_sample(self, config: SamplingConfig, num_samples: int, seed: int) -> SampledDataset[DocumentType]:
from fast_llm.data.dataset.streaming import RedisStreamingDataset

Expand Down
84 changes: 60 additions & 24 deletions fast_llm/data/dataset/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fast_llm.data.document.language_model import LanguageModelDocument
from fast_llm.data.document.range import RangeDocument
from fast_llm.data.document.token_data import TokenDataDocument
from fast_llm.engine.config_utils.run import get_run
from fast_llm.utils import Assert


Expand Down Expand Up @@ -139,28 +140,63 @@ def iterate(self, config: SamplingConfig, num_samples: int, seed: int) -> typing
else:
raise

# Set up pipeline diagnostics log file if enabled
log_file = None
if self._config.log_data_pipeline:
run = get_run()
if run is not None and run.experiment_directory is not None:
log_dir = run.experiment_directory / "data_pipeline_log"
log_dir.mkdir(parents=True, exist_ok=True)
log_file = open(log_dir / f"rank_{config.rank}.jsonl", "a")

last_read_time = time.time()
start_time = time.time()
while True:
# XREADGROUP reads from the consumer group
# COUNT: max number of messages to fetch at once
# BLOCK: wait for new messages (milliseconds)
messages = client.xreadgroup(
groupname=REDIS_GROUP_NAME,
consumername=f"fast_llm_consumer_{config.rank}",
# ">" reads only new messages that have not been delivered to any consumer
streams={REDIS_DATA_STREAM: ">"},
count=1,
block=1000,
# No explicit ACK: messages are processed immediately; on rank failure the job restarts,
# so message loss is acceptable and simplifies coordination
noack=True,
)
if messages:
for stream_key, messages_ in messages:
assert stream_key == REDIS_DATA_STREAM.encode()
for message_id, message in messages_:
yield RedisStreamingDocumentData.from_message(message).to_document()
start_time = time.time()

elif (t := time.time() - start_time) > self._config.timeout:
raise TimeoutError(f"No document received after {t} seconds")
try:
while True:
# XREADGROUP reads from the consumer group
# COUNT: max number of messages to fetch at once
# BLOCK: wait for new messages (milliseconds)
messages = client.xreadgroup(
groupname=REDIS_GROUP_NAME,
consumername=f"fast_llm_consumer_{config.rank}",
# ">" reads only new messages that have not been delivered to any consumer
streams={REDIS_DATA_STREAM: ">"},
count=1,
block=1000,
# No explicit ACK: messages are processed immediately; on rank failure the job restarts,
# so message loss is acceptable and simplifies coordination
noack=True,
)
if messages:
now = time.time()
for stream_key, messages_ in messages:
assert stream_key == REDIS_DATA_STREAM.encode()
for message_id, message in messages_:
doc = RedisStreamingDocumentData.from_message(message)
if log_file is not None:
write_ms = int(message_id.split(b"-")[0]) / 1000.0
gap_ms = (now - last_read_time) * 1000.0
latency_ms = (now - write_ms) * 1000.0
log_file.write(
json.dumps(
{
"event": "READ",
"rank": config.rank,
"t": now,
"gap_ms": round(gap_ms, 2),
"latency_ms": round(latency_ms, 2),
"tokens": doc.num_tokens,
}
)
+ "\n"
)
log_file.flush()
last_read_time = now
yield doc.to_document()
start_time = time.time()

elif (t := time.time() - start_time) > self._config.timeout:
raise TimeoutError(f"No document received after {t} seconds")
finally:
if log_file is not None:
log_file.close()
25 changes: 25 additions & 0 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class LanguageModelTargetInput(ModelInput):
class LanguageModelInput(TokenModelInput):
targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list)
image_patches: PatchModelInput | None = None
# Number of documents with at least one response token in this micro-batch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just using the plain document count? Is there any scenario where we would expect response-free documents and need to take them into account? Seems like they shouldn't be in the batch to begin with because they don't contribute to the loss...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that response-free documents shouldn't be in the batch. However the > 0 check is needed because of padding β€” by the time we reach _get_model_input, padding is just another entry in lengths and structurally indistinguishable from a real document. To get the true document count cleanly, we can store num_real_docs = len(documents) in from_documents before
padding is appended.

# Computed before any TP/SP/PP splitting; used to normalize per-document metrics.
num_docs: int | None = None

def set_children_attributes(self) -> None:
if self.image_patches is not None:
Expand All @@ -58,6 +61,7 @@ def to_kwargs(self) -> dict[str, typing.Any]:
LanguageModelKwargs.advantages: [target.advantages for target in self.targets],
LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets],
LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets],
LanguageModelKwargs.num_docs: self.num_docs,
}
if self.image_patches is not None:
out.update(self.image_patches.to_kwargs())
Expand Down Expand Up @@ -121,6 +125,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis

model_inputs.append(model_input)

# num_docs is counted only on the first split (micro_sequence_index==0) to avoid
# double-counting documents that span a split boundary when micro_batch_splits > 1.
# The first split already has the correct count (or 0 for SDP rank > 0); clear the rest.
for model_input in model_inputs[1:]:
model_input.num_docs = None

return model_inputs

def _get_model_input(
Expand Down Expand Up @@ -161,6 +171,21 @@ def _get_model_input(
length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Track documents with at least one response token for per-document metric
# normalization. Only counted on sequence_data_rank==0 so that when the runner
# all_reduces the denominator across the data group (which includes SDP ranks),
# each document is counted exactly once even though all SDP ranks see the same
# documents (but process different token slices of them).
if model_input.num_docs is None:
# Skip .item() on meta tensors (called during setup/shape inference).
if labels_per_document.is_meta:
model_input.num_docs = 0
else:
model_input.num_docs = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just reduce over the batch data group instead?

int((labels_per_document > 0).sum().item())
if config.distributed.sequence_data_rank == 0
else 0
)
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class LossDef:
name: str
formatted_name: str
# The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging.
# TODO: Allow variable count? Would need a reduction across PP devices.
count: int = 1
dtype: DataType = DataType.float32
# If set, normalize this metric by summing values from context.batch[i][denominator_batch_field]
# across micro-batches and DP ranks, instead of using count * data_parallel * num_inputs.
# The field must be a scalar (int or float) pre-computed before any TP/SP/PP splitting.
denominator_batch_field: str | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat redundant with count. Forcing the denominator to be in a scalar-valued kwarg field may not be generic enough, I'd suggest returning losses as a (numerator, denominator) tuple instead, or maybe reducing the denominator beforehand?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason of not returning (numerator, denominator) from the loss is that the loss is computed per micro-batch split (micro sequence), and a document can span a split boundary. If each split returned its own document count as the denominator, documents straddling a boundary would be counted twice. That's why num_docs is computed only on micro_sequence_index == 0 and
all_reduced β€” to count each document exactly once.

The current design handles this by precomputing num_docs at the batch level before splitting, and having the loss reference it by name via denominator_batch_field β€” essentially "use this precomputed scalar as my denominator" rather than computing it from within the split. Returning a denominator from the loss function directly would only work cleanly for
quantities that are local to a split (like num_tokens or num_labels), but num_docs inherently crosses split boundaries so it can't be accumulated that way easily.

Copy link
Collaborator

@jlamypoirier jlamypoirier Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In either case what matters is sum(numerator)/sum(denominator), so as long as the returned denominators sum to the right value (as in the valuers you're using) things still work

Anyway in #477 I'm precomputing sum(denominator) so we simply return numerator/sum(denominator) in losses and the reduction is a simple sum.

5 changes: 5 additions & 0 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class ScheduleConfig(Config):
hint=FieldHint.optional,
valid=check_field(Assert.gt, 0),
)
log_data_pipeline: bool = Field(
default=False,
desc="Write per-micro-batch timing to data_pipeline_log/rank_{rank}.jsonl for pipeline diagnostics.",
hint=FieldHint.optional,
)
# Enable cpu throttling to avoid lag spikes, see https://arxiv.org/pdf/2211.05953.pdf, appendix D.2.
throttle_cpu: bool = Field(
default=True,
Expand Down
91 changes: 80 additions & 11 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import contextlib
import dataclasses
import json
import logging
import time
import typing
Expand Down Expand Up @@ -287,20 +288,39 @@ def run_step(
def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
reduced_losses = {}
for name, losses in context.losses.items():
loss_def = self._loss_definitions[name]
if losses or self._distributed.pipeline_group:
if losses:
loss_count = (
self._loss_definitions[name].count
* self._distributed_config.data_parallel
* context.schedule.config.num_inputs
)
reduced_loss = torch.stack(losses).sum() / loss_count
if self._distributed.data_group:
all_reduce(reduced_loss, group=self._distributed.data_group)
denom_field = loss_def.denominator_batch_field
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for metrics, but not for actual losses used for training which would need to reproduce this code. I think this could be addressed by reducing the values at the beginning of the batch, ex. in ModelInput.reduce_counts so we can use the counts freely from anywhere.

if denom_field is not None:
# Normalize by a per-micro-batch scalar from the batch data (e.g. num_docs),
# computed before any TP/SP/PP splitting. Sum numerator and denominator
# independently across DP ranks so the result is a true global average.
numerator = torch.stack(losses).sum()
denominator = torch.tensor(
sum(
batch_kwargs[denom_field] or 0
for batch_kwargs in context.batch.values()
if denom_field in batch_kwargs
),
dtype=numerator.dtype,
device=numerator.device,
)
if self._distributed.data_group:
all_reduce(numerator, group=self._distributed.data_group)
all_reduce(denominator, group=self._distributed.data_group)
reduced_loss = numerator / denominator.clamp(min=1)
else:
loss_count = (
loss_def.count
* self._distributed_config.data_parallel
* context.schedule.config.num_inputs
)
reduced_loss = torch.stack(losses).sum() / loss_count
if self._distributed.data_group:
all_reduce(reduced_loss, group=self._distributed.data_group)
else:
reduced_loss = torch.zeros(
[1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device
)
reduced_loss = torch.zeros([1], dtype=loss_def.dtype.torch, device=self._distributed.device)
if self._distributed.pipeline_group:
all_reduce(reduced_loss, group=self._distributed.pipeline_group)
else:
Expand Down Expand Up @@ -428,14 +448,63 @@ def _backward(self, context: BatchContext, step: Step) -> torch.Tensor:
self._record_compute(context, step)
return input_grad

def _get_pipeline_log_file(self):
if not self._config.log_data_pipeline:
return None
if not hasattr(self, "_pipeline_log_file"):
run = get_run()
if run is not None and run.experiment_directory is not None:
log_dir = run.experiment_directory / "data_pipeline_log"
log_dir.mkdir(parents=True, exist_ok=True)
rank = self._distributed_config.data_rank
self._pipeline_log_file = open(log_dir / f"rank_{rank}.jsonl", "a")
else:
self._pipeline_log_file = None
return self._pipeline_log_file

def _get_forward_input(self, context: BatchContext, step: Step) -> torch.Tensor:
if step.index not in context.batch:
start_time = time.perf_counter()
rank = self._distributed_config.data_rank
log_file = self._get_pipeline_log_file()

if log_file is not None:
log_file.write(
json.dumps(
{
"event": "BATCH_START",
"rank": rank,
"micro": step.index,
"t": time.time(),
}
)
+ "\n"
)
log_file.flush()

samples_loaded = 0
while step.index not in context.batch:
next(context.data_iterator)
samples_loaded += 1

data_time = (time.perf_counter() - start_time) * 1000

if log_file is not None:
log_file.write(
json.dumps(
{
"event": "BATCH_END",
"rank": rank,
"micro": step.index,
"t": time.time(),
"duration_ms": round(data_time, 2),
"samples": samples_loaded,
}
)
+ "\n"
)
log_file.flush()

if data_time > self._config.data_batch_warn_time_ms:
logger.warning(f"Data loading took {data_time:,.2f} ms")
return context.inputs.pop(step.global_index).detach().requires_grad_(step.stage != 0)
Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/language_model/loss/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class LanguageModelLossKwargs(BlockKwargs):
advantages = "advantages"
old_log_probabilities = "old_log_probabilities"
label_counts = "num_labels_in_seq"
num_docs = "num_docs"


@config_class(registry=True)
Expand Down
6 changes: 5 additions & 1 deletion fast_llm/layers/language_model/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
LossDef(
self._logprob_metric_name,
formatted_name=self._logprob_metric_name,
count=1, # This is an additive metric over the sequence.
count=1,
dtype=DataType.float32,
# Normalize by the number of documents with response tokens in each micro-batch,
# giving a true per-document average regardless of variable document lengths.
# num_docs is computed before any TP/SP/PP splitting in language_model.py.
denominator_batch_field=LanguageModelLossKwargs.num_docs,
)
]

Expand Down
Loading
Loading