Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def load_model(
state_dict[lm_head_param_name] = state_dict.pop(compat_tied_lm_head_source_key)

state_dict = _maybe_adapt_state_dict_from_hf(model_state.model[0], state_dict, moe_mesh=self.moe_mesh)
key_diff = _summarize_state_dict_key_diff(expected_keys, set(state_dict.keys()))
expected_keys_for_diff = {k for k in expected_keys if not k.endswith("_extra_state")}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @adil-a , i think for bf16 this should be fine, but thinking ahead, I'm wondering if that would break any fp8 workflows? Please let me know what you think.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we put a check to ensure the model is in bf16 as well?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, i would remove _extra_state if it's empty on the model side

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response from CC:

Good question — I dug into this and it turns out FP8 workflows aren't affected, because the codebase already treats _extra_state as ephemeral on every save/load
path. Quick summary of why:

TE-based FP8. TE's FP8 amax history lives in _extra_state, but it's never round-tripped through our checkpointer:

  • _maybe_adapt_state_dict_to_hf in checkpointing.py calls adapter.to_hf(..., exclude_key_regex=r"._extra_state.", ...) unconditionally on save (line 1590).
  • stateful_wrappers.py monkey-patches TransformerEngineBaseModule.set_extra_state and BasicOperation.set_extra_state to no-op on DCP's _EXTRA_STATE sentinel (lines
    31–49), gated only on HAS_TE — not on FP8 being enabled.
  • te_attention.py stashes the TE attention module via object.setattr specifically so attn_module._extra_state never enters the state_dict (lines 548–553, with
    a comment to that effect).
  • On load, when a TE module has a custom get_extra_state, an empty torch.tensor([], dtype=uint8) placeholder is injected so DCP doesn't complain (lines 1437–1443).

Net result: TE FP8 amax history is rebuilt from observed activations after load. Filtering _extra_state from the mismatch warning only hides keys that the rest of
the framework is already silently dropping.

torchao Float8Linear. Doesn't use _extra_state at all — Float8Linear.{get,set}_extra_state is nn.Module.{get,set}_extra_state, its state_dict() is ['weight',
'bias'], and the weight stays as a bf16/fp32 nn.Parameter (FP8 conversion happens dynamically in forward).

So unconditionally filtering _extra_state from the warning is consistent with how all four other code paths handle it, and won't mask any real mismatch for FP8
workflows.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, i would remove _extra_state if it's empty on the model side

It comes with the TE nn.Modules as a parameter. We'd have to do unnecessary plumbing for this. TBH checkpointing is already in a finicky spot I'd rather we avoid changing things where possible in the current state.

loaded_keys_for_diff = {k for k in state_dict if not k.endswith("_extra_state")}
key_diff = _summarize_state_dict_key_diff(expected_keys_for_diff, loaded_keys_for_diff)
if key_diff["missing_count"] or key_diff["unexpected_count"]:
logging.warning(
"Checkpoint key mismatch for %s: missing=%d unexpected=%d "
Expand Down
Loading