Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fast_llm/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def reduce_op(
input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp.RedOpType = ReduceOp.SUM, async_op: bool = False
) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor:
if group:
handle = all_reduce(input_, group=group, async_op=async_op, op=op)
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/data/dataset/streaming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import json
import logging
import time
import typing

Expand All @@ -14,6 +15,8 @@
from fast_llm.data.document.token_data import TokenDataDocument
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


@config_class()
class RedisStreamingDocumentData(Config):
Expand Down
11 changes: 11 additions & 0 deletions fast_llm/data/document/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if typing.TYPE_CHECKING:
import torch

from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import TensorMeta


Expand Down Expand Up @@ -59,6 +60,16 @@ def to_kwargs(self) -> dict[str, typing.Any]:
AttentionKwargs.presents: self.presents,
}

@classmethod
def share_batch_data(cls, model_inputs: "list[ModelInput]", distributed: "Distributed"):
"""
Gather values depending on the entire data-parallel batch, ex. the total number of labels or documents.
Should be called in the main process because distributed operations are not available during preprocessing.
Implemented as a class method so quantities shared by all models inputs are only computed once.
Note: this may be called more than once (ex. reference model preprocessing), so the method should be idempotent.
TODO: ====== Use as entry point for batch broadcasting? ======
"""


@dataclasses.dataclass(kw_only=True)
class Batch(Document):
Expand Down
7 changes: 6 additions & 1 deletion fast_llm/data/document/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig):
return_position_index: bool = Field(default=False)


@config_class()
class TokenPreprocessingConfig(LengthPreprocessingConfig):
return_document_count: bool = Field(default=False)


@config_class()
class ImageNormalizationConfig(Config):
scale: float = Field(default=255.0)
Expand Down Expand Up @@ -62,7 +67,7 @@ def get_batch_meta(self, size: int = 1) -> "PatchBatch":


@config_class()
class LanguageModelBatchPreprocessingConfig(LengthPreprocessingConfig):
class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig):
_abstract = False
phase: PhaseType = Field(default=PhaseType.training)
micro_batch_splits: int = Field(default=1)
Expand Down
144 changes: 84 additions & 60 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import torch

from fast_llm.core.distributed import allreduce_scalar
from fast_llm.data.document.abstract import ModelInput
from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig
from fast_llm.data.document.patch import PatchBatch, PatchDocument, PatchModelInput
from fast_llm.data.document.range import RangeBatch, RangeDocument
from fast_llm.data.document.token import TokenBatch, TokenDocument, TokenModelInput
from fast_llm.data.document.token_data import TokenDataBatch, TokenDataDocument
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.utils import div

Expand All @@ -33,13 +35,37 @@ class LanguageModelTargetInput(ModelInput):
advantages: torch.Tensor | None = None
old_log_probabilities: torch.Tensor | None = None
label_counts: torch.Tensor | None = None
num_labels: int | None = None
num_labels_in_batch: int | None = None

@classmethod
def share_batch_data(cls, model_inputs: "list[LanguageModelTargetInput]", distributed: "Distributed"):
if model_inputs[0].num_labels is not None and model_inputs[0].num_labels_in_batch is None:
# We sum over sequences but not within a sequence.
num_labels_in_batch = allreduce_scalar(
sum(model_input.num_labels for model_input in model_inputs),
dtype=torch.int32,
group=distributed.batch_data_group,
)
for model_input in model_inputs:
model_input.num_labels_in_batch = num_labels_in_batch


@dataclasses.dataclass(kw_only=True)
class LanguageModelInput(TokenModelInput):
targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list)
image_patches: PatchModelInput | None = None

@classmethod
def share_batch_data(cls, model_inputs: "list[LanguageModelInput]", distributed: "Distributed"):
super().share_batch_data(model_inputs, distributed)
for targets in zip(*(model_input.targets for model_input in model_inputs), strict=True):
targets[0].share_batch_data(targets, distributed)
if model_inputs[0].image_patches is not None:
model_inputs[0].image_patches.share_batch_data(
[model_input.image_patches for model_input in model_inputs], distributed
)

def set_children_attributes(self) -> None:
if self.image_patches is not None:
self.image_patches.set_parent_attributes(self)
Expand All @@ -58,6 +84,7 @@ def to_kwargs(self) -> dict[str, typing.Any]:
LanguageModelKwargs.advantages: [target.advantages for target in self.targets],
LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets],
LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets],
LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets],
}
if self.image_patches is not None:
out.update(self.image_patches.to_kwargs())
Expand Down Expand Up @@ -113,6 +140,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis
)
):
model_input = self._get_model_input(sequence_k_past, sequence_k_past + local_input_length, config)
model_input.phase = config.phase

if config.use_image_patches:
model_input.image_patches = self.image_patches.get_model_input(
sequence_k_past, sequence_k_past + local_input_length, config.vision_encoder
)

model_input.pasts = presents
presents = None if micro_sequence_index == config.micro_batch_splits - 1 else []
Expand All @@ -121,73 +154,64 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis

model_inputs.append(model_input)

self._set_target_inputs(model_inputs, config)

return model_inputs

def _get_model_input(
self, begin: int, end: int, config: LanguageModelBatchPreprocessingConfig
) -> LanguageModelInput:
model_input = super()._get_model_input(begin, end, config)
model_input.phase = config.phase
def _set_target_inputs(
self, model_inputs: list[LanguageModelInput], config: LanguageModelBatchPreprocessingConfig
):
labels = self.tokens.clone()

if config.use_image_patches:
model_input.image_patches = self.image_patches.get_model_input(begin, end, config.vision_encoder)
# Apply loss masking spans.
if config.use_loss_masking_spans and self.loss_masking_spans is not None:
for span_begin, span_end in self.loss_masking_spans.ranges:
labels[span_begin:span_end] = -100

for prediction_distance in range(1, config.num_labels + 1):
label_begin = begin + prediction_distance
label_end = end + prediction_distance
# Keep complete documents to simplify preprocessing.
_, first_document_begin, last_document_end = self._get_cropped_lengths(begin, label_end)
cropped_lengths, _, _ = self._get_cropped_lengths(first_document_begin, last_document_end)
labels = self.tokens[first_document_begin:last_document_end].clone()
labels_in_range = labels[label_begin - first_document_begin : label_end - first_document_begin]

# Apply loss masking spans.
if config.use_loss_masking_spans and self.loss_masking_spans is not None:
for span_begin, span_end in self.loss_masking_spans.get_cropped_ranges(
first_document_begin, last_document_end
):
labels[span_begin:span_end] = -100

# Mask cross-document predictions.
document_begin = 0
for length in cropped_lengths:
labels[document_begin : document_begin + prediction_distance] = -100
for length in self.lengths:
if prediction_distance <= length:
labels[document_begin + prediction_distance - 1] = -100
document_begin += length

if config.return_label_counts:
# Count the number of non-masked labels in each document through cumulative sums.
mask = labels >= 0
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
label_counts = labels_per_document[document_index][
label_begin - first_document_begin : label_end - first_document_begin
]
mask = (
mask[label_begin - first_document_begin : label_end - first_document_begin]
if config.return_prediction_mask
else None
)
else:
label_counts = None
mask = labels_in_range >= 0 if config.return_prediction_mask else None

# Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions.
target_input = LanguageModelTargetInput(tokens=labels_in_range, mask=mask, label_counts=label_counts)

if config.use_grpo_data and not model_input.is_meta:
target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end)
target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data(
label_begin, label_end
mask = labels >= 0
label_counts = self._get_label_counts(mask) if config.return_label_counts else None

for input_index, model_input in enumerate(model_inputs):
label_end = model_input.sequence_k_dim.size + prediction_distance
label_begin = label_end - model_input.token_dim.size

# Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions.
target_input = LanguageModelTargetInput(
tokens=labels[label_begin:label_end].clone(),
mask=mask[label_begin:label_end] if config.return_prediction_mask else None,
label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None,
# Set value for the first input only so `share_batch_data` generated the correct sum.
# TODO: ====== Make optional?
num_labels=(
len(mask) if self.is_meta else mask.sum(dtype=torch.int32).item() if input_index == 0 else 0
),
)

model_input.targets.append(target_input)

return model_input
if config.use_grpo_data and not model_input.is_meta:
target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end)
target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data(
label_begin, label_end
)

model_input.targets.append(target_input)

def _get_label_counts(self, mask: torch.Tensor):
# Count the number of non-masked labels in each document through cumulative sums.
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
return labels_per_document[document_index]
6 changes: 3 additions & 3 deletions fast_llm/data/document/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ def from_documents(
document_begin += size
return cls(ranges=ranges) if ranges else None

def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]:
cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges)
return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_]
# def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]:
# cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges)
# return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_]
44 changes: 37 additions & 7 deletions fast_llm/data/document/token.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import dataclasses
import functools
import typing

import torch

from fast_llm.core.distributed import allreduce_scalar
from fast_llm.data.document.abstract import Batch, Document
from fast_llm.data.document.block import BlockModelInput, LengthModelInputPreprocessor
from fast_llm.data.document.config import LengthPreprocessingConfig
from fast_llm.data.document.config import TokenPreprocessingConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert

Expand All @@ -22,12 +24,34 @@ def __len__(self) -> int:
def device(self) -> torch.device:
return self.tokens.device

@property
def is_meta(self) -> bool:
return self.device.type == "meta"


@dataclasses.dataclass(kw_only=True)
class TokenModelInput(BlockModelInput, TokenDocument):
@functools.cached_property
def is_meta(self) -> bool:
return isinstance(self.tokens, TensorMeta)
num_documents: int | None = None
num_documents_in_batch: int | None = None

@classmethod
def share_batch_data(cls, model_inputs: "list[TokenModelInput]", distributed: "Distributed"):
if model_inputs[0].num_documents is not None and model_inputs[0].num_documents_in_batch is None:
# We sum over sequences but not within a sequence.
num_documents_in_batch = allreduce_scalar(
sum(model_input.num_documents for model_input in model_inputs),
dtype=torch.int32,
group=distributed.batch_data_group,
)
for model_input in model_inputs:
model_input.num_documents_in_batch = num_documents_in_batch

def to_kwargs(self) -> dict[str, typing.Any]:
# TODO: Avoid conversion, use `LanguageModelMicroBatch` directly instead.
return {
**super().to_kwargs(),
LanguageModelKwargs.num_documents_in_batch: self.num_documents_in_batch,
}


@dataclasses.dataclass(kw_only=True)
Expand Down Expand Up @@ -74,10 +98,16 @@ def _get_cropped_lengths(self, begin: int, end: int) -> tuple[list[int], int, in

return lengths, first_document_begin, document_end

def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConfig):
def _get_model_input(self, begin: int, end: int, config: TokenPreprocessingConfig):
model_input = self._model_input_class(tokens=self.tokens[begin:end])
lengths, first_document_begin, last_document_end = self._get_cropped_lengths(begin, end)

if config.return_document_count:
# Exclude the padding "length" from the document count.
model_input.num_documents = (
len(self.lengths) - (1 if self.unpadded_length < len(self.tokens) else 0) if begin == 0 else 0
)

LengthModelInputPreprocessor(
lengths=lengths,
sequence_k_past=begin,
Expand All @@ -89,7 +119,7 @@ def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConf
).preprocess(model_input, config)

Assert.eq(model_input.token_dim.size, end - begin)
if self.tokens.device.type == "meta":
if self.is_meta:
model_input.tokens = TensorMeta.from_dims(
(model_input.token_dim,), tensor_name=f"tokens_{begin}_to_{end}", dtype=torch.int64
)
Expand Down
9 changes: 4 additions & 5 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c
out += layer.get_compute_usage(input_, kwargs, config)
return out

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
def get_loss_definitions(self) -> list[LossDef]:
losses = []
for layer in self.get_layers():
if layer is not self:
losses += layer.get_loss_definitions(count)
losses += layer.get_loss_definitions()
return losses

def get_preprocessing_config(self) -> dict[str, typing.Any]:
Expand Down Expand Up @@ -178,14 +178,13 @@ def __init__(
@abc.abstractmethod
def preprocess_batch(
self,
model_inputs: list[ModelInput],
model_input: ModelInput,
*,
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
extra_kwargs: dict[str, typing.Any] | None = None,
device: torch.device | None,
) -> list[tuple[torch.Tensor, dict]]:
) -> tuple[torch.Tensor, dict]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass

Expand Down
Loading
Loading