diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index 7d361a22e..46dea8fce 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -16,7 +16,7 @@ def reduce_op( - input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False + input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp.RedOpType = ReduceOp.SUM, async_op: bool = False ) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: if group: handle = all_reduce(input_, group=group, async_op=async_op, op=op) diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 8835612ec..e3fce4eb3 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -1,5 +1,6 @@ import functools import json +import logging import time import typing @@ -14,6 +15,8 @@ from fast_llm.data.document.token_data import TokenDataDocument from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + @config_class() class RedisStreamingDocumentData(Config): diff --git a/fast_llm/data/document/abstract.py b/fast_llm/data/document/abstract.py index 85014452f..6f546e9c3 100644 --- a/fast_llm/data/document/abstract.py +++ b/fast_llm/data/document/abstract.py @@ -9,6 +9,7 @@ if typing.TYPE_CHECKING: import torch + from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import TensorMeta @@ -59,6 +60,16 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.presents: self.presents, } + @classmethod + def share_batch_data(cls, model_inputs: "list[ModelInput]", distributed: "Distributed"): + """ + Gather values depending on the entire data-parallel batch, ex. the total number of labels or documents. + Should be called in the main process because distributed operations are not available during preprocessing. + Implemented as a class method so quantities shared by all models inputs are only computed once. + Note: this may be called more than once (ex. reference model preprocessing), so the method should be idempotent. + TODO: ====== Use as entry point for batch broadcasting? ====== + """ + @dataclasses.dataclass(kw_only=True) class Batch(Document): diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 8967227e8..352311b51 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -29,6 +29,11 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig): return_position_index: bool = Field(default=False) +@config_class() +class TokenPreprocessingConfig(LengthPreprocessingConfig): + return_document_count: bool = Field(default=False) + + @config_class() class ImageNormalizationConfig(Config): scale: float = Field(default=255.0) @@ -62,7 +67,7 @@ def get_batch_meta(self, size: int = 1) -> "PatchBatch": @config_class() -class LanguageModelBatchPreprocessingConfig(LengthPreprocessingConfig): +class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): _abstract = False phase: PhaseType = Field(default=PhaseType.training) micro_batch_splits: int = Field(default=1) diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 00040e576..7821b81c5 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -4,12 +4,14 @@ import torch +from fast_llm.core.distributed import allreduce_scalar from fast_llm.data.document.abstract import ModelInput from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.document.patch import PatchBatch, PatchDocument, PatchModelInput from fast_llm.data.document.range import RangeBatch, RangeDocument from fast_llm.data.document.token import TokenBatch, TokenDocument, TokenModelInput from fast_llm.data.document.token_data import TokenDataBatch, TokenDataDocument +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.utils import div @@ -33,6 +35,20 @@ class LanguageModelTargetInput(ModelInput): advantages: torch.Tensor | None = None old_log_probabilities: torch.Tensor | None = None label_counts: torch.Tensor | None = None + num_labels: int | None = None + num_labels_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, model_inputs: "list[LanguageModelTargetInput]", distributed: "Distributed"): + if model_inputs[0].num_labels is not None and model_inputs[0].num_labels_in_batch is None: + # We sum over sequences but not within a sequence. + num_labels_in_batch = allreduce_scalar( + sum(model_input.num_labels for model_input in model_inputs), + dtype=torch.int32, + group=distributed.batch_data_group, + ) + for model_input in model_inputs: + model_input.num_labels_in_batch = num_labels_in_batch @dataclasses.dataclass(kw_only=True) @@ -40,6 +56,16 @@ class LanguageModelInput(TokenModelInput): targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list) image_patches: PatchModelInput | None = None + @classmethod + def share_batch_data(cls, model_inputs: "list[LanguageModelInput]", distributed: "Distributed"): + super().share_batch_data(model_inputs, distributed) + for targets in zip(*(model_input.targets for model_input in model_inputs), strict=True): + targets[0].share_batch_data(targets, distributed) + if model_inputs[0].image_patches is not None: + model_inputs[0].image_patches.share_batch_data( + [model_input.image_patches for model_input in model_inputs], distributed + ) + def set_children_attributes(self) -> None: if self.image_patches is not None: self.image_patches.set_parent_attributes(self) @@ -58,6 +84,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets], } if self.image_patches is not None: out.update(self.image_patches.to_kwargs()) @@ -113,6 +140,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis ) ): model_input = self._get_model_input(sequence_k_past, sequence_k_past + local_input_length, config) + model_input.phase = config.phase + + if config.use_image_patches: + model_input.image_patches = self.image_patches.get_model_input( + sequence_k_past, sequence_k_past + local_input_length, config.vision_encoder + ) model_input.pasts = presents presents = None if micro_sequence_index == config.micro_batch_splits - 1 else [] @@ -121,73 +154,64 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis model_inputs.append(model_input) + self._set_target_inputs(model_inputs, config) + return model_inputs - def _get_model_input( - self, begin: int, end: int, config: LanguageModelBatchPreprocessingConfig - ) -> LanguageModelInput: - model_input = super()._get_model_input(begin, end, config) - model_input.phase = config.phase + def _set_target_inputs( + self, model_inputs: list[LanguageModelInput], config: LanguageModelBatchPreprocessingConfig + ): + labels = self.tokens.clone() - if config.use_image_patches: - model_input.image_patches = self.image_patches.get_model_input(begin, end, config.vision_encoder) + # Apply loss masking spans. + if config.use_loss_masking_spans and self.loss_masking_spans is not None: + for span_begin, span_end in self.loss_masking_spans.ranges: + labels[span_begin:span_end] = -100 for prediction_distance in range(1, config.num_labels + 1): - label_begin = begin + prediction_distance - label_end = end + prediction_distance - # Keep complete documents to simplify preprocessing. - _, first_document_begin, last_document_end = self._get_cropped_lengths(begin, label_end) - cropped_lengths, _, _ = self._get_cropped_lengths(first_document_begin, last_document_end) - labels = self.tokens[first_document_begin:last_document_end].clone() - labels_in_range = labels[label_begin - first_document_begin : label_end - first_document_begin] - - # Apply loss masking spans. - if config.use_loss_masking_spans and self.loss_masking_spans is not None: - for span_begin, span_end in self.loss_masking_spans.get_cropped_ranges( - first_document_begin, last_document_end - ): - labels[span_begin:span_end] = -100 - # Mask cross-document predictions. document_begin = 0 - for length in cropped_lengths: - labels[document_begin : document_begin + prediction_distance] = -100 + for length in self.lengths: + if prediction_distance <= length: + labels[document_begin + prediction_distance - 1] = -100 document_begin += length - if config.return_label_counts: - # Count the number of non-masked labels in each document through cumulative sums. - mask = labels >= 0 - mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) - length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0) - label_count_cumsum = mask_cumsum[length_cumsum] - labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] - # Expand to one entry per token: find each token's document index via the sorted - # length cumsum, then look up that document's label count. - # TODO: Document index already computed in `LengthModelInputPreprocessor`. - document_index = torch.searchsorted( - length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" - ) - label_counts = labels_per_document[document_index][ - label_begin - first_document_begin : label_end - first_document_begin - ] - mask = ( - mask[label_begin - first_document_begin : label_end - first_document_begin] - if config.return_prediction_mask - else None - ) - else: - label_counts = None - mask = labels_in_range >= 0 if config.return_prediction_mask else None - - # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. - target_input = LanguageModelTargetInput(tokens=labels_in_range, mask=mask, label_counts=label_counts) - - if config.use_grpo_data and not model_input.is_meta: - target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end) - target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data( - label_begin, label_end + mask = labels >= 0 + label_counts = self._get_label_counts(mask) if config.return_label_counts else None + + for input_index, model_input in enumerate(model_inputs): + label_end = model_input.sequence_k_dim.size + prediction_distance + label_begin = label_end - model_input.token_dim.size + + # Labels contain all four sources of masking: padding, user-defined spans, image placeholders, cross-document predictions. + target_input = LanguageModelTargetInput( + tokens=labels[label_begin:label_end].clone(), + mask=mask[label_begin:label_end] if config.return_prediction_mask else None, + label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None, + # Set value for the first input only so `share_batch_data` generated the correct sum. + # TODO: ====== Make optional? + num_labels=( + len(mask) if self.is_meta else mask.sum(dtype=torch.int32).item() if input_index == 0 else 0 + ), ) - - model_input.targets.append(target_input) - - return model_input + if config.use_grpo_data and not model_input.is_meta: + target_input.advantages = self.advantages.get_cropped_data(label_begin, label_end) + target_input.old_log_probabilities = self.old_log_probabilities.get_cropped_data( + label_begin, label_end + ) + + model_input.targets.append(target_input) + + def _get_label_counts(self, mask: torch.Tensor): + # Count the number of non-masked labels in each document through cumulative sums. + mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) + length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) + label_count_cumsum = mask_cumsum[length_cumsum] + labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] + # Expand to one entry per token: find each token's document index via the sorted + # length cumsum, then look up that document's label count. + # TODO: Document index already computed in `LengthModelInputPreprocessor`. + document_index = torch.searchsorted( + length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" + ) + return labels_per_document[document_index] diff --git a/fast_llm/data/document/range.py b/fast_llm/data/document/range.py index ea5d0e7fd..ed2503455 100644 --- a/fast_llm/data/document/range.py +++ b/fast_llm/data/document/range.py @@ -32,6 +32,6 @@ def from_documents( document_begin += size return cls(ranges=ranges) if ranges else None - def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]: - cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges) - return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] + # def get_cropped_ranges(self, begin: int, end: int) -> list[tuple[int, int]]: + # cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in self.ranges) + # return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] diff --git a/fast_llm/data/document/token.py b/fast_llm/data/document/token.py index 1871b2c83..70261a152 100644 --- a/fast_llm/data/document/token.py +++ b/fast_llm/data/document/token.py @@ -1,12 +1,14 @@ import dataclasses -import functools import typing import torch +from fast_llm.core.distributed import allreduce_scalar from fast_llm.data.document.abstract import Batch, Document from fast_llm.data.document.block import BlockModelInput, LengthModelInputPreprocessor -from fast_llm.data.document.config import LengthPreprocessingConfig +from fast_llm.data.document.config import TokenPreprocessingConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -22,12 +24,34 @@ def __len__(self) -> int: def device(self) -> torch.device: return self.tokens.device + @property + def is_meta(self) -> bool: + return self.device.type == "meta" + @dataclasses.dataclass(kw_only=True) class TokenModelInput(BlockModelInput, TokenDocument): - @functools.cached_property - def is_meta(self) -> bool: - return isinstance(self.tokens, TensorMeta) + num_documents: int | None = None + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, model_inputs: "list[TokenModelInput]", distributed: "Distributed"): + if model_inputs[0].num_documents is not None and model_inputs[0].num_documents_in_batch is None: + # We sum over sequences but not within a sequence. + num_documents_in_batch = allreduce_scalar( + sum(model_input.num_documents for model_input in model_inputs), + dtype=torch.int32, + group=distributed.batch_data_group, + ) + for model_input in model_inputs: + model_input.num_documents_in_batch = num_documents_in_batch + + def to_kwargs(self) -> dict[str, typing.Any]: + # TODO: Avoid conversion, use `LanguageModelMicroBatch` directly instead. + return { + **super().to_kwargs(), + LanguageModelKwargs.num_documents_in_batch: self.num_documents_in_batch, + } @dataclasses.dataclass(kw_only=True) @@ -74,10 +98,16 @@ def _get_cropped_lengths(self, begin: int, end: int) -> tuple[list[int], int, in return lengths, first_document_begin, document_end - def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConfig): + def _get_model_input(self, begin: int, end: int, config: TokenPreprocessingConfig): model_input = self._model_input_class(tokens=self.tokens[begin:end]) lengths, first_document_begin, last_document_end = self._get_cropped_lengths(begin, end) + if config.return_document_count: + # Exclude the padding "length" from the document count. + model_input.num_documents = ( + len(self.lengths) - (1 if self.unpadded_length < len(self.tokens) else 0) if begin == 0 else 0 + ) + LengthModelInputPreprocessor( lengths=lengths, sequence_k_past=begin, @@ -89,7 +119,7 @@ def _get_model_input(self, begin: int, end: int, config: LengthPreprocessingConf ).preprocess(model_input, config) Assert.eq(model_input.token_dim.size, end - begin) - if self.tokens.device.type == "meta": + if self.is_meta: model_input.tokens = TensorMeta.from_dims( (model_input.token_dim,), tensor_name=f"tokens_{begin}_to_{end}", dtype=torch.int64 ) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index a12b68c17..4cb529463 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -48,11 +48,11 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c out += layer.get_compute_usage(input_, kwargs, config) return out - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: losses = [] for layer in self.get_layers(): if layer is not self: - losses += layer.get_loss_definitions(count) + losses += layer.get_loss_definitions() return losses def get_preprocessing_config(self) -> dict[str, typing.Any]: @@ -178,14 +178,13 @@ def __init__( @abc.abstractmethod def preprocess_batch( self, - model_inputs: list[ModelInput], + model_input: ModelInput, *, phase: PhaseType, iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, - device: torch.device | None, - ) -> list[tuple[torch.Tensor, dict]]: + ) -> tuple[torch.Tensor, dict]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 0526b9dc2..2770e67a2 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -1,5 +1,6 @@ import abc import dataclasses +import enum import typing from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class @@ -11,6 +12,7 @@ import torch from fast_llm.engine.base_model.base_model import BaseModel + from fast_llm.engine.distributed.distributed import Distributed @config_class() @@ -103,12 +105,81 @@ class ResourceUsageConfig: backward: int = 1 +class ReductionType(enum.StrEnum): + """ + An enum to represent data types independently of third party libraries, + so we can swap them more easily and allow for lazy imports. + """ + + sum = "float64" + average = "float32" + minimum = "float16" + maximum = "bfloat16" + + @property + def torch(self) -> "typing.Callable[[torch.Tensor], torch.Tensor]": + if not _TORCH_REDUCTION_MAP: + _set_torch_reduction_map() + return _TORCH_REDUCTION_MAP[self] + + @property + def distributed(self) -> "torch.distributed.ReduceOp.RedOpType": + if not _DISTRIBUTED_REDUCTION_MAP: + _set_distributed_reduction_map() + return _DISTRIBUTED_REDUCTION_MAP[self] + + +_TORCH_REDUCTION_MAP: dict[ReductionType, "typing.Callable[[torch.Tensor], torch.Tensor]"] = {} + + +def _set_torch_reduction_map() -> None: + import torch + + global _TORCH_REDUCTION_MAP + + _TORCH_REDUCTION_MAP = { + ReductionType.sum: torch.sum, + ReductionType.average: torch.mean, + ReductionType.minimum: torch.min, + ReductionType.maximum: torch.max, + } + + +_DISTRIBUTED_REDUCTION_MAP: dict[ReductionType, "torch.distributed.ReduceOp.RedOpType"] = {} + + +def _set_distributed_reduction_map() -> None: + import torch + + global _DISTRIBUTED_REDUCTION_MAP + + _DISTRIBUTED_REDUCTION_MAP = { + ReductionType.sum: torch.distributed.ReduceOp.SUM, + ReductionType.average: torch.distributed.ReduceOp.AVG, + ReductionType.minimum: torch.distributed.ReduceOp.MIN, + ReductionType.maximum: torch.distributed.ReduceOp.MAX, + } + + @dataclasses.dataclass() class LossDef: # A name for the loss name: str - formatted_name: str - # The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging. - # TODO: Allow variable count? Would need a reduction across PP devices. - count: int = 1 dtype: DataType = DataType.float32 + reduction: ReductionType = ReductionType.sum + + def reduce(self, losses: "list[torch.Tensor]", distributed: "Distributed") -> "torch.Tensor | None": + import torch + + from fast_llm.core.ops import reduce_op + + if losses or distributed.pipeline_group: + if losses: + reduced_loss = losses[0] if len(losses) == 1 else self.reduction.torch(torch.stack(losses)) + reduce_op(reduced_loss, group=distributed.data_group, op=self.reduction.distributed) + else: + reduced_loss = torch.zeros([1], dtype=self.dtype.torch, device=distributed.device) + reduce_op(reduced_loss, group=distributed.pipeline_group, op=self.reduction.distributed) + return reduced_loss + else: + return None diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index c3950cedf..a214e8e50 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -80,6 +80,16 @@ class DistributedDim: def __post_init__(self): self._is_setup = False + def __getstate__(self): + # Prevent process groups from being pickled, ex. in the data loader. + state = self.__dict__.copy() + if "_group" in state: + del state["_group"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + @property def group(self) -> "ProcessGroup|None": assert hasattr(self, "_group") @@ -117,7 +127,9 @@ def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start: global_ranks = range(start, start + size * stride, global_ranks.step) else: - global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks] + global_ranks = tuple( + rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks + ) Assert.eq(len(global_ranks), world_size) return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks) diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py index f3b16c647..d9ed695ec 100644 --- a/fast_llm/engine/inference/runner.py +++ b/fast_llm/engine/inference/runner.py @@ -1,6 +1,7 @@ import abc import typing +from fast_llm.data.document.abstract import ModelInput from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import ScheduleConfig @@ -57,15 +58,14 @@ def setup(self): Assert.is_(self._runner._distributed, self._fast_llm_model.distributed) def forward( - self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False + self, model_input: ModelInput, *, iteration: int = 1, return_metrics: bool = False ) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]: # TODO: Return an actual model output. reduced_losses, update_successful, metrics = self._runner.run_step( - iter((((input_, kwargs),),)), + iter(((model_input,),)), self._schedule, iteration=iteration, return_metrics=return_metrics, - preprocessed=True, ) assert update_successful return reduced_losses, metrics diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 24b8b3d63..7ad03b24c 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -146,7 +146,6 @@ def run_step( *, iteration: int = 1, return_metrics: bool = False, - preprocessed: bool = False, ) -> tuple[dict[str, float | int], bool, dict[str, typing.Any] | None]: assert self._is_setup assert schedule._config is self._config # Noqa @@ -161,7 +160,7 @@ def run_step( losses={loss_def: [] for loss_def in self._loss_definitions}, metrics=metrics, ) - context.data_iterator = self._preprocess_data(context, data_iterator, preprocessed) + context.data_iterator = self._preprocess_data(context, data_iterator) if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( @@ -285,30 +284,12 @@ def run_step( return self._reduce_losses(context), update_successful, metrics def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: - reduced_losses = {} - for name, losses in context.losses.items(): - if losses or self._distributed.pipeline_group: - if losses: - loss_count = ( - self._loss_definitions[name].count - * self._distributed_config.data_parallel - * context.schedule.config.num_inputs - ) - reduced_loss = torch.stack(losses).sum() / loss_count - if self._distributed.data_group: - all_reduce(reduced_loss, group=self._distributed.data_group) - else: - reduced_loss = torch.zeros( - [1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device - ) - if self._distributed.pipeline_group: - all_reduce(reduced_loss, group=self._distributed.pipeline_group) - else: - reduced_loss = 0.0 - reduced_losses[name] = reduced_loss + reduced_losses = { + name: self._loss_definitions[name].reduce(losses, self._distributed) + for name, losses in context.losses.items() + } return { - name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss - for name, reduced_loss in reduced_losses.items() + name: 0.0 if reduced_loss is None else reduced_loss.item() for name, reduced_loss in reduced_losses.items() } def _train_step(self, context: BatchContext, step: Step) -> None: @@ -328,16 +309,25 @@ def _train_step(self, context: BatchContext, step: Step) -> None: self._reduce(context, step) def _preprocess_data( - self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool + self, context: BatchContext, data_iterator: typing.Iterator ) -> typing.Generator[None, None, None]: + # We multiply by the data-parallel size to improve numerical stability (reduce numerical underflow). + # This factor is canceled in the averaging during gradient reduction. grad_output = ( - self._optimizer.grad_scale / self._config.num_inputs if context.schedule.phase.is_training else None + self._optimizer.grad_scale * self._distributed_config.data_parallel + if context.schedule.phase.is_training + else None + ) + model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + model_inputs[0][0].share_batch_data( + [model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed ) - for micro_batch in range(self._config.sequential_micro_batches): - micro_batch_data = next(data_iterator) - if not preprocessed: - micro_batch_data = self._multi_stage.base_model.preprocess_batch( - micro_batch_data, + + for micro_batch, model_inputs_ in enumerate(model_inputs): + Assert.eq(len(model_inputs_), self._config.micro_batch_splits) + for micro_batch_split, model_input in enumerate(model_inputs_): + input_, kwargs = self._multi_stage.base_model.preprocess_batch( + model_input, phase=context.phase, iteration=context.iteration, metrics=context.metrics, @@ -347,10 +337,7 @@ def _preprocess_data( "num_micro_batches": self._config.sequential_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, - device=self._distributed.device, ) - Assert.eq(len(micro_batch_data), self._config.micro_batch_splits) - for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): kwargs.update(micro_batch_split=micro_batch_split) data_index = micro_batch * self._config.micro_batch_splits + micro_batch_split if self._stages_owned[0]: @@ -408,7 +395,7 @@ def _recv(self, context: BatchContext, step: Step) -> None: step.recv_event.wait() self._record_event(context, EventType.compute_wait_pipe, step) - def _forward(self, context: BatchContext, step: Step) -> None: + def _forward(self, context: BatchContext, step: Step) -> torch.Tensor | None: output, grad_context = self._stages[step.stage].forward( self._get_forward_input(context, step), context.batch[step.index], diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index bc425520f..e2a9c75b5 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -127,12 +127,14 @@ def __init__( warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. - self._preprocessed_meta = self._multi_stage.base_model.preprocess_batch( - batch_meta, - phase=self._phase, - iteration=0, - device=None, - ) + self._preprocessed_meta = [ + self._multi_stage.base_model.preprocess_batch( + model_input, + phase=self._phase, + iteration=0, + ) + for model_input in batch_meta + ] self._steps, self._first_grad_stage = self._create_steps() @@ -536,6 +538,6 @@ def compute_usage(self) -> tuple[int | None, int | None]: def get_compute_metrics(self, time_per_iteration: float) -> dict[str, float]: model_compute, hardware_compute = self.compute_usage return { - "model_tflops": math.nan if model_compute is None else model_compute / time_per_iteration, - "hardware_tflops": math.nan if hardware_compute is None else hardware_compute / time_per_iteration, + "model_tflops": math.nan if model_compute is None else model_compute / time_per_iteration / 1e12, + "hardware_tflops": math.nan if hardware_compute is None else hardware_compute / time_per_iteration / 1e12, } diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 65dcee32b..05eaae520 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -2,6 +2,7 @@ from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce from fast_llm.functional.config import EntropyLossType, TargetFormat +from fast_llm.functional.utils import reduce_losses from fast_llm.utils import Assert @@ -285,18 +286,21 @@ def fused_entropy_loss_forward_backward( temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, but still suboptimal because it needs multiple kernels. """ - grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + if divisor is None: + divisor = logits.shape[:-1].numel() + grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor if target_format == TargetFormat.labels: assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) assert loss_mask is None loss_mask = target >= 0 - per_sample_loss, grad = _fused_cross_entropy_base_from_labels( + losses, grad = _fused_cross_entropy_base_from_labels( logits, target, loss_mask, @@ -305,7 +309,7 @@ def fused_entropy_loss_forward_backward( group, ) elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl): - per_sample_loss, grad = _fused_cross_entropy_base_from_distribution( + losses, grad = _fused_cross_entropy_base_from_distribution( logits, target, grad_output, @@ -316,7 +320,7 @@ def fused_entropy_loss_forward_backward( return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, ) elif entropy_loss_type == EntropyLossType.reverse_kl: - per_sample_loss, grad = _fused_reverse_kl_base_from_distribution( + losses, grad = _fused_reverse_kl_base_from_distribution( logits, target, grad_output, @@ -328,9 +332,7 @@ def fused_entropy_loss_forward_backward( else: raise NotImplementedError(entropy_loss_type) - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss = reduce_losses(losses, divisor, loss_mask) if grad is not None: if loss_mask is not None: diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index 38658ffc5..e1742a1bb 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -8,7 +8,6 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.functional.triton.sparse_linear import ( @@ -17,6 +16,7 @@ input_row_sparse_matmul, output_sparse_matmul, ) +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import accumulate_gradient, param_get_and_unset_is_zero diff --git a/fast_llm/functional/triton/entropy_loss.py b/fast_llm/functional/triton/entropy_loss.py index 3d9937439..9ec13a7d4 100644 --- a/fast_llm/functional/triton/entropy_loss.py +++ b/fast_llm/functional/triton/entropy_loss.py @@ -2,6 +2,7 @@ from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.utils import reduce_losses @triton_jit() @@ -656,7 +657,7 @@ def _cross_entropy_loss_from_labels( sum_exp_logits: torch.Tensor, max_logits: torch.Tensor, ) -> torch.Tensor: - return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0).mean() + return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0) @torch.compile @@ -700,6 +701,7 @@ def triton_entropy_loss_forward_backward( entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, block_size: int | None = None, num_warps: int | None = None, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -712,6 +714,8 @@ def triton_entropy_loss_forward_backward( assert target.is_contiguous() n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) + if divisor is None: + divisor = n_rows if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -730,7 +734,7 @@ def triton_entropy_loss_forward_backward( grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / n_rows, + "grad_losses": grad_output / divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } @@ -745,23 +749,22 @@ def triton_entropy_loss_forward_backward( **kwargs, **backward_kwargs, ) - loss = losses.mean() else: - partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(partial_losses) - sum_exp_logits = torch.empty_like(partial_losses) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(losses) + sum_exp_logits = torch.empty_like(losses) triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( logits, target, max_logits_ptr=local_max_logits, sum_exp_logits_ptr=sum_exp_logits, - predicted_logits_ptr=partial_losses, + predicted_logits_ptr=losses, col_min=n_cols * group.rank(), **kwargs, ) max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) - torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _cross_entropy_loss_from_labels(partial_losses, target, sum_exp_logits, max_logits) + torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, group=group) + losses = _cross_entropy_loss_from_labels(losses, target, sum_exp_logits, max_logits) if grad_output is not None: triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( logits, @@ -798,14 +801,13 @@ def triton_entropy_loss_forward_backward( **kwargs, **backward_kwargs, ) - loss = losses.mean() else: - partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(partial_losses) - sum_exp_logits = torch.empty_like(partial_losses) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(losses) + sum_exp_logits = torch.empty_like(losses) if target_format == TargetFormat.logits: - local_target_max_logits = torch.empty_like(partial_losses) - target_sum_exp_logits = torch.empty_like(partial_losses) + local_target_max_logits = torch.empty_like(losses) + target_sum_exp_logits = torch.empty_like(losses) else: local_target_max_logits = target_sum_exp_logits = None @@ -823,7 +825,7 @@ def triton_entropy_loss_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=local_target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - partial_losses_ptr=partial_losses, + partial_losses_ptr=losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, @@ -835,14 +837,12 @@ def triton_entropy_loss_forward_backward( target_sum_exp_logits, local_target_max_logits, group ) if entropy_loss_type != EntropyLossType.reverse_kl: - partial_losses = _rescale_predicted_logits( - partial_losses, local_target_max_logits, target_max_logits - ) + losses = _rescale_predicted_logits(losses, local_target_max_logits, target_max_logits) else: target_max_logits = None if entropy_loss_type == EntropyLossType.reverse_kl: - partial_losses = _rescale_predicted_logits(partial_losses, local_max_logits, max_logits) - torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) + losses = _rescale_predicted_logits(losses, local_max_logits, max_logits) + torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, group=group) kernel[(n_rows,)]( logits, @@ -852,13 +852,13 @@ def triton_entropy_loss_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - partial_losses_ptr=partial_losses, - losses_ptr=partial_losses, + partial_losses_ptr=losses, + losses_ptr=losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, **kwargs, **backward_kwargs, ) - loss = partial_losses.mean() + loss = reduce_losses(losses, divisor) return loss, grad_logits diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 7949faaf0..4a8c5f179 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -5,7 +5,6 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig from fast_llm.functional.linear import ( input_parallel_linear_forward, @@ -23,6 +22,7 @@ copy_sparse_to_dense_forward, ) from fast_llm.functional.triton.sparse_linear import output_sparse_matmul +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import param_get_and_unset_is_zero diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 9538a9275..7c25ce735 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -2,9 +2,9 @@ import torch -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, tl_full, triton, triton_jit +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import param_get_and_unset_is_zero diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 3d9c07145..f07046a52 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,8 +1,8 @@ import torch -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index e68692d9c..6af0c7828 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -3,9 +3,9 @@ import torch from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, TritonConfig from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.utils import wrap_forward_backward @dataclasses.dataclass() diff --git a/fast_llm/functional/triton/z_loss.py b/fast_llm/functional/triton/z_loss.py index cb3220131..d9592a4f4 100644 --- a/fast_llm/functional/triton/z_loss.py +++ b/fast_llm/functional/triton/z_loss.py @@ -6,6 +6,7 @@ triton_cross_entropy_forward_from_labels_parallel_kernel, triton_fused_softmax_base, ) +from fast_llm.functional.utils import reduce_losses @triton_jit() @@ -83,12 +84,15 @@ def triton_z_loss_forward_backward( logits_scale_factor: float = 1.0, block_size: int | None = None, num_warps: int | None = None, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: assert logits.is_contiguous() if loss_mask is not None: assert loss_mask.is_contiguous() n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) + if divisor is None: + divisor = logits.shape[:-1].numel() if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -108,7 +112,7 @@ def triton_z_loss_forward_backward( backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / n_rows, + "grad_losses": grad_output / divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } @@ -141,4 +145,5 @@ def triton_z_loss_forward_backward( **kwargs, **backward_kwargs, ) - return losses.mean(), grad_logits + loss = reduce_losses(losses, divisor) + return loss, grad_logits diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/utils.py similarity index 91% rename from fast_llm/functional/autograd.py rename to fast_llm/functional/utils.py index 586f833b3..b2fc4589d 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/utils.py @@ -69,3 +69,12 @@ def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float | Non @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa return grad_output, ctx.grad, None + + +@torch.compile +def reduce_losses( + losses: torch.Tensor, divisor: float | None = None, mask: torch.Tensor | None = None +) -> torch.Tensor: + if mask is not None: + losses = losses * mask + return losses.mean() if divisor is None else losses.sum() / divisor diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 16caf2d66..be40317f3 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index b085961bf..d2a8c7f3b 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -69,10 +69,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.num_blocks self._layers_with_namespace[0].preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - self[0].get_loss_definitions(count=count * self._config.num_blocks) if self._config.num_blocks > 0 else [] - ) + def get_loss_definitions(self) -> list[LossDef]: + return self[0].get_loss_definitions() if self._config.num_blocks > 0 else [] class PatternBlockSequence[ConfigType: PatternBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): @@ -139,11 +137,11 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[BlockKwargs.num_blocks_in_sequence] = self._config.expanded_pattern.count(name) self._layers_with_namespace[index].preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: # TODO: Prevent name conflicts. return sum( ( - self[self._config.preprocessing_layers[name]].get_loss_definitions(count=count * count_) + self[self._config.preprocessing_layers[name]].get_loss_definitions() for name, count_ in collections.Counter(self._config.expanded_pattern).items() ), [], diff --git a/fast_llm/layers/common/linear/linear.py b/fast_llm/layers/common/linear/linear.py index d0ea7a681..f19e97a94 100644 --- a/fast_llm/layers/common/linear/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -5,7 +5,6 @@ from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedDim -from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, input_parallel_linear_backward, @@ -15,6 +14,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index fcff5d496..eaf9f67f0 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -3,7 +3,7 @@ import torch from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.functional.utils import wrap_forward_backward from fast_llm.layers.common.linear.linear import Linear, LinearBase from fast_llm.tensor import ParameterMeta diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index a9d213912..a2f2d3519 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.functional.autograd import AuxiliaryLoss +from fast_llm.functional.utils import AuxiliaryLoss from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig @@ -216,18 +216,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: # TODO: add layer_index _distillation_loss_name = "activation_distillation_loss" - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: loss_definitions = [] if self._config.distillation_model is not None: - loss_definitions.append( - LossDef( - name=self._distillation_loss_name, - formatted_name=self._distillation_loss_name, - count=count, - ) - ) - return ( - loss_definitions - + self.mixer.get_loss_definitions(count=count) - + self.mlp.get_loss_definitions(count=count) - ) + loss_definitions.append(LossDef(name=self._distillation_loss_name)) + return loss_definitions + self.mixer.get_loss_definitions() + self.mlp.get_loss_definitions() diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 13ba79a7a..48bc5a5e1 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -9,9 +9,9 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.functional.utils import AuxiliaryLoss from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType @@ -247,24 +247,12 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: loss_definitions = [] if self._config.routing == RoutingType.topk: - loss_definitions.append( - LossDef( - name=MLPLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=1, - ) - ) + loss_definitions.append(LossDef(name=MLPLossNames.load_balancing_loss)) if self._config.z_loss_coefficient: - loss_definitions.append( - LossDef( - name=MLPLossNames.router_z_loss, - formatted_name="router z loss", - count=1, - ) - ) + loss_definitions.append(LossDef(name=MLPLossNames.router_z_loss)) return loss_definitions diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 9def3895c..97bd1f477 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -231,7 +231,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c return int(expected_usage) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: """ Merge loss definitions from all mixers with namespacing. @@ -241,13 +241,11 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: """ all_losses = [] for mixer_name, mixer in self.mixers.items(): - mixer_losses = mixer.get_loss_definitions(count=count) + mixer_losses = mixer.get_loss_definitions() # Namespace each loss with the mixer name to avoid conflicts for loss_def in mixer_losses: namespaced_loss = LossDef( name=f"{mixer_name}/{loss_def.name}", - formatted_name=f"{mixer_name}/{loss_def.formatted_name}", - count=loss_def.count, dtype=loss_def.dtype, ) all_losses.append(namespaced_loss) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index a199ad154..4a8efdab6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -24,6 +24,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs): token_map = "token_map" sample_map = "sample_map" embedding_map = "embedding_map" + num_documents_in_batch = "num_documents_in_batch" # TODO: These are generic phase = "phase" loss_mask = "loss_mask" diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b6b749095..d57b465bf 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,8 +10,8 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.functional.utils import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( @@ -136,7 +136,7 @@ def forward( (scalar_dim,), tensor_name=f"{self.module_name} output", reductions=( - (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), + (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.SUM), ), ) else: @@ -221,13 +221,13 @@ def _logits_loss_forward_backward( total_losses.append(total_loss_) # TODO: Avoid copy with explicit out argument. input_grad_.copy_(grad_) - total_loss = sum(total_losses) / self._config.cross_entropy_splits if total_losses else None + total_loss = torch.stack(total_losses).sum() if total_losses else None # TODO: ====== Drop return value, treat as normal loss ====== # Return value only needed because stage expects a return tensor if self._sequence_parallel_logits: # TODO: Async - all_reduce(total_loss, op=ReduceOp.AVG, group=self._parallel_dim.group) + all_reduce(total_loss, op=ReduceOp.SUM, group=self._parallel_dim.group) if losses is not None: losses[self._total_loss_name].append(total_loss) @@ -277,14 +277,10 @@ def _logits_loss_forward_backward_partial( output_parallel_linear_backward(grad, context) if self.training else None ) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: return [ - LossDef(name=self._total_loss_name, formatted_name=self._total_loss_name, count=count), - *( - loss_ - for loss in self.losses - for loss_ in loss.get_loss_definitions(count * self._config.cross_entropy_splits) - ), + LossDef(name=self._total_loss_name), + *(loss_ for loss in self.losses for loss_ in loss.get_loss_definitions()), ] def _get_full_loss_name(self, name) -> str: diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index c3dd625ec..1f12c5b52 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -82,14 +82,14 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.head.preprocess(kwargs) self.multi_token_prediction.preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? return sum( ( - self.embeddings.get_loss_definitions(count), - self.decoder.get_loss_definitions(count), - self.head.get_loss_definitions(count), - self.multi_token_prediction.get_loss_definitions(count), + self.embeddings.get_loss_definitions(), + self.decoder.get_loss_definitions(), + self.head.get_loss_definitions(), + self.multi_token_prediction.get_loss_definitions(), ), [], ) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 99d4bce9a..a2c067a95 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -26,7 +26,8 @@ class LanguageModelLossKwargs(BlockKwargs): rejected_spans = "rejected_spans" advantages = "advantages" old_log_probabilities = "old_log_probabilities" - label_counts = "num_labels_in_seq" + label_counts = "label_counts" + num_labels_in_batch = "num_labels_in_batch" @config_class(registry=True) diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index f16b6de44..48c1556f3 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -35,6 +35,7 @@ def _forward_backward( logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.labels, entropy_loss_type=self._config.loss_type, + divisor=self._get_label_count(kwargs), ) @@ -61,6 +62,7 @@ def _forward_backward( logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.logits, entropy_loss_type=self._config.loss_type, + divisor=self._get_label_count(kwargs), ) def get_preprocessing_config(self) -> dict[str, typing.Any]: diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 2136e7918..62f591d9f 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -4,8 +4,8 @@ import torch from fast_llm.engine.base_model.config import LossDef -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base +from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -35,6 +35,7 @@ def _forward_backward( if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), + divisor=self._get_label_count(kwargs), ) self._register_loss( @@ -42,15 +43,8 @@ def _forward_backward( ) return loss, grad - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return super().get_loss_definitions(count) + [ - LossDef( - self._logprob_metric_name, - formatted_name=self._logprob_metric_name, - count=1, # This is an additive metric over the sequence. - dtype=DataType.float32, - ) - ] + def get_loss_definitions(self) -> list[LossDef]: + return super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] def get_preprocessing_config( self, @@ -77,8 +71,11 @@ def fused_grpo_loss_forward_backward( num_labels_in_seq: ( torch.Tensor | None ) = None, # (*batch,) — response-span length broadcast per token, 0 for non-response + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: - grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + if divisor is None: + divisor = logits.shape[:-1].numel() + grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor loss_mask = target >= 0 logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) @@ -88,12 +85,11 @@ def fused_grpo_loss_forward_backward( new_log_probs = predicted_logits - sum_exp_logits.log() probability_ratio = (new_log_probs - old_log_probabilities).exp() - per_sample_loss = -torch.min( + losses = -torch.min( probability_ratio * advantages, torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, ) - per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss = reduce_losses(losses, divisor, loss_mask) # Sum of per-sequence mean log-probs, matching pipelinerl's new_logprobs metric: # sum_sum(new_logprobs / num_labels_in_seq, masks_shifted, segments) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 990d4c3a1..9a92661c9 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -6,7 +6,6 @@ from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs @@ -68,19 +67,8 @@ def _forward_backward( ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - [ - LossDef( - name=self.name, - formatted_name=self.name, - count=count, - dtype=DataType.float32, - ) - ] - if self._do_register_loss - else [] - ) + def get_loss_definitions(self) -> list[LossDef]: + return [LossDef(name=self.name)] if self._do_register_loss else [] def get_preprocessing_config( self, @@ -88,7 +76,7 @@ def get_preprocessing_config( return {} def _register_loss( - self, name: str, value: torch.Tensor, losses: dict | None, reduce_op=torch.distributed.ReduceOp.AVG + self, name: str, value: torch.Tensor, losses: dict | None, reduce_op=torch.distributed.ReduceOp.SUM ): if losses is None: return @@ -127,18 +115,14 @@ def _prepare_target( def _get_grad_output(self, kwargs: dict[str, typing.Any]) -> float | None: grad_output = kwargs.get(LanguageModelKwargs.grad_output) - if grad_output is not None: - grad_output = ( - grad_output - * self._weight - / (self._parallel_dim.size if self._sequence_parallel else 1) - / self._num_splits - ) - return grad_output + return None if grad_output is None else grad_output * self._weight def _get_labels(self, kwargs: dict[str, typing.Any], split_index: int = 0): return self._prepare_target(kwargs[LanguageModelLossKwargs.labels], split_index) + def _get_label_count(self, kwargs: dict[str, typing.Any]): + return kwargs[LanguageModelKwargs.num_labels_in_batch][self._prediction_distance - 1] + def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) return None if loss_mask is None else self._prepare_target(loss_mask, split_index) diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index 5565294d5..2e5f90b1d 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -5,6 +5,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_softmax_base from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward +from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -29,8 +30,12 @@ def _forward_backward( group=self._parallel_dim.group if self._vocab_parallel else None, logits_scale_factor=self._logits_scale_factor, grad_logits=grad_logits, + divisor=self._get_label_count(kwargs), ) + def get_preprocessing_config(self) -> dict[str, typing.Any]: + return {"return_prediction_mask": True} + @torch.compile def z_loss( @@ -54,19 +59,19 @@ def fused_z_loss_forward_backward( grad_output: float | None = None, group: torch.distributed.ProcessGroup | None = None, logits_scale_factor: float = 1.0, + divisor: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Z-loss = mean(logsumexp(logits, dim=-1) ** 2) Grad = 2 * log_sum_exp_logits * softmax(logits) """ - grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + if divisor is None: + divisor = logits.shape[:-1].numel() + grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor logits_norm, exp_logits, sum_exp_logits, logits_max = fused_softmax_base(logits, logits_scale_factor, group) log_sum_exp_logits = sum_exp_logits.log() + logits_max - per_sample_loss = log_sum_exp_logits**2 - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss = reduce_losses(log_sum_exp_logits**2, divisor, loss_mask) if grad_output is not None: grad_base = 2 * grad_output * (log_sum_exp_logits / sum_exp_logits) diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index c7be11b70..9766182b8 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -99,10 +99,10 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: if self._enabled: self._layers_with_namespace[0].preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: return ( - self.blocks[0].get_loss_definitions(count=count * (self._config.prediction_heads - 1)) - + [loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count)] + self.blocks[0].get_loss_definitions() + + [loss_definition for head in self.heads for loss_definition in head.get_loss_definitions()] if self._enabled else [] ) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index f694d80a6..cf5bc0bc4 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -227,7 +227,7 @@ def __init__( self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft ) - if _fast_gdn_available: + if _fast_gdn_available and distributed_config.use_cuda: self.chunk_gated_delta_rule = chunk_gated_delta_rule else: logger.warning( diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py index 3116702e6..0b94beec9 100644 --- a/fast_llm/layers/vision/vision_encoder.py +++ b/fast_llm/layers/vision/vision_encoder.py @@ -69,12 +69,12 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.encoder.preprocess(kwargs) self.adapter.preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + def get_loss_definitions(self) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder. TODO: Avoidable? return ( - self.embeddings.get_loss_definitions(count) - + self.encoder.get_loss_definitions(count) - + self.adapter.get_loss_definitions(count) + self.embeddings.get_loss_definitions() + + self.encoder.get_loss_definitions() + + self.adapter.get_loss_definitions() ) @@ -123,8 +123,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._vision_encoder_with_namespace.preprocess(kwargs) super().preprocess(kwargs) - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.vision_encoder.get_loss_definitions(count) + super().get_loss_definitions(count) + def get_loss_definitions(self) -> list[LossDef]: + return self.vision_encoder.get_loss_definitions() + super().get_loss_definitions() @functools.cached_property def _vision_encoder_namespace(self) -> str: diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 3f45c8184..2619883d6 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -112,7 +112,7 @@ def format_metrics( **{key: metrics.pop(key, _NAN) for key in _METRIC_FORMATS_KEYS[phase]}, ) ] - outputs.extend([f"{loss_def.formatted_name}: {metrics.pop(loss_def.name, _NAN):.5f}" for loss_def in loss_defs]) + outputs.extend([f"{loss_def.name}: {metrics.pop(loss_def.name, _NAN):.5f}" for loss_def in loss_defs]) if metrics: outputs.extend([f"{key}: {value}" for key, value in metrics.items()]) diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 5ce91fbac..cb9c5c1f2 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -13,6 +13,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + get_parameter_converter, ) from fast_llm.utils import Assert, safe_merge_dicts @@ -38,7 +39,21 @@ def get_converters( config: LanguageModelConfig, exported_config: dict, ) -> list[WeightConverter]: - converters = super().get_converters(config, exported_config) + # Override: map head.final_norm to model.mtp_norms.0 (not model.norm as in standard Llama), + # since MTPLlamaModel uses mtp_norms[0] for the first prediction head. + converters = [ + *cls.normalization_converter_class.get_converters( + config.head.normalization, + "head.final_norm", + "model.mtp_norms.0", + ), + get_parameter_converter( + "head.output_weights", + "lm_head.weight", + drop_on_import=exported_config["tie_word_embeddings"], + drop_on_export=exported_config["tie_word_embeddings"], + ), + ] for prediction_distance in range(2, config.head.prediction_heads + 1): converters += cls.block_converter_class.get_converters( config.decoder.last_block_config, diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index def664d66..55c30c7ee 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -11,7 +11,6 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -81,7 +80,7 @@ def _get_batch( def _inner_forward( self, - batch: LanguageModelInput, + batch: LanguageModelBatch, input_shape: tuple[int], past_key_values=None, inputs_embeds: torch.FloatTensor | None = None, @@ -114,19 +113,12 @@ def _inner_forward( use_cache, output_hidden_states, ) - ((input_, kwargs),) = self.fast_llm_base_model.preprocess_batch( - [model_input], - phase=PhaseType.inference, - iteration=iteration, - device=self._fast_llm_model.distributed.device, - ) - - self._inference_runner.forward(input_, kwargs, iteration=iteration) + self._inference_runner.forward(model_input, iteration=iteration) # TODO: Make a proper way of returning the model output. hidden_states = { name: meta.local_to_global(tensor)[0].unflatten(0, input_shape) - for name, (meta, tensor) in kwargs[AttentionKwargs.hidden_states].items() + for name, (meta, tensor) in model_input.hidden_states.items() } # TODO: Handle MTP. @@ -135,7 +127,7 @@ def _inner_forward( output = transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states or None, - past_key_values=kwargs[AttentionKwargs.presents], + past_key_values=model_input.presents, ) return ( output diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index fc4537ee7..83abaca21 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -1,3 +1,4 @@ +import dataclasses import functools import logging import re @@ -42,54 +43,46 @@ def __init__( def preprocess_batch( self, - model_inputs: list[LanguageModelInput], + model_input: LanguageModelInput, *, phase: PhaseType, iteration: int, metrics: dict | None = None, extra_kwargs: dict[str, typing.Any] | None = None, - device: torch.device | None, - ) -> list[tuple[torch.Tensor, dict]]: - reference_preprocessed_batches = {} - for name, reference_model in self._reference_models.items(): - reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( - model_inputs, - phase=PhaseType.inference, - iteration=iteration, - device=device, - ) - - preprocessed = [] - for input_index, model_input in enumerate(model_inputs): - if device is not None: - model_input.to_device_(device) - kwargs = model_input.to_kwargs() - kwargs[LanguageModelKwargs.iteration] = iteration - if extra_kwargs is not None: - Assert.empty(kwargs.keys() & extra_kwargs.keys()) - kwargs.update(extra_kwargs) - if phase == PhaseType.inference: - kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) - - if not model_input.is_meta: - for name, reference_model in self._reference_models.items(): - reference_tokens, reference_kwargs = reference_preprocessed_batches[name][input_index] - if name in self._decoder_reference_models: - # TODO: Get the actual names - reference_kwargs[BlockKwargs.output_hidden_states].add( - re.compile(r"decoder\.\d+\.mixer_output$") - ) - - reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - - kwargs[f"reference_{name}_hidden_states"] = { - layer_name: tensor - for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() - } - self.preprocess(kwargs) - preprocessed.append((model_input.tokens, kwargs)) - - return preprocessed + ) -> tuple[torch.Tensor, dict]: + if not model_input.is_meta: + model_input.to_device_(self._distributed.device) + kwargs = model_input.to_kwargs() + kwargs[LanguageModelKwargs.iteration] = iteration + if extra_kwargs is not None: + Assert.empty(kwargs.keys() & extra_kwargs.keys()) + kwargs.update(extra_kwargs) + if phase == PhaseType.inference: + kwargs[BlockKwargs.output_hidden_states].add(re.compile(r"head\..*logits.*$")) + + if not model_input.is_meta: + for name, reference_model in self._reference_models.items(): + output_hidden_states = set() + if name in self._head_reference_models: + output_hidden_states.add(re.compile(r"head\..*logits.*$")) + if name in self._decoder_reference_models: + # TODO: Get the actual names + output_hidden_states.add(re.compile(r"decoder\.\d+\.mixer_output$")) + assert len(output_hidden_states) >= 1 + reference_model_input = dataclasses.replace( + model_input, + output_hidden_states=output_hidden_states, + hidden_states={}, + ) + reference_model_input.set_children_attributes() + reference_model.forward(reference_model_input, iteration=iteration) + + kwargs[f"reference_{name}_hidden_states"] = { + layer_name: tensor for layer_name, (meta, tensor) in reference_model_input.hidden_states.items() + } + self.preprocess(kwargs) + + return model_input.tokens, kwargs def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # TODO: Integrate to the `LayerBase` interface, move to `LanguageModel`, `MultiTokenPrediction`? diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index ea0611953..9e82dfc4f 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -2839,7 +2839,7 @@ def forward( # Reshape back to [batch, num_patches, text_hidden] image_features = image_features.squeeze(0).view(batch_size, num_patches_per_image, -1) - return image_features, (*all_hidden_states, hidden_states, image_features) + return image_features, (*all_hidden_states, hidden_states, image_features) if output_hidden_states else None class SimpleMLP(nn.Module): diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index c5268f23c..8734aa02c 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -481,7 +481,7 @@ def test_batch_processing_behavior(self, model_pair): with torch.no_grad(): # Batch processing batch_src = get_pixtral_vision_features(source, pixel_values) - batch_tgt, _ = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1]) + batch_tgt = target.get_image_features(pixel_values)[0].view(-1, batch_src.shape[-1]) # Sequential processing singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)] diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index d0e56e3f0..ae58121ae 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -1,55 +1,418 @@ +import dataclasses +import functools + import pytest import torch from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument from fast_llm.data.document.range import RangeDocument -from fast_llm.utils import Assert +from fast_llm.data.document.token_data import TokenDataDocument +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.utils import Assert, div -# TODO: Test padding, more scenarios -# TODO: Check rest of preprocessing output -@pytest.mark.parametrize( - ("tokens", "loss_masking_spans"), - ( - ([[100, 101, 102, 103, 104, 105, 106, 107]], [None]), # Simple case - ([[100, 101, -100, -100, 104, 105, 106, 107]], [None]), # Negative tokens - ([[100, 101, 102, 103, 104, 105, 106, 107]], [[(3, 5)]]), # Loss masking span - ([[100, 101, 102, 103, -100, -100, 106, 107]], [[(2, 3)]]), # Both - ( +def _get_cropped_lengths(batch_lengths: list[int], begin: int, end: int) -> tuple[list[int], int]: + """Return (cropped_lengths, first_document_begin) for the token window [begin, end).""" + doc_begin = 0 + cropped = [] + first_doc_begin = 0 + for length in batch_lengths: + doc_end = doc_begin + length + crop = min(doc_end, end) - max(doc_begin, begin) + if crop > 0: + if not cropped: + first_doc_begin = doc_begin + cropped.append(crop) + if doc_end > end: + break + doc_begin = doc_end + return cropped, first_doc_begin + + +def _compute_label_counts(batch_lengths: list[int], labels: list[int]) -> torch.Tensor: + """For each token, compute the count of valid (non-negative) labels in its document.""" + result = [] + offset = 0 + for length in batch_lengths: + count = sum(1 for label in labels[offset : offset + length] if label >= 0) + result.extend([count] * length) + offset += length + return torch.tensor(result, dtype=torch.int64) + + +def _assert_tensor_equal_or_none(actual: torch.Tensor | None, expected: torch.Tensor | None) -> None: + if expected is None: + assert actual is None + else: + Assert.all_equal(actual, expected) + + +@dataclasses.dataclass +class PreprocessingTestConfig: + name: str + tokens: list[list[int]] + loss_masking_spans: list[list[tuple[int, int]] | None] | None = None + padding: int | None = None + advantages: list[list[float]] | None = None + log_probabilities: list[list[float]] | None = None + phase: PhaseType = PhaseType.training + predicted_tokens: int = 1 + micro_batch_splits: int = 1 + use_grpo_data: bool = False + return_prediction_mask: bool = False + return_label_counts: bool = False + return_position_index: bool = False + return_document_count: bool = False + return_cumulative_sequence_lengths: bool = False + + @functools.cached_property + def config_kwargs(self) -> dict: + return { + "phase": self.phase, + "predicted_tokens": self.predicted_tokens, + "micro_batch_splits": self.micro_batch_splits, + "use_grpo_data": self.use_grpo_data, + "return_prediction_mask": self.return_prediction_mask, + "return_label_counts": self.return_label_counts, + "return_position_index": self.return_position_index, + "return_document_count": self.return_document_count, + "return_cumulative_sequence_lengths": self.return_cumulative_sequence_lengths, + } + + @functools.cached_property + def padding_size(self) -> int: + return 0 if self.padding is None else self.padding + + @functools.cached_property + def unpadded_size(self) -> int: + return sum(self.unpadded_lengths) + + @functools.cached_property + def padded_size(self) -> int: + return self.unpadded_size + self.padding_size + + @functools.cached_property + def unpadded_lengths(self) -> list[int]: + return [len(tokens) for tokens in self.tokens] + + @functools.cached_property + def padded_lengths(self) -> list[int]: + return self.unpadded_lengths + ([self.padding_size] if self.padding_size > 0 else []) + + @functools.cached_property + def num_labels(self) -> int: + return self.padded_size - self.predicted_tokens + + @functools.cached_property + def split_size(self) -> int: + return div(self.num_labels, self.micro_batch_splits) + + @functools.cached_property + def all_flat_tokens(self) -> list[int]: + return sum(self.tokens, []) + [-100] * self.padding_size + + @functools.cached_property + def base_labels(self) -> list[int]: + """Tokens with loss masking spans applied, but no cross-document masking.""" + labels = list(self.all_flat_tokens) + if self.loss_masking_spans is not None: + offset = 0 + for doc_tokens, spans in zip(self.tokens, self.loss_masking_spans): + if spans is not None: + for begin, end in spans: + labels[offset + begin : offset + end] = [-100] * (end - begin) + offset += len(doc_tokens) + return labels + + @functools.cached_property + def labels_per_distance(self) -> list[list[int]]: + """For each prediction distance d, labels with cumulative cross-document masking.""" + result = [] + labels = list(self.base_labels) + for d in range(1, self.predicted_tokens + 1): + offset = 0 + for doc_tokens in self.tokens: + if d <= len(doc_tokens): + labels[offset + d - 1] = -100 + offset += len(doc_tokens) + result.append(list(labels)) + return result + + @functools.cached_property + def _split_ranges(self) -> list[tuple[int, int]]: + return [(i * self.split_size, (i + 1) * self.split_size) for i in range(self.micro_batch_splits)] + + @functools.cached_property + def _cropped_lengths_per_split(self) -> list[tuple[list[int], int]]: + return [_get_cropped_lengths(self.padded_lengths, begin, end) for begin, end in self._split_ranges] + + @functools.cached_property + def expected_input_tokens(self) -> list[torch.Tensor]: + all_tokens = torch.tensor(self.all_flat_tokens, dtype=torch.int64) + return [all_tokens[begin:end] for begin, end in self._split_ranges] + + @functools.cached_property + def expected_target_tokens(self) -> list[list[torch.Tensor]]: + labels_tensors = [torch.tensor(labels, dtype=torch.int64) for labels in self.labels_per_distance] + return [ + [ + labels_tensors[target_index][begin + d : end + d] + for target_index, d in enumerate(range(1, self.predicted_tokens + 1)) + ] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_target_mask(self) -> list[list[torch.Tensor | None]]: + if not self.return_prediction_mask: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + return [[tokens >= 0 for tokens in split_targets] for split_targets in self.expected_target_tokens] + + @functools.cached_property + def expected_target_label_counts(self) -> list[list[torch.Tensor | None]]: + if not self.return_label_counts: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + return [ [ - [100, 101, -100, 103, -100, -100, 106, 107], - [100, 101, 102, 103, 104, 105, 106, 107], - ], - [[(2, 3)], None], - ), # Two samples + _compute_label_counts(self.padded_lengths, self.labels_per_distance[target_index])[begin + d : end + d] + for target_index, d in enumerate(range(1, self.predicted_tokens + 1)) + ] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_advantages(self) -> list[list[torch.Tensor | None]]: + if self.advantages is None: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + flat = torch.tensor(sum(self.advantages, []) + [0.0] * self.padding_size, dtype=torch.float32) + return [ + [flat[begin + d : end + d] for d in range(1, self.predicted_tokens + 1)] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_log_probabilities(self) -> list[list[torch.Tensor | None]]: + if self.log_probabilities is None: + return [[None] * self.predicted_tokens for _ in range(self.micro_batch_splits)] + flat = torch.tensor(sum(self.log_probabilities, []) + [0.0] * self.padding_size, dtype=torch.float32) + return [ + [flat[begin + d : end + d] for d in range(1, self.predicted_tokens + 1)] + for begin, end in self._split_ranges + ] + + @functools.cached_property + def expected_position_index(self) -> list[torch.Tensor | None]: + if not self.return_position_index: + return [None] * self.micro_batch_splits + result = [] + for split_index, (begin, _end) in enumerate(self._split_ranges): + cropped_lengths, first_doc_begin = self._cropped_lengths_per_split[split_index] + pos_in_doc = begin - first_doc_begin + positions = [] + remaining = cropped_lengths[0] if cropped_lengths else 0 + doc_index = 0 + for _ in range(self.split_size): + positions.append(pos_in_doc) + pos_in_doc += 1 + remaining -= 1 + if remaining == 0 and doc_index + 1 < len(cropped_lengths): + doc_index += 1 + remaining = cropped_lengths[doc_index] + pos_in_doc = 0 + result.append(torch.tensor(positions, dtype=torch.int32)) + return result + + @functools.cached_property + def expected_cumulative_lengths(self) -> list[tuple[torch.Tensor | None, torch.Tensor | None]]: + if not self.return_cumulative_sequence_lengths: + return [(None, None)] * self.micro_batch_splits + result = [] + for split_index, (begin, _end) in enumerate(self._split_ranges): + cropped_lengths, first_doc_begin = self._cropped_lengths_per_split[split_index] + cu_q = torch.tensor([0] + cropped_lengths, dtype=torch.int32).cumsum(dim=0) + cu_k = (cu_q + begin).clone() + cu_k[0] = first_doc_begin + result.append((cu_q, cu_k)) + return result + + @functools.cached_property + def expected_num_documents(self) -> list[int | None]: + if self.return_document_count: + return [len(self.tokens) if split_index == 0 else 0 for split_index in range(self.micro_batch_splits)] + else: + return [None] * self.micro_batch_splits + + +_BASE_TEST_CASES = [ + PreprocessingTestConfig( + name="simple", + tokens=[[100, 101, 102, 103, 104, 105, 106, 107, 108]], + ), + PreprocessingTestConfig( + name="negative_tokens", + tokens=[[100, 101, -100, -100, 104, 105, 106, 107, 108]], + ), + PreprocessingTestConfig( + name="loss_masking_span", + tokens=[[100, 101, 102, 103, 104, 105, 106, 107, 108]], + loss_masking_spans=[[(3, 5)]], + ), + PreprocessingTestConfig( + name="negative_tokens_and_loss_masking", + tokens=[[100, 101, 102, 103, -100, -100, 106, 107, 108]], + loss_masking_spans=[[(2, 3)]], + ), + PreprocessingTestConfig( + name="two_documents", + tokens=[[100, 101, -100, 103, -100, -100, 106, 107], [100, 101, 102, 103, 104, 105, 106, 107, 108]], + loss_masking_spans=[[(2, 3)], None], + ), + PreprocessingTestConfig( + name="three_documents", + tokens=[[100, 101, 102], [103, 104, 105], [106, 107, 108]], + loss_masking_spans=[[(1, 2)], None, [(0, 2)]], + ), + PreprocessingTestConfig( + # Document of length 1 is shorter than predicted_tokens=3; cross-document masking must not go out of bounds. + name="short_document", + tokens=[[100], [101, 102, 103, 104, 105, 106, 107, 108]], + ), + PreprocessingTestConfig( + name="multiple_loss_masking_spans", + tokens=[[100, 101, 102, 103, 104, 105, 106, 107, 108]], + loss_masking_spans=[[(1, 3), (5, 7)]], + ), + PreprocessingTestConfig( + # use_grpo_data attaches per-token advantages and log-probabilities to the target. + name="grpo_data", + tokens=[[100, 101, 102, 103, 104, 105, 106]], + advantages=[[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]], + log_probabilities=[[-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7]], + use_grpo_data=True, + ), + PreprocessingTestConfig( + name="two_documents_grpo_data", + tokens=[[100, 101, 102, 103], [104, 105, 106, 107, 108]], + advantages=[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8, 0.9]], + log_probabilities=[[-0.1, -0.2, -0.3, -0.4], [-0.5, -0.6, -0.7, -0.8, -0.9]], + use_grpo_data=True, + ), + PreprocessingTestConfig( + # In inference phase num_labels=0, so the full token sequence is the input and there are no targets. + name="inference", + tokens=[[100, 101, 102, 103, 104]], + phase=PhaseType.inference, ), +] + +# Each base case is run with each return configuration and both values of micro_batch_splits, +# except inference which has no labels to return or split. +_RETURN_CONFIG_VARIANTS: dict[str, dict] = { + "": {}, + "prediction_mask": {"return_prediction_mask": True}, + "label_counts": {"return_label_counts": True}, + "position_index": {"return_position_index": True}, + "document_count": {"return_document_count": True}, + "cumulative_sequence_lengths": {"return_cumulative_sequence_lengths": True}, + "all": { + "return_prediction_mask": True, + "return_label_counts": True, + "return_position_index": True, + "return_document_count": True, + "return_cumulative_sequence_lengths": True, + }, +} + + +def _make_name( + base_name: str, return_name: str, predicted_tokens: int, micro_batch_splits: int, padding: int | None +) -> str: + parts = [base_name] + if return_name: + parts.append(f"return_{return_name}") + if predicted_tokens > 1: + parts.append(f"predicted_tokens_{predicted_tokens}") + if micro_batch_splits > 1: + parts.append(f"splits_{micro_batch_splits}") + if padding is not None: + parts.append(f"padding_{padding}") + return "_".join(parts) + + +_PREPROCESSING_TEST_CASES = [ + dataclasses.replace( + base_case, + name=_make_name(base_case.name, return_name, predicted_tokens, micro_batch_splits, padding), + predicted_tokens=predicted_tokens, + micro_batch_splits=micro_batch_splits, + padding=padding, + **return_config, + ) + for base_case in _BASE_TEST_CASES + for return_name, return_config in _RETURN_CONFIG_VARIANTS.items() + for predicted_tokens in (1, 3) + for micro_batch_splits in (1, 2) + for padding in (None, 0, 2) + if base_case.phase != PhaseType.inference + or (not return_config and predicted_tokens == 1 and micro_batch_splits == 1) +] + + +@pytest.mark.parametrize( + "test_config", [pytest.param(test_config, id=test_config.name) for test_config in _PREPROCESSING_TEST_CASES] ) -def test_preprocessing(tokens, loss_masking_spans): +def test_preprocessing(test_config: PreprocessingTestConfig): + config = LanguageModelBatchPreprocessingConfig(**test_config.config_kwargs) + documents = [ LanguageModelDocument( - tokens=torch.tensor(tokens_, dtype=torch.int64), - loss_masking_spans=None if loss_masking_spans_ is None else RangeDocument(ranges=loss_masking_spans_), + tokens=torch.tensor(tokens, dtype=torch.int64), + loss_masking_spans=None if spans is None else RangeDocument(ranges=spans), + advantages=None if doc_advantages is None else TokenDataDocument(data=torch.tensor(doc_advantages)), + old_log_probabilities=( + None if doc_log_probs is None else TokenDataDocument(data=torch.tensor(doc_log_probs)) + ), + ) + for tokens, spans, doc_advantages, doc_log_probs in zip( + test_config.tokens, + test_config.loss_masking_spans or [None] * len(test_config.tokens), + test_config.advantages or [None] * len(test_config.tokens), + test_config.log_probabilities or [None] * len(test_config.tokens), + strict=True, ) - for tokens_, loss_masking_spans_ in zip(tokens, loss_masking_spans, strict=True) ] - (model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs( - LanguageModelBatchPreprocessingConfig() + batch = LanguageModelBatch.from_documents( + documents, pad_to_size=test_config.padded_size if test_config.padding is not None else None ) + model_inputs = batch.get_model_inputs(config) + + # Inference: full token sequence as input, no targets. + if config.phase == PhaseType.inference: + Assert.eq(len(model_inputs), 1) + Assert.all_equal(model_inputs[0].tokens, batch.tokens) + Assert.eq(len(model_inputs[0].targets), 0) + return + + Assert.eq(len(model_inputs), test_config.micro_batch_splits) + for split_index, model_input in enumerate(model_inputs): + Assert.all_equal(model_input.tokens, test_config.expected_input_tokens[split_index]) + Assert.eq(len(model_input.targets), test_config.predicted_tokens) + + for target_index, target in enumerate(model_input.targets): + Assert.all_equal(target.tokens, test_config.expected_target_tokens[split_index][target_index]) + _assert_tensor_equal_or_none(target.mask, test_config.expected_target_mask[split_index][target_index]) + _assert_tensor_equal_or_none( + target.label_counts, test_config.expected_target_label_counts[split_index][target_index] + ) + _assert_tensor_equal_or_none(target.advantages, test_config.expected_advantages[split_index][target_index]) + _assert_tensor_equal_or_none( + target.old_log_probabilities, test_config.expected_log_probabilities[split_index][target_index] + ) - Assert.all_equal(model_input.tokens, torch.cat([document.tokens for document in documents])[:-1]) - - label_tokens = [] - for document in documents: - label_tokens_ = document.tokens.clone() - # Mask cross-document attention - label_tokens_[0] = -100 - # Loss masking spans - if document.loss_masking_spans is not None: - for begin, end in document.loss_masking_spans.ranges: - label_tokens_[begin:end] = -100 - label_tokens.append(label_tokens_) - - Assert.eq(len(model_input.targets), 1) - Assert.all_equal(model_input.targets[0].tokens, torch.cat(label_tokens)[1:]) + _assert_tensor_equal_or_none(model_input.position_index, test_config.expected_position_index[split_index]) + cu_q, cu_k = test_config.expected_cumulative_lengths[split_index] + _assert_tensor_equal_or_none(model_input.cumulative_lengths_q, cu_q) + _assert_tensor_equal_or_none(model_input.cumulative_lengths_k, cu_k) + Assert.eq(model_input.num_documents, test_config.expected_num_documents[split_index]) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index c7088eae3..83f7657a0 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -144,7 +144,7 @@ def test_streaming_sampled_dataset( assert batch.old_log_probabilities is None -_NUM_BATCHES = 2 +_NUM_BATCHES = 10 _SEQUENCE_LENGTH = 10 @@ -160,7 +160,9 @@ def _get_distributed_config(distributed_config_dict: dict[str, typing.Any], worl ) -def _run_test_data_streaming(path: pathlib.Path, distributed_config: DistributedConfig, port: int): +def _run_test_data_streaming( + path: pathlib.Path, distributed_config: DistributedConfig, port: int, num_workers: int = 1 +): redis_config = RedisConfig(port=port + 100, timeout=1) data = GPTData( @@ -186,7 +188,7 @@ def _run_test_data_streaming(path: pathlib.Path, distributed_config: Distributed distributed_config.batch_data_parallel * _NUM_BATCHES, ) data_iter = data.get_iterator( - "train", consumed_samples=0, num_workers=0, prefetch_factor=None, timeout=5, preprocess=False + "train", consumed_samples=0, num_workers=num_workers, prefetch_factor=None, timeout=5, preprocess=False ) batches = [next(data_iter) for _ in range(_NUM_BATCHES)] path.mkdir(parents=True, exist_ok=True) @@ -228,10 +230,11 @@ def _run_test_data_streaming_distributed( _run_test_data_streaming(base_path / name, distributed_config, port) -def test_data_streaming(result_path, worker_resources): +@pytest.mark.parametrize("num_workers", (0, 1)) +def test_data_streaming(result_path, worker_resources, num_workers): distributed_config = _get_distributed_config({}) path = result_path / "data_streaming/single_gpu" - _run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port) + _run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port, num_workers) check_data_streaming_results(path, distributed_config) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c2bde6a8b..73e9f4807 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -102,6 +102,11 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: torch.randint(0, 2, (NUM_TOKENS,), dtype=torch.bool, device=device) for _ in range(self.prediction_heads) ] + kwargs[LanguageModelKwargs.num_labels_in_batch] = [ + loss_mask.sum().item() for loss_mask in kwargs[LanguageModelKwargs.loss_mask] + ] + else: + kwargs[LanguageModelKwargs.num_labels_in_batch] = [NUM_TOKENS for _ in range(self.prediction_heads)] if self.actual_label_loss is not False or self.grpo_loss is not False: labels = [ torch.randint( @@ -166,32 +171,34 @@ def get_reference_outputs( names_losses_weights = [] + loss_mask = ( + kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] + if LanguageModelKwargs.loss_mask in kwargs + else None + ) + if self.actual_label_loss is not False or self.grpo_loss is not False: labels = kwargs[LanguageModelKwargs.labels][head._prediction_distance - 1] if self.actual_label_loss is not False: - label_loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none").mean() + label_loss = torch.nn.functional.cross_entropy(logits, labels) names_losses_weights.append(("label", label_loss, float(self.actual_label_loss))) - # total_loss = total_loss + float(self.actual_label_loss) * label_loss if self.distillation_loss is not False: distillation_loss = torch.nn.functional.cross_entropy( logits, torch.softmax(kwargs[f"reference_distillation_hidden_states"]["head.logits"], -1), - reduction="none", + reduction="mean" if loss_mask is None else "none", ) - if LanguageModelKwargs.loss_mask in kwargs: - distillation_loss = ( - distillation_loss * kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] - ) - distillation_loss = distillation_loss.mean() + if loss_mask is not None: + distillation_loss = (distillation_loss * loss_mask).sum() / loss_mask.sum() names_losses_weights.append(("distillation", distillation_loss, float(self.distillation_loss))) if self.z_loss is not False: z_loss = torch.logsumexp(logits, dim=-1) ** 2 - if LanguageModelKwargs.loss_mask in kwargs: - z_loss = z_loss * kwargs[LanguageModelKwargs.loss_mask][head._prediction_distance - 1] - z_loss = z_loss.mean() + if loss_mask is not None: + z_loss = z_loss * loss_mask + z_loss = z_loss.mean() if loss_mask is None else (z_loss * loss_mask).sum() / loss_mask.sum() names_losses_weights.append(("z_loss", z_loss, float(self.z_loss))) if self.grpo_loss is not False: @@ -317,7 +324,6 @@ def test_lm_head(test_config: LMHeadTestConfig): losses = collections.defaultdict(list) output, context = stage.forward(head_input, kwargs, losses) - print(losses) stage.backward(output_grad, context) threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( @@ -330,8 +336,8 @@ def test_lm_head(test_config: LMHeadTestConfig): Assert.eq(losses.keys(), ref_losses.keys(), loss_definitions.keys()) losses = { - name: loss[0] if len(loss) == 1 else torch.stack(loss).sum() / loss_definitions[name].count - for name, loss in losses.items() + name: loss_definition.reduce(losses[name], distributed) + for name, loss_definition in loss_definitions.items() } for name, loss in losses.items(): diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index a719e44e8..3a68a999f 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -147,7 +147,7 @@ def reference_grpo_loss( # new_logprobs: sum of per-sequence mean log-probs log_probs = torch.nn.functional.log_softmax(logits_, -1).gather(-1, labels.unsqueeze(-1)).squeeze(-1) new_logprobs = (log_probs * loss_mask).sum() / max(float(loss_mask.sum()), 1.0) - return (loss * loss_mask).mean(), new_logprobs + return (loss * loss_mask).sum() / loss_mask.sum(), new_logprobs _BATCH_SHAPES = ((64,), (16, 8)) @@ -272,6 +272,7 @@ def _test_grpo_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, + divisor=(target >= 0).sum().item(), ) _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 74c51719d..5f0f5a80f 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -472,13 +472,12 @@ def test_save_and_load_in_parallel(run_parallel_script, run_test_script_base_pat # Save and load checkpoints to and from various distributed configurations. # Combined in a single test to mitigate process creation overhead. # TODO: Test beyond 2 gpu configs? - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs2") run_parallel_script( _save_and_load_in_parallel, (run_test_script_base_path, model_testing_config), world_size=2, backend=model_testing_config.distributed_backend, + use_cuda=torch.cuda.is_available(), ) @@ -503,6 +502,7 @@ def test_load_parallel_checkpoint_in_single_gpu( load_and_compare_checkpoints, reference_distributed_shard, report_subtest, + testing_device, ): if ( model_testing_config.checkpoint_format is None @@ -514,16 +514,16 @@ def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config = distributed_save_load_config.resolve( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) - if torch.cuda.device_count() < distributed_save_load_config.num_gpus: - pytest.skip( - f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}" - ) - report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus) + report_subtest( + distributed_save_load_config.save_path, + distributed_save_load_config.num_gpus, + use_cuda=torch.cuda.is_available(), + ) load_and_compare_checkpoints( DistributedCheckpointFormat, distributed_save_load_config.save_path / DistributedCheckpointFormat.name, None, - reference_distributed_shard.to(device="cuda"), + reference_distributed_shard.to(device=testing_device), ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 0c58afade..f3a9a1d7c 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -68,13 +68,12 @@ def _run_model_distributed( ModelTestingGroup.distributed, ) def test_run_model_distributed(run_parallel_script, model_testing_config, run_test_script_base_path): - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs") run_parallel_script( _run_model_distributed, (run_test_script_base_path, model_testing_config), - world_size=torch.cuda.device_count(), + world_size=torch.cuda.device_count() if torch.cuda.is_available() else 8, backend=model_testing_config.distributed_backend, + use_cuda=torch.cuda.is_available(), ) @@ -94,9 +93,7 @@ def test_model_distributed( config = DISTRIBUTED_TESTING_CONFIGS[config_name] if model_testing_config.should_skip(config): pytest.skip(f"Configuration not supported.") - if torch.cuda.device_count() < config.num_gpus: - pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < {config.num_gpus}") - report_subtest(run_test_script_base_path / config.name, config.num_gpus) + report_subtest(run_test_script_base_path / config.name, config.num_gpus, use_cuda=torch.cuda.is_available()) if config.compare is not None: if not check_subtest_success(run_test_script_base_path / config.compare): pytest.fail(f"Test {config.compare} failed", pytrace=False) diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py index 7b39a62f2..0c40f0a48 100644 --- a/tests/models/test_streaming.py +++ b/tests/models/test_streaming.py @@ -132,7 +132,7 @@ def _run_model_streaming_configs( model_testing_config, None, updates={ - ("data", "datasets"): {"training": {"port": port}}, + ("data", "datasets"): {"training": {"port": port, "timeout": 1.0}}, ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1}, "callbacks": { "streaming": { @@ -143,6 +143,7 @@ def _run_model_streaming_configs( "external_world_size": config.consumer_count, }, "export": {"format": model_testing_config.checkpoint_format.name}, + "timeout": 1.0, } }, # Disable tensor logging. diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index b085f0994..405aa1bcd 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -18,6 +18,7 @@ class DistributedTestingConfig: compare_config: CompareConfig | None = None # Scale the comparison thresholds for specific distributed configs. compare_factor: float = 1.0 + requires_cuda: bool = False def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareConfig: @@ -53,6 +54,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon _compare_layer_mismatch_duplicate_gradients = copy.deepcopy(_compare_layer_mismatch) _compare_layer_mismatch_duplicate_gradients.sub_configs[(None, "bias")].ignore_duplicates = True _compare_layer_mismatch_duplicate_gradients.sub_configs[(None, "gradient")].ignore_duplicates = True +_pp_tied_weight_compare.sub_configs[(None, "bias")].ignore_duplicates = True _pp_tied_weight_compare.sub_configs[(None, "gradient")].ignore_duplicates = True _pp_tied_weight_compare.sub_configs[("init", None)].ignore_duplicates = True for tensor in ("fw", "bw"): @@ -70,7 +72,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon if torch.cuda.is_available() else { (None, "norm"): get_config(ignore_tensors=True), - (None, "word_embeddings_weight"): get_config(8e-2, 1e-4), + (None, "embeddings_weight"): get_config(8e-2, 1e-4), } ), (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(2e-2, 2e-3), @@ -113,6 +115,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ) _SINGLE_GPU_TESTING_CONFIGS = [ + # TODO: 16-bit matmuls extremely slow on cpu DistributedTestingConfig( name="bf16", compare="simple", @@ -124,6 +127,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ], num_gpus=1, compare_config=_bf16_compare, + requires_cuda=True, ), DistributedTestingConfig( name="fp16", @@ -131,6 +135,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=["model.distributed.compute_dtype=fp16", "data.micro_batch_size=4096"], num_gpus=1, compare_config=_fp16_compare, + requires_cuda=True, ), # Cross-entropy splits. DistributedTestingConfig( @@ -397,17 +402,17 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Simple DistributedTestingConfig( name="dp2_stp2_pp2s2_bf4", - compare="dp2_z2_df4", + compare="df8", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "schedule.breadth_first_micro_batches=4", - "data.micro_batch_size=412", + "data.micro_batch_size=512", ], num_gpus=8, - compare_config=_compare_layer_match, + compare_config=_compare_layer_mismatch, ), # Tied weights on different ranks DistributedTestingConfig( @@ -427,7 +432,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Micro-sequence DistributedTestingConfig( name="sdp2_stp2_pp2s2_ms4", - compare="df2", + compare="simple", config_args=[ "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", @@ -435,7 +440,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2", "schedule.micro_batch_splits=4", - "data.micro_batch_size=2048", + "data.micro_batch_size=4096", ], num_gpus=8, compare_config=_compare_layer_mismatch, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42802f1c7..6268ac194 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -154,7 +154,9 @@ def distributed_backend(self): return DistributedBackend(self.config_dict["model"]["distributed"]["backend"]) def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: - return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) + return (distributed_config.requires_cuda and not torch.cuda.is_available()) or any( + re.search(pattern, distributed_config.name) for pattern in self.skip_tests + ) def update_and_add_testing_config( @@ -264,7 +266,7 @@ def update_and_add_testing_config( "distributed": { "reproducible_init": True, "timeout": 20, - "backend": "nccl", + "backend": DistributedBackend.nccl if torch.cuda.device_count() >= 2 else DistributedBackend.gloo, "use_cuda": torch.cuda.is_available(), }, }, @@ -802,7 +804,7 @@ def update_and_add_testing_config( # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), - requires_cuda=False, + requires_cuda=True, # GDN available on CPU, but not in the converted model (also runs very slow). ) _gdn_block = MODEL_CONFIGS["apriel2_gdn"].config_dict["model"]["base_model"]["decoder"]["block"] diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 8160ef8c0..6004425dc 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -66,8 +66,6 @@ def producer_loop(): @contextlib.contextmanager def fake_redis_server(config: RedisConfig): - # We search for free port as port from previous test can still be not free even after server shutdown - # ----- Monkey-patch handler to suppress broken pipes ----- orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle @@ -83,6 +81,34 @@ def safe_handle(self): fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle + # ----- Monkey-patch setup to use Resp2Writer instead of Resp3Writer ----- + # fakeredis 2.34+ hardcodes Resp3Writer for all connections, causing blocked + # XREADGROUP timeouts to return RESP3 null (b'_\r\n') even on RESP2 connections + # (i.e. clients that never sent HELLO 3). The redis-py RESP2 parser raises + # Protocol Error: b'_' on this byte. Fix: replace with Resp2Writer at setup time. + # The Resp2Writer class was introduced alongside the bug in 2.34, so use its + # presence as the version guard. + orig_setup = fakeredis._tcp_server.TCPFakeRequestHandler.setup + if hasattr(fakeredis._tcp_server, "Resp3Writer"): + # fakeredis 2.34+ hardcodes Resp3Writer for all connections, causing blocked + # XREADGROUP timeouts to return RESP3 null (b'_\r\n') even on RESP2 connections + # (i.e. clients that never sent HELLO 3). The redis-py RESP2 parser raises + # Protocol Error: b'_' on this byte. Fix: replace with Resp2Writer at setup time. + if not hasattr(fakeredis._tcp_server, "Resp2Writer"): + raise RuntimeError( + f"fakeredis {fakeredis.__version__} has Resp3Writer but not Resp2Writer — " + "the workaround for the RESP2/RESP3 null encoding bug no longer applies. " + "See tests/utils/redis.py for details." + ) + + def resp2_setup(self): + orig_setup(self) + if not isinstance(self.writer, fakeredis._tcp_server.Resp2Writer): + self.writer = fakeredis._tcp_server.Resp2Writer(self.client_address, self.wfile, self) + self.current_client.writer = self.writer + + fakeredis._tcp_server.TCPFakeRequestHandler.setup = resp2_setup + server = fakeredis.TcpFakeServer((config.host, config.port), server_type="redis") print(f"Starting fake redis server at {config.host}:{config.port}") thread = threading.Thread(target=server.serve_forever, daemon=True) @@ -96,3 +122,5 @@ def safe_handle(self): server.shutdown() server.server_close() thread.join() + fakeredis._tcp_server.TCPFakeRequestHandler.setup = orig_setup + fakeredis._tcp_server.TCPFakeRequestHandler.handle = orig_handle diff --git a/tests/utils/save_load_configs.py b/tests/utils/save_load_configs.py index 3e7cbf10f..6bc619825 100644 --- a/tests/utils/save_load_configs.py +++ b/tests/utils/save_load_configs.py @@ -4,6 +4,7 @@ import typing import pytest +import torch from fast_llm.engine.checkpoint.config import CheckpointFormat, DistributedCheckpointFormat, FastLLMCheckpointFormat from tests.utils.model_configs import ModelTestingConfig @@ -17,6 +18,10 @@ class DistributedSaveLoadConfig: distributed: dict[str, typing.Any] num_gpus: int = 2 + def __post_init__(self): + self.distributed["use_cuda"] = torch.cuda.is_available() + self.distributed["backend"] = "nccl" if torch.cuda.device_count() >= self.num_gpus else "gloo" + def resolve(self, base_path: pathlib.Path, model_testing_config: ModelTestingConfig) -> typing.Self: if model_testing_config.checkpoint_format is None: format = {