Skip to content

feat: Add gemma4 drafter model support#2240

Open
athitten wants to merge 13 commits into
mainfrom
athitten/gemma4_drafter_support
Open

feat: Add gemma4 drafter model support#2240
athitten wants to merge 13 commits into
mainfrom
athitten/gemma4_drafter_support

Conversation

@athitten
Copy link
Copy Markdown
Contributor

@athitten athitten commented May 15, 2026

What does this PR do ?

This PR adds joint fine-tuning support for Gemma 4 base (also called target model) and drafter/assistant models that enable multi-token prediction (MTP). The drafter is co-trained with the Gemma 4 base end-to-end via a composite model (Gemma4WithDrafter) that wires up shared K/V states, sqrt(H_b)-scaled embeddings, and a K-step recurrent forward matching the Gemma 4 drafter tech report. The PR provides two reference configs for joint fine-tuning of gemma-4-E4B-it and gemma-4-E4B-it-assistant, one with MedPix VQA dataset and the other with a text-only Tulu-3 + Magicoder mix. Also provides an inference benchmark script that validates speculative-decode throughput on the saved checkpoint.

The feature has been verified only against the Gemma 4 4B (E4B) base + drafter pair. The composite is architecturally model-agnostic within the Gemma 4 family, but the example YAMLs, the parity tests, and all training/inference verification in this PR target the 4B pair.

Implementation

  • Gemma4WithDrafter composite (nemo_automodel/components/models/gemma4_drafter/)

    • Wraps a Gemma4ForConditionalGeneration base + Gemma4AssistantForCausalLM drafter as a single nn.Module for FSDP2 training.
    • K-step recurrent forward: round 0 feeds cat(base.embed(input_ids), base.h_final); round k≥1 feeds cat(base.embed(input_ids_shifted_by_k), prev_drafter.last_hidden_state). shared_kv_states is captured once from a single base forward and reused at every round (per the Gemma 4 drafter tech report).
    • Saves to <ckpt>/base/ and <ckpt>/drafter/ HF-loadable subdirs for vLLM / HF inference handoff.
    • Conditionally unfreezes post_projection only when drafter_num_steps > 1
    • masked_embedding.centroids always frozen (the torch.topk inside it blocks gradient flow back to the centroids).
  • Recipe (recipes/vlm/finetune.py, recipes/base_recipe.py)

    • New _maybe_add_drafter_loss(out, base_loss, labels, …) sums per-step CE over the composite's drafter_logits list with _shift_labels_left(labels, k) per round.
    • base_recipe.load_checkpoint dispatches to model.load_pretrained when the composite exposes it, so the base/ + drafter/ subdir layout reloads correctly on resume.
  • Text-only dataset adapter (components/datasets/vlm/datasets.py)

    • make_tulu3_magicoder_text_mix_dataset interleaves allenai/tulu-3-sft-mixture (80 %) and ise-uiuc/Magicoder-OSS-Instruct-75K (20 %) into {"conversation": [...]} dicts with no image field, so default_collate_fn emits batches without pixel_values. The composite + base accept text-only inputs unchanged.
  • Example configs and benchmark (examples/vlm_finetune/gemma4_joint_drafter/)

    • gemma4_4b_joint_drafter_medpix.yaml — Joint fine-tuning recipe of google/gemma-4-E4B-it + google/gemma-4-E4B-it-assistant with MedPix VQA
    • gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml — Joint fine-tuning on tulu and magiccoder mix, same base + drafter pair
    • benchmark_mtp_inference.py — measures speculative-decode acceptance + throughput end-to-end against the trained 4B pair

Testing and Verification

All numbers below are for joint fine-tuning of gemma-4-E4B-it with /gemma-4-E4B-it-assistant

  1. Fine-tuning the joint model on MedPix-VQA: no NaNs, no loss spikes, no grad-norm spikes. Using the recipe examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_medpix.yaml
    Loss curve:
Screenshot 2026-05-24 at 4 25 48 PM
  1. Large-scale fine-tuning on a Tulu-3 (80 %) + Magicoder (20 %) mix for 500 steps. Stable fine-tuning run observed with the recipe examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml
    Loss curve:
Screenshot 2026-05-24 at 4 28 18 PM
  1. Inference run on the tulu + magicoder fine-tuned checkpoint after 500 steps. Ran benchmark_mtp_inference.py(uses transformers generate) against the saved base/ + drafter/ pair. Results below:
image
  1. Verified save + standalone HF reload works correctly. Checkpoints are stored in separate <ckpt>/base/ and <ckpt>/drafter/ subdirs as a prerequisite for generation. After a 5-step training run, each sub-checkpoint loads via plain HF from_pretrained without any NeMo-specific code:

    • Gemma4ForConditionalGeneration.from_pretrained("<ckpt>/base/model/consolidated")7.94 B params
    • Gemma4AssistantForCausalLM.from_pretrained("<ckpt>/drafter/model/consolidated")78.5 M params
    • State-dict key sets match the upstream released references (google/gemma-4-E4B-it and google/gemma-4-E4B-it-assistant) exactly (0 missing, 0 extra).
    • masked_embedding.token_ordering (int64[262144]) buffer survives the DCP → safetensors path.
  2. Verified resume-from-checkpoint loss parity. Run A: 10 steps fresh with ckpt_every_steps=5. Run B: same config but --checkpoint.restore_from <RunA>/epoch_0_step_4. Per-step loss is bit-identical between A and B at every overlapping step (5–9): 3.1257, 3.2992, 3.4045, 2.5375, 3.2824. Confirms model weights + optimizer state + LR scheduler state + dataloader state + RNG all restore correctly.

  3. Unit and functional tests.

  • tests/unit_tests/models/gemma4_drafter/test_composite.py — composite pre-projection input layout, shared-KV plumbing, K-step recurrence, save/load round-trip.
  • tests/unit_tests/models/gemma4_drafter/test_composite_fsdp2.py — FSDP2 wrap + expert/grad sync.
  • tests/unit_tests/models/gemma4_drafter/test_drafter_wrapper.py — drafter sub-module integration.
  • tests/unit_tests/recipes/test_vlm_drafter_helpers.py_shift_labels_left + _maybe_add_drafter_loss.
  • tests/functional_tests/L2_HF_Transformer_VLM_Gemma4_Joint_Drafter.sh — L2 single-node smoke.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Signed-off-by: Abhishree <abhishreetm@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@athitten athitten changed the title Add gemma4 drafter model support [WIP] Add gemma4 drafter model support May 15, 2026
athitten and others added 3 commits May 15, 2026 14:49
Apply fixes for joint base + drafter training:

* Drop ``use_cache=False`` override in ``composite.forward``. Without the
  ``DynamicCache``, HF's sliding-window mask path silently degrades
  (SDPA mask-skip can collapse sliding layers into plain causal attention),
  inflating the initial training loss. The YAML's
  ``text_config.use_cache: true`` now takes effect.
* Change drafter label shift from ``k + 1`` to ``k``. The VLM collate
  pre-shifts labels by 1 so ``labels[t] == input_ids[t + 1]``; the prior
  ``k + 1`` shift was training the drafter to predict ``input_ids[t + 2]``
  instead of ``input_ids[t + 1]``.
* Add hard asserts: ``cp_size == 1`` and ``torch_dtype == bfloat16`` in
  ``Gemma4WithDrafter.from_pretrained``.
* Add plan knobs: ``freeze_base_for_drafter``, ``share_embedding_with_base``
  (one-shot init copy; FSDP2-safe), ``base_activation_checkpointing``.
* Recipe: factor joint loss into ``FinetuneRecipeForVLM._maybe_add_drafter_loss``,
  gate log on ``is_remote_logging_step`` (was per-microbatch), and make
  validation drafter-aware so ``val_loss`` reflects drafter drift.
* Remove dead ``from_pretrained`` override in drafter wrapper.
* Drop redundant ``text_config.output_hidden_states`` from YAML; expand the
  ``use_cache: true`` comment to explain the real reason (sliding-window mask,
  not KV sharing).
* Add ``test_post_collate_semantic_alignment`` that pins the label-shift
  convention so a future regression to ``k + 1`` fails loudly. Refine
  ``test_drafter_loss_reaches_drafter_params`` to reflect that
  ``post_projection`` only sees gradient in multi-step chains.

Signed-off-by: Abhishree <abhishreetm@gmail.com>
- Composite: K-step recurrent forward where the drafter consumes its prior
  round's post-projected last_hidden_state and a teacher-forced shifted
  token id at every k>=1. shared_kv_states captured once from a single base
  forward and reused across rounds. post_projection conditionally unfrozen
  when drafter_num_steps > 1.
- Recipe load path: dispatch to model.load_pretrained when the composite
  exposes it so the base/ + drafter/ subdir layout produced by
  save_pretrained can be reloaded for resume.
- Dataset adapter: make_tulu3_magicoder_text_mix_dataset interleaves
  allenai/tulu-3-sft-mixture (80%) and ise-uiuc/Magicoder-OSS-Instruct-75K
  (20%) into a text-only VLM-shaped list consumed by default_collate_fn
  without producing pixel_values.
- YAMLs: rename joint_drafter.yaml -> _medpix.yaml; add
  _tulu_magicoder_mix.yaml at drafter_loss_weight 0.001 (1/10 of the MedPix
  setting to compensate for ~10x larger summed CE on longer text sequences).
…mark

- Move gemma4_4b_joint_drafter_medpix.yaml and
  gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml into a dedicated
  examples/vlm_finetune/gemma4_joint_drafter/ subdir so the joint-drafter
  variants are easy to find next to each other.
- Add benchmark_mtp_inference.py for measuring speculative-decoding
  throughput / acceptance with the trained base + drafter pair.

Signed-off-by: Abhishree <abhishreetm@gmail.com>
@athitten athitten changed the title [WIP] Add gemma4 drafter model support feat: Add gemma4 drafter model support May 24, 2026
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test 753ee70

… Assistant entry

- Tag Gemma 4 E2B IT as kv-shared (matches E4B; both Es use the kv-shared
  layer pattern).
- Add Gemma 4 E4B IT Assistant alongside the E4B IT entry in Available
  Models so the assistant/drafter checkpoint is discoverable from the
  same row.
- Tighten the Gemma4AssistantForCausalLM architecture description
  ("drafter / assistant", not "drafter / assistant head").

Signed-off-by: Abhishree <abhishreetm@gmail.com>
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test fddba86

Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed tech pubs review of docs/model-coverage/vlm/google/gemma4.md and left a few copyedits.

Comment thread docs/model-coverage/vlm/google/gemma4.md Outdated
Comment thread docs/model-coverage/vlm/google/gemma4.md Outdated
Comment thread docs/model-coverage/vlm/google/gemma4.md Outdated
athitten and others added 3 commits May 26, 2026 16:17
…AML + cord_v2 fixture

Three small fixes to make the L2 smoke runnable against the current PR
state. Validated locally with $TEST_DATA_DIR pointed at the tiny
hf_gemma4_e4b_{2l,assistant_2l} pair and $HF_CACHE at the staged
mini_cord_v2 fixture: 3 train steps + 3 val passes complete cleanly,
final ckpt saves both base/ and drafter/ subdirs.

- Config path: gemma4_4b_joint_drafter.yaml ->
  gemma4_joint_drafter/gemma4_4b_joint_drafter_medpix.yaml (the file was
  renamed and moved into a dedicated subdir during the joint-drafter
  example reorg).
- Dataset: swap mini_medpix for mini_cord_v2 (the standard VLM mini
  dataset already staged in CI under $HF_CACHE) and override both
  ``dataset._target_`` and ``validation_dataset._target_`` to
  ``make_cord_v2_dataset``. Also pin ``--dataset.split=train`` /
  ``--validation_dataset.split=validation`` because the YAML defaults
  use HF slice expressions (``train[:1000]``) that only resolve against
  the hub-hosted MedPix dataset, not the local parquet fixture.
- LR scheduler: add ``--lr_scheduler.lr_warmup_steps 0``. The YAML
  default ``lr_warmup_steps: 25`` is larger than ``max_steps: 3``, which
  trips ``OptimizerParamScheduler``'s
  ``assert lr_warmup_steps < lr_decay_steps`` at construction.

Signed-off-by: Abhishree <abhishreetm@gmail.com>
…xt-mix dataset, drafter helpers

Three test additions covering the new joint-drafter / text-mix paths
introduced by this PR. All run on CPU and require no external HF
downloads, so they execute in L0_Unit_Tests on every PR.

- tests/unit_tests/models/gemma4_drafter/test_composite_stubs.py (new,
  +718): stub-based tests for ``Gemma4WithDrafter`` that mimic the HF
  ``Gemma4ForConditionalGeneration`` and ``Gemma4AssistantForCausalLM``
  surface via plain ``torch.nn.Module`` doubles. Covers ``__init__``
  validation (K<1 rejection, share-embedding shape mismatch,
  base_activation_checkpointing wiring), side-effects (post_projection
  freezing for K=1 vs K>1, masked_embedding.centroids freezing,
  freeze_base_for_drafter, share_embedding_with_base copy + lm_head
  tie), and forward semantics for K=1 / K=2 (teacher-forced shifted
  ids, recurrent post_projection feedback, shared_kv_states reuse).
  Complements ``test_composite.py`` which requires the optional
  transformers TOT install.
- tests/unit_tests/datasets/vlm/test_datasets.py (+686): broader
  coverage of the new ``make_tulu3_magicoder_text_mix_dataset`` adapter
  (80/20 interleave probabilities, max_turns filter, missing
  assistant-turn drop, magicoder problem/solution shape, limit_total
  cap, no image fields in output).
- tests/unit_tests/recipes/test_vlm_drafter_helpers.py (+192):
  additional cases for the ``_maybe_add_drafter_loss`` path -- empty
  ``drafter_logits``, multi-K label-shift correctness, lambda gating,
  log-line gating off ``is_remote_logging_step``.

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com>
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test ebf0f46

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MTP support for gemma4

2 participants