diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index c4796d5188..facddf582e 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -554,7 +554,9 @@ def load_model( state_dict[lm_head_param_name] = state_dict.pop(compat_tied_lm_head_source_key) state_dict = _maybe_adapt_state_dict_from_hf(model_state.model[0], state_dict, moe_mesh=self.moe_mesh) - key_diff = _summarize_state_dict_key_diff(expected_keys, set(state_dict.keys())) + expected_keys_for_diff = {k for k in expected_keys if not k.endswith("_extra_state")} + loaded_keys_for_diff = {k for k in state_dict if not k.endswith("_extra_state")} + key_diff = _summarize_state_dict_key_diff(expected_keys_for_diff, loaded_keys_for_diff) if key_diff["missing_count"] or key_diff["unexpected_count"]: logging.warning( "Checkpoint key mismatch for %s: missing=%d unexpected=%d "