From c7cfaee520ec923af79485ab84afb44d8abacd60 Mon Sep 17 00:00:00 2001 From: adil-a Date: Fri, 15 May 2026 08:38:02 -0700 Subject: [PATCH] fix(checkpoint): exclude TE _extra_state keys from load-time mismatch warning TransformerEngine modules attach `_extra_state` entries to their state_dict for internal bookkeeping that is not present in HF safetensors checkpoints. These were being reported as missing keys in the load diagnostic, producing a noisy warning with up to dozens of `_extra_state` examples on every load. The `set_extra_state` shim already tolerates their absence, so filtering them out of the mismatch summary keeps the warning focused on real weight mismatches. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: adil-a --- nemo_automodel/components/checkpoint/checkpointing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index c4796d5188..facddf582e 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -554,7 +554,9 @@ def load_model( state_dict[lm_head_param_name] = state_dict.pop(compat_tied_lm_head_source_key) state_dict = _maybe_adapt_state_dict_from_hf(model_state.model[0], state_dict, moe_mesh=self.moe_mesh) - key_diff = _summarize_state_dict_key_diff(expected_keys, set(state_dict.keys())) + expected_keys_for_diff = {k for k in expected_keys if not k.endswith("_extra_state")} + loaded_keys_for_diff = {k for k in state_dict if not k.endswith("_extra_state")} + key_diff = _summarize_state_dict_key_diff(expected_keys_for_diff, loaded_keys_for_diff) if key_diff["missing_count"] or key_diff["unexpected_count"]: logging.warning( "Checkpoint key mismatch for %s: missing=%d unexpected=%d "