Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,8 @@ def initialize_model_weights(
if hasattr(module, "_is_hf_initialized"):
module._is_hf_initialized = False

if hasattr(model, "initialize_weights"):
model.initialize_weights()
if hasattr(model, "init_weights"):
model.init_weights()
else:
logging.warning(
"Warning: Model does not have initialize_weights method."
Expand Down
50 changes: 25 additions & 25 deletions tests/unit_tests/checkpoint/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,19 +627,19 @@ def test_resets_is_hf_initialized(self):
if hasattr(module, "_is_hf_initialized"):
assert module._is_hf_initialized is False

def test_calls_initialize_weights(self):
"""model.initialize_weights() should be called when available."""
def test_calls_init_weights(self):
"""model.init_weights() should be called when available."""
model = self._make_meta_model()
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_called_once()
model.init_weights.assert_called_once()

def test_warns_when_no_initialize_weights_method(self):
"""Should log a warning when model lacks initialize_weights."""
def test_warns_when_no_init_weights_method(self):
"""Should log a warning when model lacks init_weights."""
model = self._make_meta_model()
assert not hasattr(model, "initialize_weights")
assert not hasattr(model, "init_weights")

with patch("nemo_automodel.components.checkpoint.checkpointing.logging") as mock_logging:
Checkpointer.initialize_model_weights(model, torch.device("cpu"))
Expand All @@ -650,22 +650,22 @@ def test_skips_for_nemotron_v2(self):
model = self._make_meta_model()
model.config = SimpleNamespace(architectures=["NemotronHForCausalLM"])
model._is_hf_initialized = True
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_not_called()
model.init_weights.assert_not_called()
assert model._is_hf_initialized is True

def test_does_not_skip_for_nemotron_v3_moe(self):
"""NemotronHForCausalLM v3 (with n_routed_experts) should NOT be skipped."""
model = self._make_meta_model()
model.config = SimpleNamespace(architectures=["NemotronHForCausalLM"], n_routed_experts=8)
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_called_once()
model.init_weights.assert_called_once()

@pytest.mark.parametrize(
"architecture",
Expand All @@ -677,28 +677,28 @@ def test_skips_for_gemma3(self, architecture):
model = self._make_meta_model()
model.config = SimpleNamespace(architectures=[architecture])
model._is_hf_initialized = True
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_not_called()
model.init_weights.assert_not_called()
assert model._is_hf_initialized is True

def test_handles_missing_config_gracefully(self):
"""Model without config.architectures should not raise."""
with torch.device("meta"):
model = torch.nn.Linear(4, 4)
model.config = SimpleNamespace()
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_called_once()
model.init_weights.assert_called_once()

def test_peft_init_method_calls_init_peft_adapters(self):
"""When peft_init_method is provided, _init_peft_adapters should be called."""
model = self._make_meta_model()
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

with patch("nemo_automodel.components.checkpoint.checkpointing._init_peft_adapters") as mock_init_peft:
Checkpointer.initialize_model_weights(model, torch.device("cpu"), peft_init_method="xavier")
Expand All @@ -708,7 +708,7 @@ def test_peft_init_method_calls_init_peft_adapters(self):
def test_peft_init_method_none_skips_init_peft_adapters(self):
"""When peft_init_method is None (default), _init_peft_adapters should NOT be called."""
model = self._make_meta_model()
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

with patch("nemo_automodel.components.checkpoint.checkpointing._init_peft_adapters") as mock_init_peft:
Checkpointer.initialize_model_weights(model, torch.device("cpu"))
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def test_init_step_with_keymap_uses_backport(self):

class TestSkipInitWeightsOnLoadGate:
"""The Checkpointer.initialize_model_weights gate that lets a model opt
out of HF's initialize_weights() via a class attribute.
out of HF's init_weights() via a class attribute.

Without this gate, Mistral3FP8VLMForConditionalGeneration's PP load
deadlocks on stage-divergent DTensor collectives inside HF's init.
Expand All @@ -1027,33 +1027,33 @@ def test_skip_when_attr_true(self):
"""A model with _skip_init_weights_on_load=True takes the skip branch."""
model = self._make_meta_model()
model._skip_init_weights_on_load = True
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_not_called()
model.init_weights.assert_not_called()
# And the _is_hf_initialized flag is left alone (not reset to False).
assert model._is_hf_initialized is True

def test_does_not_skip_when_attr_false(self):
"""attr=False (or attr-missing default) does NOT take the skip branch."""
model = self._make_meta_model()
model._skip_init_weights_on_load = False
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_called_once()
model.init_weights.assert_called_once()

def test_does_not_skip_when_attr_missing(self):
"""No attr at all → default behavior (initialize_weights runs)."""
"""No attr at all → default behavior (init_weights runs)."""
model = self._make_meta_model()
assert not hasattr(model, "_skip_init_weights_on_load")
model.initialize_weights = MagicMock()
model.init_weights = MagicMock()

Checkpointer.initialize_model_weights(model, torch.device("cpu"))

model.initialize_weights.assert_called_once()
model.init_weights.assert_called_once()


class TestConsolidatedIndexUnderPPWithoutSourceIndex:
Expand Down
Loading