Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fast_llm/data/document/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if typing.TYPE_CHECKING:
import torch

from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import TensorMeta


Expand Down Expand Up @@ -59,6 +60,15 @@ def to_kwargs(self) -> dict[str, typing.Any]:
AttentionKwargs.presents: self.presents,
}

@classmethod
def share_batch_data(cls, model_inputs: "list[ModelInput]", distributed: "Distributed"):
"""
Gather values depending on the entire data-parallel batch, ex. the total number of labels or documents.
Should be called in the main process because distributed operations are not available during preprocessing.
Implemented as a class method so quantities shared by all models inputs are only computed once.
TODO: ====== Use as entry point for batch broadcasting? ======
"""


@dataclasses.dataclass(kw_only=True)
class Batch(Document):
Expand Down
7 changes: 6 additions & 1 deletion fast_llm/data/document/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig):
return_position_index: bool = Field(default=False)


@config_class()
class TokenPreprocessingConfig(LengthPreprocessingConfig):
return_document_count: bool = Field(default=False)


@config_class()
class ImageNormalizationConfig(Config):
scale: float = Field(default=255.0)
Expand Down Expand Up @@ -62,7 +67,7 @@ def get_batch_meta(self, size: int = 1) -> "PatchBatch":


@config_class()
class LanguageModelBatchPreprocessingConfig(LengthPreprocessingConfig):
class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig):
_abstract = False
phase: PhaseType = Field(default=PhaseType.training)
micro_batch_splits: int = Field(default=1)
Expand Down
145 changes: 85 additions & 60 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import torch

from fast_llm.core.distributed import allreduce_scalar
from fast_llm.data.document.abstract import ModelInput
from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig
from fast_llm.data.document.patch import PatchBatch, PatchDocument, PatchModelInput
from fast_llm.data.document.range import RangeBatch, RangeDocument
from fast_llm.data.document.token import TokenBatch, TokenDocument, TokenModelInput
from fast_llm.data.document.token_data import TokenDataBatch, TokenDataDocument
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.utils import div

Expand All @@ -33,13 +35,36 @@ class LanguageModelTargetInput(ModelInput):
advantages: torch.Tensor | None = None
old_log_probabilities: torch.Tensor | None = None
label_counts: torch.Tensor | None = None
num_labels: int | None = None
num_labels_in_batch: int | None = None

@classmethod
def share_batch_data(cls, model_inputs: "list[LanguageModelTargetInput]", distributed: "Distributed"):
if model_inputs[0].num_labels is not None and model_inputs[0].num_labels_in_batch is None:
# We sum over sequences but not within a sequence.
num_labels_in_batch = allreduce_scalar(
sum(model_input.num_labels for model_input in model_inputs),
dtype=torch.int32,
group=distributed.batch_data_group,
)
for model_input in model_inputs:
model_input.num_labels_in_batch = num_labels_in_batch


@dataclasses.dataclass(kw_only=True)
class LanguageModelInput(TokenModelInput):
targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list)
image_patches: PatchModelInput | None = None

@classmethod
def share_batch_data(cls, model_inputs: "list[LanguageModelInput]", distributed: "Distributed"):
super().share_batch_data(model_inputs, distributed)
for targets in zip(*(model_input.targets for model_input in model_inputs), strict=True):
targets[0].share_batch_data(targets, distributed)
model_inputs[0].image_patches.share_batch_data(
[model_input.image_patches for model_input in model_inputs], distributed
)

def set_children_attributes(self) -> None:
if self.image_patches is not None:
self.image_patches.set_parent_attributes(self)
Expand All @@ -58,6 +83,7 @@ def to_kwargs(self) -> dict[str, typing.Any]:
LanguageModelKwargs.advantages: [target.advantages for target in self.targets],
LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets],
LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets],
LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets],
}
if self.image_patches is not None:
out.update(self.image_patches.to_kwargs())
Expand Down Expand Up @@ -113,6 +139,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis
)
):
model_input = self._get_model_input(sequence_k_past, sequence_k_past + local_input_length, config)
model_input.phase = config.phase

if config.use_image_patches:
model_input.image_patches = self.image_patches.get_model_input(
sequence_k_past, sequence_k_past + local_input_length, config.vision_encoder
)

model_input.pasts = presents
presents = None if micro_sequence_index == config.micro_batch_splits - 1 else []
Expand All @@ -121,73 +153,66 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis

model_inputs.append(model_input)

self._set_target_inputs(model_inputs, config)

return model_inputs

def _get_model_input(
self, begin: int, end: int, config: LanguageModelBatchPreprocessingConfig
) -> LanguageModelInput:
model_input = super()._get_model_input(begin, end, config)
model_input.phase = config.phase
def _set_target_inputs(
self, model_inputs: list[LanguageModelInput], config: LanguageModelBatchPreprocessingConfig
):
labels = self.tokens.clone()

if config.use_image_patches:
model_input.image_patches = self.image_patches.get_model_input(begin, end, config.vision_encoder)
# Apply loss masking spans.
if config.use_loss_masking_spans and self.loss_masking_spans is not None:
for span_begin, span_end in self.loss_masking_spans.ranges:
labels[span_begin:span_end] = -100

for prediction_distance in range(1, config.num_labels + 1):
label_begin = begin + prediction_distance
label_end = end + prediction_distance
# Keep complete documents to simplify preprocessing.
_, first_document_begin, last_document_end = self._get_cropped_lengths(begin, label_end)
cropped_lengths, _, _ = self._get_cropped_lengths(first_document_begin, last_document_end)
labels = self.tokens[first_document_begin:last_document_end].clone()
labels_in_range = labels[label_begin - first_document_begin : label_end - first_document_begin]

# Apply loss masking spans.
if config.use_loss_masking_spans and self.loss_masking_spans is not None:
for span_begin, span_end in self.loss_masking_spans.get_cropped_ranges(
first_document_begin, last_document_end
):
labels[span_begin:span_end] = -100

# Mask cross-document predictions.
document_begin = 0
for length in cropped_lengths:
labels[document_begin : document_begin + prediction_distance] = -100
for length in self.lengths:
labels[document_begin + prediction_distance - 1] = -100
document_begin += length

if config.return_label_counts:
# Count the number of non-masked labels in each document through cumulative sums.
mask = labels >= 0
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
label_counts = labels_per_document[document_index][
label_begin - first_document_begin : label_end - first_document_begin
]
mask = (
mask[label_begin - first_document_begin : label_end - first_document_begin]
if config.return_prediction_mask
else None
)
else:
label_counts = None
mask = labels_in_range >= 0 if config.return_prediction_mask else None

# Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions.
target_input = LanguageModelTargetInput(tokens=labels_in_range, mask=mask, label_counts=label_counts)

if config.use_grpo_data and not model_input.is_meta:
target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end)
target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data(
label_begin, label_end
prediction_labels = labels[
prediction_distance : len(self.tokens) - config.num_labels + prediction_distance
].clone()
mask = prediction_labels >= 0
label_counts = self._get_label_counts(mask) if config.return_label_counts else None

for input_index, model_input in enumerate(model_inputs):
begin = model_input.sequence_k_dim.size
end = begin + model_input.token_dim.size

# Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions.
target_input = LanguageModelTargetInput(
tokens=labels[begin:end],
mask=mask[begin:end] if config.return_prediction_mask else None,
label_counts=label_counts[begin:end] if config.return_label_counts else None,
# Set value for the first input only so `share_batch_data` generated the correct sum.
# TODO: ====== Make optional?
num_labels=mask.sum(dtype=torch.int32).item() if input_index == 0 else 0,
)

model_input.targets.append(target_input)

return model_input
if config.use_grpo_data and not model_input.is_meta:
target_input.advantages = self.advantages.get_cropped_data(
begin + prediction_distance, end + prediction_distance
)
target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data(
begin + prediction_distance, end + prediction_distance
)

model_input.targets.append(target_input)

def _get_label_counts(self, mask: torch.Tensor):
# Count the number of non-masked labels in each document through cumulative sums.
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
return labels_per_document[document_index]
6 changes: 3 additions & 3 deletions fast_llm/data/document/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ def from_documents(
document_begin += size
return cls(ranges=ranges) if ranges else None

def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]:
cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges)
return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_]
# def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]:
# cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges)
# return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_]
41 changes: 34 additions & 7 deletions fast_llm/data/document/token.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import dataclasses
import functools
import typing

import torch

from fast_llm.core.distributed import allreduce_scalar
from fast_llm.data.document.abstract import Batch, Document
from fast_llm.data.document.block import BlockModelInput, LengthModelInputPreprocessor
from fast_llm.data.document.config import LengthPreprocessingConfig
from fast_llm.data.document.config import TokenPreprocessingConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert

Expand All @@ -22,12 +24,34 @@ def __len__(self) -> int:
def device(self) -> torch.device:
return self.tokens.device

@property
def is_meta(self) -> bool:
return self.device.type == "meta"


@dataclasses.dataclass(kw_only=True)
class TokenModelInput(BlockModelInput, TokenDocument):
@functools.cached_property
def is_meta(self) -> bool:
return isinstance(self.tokens, TensorMeta)
num_documents: int | None = None
num_documents_in_batch: int | None = None

@classmethod
def share_batch_data(cls, model_inputs: "list[TokenModelInput]", distributed: "Distributed"):
if model_inputs[0].num_documents is not None and model_inputs[0].num_documents_in_batch is None:
# We sum over sequences but not within a sequence.
num_documents_in_batch = allreduce_scalar(
sum(model_input.num_documents for model_input in model_inputs),
dtype=torch.int32,
group=distributed.batch_data_group,
)
for model_input in model_inputs:
model_input.num_documents_in_batch = num_documents_in_batch

def to_kwargs(self) -> dict[str, typing.Any]:
# TODO: Avoid conversion, use `LanguageModelMicroBatch` directly instead.
return {
**super().to_kwargs(),
LanguageModelKwargs.num_documents_in_batch: self.num_documents_in_batch,
}


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -74,10 +98,13 @@ def _get_cropped_lengths(self, begin: int, end: int) -> tuple[list[int], int, in

return lengths, first_document_begin, document_end

def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConfig):
def _get_model_input(self, begin: int, end: int, config: TokenPreprocessingConfig):
model_input = self._model_input_class(tokens=self.tokens[begin:end])
lengths, first_document_begin, last_document_end = self._get_cropped_lengths(begin, end)

if config.return_document_count:
model_input.num_documents = len(self.lengths) if begin == 0 else 0

LengthModelInputPreprocessor(
lengths=lengths,
sequence_k_past=begin,
Expand All @@ -89,7 +116,7 @@ def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConf
).preprocess(model_input, config)

Assert.eq(model_input.token_dim.size, end - begin)
if self.tokens.device.type == "meta":
if self.is_meta:
model_input.tokens = TensorMeta.from_dims(
(model_input.token_dim,), tensor_name=f"tokens_{begin}_to_{end}", dtype=torch.int64
)
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,13 @@ def __init__(
@abc.abstractmethod
def preprocess_batch(
self,
model_inputs: list[ModelInput],
model_input: ModelInput,
*,
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
extra_kwargs: dict[str, typing.Any] | None = None,
device: torch.device | None,
) -> list[tuple[torch.Tensor, dict]]:
) -> tuple[torch.Tensor, dict]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass

Expand Down
6 changes: 3 additions & 3 deletions fast_llm/engine/inference/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import typing

from fast_llm.data.document.abstract import ModelInput
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.config import ScheduleConfig
Expand Down Expand Up @@ -57,15 +58,14 @@ def setup(self):
Assert.is_(self._runner._distributed, self._fast_llm_model.distributed)

def forward(
self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False
self, model_input: ModelInput, *, iteration: int = 1, return_metrics: bool = False
) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]:
# TODO: Return an actual model output.
reduced_losses, update_successful, metrics = self._runner.run_step(
iter((((input_, kwargs),),)),
iter(((model_input,),)),
self._schedule,
iteration=iteration,
return_metrics=return_metrics,
preprocessed=True,
)
assert update_successful
return reduced_losses, metrics
Loading
Loading