-
Notifications
You must be signed in to change notification settings - Fork 43
[WIP][PipelineRL] Normalization of new_logprobs and addition of other RL metrics #476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
930bd96
baff181
03ea805
a656e9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,9 @@ class LanguageModelTargetInput(ModelInput): | |
| class LanguageModelInput(TokenModelInput): | ||
| targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list) | ||
| image_patches: PatchModelInput | None = None | ||
| # Number of documents with at least one response token in this micro-batch. | ||
| # Computed before any TP/SP/PP splitting; used to normalize per-document metrics. | ||
| num_docs: int | None = None | ||
|
|
||
| def set_children_attributes(self) -> None: | ||
| if self.image_patches is not None: | ||
|
|
@@ -58,6 +61,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: | |
| LanguageModelKwargs.advantages: [target.advantages for target in self.targets], | ||
| LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], | ||
| LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], | ||
| LanguageModelKwargs.num_docs: self.num_docs, | ||
| } | ||
| if self.image_patches is not None: | ||
| out.update(self.image_patches.to_kwargs()) | ||
|
|
@@ -121,6 +125,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis | |
|
|
||
| model_inputs.append(model_input) | ||
|
|
||
| # num_docs is counted only on the first split (micro_sequence_index==0) to avoid | ||
| # double-counting documents that span a split boundary when micro_batch_splits > 1. | ||
| # The first split already has the correct count (or 0 for SDP rank > 0); clear the rest. | ||
| for model_input in model_inputs[1:]: | ||
| model_input.num_docs = None | ||
|
|
||
| return model_inputs | ||
|
|
||
| def _get_model_input( | ||
|
|
@@ -161,6 +171,21 @@ def _get_model_input( | |
| length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0) | ||
| label_count_cumsum = mask_cumsum[length_cumsum] | ||
| labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] | ||
| # Track documents with at least one response token for per-document metric | ||
| # normalization. Only counted on sequence_data_rank==0 so that when the runner | ||
| # all_reduces the denominator across the data group (which includes SDP ranks), | ||
| # each document is counted exactly once even though all SDP ranks see the same | ||
| # documents (but process different token slices of them). | ||
| if model_input.num_docs is None: | ||
| # Skip .item() on meta tensors (called during setup/shape inference). | ||
| if labels_per_document.is_meta: | ||
| model_input.num_docs = 0 | ||
| else: | ||
| model_input.num_docs = ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could just reduce over the batch data group instead? |
||
| int((labels_per_document > 0).sum().item()) | ||
| if config.distributed.sequence_data_rank == 0 | ||
| else 0 | ||
| ) | ||
| # Expand to one entry per token: find each token's document index via the sorted | ||
| # length cumsum, then look up that document's label count. | ||
| # TODO: Document index already computed in `LengthModelInputPreprocessor`. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,6 +109,9 @@ class LossDef: | |
| name: str | ||
| formatted_name: str | ||
| # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. | ||
| # TODO: Allow variable count? Would need a reduction across PP devices. | ||
| count: int = 1 | ||
| dtype: DataType = DataType.float32 | ||
| # If set, normalize this metric by summing values from context.batch[i][denominator_batch_field] | ||
| # across micro-batches and DP ranks, instead of using count * data_parallel * num_inputs. | ||
| # The field must be a scalar (int or float) pre-computed before any TP/SP/PP splitting. | ||
| denominator_batch_field: str | None = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somewhat redundant with
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason of not returning (numerator, denominator) from the loss is that the loss is computed per micro-batch split (micro sequence), and a document can span a split boundary. If each split returned its own document count as the denominator, documents straddling a boundary would be counted twice. That's why num_docs is computed only on micro_sequence_index == 0 and The current design handles this by precomputing num_docs at the batch level before splitting, and having the loss reference it by name via denominator_batch_field β essentially "use this precomputed scalar as my denominator" rather than computing it from within the split. Returning a denominator from the loss function directly would only work cleanly for
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In either case what matters is Anyway in #477 I'm precomputing |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import collections | ||
| import contextlib | ||
| import dataclasses | ||
| import json | ||
| import logging | ||
| import time | ||
| import typing | ||
|
|
@@ -287,20 +288,39 @@ def run_step( | |
| def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: | ||
| reduced_losses = {} | ||
| for name, losses in context.losses.items(): | ||
| loss_def = self._loss_definitions[name] | ||
| if losses or self._distributed.pipeline_group: | ||
| if losses: | ||
| loss_count = ( | ||
| self._loss_definitions[name].count | ||
| * self._distributed_config.data_parallel | ||
| * context.schedule.config.num_inputs | ||
| ) | ||
| reduced_loss = torch.stack(losses).sum() / loss_count | ||
| if self._distributed.data_group: | ||
| all_reduce(reduced_loss, group=self._distributed.data_group) | ||
| denom_field = loss_def.denominator_batch_field | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works for metrics, but not for actual losses used for training which would need to reproduce this code. I think this could be addressed by reducing the values at the beginning of the batch, ex. in |
||
| if denom_field is not None: | ||
| # Normalize by a per-micro-batch scalar from the batch data (e.g. num_docs), | ||
| # computed before any TP/SP/PP splitting. Sum numerator and denominator | ||
| # independently across DP ranks so the result is a true global average. | ||
| numerator = torch.stack(losses).sum() | ||
| denominator = torch.tensor( | ||
| sum( | ||
| batch_kwargs[denom_field] or 0 | ||
| for batch_kwargs in context.batch.values() | ||
| if denom_field in batch_kwargs | ||
| ), | ||
| dtype=numerator.dtype, | ||
| device=numerator.device, | ||
| ) | ||
| if self._distributed.data_group: | ||
| all_reduce(numerator, group=self._distributed.data_group) | ||
| all_reduce(denominator, group=self._distributed.data_group) | ||
| reduced_loss = numerator / denominator.clamp(min=1) | ||
| else: | ||
| loss_count = ( | ||
| loss_def.count | ||
| * self._distributed_config.data_parallel | ||
| * context.schedule.config.num_inputs | ||
| ) | ||
| reduced_loss = torch.stack(losses).sum() / loss_count | ||
| if self._distributed.data_group: | ||
| all_reduce(reduced_loss, group=self._distributed.data_group) | ||
| else: | ||
| reduced_loss = torch.zeros( | ||
| [1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device | ||
| ) | ||
| reduced_loss = torch.zeros([1], dtype=loss_def.dtype.torch, device=self._distributed.device) | ||
| if self._distributed.pipeline_group: | ||
| all_reduce(reduced_loss, group=self._distributed.pipeline_group) | ||
| else: | ||
|
|
@@ -428,14 +448,63 @@ def _backward(self, context: BatchContext, step: Step) -> torch.Tensor: | |
| self._record_compute(context, step) | ||
| return input_grad | ||
|
|
||
| def _get_pipeline_log_file(self): | ||
| if not self._config.log_data_pipeline: | ||
| return None | ||
| if not hasattr(self, "_pipeline_log_file"): | ||
| run = get_run() | ||
| if run is not None and run.experiment_directory is not None: | ||
| log_dir = run.experiment_directory / "data_pipeline_log" | ||
| log_dir.mkdir(parents=True, exist_ok=True) | ||
| rank = self._distributed_config.data_rank | ||
| self._pipeline_log_file = open(log_dir / f"rank_{rank}.jsonl", "a") | ||
| else: | ||
| self._pipeline_log_file = None | ||
| return self._pipeline_log_file | ||
|
|
||
| def _get_forward_input(self, context: BatchContext, step: Step) -> torch.Tensor: | ||
| if step.index not in context.batch: | ||
| start_time = time.perf_counter() | ||
| rank = self._distributed_config.data_rank | ||
| log_file = self._get_pipeline_log_file() | ||
|
|
||
| if log_file is not None: | ||
| log_file.write( | ||
| json.dumps( | ||
| { | ||
| "event": "BATCH_START", | ||
| "rank": rank, | ||
| "micro": step.index, | ||
| "t": time.time(), | ||
| } | ||
| ) | ||
| + "\n" | ||
| ) | ||
| log_file.flush() | ||
|
|
||
| samples_loaded = 0 | ||
| while step.index not in context.batch: | ||
| next(context.data_iterator) | ||
| samples_loaded += 1 | ||
|
|
||
| data_time = (time.perf_counter() - start_time) * 1000 | ||
|
|
||
| if log_file is not None: | ||
| log_file.write( | ||
| json.dumps( | ||
| { | ||
| "event": "BATCH_END", | ||
| "rank": rank, | ||
| "micro": step.index, | ||
| "t": time.time(), | ||
| "duration_ms": round(data_time, 2), | ||
| "samples": samples_loaded, | ||
| } | ||
| ) | ||
| + "\n" | ||
| ) | ||
| log_file.flush() | ||
|
|
||
| if data_time > self._config.data_batch_warn_time_ms: | ||
| logger.warning(f"Data loading took {data_time:,.2f} ms") | ||
| return context.inputs.pop(step.global_index).detach().requires_grad_(step.stage != 0) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just using the plain document count? Is there any scenario where we would expect response-free documents and need to take them into account? Seems like they shouldn't be in the batch to begin with because they don't contribute to the loss...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that response-free documents shouldn't be in the batch. However the > 0 check is needed because of padding β by the time we reach _get_model_input, padding is just another entry in lengths and structurally indistinguishable from a real document. To get the true document count cleanly, we can store num_real_docs = len(documents) in from_documents before
padding is appended.