Skip to content

fix(speculative): EAGLE-3 vocab-shrunk checkpoint resume shape mismatch#2319

Merged
HuiyingLi merged 1 commit into
mainfrom
qiaochuz-nv/fix/eagle3-vocab-shrunk-resume
May 27, 2026
Merged

fix(speculative): EAGLE-3 vocab-shrunk checkpoint resume shape mismatch#2319
HuiyingLi merged 1 commit into
mainfrom
qiaochuz-nv/fix/eagle3-vocab-shrunk-resume

Conversation

@qiaochuz-nv
Copy link
Copy Markdown
Contributor

@qiaochuz-nv qiaochuz-nv commented May 26, 2026

Summary

Fix LlamaEagle3DraftModel checkpoint resume when vocab shrinking is enabled
(draft_vocab_size < target_vocab_size) and the target model has
tie_word_embeddings=True.

Two complementary changes:

  1. components/checkpoint/utils.pyhas_local_tied_lm_head() now
    returns False when lm_head.weight.shape != embed_tokens.weight.shape,
    even if config.tie_word_embeddings is True. This is the durable fix
    and protects any other model with the same asymmetry pattern.

  2. recipes/llm/train_eagle3.py — explicitly set
    draft_config["tie_word_embeddings"] = False after copying the target
    config, so the inherited target flag cannot propagate into the draft.
    Defensive belt-and-braces that keeps the draft config self-consistent
    with the actual draft model architecture.

Adds three unit tests in tests/unit_tests/utils/test_checkpoint_utils.py
covering shape-mismatch (False), shape-match-tied (True), and untied
(False) paths for has_local_tied_lm_head.

Root cause

recipes/llm/train_eagle3.py:165-184 builds the draft config by copying
the target config wholesale:

draft_config = target_config.to_dict()
draft_config["draft_vocab_size"] = int(selected_token_ids.numel())
...

tie_word_embeddings is not overridden, so for Llama-3.2-1B-Instruct (and
any other tied-embedding target) the draft inherits True.

components/speculative/eagle/draft_llama.py:440,478 then intentionally
builds asymmetric vocab tables — embed_tokens at config.vocab_size
(full target vocab, e.g. 128256) and lm_head at self.draft_vocab_size
(e.g. 8192). The two are not actually tied; they have different shapes by
design.

components/checkpoint/utils.py:has_local_tied_lm_head previously
trusted the config flag without checking shapes:

if not is_tied_word_embeddings(model):
    return False
lm_head_weight, _ = get_lm_head_weight_and_name(model)
input_embeddings_weight, _ = get_input_embeddings_weight_and_name(model)
return lm_head_weight is not None and input_embeddings_weight is not None

That made ModelState
(components/checkpoint/stateful_wrappers.py:286-287) pop
lm_head.weight from the saved state_dict on save (treating it as a tied
alias of the embedding), and made materialize_missing_tied_lm_head
(components/checkpoint/utils.py:205-244) copy
model.embed_tokens.weight (full target vocab) into the missing
lm_head.weight slot on load. The strict matcher then crashed because
the synthesized tensor was [128256, 2048] while the real lm_head was
[8192, 2048].

Setting --checkpoint.save_consolidated false did not help: the
pop / materialize logic lives in ModelState, which is used by
Checkpointer.save_model / load_model for the DCP shards regardless of
the consolidated-safetensors addon.

Repro (before fix)

EAGLE-3 recipe with default draft_vocab_size: 8192, against
meta-llama/Llama-3.2-1B-Instruct (which has tie_word_embeddings=True):

torchrun --nproc-per-node=2 -m nemo_automodel.cli.app \
  examples/speculative/eagle3/llama_eagle3_mvp.yaml \
  --recipe_args.target_model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
  --recipe_args.train_data_path /tmp/eagle3_train.jsonl \
  --recipe_args.output_dir /tmp/eagle3_out \
  --recipe_args.seq_length 256 --recipe_args.ttt_steps 2 \
  --recipe_args.num_epochs 1 \
  --checkpoint.enabled true --checkpoint.checkpoint_dir /tmp/eagle3_ckpt \
  --checkpoint.save_consolidated false
# (run again with --checkpoint.restore_from LATEST)

Crash:

File ".../recipes/llm/train_eagle3.py", line 263, in setup
    self.load_checkpoint(self.cfg.get("checkpoint.restore_from", None))
File ".../recipes/llm/train_eagle3.py", line 458, in load_checkpoint
    self.checkpointer.load_model(draft_model, os.path.join(ckpt_dir, "model"))
...
RuntimeError: Error(s) in loading state_dict for LlamaEagle3DraftModel:
    size mismatch for lm_head.weight: copying a param with shape
    torch.Size([128256, 2048]) from checkpoint, the shape in current
    model is torch.Size([8192, 2048]).

After fix — observed

Unit tests added by this PR cover both directions of the predicate change.
Local run inside the worktree:

$ python3 -m pytest tests/unit_tests/utils/test_checkpoint_utils.py -v
============================== 8 passed in 0.73s ==============================
tests/.../test_is_tied_word_embeddings_prefers_text_config_value PASSED
tests/.../test_is_tied_word_embeddings_respects_qwen3_vl_moe_exclusion PASSED
tests/.../test_is_tied_word_embeddings_falls_back_to_top_level_when_no_text_config PASSED
tests/.../test_is_tied_word_embeddings_handles_missing_config PASSED
tests/.../test_is_tied_word_embeddings_respects_exclusion_list PASSED
tests/.../test_has_local_tied_lm_head_false_when_shapes_disagree PASSED
tests/.../test_has_local_tied_lm_head_true_when_shapes_match_and_tied PASSED
tests/.../test_has_local_tied_lm_head_false_when_flag_unset PASSED

End-to-end re-verification with the daily-PR gap test
(automodel_eagle3_checkpoint_resume_daily_pr_test.sh) is queued on the
internal SLURM cluster (h100); we will update the test to drop its
--recipe_args.draft_vocab_size None workaround once this PR merges.

Why the existing test did not catch this

tests/unit_tests/recipes/llm/test_eagle_checkpoint_resume.py builds the
draft model directly without going through train_eagle3.setup(), so it
never inherits tie_word_embeddings=True from a real Llama target config.
Functional CI also did not exercise the resume path with vocab shrinking
on against a tied-embedding target. The new unit test covers the
predicate directly so the regression cannot return.

Detected by

NeMo daily-PR impact-pipeline gap test for PR #2285
(automodel_eagle3_checkpoint_resume_daily_pr_test.sh). The test
currently sidesteps the bug by passing --recipe_args.draft_vocab_size None (no shrinking); once this PR merges, the workaround will be
removed so the test exercises the realistic vocab-shrunk configuration.

Test plan

  • tests/unit_tests/utils/test_checkpoint_utils.py — 3 new unit
    tests covering shape-mismatch / shape-match-tied / untied paths;
    all 8 tests in the file pass locally.
  • ruff format --check and ruff check clean on all touched files.
  • No behavior change for ordinary tied-embedding models
    (shape-match-tied test asserts True).
  • No behavior change for untied models
    (untied test asserts False).
  • Functional EAGLE-3 resume smoke with draft_vocab_size=8192 on
    Llama-3.2-1B-Instruct — verified on H100 2026-05-26 with PR fix(speculative): EAGLE-3 vocab-shrunk checkpoint resume shape mismatch #2319
    cherry-picked at runtime; see Update section below.

Update — End-to-end verification on H100 (2026-05-26)

The daily-PR gap test
(automodel_eagle3_checkpoint_resume_daily_pr_test.sh) was re-run on
H100 with its --recipe_args.draft_vocab_size None workaround dropped,
exercising the realistic vocab-shrunk path against
meta-llama/Llama-3.2-1B-Instruct (which has
tie_word_embeddings=True). PR #2319 was applied via runtime
cherry-pick before launch.

Result: PASS. The shape-mismatch crash that previously failed this
configuration is gone.

[Run 1 — fresh training, save checkpoint]
2026-05-26 11:21:22 | Training start: start_epoch=0 num_epochs=1
                     batches_per_epoch=2 grad_accum=1 total_optim_steps=2
                     warmup_steps=1 peak_lr=1.000e-04 min_lr_ratio=0.1
2026-05-26 11:21:24 | epoch=0 step=1 train_loss=9.548348 train_acc=0.000000
2026-05-26 11:21:24 | epoch=0 step=2 train_loss=3.384131 train_acc=0.437500
2026-05-26 11:21:24 | Epoch 0 done: total_batches_seen=2 global_step=2
2026-05-26 11:21:25 | Saved checkpoint to .../checkpoints/epoch_1_step_2
2026-05-26 11:21:25 | Training complete: global_step=2

[Run 2 — resume from LATEST with default draft_vocab_size=8192]
2026-05-26 11:21:45 | Resuming from checkpoint: .../checkpoints/epoch_1_step_2
2026-05-26 11:21:46 | All 1 epochs already completed; nothing to do.
[2026-05-26T18:21:50] Run finished: status=completed failures=0

Before this PR, run 2 crashed in train_eagle3.setup()
load_checkpointCheckpointer.load_model with:

RuntimeError: Error(s) in loading state_dict for LlamaEagle3DraftModel:
    size mismatch for lm_head.weight: copying a param with shape
    torch.Size([128256, 2048]) from checkpoint, the shape in current
    model is torch.Size([8192, 2048]).

After this PR, run 2 resumes cleanly, restores global_step=2,
recognizes the single configured epoch is already complete, and exits
zero. The _load_extra_state path for selected_token_ids /
selected_token_mask round-trips through the cherry-picked fix without
the tied-lm_head pop trap firing.

Test artifacts:
/localhome/local-qiaochuz/nemo/nmfw_tests_nightly/nemo_llm/test_suite/automodel_test/logs/20260526T181815Z-01772afb

When the target model has tie_word_embeddings=True (e.g. Llama-3.2-1B/3B,
Llama-3.1-8B) and EAGLE-3 vocab shrinking is enabled (draft_vocab_size <
target_vocab_size), checkpoint resume fails with:

    RuntimeError: Error(s) in loading state_dict for LlamaEagle3DraftModel:
        size mismatch for lm_head.weight: copying a param with shape
        torch.Size([128256, 2048]) from checkpoint, the shape in current
        model is torch.Size([8192, 2048]).

Root cause: train_eagle3 builds the draft config by copying the target
config wholesale, so tie_word_embeddings=True is inherited. The draft
model intentionally builds asymmetric vocab tables -- embed_tokens at
target_vocab_size, lm_head at draft_vocab_size -- but
has_local_tied_lm_head() trusts the config flag without checking shapes.
On save, ModelState pops lm_head.weight (treating it as a tied alias);
on load, materialize_missing_tied_lm_head() copies embed_tokens.weight
(full target vocab) into the lm_head slot, producing a strict-load
shape mismatch.

Two complementary fixes:

1. components/checkpoint/utils.py: has_local_tied_lm_head() now returns
   False when lm_head.weight.shape != embed_tokens.weight.shape, even if
   config.tie_word_embeddings is True. This is the durable fix and
   protects any other model with the same asymmetry pattern.

2. recipes/llm/train_eagle3.py: explicitly clear tie_word_embeddings on
   the draft config so the inherited target flag does not propagate.
   Defensive belt-and-braces that keeps the draft config self-consistent.

Adds three unit tests in tests/unit_tests/utils/test_checkpoint_utils.py
covering shape-mismatch (False), shape-match-tied (True), and untied
(False) cases for has_local_tied_lm_head.

Signed-off-by: qiaochuz <qiaochuz@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@qiaochuz-nv
Copy link
Copy Markdown
Contributor Author

/ok to test 0440dab

@qiaochuz-nv
Copy link
Copy Markdown
Contributor Author

test script

torchrun --nproc-per-node=2 -m nemo_automodel.cli.app
examples/speculative/eagle3/llama_eagle3_mvp.yaml
--recipe_args.target_model_name_or_path meta-llama/Llama-3.2-1B-Instruct
--recipe_args.train_data_path "$DATA_PATH"
--recipe_args.output_dir "$OUT_DIR"
--recipe_args.seq_length 256
--recipe_args.ttt_steps 2
--recipe_args.num_epochs 1
--recipe_args.log_every_steps 1
--checkpoint.enabled true
--checkpoint.checkpoint_dir "$CKPT_DIR"
--checkpoint.save_consolidated false

torchrun --nproc-per-node=2 -m nemo_automodel.cli.app
examples/speculative/eagle3/llama_eagle3_mvp.yaml
--recipe_args.target_model_name_or_path meta-llama/Llama-3.2-1B-Instruct
--recipe_args.train_data_path "$DATA_PATH"
--recipe_args.output_dir "$OUT_DIR"
--recipe_args.seq_length 256
--recipe_args.ttt_steps 2
--recipe_args.num_epochs 1
--recipe_args.log_every_steps 1
--checkpoint.enabled true
--checkpoint.checkpoint_dir "$CKPT_DIR"
--checkpoint.restore_from LATEST
--checkpoint.save_consolidated false

@khazic
Copy link
Copy Markdown
Contributor

khazic commented May 27, 2026

Thanks for the PR. I reviewed the change and agree this is a real issue caused by the draft config inheriting tie_word_embeddings=True from the target model.

For EAGLE-3 with vocab shrinking, the draft embed_tokens and lm_head intentionally have different vocab dimensions, so treating them as tied can make checkpoint resume materialize a shape-mismatched lm_head.weight. This was a mistake in my earlier code path. The shape check plus explicitly setting the draft config to untied makes sense to me.

@HuiyingLi HuiyingLi merged commit 3d7ee46 into main May 27, 2026
83 checks passed
@HuiyingLi HuiyingLi deleted the qiaochuz-nv/fix/eagle3-vocab-shrunk-resume branch May 27, 2026 03:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants