Skip to content

fix: call init_weights() instead of initialize_weights() to restore w…#2213

Open
Meiyim wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
Meiyim:fix-init-weight-for-meta-device
Open

fix: call init_weights() instead of initialize_weights() to restore w…#2213
Meiyim wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
Meiyim:fix-init-weight-for-meta-device

Conversation

@Meiyim
Copy link
Copy Markdown

@Meiyim Meiyim commented May 12, 2026

initialize_weights() only runs per-module _init_weights; init_weights() additionally calls tie_weights() afterward, restoring the lm_head/embed_tokens weight tie that FSDP2's fully_shard() breaks by replacing shared parameters with separate DTensors.

What does this PR do ?

Fix broken weight tying for models with shared embeddings (e.g. tie_word_embeddings=True) when training from scratch with FSDP2.

Changelog

  • checkpointing.py: Replace model.initialize_weights() with model.init_weights() in initialize_model_weights().

Before your PR is "Ready for review"

Additional Information

  • Related to # (issue)

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 12, 2026

Hi @Meiyim thanks a lot for the fix!

I see that there's a conflict with main, would you mind taking a look? I can then help trigger the CI.

Please let me know if I can help with anything. Thank you

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 12, 2026
@Meiyim Meiyim force-pushed the fix-init-weight-for-meta-device branch 2 times, most recently from 82fa837 to 43d7766 Compare May 13, 2026 10:07
@Meiyim
Copy link
Copy Markdown
Author

Meiyim commented May 13, 2026

hi @akoumpa, i resolved the conflict. can you trigger CI ,thanks.

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 13, 2026
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 13, 2026

/ok to test 43d7766

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label May 15, 2026
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented May 19, 2026

Hi @Meiyim thanks for making this, I see there's a few tests failing, can you take a look? it's ok to update the tests if needed. Please let me know if there's anything I can help.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond and removed waiting-on-maintainers Waiting on maintainers to respond labels May 19, 2026
Meiyim added 2 commits May 20, 2026 11:11
@Meiyim Meiyim force-pushed the fix-init-weight-for-meta-device branch from 43d7766 to 2ea4e19 Compare May 20, 2026 03:26
@Meiyim
Copy link
Copy Markdown
Author

Meiyim commented May 20, 2026

hi @akoumpa , i update some of the testcase, can you trigger another round of unittesting ?

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 2ea4e19

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-maintainers Waiting on maintainers to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants