fix(speculative): EAGLE-3 vocab-shrunk checkpoint resume shape mismatch#2319
Conversation
When the target model has tie_word_embeddings=True (e.g. Llama-3.2-1B/3B,
Llama-3.1-8B) and EAGLE-3 vocab shrinking is enabled (draft_vocab_size <
target_vocab_size), checkpoint resume fails with:
RuntimeError: Error(s) in loading state_dict for LlamaEagle3DraftModel:
size mismatch for lm_head.weight: copying a param with shape
torch.Size([128256, 2048]) from checkpoint, the shape in current
model is torch.Size([8192, 2048]).
Root cause: train_eagle3 builds the draft config by copying the target
config wholesale, so tie_word_embeddings=True is inherited. The draft
model intentionally builds asymmetric vocab tables -- embed_tokens at
target_vocab_size, lm_head at draft_vocab_size -- but
has_local_tied_lm_head() trusts the config flag without checking shapes.
On save, ModelState pops lm_head.weight (treating it as a tied alias);
on load, materialize_missing_tied_lm_head() copies embed_tokens.weight
(full target vocab) into the lm_head slot, producing a strict-load
shape mismatch.
Two complementary fixes:
1. components/checkpoint/utils.py: has_local_tied_lm_head() now returns
False when lm_head.weight.shape != embed_tokens.weight.shape, even if
config.tie_word_embeddings is True. This is the durable fix and
protects any other model with the same asymmetry pattern.
2. recipes/llm/train_eagle3.py: explicitly clear tie_word_embeddings on
the draft config so the inherited target flag does not propagate.
Defensive belt-and-braces that keeps the draft config self-consistent.
Adds three unit tests in tests/unit_tests/utils/test_checkpoint_utils.py
covering shape-mismatch (False), shape-match-tied (True), and untied
(False) cases for has_local_tied_lm_head.
Signed-off-by: qiaochuz <qiaochuz@nvidia.com>
|
/ok to test 0440dab |
|
test script torchrun --nproc-per-node=2 -m nemo_automodel.cli.app torchrun --nproc-per-node=2 -m nemo_automodel.cli.app |
|
Thanks for the PR. I reviewed the change and agree this is a real issue caused by the draft config inheriting For EAGLE-3 with vocab shrinking, the draft |
Summary
Fix
LlamaEagle3DraftModelcheckpoint resume when vocab shrinking is enabled(
draft_vocab_size < target_vocab_size) and the target model hastie_word_embeddings=True.Two complementary changes:
components/checkpoint/utils.py—has_local_tied_lm_head()nowreturns
Falsewhenlm_head.weight.shape != embed_tokens.weight.shape,even if
config.tie_word_embeddingsisTrue. This is the durable fixand protects any other model with the same asymmetry pattern.
recipes/llm/train_eagle3.py— explicitly setdraft_config["tie_word_embeddings"] = Falseafter copying the targetconfig, so the inherited target flag cannot propagate into the draft.
Defensive belt-and-braces that keeps the draft config self-consistent
with the actual draft model architecture.
Adds three unit tests in
tests/unit_tests/utils/test_checkpoint_utils.pycovering shape-mismatch (False), shape-match-tied (True), and untied
(False) paths for
has_local_tied_lm_head.Root cause
recipes/llm/train_eagle3.py:165-184builds the draft config by copyingthe target config wholesale:
tie_word_embeddingsis not overridden, so for Llama-3.2-1B-Instruct (andany other tied-embedding target) the draft inherits
True.components/speculative/eagle/draft_llama.py:440,478then intentionallybuilds asymmetric vocab tables —
embed_tokensatconfig.vocab_size(full target vocab, e.g. 128256) and
lm_headatself.draft_vocab_size(e.g. 8192). The two are not actually tied; they have different shapes by
design.
components/checkpoint/utils.py:has_local_tied_lm_headpreviouslytrusted the config flag without checking shapes:
That made
ModelState(
components/checkpoint/stateful_wrappers.py:286-287) poplm_head.weightfrom the saved state_dict on save (treating it as a tiedalias of the embedding), and made
materialize_missing_tied_lm_head(
components/checkpoint/utils.py:205-244) copymodel.embed_tokens.weight(full target vocab) into the missinglm_head.weightslot on load. The strict matcher then crashed becausethe synthesized tensor was
[128256, 2048]while the reallm_headwas[8192, 2048].Setting
--checkpoint.save_consolidated falsedid not help: thepop / materialize logic lives in
ModelState, which is used byCheckpointer.save_model/load_modelfor the DCP shards regardless ofthe consolidated-safetensors addon.
Repro (before fix)
EAGLE-3 recipe with default
draft_vocab_size: 8192, againstmeta-llama/Llama-3.2-1B-Instruct(which hastie_word_embeddings=True):Crash:
After fix — observed
Unit tests added by this PR cover both directions of the predicate change.
Local run inside the worktree:
End-to-end re-verification with the daily-PR gap test
(
automodel_eagle3_checkpoint_resume_daily_pr_test.sh) is queued on theinternal SLURM cluster (h100); we will update the test to drop its
--recipe_args.draft_vocab_size Noneworkaround once this PR merges.Why the existing test did not catch this
tests/unit_tests/recipes/llm/test_eagle_checkpoint_resume.pybuilds thedraft model directly without going through
train_eagle3.setup(), so itnever inherits
tie_word_embeddings=Truefrom a real Llama target config.Functional CI also did not exercise the resume path with vocab shrinking
on against a tied-embedding target. The new unit test covers the
predicate directly so the regression cannot return.
Detected by
NeMo daily-PR impact-pipeline gap test for PR #2285
(
automodel_eagle3_checkpoint_resume_daily_pr_test.sh). The testcurrently sidesteps the bug by passing
--recipe_args.draft_vocab_size None(no shrinking); once this PR merges, the workaround will beremoved so the test exercises the realistic vocab-shrunk configuration.
Test plan
tests/unit_tests/utils/test_checkpoint_utils.py— 3 new unittests covering shape-mismatch / shape-match-tied / untied paths;
all 8 tests in the file pass locally.
ruff format --checkandruff checkclean on all touched files.(shape-match-tied test asserts
True).(untied test asserts
False).draft_vocab_size=8192onLlama-3.2-1B-Instruct — verified on H100 2026-05-26 with PR fix(speculative): EAGLE-3 vocab-shrunk checkpoint resume shape mismatch #2319
cherry-picked at runtime; see Update section below.
Update — End-to-end verification on H100 (2026-05-26)
The daily-PR gap test
(
automodel_eagle3_checkpoint_resume_daily_pr_test.sh) was re-run onH100 with its
--recipe_args.draft_vocab_size Noneworkaround dropped,exercising the realistic vocab-shrunk path against
meta-llama/Llama-3.2-1B-Instruct(which hastie_word_embeddings=True). PR #2319 was applied via runtimecherry-pick before launch.
Result: PASS. The shape-mismatch crash that previously failed this
configuration is gone.
Before this PR, run 2 crashed in
train_eagle3.setup()→load_checkpoint→Checkpointer.load_modelwith:After this PR, run 2 resumes cleanly, restores
global_step=2,recognizes the single configured epoch is already complete, and exits
zero. The
_load_extra_statepath forselected_token_ids/selected_token_maskround-trips through the cherry-picked fix withoutthe tied-lm_head pop trap firing.
Test artifacts:
/localhome/local-qiaochuz/nemo/nmfw_tests_nightly/nemo_llm/test_suite/automodel_test/logs/20260526T181815Z-01772afb