diff --git a/fast_llm/data/document/abstract.py b/fast_llm/data/document/abstract.py index 85014452f..06ea0534b 100644 --- a/fast_llm/data/document/abstract.py +++ b/fast_llm/data/document/abstract.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: import torch + from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import TensorMeta @@ -59,6 +60,15 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.presents: self.presents, } + @classmethod + def share_batch_data(cls, model_inputs: "list[ModelInput]", distributed: "Distributed"): + """ + Gather values depending on the entire data-parallel batch, ex. the total number of labels or documents. + Should be called in the main process because distributed operations are not available during preprocessing. + Implemented as a class method so quantities shared by all models inputs are only computed once. + TODO: ====== Use as entry point for batch broadcasting? ====== + """ + @dataclasses.dataclass(kw_only=True) class Batch(Document): diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 8967227e8..352311b51 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -29,6 +29,11 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig): return_position_index: bool = Field(default=False) +@config_class() +class TokenPreprocessingConfig(LengthPreprocessingConfig): + return_document_count: bool = Field(default=False) + + @config_class() class ImageNormalizationConfig(Config): scale: float = Field(default=255.0) @@ -62,7 +67,7 @@ def get_batch_meta(self, size: int = 1) -> "PatchBatch": @config_class() -class LanguageModelBatchPreprocessingConfig(LengthPreprocessingConfig): +class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): _abstract = False phase: PhaseType = Field(default=PhaseType.training) micro_batch_splits: int = Field(default=1) diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 00040e576..076f3abb3 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -4,12 +4,14 @@ import torch +from fast_llm.core.distributed import allreduce_scalar from fast_llm.data.document.abstract import ModelInput from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.document.patch import PatchBatch, PatchDocument, PatchModelInput from fast_llm.data.document.range import RangeBatch, RangeDocument from fast_llm.data.document.token import TokenBatch, TokenDocument, TokenModelInput from fast_llm.data.document.token_data import TokenDataBatch, TokenDataDocument +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.utils import div @@ -33,6 +35,20 @@ class LanguageModelTargetInput(ModelInput): advantages: torch.Tensor | None = None old_log_probabilities: torch.Tensor | None = None label_counts: torch.Tensor | None = None + num_labels: int | None = None + num_labels_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, model_inputs: "list[LanguageModelTargetInput]", distributed: "Distributed"): + if model_inputs[0].num_labels is not None and model_inputs[0].num_labels_in_batch is None: + # We sum over sequences but not within a sequence. + num_labels_in_batch = allreduce_scalar( + sum(model_input.num_labels for model_input in model_inputs), + dtype=torch.int32, + group=distributed.batch_data_group, + ) + for model_input in model_inputs: + model_input.num_labels_in_batch = num_labels_in_batch @dataclasses.dataclass(kw_only=True) @@ -40,6 +56,15 @@ class LanguageModelInput(TokenModelInput): targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list) image_patches: PatchModelInput | None = None + @classmethod + def share_batch_data(cls, model_inputs: "list[LanguageModelInput]", distributed: "Distributed"): + super().share_batch_data(model_inputs, distributed) + for targets in zip(*(model_input.targets for model_input in model_inputs), strict=True): + targets[0].share_batch_data(targets, distributed) + model_inputs[0].image_patches.share_batch_data( + [model_input.image_patches for model_input in model_inputs], distributed + ) + def set_children_attributes(self) -> None: if self.image_patches is not None: self.image_patches.set_parent_attributes(self) @@ -58,6 +83,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets], } if self.image_patches is not None: out.update(self.image_patches.to_kwargs()) @@ -113,6 +139,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis ) ): model_input = self._get_model_input(sequence_k_past, sequence_k_past + local_input_length, config) + model_input.phase = config.phase + + if config.use_image_patches: + model_input.image_patches = self.image_patches.get_model_input( + sequence_k_past, sequence_k_past + local_input_length, config.vision_encoder + ) model_input.pasts = presents presents = None if micro_sequence_index == config.micro_batch_splits - 1 else [] @@ -121,73 +153,66 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis model_inputs.append(model_input) + self._set_target_inputs(model_inputs, config) + return model_inputs - def _get_model_input( - self, begin: int, end: int, config: LanguageModelBatchPreprocessingConfig - ) -> LanguageModelInput: - model_input = super()._get_model_input(begin, end, config) - model_input.phase = config.phase + def _set_target_inputs( + self, model_inputs: list[LanguageModelInput], config: LanguageModelBatchPreprocessingConfig + ): + labels = self.tokens.clone() - if config.use_image_patches: - model_input.image_patches = self.image_patches.get_model_input(begin, end, config.vision_encoder) + # Apply loss masking spans. + if config.use_loss_masking_spans and self.loss_masking_spans is not None: + for span_begin, span_end in self.loss_masking_spans.ranges: + labels[span_begin:span_end] = -100 for prediction_distance in range(1, config.num_labels + 1): - label_begin = begin + prediction_distance - label_end = end + prediction_distance - # Keep complete documents to simplify preprocessing. - _, first_document_begin, last_document_end = self._get_cropped_lengths(begin, label_end) - cropped_lengths, _, _ = self._get_cropped_lengths(first_document_begin, last_document_end) - labels = self.tokens[first_document_begin:last_document_end].clone() - labels_in_range = labels[label_begin - first_document_begin : label_end - first_document_begin] - - # Apply loss masking spans. - if config.use_loss_masking_spans and self.loss_masking_spans is not None: - for span_begin, span_end in self.loss_masking_spans.get_cropped_ranges( - first_document_begin, last_document_end - ): - labels[span_begin:span_end] = -100 - # Mask cross-document predictions. document_begin = 0 - for length in cropped_lengths: - labels[document_begin : document_begin + prediction_distance] = -100 + for length in self.lengths: + labels[document_begin + prediction_distance - 1] = -100 document_begin += length - if config.return_label_counts: - # Count the number of non-masked labels in each document through cumulative sums. - mask = labels >= 0 - mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) - length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0) - label_count_cumsum = mask_cumsum[length_cumsum] - labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] - # Expand to one entry per token: find each token's document index via the sorted - # length cumsum, then look up that document's label count. - # TODO: Document index already computed in `LengthModelInputPreprocessor`. - document_index = torch.searchsorted( - length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" - ) - label_counts = labels_per_document[document_index][ - label_begin - first_document_begin : label_end - first_document_begin - ] - mask = ( - mask[label_begin - first_document_begin : label_end - first_document_begin] - if config.return_prediction_mask - else None - ) - else: - label_counts = None - mask = labels_in_range >= 0 if config.return_prediction_mask else None - - # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. - target_input = LanguageModelTargetInput(tokens=labels_in_range, mask=mask, label_counts=label_counts) - - if config.use_grpo_data and not model_input.is_meta: - target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end) - target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data( - label_begin, label_end + prediction_labels = labels[ + prediction_distance : len(self.tokens) - config.num_labels + prediction_distance + ].clone() + mask = prediction_labels >= 0 + label_counts = self._get_label_counts(mask) if config.return_label_counts else None + + for input_index, model_input in enumerate(model_inputs): + begin = model_input.sequence_k_dim.size + end = begin + model_input.token_dim.size + + # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. + target_input = LanguageModelTargetInput( + tokens=labels[begin:end], + mask=mask[begin:end] if config.return_prediction_mask else None, + label_counts=label_counts[begin:end] if config.return_label_counts else None, + # Set value for the first input only so `share_batch_data` generated the correct sum. + # TODO: ====== Make optional? + num_labels=mask.sum(dtype=torch.int32).item() if input_index == 0 else 0, ) - - model_input.targets.append(target_input) - - return model_input + if config.use_grpo_data and not model_input.is_meta: + target_input.advantages = self.advantages.get_cropped_data( + begin + prediction_distance, end + prediction_distance + ) + target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data( + begin + prediction_distance, end + prediction_distance + ) + + model_input.targets.append(target_input) + + def _get_label_counts(self, mask: torch.Tensor): + # Count the number of non-masked labels in each document through cumulative sums. + mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) + length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) + label_count_cumsum = mask_cumsum[length_cumsum] + labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] + # Expand to one entry per token: find each token's document index via the sorted + # length cumsum, then look up that document's label count. + # TODO: Document index already computed in `LengthModelInputPreprocessor`. + document_index = torch.searchsorted( + length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" + ) + return labels_per_document[document_index] diff --git a/fast_llm/data/document/range.py b/fast_llm/data/document/range.py index ea5d0e7fd..ed2503455 100644 --- a/fast_llm/data/document/range.py +++ b/fast_llm/data/document/range.py @@ -32,6 +32,6 @@ def from_documents( document_begin += size return cls(ranges=ranges) if ranges else None - def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]: - cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges) - return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] + # def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]: + # cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges) + # return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] diff --git a/fast_llm/data/document/token.py b/fast_llm/data/document/token.py index 1871b2c83..8aeabb694 100644 --- a/fast_llm/data/document/token.py +++ b/fast_llm/data/document/token.py @@ -1,12 +1,14 @@ import dataclasses -import functools import typing import torch +from fast_llm.core.distributed import allreduce_scalar from fast_llm.data.document.abstract import Batch, Document from fast_llm.data.document.block import BlockModelInput, LengthModelInputPreprocessor -from fast_llm.data.document.config import LengthPreprocessingConfig +from fast_llm.data.document.config import TokenPreprocessingConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -22,12 +24,34 @@ def __len__(self) -> int: def device(self) -> torch.device: return self.tokens.device + @property + def is_meta(self) -> bool: + return self.device.type == "meta" + @dataclasses.dataclass(kw_only=True) class TokenModelInput(BlockModelInput, TokenDocument): - @functools.cached_property - def is_meta(self) -> bool: - return isinstance(self.tokens, TensorMeta) + num_documents: int | None = None + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, model_inputs: "list[TokenModelInput]", distributed: "Distributed"): + if model_inputs[0].num_documents is not None and model_inputs[0].num_documents_in_batch is None: + # We sum over sequences but not within a sequence. + num_documents_in_batch = allreduce_scalar( + sum(model_input.num_documents for model_input in model_inputs), + dtype=torch.int32, + group=distributed.batch_data_group, + ) + for model_input in model_inputs: + model_input.num_documents_in_batch = num_documents_in_batch + + def to_kwargs(self) -> dict[str, typing.Any]: + # TODO: Avoid conversion, use `LanguageModelMicroBatch` directly instead. + return { + **super().to_kwargs(), + LanguageModelKwargs.num_documents_in_batch: self.num_documents_in_batch, + } @dataclasses.dataclass(kw_only=True) @@ -74,10 +98,13 @@ def _get_cropped_lengths(self, begin: int, end: int) -> tuple[list[int], int, in return lengths, first_document_begin, document_end - def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConfig): + def _get_model_input(self, begin: int, end: int, config: TokenPreprocessingConfig): model_input = self._model_input_class(tokens=self.tokens[begin:end]) lengths, first_document_begin, last_document_end = self._get_cropped_lengths(begin, end) + if config.return_document_count: + model_input.num_documents = len(self.lengths) if begin == 0 else 0 + LengthModelInputPreprocessor( lengths=lengths, sequence_k_past=begin, @@ -89,7 +116,7 @@ def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConf ).preprocess(model_input, config) Assert.eq(model_input.token_dim.size, end - begin) - if self.tokens.device.type == "meta": + if self.is_meta: model_input.tokens = TensorMeta.from_dims( (model_input.token_dim,), tensor_name=f"tokens_{begin}_to_{end}", dtype=torch.int64 ) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index a12b68c17..a9f6887c7 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -178,14 +178,13 @@ def __init__( @abc.abstractmethod def preprocess_batch( self, - model_inputs: list[ModelInput], + model_input: ModelInput, *, phase: PhaseType, iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, - device: torch.device | None, - ) -> list[tuple[torch.Tensor, dict]]: + ) -> tuple[torch.Tensor, dict]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py index f3b16c647..d9ed695ec 100644 --- a/fast_llm/engine/inference/runner.py +++ b/fast_llm/engine/inference/runner.py @@ -1,6 +1,7 @@ import abc import typing +from fast_llm.data.document.abstract import ModelInput from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import ScheduleConfig @@ -57,15 +58,14 @@ def setup(self): Assert.is_(self._runner._distributed, self._fast_llm_model.distributed) def forward( - self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False + self, model_input: ModelInput, *, iteration: int = 1, return_metrics: bool = False ) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]: # TODO: Return an actual model output. reduced_losses, update_successful, metrics = self._runner.run_step( - iter((((input_, kwargs),),)), + iter(((model_input,),)), self._schedule, iteration=iteration, return_metrics=return_metrics, - preprocessed=True, ) assert update_successful return reduced_losses, metrics diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 24b8b3d63..20a777a70 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -146,7 +146,6 @@ def run_step( *, iteration: int = 1, return_metrics: bool = False, - preprocessed: bool = False, ) -> tuple[dict[str, float | int], bool, dict[str, typing.Any] | None]: assert self._is_setup assert schedule._config is self._config # Noqa @@ -161,7 +160,7 @@ def run_step( losses={loss_def: [] for loss_def in self._loss_definitions}, metrics=metrics, ) - context.data_iterator = self._preprocess_data(context, data_iterator, preprocessed) + context.data_iterator = self._preprocess_data(context, data_iterator) if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( @@ -328,16 +327,20 @@ def _train_step(self, context: BatchContext, step: Step) -> None: self._reduce(context, step) def _preprocess_data( - self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool + self, context: BatchContext, data_iterator: typing.Iterator ) -> typing.Generator[None, None, None]: grad_output = ( self._optimizer.grad_scale / self._config.num_inputs if context.schedule.phase.is_training else None ) - for micro_batch in range(self._config.sequential_micro_batches): - micro_batch_data = next(data_iterator) - if not preprocessed: - micro_batch_data = self._multi_stage.base_model.preprocess_batch( - micro_batch_data, + model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + if not preprocessed: + model_inputs[0][0].share_batch_data(model_inputs, self._distributed) + + for micro_batch, model_inputs_ in enumerate(model_inputs): + Assert.eq(len(model_inputs_), self._config.micro_batch_splits) + for micro_batch_split, model_input in enumerate(model_inputs_): + input_, kwargs = self._multi_stage.base_model.preprocess_batch( + model_input, phase=context.phase, iteration=context.iteration, metrics=context.metrics, @@ -347,10 +350,7 @@ def _preprocess_data( "num_micro_batches": self._config.sequential_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, - device=self._distributed.device, ) - Assert.eq(len(micro_batch_data), self._config.micro_batch_splits) - for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): kwargs.update(micro_batch_split=micro_batch_split) data_index = micro_batch * self._config.micro_batch_splits + micro_batch_split if self._stages_owned[0]: diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index bc425520f..361772818 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -127,12 +127,14 @@ def __init__( warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. - self._preprocessed_meta = self._multi_stage.base_model.preprocess_batch( - batch_meta, - phase=self._phase, - iteration=0, - device=None, - ) + self._preprocessed_meta = [ + self._multi_stage.base_model.preprocess_batch( + model_input, + phase=self._phase, + iteration=0, + ) + for model_input in batch_meta + ] self._steps, self._first_grad_stage = self._create_steps() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index a199ad154..4a8efdab6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -24,6 +24,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs): token_map = "token_map" sample_map = "sample_map" embedding_map = "embedding_map" + num_documents_in_batch = "num_documents_in_batch" # TODO: These are generic phase = "phase" loss_mask = "loss_mask" diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 99d4bce9a..5168aecfb 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -27,6 +27,7 @@ class LanguageModelLossKwargs(BlockKwargs): advantages = "advantages" old_log_probabilities = "old_log_probabilities" label_counts = "num_labels_in_seq" + num_labels_in_batch = "num_labels_in_batch" @config_class(registry=True) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index def664d66..a53c234f0 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -114,11 +114,10 @@ def _inner_forward( use_cache, output_hidden_states, ) - ((input_, kwargs),) = self.fast_llm_base_model.preprocess_batch( - [model_input], + input_, kwargs = self.fast_llm_base_model.preprocess_batch( + model_input, phase=PhaseType.inference, iteration=iteration, - device=self._fast_llm_model.distributed.device, ) self._inference_runner.forward(input_, kwargs, iteration=iteration) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index fc4537ee7..a21bdee7e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -1,3 +1,4 @@ +import dataclasses import functools import logging import re @@ -42,54 +43,46 @@ def __init__( def preprocess_batch( self, - model_inputs: list[LanguageModelInput], + model_input: LanguageModelInput, *, phase: PhaseType, iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, - device: torch.device | None, - ) -> list[tuple[torch.Tensor, dict]]: - reference_preprocessed_batches = {} - for name, reference_model in self._reference_models.items(): - reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( - model_inputs, - phase=PhaseType.inference, - iteration=iteration, - device=device, - ) - - preprocessed = [] - for input_index, model_input in enumerate(model_inputs): - if device is not None: - model_input.to_device_(device) - kwargs = model_input.to_kwargs() - kwargs[LanguageModelKwargs.iteration] = iteration - if extra_kwargs is not None: - Assert.empty(kwargs.keys() & extra_kwargs.keys()) - kwargs.update(extra_kwargs) - if phase == PhaseType.inference: - kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) - - if not model_input.is_meta: - for name, reference_model in self._reference_models.items(): - reference_tokens, reference_kwargs = reference_preprocessed_batches[name][input_index] - if name in self._decoder_reference_models: - # TODO: Get the actual names - reference_kwargs[BlockKwargs.output_hidden_states].add( - re.compile(r"decoder\.\d+\.mixer_output$") - ) - - reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - - kwargs[f"reference_{name}_hidden_states"] = { - layer_name: tensor - for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() - } - self.preprocess(kwargs) - preprocessed.append((model_input.tokens, kwargs)) - - return preprocessed + ) -> tuple[torch.Tensor, dict]: + if not model_input.is_meta: + model_input.to_device_(self._distributed.device) + kwargs = model_input.to_kwargs() + kwargs[LanguageModelKwargs.iteration] = iteration + if extra_kwargs is not None: + Assert.empty(kwargs.keys() & extra_kwargs.keys()) + kwargs.update(extra_kwargs) + if phase == PhaseType.inference: + kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) + + if not model_input.is_meta: + for name, reference_model in self._reference_models.items(): + output_hidden_states = set() + if name in self._head_reference_models: + output_hidden_states.add(re.compile(r"head\..*logits.*$")) + if name in self._decoder_reference_models: + # TODO: Get the actual names + output_hidden_states.add(re.compile(r"decoder\.\d+\.mixer_output$")) + assert len(output_hidden_states) >= 1 + reference_model_input = dataclasses.replace( + model_input, + output_hidden_states=output_hidden_states, + hidden_states={}, + ) + reference_model_input.set_children_attributes() + reference_model.forward(model_input, iteration=iteration) + + kwargs[f"reference_{name}_hidden_states"] = { + layer_name: tensor for layer_name, (meta, tensor) in reference_model_input.hidden_states.items() + } + self.preprocess(kwargs) + + return model_input.tokens, kwargs def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # TODO: Integrate to the `LayerBase` interface, move to `LanguageModel`, `MultiTokenPrediction`?