diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index c4796d5188..ee6402dc04 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -296,7 +296,7 @@ def save_model( model_state = ModelState(model, self.config.is_peft) state_dict = model_state.state_dict() - # Convert to HF format if using custom model implementations + # Convert to HF format if using custom model implementations. state_dict = _maybe_adapt_state_dict_to_hf( model_state.model[0], state_dict, @@ -304,6 +304,8 @@ def save_model( device_mesh=self.moe_mesh, v4_compatible=self.config.v4_compatible, ) + # MoE adapters return non-contiguous views; safetensors.save rejects those. + _materialize_to_hf_views_for_save(state_dict) # Build the consolidated model.safetensors.index.json if needed fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict) @@ -507,6 +509,8 @@ def load_model( reader_key_mapping = None if has_state_dict_adapter else key_mapping storage_reader = self._get_storage_reader(model_path, reader_key_mapping, is_init_step=is_init_step) + # MoE adapters return views into model storage; DCP writes safetensors + # data straight through them and from_hf skips the rebuild. state_dict = _maybe_adapt_state_dict_to_hf( model_state.model[0], state_dict, @@ -1578,6 +1582,27 @@ def _maybe_adapt_state_dict_to_hf( return state_dict +def _materialize_to_hf_views_for_save(state_dict: dict[str, torch.Tensor]) -> None: + """Replace non-contiguous tensor values in ``state_dict`` with contiguous copies in place. + + MoE adapters return non-contiguous strided views into the model's grouped + expert storage for the optimized load path; ``safetensors.torch.save`` + (which the DCP HF storage writer calls) rejects non-contiguous tensors, + so we materialize one tensor at a time here with ``empty_cache`` between + iterations. Per-tensor transient is bounded to a single expert weight + instead of allocating the full grouped set up front. + """ + if not state_dict: + return + cuda_available = torch.cuda.is_available() + for key, value in list(state_dict.items()): + if isinstance(value, torch.Tensor) and not value.is_contiguous(): + state_dict[key] = value.contiguous() + del value + if cuda_available: + torch.cuda.empty_cache() + + def _equally_divide_layers(num_shards: int, keys: list[str]) -> dict[str, int]: """ Equally divide the state dict keys into num_shards shards. diff --git a/nemo_automodel/components/distributed/thd_utils.py b/nemo_automodel/components/distributed/thd_utils.py index 3747abd65d..1365160702 100644 --- a/nemo_automodel/components/distributed/thd_utils.py +++ b/nemo_automodel/components/distributed/thd_utils.py @@ -54,11 +54,21 @@ def process_input_for_thd( [total_tokens, hidden_dim] for 3D embeddings - 'labels': Reshaped labels tensor of shape [total_tokens] - 'position_ids': Reshaped tensor of shape [total_tokens] - - 'cu_seqlens': Cumulative padded sequence lengths tensor of shape [num_sequences + 1] (int32) + - 'cu_seqlens': Cumulative REAL sequence lengths tensor of shape [num_sequences + 1] (int32) where num_sequences is the total count of non-padded sequences across the batch. - NOTE: This contains cumulative lengths from seq_lens_padded (not seq_lens) since - CP doesn't support padding between sequences (resulting in NaNs). The labels or loss mask - will ensure that loss is computed correctly. + Built from seq_lens (the unpadded real lengths). When the trailing pack-pad is + purely at the end (cp_size == 1), the last entry is grown to total_tokens to absorb + that pad and avoid TE's ``pad_between_seqs=True`` path; see the absorption block in + the function body for the gate. + - 'cu_seqlens_padded': (optional) Cumulative PADDED sequence lengths tensor of the same + shape as ``cu_seqlens``. Only emitted when it differs from ``cu_seqlens`` after + absorption (i.e., when padding lives between sub-sequences, which is the CP case). + Forwarded to TE as ``cu_seqlens_q_padded`` / ``cu_seqlens_kv_padded`` with + ``pad_between_seqs=True`` so the kernel reads memory offsets from the padded + variant while attending only over the real-length slots. + - 'max_seqlen': Scalar int32 tensor equal to ``max(cu_seqlens[i+1] - cu_seqlens[i])`` + after any absorption. Honors TE's contract that + ``max_seqlen_q >= max(cu_seqlens_q[i+1] - cu_seqlens_q[i])``. - 'padding_mask': Boolean tensor of shape [total_tokens] indicating padding positions - Non-tensor keys from input batch are preserved (e.g., 'qkv_format') @@ -77,8 +87,11 @@ def process_input_for_thd( >>> # result['input_ids'].shape: [12] (2D input collapsed to 1D) >>> # result['labels'].shape: [12] >>> # result['position_ids'].shape: [12] - >>> # result['cu_seqlens']: tensor([0, 4, 6, 12], dtype=torch.int32) + >>> # result['cu_seqlens']: tensor([0, 3, 5, 11], dtype=torch.int32) + >>> # Breakdown: [0] + cumsum([3, 2, 6]) = [0, 3, 5, 11] (from seq_lens — real lengths) + >>> # result['cu_seqlens_padded']: tensor([0, 4, 6, 12], dtype=torch.int32) >>> # Breakdown: [0] + cumsum([4, 2, 6]) = [0, 4, 6, 12] (from seq_lens_padded) + >>> # result['max_seqlen']: tensor(6, dtype=torch.int32) # max slot width in cu_seqlens >>> # result['padding_mask'].shape: [12] """ input_ids = batch["input_ids"] @@ -96,13 +109,13 @@ def process_input_for_thd( input_ids_thd = input_ids.reshape(total_tokens, -1).squeeze(-1) labels_thd = labels.reshape(total_tokens, -1).squeeze(-1) + cu_seqlens = None + cu_seqlens_padded = None + max_seqlen = None if seq_lens is not None: - # Filter out padding values and flatten - # seq_lens shape: [batch_size, num_packs] -> flatten and remove padding values seq_lens_flat = seq_lens.reshape(-1) valid_seq_lens = seq_lens_flat[seq_lens_flat != seq_lens_padding_value] - # Compute cumulative sequence lengths for attention cu_seqlens = torch.cat( [ torch.tensor([0], dtype=valid_seq_lens.dtype, device=valid_seq_lens.device), @@ -112,7 +125,6 @@ def process_input_for_thd( cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(device=valid_seq_lens.device) if seq_lens_padded is not None: - # Same processing for padded sequence lengths seq_lens_padded_flat = seq_lens_padded.reshape(-1) valid_seq_lens_padded = seq_lens_padded_flat[seq_lens_padded_flat != seq_lens_padding_value] @@ -121,16 +133,46 @@ def process_input_for_thd( ) cu_seqlens_padded = cu_seqlens_padded.to(dtype=torch.int32).to(device=valid_seq_lens_padded.device) + # Trailing-only pack-pad (cp_size==1): absorb into cu_seqlens[-1] so + # the emit gate below drops cu_seqlens_padded and TE skips its + # pad_between_seqs=True path. CP>1 differs in multiple entries and + # falls through; both arrays are emitted and TE handles padding. + if ( + cu_seqlens is not None + and cu_seqlens_padded is not None + and cu_seqlens.numel() == cu_seqlens_padded.numel() + and cu_seqlens.numel() > 1 + and torch.equal(cu_seqlens[:-1], cu_seqlens_padded[:-1]) + ): + _total = int(total_tokens) + _real_total = int(cu_seqlens[-1].item()) + if _real_total < _total: + _extended = cu_seqlens.clone() + _extended[-1] = _total + cu_seqlens = _extended + cu_seqlens_padded = cu_seqlens.clone() + + # Compute max_seqlen from the FINAL cu_seqlens to honor TE's contract + # (``max_seqlen_q >= max(cu_seqlens[i+1] - cu_seqlens[i])``, see TE's + # cpp_extensions/fused_attn.py:152-159). + if cu_seqlens is not None and cu_seqlens.numel() > 1: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().to(dtype=torch.int32) + result = { "input_ids": input_ids_thd, "position_ids": position_ids_thd, - # Pass cu_seqlens_padded here since CP doesn't support padding between sequences correctly, the labels or loss mask will ensure that loss is computed correctly. - "cu_seqlens": cu_seqlens_padded, + "cu_seqlens": cu_seqlens, "labels": labels_thd, "padding_mask": (input_ids_thd == padding_token_id), } + # Emit cu_seqlens_padded only when it differs from cu_seqlens — its + # presence is what flips TE's pad_between_seqs=True path in + # attention/utils.py. + if cu_seqlens_padded is not None and not torch.equal(cu_seqlens_padded, cu_seqlens): + result["cu_seqlens_padded"] = cu_seqlens_padded + if max_seqlen is not None: + result["max_seqlen"] = max_seqlen - # Preserve qkv_format and other non-tensor keys from the original batch for key, value in batch.items(): if key not in result and not isinstance(value, torch.Tensor): result[key] = value @@ -175,8 +217,14 @@ def split_batch_into_thd_chunks( - 'input_ids': [num_chunks, tokens_per_chunk] or [num_chunks, tokens_per_chunk, hidden_dim] - 'labels': [num_chunks, tokens_per_chunk] - 'position_ids': [num_chunks, tokens_per_chunk] - - 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (padded with seq_lens_padding_value). - Contains cumulative lengths from seq_lens_padded for CP compatibility. + - 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (right-padded with + seq_lens_padding_value across chunks for rectangularity). Built from seq_lens + (real lengths) per chunk; see ``process_input_for_thd`` for the absorption + semantics applied per chunk. + - 'cu_seqlens_padded': (optional) Same shape, emitted whenever ANY chunk emits it. + For chunks that absorbed (no separate padded variant), this row equals the + chunk's ``cu_seqlens``. + - 'max_seqlen': [num_chunks] per-chunk scalar tensor. - 'padding_mask': [num_chunks, tokens_per_chunk] - Non-tensor keys from input batch are preserved - When num_chunks <= 1: @@ -230,12 +278,21 @@ def pad_and_stack(tensor_list, padding_value): for i in range(num_chunks) ] - # Stack results - return { + stacked: dict = { "input_ids": torch.stack([c["input_ids"] for c in chunk_results]), "labels": torch.stack([c["labels"] for c in chunk_results]), "position_ids": torch.stack([c["position_ids"] for c in chunk_results]), "cu_seqlens": pad_and_stack([c["cu_seqlens"] for c in chunk_results], seq_lens_padding_value), "padding_mask": torch.stack([c["padding_mask"] for c in chunk_results]), - **{k: v for k, v in chunk_results[0].items() if not isinstance(v, torch.Tensor)}, } + # Emit cu_seqlens_padded whenever any chunk emits it; absorbed chunks + # fall back to their cu_seqlens (semantically equal) for rectangularity. + if any("cu_seqlens_padded" in c for c in chunk_results): + stacked["cu_seqlens_padded"] = pad_and_stack( + [c.get("cu_seqlens_padded", c["cu_seqlens"]) for c in chunk_results], + seq_lens_padding_value, + ) + if all("max_seqlen" in c for c in chunk_results): + stacked["max_seqlen"] = torch.stack([c["max_seqlen"] for c in chunk_results]) + stacked.update({k: v for k, v in chunk_results[0].items() if not isinstance(v, torch.Tensor)}) + return stacked diff --git a/nemo_automodel/components/loss/mtp.py b/nemo_automodel/components/loss/mtp.py index 13a3ff41dd..437d6257f1 100644 --- a/nemo_automodel/components/loss/mtp.py +++ b/nemo_automodel/components/loss/mtp.py @@ -32,6 +32,8 @@ def calculate_mtp_loss( scaling_factor: float = 0.1, num_label_tokens: Optional[int] = None, ignore_index: int = -100, + cu_seqlens: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss. @@ -52,6 +54,14 @@ def calculate_mtp_loss( base loss for sum-reduction normalization). ignore_index: Label value masked out of the CE loss for the trailing ``k+1`` rolled positions at depth ``k``. + cu_seqlens: Optional cumulative sequence lengths ``[num_seqs+1]`` + (THD-pack layout). When supplied and ``seq_idx`` is not, builds + a per-token sub-sequence index via searchsorted. Without packing + this can be omitted. + seq_idx: Optional per-token sub-sequence index ``[B, S]`` (or ``[S]``). + Equality classes are what matter; absolute values can be any + ints. Takes precedence over ``cu_seqlens``. Used to mask label + rolls whose source position lies in a different sub-sequence. Returns: Scalar MTP loss with autograd graph. @@ -60,15 +70,62 @@ def calculate_mtp_loss( raise ValueError("Provide exactly one of mtp_per_depth_h or mtp_per_depth_logits") mtp_outputs = mtp_per_depth_logits if mtp_per_depth_logits is not None else mtp_per_depth_h + + # Reconcile per-depth output and label dims for the THD-packed non-PP path: + # the model unsqueezes outputs from ``[T, *]`` back to ``[1, T, *]`` (model.py + # post-MTP-forward), while labels arrive as 1D ``[T]`` from + # ``process_input_for_thd``. ``FusedLinearCrossEntropy`` / ``cut_cross_entropy`` + # asserts ``hidden_states.shape[:-1] == labels.shape`` so squeeze the synthetic + # batch axis when labels are flat. + if labels.dim() == 1: + mtp_outputs = [h.squeeze(0) if (h.dim() == 3 and h.shape[0] == 1) else h for h in mtp_outputs] + D = len(mtp_outputs) cur_labels = labels total = mtp_outputs[0].new_zeros(()) + + if seq_idx is None and cu_seqlens is not None: + cs = cu_seqlens + if cs.dim() == 2 and cs.shape[0] == 1: + cs = cs.squeeze(0) + if cs.dim() == 1: + # Span the full (padded) token axis; cu_seqlens[-1] excludes tail pad. + # Matches the model's mamba seq_idx build (nemotron_v3/layers.py). + total_len = labels.shape[-1] + positions = torch.arange(total_len, device=labels.device) + # ``right=True`` so a position equal to a boundary (the first token + # of sub-seq k, position == cu_seqlens[k]) maps to k, not k-1. + seq_idx = torch.searchsorted(cs[1:].contiguous(), positions, right=True) + if labels.dim() == 2: + seq_idx = seq_idx.unsqueeze(0).expand(labels.shape[0], -1) + elif seq_idx is not None: + if seq_idx.dim() == 1 and labels.dim() == 2: + seq_idx = seq_idx.unsqueeze(0).expand(labels.shape[0], -1) + elif seq_idx.dim() == 2 and labels.dim() == 1 and seq_idx.shape[0] == 1: + seq_idx = seq_idx.squeeze(0) + # Under PP the caller must chunk seq_idx to per-microbatch shape; a + # mismatch is a wiring bug, not a runtime condition to swallow. + if seq_idx.shape != labels.shape: + raise ValueError( + f"calculate_mtp_loss: seq_idx.shape={tuple(seq_idx.shape)} does not " + f"match labels.shape={tuple(labels.shape)}; under PP, chunk seq_idx " + f"into per-microbatch pieces before passing it in." + ) + for k, mtp_output in enumerate(mtp_outputs): cur_labels = roll_tensor(cur_labels, shifts=-1, dim=-1) masked = cur_labels.clone() n_invalid = min(k + 1, masked.shape[-1]) masked[..., -n_invalid:] = ignore_index + # Mask labels whose rolled source (position t+k+1) lives in a + # different sub-seq than position t — predictions across sub-seq + # boundaries are nonsensical. + if seq_idx is not None: + rolled_seq_idx = roll_tensor(seq_idx, shifts=-(k + 1), dim=-1) + cross_seq = rolled_seq_idx != seq_idx + masked = torch.where(cross_seq, torch.full_like(masked, ignore_index), masked) + if mtp_per_depth_logits is not None: if isinstance(loss_fn, FusedLinearCrossEntropy): raise ValueError("MTP logits are incompatible with FusedLinearCrossEntropy") @@ -104,14 +161,40 @@ def calculate_mtp_loss( class PipelineCausalLMLoss(nn.Module): - """Pipeline schedule loss that can add MTP auxiliary CE on the last stage.""" + """Pipeline schedule loss that can add MTP auxiliary CE on the last stage. + + Per-microbatch ``seq_idx`` is read from a trailing element of the + last-stage output tuple — the model appends an ``[B, S] int32`` tail + when MTP is enabled. This binds each microbatch's seq_idx to its loss + call via the PP runtime's output→loss contract, so the wiring is + schedule-agnostic. Legacy ``cu_seqlens`` (THD path) is a fallback for + models that don't emit a seq_idx tail. + """ def __init__(self, loss_fn: nn.Module, model: nn.Module): super().__init__() self.loss_fn = loss_fn self.model = model + # Legacy THD-pack fallback used when the model has no seq_idx tail. + self.cu_seqlens: Optional[torch.Tensor] = None + + @staticmethod + def _extract_seq_idx_tail(output) -> tuple[Optional[torch.Tensor], object]: + """Detect and strip a trailing per-microbatch seq_idx from output. + + Convention: with MTP enabled the last-stage output is + ``(logits, *mtp_per_depth_h, seq_idx)`` with an ``[B, S] int32`` + tail — dtype alone discriminates. + """ + if isinstance(output, tuple) and len(output) > 0: + last = output[-1] + if isinstance(last, torch.Tensor) and last.dtype == torch.int32 and last.dim() == 2: + return last, output[:-1] + return None, output def forward(self, output, labels: torch.Tensor) -> torch.Tensor: + seq_idx_mb, output = self._extract_seq_idx_tail(output) + if isinstance(output, tuple): logits = output[0] hidden_states = None @@ -145,5 +228,7 @@ def forward(self, output, labels: torch.Tensor) -> torch.Tensor: labels=labels, model=self.model, scaling_factor=scaling_factor, + cu_seqlens=self.cu_seqlens, + seq_idx=seq_idx_mb, ) return loss diff --git a/nemo_automodel/components/models/common/mtp/mtp.py b/nemo_automodel/components/models/common/mtp/mtp.py index 834207295b..ae9fb3dc72 100644 --- a/nemo_automodel/components/models/common/mtp/mtp.py +++ b/nemo_automodel/components/models/common/mtp/mtp.py @@ -171,22 +171,40 @@ def pattern_length(self) -> int: def forward( self, - input_ids: torch.LongTensor, hidden_states: torch.Tensor, - embed_fn: Callable[[torch.LongTensor], torch.Tensor], + *, + input_ids: torch.LongTensor | None = None, + embed_fn: Callable[[torch.LongTensor], torch.Tensor] | None = None, + embed_inputs: tuple[torch.Tensor, ...] | None = None, position_ids: torch.LongTensor | None = None, **block_kwargs, ) -> list[torch.Tensor]: """Iterate over MTP depths and return per-depth hidden states. + Two mutually-exclusive input modes: + + * **Single-rank / first-stage PP** (default): pass ``input_ids`` plus + ``embed_fn``. The module rolls ``input_ids`` cumulatively left by 1 + per depth and applies ``embed_fn`` to produce the future-token + embedding for that depth. + * **Final-stage PP**: pass ``embed_inputs`` (a tuple of pre-rolled + per-depth embeddings, length ``num_depths``). Used when the last PP + stage no longer owns ``embed_tokens``; the first PP stage has + already produced the rolled embeddings and propagated them through + the pipeline. + Args: - input_ids: Token ids ``[B, S]`` (or ``[T]`` in THD). Rolled - cumulatively left by 1 per depth. hidden_states: Output of the main model's final norm (``h_0``); shape matches the model's residual stream. + input_ids: Token ids ``[B, S]`` (or ``[T]`` in THD). Rolled + cumulatively left by 1 per depth. Mutually exclusive with + ``embed_inputs``. embed_fn: Callable applied to rolled ``input_ids`` to produce the future-token embedding (typically the model's input embedding - layer). + layer). Required when ``input_ids`` is supplied. + embed_inputs: Optional tuple of ``num_depths`` pre-computed + future-token embeddings, one per depth in MTP order. + Mutually exclusive with ``input_ids``/``embed_fn``. position_ids: Position ids matching ``input_ids``. When supplied, rolled cumulatively per depth in lockstep with ``input_ids`` (so slot ``t`` carries the original position of the rolled @@ -200,6 +218,15 @@ def forward( List of length ``num_depths`` containing the hidden state produced at each depth. """ + if embed_inputs is not None: + if input_ids is not None or embed_fn is not None: + raise ValueError("embed_inputs is mutually exclusive with input_ids/embed_fn") + if len(embed_inputs) != self.num_depths: + raise ValueError(f"embed_inputs length {len(embed_inputs)} does not match num_depths {self.num_depths}") + else: + if input_ids is None or embed_fn is None: + raise ValueError("MTPModule.forward requires either embed_inputs or (input_ids, embed_fn)") + num_iterations = self.num_depths num_sublayers_per_depth = self.pattern_length use_repeated = self.mtp_config.use_repeated_layer @@ -207,11 +234,14 @@ def forward( cur_input_ids = input_ids cur_position_ids = position_ids for depth in range(num_iterations): - cur_input_ids = roll_tensor(cur_input_ids, shifts=-1, dim=-1) + if embed_inputs is not None: + decoder_input = embed_inputs[depth] + else: + cur_input_ids = roll_tensor(cur_input_ids, shifts=-1, dim=-1) + decoder_input = embed_fn(cur_input_ids) if cur_position_ids is not None: cur_position_ids = roll_tensor(cur_position_ids, shifts=-1, dim=-1) - decoder_input = embed_fn(cur_input_ids) physical_depth = 0 if use_repeated else depth for sublayer_idx in range(num_sublayers_per_depth): sublayer = self.layers[physical_depth * num_sublayers_per_depth + sublayer_idx] diff --git a/nemo_automodel/components/models/common/utils.py b/nemo_automodel/components/models/common/utils.py index 5a2ccb9593..515f94db9f 100644 --- a/nemo_automodel/components/models/common/utils.py +++ b/nemo_automodel/components/models/common/utils.py @@ -157,10 +157,6 @@ class BackendConfig: manager instance across MoE layers. dispatcher_async_dispatch: Whether DeepEP/UCCL-EP dispatch should return asynchronously and allocate dispatched tensors on the communication stream. - disable_shared_expert_overlap: When True, run shared experts sequentially on the - current CUDA stream instead of overlapping them on a side stream with the - grouped-expert dispatch. Useful as an escape hatch when the side-stream - overlap interacts poorly with the dispatcher backend. enable_deepep: Deprecated. Use dispatcher="deepep" and experts="gmm" instead. fake_balanced_gate: If True, replace the learned Gate with FakeBalancedGate that assigns tokens to experts without learned routing weights. @@ -190,7 +186,6 @@ class BackendConfig: dispatcher_num_sms: int = 20 dispatcher_share_token_dispatcher: bool = True dispatcher_async_dispatch: bool = False - disable_shared_expert_overlap: bool = False enable_deepep: bool | None = None # Deprecated: use dispatcher="deepep" instead fake_balanced_gate: bool = False # Approximate max/mean load ratios (64 experts, top-8, 4096 tokens): diff --git a/nemo_automodel/components/models/nemotron_v3/layers.py b/nemo_automodel/components/models/nemotron_v3/layers.py index 879fcf2be6..9459a39dfd 100644 --- a/nemo_automodel/components/models/nemotron_v3/layers.py +++ b/nemo_automodel/components/models/nemotron_v3/layers.py @@ -44,6 +44,8 @@ def __init__(self, config, backend: BackendConfig | None = None): self.hidden_size = config.hidden_size self.attention_bias = getattr(config, "attention_bias", False) self.attention_dropout = getattr(config, "attention_dropout", 0.0) + # Cached for debug-print role disambiguation (backbone vs mtp sublayer). + self.num_hidden_layers = int(getattr(config, "num_hidden_layers", 0)) self.q_proj = initialize_linear_module( self.backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, self.attention_bias @@ -280,15 +282,39 @@ def forward( # Build seq_idx for Mamba kernel (marks sequence boundaries for packing / CP). seq_idx = kwargs.get("seq_idx", None) if seq_idx is None and "cu_seqlens" in kwargs: - cu_seqlens = kwargs["cu_seqlens"] - # cu_seqlens from the THD batch is GLOBAL (pre-TE-partitioning). - # When CP is active, the mamba kernel receives the global sequence - # (after all-to-all gather). Scale total_len by cp_size so that - # seq_idx has the correct global length. + # [FIX mbs>1 THD] hidden_states use the PADDED layout (each packed bin padded + # to packed_sequence_size), so seq_idx must segment with cu_seqlens_PADDED. + # The real (contiguous) cu_seqlens is misaligned at mbs>1 — mamba's SSD scan + # would carry state across bin boundaries (bin-k pad + bin-(k+1) start lumped + # into one sub-seq), corrupting the hidden states. At mbs=1 only trailing pad + # mismatches (harmless), which is why packing previously required lbs==1. + cu_seqlens = kwargs.get("cu_seqlens_padded") + if cu_seqlens is None: + cu_seqlens = kwargs["cu_seqlens"] cp_size = self.cp.cp_size if self.cp is not None else 1 - total_len = (hidden_states.shape[1] if hidden_states.dim() == 3 else hidden_states.shape[0]) * cp_size - positions = torch.arange(total_len, device=hidden_states.device) - seq_idx = (torch.searchsorted(cu_seqlens[1:], positions)).unsqueeze(0).to(torch.int32) + if hidden_states.dim() == 3 and batch_size > 1: + # BSHD with B > 1 (e.g. default-collater validation): the cu_seqlens + # derived from a 2D attention_mask is global (cumsum across rows), not + # per-row. mamba_ssm asserts seq_idx.shape == (B, S) and processes each + # row independently, so treat each row as a single sub-sequence. (Gating + # on ``cu_seqlens`` being present is load-bearing: plain batched BSHD + # passes no cu_seqlens, so seq_idx stays None and the kernel handles the + # batched / CP-gathered sequence itself — forcing zeros here breaks + # BSHD context-parallel, where the kernel sees the gathered length.) + seq_idx = torch.zeros(batch_size, seq_len, device=hidden_states.device, dtype=torch.int32) + else: + # cu_seqlens from the THD batch is GLOBAL (pre-TE-partitioning). + # When CP is active, the mamba kernel receives the global sequence + # (after all-to-all gather). Scale total_len by cp_size so that + # seq_idx has the correct global length. + total_len = (hidden_states.shape[1] if hidden_states.dim() == 3 else hidden_states.shape[0]) * cp_size + positions = torch.arange(total_len, device=hidden_states.device) + # ``right=True`` so a position equal to a boundary (the first token of + # a new sub-seq, position == cu_seqlens[k]) maps to ``k``, not ``k-1``. + # Without this mamba's SSD scan resets state one token late at every + # sub-seq boundary — the first token of a new sub-seq sees the + # previous sub-seq's accumulated state. + seq_idx = torch.searchsorted(cu_seqlens[1:], positions, right=True).unsqueeze(0).to(torch.int32) # --- Path A: Training (no cache) → fused kernel --- if not use_cache: diff --git a/nemo_automodel/components/models/nemotron_v3/model.py b/nemo_automodel/components/models/nemotron_v3/model.py index 03c4276358..0a73398878 100644 --- a/nemo_automodel/components/models/nemotron_v3/model.py +++ b/nemo_automodel/components/models/nemotron_v3/model.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from transformers import AutoConfig from transformers.generation import GenerationConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast @@ -30,6 +31,7 @@ from nemo_automodel.components.models.common.utils import cast_model_to_dtype from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Block from nemo_automodel.components.models.nemotron_v3.mtp import ( + _resolve_block_types_per_sublayer, build_mtp_config_from_hf, build_nemotron_v3_mtp, ) @@ -139,42 +141,92 @@ def forward( cache_position: torch.LongTensor | None = None, **kwargs: Any, ) -> torch.Tensor: - """Forward pass through the model. Supports BSHD ``[B, S, H]`` and THD ``[T, H]``.""" - # Get embeddings + """Forward pass through the model. Supports BSHD ``[B, S, H]`` and THD ``[T, H]``. + + Pipeline-parallel awareness: when ``self.embed_tokens is None`` (non-first + PP stage), ``input_ids`` is interpreted as the upstream hidden-state + tensor and routed through the ``inputs_embeds`` branch. When + ``self.norm is None`` (non-last PP stage), the final norm is skipped. + """ + # Get embeddings (PP-aware: a trimmed mid-stage has embed_tokens=None + # and receives the prior stage's hidden states as input_ids). if inputs_embeds is None: - if input_ids is None: - raise ValueError("input_ids must be provided if inputs_embeds is not provided") - hidden_states = self.embed_tokens(input_ids) + if getattr(self, "embed_tokens", None) is not None: + if input_ids is None: + raise ValueError("input_ids must be provided if inputs_embeds is not provided") + hidden_states = self.embed_tokens(input_ids) + else: + if input_ids is None or input_ids.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise ValueError("Non-first PP stage expects an upstream hidden-state tensor") + hidden_states = input_ids else: hidden_states = inputs_embeds - # When qkv_format="thd" is explicitly requested with batch_size=1, - # squeeze to 2D [T, H] so attention layers receive the correct shape - # for TE's thd qkv_format. Note: cu_seqlens alone does NOT trigger - # the squeeze because cu_seqlens may be present solely for mamba's - # seq_idx construction (e.g. packed sequences with TE p2p CP where - # attention must stay in BSHD format). + # TE natively supports THD; squeeze to [T, H] so attention layers + # pick the THD branch. SDPA/flex only support 4D BSHD. + _attn_impl = getattr(getattr(self, "backend", None), "attn", None) squeezed_for_thd = False - if kwargs.get("qkv_format") == "thd" and hidden_states.dim() == 3 and hidden_states.shape[0] == 1: + if ( + kwargs.get("qkv_format") == "thd" + and _attn_impl == "te" + and hidden_states.dim() == 3 + and hidden_states.shape[0] == 1 + ): hidden_states = hidden_states.squeeze(0) squeezed_for_thd = True is_thd = hidden_states.dim() == 2 - # TODO: attention mask currently does not work. A default causal mask is applied. + # Non-THD-collater path doesn't emit cu_seqlens — recover it from a + # 2D indexed attention_mask so mamba's seq_idx derivation has input. + # Gated on B==1: cu_seqlens describes sub-sequence boundaries within a + # single flattened token stream. At B>1 the rows are independent + # sequences (the batch dim already separates them), so cumsum(0) would + # be a wrong cross-row global offset. Leave cu_seqlens unset there and + # let each consumer treat every row as one sequence (mamba seq_idx + # falls back to per-row, the PP seq_idx tail to its no-mask sentinel). + if ( + "cu_seqlens" not in kwargs + and attention_mask is not None + and attention_mask.dim() == 2 + and attention_mask.shape[0] == 1 + and attention_mask.dtype != torch.bool + ): + seq_lens = attention_mask.sum(dim=-1).to(torch.int32) + kwargs["cu_seqlens"] = F.pad(seq_lens.cumsum(0).to(torch.int32), (1, 0)) + + # Per-microbatch arrives as [1, K] (THD collator stacks per-MB and PP + # tensor_splits dim 0). Squeeze to 1D, strip the -1000 right-pad + # sentinels from pad_and_stack, and clone to a fresh tensor so the + # values survive TE backward (PP may free view-backed kwarg storage). + _SEQLEN_SENTINEL = -1000 + for _k in ("cu_seqlens", "cu_seqlens_padded"): + _v = kwargs.get(_k) + if isinstance(_v, torch.Tensor) and _v.dim() == 2 and _v.shape[0] == 1: + _v = _v.squeeze(0) + if isinstance(_v, torch.Tensor) and _v.dim() == 1: + if (_v == _SEQLEN_SENTINEL).any(): + _v = _v[_v != _SEQLEN_SENTINEL] + kwargs[_k] = _v.contiguous().clone() + _ms = kwargs.get("max_seqlen") + if isinstance(_ms, torch.Tensor) and _ms.dim() >= 1 and _ms.numel() == 1: + kwargs["max_seqlen"] = _ms.flatten()[0].clone() - # Get 4D causal mask for attention layers (from precomputed masks). causal_mask = causal_mask_mapping.get("full_attention") if causal_mask_mapping is not None else None - # Apply transformer layers + # Neat-packed SDPA path: seq_idx came from _packed_seq_ids upstream + # and attention_mask is the 4D block-causal bool mask. Mamba uses + # seq_idx (no 2D mask); attention uses the 4D mask directly. + _neat_packed = "seq_idx" in kwargs and attention_mask is not None and attention_mask.dim() == 4 + for layer in self.layers.values(): - # Pass appropriate mask based on layer type if is_thd: mask = None + elif _neat_packed: + mask = attention_mask if layer.block_type == "attention" else None elif layer.block_type == "attention": - # Attention layers use 4D causal mask; fall back to 2D attention_mask - # when causal_mask is None (e.g. during TE+CP training where CP split - # removes the precomputed 4D mask) so TE can use padding_causal mode. + # Fall back to 2D attention_mask when causal_mask is None + # (e.g. TE+CP, where the CP split drops the precomputed 4D mask). mask = causal_mask if causal_mask is not None else attention_mask elif layer.block_type == "mamba": # Mamba layers use 2D padding mask during prefill, None during decode @@ -191,10 +243,10 @@ def forward( **kwargs, ) - # Final norm - hidden_states = self.norm(hidden_states) + # Norm is None on non-last PP stages (splitter trims it). + if getattr(self, "norm", None) is not None: + hidden_states = self.norm(hidden_states) - # Restore batch dimension if we squeezed for THD if squeezed_for_thd: hidden_states = hidden_states.unsqueeze(0) @@ -204,15 +256,18 @@ def forward( def initialize_weights(self, buffer_device: torch.device | None = None) -> None: """Initialize model weights according to NemotronV3 spec. + After PP splitting, ``embed_tokens`` may be ``None`` on non-first + stages and ``norm`` may be ``None`` on non-last stages; guard each. + Args: buffer_device: Device to use for buffer initialization """ - # Embedding weights: normal initialization with buffer_device: - nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=self.config.initializer_range) - self.norm.reset_parameters() + if getattr(self, "embed_tokens", None) is not None: + nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=self.config.initializer_range) + if getattr(self, "norm", None) is not None: + self.norm.reset_parameters() - # Initialize all layers via delegation for block in self.layers.values(): block.init_weights(buffer_device=buffer_device) @@ -224,11 +279,13 @@ class NemotronHForCausalLM(HFCheckpointingMixin, GenerationMixin, nn.Module, MoE per-step KV caching for attention layers and recurrent state caching for Mamba2 layers. """ - # Prevent GenerationMixin from creating a DynamicCache: the hybrid Mamba2/Attention - # architecture uses its own NemotronHybridCache. + # Hybrid Mamba2/Attention uses NemotronHybridCache, not DynamicCache. _is_stateful: bool = True main_input_name: str = "input_ids" + # Skip patch_hf_model_for_pp; our forward already handles PP routing. + _pp_keep_self_forward: bool = True + @classmethod def from_config( cls, @@ -323,7 +380,6 @@ def __init__( dtype=dtype, ) - # self.mtp is None when num_nextn_predict_layers is absent or 0. self.mtp_config = build_mtp_config_from_hf( config, loss_scaling_factor=mtp_loss_scaling_factor, @@ -331,17 +387,18 @@ def __init__( use_repeated_layer=mtp_use_repeated_layer, ) if self.mtp_config.enabled: + block_types = _resolve_block_types_per_sublayer(config) self.mtp = build_nemotron_v3_mtp( config, mtp_config=self.mtp_config, backend=self.backend, moe_config=self.model.moe_config, dtype=dtype, + block_types=block_types, ) else: self.mtp = None - # Create state_dict_adapter if enabled (needed to convert HF checkpoints) if self.backend.enable_hf_state_dict_adapter: self.state_dict_adapter = NemotronV3StateDictAdapter( config=config, @@ -350,7 +407,7 @@ def __init__( dtype=dtype, ) - # Required by GenerationMixin.generate() + # Required by GenerationMixin.generate(). self.generation_config = GenerationConfig() @property @@ -375,9 +432,131 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def _is_pipeline_parallel_stage(self) -> bool: + """True when this module instance has been trimmed to a PP stage subset. + + Detection mirrors ``DeepseekV4ForCausalLM._is_pipeline_parallel_stage``: + any of (a) ``lm_head`` is None, (b) inner ``embed_tokens`` is None, + (c) ``model.layers`` count diverges from ``config.num_hidden_layers`` + is sufficient — the PP splitter nulls these attributes when trimming. + + The checks use ``hasattr`` to distinguish "splitter nulled the + attribute" (attribute present, value is None) from "caller replaced + ``self.model`` with a stub that doesn't declare the attribute" + (attribute absent). Tests that swap in stub inner modules should not + be misclassified as PP stages. + """ + if self.lm_head is None: + return True + if hasattr(self.model, "embed_tokens") and self.model.embed_tokens is None: + return True + if hasattr(self.model, "layers"): + try: + return len(self.model.layers) != int(self.config.num_hidden_layers) + except TypeError: + return False + return False + + def _build_mtp_embed_inputs_for_pp(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Build the per-depth rolled-token embeddings on the first PP stage. + + The first PP stage owns ``embed_tokens`` and is the only rank that can + produce the future-token embeddings consumed by the MTP head on the + final stage. The tuple flows alongside ``hidden_states`` through every + intermediate stage as additional positional outputs (see ``forward``). + + Args: + input_ids: Token ids ``[B, S]`` (int). + + Returns: + Tuple of length ``self.mtp_config.num_layers`` containing + ``[B, S, hidden]`` embeddings for depths 1..D (i.e. for predicting + tokens shifted left by 1..D positions). + """ + if getattr(self.model, "embed_tokens", None) is None: + raise ValueError("First PP stage must own embed_tokens to build MTP embeddings") + if input_ids.dtype not in (torch.int32, torch.int64, torch.long): + raise ValueError("First PP stage must receive token ids to build MTP embeddings") + + from nemo_automodel.components.models.common.mtp import roll_tensor # noqa: PLC0415 + + cur_input_ids = input_ids + embeds: list[torch.Tensor] = [] + for _ in range(self.mtp_config.num_layers): + cur_input_ids = roll_tensor(cur_input_ids, shifts=-1, dim=-1) + embeds.append(self.model.embed_tokens(cur_input_ids)) + return tuple(embeds) + + def customize_pipeline_stage_modules( + self, + module_names_per_stage: list[list[str]], + *, + layers_prefix: str, + text_model: nn.Module | None = None, + ) -> list[list[str]]: + """Pin the MTP head to the last PP stage's FQN list. + + Called by ``split_model_into_stages`` (functional.py:494-502) after the + default per-stage FQN auto-generation. The auto-generator includes + ``embed_tokens`` on the first stage and ``norm``/``lm_head`` on the + last stage but doesn't know about ``model.mtp``; this hook appends it. + """ + del layers_prefix, text_model # unused — no per-stage rotary to replicate + stage_modules = [list(m) for m in module_names_per_stage] + if self.mtp is not None and stage_modules: + last = stage_modules[-1] + if "mtp" not in last: + last.append("mtp") + return stage_modules + + def get_pipeline_stage_metas( + self, + *, + is_first: bool, + microbatch_size: int, + seq_len: int, + dtype: torch.dtype, + ) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: + """Return analytical (inputs_meta, outputs_meta) for a PP stage. + + Inter-stage tensors are plain ``[B, S, H]`` (no HC stream). With MTP + enabled, every transfer carries ``1 + D`` tensors so the variadic + forward signature is exercised on every microbatch. + """ + hidden_shape = (microbatch_size, seq_len, self.config.hidden_size) + mtp_depth = int(getattr(self.mtp_config, "num_layers", 0) or 0) + + def meta(shape: tuple[int, ...], d: torch.dtype = dtype) -> torch.Tensor: + return torch.empty(*shape, device="meta", dtype=d) + + def append_mtp(primary: torch.Tensor) -> tuple[torch.Tensor, ...]: + if mtp_depth == 0: + return (primary,) + return (primary, *(meta(hidden_shape) for _ in range(mtp_depth))) + + if is_first: + inputs_meta: tuple[torch.Tensor, ...] = ( + torch.empty(microbatch_size, seq_len, device="meta", dtype=torch.long), + ) + else: + inputs_meta = append_mtp(meta(hidden_shape)) + + if self.lm_head is not None: + primary_out = meta((microbatch_size, seq_len, self.config.vocab_size)) + else: + primary_out = meta(hidden_shape) + outputs_meta = append_mtp(primary_out) + # Last stage appends an int32 [B, S] seq_idx so the loss fn can mask + # MTP label rolls across sub-seq boundaries — bonded to its microbatch + # via the PP output-tuple contract (schedule-agnostic). + if self.lm_head is not None and mtp_depth > 0: + outputs_meta = (*outputs_meta, meta((microbatch_size, seq_len), d=torch.int32)) + return inputs_meta, outputs_meta + def forward( self, input_ids: Optional[torch.LongTensor] = None, + *mtp_embed_inputs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal_mask_mapping: Optional[dict[str, torch.Tensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -395,48 +574,114 @@ def forward( """Forward pass with optional loss computation. Supports both BSHD format (``input_ids`` shape ``[B, S]``) and THD format - (``input_ids`` shape ``[T]`` after ``squeeze_input_for_thd``). When - ``kwargs["qkv_format"] == "thd"``, inputs are squeezed to THD before the - base-model forward and logits are unsqueezed back to ``[1, T, V]`` on exit. + (``input_ids`` shape ``[T]`` after ``squeeze_input_for_thd``). When + ``kwargs["qkv_format"] == "thd"`` AND the attention backend is TE, + inputs are squeezed to THD before the base-model forward and logits + are unsqueezed back to ``[1, T, V]`` on exit. SDPA / flex stay in BSHD. + + Pipeline-parallel awareness: when run as a PP stage, ``input_ids`` is + the upstream stage's hidden-state tensor on non-first stages, and + ``*mtp_embed_inputs`` carries ``num_nextn_predict_layers`` future-token + embeddings produced by the first stage. See the Returns section below + for the per-stage tuple contract. The single-rank (no-PP) path returns + :class:`NemotronHCausalLMOutputWithPast` unchanged. Args: - input_ids: Input token IDs. BSHD: ``[B, S]``; THD: ``[1, T]`` (squeezed internally). + input_ids: Input token IDs. BSHD: ``[B, S]``; THD: ``[1, T]`` + (squeezed internally). On non-first PP stages this slot + instead carries the upstream stage's hidden-state tensor. + *mtp_embed_inputs: Pre-computed future-token embeddings produced + by the first PP stage and forwarded between stages as + positional args. Empty on the single-rank (no-PP) path. attention_mask: 2D padding mask ``[B, S]``. - causal_mask_mapping: Dict with precomputed 4D causal masks. + causal_mask_mapping: Dict with precomputed 4D causal masks + (key ``"full_attention"`` is consumed). inputs_embeds: Pre-computed input embeddings (optional). - labels: Token IDs for loss computation ``[B, S]`` (optional). - past_key_values: Optional NemotronHybridCache for incremental decoding. - use_cache: Whether to return past_key_values for subsequent steps. + labels: Token IDs for loss computation ``[B, S]`` (optional; + under PP, loss is computed by ``PipelineCausalLMLoss``). + past_key_values: Optional ``NemotronHybridCache`` for incremental decoding. + use_cache: Whether to return ``past_key_values`` for subsequent steps. cache_position: Token position indices for cache updates. - position_ids: Unused -- accepted for API compatibility with GenerationMixin. - padding_mask: Padding mask ``[B, S]`` used by THD squeeze helper. - logits_to_keep: If > 0, only compute logits for the last ``logits_to_keep`` - token positions (avoids materialising the full logit matrix during generation). + position_ids: Position IDs (forwarded into MTP sublayer kwargs). + padding_mask: Padding mask ``[B, S]`` used by the THD squeeze helper + and as the MoE / mamba 2D mask source. + logits_to_keep: If > 0, only compute logits for the last + ``logits_to_keep`` token positions. output_hidden_states: Whether to return hidden states. - return_dict: Accepted for API compatibility (always returns CausalLMOutputWithPast). + return_dict: Accepted for API compatibility (always returns a + ``NemotronHCausalLMOutputWithPast`` off-PP). **kwargs: Additional arguments forwarded to the base model - (e.g. seq_idx, cu_seqlens, qkv_format, CP kwargs). + (e.g. ``qkv_format``, ``cu_seqlens``, ``cu_seqlens_padded``, + ``max_seqlen``, ``seq_idx``, ``cp_rank``, ``cp_size``, + ``_packed_seq_ids``). Returns: - :class:`~transformers.modeling_outputs.CausalLMOutputWithPast` with - ``logits`` (float32), optional ``loss``, ``past_key_values``, and - ``hidden_states``. + Off-PP: :class:`NemotronHCausalLMOutputWithPast` with ``logits``, + optional ``loss``, ``past_key_values``, ``hidden_states``, and the + MTP per-depth hidden states / loss-scaling factor when MTP is on. + + Under PP, returns a positional tuple instead: + * mid stages: ``(hidden_states, *mtp_embed_inputs)`` arity ``1 + D``. + * last stage: ``(logits, *mtp_per_depth_h, seq_idx)`` arity + ``1 + D + 1`` when MTP is enabled, else ``logits`` alone. """ + is_pp_stage = self._is_pipeline_parallel_stage() + is_first_stage = getattr(self.model, "embed_tokens", None) is not None + has_lm_head = self.lm_head is not None + mtp_depth = int(getattr(self.mtp_config, "num_layers", 0) or 0) + pp_mtp_enabled = is_pp_stage and self.mtp_config.enabled + + # Neat-packed SDPA: convert _packed_seq_ids (1-based [B,S] int, 0=pad) + # to mamba's seq_idx and derive a 2D padding_mask. The neat collater + # already supplies a 4D attention_mask, but mamba's mixer multiplies + # hidden_states by a 2D mask, so we need both. + _packed_seq_ids = kwargs.pop("_packed_seq_ids", None) + if isinstance(_packed_seq_ids, torch.Tensor) and "seq_idx" not in kwargs: + kwargs["seq_idx"] = _packed_seq_ids.to(torch.int32).contiguous() + if padding_mask is None: + padding_mask = _packed_seq_ids == 0 + output_hidden_states = ( output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) ) + # Stash pre-squeeze [B, S] input_ids: the MTP embed tuple must be + # built AFTER self.model() runs (FSDP2 root lazy-init requires the + # root forward first) but with the pre-squeeze shape so emitted + # tensors match the [B, S, H] contract from get_pipeline_stage_metas. + pre_squeeze_input_ids = ( + input_ids + if ( + pp_mtp_enabled + and is_first_stage + and not has_lm_head + and not mtp_embed_inputs + and input_ids is not None + and input_ids.dtype in (torch.int32, torch.int64, torch.long) + ) + else None + ) + + # Squeezing to [T, H] only helps TE; SDPA/flex need 4D BSHD. Keep + # is_thd true regardless so the post-forward unsqueeze still fires. + _attn_impl = getattr(getattr(self, "backend", None), "attn", None) is_thd = kwargs.get("qkv_format") == "thd" - if is_thd: + squeeze_for_thd = is_thd and _attn_impl == "te" + if squeeze_for_thd and is_first_stage: input_ids, position_ids, padding_mask, kwargs = squeeze_input_for_thd( input_ids, position_ids, padding_mask, kwargs ) attention_mask = None causal_mask_mapping = None - # Forward through base model + # MoE needs a padding_mask; derive from attention_mask when missing. + if padding_mask is None and attention_mask is not None and attention_mask.dim() == 2: + padding_mask = attention_mask.bool().logical_not() + + # On non-first PP stages, the upstream hidden-state tensor arrives in + # the input_ids slot; the inner model routes it via inputs_embeds. hidden_states = self.model( input_ids, attention_mask=attention_mask, @@ -447,24 +692,31 @@ def forward( **kwargs, ) - # Mark cache as having state after the first forward pass (prefill done) + # Root forward has run; FSDP2 lazy-init is satisfied. Build MTP embed + # tuple from pre-squeeze [B, S] ids so emitted shapes match + # get_pipeline_stage_metas. + if pre_squeeze_input_ids is not None: + mtp_embed_inputs = self._build_mtp_embed_inputs_for_pp(pre_squeeze_input_ids) + if past_key_values is not None: past_key_values.has_previous_state = True - # Optionally restrict logit computation to the last few positions. - # When logits_to_keep == 0 we compute all positions (training default). - if isinstance(logits_to_keep, int) and logits_to_keep == 0: - logits = self.lm_head(hidden_states) - else: - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - if hidden_states.dim() == 2: - logits = self.lm_head(hidden_states[slice_indices, :]) + # lm_head is None on non-last PP stages; return raw hidden_states. + if has_lm_head: + if isinstance(logits_to_keep, int) and logits_to_keep == 0: + logits = self.lm_head(hidden_states) else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + if hidden_states.dim() == 2: + logits = self.lm_head(hidden_states[slice_indices, :]) + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: + logits = hidden_states loss = None - if labels is not None: - # Shift for next-token prediction + # PP path defers loss to PipelineCausalLMLoss; only compute here off-PP. + if labels is not None and has_lm_head and not is_pp_stage: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( @@ -472,30 +724,132 @@ def forward( shift_labels.view(-1), ) - # Per-depth hidden states for the MTP auxiliary head; the recipe - # dispatches CE per depth via the configured loss class. - mtp_per_depth_h = None + # MTP head: last PP stage in training only. Other stages / eval emit + # placeholder empties below so the tuple arity stays 1 + D. + mtp_per_depth_h: list[torch.Tensor] | None = None if self.mtp is not None and self.training: - mtp_per_depth_h = self.mtp( - input_ids=input_ids, - hidden_states=hidden_states, - embed_fn=self.model.embed_tokens, - position_ids=position_ids, - attention_mask=causal_mask_mapping.get("full_attention") - if causal_mask_mapping is not None - else attention_mask, + mtp_attention_mask = ( + causal_mask_mapping.get("full_attention") if causal_mask_mapping is not None else attention_mask ) - - # Restore the batch dim for THD only when the inner forward returned - # 2D logits. When the caller feeds the model via ``inputs_embeds`` - # (shape ``[1, T, H]``), ``NemotronHModel.forward`` squeezes to - # ``[T, H]`` for the layer stack and unsqueezes back to ``[1, T, H]`` - # before returning (see the ``squeezed_for_thd`` branch); the lm_head - # then yields ``[1, T, V]`` already and a second unsqueeze here would - # produce a spurious ``[1, 1, T, V]``. + # Forward THD-packing context to the MTP sublayers; without it + # they'd auto-detect bshd from the 3D shape and bleed attention + # across sub-seq boundaries. cp_rank/cp_size must also be + # forwarded — the CP forward pre-hook only attaches to the + # backbone attention, not to MTP sublayers that share the class. + mtp_kwargs = { + "position_ids": position_ids, + "attention_mask": mtp_attention_mask, + } + for _k in ( + "qkv_format", + "cu_seqlens", + "cu_seqlens_padded", + "max_seqlen", + "max_seqlen_q", + "max_seqlen_kv", + "seq_idx", + "cp_rank", + "cp_size", + ): + if _k in kwargs: + _v = kwargs[_k] + # Same per-microbatch [1, K] → 1D normalization as in + # NemotronV3Model.forward; re-applied for the MTP chain. + if _k in ("cu_seqlens", "cu_seqlens_padded", "seq_idx") and isinstance(_v, torch.Tensor): + if _v.dim() == 2 and _v.shape[0] == 1: + _v = _v.squeeze(0) + if _v.dim() == 1: + # Strip pad_and_stack's -1000 sentinels; TE would + # interpret pairs like (1024, -1000) as a negative- + # length sub-seq and OOB-write during backward. + if (_v == -1000).any(): + _v = _v[_v != -1000] + _v = _v.contiguous().clone() + mtp_kwargs[_k] = _v + # Squeeze to 2D ``[T, H]`` when THD-packed so the attention layer + # selects its THD branch. Also squeeze the propagated MTP embed + # tensors for the same reason. + mtp_hidden = hidden_states + mtp_embeds_for_call = tuple(mtp_embed_inputs) if mtp_embed_inputs else () + if is_thd: + if mtp_hidden.dim() == 3 and mtp_hidden.shape[0] == 1: + mtp_hidden = mtp_hidden.squeeze(0) + mtp_embeds_for_call = tuple( + e.squeeze(0) if (e.dim() == 3 and e.shape[0] == 1) else e for e in mtp_embeds_for_call + ) + + if mtp_embeds_for_call: + # Final PP stage: embeddings produced upstream. + mtp_per_depth_h = self.mtp( + hidden_states=mtp_hidden, + embed_inputs=mtp_embeds_for_call, + **mtp_kwargs, + ) + else: + # Non-PP single-rank: roll input_ids locally. + mtp_per_depth_h = self.mtp( + hidden_states=mtp_hidden, + input_ids=input_ids, + embed_fn=self.model.embed_tokens, + **mtp_kwargs, + ) + if is_thd and mtp_per_depth_h is not None: + mtp_per_depth_h = [h.unsqueeze(0) if h.dim() == 2 else h for h in mtp_per_depth_h] + elif pp_mtp_enabled and has_lm_head: + # Eval, or no MTP on this rank — emit empties to keep tuple arity. + mtp_per_depth_h = [hidden_states.new_empty(hidden_states.shape) for _ in range(mtp_depth)] + + # Restore batch dim only when inner forward returned 2D — inputs_embeds + # path already produces [1, T, V] and would become [1, 1, T, V] here. if is_thd and logits.dim() == 2: logits = logits.unsqueeze(0) + # PP return contract: + # mid stages: (hidden_states, *mtp_embed_inputs) arity 1+D + # last stage: (logits, *mtp_per_depth_h, seq_idx) arity 1+D+1 + # The seq_idx tail binds the per-microbatch sub-seq layout to its loss + # call via the PP runtime's forward(mb_i)→loss(mb_i) contract. + if is_pp_stage: + if pp_mtp_enabled: + if not has_lm_head: + return (logits, *mtp_embed_inputs) + assert mtp_per_depth_h is not None + # seq_idx tail sources, in order of preference: + # 1. kwargs["seq_idx"] (neat-path). + # 2. derived from kwargs["cu_seqlens"] (THD/TE path). + # 3. all-1 sentinel — loss-fn cross-boundary mask is a no-op. + if logits.dim() == 3: + _B, _S = logits.shape[:2] + elif logits.dim() == 2: + _B, _S = 1, logits.shape[0] + else: + _B, _S = 1, hidden_states.shape[-2] + + _seq_idx_tail = kwargs.get("seq_idx", None) + if not isinstance(_seq_idx_tail, torch.Tensor): + _cu = kwargs.get("cu_seqlens", None) + if isinstance(_cu, torch.Tensor): + _cu1d = _cu.squeeze(0) if (_cu.dim() == 2 and _cu.shape[0] == 1) else _cu + if _cu1d.dim() == 1: + _positions = torch.arange(_S, device=_cu1d.device) + # ``right=True`` so a position equal to a boundary + # (first token of sub-seq k, position == cu_seqlens[k]) + # maps to k, not k-1 — matches mtp.py and layers.py. + _seq_idx_1d = torch.searchsorted(_cu1d[1:].contiguous(), _positions, right=True).to( + torch.int32 + ) + _seq_idx_tail = _seq_idx_1d.unsqueeze(0).expand(_B, _S).contiguous() + if not isinstance(_seq_idx_tail, torch.Tensor): + _seq_idx_tail = torch.ones((_B, _S), dtype=torch.int32, device=logits.device) + else: + if _seq_idx_tail.dim() == 1: + _seq_idx_tail = _seq_idx_tail.unsqueeze(0) + if _seq_idx_tail.dtype != torch.int32: + _seq_idx_tail = _seq_idx_tail.to(torch.int32) + return (logits, *mtp_per_depth_h, _seq_idx_tail) + return logits + + # Non-PP: dataclass return for the recipe's MTP loss reader. return NemotronHCausalLMOutputWithPast( loss=loss, logits=logits, @@ -611,6 +965,10 @@ def initialize_weights( ) -> None: """Initialize model weights. + PP-aware: skips ``lm_head`` and ``mtp`` initialization when those have + been trimmed to ``None`` on a non-owning stage. ``self.model`` itself + also internally guards ``embed_tokens`` and ``norm``. + Args: buffer_device: Device to use for buffer initialization dtype: Target dtype for model weights @@ -618,7 +976,8 @@ def initialize_weights( buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") with buffer_device: self.model.initialize_weights(buffer_device=buffer_device) - nn.init.normal_(self.lm_head.weight, mean=0.0, std=self.config.initializer_range) + if self.lm_head is not None: + nn.init.normal_(self.lm_head.weight, mean=0.0, std=self.config.initializer_range) if self.mtp is not None: for sublayer in self.mtp.layers: sublayer.init_weights(buffer_device=buffer_device) diff --git a/nemo_automodel/components/models/nemotron_v3/mtp.py b/nemo_automodel/components/models/nemotron_v3/mtp.py index ec4c50ae88..e3d0c4c105 100644 --- a/nemo_automodel/components/models/nemotron_v3/mtp.py +++ b/nemo_automodel/components/models/nemotron_v3/mtp.py @@ -167,12 +167,49 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: self.final_layernorm.reset_parameters() +_VALID_BLOCK_TYPES = frozenset(_PATTERN_SYMBOL_TO_BLOCK_TYPE.values()) + + +def _resolve_block_types_per_sublayer(config) -> list[str] | None: + """Resolve the per-depth MTP block-type list from either HF field. + + Super-V3 ships ``mtp_hybrid_override_pattern`` (symbol-string form like + ``"*E"``). Newer NemotronH variants ship ``mtp_layers_block_type`` + (list-of-strings form like ``["attention", "moe"]``). Either is + accepted. + + Args: + config: HF NemotronH config. + + Returns: + Parsed list of block-type names, or ``None`` when neither field is set. + + Raises: + ValueError: If ``mtp_layers_block_type`` contains an unknown block type. + """ + pattern = getattr(config, "mtp_hybrid_override_pattern", None) + if pattern: + return parse_mtp_layer_pattern(pattern) + block_types = getattr(config, "mtp_layers_block_type", None) + if block_types: + block_types = list(block_types) + for bt in block_types: + if bt not in _VALID_BLOCK_TYPES: + raise ValueError( + f"Unknown MTP block type {bt!r} in mtp_layers_block_type; " + f"valid types are {sorted(_VALID_BLOCK_TYPES)}" + ) + return block_types + return None + + def build_nemotron_v3_mtp( config, mtp_config: MTPConfig, backend: BackendConfig, moe_config, dtype: torch.dtype, + block_types: list[str] | None = None, ) -> MTPModule: """Construct the NemotronV3 MTP block. @@ -183,13 +220,26 @@ def build_nemotron_v3_mtp( moe_config: MoE configuration shared with the main backbone (required when the MTP pattern contains MoE sublayers). dtype: Target dtype for newly created linear modules. + block_types: Optional pre-parsed list of block-type names (one per + inner sublayer). When supplied, bypasses + :func:`parse_mtp_layer_pattern` on ``mtp_config.layer_pattern``. + Required when ``mtp_config.layer_pattern`` is a length-only + sentinel (e.g. produced from ``mtp_layers_block_type``). Returns: A configured :class:`MTPModule`. Caller should not invoke this when ``mtp_config.enabled`` is ``False``. """ base_layer_idx = config.num_hidden_layers - block_types_per_sublayer = parse_mtp_layer_pattern(mtp_config.layer_pattern) + if block_types is None: + block_types_per_sublayer = parse_mtp_layer_pattern(mtp_config.layer_pattern) + else: + block_types_per_sublayer = list(block_types) + if len(block_types_per_sublayer) != mtp_config.pattern_length: + raise ValueError( + f"block_types length {len(block_types_per_sublayer)} does not match " + f"mtp_config.pattern_length {mtp_config.pattern_length}" + ) def factory(*, global_idx, depth, sublayer_idx, block_type, has_fusion, has_final_norm): return NemotronV3MTPSublayer( @@ -219,10 +269,15 @@ def build_mtp_config_from_hf( ) -> MTPConfig: """Construct an :class:`MTPConfig` from an HF NemotronH config. - Reads ``num_nextn_predict_layers`` and ``mtp_hybrid_override_pattern`` - directly off the HF config object (both present on the released Super V3 - ``config.json``). Returns a disabled config (``num_layers=0``) when MTP - is not configured. + Reads ``num_nextn_predict_layers`` and resolves the per-depth pattern from + either ``mtp_hybrid_override_pattern`` (Super-V3 symbol-string form) or + ``mtp_layers_block_type`` (list-of-strings form). Returns a disabled + config (``num_layers=0``) when MTP is not configured. + + When the pattern source is the list form, :attr:`MTPConfig.layer_pattern` + is set to a length-matching sentinel string of ``"X"`` characters — the + actual block-type names are carried separately into + :func:`build_nemotron_v3_mtp` via its ``block_types`` kwarg. Args: config: HF NemotronH config. @@ -245,7 +300,14 @@ def build_mtp_config_from_hf( num_layers = int(getattr(config, "num_nextn_predict_layers", 0) or 0) else: num_layers = int(num_nextn_predict_layers) - pattern = getattr(config, "mtp_hybrid_override_pattern", "") or "" + + pattern = getattr(config, "mtp_hybrid_override_pattern", None) or "" + if not pattern: + block_types = getattr(config, "mtp_layers_block_type", None) + if block_types: + # Length-only sentinel; real block-type list flows through + # build_nemotron_v3_mtp's block_types kwarg. + pattern = "X" * len(list(block_types)) return MTPConfig( num_layers=num_layers, layer_pattern=pattern, diff --git a/nemo_automodel/components/models/nemotron_v3/state_dict_adapter.py b/nemo_automodel/components/models/nemotron_v3/state_dict_adapter.py index 323c872586..66b14ec802 100644 --- a/nemo_automodel/components/models/nemotron_v3/state_dict_adapter.py +++ b/nemo_automodel/components/models/nemotron_v3/state_dict_adapter.py @@ -130,6 +130,18 @@ def from_hf( Returns: Internal format state dict """ + # Drop checkpoint keys for backbone layers past ``num_hidden_layers`` + # (e.g. when loading the first N layers of a larger checkpoint for a + # downsized smoke run). The matcher tolerates both ``backbone.layers.{i}`` + # and ``model.layers.{i}`` since the prefix is normalized after this. + num_layers = int(getattr(self.config, "num_hidden_layers", 0) or 0) + if num_layers > 0: + layer_idx_pattern = re.compile(r"^(?:backbone|model)\.layers\.(\d+)\.") + for key in list(hf_state_dict.keys()): + m = layer_idx_pattern.match(key) + if m is not None and int(m.group(1)) >= num_layers: + hf_state_dict.pop(key) + # Separate MTP keys; they live in their own top-level namespace and # are not subject to the backbone/model rename. mtp_state_dict: dict[str, Any] = {} diff --git a/nemo_automodel/components/moe/layers.py b/nemo_automodel/components/moe/layers.py index d68315c2ca..3e8b8f571a 100644 --- a/nemo_automodel/components/moe/layers.py +++ b/nemo_automodel/components/moe/layers.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings -from contextlib import nullcontext from functools import partial from typing import Optional @@ -41,6 +40,77 @@ _shared_experts_stream: Optional[torch.cuda.Stream] = None +def _record_stream_safe(t: torch.Tensor, stream: Optional[torch.cuda.Stream]) -> None: + """Tell the caching allocator that ``stream`` uses ``t``'s storage. + + Required whenever a tensor allocated on one CUDA stream is read/written on + another stream: without it the allocator may recycle the block while + ``stream`` is still using it (cross-stream use-after-free -> corrupted + values). No-op for non-CUDA tensors; handles DTensor by recording on the + local shard. + """ + if stream is None or not isinstance(t, torch.Tensor): + return + local = t.to_local() if hasattr(t, "to_local") else t + if local.is_cuda: + local.record_stream(stream) + + +class _SharedExpertStreamFork(torch.autograd.Function): + """Fork the shared-expert input ``x`` into the side stream. + + Data-identity op; it only enforces cross-stream ordering + allocator safety + so the side-stream overlap is correct in BOTH passes. A raw ``wait_stream`` + is forward-only and has no backward mirror, which (together with missing + ``record_stream``) is the cross-stream race that corrupts gradients. + + * forward (main stream): the side stream waits for main so it observes ``x``. + * backward (main stream): ``grad_x`` was produced by the shared-expert + subgraph on the side stream, so main waits for side before that gradient + is accumulated into ``x``'s ``.grad`` alongside the main-stream branches. + """ + + @staticmethod + def forward(ctx, x, side_stream): + ctx.main_stream = torch.cuda.current_stream() + ctx.side_stream = side_stream + side_stream.wait_stream(ctx.main_stream) + _record_stream_safe(x, side_stream) + return x + + @staticmethod + def backward(ctx, grad_x): + ctx.main_stream.wait_stream(ctx.side_stream) + _record_stream_safe(grad_x, ctx.main_stream) + return grad_x, None + + +class _SharedExpertStreamJoin(torch.autograd.Function): + """Join the shared-expert output ``z`` back to the main stream. + + * forward (main stream): main waits for the side stream so ``z`` is ready + before it is added to the routed output. Applied AFTER ``experts()`` is + launched so the shared-expert compute overlaps the dispatch comm. + * backward (main stream): ``grad_z`` is produced on the main stream by the + ``y + z`` add; the side stream waits for main before the shared-expert + subgraph (replayed on the side stream) consumes it. + """ + + @staticmethod + def forward(ctx, z, side_stream): + ctx.main_stream = torch.cuda.current_stream() + ctx.side_stream = side_stream + ctx.main_stream.wait_stream(side_stream) + _record_stream_safe(z, ctx.main_stream) + return z + + @staticmethod + def backward(ctx, grad_z): + ctx.side_stream.wait_stream(ctx.main_stream) + _record_stream_safe(grad_z, ctx.side_stream) + return grad_z, None + + class MLP(nn.Module): """ Multi-Layer Perceptron (MLP) used as a feed-forward layer. @@ -689,8 +759,6 @@ def __init__(self, config: MoEConfig, backend: BackendConfig): # Set during model parallelization (see parallelizer.apply_cp) self.cp_mesh: Optional[DeviceMesh] = None - self._disable_shared_expert_overlap = backend.disable_shared_expert_overlap - def forward( self, x: torch.Tensor, @@ -727,27 +795,37 @@ def forward( weights, indices, aux_loss = self.gate(x, token_mask, cp_mesh) - # Shared-expert output (optionally gated). Run on a side CUDA stream - # to overlap with the grouped-expert dispatch comm, unless overlap is - # disabled (in which case run sequentially on the current stream). + # Shared-expert output (optionally gated). Run on a side CUDA stream to + # overlap with the grouped-expert dispatch comm. The fork/join autograd + # fences (_SharedExpertStreamFork / _SharedExpertStreamJoin) make the + # overlap correct in BOTH forward and backward: a raw wait_stream is + # forward-only, and its missing backward mirror plus the missing + # record_stream is a cross-stream race that corrupts gradients (grad-norm + # explosion at high tokens/microbatch on some hardware). shared_experts + # stays in the normal autograd graph, so FSDP2's post-backward + # reduce-scatter hooks and gradient accumulation are unaffected; expert + # parallelism only touches the routed ``experts`` path on the main stream. z = None side_stream = None if self.shared_experts is not None: - if not self._disable_shared_expert_overlap: - global _shared_experts_stream - if _shared_experts_stream is None: - _shared_experts_stream = torch.cuda.Stream() - side_stream = _shared_experts_stream - side_stream.wait_stream(torch.cuda.current_stream()) - stream_ctx = torch.cuda.stream(side_stream) if side_stream is not None else nullcontext() - with stream_ctx: - z = self.shared_experts(x) + global _shared_experts_stream + if _shared_experts_stream is None: + _shared_experts_stream = torch.cuda.Stream() + side_stream = _shared_experts_stream + + # Fork into the side stream (fences main->side fwd, side->main bwd). + x_se = _SharedExpertStreamFork.apply(x, side_stream) + with torch.cuda.stream(side_stream): + z = self.shared_experts(x_se) if self.shared_expert_gate is not None: - z = torch.nn.functional.sigmoid(self.shared_expert_gate(x)) * z + z = torch.nn.functional.sigmoid(self.shared_expert_gate(x_se)) * z + # Routed experts on the main stream — runs concurrently with the + # shared-expert side stream above; the join below waits for it. y = self.experts(x_latent, token_mask, weights, indices) if side_stream is not None: - torch.cuda.current_stream().wait_stream(side_stream) + # Join back to the main stream (fences side->main fwd, main->side bwd). + z = _SharedExpertStreamJoin.apply(z, side_stream) if self.fc2_latent_proj is not None: y = self.fc2_latent_proj(y) diff --git a/nemo_automodel/components/moe/parallelizer.py b/nemo_automodel/components/moe/parallelizer.py index 5a8d4f4a48..53fa1b672b 100644 --- a/nemo_automodel/components/moe/parallelizer.py +++ b/nemo_automodel/components/moe/parallelizer.py @@ -360,7 +360,7 @@ def apply_cp(model: torch.nn.Module, cp_mesh: DeviceMesh, cp_comm_type: str = "p # "Padding mask not supported with context parallelism!". _model._cp_enabled = True - for _, block in _model.layers.named_children(): + for _parent, _layer_id, block in _iter_transformer_and_mtp_blocks(model): layer_type = getattr(block, "layer_type", getattr(block, "attention_type", "full_attention")) if layer_type in ("full_attention", "sliding_attention"): diff --git a/nemo_automodel/components/moe/state_dict_mixin.py b/nemo_automodel/components/moe/state_dict_mixin.py index 8d37213cad..e1b9d73d5c 100644 --- a/nemo_automodel/components/moe/state_dict_mixin.py +++ b/nemo_automodel/components/moe/state_dict_mixin.py @@ -53,6 +53,26 @@ def _is_gated_moe(self) -> bool: return is_gated_activation(self.moe_config.expert_activation) + def _register_inplace_loaded_key(self, fqn: str, prefix_override: str | None) -> None: + """Mark ``fqn`` as loaded via in-place views so ``_from_hf_w_merged_experts`` skips its rebuild. + + The tracked key must match the native_key that the from_hf merge loop + reconstructs from the HF per-expert keys. For backbone tensors the + native_key equals ``fqn``; for MTP tensors (``prefix_override="mtp."``) + the HF keys live under the ``mtp.`` namespace and from_hf processes + them with that prefix stripped, so the tracked key is also the + ``mtp.``-less form. The user of this set (``_from_hf_w_merged_experts``) + receives the matching stripped key when called via the adapter's + per-namespace dispatch. + """ + if prefix_override is not None and prefix_override.endswith("."): + tracked = fqn[len(prefix_override) :] if fqn.startswith(prefix_override) else fqn + else: + tracked = fqn + if not hasattr(self, "_inplace_loaded_native_keys") or self._inplace_loaded_native_keys is None: + self._inplace_loaded_native_keys = set() + self._inplace_loaded_native_keys.add(tracked) + @property def _hf_prefix(self) -> str: """Prefix for HuggingFace format keys. Override in subclass.""" @@ -340,14 +360,18 @@ def _recombine_lora_expert_keys(self, state_dict: dict[str, Any]) -> dict[str, A return result - def _to_hf_w_split_experts(self, state_dict: dict[str, Any]) -> dict[str, Any]: + def _to_hf_w_split_experts(self, state_dict: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """Convert DeepEP format to HuggingFace format. - Handles: gate_and_up_projs, down_projs -> individual expert weights + + Handles ``gate_and_up_projs`` / ``down_projs`` -> individual expert + weights. Forwards ``**kwargs`` to + ``_convert_single_merged_expert_to_hf_split_experts`` for adapter + compatibility (e.g. ``exclude_key_regex``). """ hf_state_dict: dict[str, Any] = {} for fqn, tensor in state_dict.items(): - converted = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor) + converted = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs) if converted is not None: for key, value in converted: hf_state_dict[key] = value @@ -401,6 +425,9 @@ def _from_hf_w_merged_experts( rf"(?P(?:model\.)?(?:language_model\.)?)layers\.(\d+)\.{re.escape(expert_segment)}\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight" ) + inplace_loaded_keys: set = getattr(self, "_inplace_loaded_native_keys", None) or set() + consumed_inplace_keys: set = set() + for key in list(hf_state_dict.keys()): value = hf_state_dict.pop(key) if f".{expert_segment}." in key and key.endswith(".weight"): @@ -413,17 +440,23 @@ def _from_hf_w_merged_experts( layer_num, expert_num, which = m.group(2), m.group(3), m.group(4) expert_num = int(expert_num) + if which in ["gate_proj", "up_proj"]: + native_key = f"{prefix}layers.{layer_num}.{expert_segment}.gate_and_up_projs" + else: # down_proj + native_key = f"{prefix}layers.{layer_num}.{expert_segment}.down_projs" + + # Skip rebuild: DCP wrote through the view; model already holds the data. + if native_key in inplace_loaded_keys: + consumed_inplace_keys.add(native_key) + del value + continue + if not should_load_expert_for_rank(expert_num, device_mesh, n_experts): continue if layer_num not in expert_weights_by_layer: expert_weights_by_layer[layer_num] = {} - if which in ["gate_proj", "up_proj"]: - native_key = f"{prefix}layers.{layer_num}.{expert_segment}.gate_and_up_projs" - else: # down_proj - native_key = f"{prefix}layers.{layer_num}.{expert_segment}.down_projs" - if native_key not in expert_weights_by_layer[layer_num]: expert_weights_by_layer[layer_num][native_key] = {} @@ -473,11 +506,20 @@ def _from_hf_w_merged_experts( stacked = torch.stack(tensors, dim=0).to(self.dtype) state_dict[native_key] = create_dtensor_from_local(stacked, device_mesh, rank) - # Free completed expert tensors to release GPU memory + # Aggressively release intermediates so the per-layer + # transient does not pile on top of the model's + # already-materialized GPU DTensors. Without this, + # ``tensors``/``stacked`` and the per-expert dict + # entries hang around until Python's refcount GC + # eventually runs — too late under tight GPU budgets + # (e.g. a large MoE on 2 nodes / 8 GPUs). + del tensors, stacked del expert_weights_by_layer[layer_num][native_key] if not expert_weights_by_layer[layer_num]: del expert_weights_by_layer[layer_num] gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: # down_proj expert_weights_by_layer[layer_num][native_key][expert_num] = value @@ -499,19 +541,25 @@ def _from_hf_w_merged_experts( stacked = torch.stack(ordered, dim=0) stacked = stacked.to(self.dtype) - dtensor = create_dtensor_from_local(stacked, device_mesh, rank) - state_dict[native_key] = dtensor + state_dict[native_key] = create_dtensor_from_local(stacked, device_mesh, rank) - # Free completed expert tensors to release GPU memory + # See gate/up branch above for the cleanup rationale. + del ordered, stacked del expert_weights_by_layer[layer_num][native_key] if not expert_weights_by_layer[layer_num]: del expert_weights_by_layer[layer_num] gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: if not key.endswith("_scale_inv"): state_dict[key] = value + # Drop consumed entries so a subsequent from_hf (e.g. MTP merge after backbone) starts clean. + if consumed_inplace_keys: + self._inplace_loaded_native_keys -= consumed_inplace_keys + # Recombine any per-expert HF LoRA keys back to grouped format state_dict = self._recombine_lora_expert_keys(state_dict) @@ -527,6 +575,21 @@ def _convert_single_merged_expert_to_hf_split_experts( ) -> list[tuple[str, torch.Tensor]]: """Convert a single merged expert tensor from native format to split HuggingFace format. + When ``tensor`` is a model DTensor with a plain (non-DTensor) local + split — i.e. ``ep_shard == 1`` — the per-expert outputs are returned + as **non-contiguous strided views** into the local storage of the + model's grouped DTensor instead of newly-allocated contiguous copies. + DCP's ``target.copy_(source)`` then writes safetensors data directly + through the views into the model's storage, and + ``_from_hf_w_merged_experts`` skips the rebuild for the corresponding + native key (tracked in ``_inplace_loaded_native_keys``). For loads of + large MoE checkpoints this avoids tens of GB of per-expert + scratch on top of the already-materialized model. + + Save callers must materialize the views before serializing — + ``safetensors.torch.save`` rejects non-contiguous tensors. See + ``_materialize_to_hf_views_for_save`` in ``checkpointing.py``. + Args: fqn: Fully qualified name of the tensor in native format. tensor: The tensor to convert. @@ -537,38 +600,55 @@ def _convert_single_merged_expert_to_hf_split_experts( that forward arbitrary state-dict kwargs (e.g. ``exclude_key_regex``). Returns: - List of (fqn, tensor) tuples in HuggingFace format, or None if not an expert tensor + List of (fqn, tensor) tuples in HuggingFace format, or None if not an expert tensor. """ n_experts = self.moe_config.n_routed_experts inter_dim = self.moe_config.moe_inter_dim prefix = prefix_override if prefix_override is not None else self._hf_prefix expert_segment = self._expert_path_segment + from nemo_automodel.components.moe.state_dict_utils import ( + is_dtensor, + validate_dtensor_expert_sharding, + ) + if f".{expert_segment}.gate_and_up_projs" in fqn and fqn.endswith(".gate_and_up_projs"): layer_num = re.search(r"layers\.(\d+)", fqn).group(1) - from nemo_automodel.components.moe.state_dict_utils import ( - is_dtensor, - validate_dtensor_expert_sharding, - ) - if is_dtensor(tensor): validate_dtensor_expert_sharding(tensor, n_experts, f"gate_and_up_projs layer {layer_num}") splits = self._split_experts_weights(tensor, n_experts) + + # In-place views only engage when splits are plain (ep_shard==1). + inplace_ok = is_dtensor(tensor) and len(splits) > 0 and not is_dtensor(splits[0]) + if inplace_ok: + self._register_inplace_loaded_key(fqn, prefix_override) + result = [] for i, w in enumerate(splits): expert_id = self._last_expert_ids[i] if self._is_gated_moe: # Gated: split into gate_proj and up_proj - w_gate = w[:, :inter_dim].transpose(0, 1).contiguous() - w_up = w[:, inter_dim:].transpose(0, 1).contiguous() + if inplace_ok: + w_gate = w[:, :inter_dim].transpose(0, 1) + w_up = w[:, inter_dim:].transpose(0, 1) + else: + w_gate = w[:, :inter_dim].transpose(0, 1).contiguous() + w_up = w[:, inter_dim:].transpose(0, 1).contiguous() result.append((f"{prefix}layers.{layer_num}.{expert_segment}.{expert_id}.gate_proj.weight", w_gate)) result.append((f"{prefix}layers.{layer_num}.{expert_segment}.{expert_id}.up_proj.weight", w_up)) else: # Non-gated: only up_proj (tensor is [dim, inter_dim], not [dim, 2*inter_dim]) - w_up = w.transpose(0, 1).contiguous() + if inplace_ok: + w_up = w.transpose(0, 1) + else: + w_up = w.transpose(0, 1).contiguous() result.append((f"{prefix}layers.{layer_num}.{expert_segment}.{expert_id}.up_proj.weight", w_up)) + del splits + if not inplace_ok and isinstance(tensor, torch.Tensor) and not tensor.is_meta and torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() return result elif ( @@ -579,24 +659,32 @@ def _convert_single_merged_expert_to_hf_split_experts( ): layer_num = re.search(r"layers\.(\d+)", fqn).group(1) - from nemo_automodel.components.moe.state_dict_utils import ( - is_dtensor, - validate_dtensor_expert_sharding, - ) - if is_dtensor(tensor): validate_dtensor_expert_sharding(tensor, n_experts, f"down_projs (DeepEP) layer {layer_num}") splits = self._split_experts_weights(tensor, n_experts) + inplace_ok = is_dtensor(tensor) and len(splits) > 0 and not is_dtensor(splits[0]) + if inplace_ok: + self._register_inplace_loaded_key(fqn, prefix_override) + result = [] for i, w in enumerate(splits): expert_id = self._last_expert_ids[i] + if inplace_ok: + w_down = w.transpose(0, 1) + else: + w_down = w.transpose(0, 1).contiguous() result.append( ( f"{prefix}layers.{layer_num}.{expert_segment}.{expert_id}.down_proj.weight", - w.transpose(0, 1).contiguous(), + w_down, ) ) + # See gate_and_up branch above for the cleanup rationale. + del splits + if not inplace_ok and isinstance(tensor, torch.Tensor) and not tensor.is_meta and torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() return result # MoE expert LoRA keys: split grouped 3-D adapter tensors into per-expert diff --git a/nemo_automodel/recipes/llm/train_ft.py b/nemo_automodel/recipes/llm/train_ft.py index 61bd194714..d9714f01f5 100644 --- a/nemo_automodel/recipes/llm/train_ft.py +++ b/nemo_automodel/recipes/llm/train_ft.py @@ -154,13 +154,14 @@ def _uses_te_dot_product_attention(model_or_cfg): def _uses_thd_collater(cfg_dataloader): + """Return True if the dataloader's collate_fn is ``packed_sequence_thd_collater``. + + ``collate_fn`` ends in ``_fn``, so ConfigNode resolves the YAML dotted-path string to + the actual callable at load time — the value here is always the function, never a string. + """ from nemo_automodel.components.datasets.utils import packed_sequence_thd_collater - return ( - True - if hasattr(cfg_dataloader, "collate_fn") and cfg_dataloader.collate_fn == packed_sequence_thd_collater - else False - ) + return getattr(cfg_dataloader, "collate_fn", None) is packed_sequence_thd_collater def _should_precompute_pp_causal_masks(model_config: Any) -> bool: @@ -1373,15 +1374,19 @@ def _forward_backward_step( ) for k, v in batch.items() } + _thd_collater = _uses_thd_collater(self.cfg.dataloader) + # Gate THD/cu_seqlens processing on the dataset being THD-packed, not on TE + # attention being present on this rank: both TE attention and mamba need + # cu_seqlens, and gating on attention would drop PP stages with no attention + # layers (mamba+moe only) and leave cu_seqlens unbuilt downstream. + _use_te_value = _thd_collater + _num_chunks_value = _get_num_thd_chunks(self.pp_enabled, self.cfg) train_ctx, batch = make_cp_batch_and_ctx( self.device_mesh, batch, - use_te=_uses_te_dot_product_attention( - self.model_parts[0] if hasattr(self, "model_parts") else self.cfg.model - ) - and _uses_thd_collater(self.cfg.dataloader), + use_te=_use_te_value, padding_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0, - num_chunks=_get_num_thd_chunks(self.pp_enabled, self.cfg), + num_chunks=_num_chunks_value, ) labels = batch.pop("labels") fp8_ctx = self.te_fp8.maybe_te_autocast() if self.te_fp8 is not None else nullcontext() @@ -1405,7 +1410,16 @@ def _forward_backward_step( batch_filtered = { k: v for k, v in batch.items() if v is not None and not (isinstance(v, dict) and len(v) == 0) } - + # Hand the THD ``cu_seqlens`` to the PP loss to mask cross-sequence boundaries — + # the fallback when the model emits no per-microbatch seq_idx tail (which the loss + # prefers). One cu_seqlens encodes a single shared layout, so it is only correct at + # one pack/microbatch per step; the seq_idx tail handles differing per-microbatch boundaries. + cu_seqlens = batch_filtered.get("cu_seqlens") + if isinstance(cu_seqlens, torch.Tensor) and cu_seqlens.dim() == 2: + cu_seqlens = cu_seqlens.squeeze(0) # [1, T] -> [T] + pp_loss_fn = getattr(self.pp.info.schedule, "_loss_fn", None) if self.pp.info.has_last_stage else None + if pp_loss_fn is not None and hasattr(pp_loss_fn, "cu_seqlens"): + pp_loss_fn.cu_seqlens = cu_seqlens if is_train: # Use step for training (forward + backward) if self.pp.info.has_first_stage: @@ -1467,6 +1481,8 @@ def _forward_backward_step( model=model, scaling_factor=out.mtp_loss_scaling_factor, num_label_tokens=num_label_tokens, + # mask cross-boundary MTP label rolls in THD packing (matches the PP path) + cu_seqlens=batch.get("cu_seqlens"), ) loss_buffer.append(local_loss.clone().detach()) if is_train: diff --git a/tests/functional_tests/context_parallel/run_attention_cp.py b/tests/functional_tests/context_parallel/run_attention_cp.py index 7f923c76ea..8350a8fd31 100644 --- a/tests/functional_tests/context_parallel/run_attention_cp.py +++ b/tests/functional_tests/context_parallel/run_attention_cp.py @@ -51,6 +51,7 @@ # Shared helpers # --------------------------------------------------------------------------- + def dual_chunk_swap_unsplit(chunks_per_rank, cp_size, seq_dim=1): """Reconstruct full sequence from DualChunkSwap-ordered rank outputs.""" all_chunks = [None] * (2 * cp_size) @@ -157,6 +158,7 @@ def _compare_results( # Packed-sequence batch creation (used by qwen3_moe / deepseek_v3 thd_te) # --------------------------------------------------------------------------- + def create_packed_sequence_batch(batch_size, seq_lens_per_batch, device, padding_token_id=0): """ Create a packed sequence batch for testing. @@ -214,6 +216,7 @@ def create_packed_sequence_batch(batch_size, seq_lens_per_batch, device, padding # Model factory # --------------------------------------------------------------------------- + def get_model_config_and_attention(model_type, device): """Get model configuration and attention layer based on model type. @@ -323,7 +326,9 @@ def get_freqs_cis(position_ids, qkv_format, cp_size=1): attn_with_cp = MLA(config, backend).to(device).to(torch.bfloat16) def get_freqs_cis(position_ids, qkv_format, cp_size=1): - return freqs_cis_from_position_ids(position_ids, rope_freqs, qkv_format=qkv_format, for_fused_rope=True, cp_size=cp_size) + return freqs_cis_from_position_ids( + position_ids, rope_freqs, qkv_format=qkv_format, for_fused_rope=True, cp_size=cp_size + ) elif model_type == "nemotron_v3": @@ -349,6 +354,7 @@ def __init__(self): # NemotronV3 attention pair creation # --------------------------------------------------------------------------- + def _create_nemotron_v3_attn_pair(config, backend, device): """Create a pair of identical NemotronV3Attention modules with synced weights.""" from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Attention @@ -370,8 +376,8 @@ def _create_nemotron_v3_attn_pair(config, backend, device): # Config: bshd_te # --------------------------------------------------------------------------- -def run_bshd_te(model_type, config, rank, world_size, device, - attn_no_cp=None, attn_with_cp=None, get_freqs_cis=None): + +def run_bshd_te(model_type, config, rank, world_size, device, attn_no_cp=None, attn_with_cp=None, get_freqs_cis=None): """3D BSHD input with TE p2p CP and DualChunkSwap. Only supported for nemotron_v3 (qwen3_moe / deepseek_v3 do not use this config). @@ -465,8 +471,8 @@ def run_bshd_te(model_type, config, rank, world_size, device, # Config: thd_te # --------------------------------------------------------------------------- -def run_thd_te(model_type, config, rank, world_size, device, - attn_no_cp=None, attn_with_cp=None, get_freqs_cis=None): + +def run_thd_te(model_type, config, rank, world_size, device, attn_no_cp=None, attn_with_cp=None, get_freqs_cis=None): """THD input with TE p2p CP. For qwen3_moe / deepseek_v3: uses make_cp_batch_for_te + apply_cp flow. @@ -475,12 +481,12 @@ def run_thd_te(model_type, config, rank, world_size, device, if model_type == "nemotron_v3": return _run_thd_te_nemotron_v3(config, rank, world_size, device) else: - return _run_thd_te_qwen_deepseek(model_type, config, rank, world_size, device, - attn_no_cp, attn_with_cp, get_freqs_cis) + return _run_thd_te_qwen_deepseek( + model_type, config, rank, world_size, device, attn_no_cp, attn_with_cp, get_freqs_cis + ) -def _run_thd_te_qwen_deepseek(model_type, config, rank, world_size, device, - attn_no_cp, attn_with_cp, get_freqs_cis): +def _run_thd_te_qwen_deepseek(model_type, config, rank, world_size, device, attn_no_cp, attn_with_cp, get_freqs_cis): """THD test flow for qwen3_moe / deepseek_v3 (preserves original run_test behavior).""" try: import transformer_engine.pytorch # This creates transformer_engine_torch module @@ -521,7 +527,9 @@ def _run_thd_te_qwen_deepseek(model_type, config, rank, world_size, device, ) total_tokens_no_cp = batch_no_cp["input_ids"].shape[0] - x_no_cp = torch.randn(total_tokens_no_cp, config.hidden_size, device=device, dtype=torch.bfloat16, requires_grad=True) + x_no_cp = torch.randn( + total_tokens_no_cp, config.hidden_size, device=device, dtype=torch.bfloat16, requires_grad=True + ) freqs_cis_no_cp = get_freqs_cis(batch_no_cp["position_ids"], qkv_format="thd") @@ -537,7 +545,11 @@ def _run_thd_te_qwen_deepseek(model_type, config, rank, world_size, device, output_no_cp = attn_no_cp( x_no_cp, freqs_cis=freqs_cis_no_cp, - cu_seqlens=batch_no_cp["cu_seqlens"], + # thd_utils now emits ``cu_seqlens`` as REAL lengths and a separate + # ``cu_seqlens_padded``; TE attention operates on the padded token + # layout (the input is padded to slot width), so use the padded + # boundaries here — matching the CP path's _shard output. + cu_seqlens=batch_no_cp.get("cu_seqlens_padded", batch_no_cp["cu_seqlens"]), max_seqlen=max_seqlen_no_cp, qkv_format=batch_no_cp.get("qkv_format", "thd"), ) @@ -806,8 +818,8 @@ def _run_thd_te_nemotron_v3(config, rank, world_size, device): # Config: bshd_sdpa # --------------------------------------------------------------------------- -def run_bshd_sdpa(model_type, config, rank, world_size, device, - attn_no_cp=None, attn_with_cp=None, get_freqs_cis=None): + +def run_bshd_sdpa(model_type, config, rank, world_size, device, attn_no_cp=None, attn_with_cp=None, get_freqs_cis=None): """3D BSHD input with DTensor context_parallel() and SDPA backend. Only supported for nemotron_v3 (qwen3_moe / deepseek_v3 do not use this config). @@ -946,9 +958,7 @@ def main(): torch.cuda.manual_seed_all(42) # Get model configuration and attention layers - config, attn_no_cp, attn_with_cp, get_freqs_cis = get_model_config_and_attention( - args.model_type, device - ) + config, attn_no_cp, attn_with_cp, get_freqs_cis = get_model_config_and_attention(args.model_type, device) # Run selected configs and collect results results = {} @@ -957,7 +967,11 @@ def main(): runner = CONFIG_RUNNERS[config_name] try: results[config_name] = runner( - args.model_type, config, rank, world_size, device, + args.model_type, + config, + rank, + world_size, + device, attn_no_cp=attn_no_cp, attn_with_cp=attn_with_cp, get_freqs_cis=get_freqs_cis, diff --git a/tests/unit_tests/distributed/test_seq_idx_from_cu_seqlens.py b/tests/unit_tests/distributed/test_seq_idx_from_cu_seqlens.py new file mode 100644 index 0000000000..a632d97c23 --- /dev/null +++ b/tests/unit_tests/distributed/test_seq_idx_from_cu_seqlens.py @@ -0,0 +1,216 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pins the cu_seqlens → seq_idx derivation against a brute-force reference. + +Two production sites derive ``seq_idx`` via searchsorted on ``cu_seqlens[1:]``: + + * ``nemo_automodel/components/loss/mtp.py:calculate_mtp_loss`` — cross-seq + mask in the MTP loss. + * ``nemo_automodel/components/models/nemotron_v3/layers.py`` — mamba SSD + scan state-reset. + +Both must use ``side="right"`` (i.e. ``right=True``) so a position equal to a +boundary (``t == cu_seqlens[k]``, the FIRST token of sub-seq k) maps to k, +not k-1. The default ``side="left"`` is off-by-one at every internal +boundary. +""" + +from __future__ import annotations + +import random + +import pytest +import torch + + +def _brute_force_seq_idx(cu_seqlens: list[int], total_len: int) -> list[int]: + """Ground-truth reference: position t belongs to sub-seq k iff + ``cu_seqlens[k] <= t < cu_seqlens[k+1]``. + """ + K = len(cu_seqlens) - 1 + out: list[int] = [] + for t in range(total_len): + for k in range(K): + if cu_seqlens[k] <= t < cu_seqlens[k + 1]: + out.append(k) + break + else: # pragma: no cover — only triggers on malformed cu_seqlens + raise AssertionError(f"position {t} fits no sub-seq with cu_seqlens={cu_seqlens}") + return out + + +def _searchsorted_right(cu_seqlens: list[int]) -> list[int]: + """The production derivation: searchsorted on cu_seqlens[1:] with right=True.""" + cu = torch.tensor(cu_seqlens, dtype=torch.int32) + pos = torch.arange(int(cu[-1].item())) + return torch.searchsorted(cu[1:].contiguous(), pos, right=True).tolist() + + +# ───── Deterministic edge cases ───────────────────────────────────────────── + + +def test_single_subseq(): + """K=1: every position must be sub-seq 0.""" + cu = [0, 7] + assert _searchsorted_right(cu) == _brute_force_seq_idx(cu, 7) + assert _searchsorted_right(cu) == [0] * 7 + + +def test_two_equal_subseqs(): + """K=2, equal widths. The boundary position is 4 (start of sub-seq 1).""" + cu = [0, 4, 8] + ref = _brute_force_seq_idx(cu, 8) + assert ref == [0, 0, 0, 0, 1, 1, 1, 1] + assert _searchsorted_right(cu) == ref + + +def test_uneven_widths(): + """K=3 with widths 3, 2, 4.""" + cu = [0, 3, 5, 9] + ref = _brute_force_seq_idx(cu, 9) + assert ref == [0, 0, 0, 1, 1, 2, 2, 2, 2] + assert _searchsorted_right(cu) == ref + + +def test_unit_width_subseqs(): + """K=5 with width-1 sub-seqs — every position is itself a boundary.""" + cu = [0, 1, 2, 3, 4, 5] + ref = _brute_force_seq_idx(cu, 5) + assert ref == [0, 1, 2, 3, 4] + assert _searchsorted_right(cu) == ref + + +def test_large_then_small_widths(): + """Mixed widths: large slot followed by tiny ones.""" + cu = [0, 100, 101, 102, 103] + ref = _brute_force_seq_idx(cu, 103) + assert ref[:100] == [0] * 100 + assert ref[100:103] == [1, 2, 3] + assert _searchsorted_right(cu) == ref + + +def test_left_side_default_is_off_by_one(): + """Sanity: confirm the pre-fix default (side='left') is off by one at + every internal boundary. This is the bug the right=True fix addresses. + """ + cu_seqlens = [0, 4, 8] + cu_t = torch.tensor(cu_seqlens) + pos = torch.arange(8) + left = torch.searchsorted(cu_t[1:], pos).tolist() + right = torch.searchsorted(cu_t[1:], pos, right=True).tolist() + ref = _brute_force_seq_idx(cu_seqlens, 8) + + assert right == ref + # The left variant disagrees with the reference at exactly the + # internal-boundary positions. + diffs = [i for i, (a, b) in enumerate(zip(left, ref)) if a != b] + assert diffs == [4], f"expected mismatch only at position 4, got {diffs}" + + +# ───── Randomized exhaustive verification ─────────────────────────────────── + + +@pytest.mark.parametrize("seed", list(range(20))) +def test_random_shapes_match_brute_force(seed: int): + """For a randomized cu_seqlens with mixed widths, every position's + derived seq_idx must match the brute-force reference exactly. + """ + rng = random.Random(seed) + num_subseqs = rng.randint(2, 25) + widths = [rng.randint(1, 50) for _ in range(num_subseqs)] + cu = [0] + for w in widths: + cu.append(cu[-1] + w) + + ref = _brute_force_seq_idx(cu, cu[-1]) + derived = _searchsorted_right(cu) + assert derived == ref, f"seed={seed} cu_seqlens={cu}: {sum(1 for a, b in zip(derived, ref) if a != b)} mismatches" + + +@pytest.mark.parametrize("seed", list(range(5))) +def test_random_shapes_left_side_fails_at_boundaries(seed: int): + """Negative control: the pre-fix default (left side) ALWAYS mis-classifies + every internal boundary position. Verifies the magnitude of the bug. + """ + rng = random.Random(100 + seed) + num_subseqs = rng.randint(3, 15) + widths = [rng.randint(2, 30) for _ in range(num_subseqs)] + cu = [0] + for w in widths: + cu.append(cu[-1] + w) + + cu_t = torch.tensor(cu) + pos = torch.arange(cu[-1]) + left = torch.searchsorted(cu_t[1:], pos).tolist() + ref = _brute_force_seq_idx(cu, cu[-1]) + + # Every internal boundary t = cu[k] (k = 1..K-1) is mis-classified by + # left-side: it returns k-1 instead of k. + expected_buggy_positions = set(cu[1:-1]) # exclude the last (== total_len) + actual_buggy_positions = {i for i, (a, b) in enumerate(zip(left, ref)) if a != b} + assert actual_buggy_positions == expected_buggy_positions + + +# ───── Cross-check production call sites ──────────────────────────────────── + + +def test_production_site_loss_mtp_matches_brute_force(): + """Re-runs the exact derivation in + ``nemo_automodel.components.loss.mtp:calculate_mtp_loss`` (lines ~75-87) + against the brute-force reference, ensuring the production site uses + ``right=True``. + """ + cu_seqlens = torch.tensor([0, 5, 11, 17], dtype=torch.int32) + positions = torch.arange(17) + derived = torch.searchsorted(cu_seqlens[1:].contiguous(), positions, right=True).tolist() + ref = _brute_force_seq_idx(cu_seqlens.tolist(), 17) + assert derived == ref + + +def test_production_site_layers_mamba_matches_brute_force(): + """Re-runs the exact derivation in + ``nemo_automodel.components.models.nemotron_v3.layers`` (line ~330) + against the brute-force reference, ensuring the production site uses + ``right=True``. + """ + cu_seqlens = torch.tensor([0, 7, 13, 20], dtype=torch.int32) + total_len = int(cu_seqlens[-1].item()) + positions = torch.arange(total_len) + derived = torch.searchsorted(cu_seqlens[1:], positions, right=True).unsqueeze(0).to(torch.int32) + ref = torch.tensor(_brute_force_seq_idx(cu_seqlens.tolist(), total_len), dtype=torch.int32).unsqueeze(0) + assert torch.equal(derived, ref) + + +def test_searchsorted_call_sites_use_right_true(): + """Static-analysis style check: both production call sites must use + ``right=True`` (or equivalent ``side="right"``) on ``searchsorted`` over + ``cu_seqlens[1:]``. Catches regressions that revert to the default + ``side="left"``. + """ + import inspect + + from nemo_automodel.components.loss import mtp as _mtp_mod + from nemo_automodel.components.models.nemotron_v3 import layers as _layers_mod + + for mod in (_mtp_mod, _layers_mod): + src = inspect.getsource(mod) + # Each searchsorted call on cu_seqlens-derived array must have right=True + # (or side="right") on the same line. Strip comments first. + for raw_line in src.splitlines(): + code = raw_line.split("#", 1)[0] + if "searchsorted(" in code and ("cu_seqlens" in code or "cs[1:]" in code): + assert "right=True" in code or 'side="right"' in code, ( + f"{mod.__name__}: searchsorted on cu_seqlens without right=True: {raw_line.strip()}" + ) diff --git a/tests/unit_tests/distributed/test_thd_utils.py b/tests/unit_tests/distributed/test_thd_utils.py index 3195f314e8..b789e819e7 100644 --- a/tests/unit_tests/distributed/test_thd_utils.py +++ b/tests/unit_tests/distributed/test_thd_utils.py @@ -62,11 +62,15 @@ def test_with_multiple_packed_sequences(self): assert result["input_ids"].shape == (12,) assert result["position_ids"].shape == (12,) - # Check cu_seqlens - uses seq_lens_padded values for CP compatibility - # First batch: padded lengths [4, 2] -> cumsum [0, 4, 6] - # Second batch: padded lengths [3, 3] -> cumsum [6, 9, 12] - expected_cu_seqlens = torch.tensor([0, 4, 6, 9, 12], dtype=torch.int32) + # cu_seqlens uses real (unpadded) seq_lens; cu_seqlens_padded uses + # padded values. Both are emitted because they differ in multiple + # entries (inter-sub-seq padding) — TE will use pad_between_seqs=True. + # First batch: real [3, 2] -> cumsum [0, 3, 5], padded [4, 2] -> [0, 4, 6] + # Second batch: real [2, 3] -> [5, 7, 10], padded [3, 3] -> [6, 9, 12] + expected_cu_seqlens = torch.tensor([0, 3, 5, 7, 10], dtype=torch.int32) + expected_cu_seqlens_padded = torch.tensor([0, 4, 6, 9, 12], dtype=torch.int32) assert torch.equal(result["cu_seqlens"], expected_cu_seqlens) + assert torch.equal(result["cu_seqlens_padded"], expected_cu_seqlens_padded) def test_with_variable_num_sequences_and_padding(self): """Test with variable number of sequences per example (seq_lens padding with -1000).""" @@ -84,11 +88,14 @@ def test_with_variable_num_sequences_and_padding(self): assert result["input_ids"].shape == (12,) assert result["position_ids"].shape == (12,) - # Check cu_seqlens - uses seq_lens_padded values (filters out -1000) - # First batch: padded lengths [4, 2] -> cumsum [0, 4, 6] - # Second batch: padded length [6] (second is -1000, filtered) -> cumsum [6, 12] - expected_cu_seqlens = torch.tensor([0, 4, 6, 12], dtype=torch.int32) + # cu_seqlens uses real seq_lens (filtered for -1000); cu_seqlens_padded + # uses padded. Both are emitted because they differ in multiple entries. + # First batch: real [3, 2] -> [0, 3, 5], padded [4, 2] -> [0, 4, 6] + # Second batch: real [6] -> [5, 11], padded [6] -> [6, 12] + expected_cu_seqlens = torch.tensor([0, 3, 5, 11], dtype=torch.int32) + expected_cu_seqlens_padded = torch.tensor([0, 4, 6, 12], dtype=torch.int32) assert torch.equal(result["cu_seqlens"], expected_cu_seqlens) + assert torch.equal(result["cu_seqlens_padded"], expected_cu_seqlens_padded) def test_with_qkv_format_preservation(self): """Test that non-tensor keys like qkv_format are preserved.""" @@ -278,15 +285,18 @@ def test_chunking_with_packed_sequences(self): assert result["input_ids"].shape == (2, 12) assert result["labels"].shape == (2, 12) - # Check cu_seqlens for first chunk - uses seq_lens_padded values [4, 2] and [3, 3] - # First batch: [4, 2] -> cumsum [0, 4, 6] - # Second batch: [3, 3] -> cumsum [6, 9, 12] - expected_cu_seqlens_0 = torch.tensor([0, 4, 6, 9, 12], dtype=torch.int32) - assert torch.equal(result["cu_seqlens"][0], expected_cu_seqlens_0) - - # Check cu_seqlens for second chunk - uses seq_lens_padded values [4, 2] and [3, 3] - expected_cu_seqlens_1 = torch.tensor([0, 4, 6, 9, 12], dtype=torch.int32) - assert torch.equal(result["cu_seqlens"][1], expected_cu_seqlens_1) + # cu_seqlens uses real (unpadded) seq_lens; cu_seqlens_padded uses + # padded values. Both arrays present because they differ in multiple + # entries (inter-sub-seq padding). For each chunk: + # Real: [3, 2] -> [0, 3, 5]; [2, 3] -> [5, 7, 10] + # Padded: [4, 2] -> [0, 4, 6]; [3, 3] -> [6, 9, 12] + expected_cu_seqlens_chunk = torch.tensor([0, 3, 5, 7, 10], dtype=torch.int32) + expected_cu_padded_chunk = torch.tensor([0, 4, 6, 9, 12], dtype=torch.int32) + assert torch.equal(result["cu_seqlens"][0], expected_cu_seqlens_chunk) + assert torch.equal(result["cu_seqlens"][1], expected_cu_seqlens_chunk) + assert "cu_seqlens_padded" in result + assert torch.equal(result["cu_seqlens_padded"][0], expected_cu_padded_chunk) + assert torch.equal(result["cu_seqlens_padded"][1], expected_cu_padded_chunk) def test_chunking_with_embeddings(self): """Test chunking with 3D embeddings input.""" @@ -337,17 +347,20 @@ def test_variable_length_cu_seqlens_padding(self): # cu_seqlens should be [num_chunks, max_seqs_across_chunks+1] assert result["cu_seqlens"].shape[0] == 2 - # First chunk - uses seq_lens_padded values [4, 2] and [6] - # First batch: [4, 2] -> cumsum [0, 4, 6] - # Second batch: [6] -> cumsum [6, 12] - expected_cu_seqlens_0 = torch.tensor([0, 4, 6, 12], dtype=torch.int32) - assert torch.equal(result["cu_seqlens"][0], expected_cu_seqlens_0) - - # Second chunk - uses seq_lens_padded values [4] and [3, 3] - # Third batch: [4] -> cumsum [0, 4] - # Fourth batch: [3, 3] -> cumsum [4, 7, 10] - expected_cu_seqlens_1 = torch.tensor([0, 4, 7, 10], dtype=torch.int32) - assert torch.equal(result["cu_seqlens"][1], expected_cu_seqlens_1) + # cu_seqlens uses REAL (unpadded) seq_lens; cu_seqlens_padded uses + # padded values. Both arrays are stacked because at least one chunk + # emits cu_seqlens_padded (multiple per-chunk entries differ). + # Chunk 0: real [3, 2, 6] -> [0, 3, 5, 11]; padded [4, 2, 6] -> [0, 4, 6, 12] + # Chunk 1: real [4, 2, 3] -> [0, 4, 6, 9]; padded [4, 3, 3] -> [0, 4, 7, 10] + expected_cu_0 = torch.tensor([0, 3, 5, 11], dtype=torch.int32) + expected_cu_1 = torch.tensor([0, 4, 6, 9], dtype=torch.int32) + expected_cu_padded_0 = torch.tensor([0, 4, 6, 12], dtype=torch.int32) + expected_cu_padded_1 = torch.tensor([0, 4, 7, 10], dtype=torch.int32) + assert torch.equal(result["cu_seqlens"][0], expected_cu_0) + assert torch.equal(result["cu_seqlens"][1], expected_cu_1) + assert "cu_seqlens_padded" in result + assert torch.equal(result["cu_seqlens_padded"][0], expected_cu_padded_0) + assert torch.equal(result["cu_seqlens_padded"][1], expected_cu_padded_1) def test_single_chunk(self): """Test with num_chunks=1 (no actual chunking).""" @@ -524,3 +537,140 @@ def test_different_chunk_sizes(self, num_chunks): assert result["input_ids"].shape == (num_chunks, tokens_per_chunk) assert result["cu_seqlens"].shape[0] == num_chunks + + +class TestTrailingPadAbsorption: + """Tests for the trailing-pack-pad absorption in process_input_for_thd. + + The original captured bug: a "short" microbatch (5 sub-seqs of 112 in a + 1024-pack) had its trailing 464-token pad absorbed into the last + cu_seqlens slot (576 wide), while the collator told TE ``max_seqlen=112`` + — a documented TE-contract violation (``max_seqlen_q`` MUST be >= the + actual max slot width per ``fused_attn.h:548-551``). cuDNN-fused-attn-bwd + then wrote OOB. + + The fix: compute ``max_seqlen`` from the FINAL cu_seqlens (after + absorption), so the value handed to TE always reflects the true max slot + width. With this in place, absorption is contract-clean for any trailing + pad size, and the previous dummy-slot extension workaround is no longer + needed. Verified safe via + ``/opt/Automodel/te_bug_report/te_thd_repro_MINIMAL.py``: TE handles + a 576-wide slot cleanly when given truthful ``max_seqlen=576``. + """ + + def test_short_microbatch_absorbs_with_truthful_max_seqlen(self): + """Captured failing case (5×112 + 464 trailing pad): the absorbed + cu_seqlens last slot is 576 wide. ``max_seqlen`` must reflect that + post-absorption width, not the pre-absorption 112. + """ + packed = 1024 + sub = 112 + seq_lens = torch.tensor([[sub, sub, sub, sub, sub, -1000, -1000, -1000, -1000]]) + seq_lens_padded = torch.tensor([[sub, sub, sub, sub, sub + (packed - 5 * sub), + -1000, -1000, -1000, -1000]]) + batch = { + "input_ids": torch.zeros((1, packed), dtype=torch.long), + "labels": torch.zeros((1, packed), dtype=torch.long), + "position_ids": torch.arange(packed).unsqueeze(0), + "seq_lens": seq_lens, + "seq_lens_padded": seq_lens_padded, + } + result = process_input_for_thd(batch) + + # Absorption fires → cu_seqlens = [0,112,224,336,448,1024]; + # last slot is 576 (real 112 + trailing pad 464). + expected_cu = torch.tensor([0, 112, 224, 336, 448, 1024], dtype=torch.int32) + assert torch.equal(result["cu_seqlens"], expected_cu), ( + f"cu_seqlens should be absorbed: expected {expected_cu.tolist()}, " + f"got {result['cu_seqlens'].tolist()}" + ) + # cu_seqlens_padded is dropped (equal to cu_seqlens, gated out). + assert "cu_seqlens_padded" not in result + # CRITICAL: max_seqlen reflects the absorbed slot width (576), not + # the pre-absorption max real sub-seq length (112). This is what + # makes the layout TE-contract-clean. + assert int(result["max_seqlen"].item()) == 576, ( + f"max_seqlen should reflect post-absorption slot width 576; " + f"got {int(result['max_seqlen'].item())}" + ) + + def test_full_microbatch_absorbs_with_bumped_max_seqlen(self): + """Common-case (9×112 + 16 trailing pad): absorption fires and + ``max_seqlen`` reflects the absorbed last slot (128), not the + pre-absorption max (112). The full-pack perf path is preserved. + """ + packed = 1024 + sub = 112 + seq_lens = torch.tensor([[sub] * 9]) + seq_lens_padded = torch.tensor([[sub] * 8 + [sub + 16]]) + batch = { + "input_ids": torch.zeros((1, packed), dtype=torch.long), + "labels": torch.zeros((1, packed), dtype=torch.long), + "position_ids": torch.arange(packed).unsqueeze(0), + "seq_lens": seq_lens, + "seq_lens_padded": seq_lens_padded, + } + result = process_input_for_thd(batch) + + # Absorption fires → cu_seqlens[-1] == packed_size. + assert int(result["cu_seqlens"][-1].item()) == packed + # cu_seqlens_padded dropped. + assert "cu_seqlens_padded" not in result, ( + "cu_seqlens_padded should be omitted when absorption fired" + ) + # max_seqlen reflects the absorbed last slot width = 112 + 16 = 128. + assert int(result["max_seqlen"].item()) == 128, ( + f"max_seqlen should reflect post-absorption slot 128; " + f"got {int(result['max_seqlen'].item())}" + ) + + def test_split_into_chunks_mixed_short_and_full(self): + """Two-chunk batch where chunk 0 is a near-full pack (16 trailing + pad) and chunk 1 is short (464 trailing pad). Both absorb; their + max_seqlen values differ because the absorbed last-slot widths differ. + """ + packed = 1024 + sub = 112 + seq_lens = torch.tensor([ + [sub] * 9, + [sub, sub, sub, sub, sub, -1000, -1000, -1000, -1000], + ]) + seq_lens_padded = torch.tensor([ + [sub] * 8 + [sub + 16], + [sub, sub, sub, sub, sub + (packed - 5 * sub), -1000, -1000, -1000, -1000], + ]) + batch = { + "input_ids": torch.zeros((2, packed), dtype=torch.long), + "labels": torch.zeros((2, packed), dtype=torch.long), + "position_ids": torch.arange(packed).unsqueeze(0).expand(2, -1), + "seq_lens": seq_lens, + "seq_lens_padded": seq_lens_padded, + } + result = split_batch_into_thd_chunks(batch, num_chunks=2) + + assert "cu_seqlens" in result + # Both chunks absorbed (cu_seqlens_padded == cu_seqlens for both), + # so split_batch_into_thd_chunks omits the padded key. + assert "cu_seqlens_padded" not in result + + # Chunk 0 (full pack, absorbed) — last non-sentinel value == packed_size. + c0_cu = result["cu_seqlens"][0] + c0_real = c0_cu[c0_cu != -1000] + assert int(c0_real[-1].item()) == packed + c0_widths = c0_real[1:] - c0_real[:-1] + assert int(c0_widths.max().item()) == 128 # absorbed last slot + + # Chunk 1 (short, absorbed) — last non-sentinel value == packed_size, + # absorbed last slot is wider (576). + c1_cu = result["cu_seqlens"][1] + c1_real = c1_cu[c1_cu != -1000] + assert int(c1_real[-1].item()) == packed + c1_widths = c1_real[1:] - c1_real[:-1] + assert int(c1_widths.max().item()) == 576 # absorbed last slot + + # Per-chunk max_seqlen reflects each chunk's max slot width. + # split_batch_into_thd_chunks stacks them, so result["max_seqlen"] + # is a tensor of shape (2,). + assert result["max_seqlen"].shape == (2,) + assert int(result["max_seqlen"][0].item()) == 128 + assert int(result["max_seqlen"][1].item()) == 576 diff --git a/tests/unit_tests/loss/test_mtp_cross_boundary.py b/tests/unit_tests/loss/test_mtp_cross_boundary.py new file mode 100644 index 0000000000..194111612b --- /dev/null +++ b/tests/unit_tests/loss/test_mtp_cross_boundary.py @@ -0,0 +1,301 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MTP cross-sequence-boundary label masking. + +When MTP is enabled on a packed sequence, the loss at depth k uses labels +shifted left by ``k+1``. If the source position (``t + k + 1``) falls into a +*different* sub-sequence than position ``t``, the prediction would cross a +packing boundary — which is nonsensical — and must be masked out with +``ignore_index`` so it does not contribute to the loss. + +The cross-boundary logic lives at +``nemo_automodel/components/loss/mtp.py::calculate_mtp_loss`` lines 117-120 +and is driven by either an explicit ``seq_idx`` or a derived one from +``cu_seqlens``. +""" + +from __future__ import annotations + +from unittest import mock + +import pytest +import torch + +from nemo_automodel.components.loss.mtp import calculate_mtp_loss + +IGNORE = -100 + + +def _hand_masked_labels( + labels: torch.Tensor, + seq_idx: torch.Tensor, + depth: int, +) -> torch.Tensor: + """Compute the expected per-depth masked labels by hand. + + Mirrors the in-function logic so we can compare against it exactly: + - left-shift labels by depth+1 (trailing positions filled with 0) + - mask the trailing ``depth+1`` positions with ``ignore_index`` + - mask positions where ``seq_idx[t+depth+1] != seq_idx[t]`` + """ + shift = depth + 1 + if labels.dim() == 1: + L = labels.shape[0] + rolled = torch.cat([labels[shift:], torch.zeros(shift, dtype=labels.dtype)]) + out = rolled.clone() + out[-shift:] = IGNORE + rolled_seq = torch.cat([seq_idx[shift:], torch.zeros(shift, dtype=seq_idx.dtype)]) + out = torch.where(rolled_seq != seq_idx, torch.full_like(out, IGNORE), out) + return out + # 2D path + B, S = labels.shape + rolled = torch.cat([labels[:, shift:], torch.zeros(B, shift, dtype=labels.dtype)], dim=1) + out = rolled.clone() + out[:, -shift:] = IGNORE + rolled_seq = torch.cat([seq_idx[:, shift:], torch.zeros(B, shift, dtype=seq_idx.dtype)], dim=1) + out = torch.where(rolled_seq != seq_idx, torch.full_like(out, IGNORE), out) + return out + + +class _CaptureLoss: + """Mock loss callable that records the labels it received at each call.""" + + def __init__(self): + self.captured_labels: list[torch.Tensor] = [] + + def __call__(self, **kwargs): + self.captured_labels.append(kwargs["labels"].clone()) + return torch.zeros((), requires_grad=True) + + +def _run_capture( + *, + labels: torch.Tensor, + seq_idx: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + depths: int = 2, + hidden_dim: int = 4, +) -> list[torch.Tensor]: + """Run calculate_mtp_loss with a capture loss; return masked labels per depth.""" + # mtp_per_depth_h: D x [B, S, H] (or [T, H]) — only shape matters because + # calculate_loss is mocked. + if labels.dim() == 1: + h_shape = (labels.shape[0], hidden_dim) + else: + h_shape = (labels.shape[0], labels.shape[1], hidden_dim) + mtp_per_depth_h = [torch.zeros(h_shape, dtype=torch.float32, requires_grad=True) for _ in range(depths)] + cap = _CaptureLoss() + # Patch calculate_loss in the mtp module's namespace. + with mock.patch("nemo_automodel.components.loss.mtp.calculate_loss", side_effect=lambda loss_fn, **kw: cap(**kw)): + calculate_mtp_loss( + loss_fn=mock.MagicMock(), # signature only matters because calculate_loss is patched + mtp_per_depth_h=mtp_per_depth_h, + labels=labels, + model=mock.MagicMock(), + scaling_factor=1.0, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + ignore_index=IGNORE, + ) + return cap.captured_labels + + +def test_cross_boundary_masked_via_seq_idx_1d(): + """Two 4-token sub-seqs in an 8-token packed sample. With seq_idx supplied + directly, depth-0 must mask position 3 (rolls into sub-seq 1); depth-1 must + mask positions 2 and 3 (both roll into sub-seq 1). + """ + labels = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.long) + seq_idx = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long) + + captured = _run_capture(labels=labels, seq_idx=seq_idx, depths=2) + assert len(captured) == 2 + + # Depth 0 (shift=1): rolled labels = [2,3,4,5,6,7,8,0]; cross-seq at t=3 + # (rolled is from sub-seq 1 while t=3 is sub-seq 0). Trailing 1 position + # is also masked. So positions {3, 7} are IGNORE; rest are real. + d0 = captured[0] + assert d0.tolist() == [2, 3, 4, IGNORE, 6, 7, 8, IGNORE] + + # Depth 1 (shift=2): rolled labels = [3,4,5,6,7,8,0,0]; cross-seq at t∈{2,3} + # (sources at t=4,5 are sub-seq 1). Trailing 2 positions are also masked. + # So positions {2, 3, 6, 7} are IGNORE. + d1 = captured[1] + assert d1.tolist() == [3, 4, IGNORE, IGNORE, 7, 8, IGNORE, IGNORE] + + +def test_cross_boundary_masked_via_cu_seqlens_1d(): + """Same scenario but seq_idx is derived from cu_seqlens via searchsorted.""" + labels = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.long) + cu_seqlens = torch.tensor([0, 4, 8], dtype=torch.int32) + + captured = _run_capture(labels=labels, cu_seqlens=cu_seqlens, depths=2) + assert len(captured) == 2 + + # Derived seq_idx = searchsorted([4, 8], [0..7]) = [0,0,0,0,1,1,1,1]. + # Expected masks match the previous test. + assert captured[0].tolist() == [2, 3, 4, IGNORE, 6, 7, 8, IGNORE] + assert captured[1].tolist() == [3, 4, IGNORE, IGNORE, 7, 8, IGNORE, IGNORE] + + +def test_no_masking_when_seq_idx_is_constant(): + """If every token belongs to a single sub-sequence, cross-seq mask is a + no-op and only the trailing-shift mask applies. + """ + labels = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.long) + seq_idx = torch.zeros(8, dtype=torch.long) # all same sub-seq + + captured = _run_capture(labels=labels, seq_idx=seq_idx, depths=2) + + # Depth 0: only trailing 1 masked. + assert captured[0].tolist() == [2, 3, 4, 5, 6, 7, 8, IGNORE] + # Depth 1: only trailing 2 masked. + assert captured[1].tolist() == [3, 4, 5, 6, 7, 8, IGNORE, IGNORE] + + +def test_three_subseqs_uneven_widths(): + """Three sub-seqs of widths 3, 2, 4 in a 9-token pack. Verifies that + masking respects unequal boundaries at depth 0 and depth 2. + """ + labels = torch.arange(1, 10, dtype=torch.long) # [1..9] + seq_idx = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2], dtype=torch.long) + cu_seqlens = torch.tensor([0, 3, 5, 9], dtype=torch.int32) + + captured_seq = _run_capture(labels=labels, seq_idx=seq_idx, depths=3) + captured_cu = _run_capture(labels=labels, cu_seqlens=cu_seqlens, depths=3) + + # Hand-mask reference (validates the in-function logic against an + # independent re-implementation). + expected = [_hand_masked_labels(labels, seq_idx, d).tolist() for d in range(3)] + + for d in range(3): + assert captured_seq[d].tolist() == expected[d], f"seq_idx path mismatch at depth {d}" + assert captured_cu[d].tolist() == expected[d], f"cu_seqlens path mismatch at depth {d}" + + +def test_cross_boundary_2d_batch(): + """2D ``[B, S]`` labels. seq_idx is broadcast across the batch by the + function under test. Verifies the broadcasting + masking together. + """ + labels = torch.tensor( + [[10, 11, 12, 13, 14, 15, 16, 17], [20, 21, 22, 23, 24, 25, 26, 27]], + dtype=torch.long, + ) + seq_idx_1d = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.long) + + captured = _run_capture(labels=labels, seq_idx=seq_idx_1d, depths=2) + seq_idx_2d = seq_idx_1d.unsqueeze(0).expand(2, -1) + expected = [_hand_masked_labels(labels, seq_idx_2d, d) for d in range(2)] + for d in range(2): + assert torch.equal(captured[d], expected[d]), f"depth {d} mismatch" + + +def test_seq_idx_shape_mismatch_raises_under_2d(): + """If a 2D seq_idx is supplied whose shape does NOT match labels, the + function should raise — silent broadcasting would mask the wrong tokens + under PP chunking. + """ + labels = torch.zeros(2, 8, dtype=torch.long) + bad_seq_idx = torch.zeros(3, 8, dtype=torch.long) # batch dim doesn't match + with pytest.raises(ValueError, match="seq_idx.shape"): + _run_capture(labels=labels, seq_idx=bad_seq_idx, depths=1) + + +def test_depth_beyond_subseq_length_fully_masked(): + """A 2-token sub-seq at depth=2 has no in-bounds rolled label. The mask + must be IGNORE everywhere for that sub-seq's positions. + """ + labels = torch.tensor([1, 2, 3, 4], dtype=torch.long) + # Two 2-token sub-seqs. + seq_idx = torch.tensor([0, 0, 1, 1], dtype=torch.long) + + captured = _run_capture(labels=labels, seq_idx=seq_idx, depths=3) + + # Depth 2 (shift=3): rolled = [4, 0, 0, 0]; trailing 3 → IGNORE. + # Position 0 rolls to position 3 (sub-seq 1) ≠ sub-seq 0 → IGNORE. + # So all 4 positions are IGNORE. + assert captured[2].tolist() == [IGNORE, IGNORE, IGNORE, IGNORE] + + +def test_thd_flat_labels_squeeze_3d_hidden_states(): + """Regression: under THD packing the model unsqueezes ``mtp_per_depth_h`` + back to ``[1, T, H]`` (model.py post-MTP-forward), but the recipe pops 1D + ``[T]`` labels from the THD-flattened batch. ``cut_cross_entropy`` asserts + ``hidden_states.shape[:-1] == labels.shape``, so calculate_mtp_loss must + squeeze the synthetic batch axis when labels are 1D. + """ + T, H = 8, 4 + labels = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.long) # (T,) + # Mimic model.py:790: per-depth hidden states are unsqueezed to (1, T, H). + mtp_per_depth_h = [torch.zeros((1, T, H), dtype=torch.float32, requires_grad=True) for _ in range(2)] + + captured_hidden: list[torch.Tensor] = [] + captured_labels: list[torch.Tensor] = [] + + def _capture(loss_fn, **kw): + captured_hidden.append(kw["hidden_states"]) + captured_labels.append(kw["labels"]) + return torch.zeros((), requires_grad=True) + + with mock.patch("nemo_automodel.components.loss.mtp.calculate_loss", side_effect=_capture): + from nemo_automodel.components.loss.linear_ce import FusedLinearCrossEntropy + + # Pass a FusedLinearCrossEntropy instance so calculate_mtp_loss takes + # the hidden_states branch (the bug only manifests there). + calculate_mtp_loss( + loss_fn=FusedLinearCrossEntropy(), + mtp_per_depth_h=mtp_per_depth_h, + labels=labels, + model=mock.MagicMock(), + scaling_factor=1.0, + ignore_index=IGNORE, + ) + + assert len(captured_hidden) == 2 + for d, (h, lab) in enumerate(zip(captured_hidden, captured_labels)): + # After the fix, the synthetic batch axis must be squeezed so + # h.shape[:-1] == lab.shape (cce's invariant). + assert h.shape == (T, H), f"depth {d}: expected hidden_states (T, H), got {tuple(h.shape)}" + assert lab.shape == (T,), f"depth {d}: expected labels (T,), got {tuple(lab.shape)}" + assert h.shape[:-1] == lab.shape + + +def test_2d_labels_and_3d_hidden_states_unchanged(): + """Sanity: when labels are already 2D ``[B, S]`` and hidden states are + ``[B, S, H]`` (the BSHD path), the reconciliation must be a no-op.""" + B, S, H = 2, 8, 4 + labels = torch.arange(1, B * S + 1, dtype=torch.long).view(B, S) + mtp_per_depth_h = [torch.zeros((B, S, H), dtype=torch.float32, requires_grad=True) for _ in range(2)] + + captured_hidden: list[torch.Tensor] = [] + + def _capture(loss_fn, **kw): + captured_hidden.append(kw["hidden_states"]) + return torch.zeros((), requires_grad=True) + + with mock.patch("nemo_automodel.components.loss.mtp.calculate_loss", side_effect=_capture): + from nemo_automodel.components.loss.linear_ce import FusedLinearCrossEntropy + + calculate_mtp_loss( + loss_fn=FusedLinearCrossEntropy(), + mtp_per_depth_h=mtp_per_depth_h, + labels=labels, + model=mock.MagicMock(), + scaling_factor=1.0, + ignore_index=IGNORE, + ) + + for d, h in enumerate(captured_hidden): + assert h.shape == (B, S, H), f"depth {d}: BSHD path should be unchanged, got {tuple(h.shape)}" diff --git a/tests/unit_tests/models/common/test_mtp_rolling.py b/tests/unit_tests/models/common/test_mtp_rolling.py new file mode 100644 index 0000000000..687690915c --- /dev/null +++ b/tests/unit_tests/models/common/test_mtp_rolling.py @@ -0,0 +1,220 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pins MTPModule's cumulative left-rolling of input_ids and position_ids. + +At depth ``k`` the MTP module embeds the token originally at position +``t + k + 1`` for slot ``t``. This is implemented by cumulatively rolling +``cur_input_ids`` (and ``cur_position_ids``) left by one at each depth in +``MTPModule.forward`` (see ``components/models/common/mtp/mtp.py``). + +These tests intercept what reaches each sublayer and verify the rolled +inputs match a hand-rolled reference. The label-rolling on the loss side is +covered separately by ``tests/unit_tests/loss/test_mtp_cross_boundary.py``. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +from nemo_automodel.components.models.common.mtp.mtp import MTPConfig, MTPModule, roll_tensor + + +class _RecordingSublayer(nn.Module): + """Sublayer that records its kwargs and returns hidden_states unchanged.""" + + def __init__(self): + super().__init__() + self.calls: list[dict] = [] + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + # Detach + clone so a later in-place op can't mutate what we recorded. + rec = {} + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + rec[k] = v.detach().clone() + else: + rec[k] = v + self.calls.append(rec) + return hidden_states + + +def _build_module(num_depths: int = 3, pattern_length: int = 1) -> MTPModule: + """Build an MTPModule with recording sublayers (no real attention/MoE).""" + cfg = MTPConfig( + num_layers=num_depths, + layer_pattern="A" * pattern_length, + loss_scaling_factor=1.0, + use_repeated_layer=False, + ) + block_types = ["recording"] * pattern_length + + def factory(global_idx, depth, sublayer_idx, block_type, has_fusion, has_final_norm): + return _RecordingSublayer() + + return MTPModule(cfg, block_types, factory) + + +def _embed_fn_identity(input_ids: torch.LongTensor) -> torch.Tensor: + """Embedding stub: cast int IDs to float, keep the seq dim intact. + + The MTP module passes the embedded result as ``embed_input`` to sublayer 0 + of each depth; we recover the un-embedded IDs by casting back to long. + """ + return input_ids.to(torch.float32).unsqueeze(-1) # shape [..., S, 1] + + +def _extract_embed_ids(rec_embed: torch.Tensor) -> list[int]: + """Inverse of ``_embed_fn_identity``: pull the original IDs out.""" + return rec_embed.squeeze(-1).to(torch.long).tolist() + + +def test_cumulative_input_ids_rolling_1d(): + """Depth k sees input_ids rolled left by ``k+1`` (cumulative).""" + mtp = _build_module(num_depths=3, pattern_length=1) + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.long) + hidden = torch.zeros(input_ids.shape[0], 4) + + mtp.forward(hidden, input_ids=input_ids, embed_fn=_embed_fn_identity) + + # 3 sublayers (1 per depth), each recorded one call. + sublayer_calls = [layer.calls[0] for layer in mtp.layers] + assert len(sublayer_calls) == 3 + + for depth in range(3): + expected_ids = roll_tensor(input_ids, shifts=-(depth + 1), dim=-1).tolist() + got_ids = _extract_embed_ids(sublayer_calls[depth]["embed_input"]) + assert got_ids == expected_ids, ( + f"depth {depth}: expected ids rolled by -{depth + 1}, got {got_ids} (expected {expected_ids})" + ) + + +def test_cumulative_position_ids_rolling_1d(): + """Depth k sees position_ids rolled left by ``k+1`` (cumulative).""" + mtp = _build_module(num_depths=3, pattern_length=1) + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.long) + position_ids = torch.arange(8, dtype=torch.long) + hidden = torch.zeros(input_ids.shape[0], 4) + + mtp.forward(hidden, input_ids=input_ids, embed_fn=_embed_fn_identity, position_ids=position_ids) + + for depth in range(3): + expected_pos = roll_tensor(position_ids, shifts=-(depth + 1), dim=-1).tolist() + got_pos = mtp.layers[depth].calls[0]["position_ids"].tolist() + assert got_pos == expected_pos, ( + f"depth {depth}: expected position_ids rolled by -{depth + 1}, got {got_pos} (expected {expected_pos})" + ) + + +def test_cumulative_rolling_2d_batch(): + """Per-row cumulative rolling in 2D ``[B, S]`` mode.""" + mtp = _build_module(num_depths=2, pattern_length=1) + input_ids = torch.tensor( + [[10, 11, 12, 13, 14, 15, 16, 17], + [20, 21, 22, 23, 24, 25, 26, 27]], + dtype=torch.long, + ) + position_ids = torch.arange(8, dtype=torch.long).unsqueeze(0).expand(2, -1).contiguous() + hidden = torch.zeros(2, 8, 4) + + mtp.forward(hidden, input_ids=input_ids, embed_fn=_embed_fn_identity, position_ids=position_ids) + + for depth in range(2): + expected_ids = roll_tensor(input_ids, shifts=-(depth + 1), dim=-1) + got_ids = mtp.layers[depth].calls[0]["embed_input"].squeeze(-1).to(torch.long) + assert torch.equal(got_ids, expected_ids), ( + f"depth {depth}: input_ids rolling mismatch in 2D batch" + ) + + expected_pos = roll_tensor(position_ids, shifts=-(depth + 1), dim=-1) + got_pos = mtp.layers[depth].calls[0]["position_ids"] + assert torch.equal(got_pos, expected_pos), ( + f"depth {depth}: position_ids rolling mismatch in 2D batch" + ) + + +def test_multi_sublayer_per_depth_sees_same_rolled_inputs(): + """When pattern_length > 1, all sublayers of a single depth see the + rolled inputs from THAT depth (no further intra-depth rolling). + """ + mtp = _build_module(num_depths=2, pattern_length=2) + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.long) + position_ids = torch.arange(8, dtype=torch.long) + hidden = torch.zeros(input_ids.shape[0], 4) + + mtp.forward(hidden, input_ids=input_ids, embed_fn=_embed_fn_identity, position_ids=position_ids) + + # 4 sublayers (2 depths × 2 sublayers/depth). + assert len(mtp.layers) == 4 + for depth in range(2): + for sub_in_depth in range(2): + flat = depth * 2 + sub_in_depth + got_pos = mtp.layers[flat].calls[0]["position_ids"].tolist() + expected_pos = roll_tensor(position_ids, shifts=-(depth + 1), dim=-1).tolist() + assert got_pos == expected_pos, ( + f"depth {depth} sublayer {sub_in_depth}: position_ids mismatch" + ) + # Only sublayer 0 of each depth gets embed_input. + assert "embed_input" in mtp.layers[0].calls[0] + assert "embed_input" not in mtp.layers[1].calls[0] + assert "embed_input" in mtp.layers[2].calls[0] + assert "embed_input" not in mtp.layers[3].calls[0] + + +def test_precomputed_embed_inputs_path_skips_token_rolling(): + """When ``embed_inputs`` is supplied, the per-depth token rolling is + skipped — the caller has already prepared the rolled embeddings. + Position_ids rolling still happens (it's independent of embed source). + """ + mtp = _build_module(num_depths=2, pattern_length=1) + # Two pre-computed embeddings (one per depth) marked with distinct values + # so we can tell them apart in the captured embed_input. + emb0 = torch.full((8, 4), 11.0) + emb1 = torch.full((8, 4), 22.0) + position_ids = torch.arange(8, dtype=torch.long) + hidden = torch.zeros(8, 4) + + mtp.forward(hidden, embed_inputs=(emb0, emb1), position_ids=position_ids) + + # Depth 0 sees emb0 verbatim; depth 1 sees emb1 verbatim. + assert torch.equal(mtp.layers[0].calls[0]["embed_input"], emb0) + assert torch.equal(mtp.layers[1].calls[0]["embed_input"], emb1) + + # Position_ids still roll cumulatively even on the embed_inputs path. + for depth in range(2): + expected_pos = roll_tensor(position_ids, shifts=-(depth + 1), dim=-1).tolist() + got_pos = mtp.layers[depth].calls[0]["position_ids"].tolist() + assert got_pos == expected_pos, f"depth {depth}: position_ids rolling expected on embed_inputs path" + + +def test_trailing_positions_zero_after_roll(): + """The roll is left-shift with trailing zeros (not wrap-around). At depth + k, the last ``k+1`` slots of the rolled tensor are zero — these are the + positions for which there is no valid future token to predict. + """ + mtp = _build_module(num_depths=3, pattern_length=1) + input_ids = torch.arange(1, 9, dtype=torch.long) # [1,2,3,4,5,6,7,8] + hidden = torch.zeros(8, 4) + + mtp.forward(hidden, input_ids=input_ids, embed_fn=_embed_fn_identity) + + for depth in range(3): + ids = _extract_embed_ids(mtp.layers[depth].calls[0]["embed_input"]) + n_trailing = depth + 1 + assert ids[-n_trailing:] == [0] * n_trailing, ( + f"depth {depth}: expected {n_trailing} trailing zeros, got tail {ids[-n_trailing:]}" + ) + # The remaining prefix should match the original IDs shifted by n_trailing. + assert ids[:-n_trailing] == input_ids[n_trailing:].tolist() diff --git a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py index 4fb658a155..709f173e5b 100644 --- a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py +++ b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py @@ -27,13 +27,13 @@ try: import mamba_ssm + _has_mamba_ssm = True except ImportError: _has_mamba_ssm = False skip_if_no_mamba = pytest.mark.skipif( - not torch.cuda.is_available() or not _has_mamba_ssm, - reason="CUDA and mamba_ssm required for Mamba Triton kernels" + not torch.cuda.is_available() or not _has_mamba_ssm, reason="CUDA and mamba_ssm required for Mamba Triton kernels" ) @@ -376,9 +376,7 @@ def test_block_init_weights_attention(self, config, backend): block.init_weights(buffer_device=device) # Verify biases are zeroed - assert torch.allclose( - block.mixer.q_proj.bias, torch.zeros_like(block.mixer.q_proj.bias) - ) + assert torch.allclose(block.mixer.q_proj.bias, torch.zeros_like(block.mixer.q_proj.bias)) def test_block_init_weights_mlp(self, config, backend): """Test weight initialization for MLP block.""" @@ -389,9 +387,7 @@ def test_block_init_weights_mlp(self, config, backend): block.init_weights(buffer_device=device) # Weights should be initialized (not all zeros) - assert not torch.allclose( - block.mixer.up_proj.weight, torch.zeros_like(block.mixer.up_proj.weight) - ) + assert not torch.allclose(block.mixer.up_proj.weight, torch.zeros_like(block.mixer.up_proj.weight)) def test_block_uses_relu2_for_mlp(self, config, backend): """Test that MLP block uses relu2 activation by default.""" @@ -562,3 +558,117 @@ def test_block_mlp_no_cache_args(self, config, backend): hidden = torch.randn(2, 8, config.hidden_size, dtype=torch.bfloat16) out = block(hidden) assert out.shape == (2, 8, config.hidden_size) + + +class TestMambaMixerSeqIdxConstruction: + """Regression tests for the seq_idx construction in NemotronV3Mamba2Mixer. + + The mamba kernel asserts ``seq_idx.shape == (batch_size, seqlen)``. When + the model receives a (B, S, H) BSHD batch with a globally-cumulated + ``cu_seqlens`` (derived in ``model.py`` from a 2D attention_mask), the + pre-fix mixer produced ``seq_idx`` of shape ``(1, S)`` and crashed. After + the fix the mixer must produce a per-row ``seq_idx`` of shape ``(B, S)``. + """ + + @pytest.fixture + def config(self): + # Small dimensions so CPU init is fast. + return MockNemotronV3Config( + hidden_size=64, + mamba_num_heads=4, + mamba_head_dim=16, + n_groups=1, + ssm_state_size=8, + chunk_size=8, + conv_kernel=4, + ) + + @staticmethod + def _patch_mamba_kernel(monkeypatch, capture): + """Install a fake ``mamba_split_conv1d_scan_combined`` that records + the ``seq_idx`` kwarg and returns a correctly-shaped output tensor.""" + import sys + import types + + def _fake_kernel(projected_states, *args, **kwargs): + capture["seq_idx"] = kwargs.get("seq_idx", None) + # Out shape matches the projected gate path: (B, S, intermediate_size) + # after the outproj_weight matmul → (B, S, hidden_size). + outproj_weight = kwargs["outproj_weight"] + B, S, _ = projected_states.shape + return projected_states.new_zeros((B, S, outproj_weight.shape[0])) + + # Stub the import chain so the deferred ``from mamba_ssm.ops.triton. + # ssd_combined import mamba_split_conv1d_scan_combined`` resolves to + # our fake without requiring the real package. + for name in ("mamba_ssm", "mamba_ssm.ops", "mamba_ssm.ops.triton"): + if name not in sys.modules: + monkeypatch.setitem(sys.modules, name, types.ModuleType(name)) + fake_ssd = types.ModuleType("mamba_ssm.ops.triton.ssd_combined") + fake_ssd.mamba_split_conv1d_scan_combined = _fake_kernel + monkeypatch.setitem(sys.modules, "mamba_ssm.ops.triton.ssd_combined", fake_ssd) + + def test_bshd_b_gt_1_with_cu_seqlens_yields_per_row_seq_idx(self, monkeypatch, config): + from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Mamba2Mixer + + mixer = NemotronV3Mamba2Mixer(config, layer_idx=0) + + B, S = 2, 16 + hidden_states = torch.randn(B, S, config.hidden_size) + # cu_seqlens as model.py derives it from a 2D attention_mask: a global + # cumsum across the batch (one boundary per row). + cu_seqlens = torch.tensor([0, S, 2 * S], dtype=torch.int32) + + capture: dict = {} + self._patch_mamba_kernel(monkeypatch, capture) + mixer.forward(hidden_states, cu_seqlens=cu_seqlens) + + seq_idx = capture["seq_idx"] + assert seq_idx is not None, "mamba kernel did not receive a seq_idx" + assert seq_idx.shape == (B, S), f"expected (B, S)=({B}, {S}), got {tuple(seq_idx.shape)}" + assert seq_idx.dtype == torch.int32 + + def test_bshd_b_eq_1_with_cu_seqlens_keeps_flat_construction(self, monkeypatch, config): + """When B == 1 (THD flat layout), the legacy searchsorted construction + still applies and should yield shape (1, S).""" + from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Mamba2Mixer + + mixer = NemotronV3Mamba2Mixer(config, layer_idx=0) + + B, S = 1, 16 + hidden_states = torch.randn(B, S, config.hidden_size) + # Two intra-row sub-sequences: boundaries at positions 8 and 16. + cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32) + + capture: dict = {} + self._patch_mamba_kernel(monkeypatch, capture) + mixer.forward(hidden_states, cu_seqlens=cu_seqlens) + + seq_idx = capture["seq_idx"] + assert seq_idx is not None + assert seq_idx.shape == (1, S) + # right=True so boundary positions map to the new sub-seq. + assert seq_idx[0, :8].tolist() == [0] * 8 + assert seq_idx[0, 8:].tolist() == [1] * 8 + + def test_bshd_b_gt_1_passes_through_upstream_seq_idx(self, monkeypatch, config): + """If seq_idx is supplied upstream (e.g. via _packed_seq_ids), the + construction path must NOT override it — neat-packing semantics.""" + from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Mamba2Mixer + + mixer = NemotronV3Mamba2Mixer(config, layer_idx=0) + + B, S = 2, 16 + hidden_states = torch.randn(B, S, config.hidden_size) + # Pretend the wrapper has already set seq_idx from _packed_seq_ids. + upstream = torch.tensor( + [[0] * 8 + [1] * 8, [0] * 6 + [1] * 5 + [2] * 5], + dtype=torch.int32, + ) + + capture: dict = {} + self._patch_mamba_kernel(monkeypatch, capture) + mixer.forward(hidden_states, cu_seqlens=torch.tensor([0, 16, 32], dtype=torch.int32), seq_idx=upstream) + + seq_idx = capture["seq_idx"] + assert seq_idx is upstream, "upstream seq_idx must not be replaced by the construction path" diff --git a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_mtp.py b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_mtp.py index 6aecf01199..4c4ffecdf5 100644 --- a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_mtp.py +++ b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_mtp.py @@ -28,7 +28,11 @@ MTPConfig, roll_tensor, ) -from nemo_automodel.components.models.nemotron_v3.mtp import parse_mtp_layer_pattern +from nemo_automodel.components.models.nemotron_v3.mtp import ( + _resolve_block_types_per_sublayer, + build_mtp_config_from_hf, + parse_mtp_layer_pattern, +) class MockNemotronV3Config: @@ -114,6 +118,48 @@ def test_empty_pattern_raises(self): with pytest.raises(ValueError, match="empty"): parse_mtp_layer_pattern("") + def test_resolve_from_symbol_pattern(self): + """Super-V3 path: ``mtp_hybrid_override_pattern`` is honored.""" + cfg = MockNemotronV3Config(num_nextn_predict_layers=1, mtp_hybrid_override_pattern="*E") + cfg.mtp_layers_block_type = None + assert _resolve_block_types_per_sublayer(cfg) == ["attention", "moe"] + + def test_resolve_from_layers_block_type(self): + """List-form path: ``mtp_layers_block_type`` list-of-strings is honored.""" + cfg = MockNemotronV3Config(num_nextn_predict_layers=1) + cfg.mtp_hybrid_override_pattern = None + cfg.mtp_layers_block_type = ["attention", "moe"] + assert _resolve_block_types_per_sublayer(cfg) == ["attention", "moe"] + + def test_resolve_returns_none_when_both_absent(self): + cfg = MockNemotronV3Config() + cfg.mtp_hybrid_override_pattern = None + cfg.mtp_layers_block_type = None + assert _resolve_block_types_per_sublayer(cfg) is None + + def test_resolve_rejects_unknown_block_type_in_list(self): + cfg = MockNemotronV3Config() + cfg.mtp_hybrid_override_pattern = None + cfg.mtp_layers_block_type = ["attention", "bogus"] + with pytest.raises(ValueError, match="Unknown MTP block type"): + _resolve_block_types_per_sublayer(cfg) + + def test_build_mtp_config_from_layers_block_type(self): + """List-form config (no symbol pattern) yields an enabled MTPConfig. + + Regression for the ``mtp_layers_block_type`` fallback added so that + checkpoints lacking ``mtp_hybrid_override_pattern`` can build MTP + without raising ``MTP layer pattern is empty``. + """ + cfg = MockNemotronV3Config(num_nextn_predict_layers=1) + cfg.mtp_hybrid_override_pattern = None + cfg.mtp_layers_block_type = ["attention", "moe"] + mtp_config = build_mtp_config_from_hf(cfg) + assert mtp_config.enabled + assert mtp_config.num_layers == 1 + assert mtp_config.pattern_length == 2 + assert mtp_config.total_sublayers == 2 + class TestRollTensor: def test_left_shift_zeros_trailing(self): diff --git a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_pp_mtp.py b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_pp_mtp.py new file mode 100644 index 0000000000..0fdb3fc61d --- /dev/null +++ b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_pp_mtp.py @@ -0,0 +1,392 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline-parallel + MTP wiring for NemotronH (nemotron_v3). + +Mirrors ``tests/unit_tests/models/deepseek_v4/test_deepseek_v4_mtp.py``'s +``TestPipelineHooks`` and PP-forward cases, adapted for the nemotron-h +hybrid Mamba/Attention/MoE layout with plain ``[B, S, H]`` inter-stage +tensors (no HC stream) and ``mtp_layers_block_type`` list-form MTP pattern. +""" + +import types + +import pytest +import torch + +from nemo_automodel.components.models.common import BackendConfig + +# Reuse the CPU-friendly Mock config from the existing MTP test module. +from tests.unit_tests.models.nemotron_v3.test_nemotron_v3_mtp import MockNemotronV3Config + + +@pytest.fixture +def backend(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + enable_deepep=False, + fake_balanced_gate=True, + enable_hf_state_dict_adapter=False, + ) + + +def _make_model(backend, *, mtp_layers=0, mtp_pattern="", mtp_layers_block_type=None, **cfg_overrides): + from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM + + cfg = MockNemotronV3Config( + num_nextn_predict_layers=mtp_layers, + mtp_hybrid_override_pattern=mtp_pattern, + **cfg_overrides, + ) + if mtp_layers_block_type is not None: + cfg.mtp_layers_block_type = mtp_layers_block_type + model = NemotronHForCausalLM(cfg, backend=backend) + return model.to(torch.bfloat16), cfg + + +# --------------------------------------------------------------------------- +# Stage-detection + helpers +# --------------------------------------------------------------------------- + + +class TestIsPipelineParallelStage: + def test_full_model_is_not_pp_stage(self, backend): + model, _ = _make_model(backend) + assert model._is_pipeline_parallel_stage() is False + + def test_missing_lm_head_marks_pp(self, backend): + model, _ = _make_model(backend) + model.lm_head = None + assert model._is_pipeline_parallel_stage() is True + + def test_missing_embed_tokens_marks_pp(self, backend): + model, _ = _make_model(backend) + model.model.embed_tokens = None + assert model._is_pipeline_parallel_stage() is True + + def test_trimmed_layer_count_marks_pp(self, backend): + model, _ = _make_model(backend) + # Pop one layer to simulate the splitter trimming. + keys = list(model.model.layers.keys()) + del model.model.layers[keys[-1]] + assert model._is_pipeline_parallel_stage() is True + + +class TestBuildMTPEmbedInputsForPP: + def test_rolls_input_ids_and_embeds_per_depth(self, backend): + model, cfg = _make_model( + backend, + mtp_layers=2, + mtp_layers_block_type=["attention", "moe"], + ) + # Deterministic embedding for assertion. + with torch.no_grad(): + model.model.embed_tokens.weight.copy_( + torch.arange(cfg.vocab_size * cfg.hidden_size, dtype=torch.float32).view( + cfg.vocab_size, cfg.hidden_size + ) + ) + + input_ids = torch.tensor([[10, 11, 12, 13]]) + out = model._build_mtp_embed_inputs_for_pp(input_ids) + assert len(out) == 2 + expected_d0_ids = torch.tensor([[11, 12, 13, 0]]) + expected_d1_ids = torch.tensor([[12, 13, 0, 0]]) + torch.testing.assert_close(out[0], model.model.embed_tokens(expected_d0_ids)) + torch.testing.assert_close(out[1], model.model.embed_tokens(expected_d1_ids)) + + +# --------------------------------------------------------------------------- +# customize_pipeline_stage_modules +# --------------------------------------------------------------------------- + + +class TestCustomizePipelineStageModules: + def test_appends_mtp_to_last_stage_only(self, backend): + model, _ = _make_model( + backend, + mtp_layers=1, + mtp_layers_block_type=["attention", "moe"], + ) + stages = [ + ["model.embed_tokens", "model.layers.0", "model.layers.1"], + ["model.layers.2", "model.layers.3", "model.norm", "lm_head"], + ] + out = model.customize_pipeline_stage_modules(stages, layers_prefix="model.", text_model=model.model) + assert "mtp" not in out[0] + assert "mtp" in out[-1] + + def test_no_mtp_when_disabled(self, backend): + model, _ = _make_model(backend) + stages = [["model.embed_tokens", "model.layers.0"], ["model.norm", "lm_head"]] + out = model.customize_pipeline_stage_modules(stages, layers_prefix="model.", text_model=model.model) + assert all("mtp" not in s for s in out) + + +# --------------------------------------------------------------------------- +# get_pipeline_stage_metas +# --------------------------------------------------------------------------- + + +class TestPipelineStageMetas: + def test_first_middle_final_arity_and_shapes_with_mtp(self, backend): + D = 2 + first, cfg = _make_model( + backend, + mtp_layers=D, + mtp_layers_block_type=["attention", "moe"], + ) + first.lm_head = None + first.model.norm = None + first.mtp = None + f_in, f_out = first.get_pipeline_stage_metas(is_first=True, microbatch_size=2, seq_len=16, dtype=torch.bfloat16) + assert f_in[0].shape == (2, 16) and f_in[0].dtype == torch.long + assert len(f_out) == 1 + D + assert f_out[0].shape == (2, 16, cfg.hidden_size) + for h in f_out[1:]: + assert h.shape == (2, 16, cfg.hidden_size) + + middle, _ = _make_model( + backend, + mtp_layers=D, + mtp_layers_block_type=["attention", "moe"], + ) + middle.model.embed_tokens = None + middle.lm_head = None + middle.model.norm = None + middle.mtp = None + m_in, m_out = middle.get_pipeline_stage_metas( + is_first=False, microbatch_size=2, seq_len=16, dtype=torch.bfloat16 + ) + assert len(m_in) == 1 + D and len(m_out) == 1 + D + assert m_in[0].shape == (2, 16, cfg.hidden_size) + + final, _ = _make_model( + backend, + mtp_layers=D, + mtp_layers_block_type=["attention", "moe"], + ) + final.model.embed_tokens = None + l_in, l_out = final.get_pipeline_stage_metas( + is_first=False, microbatch_size=2, seq_len=16, dtype=torch.bfloat16 + ) + assert len(l_in) == 1 + D + # Final stage appends an int32 [B, S] seq_idx tail: (logits, *mtp_h, seq_idx). + assert len(l_out) == 1 + D + 1 + assert l_out[0].shape == (2, 16, cfg.vocab_size) + for h in l_out[1 : 1 + D]: + assert h.shape == (2, 16, cfg.hidden_size) + assert l_out[-1].shape == (2, 16) and l_out[-1].dtype == torch.int32 + + def test_no_mtp_arity_is_one(self, backend): + model, cfg = _make_model(backend) + f_in, f_out = model.get_pipeline_stage_metas(is_first=True, microbatch_size=1, seq_len=8, dtype=torch.bfloat16) + assert len(f_in) == 1 and f_in[0].dtype == torch.long + assert len(f_out) == 1 + + +# --------------------------------------------------------------------------- +# PP forward variants (stubbed backbone) +# --------------------------------------------------------------------------- + + +class TestPPForward: + def test_first_stage_propagates_shifted_mtp_embeddings(self, backend): + model, cfg = _make_model( + backend, + mtp_layers=1, + mtp_layers_block_type=["attention", "moe"], + ) + model.train() + # Simulate a first-stage trim: lm_head + norm + mtp absent, embed_tokens kept. + model.lm_head = None + model.model.norm = None + model.mtp = None + with torch.no_grad(): + model.model.embed_tokens.weight.copy_( + torch.arange(cfg.vocab_size * cfg.hidden_size, dtype=torch.float32).view( + cfg.vocab_size, cfg.hidden_size + ) + ) + + def fake_inner(self, input_ids, **kwargs): + del self, kwargs + return torch.ones(input_ids.shape[0], input_ids.shape[1], cfg.hidden_size, dtype=torch.bfloat16) + + model.model.forward = types.MethodType(fake_inner, model.model) + + input_ids = torch.tensor([[10, 11, 12, 13]]) + out = model(input_ids) + + assert isinstance(out, tuple) + assert len(out) == 2 # 1 + D + assert out[0].shape == (1, 4, cfg.hidden_size) + expected_ids = torch.tensor([[11, 12, 13, 0]]) + torch.testing.assert_close(out[1], model.model.embed_tokens(expected_ids)) + + def test_middle_stage_passes_through_mtp_embeds(self, backend): + model, cfg = _make_model( + backend, + mtp_layers=1, + mtp_layers_block_type=["attention", "moe"], + ) + model.train() + # Simulate a middle-stage trim: nothing owned except some backbone layers. + model.lm_head = None + model.model.embed_tokens = None + model.model.norm = None + model.mtp = None + + captured = {} + + def fake_inner(self, input_ids, **kwargs): + captured["received_as_input_ids"] = input_ids + return input_ids # passthrough; same shape + + model.model.forward = types.MethodType(fake_inner, model.model) + + activation = torch.zeros(1, 4, cfg.hidden_size, dtype=torch.bfloat16) + mtp_embed = torch.randn(1, 4, cfg.hidden_size, dtype=torch.bfloat16) + out = model(activation, mtp_embed) + + assert isinstance(out, tuple) + assert len(out) == 2 + assert out[0].shape == (1, 4, cfg.hidden_size) + # mtp_embed flows through unchanged + torch.testing.assert_close(out[1], mtp_embed) + # Backbone received the inter-stage tensor through the input_ids slot. + assert captured["received_as_input_ids"] is activation + + def test_final_stage_uses_propagated_mtp_embeddings(self, backend): + model, cfg = _make_model( + backend, + mtp_layers=1, + mtp_layers_block_type=["attention", "moe"], + ) + model.train() + # Simulate a final-stage trim: embed_tokens absent, lm_head + mtp owned. + model.model.embed_tokens = None + + captured = {} + + def fake_inner(self, input_ids, **kwargs): + del self, kwargs + return torch.ones(input_ids.shape[0], input_ids.shape[1], cfg.hidden_size, dtype=torch.bfloat16) + + def fake_mtp_forward(self, **kwargs): + captured.update(kwargs) + return [kwargs["hidden_states"].clone()] + + model.model.forward = types.MethodType(fake_inner, model.model) + model.mtp.forward = types.MethodType(fake_mtp_forward, model.mtp) + + activation = torch.zeros(1, 4, cfg.hidden_size, dtype=torch.bfloat16) + mtp_embed = torch.randn(1, 4, cfg.hidden_size, dtype=torch.bfloat16) + out = model(activation, mtp_embed) + + assert isinstance(out, tuple) + assert len(out) == 3 # (logits, mtp_per_depth_h[0], seq_idx) + assert out[0].shape == (1, 4, cfg.vocab_size) + assert out[-1].shape == (1, 4) and out[-1].dtype == torch.int32 + # The MTP head was given the upstream embedding via embed_inputs and + # NOT via the input_ids/embed_fn path. + assert "embed_inputs" in captured + torch.testing.assert_close(captured["embed_inputs"][0], mtp_embed) + assert captured.get("input_ids") is None + assert captured.get("embed_fn") is None + + def test_final_stage_eval_emits_placeholders(self, backend): + """In eval mode, the last stage keeps the (1 + D + seq_idx) tuple arity.""" + D = 1 + model, cfg = _make_model( + backend, + mtp_layers=D, + mtp_layers_block_type=["attention", "moe"], + ) + model.eval() + model.model.embed_tokens = None + + def fake_inner(self, input_ids, **kwargs): + del self, kwargs + return torch.ones(input_ids.shape[0], input_ids.shape[1], cfg.hidden_size, dtype=torch.bfloat16) + + model.model.forward = types.MethodType(fake_inner, model.model) + + activation = torch.zeros(1, 4, cfg.hidden_size, dtype=torch.bfloat16) + mtp_embed = torch.randn(1, 4, cfg.hidden_size, dtype=torch.bfloat16) + with torch.no_grad(): + out = model(activation, mtp_embed) + + assert isinstance(out, tuple) + assert len(out) == 1 + D + 1 + # Logits live on out[0]; placeholders match the activation's hidden shape; + # the int32 [B, S] seq_idx tail is last. + assert out[0].shape == (1, 4, cfg.vocab_size) + for ph in out[1 : 1 + D]: + assert ph.shape == (1, 4, cfg.hidden_size) + assert out[-1].shape == (1, 4) and out[-1].dtype == torch.int32 + + +# --------------------------------------------------------------------------- +# Initialize_weights on a trimmed stage +# --------------------------------------------------------------------------- + + +class TestInitializeWeightsOnTrimmedStage: + def test_middle_stage_init_no_attribute_error(self, backend): + """Stage with embed_tokens=None, norm=None, lm_head=None, mtp=None must init cleanly.""" + model, _ = _make_model(backend, mtp_layers=1, mtp_layers_block_type=["attention", "moe"]) + model.lm_head = None + model.model.embed_tokens = None + model.model.norm = None + model.mtp = None + # Should not raise AttributeError on any of the trimmed attrs. + model.initialize_weights(buffer_device=torch.device("cpu")) + + def test_first_stage_init_no_attribute_error(self, backend): + model, _ = _make_model(backend, mtp_layers=1, mtp_layers_block_type=["attention", "moe"]) + model.lm_head = None + model.model.norm = None + model.mtp = None + model.initialize_weights(buffer_device=torch.device("cpu")) + + +# --------------------------------------------------------------------------- +# MoE FSDP iterator on a trimmed stage +# --------------------------------------------------------------------------- + + +class TestMoEIterOnTrimmedStage: + def test_iter_skips_absent_mtp(self, backend): + from nemo_automodel.components.moe.parallelizer import _iter_transformer_and_mtp_blocks + + model, _ = _make_model(backend, mtp_layers=1, mtp_layers_block_type=["attention", "moe"]) + # Mimic a middle stage that holds backbone layers but no mtp. + model.lm_head = None + model.model.embed_tokens = None + model.model.norm = None + model.mtp = None + + yielded = list(_iter_transformer_and_mtp_blocks(model)) + # Only backbone layers should be iterated; no MTP-side blocks. + assert len(yielded) == len(model.model.layers) + for parent_layers, layer_id, _block in yielded: + assert parent_layers is model.model.layers + assert layer_id in model.model.layers + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-vv"])) diff --git a/tests/unit_tests/moe/test_backend_config.py b/tests/unit_tests/moe/test_backend_config.py index cfbb474583..e1557d2270 100644 --- a/tests/unit_tests/moe/test_backend_config.py +++ b/tests/unit_tests/moe/test_backend_config.py @@ -240,16 +240,6 @@ def test_dispatcher_async_dispatch_custom(self): config = BackendConfig(dispatcher="deepep", dispatcher_async_dispatch=True) assert config.dispatcher_async_dispatch is True - def test_disable_shared_expert_overlap_default(self): - """Test that disable_shared_expert_overlap defaults to False.""" - config = BackendConfig() - assert config.disable_shared_expert_overlap is False - - def test_disable_shared_expert_overlap_custom(self): - """Test that disable_shared_expert_overlap accepts an explicit value.""" - config = BackendConfig(disable_shared_expert_overlap=True) - assert config.disable_shared_expert_overlap is True - def test_te_experts_falls_back_with_hybridep(self): """Test that te experts with hybridep dispatcher is valid (no fallback).""" config = BackendConfig(experts="te", dispatcher="hybridep") diff --git a/tests/unit_tests/moe/test_latent_projection.py b/tests/unit_tests/moe/test_latent_projection.py index a5d5fc9308..85e7d20d39 100644 --- a/tests/unit_tests/moe/test_latent_projection.py +++ b/tests/unit_tests/moe/test_latent_projection.py @@ -211,6 +211,9 @@ def test_forward_with_shared_experts_latent_enabled(self, moe_config, backend_co patch("torch.cuda.Stream") as mock_stream_class, patch("torch.cuda.current_stream") as mock_current_stream, patch("torch.cuda.stream") as mock_stream_context, + # The shared-expert fork/join calls Tensor.record_stream with the + # mocked stream; no-op the helper so the mock isn't passed to it. + patch("nemo_automodel.components.moe.layers._record_stream_safe"), ): mock_stream = Mock() mock_stream.wait_stream = Mock() diff --git a/tests/unit_tests/moe/test_layers.py b/tests/unit_tests/moe/test_layers.py index 60917cff29..8388891cad 100644 --- a/tests/unit_tests/moe/test_layers.py +++ b/tests/unit_tests/moe/test_layers.py @@ -1422,6 +1422,9 @@ def test_moe_forward_with_shared_experts(self, moe_config, backend_config, devic patch("torch.cuda.Stream") as mock_stream_class, patch("torch.cuda.current_stream") as mock_current_stream, patch("torch.cuda.stream") as mock_stream_context, + # The shared-expert fork/join calls Tensor.record_stream with the + # mocked stream; no-op the helper so the mock isn't passed to it. + patch("nemo_automodel.components.moe.layers._record_stream_safe"), ): mock_stream = Mock() mock_stream.wait_stream = Mock() diff --git a/tests/unit_tests/moe/test_state_dict_mixin.py b/tests/unit_tests/moe/test_state_dict_mixin.py index 6008811952..f5e7754e19 100644 --- a/tests/unit_tests/moe/test_state_dict_mixin.py +++ b/tests/unit_tests/moe/test_state_dict_mixin.py @@ -244,7 +244,7 @@ def test_multiple_layers_first_complete(self): "abstract_key2": { 0: torch.randn(512, 1024), 1: torch.randn(512, 1024), - } + }, } } @@ -337,7 +337,7 @@ def test_dtensor_validation(self, mock_is_dtensor, mock_validate): def test_without_model_prefix(self): mixin = MockMoEStateDictMixin(n_experts=4, uses_model_prefix=False) - with patch.object(mixin, '_split_experts_weights') as mock_split: + with patch.object(mixin, "_split_experts_weights") as mock_split: gate_and_up_weights = [torch.randn(1024, 1024) for _ in range(4)] mock_split.return_value = gate_and_up_weights mixin._last_expert_ids = [0, 1, 2, 3] @@ -378,7 +378,7 @@ def test_gate_and_up_projs_conversion(self, mock_is_dtensor): assert gate_key in result assert up_key in result assert result[gate_key].shape == (512, 1024) # [inter_dim, dim] - assert result[up_key].shape == (512, 1024) # [inter_dim, dim] + assert result[up_key].shape == (512, 1024) # [inter_dim, dim] @patch("nemo_automodel.components.moe.state_dict_mixin.is_dtensor") def test_down_projs_conversion_n2(self, mock_is_dtensor): @@ -423,7 +423,6 @@ def test_dtensor_validation_n2(self, mock_is_dtensor, mock_validate): mock_validate.assert_called_once_with(mock_dtensor, 2, "gate_and_up_projs layer 0") - # Tests merged into TestToHfWSplitExperts @@ -444,7 +443,7 @@ def test_basic_conversion(self, mock_should_load, mock_create_dtensor): key_up = f"model.layers.0.mlp.experts.{expert_id}.up_proj.weight" hf_state_dict[key_up] = torch.randn(512, 1024) - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): result = mixin._from_hf_w_merged_experts(hf_state_dict) # Check that gate_and_up_projs tensor was created @@ -463,10 +462,15 @@ def test_partial_expert_loading(self): key_up = f"model.layers.0.mlp.experts.{expert_id}.up_proj.weight" hf_state_dict[key_up] = torch.randn(512, 1024) - with patch.object(mixin, '_validate_expert_availability'): - with patch("nemo_automodel.components.moe.state_dict_mixin.should_load_expert_for_rank") as mock_should_load: + with patch.object(mixin, "_validate_expert_availability"): + with patch( + "nemo_automodel.components.moe.state_dict_mixin.should_load_expert_for_rank" + ) as mock_should_load: mock_should_load.side_effect = lambda expert_id, *args: expert_id == 1 # Only load expert 1 - with patch("nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", side_effect=lambda x, *args: x): + with patch( + "nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", + side_effect=lambda x, *args: x, + ): result = mixin._from_hf_w_merged_experts(hf_state_dict) # When only partial experts are loaded, no tensor should be created until all are available @@ -485,9 +489,12 @@ def test_without_model_prefix(self): key_up = f"layers.0.mlp.experts.{expert_id}.up_proj.weight" hf_state_dict[key_up] = torch.randn(512, 1024) - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): with patch("nemo_automodel.components.moe.state_dict_mixin.should_load_expert_for_rank", return_value=True): - with patch("nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", side_effect=lambda x, *args: x): + with patch( + "nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", + side_effect=lambda x, *args: x, + ): result = mixin._from_hf_w_merged_experts(hf_state_dict) # Result key preserves the empty prefix from input @@ -506,9 +513,12 @@ def test_with_language_model_prefix(self): key_up = f"model.language_model.layers.0.mlp.experts.{expert_id}.up_proj.weight" hf_state_dict[key_up] = torch.randn(512, 1024) - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): with patch("nemo_automodel.components.moe.state_dict_mixin.should_load_expert_for_rank", return_value=True): - with patch("nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", side_effect=lambda x, *args: x): + with patch( + "nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", + side_effect=lambda x, *args: x, + ): result = mixin._from_hf_w_merged_experts(hf_state_dict) # Result key should preserve the language_model prefix @@ -525,9 +535,12 @@ def test_with_language_model_prefix_down_proj(self): key = f"model.language_model.layers.0.mlp.experts.{expert_id}.down_proj.weight" hf_state_dict[key] = torch.randn(1024, 512) # [dim, inter_dim] - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): with patch("nemo_automodel.components.moe.state_dict_mixin.should_load_expert_for_rank", return_value=True): - with patch("nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", side_effect=lambda x, *args: x): + with patch( + "nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", + side_effect=lambda x, *args: x, + ): result = mixin._from_hf_w_merged_experts(hf_state_dict) # Result key should preserve the language_model prefix @@ -555,9 +568,12 @@ def test_with_device_mesh(self, mock_get_submesh, mock_get_expert_range): "model.layers.0.mlp.experts.0.up_proj.weight": torch.randn(512, 1024), } - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): with patch("nemo_automodel.components.moe.state_dict_mixin.should_load_expert_for_rank", return_value=True): - with patch("nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", side_effect=lambda x, *args: x): + with patch( + "nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", + side_effect=lambda x, *args: x, + ): result = mixin._from_hf_w_merged_experts(hf_state_dict, mock_device_mesh) expected_key = "model.layers.0.mlp.experts.gate_and_up_projs" @@ -574,10 +590,10 @@ def test_gate_and_up_combination(self, mock_should_load, mock_create_dtensor): hf_state_dict = { "model.layers.0.mlp.experts.0.gate_proj.weight": torch.randn(512, 1024), # [inter_dim, dim] - "model.layers.0.mlp.experts.0.up_proj.weight": torch.randn(512, 1024), # [inter_dim, dim] + "model.layers.0.mlp.experts.0.up_proj.weight": torch.randn(512, 1024), # [inter_dim, dim] } - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): result = mixin._from_hf_w_merged_experts(hf_state_dict) # Should create gate_and_up_projs tensor @@ -597,7 +613,7 @@ def test_down_proj_transpose(self, mock_should_load, mock_create_dtensor): "model.layers.0.mlp.experts.0.down_proj.weight": torch.randn(1024, 512), # [dim, inter_dim] } - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): result = mixin._from_hf_w_merged_experts(hf_state_dict) # Should create transposed down_projs tensor @@ -623,9 +639,12 @@ def test_dtensor_input_handling(self, mock_is_dtensor): "model.layers.0.mlp.experts.0.up_proj.weight": mock_up_dtensor, } - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): with patch("nemo_automodel.components.moe.state_dict_mixin.should_load_expert_for_rank", return_value=True): - with patch("nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", side_effect=lambda x, *args: x): + with patch( + "nemo_automodel.components.moe.state_dict_mixin.create_dtensor_from_local", + side_effect=lambda x, *args: x, + ): result = mixin._from_hf_w_merged_experts(hf_state_dict) # Verify to_local was called on DTensor inputs @@ -640,13 +659,12 @@ def test_skip_scale_inv_keys(self): "some_weight_scale_inv": torch.randn(10), # Should be skipped } - with patch.object(mixin, '_validate_expert_availability'): + with patch.object(mixin, "_validate_expert_availability"): result = mixin._from_hf_w_merged_experts(hf_state_dict) assert "some_weight" in result assert "some_weight_scale_inv" not in result - # Tests merged into TestFromHfWMergedExperts @@ -751,3 +769,129 @@ def test_dtensor_validation_called(self, mock_is_dtensor, mock_validate): mock_validate.assert_called_once_with(mock_dtensor, 2, "gate_and_up_projs layer 0") assert result is not None + + +class TestInplaceLoadViews: + """to_hf returns non-contiguous views into the model's grouped tensor's + local storage whenever the source is a model DTensor with plain + (non-DTensor) per-expert splits. DCP writes safetensors data through the + views into model storage, and ``_from_hf_w_merged_experts`` skips the + rebuild for those native keys (the model already holds the data). Save + callers must materialize the views to contiguous before serializing — + see ``_materialize_to_hf_views_for_save`` in checkpointing. + + The mixin re-imports ``is_dtensor`` from ``state_dict_utils`` inside the + conversion function, so patches must target that module path. + """ + + def _run_inplace_conversion(self, mixin, fqn, mock_dtensor, splits): + mixin._split_experts_weights = Mock(return_value=splits) + mixin._last_expert_ids = list(range(len(splits))) + + with ( + patch( + "nemo_automodel.components.moe.state_dict_utils.is_dtensor", + side_effect=lambda x: x is mock_dtensor, + ), + patch("nemo_automodel.components.moe.state_dict_utils.validate_dtensor_expert_sharding"), + ): + return mixin._convert_single_merged_expert_to_hf_split_experts(fqn, mock_dtensor) + + def test_inplace_load_gate_and_up_returns_views(self): + mixin = MockMoEStateDictMixin(n_experts=2, inter_dim=512) + # local[i] for gated has shape (dim=1024, 2*inter=1024). + local_storage = torch.randn(2, 1024, 1024) + splits = [local_storage[i] for i in range(2)] + mock_dtensor = Mock() + + result = self._run_inplace_conversion( + mixin, "model.layers.0.mlp.experts.gate_and_up_projs", mock_dtensor, splits + ) + + assert result is not None + src_ptr = local_storage.untyped_storage().data_ptr() + for k, v in result: + assert v.untyped_storage().data_ptr() == src_ptr, f"in-place view for {k} should alias model storage" + assert not v.is_contiguous(), f"in-place view for {k} must be the strided transpose, not a copy" + assert "model.layers.0.mlp.experts.gate_and_up_projs" in mixin._inplace_loaded_native_keys + + def test_inplace_load_down_projs_returns_views(self): + mixin = MockMoEStateDictMixin(n_experts=2, inter_dim=512) + # local[i] for down has shape (inter=512, dim=1024). + local_storage = torch.randn(2, 512, 1024) + splits = [local_storage[i] for i in range(2)] + + # The down branch dispatches via ``tensor.shape[1] == inter_dim``, so + # the mock must answer that check before splits are computed. + mock_dtensor = Mock(spec=["ndim", "shape", "is_meta"]) + mock_dtensor.ndim = 3 + mock_dtensor.shape = (2, 512, 1024) + mock_dtensor.is_meta = False + + result = self._run_inplace_conversion(mixin, "model.layers.3.mlp.experts.down_projs", mock_dtensor, splits) + + assert result is not None and len(result) == 2 + src_ptr = local_storage.untyped_storage().data_ptr() + for k, v in result: + assert v.untyped_storage().data_ptr() == src_ptr, f"in-place view for {k} should alias model storage" + assert "model.layers.3.mlp.experts.down_projs" in mixin._inplace_loaded_native_keys + + def test_inplace_load_skips_when_source_not_dtensor(self): + # When tensor is a plain CPU tensor (not from the model), the in-place + # path must not engage — there is no model storage to alias and + # contiguous copies are the correct fallback. + mixin = MockMoEStateDictMixin(n_experts=2, inter_dim=512) + tensor = torch.randn(2, 1024, 1024) + fqn = "model.layers.0.mlp.experts.gate_and_up_projs" + + with patch( + "nemo_automodel.components.moe.state_dict_utils.is_dtensor", + return_value=False, + ): + result = mixin._convert_single_merged_expert_to_hf_split_experts(fqn, tensor) + + assert result is not None + for _, v in result: + assert v.is_contiguous(), "non-DTensor source should emit contiguous copies" + assert not hasattr( + mixin, "_inplace_loaded_native_keys" + ) or "model.layers.0.mlp.experts.gate_and_up_projs" not in (mixin._inplace_loaded_native_keys or set()) + + def test_inplace_load_writes_through_to_model_storage(self): + # Simulate DCP-style copy_ on the emitted views and verify the model's + # underlying storage is updated at the correct slice. + mixin = MockMoEStateDictMixin(n_experts=2, inter_dim=512) + local_storage = torch.zeros(2, 1024, 1024) + splits = [local_storage[i] for i in range(2)] + mock_dtensor = Mock() + + result = self._run_inplace_conversion( + mixin, "model.layers.0.mlp.experts.gate_and_up_projs", mock_dtensor, splits + ) + + gate0 = next(v for k, v in result if k.endswith("0.gate_proj.weight")) + gate0.copy_(torch.full_like(gate0, 7.0)) + # Gated layout: local[0, :, :inter] holds gate, local[0, :, inter:] holds up. + assert torch.allclose(local_storage[0, :, :512], torch.full((1024, 512), 7.0)) + assert torch.all(local_storage[0, :, 512:] == 0) + assert torch.all(local_storage[1] == 0) + + def test_from_hf_skips_rebuild_for_inplace_loaded_keys(self): + # When _inplace_loaded_native_keys contains a layer's grouped key, the + # per-expert HF keys for that layer must NOT be merged back into a + # native key in the output state_dict. + mixin = MockMoEStateDictMixin(n_experts=2, inter_dim=512) + mixin._inplace_loaded_native_keys = { + "model.layers.0.mlp.experts.gate_and_up_projs", + "model.layers.0.mlp.experts.down_projs", + } + hf_state_dict = {} + for expert_id in range(2): + for proj in ("gate_proj", "up_proj", "down_proj"): + hf_state_dict[f"model.layers.0.mlp.experts.{expert_id}.{proj}.weight"] = torch.randn(512, 1024) + + out = mixin._from_hf_w_merged_experts(hf_state_dict) + + assert "model.layers.0.mlp.experts.gate_and_up_projs" not in out + assert "model.layers.0.mlp.experts.down_projs" not in out + assert mixin._inplace_loaded_native_keys == set()