Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
db3b407
feat(nemotron_v3): pipeline parallel + MTP support, plus THD collator…
adil-a May 25, 2026
cf8f3de
chore(datasets): drop pre_rendered_chat_dataset.py from branch
adil-a May 25, 2026
285ea0c
fix(thd_utils): compute max_seqlen from final cu_seqlens to honor TE …
adil-a May 26, 2026
54915d2
chore(thd_utils): drop dead cu_seqlens_padded fallback in emit
adil-a May 26, 2026
96cc3a1
docs(thd_utils): update docstrings to match post-fix cu_seqlens seman…
adil-a May 26, 2026
8946f90
chore(thd_utils): trim verbose inline comments
adil-a May 26, 2026
aa9238e
comments
adil-a May 26, 2026
ff34ee0
fix(seq_idx): use searchsorted(right=True) to classify boundary tokens
adil-a May 26, 2026
239fe64
chore: remove debug env-var hooks from THD packing investigation
adil-a May 26, 2026
a39736f
test(mtp): pin MTPModule cumulative left-rolling of input_ids and pos…
adil-a May 26, 2026
868ba6d
chore(mtp): trim verbose comments in calculate_mtp_loss
adil-a May 26, 2026
eaa1aaa
docs(nemotron_v3): correct stale 'PP cannot chunk' comments in model.…
adil-a May 26, 2026
64ec935
chore(nemotron_v3): trim verbose comments in model.py
adil-a May 26, 2026
bb238e1
fix(nemotron_v3): seq_idx tail builder uses searchsorted(right=True)
adil-a May 27, 2026
61e13a9
docs(nemotron_v3): restore Args/Returns docstring on NemotronHForCaus…
adil-a May 27, 2026
0da9e5e
docs(nemotron_v3): fix inaccuracies in NemotronHForCausalLM.forward d…
adil-a May 27, 2026
c05ea54
chore(moe): trim verbose comment in apply_cp loop
adil-a May 27, 2026
aa4b4e4
fix(nemotron_v3): correct seq_idx and MTP loss shape for non-pp THD p…
adil-a May 28, 2026
c9098b8
feat(moe): avoid load-time OOM via in-place views in MoE to_hf split
adil-a May 28, 2026
6ac2135
fix(nemotron_v3): correct mamba seq_idx + MTP masking for mbs>1 THD p…
adil-a May 31, 2026
b8dcb08
fix(moe): backward-safe, always-on shared-expert overlap
adil-a May 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,16 @@ def save_model(
model_state = ModelState(model, self.config.is_peft)
state_dict = model_state.state_dict()

# Convert to HF format if using custom model implementations
# Convert to HF format if using custom model implementations.
state_dict = _maybe_adapt_state_dict_to_hf(
model_state.model[0],
state_dict,
quantization=False,
device_mesh=self.moe_mesh,
v4_compatible=self.config.v4_compatible,
)
# MoE adapters return non-contiguous views; safetensors.save rejects those.
_materialize_to_hf_views_for_save(state_dict)
# Build the consolidated model.safetensors.index.json if needed
fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict)

Expand Down Expand Up @@ -507,6 +509,8 @@ def load_model(
reader_key_mapping = None if has_state_dict_adapter else key_mapping
storage_reader = self._get_storage_reader(model_path, reader_key_mapping, is_init_step=is_init_step)

# MoE adapters return views into model storage; DCP writes safetensors
# data straight through them and from_hf skips the rebuild.
state_dict = _maybe_adapt_state_dict_to_hf(
model_state.model[0],
state_dict,
Expand Down Expand Up @@ -1578,6 +1582,27 @@ def _maybe_adapt_state_dict_to_hf(
return state_dict


def _materialize_to_hf_views_for_save(state_dict: dict[str, torch.Tensor]) -> None:
"""Replace non-contiguous tensor values in ``state_dict`` with contiguous copies in place.

MoE adapters return non-contiguous strided views into the model's grouped
expert storage for the optimized load path; ``safetensors.torch.save``
(which the DCP HF storage writer calls) rejects non-contiguous tensors,
so we materialize one tensor at a time here with ``empty_cache`` between
iterations. Per-tensor transient is bounded to a single expert weight
instead of allocating the full grouped set up front.
"""
if not state_dict:
return
cuda_available = torch.cuda.is_available()
for key, value in list(state_dict.items()):
if isinstance(value, torch.Tensor) and not value.is_contiguous():
state_dict[key] = value.contiguous()
del value
if cuda_available:
torch.cuda.empty_cache()


def _equally_divide_layers(num_shards: int, keys: list[str]) -> dict[str, int]:
"""
Equally divide the state dict keys into num_shards shards.
Expand Down
91 changes: 74 additions & 17 deletions nemo_automodel/components/distributed/thd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,21 @@ def process_input_for_thd(
[total_tokens, hidden_dim] for 3D embeddings
- 'labels': Reshaped labels tensor of shape [total_tokens]
- 'position_ids': Reshaped tensor of shape [total_tokens]
- 'cu_seqlens': Cumulative padded sequence lengths tensor of shape [num_sequences + 1] (int32)
- 'cu_seqlens': Cumulative REAL sequence lengths tensor of shape [num_sequences + 1] (int32)
where num_sequences is the total count of non-padded sequences across the batch.
NOTE: This contains cumulative lengths from seq_lens_padded (not seq_lens) since
CP doesn't support padding between sequences (resulting in NaNs). The labels or loss mask
will ensure that loss is computed correctly.
Built from seq_lens (the unpadded real lengths). When the trailing pack-pad is
purely at the end (cp_size == 1), the last entry is grown to total_tokens to absorb
that pad and avoid TE's ``pad_between_seqs=True`` path; see the absorption block in
the function body for the gate.
- 'cu_seqlens_padded': (optional) Cumulative PADDED sequence lengths tensor of the same
shape as ``cu_seqlens``. Only emitted when it differs from ``cu_seqlens`` after
absorption (i.e., when padding lives between sub-sequences, which is the CP case).
Forwarded to TE as ``cu_seqlens_q_padded`` / ``cu_seqlens_kv_padded`` with
``pad_between_seqs=True`` so the kernel reads memory offsets from the padded
variant while attending only over the real-length slots.
- 'max_seqlen': Scalar int32 tensor equal to ``max(cu_seqlens[i+1] - cu_seqlens[i])``
after any absorption. Honors TE's contract that
``max_seqlen_q >= max(cu_seqlens_q[i+1] - cu_seqlens_q[i])``.
- 'padding_mask': Boolean tensor of shape [total_tokens] indicating padding positions
- Non-tensor keys from input batch are preserved (e.g., 'qkv_format')

Expand All @@ -77,8 +87,11 @@ def process_input_for_thd(
>>> # result['input_ids'].shape: [12] (2D input collapsed to 1D)
>>> # result['labels'].shape: [12]
>>> # result['position_ids'].shape: [12]
>>> # result['cu_seqlens']: tensor([0, 4, 6, 12], dtype=torch.int32)
>>> # result['cu_seqlens']: tensor([0, 3, 5, 11], dtype=torch.int32)
>>> # Breakdown: [0] + cumsum([3, 2, 6]) = [0, 3, 5, 11] (from seq_lens — real lengths)
>>> # result['cu_seqlens_padded']: tensor([0, 4, 6, 12], dtype=torch.int32)
>>> # Breakdown: [0] + cumsum([4, 2, 6]) = [0, 4, 6, 12] (from seq_lens_padded)
>>> # result['max_seqlen']: tensor(6, dtype=torch.int32) # max slot width in cu_seqlens
>>> # result['padding_mask'].shape: [12]
"""
input_ids = batch["input_ids"]
Expand All @@ -96,13 +109,13 @@ def process_input_for_thd(
input_ids_thd = input_ids.reshape(total_tokens, -1).squeeze(-1)
labels_thd = labels.reshape(total_tokens, -1).squeeze(-1)

cu_seqlens = None
cu_seqlens_padded = None
max_seqlen = None
if seq_lens is not None:
# Filter out padding values and flatten
# seq_lens shape: [batch_size, num_packs] -> flatten and remove padding values
seq_lens_flat = seq_lens.reshape(-1)
valid_seq_lens = seq_lens_flat[seq_lens_flat != seq_lens_padding_value]

# Compute cumulative sequence lengths for attention
cu_seqlens = torch.cat(
[
torch.tensor([0], dtype=valid_seq_lens.dtype, device=valid_seq_lens.device),
Expand All @@ -112,7 +125,6 @@ def process_input_for_thd(
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(device=valid_seq_lens.device)

if seq_lens_padded is not None:
# Same processing for padded sequence lengths
seq_lens_padded_flat = seq_lens_padded.reshape(-1)
valid_seq_lens_padded = seq_lens_padded_flat[seq_lens_padded_flat != seq_lens_padding_value]

Expand All @@ -121,16 +133,46 @@ def process_input_for_thd(
)
cu_seqlens_padded = cu_seqlens_padded.to(dtype=torch.int32).to(device=valid_seq_lens_padded.device)

# Trailing-only pack-pad (cp_size==1): absorb into cu_seqlens[-1] so
# the emit gate below drops cu_seqlens_padded and TE skips its
# pad_between_seqs=True path. CP>1 differs in multiple entries and
# falls through; both arrays are emitted and TE handles padding.
if (
cu_seqlens is not None
and cu_seqlens_padded is not None
and cu_seqlens.numel() == cu_seqlens_padded.numel()
and cu_seqlens.numel() > 1
and torch.equal(cu_seqlens[:-1], cu_seqlens_padded[:-1])
):
_total = int(total_tokens)
_real_total = int(cu_seqlens[-1].item())
if _real_total < _total:
_extended = cu_seqlens.clone()
_extended[-1] = _total
cu_seqlens = _extended
cu_seqlens_padded = cu_seqlens.clone()

# Compute max_seqlen from the FINAL cu_seqlens to honor TE's contract
# (``max_seqlen_q >= max(cu_seqlens[i+1] - cu_seqlens[i])``, see TE's
# cpp_extensions/fused_attn.py:152-159).
if cu_seqlens is not None and cu_seqlens.numel() > 1:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().to(dtype=torch.int32)

result = {
"input_ids": input_ids_thd,
"position_ids": position_ids_thd,
# Pass cu_seqlens_padded here since CP doesn't support padding between sequences correctly, the labels or loss mask will ensure that loss is computed correctly.
"cu_seqlens": cu_seqlens_padded,
"cu_seqlens": cu_seqlens,
"labels": labels_thd,
"padding_mask": (input_ids_thd == padding_token_id),
}
# Emit cu_seqlens_padded only when it differs from cu_seqlens — its
# presence is what flips TE's pad_between_seqs=True path in
# attention/utils.py.
if cu_seqlens_padded is not None and not torch.equal(cu_seqlens_padded, cu_seqlens):
result["cu_seqlens_padded"] = cu_seqlens_padded
if max_seqlen is not None:
result["max_seqlen"] = max_seqlen

# Preserve qkv_format and other non-tensor keys from the original batch
for key, value in batch.items():
if key not in result and not isinstance(value, torch.Tensor):
result[key] = value
Expand Down Expand Up @@ -175,8 +217,14 @@ def split_batch_into_thd_chunks(
- 'input_ids': [num_chunks, tokens_per_chunk] or [num_chunks, tokens_per_chunk, hidden_dim]
- 'labels': [num_chunks, tokens_per_chunk]
- 'position_ids': [num_chunks, tokens_per_chunk]
- 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (padded with seq_lens_padding_value).
Contains cumulative lengths from seq_lens_padded for CP compatibility.
- 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (right-padded with
seq_lens_padding_value across chunks for rectangularity). Built from seq_lens
(real lengths) per chunk; see ``process_input_for_thd`` for the absorption
semantics applied per chunk.
- 'cu_seqlens_padded': (optional) Same shape, emitted whenever ANY chunk emits it.
For chunks that absorbed (no separate padded variant), this row equals the
chunk's ``cu_seqlens``.
- 'max_seqlen': [num_chunks] per-chunk scalar tensor.
- 'padding_mask': [num_chunks, tokens_per_chunk]
- Non-tensor keys from input batch are preserved
- When num_chunks <= 1:
Expand Down Expand Up @@ -230,12 +278,21 @@ def pad_and_stack(tensor_list, padding_value):
for i in range(num_chunks)
]

# Stack results
return {
stacked: dict = {
"input_ids": torch.stack([c["input_ids"] for c in chunk_results]),
"labels": torch.stack([c["labels"] for c in chunk_results]),
"position_ids": torch.stack([c["position_ids"] for c in chunk_results]),
"cu_seqlens": pad_and_stack([c["cu_seqlens"] for c in chunk_results], seq_lens_padding_value),
"padding_mask": torch.stack([c["padding_mask"] for c in chunk_results]),
**{k: v for k, v in chunk_results[0].items() if not isinstance(v, torch.Tensor)},
}
# Emit cu_seqlens_padded whenever any chunk emits it; absorbed chunks
# fall back to their cu_seqlens (semantically equal) for rectangularity.
if any("cu_seqlens_padded" in c for c in chunk_results):
stacked["cu_seqlens_padded"] = pad_and_stack(
[c.get("cu_seqlens_padded", c["cu_seqlens"]) for c in chunk_results],
seq_lens_padding_value,
)
if all("max_seqlen" in c for c in chunk_results):
stacked["max_seqlen"] = torch.stack([c["max_seqlen"] for c in chunk_results])
stacked.update({k: v for k, v in chunk_results[0].items() if not isinstance(v, torch.Tensor)})
return stacked
87 changes: 86 additions & 1 deletion nemo_automodel/components/loss/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def calculate_mtp_loss(
scaling_factor: float = 0.1,
num_label_tokens: Optional[int] = None,
ignore_index: int = -100,
cu_seqlens: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute the DeepSeek-V3 Multi-Token Prediction auxiliary loss.

Expand All @@ -51,19 +53,74 @@ def calculate_mtp_loss(
base loss for sum-reduction normalization).
ignore_index: Label value masked out of the CE loss for the trailing
``k+1`` rolled positions at depth ``k``.
cu_seqlens: Optional cumulative sequence lengths ``[num_seqs+1]``
(THD-pack layout). When supplied and ``seq_idx`` is not, builds
a per-token sub-sequence index via searchsorted. Without packing
this can be omitted.
seq_idx: Optional per-token sub-sequence index ``[B, S]`` (or ``[S]``).
Equality classes are what matter; absolute values can be any
ints. Takes precedence over ``cu_seqlens``. Used to mask label
rolls whose source position lies in a different sub-sequence.

Returns:
Scalar MTP loss with autograd graph.
"""
D = len(mtp_per_depth_h)

# Reconcile per-depth hidden-state and label dims for the THD-packed
# non-PP path: the model unsqueezes mtp_per_depth_h from ``[T, H]`` back
# to ``[1, T, H]`` (model.py post-MTP-forward), while labels arrive as
# 1D ``[T]`` from ``process_input_for_thd``. ``FusedLinearCrossEntropy``
# / ``cut_cross_entropy`` asserts ``hidden_states.shape[:-1] == labels.shape``
# so squeeze the synthetic batch axis when labels are flat.
if labels.dim() == 1:
mtp_per_depth_h = [h.squeeze(0) if (h.dim() == 3 and h.shape[0] == 1) else h for h in mtp_per_depth_h]

cur_labels = labels
total = mtp_per_depth_h[0].new_zeros(())

if seq_idx is None and cu_seqlens is not None:
cs = cu_seqlens
if cs.dim() == 2 and cs.shape[0] == 1:
cs = cs.squeeze(0)
if cs.dim() == 1:
# Span the full (padded) token axis; cu_seqlens[-1] excludes tail pad.
# Matches the model's mamba seq_idx build (nemotron_v3/layers.py).
total_len = labels.shape[-1]
positions = torch.arange(total_len, device=labels.device)
# ``right=True`` so a position equal to a boundary (the first token
# of sub-seq k, position == cu_seqlens[k]) maps to k, not k-1.
seq_idx = torch.searchsorted(cs[1:].contiguous(), positions, right=True)
if labels.dim() == 2:
seq_idx = seq_idx.unsqueeze(0).expand(labels.shape[0], -1)
elif seq_idx is not None:
if seq_idx.dim() == 1 and labels.dim() == 2:
seq_idx = seq_idx.unsqueeze(0).expand(labels.shape[0], -1)
elif seq_idx.dim() == 2 and labels.dim() == 1 and seq_idx.shape[0] == 1:
seq_idx = seq_idx.squeeze(0)
# Under PP the caller must chunk seq_idx to per-microbatch shape; a
# mismatch is a wiring bug, not a runtime condition to swallow.
if seq_idx.shape != labels.shape:
raise ValueError(
f"calculate_mtp_loss: seq_idx.shape={tuple(seq_idx.shape)} does not "
f"match labels.shape={tuple(labels.shape)}; under PP, chunk seq_idx "
f"into per-microbatch pieces before passing it in."
)

for k, h_k in enumerate(mtp_per_depth_h):
cur_labels = roll_tensor(cur_labels, shifts=-1, dim=-1)
masked = cur_labels.clone()
n_invalid = min(k + 1, masked.shape[-1])
masked[..., -n_invalid:] = ignore_index

# Mask labels whose rolled source (position t+k+1) lives in a
# different sub-seq than position t — predictions across sub-seq
# boundaries are nonsensical.
if seq_idx is not None:
rolled_seq_idx = roll_tensor(seq_idx, shifts=-(k + 1), dim=-1)
cross_seq = rolled_seq_idx != seq_idx
masked = torch.where(cross_seq, torch.full_like(masked, ignore_index), masked)

if isinstance(loss_fn, FusedLinearCrossEntropy):
depth_loss = calculate_loss(
loss_fn,
Expand All @@ -89,14 +146,40 @@ def calculate_mtp_loss(


class PipelineCausalLMLoss(nn.Module):
"""Pipeline schedule loss that can add MTP auxiliary CE on the last stage."""
"""Pipeline schedule loss that can add MTP auxiliary CE on the last stage.

Per-microbatch ``seq_idx`` is read from a trailing element of the
last-stage output tuple — the model appends an ``[B, S] int32`` tail
when MTP is enabled. This binds each microbatch's seq_idx to its loss
call via the PP runtime's output→loss contract, so the wiring is
schedule-agnostic. Legacy ``cu_seqlens`` (THD path) is a fallback for
models that don't emit a seq_idx tail.
"""

def __init__(self, loss_fn: nn.Module, model: nn.Module):
super().__init__()
self.loss_fn = loss_fn
self.model = model
# Legacy THD-pack fallback used when the model has no seq_idx tail.
self.cu_seqlens: Optional[torch.Tensor] = None

@staticmethod
def _extract_seq_idx_tail(output) -> tuple[Optional[torch.Tensor], object]:
"""Detect and strip a trailing per-microbatch seq_idx from output.

Convention: with MTP enabled the last-stage output is
``(logits, *mtp_per_depth_h, seq_idx)`` with an ``[B, S] int32``
tail — dtype alone discriminates.
"""
if isinstance(output, tuple) and len(output) > 0:
last = output[-1]
if isinstance(last, torch.Tensor) and last.dtype == torch.int32 and last.dim() == 2:
return last, output[:-1]
return None, output

def forward(self, output, labels: torch.Tensor) -> torch.Tensor:
seq_idx_mb, output = self._extract_seq_idx_tail(output)

if isinstance(output, tuple):
logits = output[0]
hidden_states = None
Expand All @@ -122,5 +205,7 @@ def forward(self, output, labels: torch.Tensor) -> torch.Tensor:
labels=labels,
model=self.model,
scaling_factor=scaling_factor,
cu_seqlens=self.cu_seqlens,
seq_idx=seq_idx_mb,
)
return loss
Loading
Loading