feat: Add gemma4 drafter model support#2240
Open
athitten wants to merge 13 commits into
Open
Conversation
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Apply fixes for joint base + drafter training: * Drop ``use_cache=False`` override in ``composite.forward``. Without the ``DynamicCache``, HF's sliding-window mask path silently degrades (SDPA mask-skip can collapse sliding layers into plain causal attention), inflating the initial training loss. The YAML's ``text_config.use_cache: true`` now takes effect. * Change drafter label shift from ``k + 1`` to ``k``. The VLM collate pre-shifts labels by 1 so ``labels[t] == input_ids[t + 1]``; the prior ``k + 1`` shift was training the drafter to predict ``input_ids[t + 2]`` instead of ``input_ids[t + 1]``. * Add hard asserts: ``cp_size == 1`` and ``torch_dtype == bfloat16`` in ``Gemma4WithDrafter.from_pretrained``. * Add plan knobs: ``freeze_base_for_drafter``, ``share_embedding_with_base`` (one-shot init copy; FSDP2-safe), ``base_activation_checkpointing``. * Recipe: factor joint loss into ``FinetuneRecipeForVLM._maybe_add_drafter_loss``, gate log on ``is_remote_logging_step`` (was per-microbatch), and make validation drafter-aware so ``val_loss`` reflects drafter drift. * Remove dead ``from_pretrained`` override in drafter wrapper. * Drop redundant ``text_config.output_hidden_states`` from YAML; expand the ``use_cache: true`` comment to explain the real reason (sliding-window mask, not KV sharing). * Add ``test_post_collate_semantic_alignment`` that pins the label-shift convention so a future regression to ``k + 1`` fails loudly. Refine ``test_drafter_loss_reaches_drafter_params`` to reflect that ``post_projection`` only sees gradient in multi-step chains. Signed-off-by: Abhishree <abhishreetm@gmail.com>
- Composite: K-step recurrent forward where the drafter consumes its prior round's post-projected last_hidden_state and a teacher-forced shifted token id at every k>=1. shared_kv_states captured once from a single base forward and reused across rounds. post_projection conditionally unfrozen when drafter_num_steps > 1. - Recipe load path: dispatch to model.load_pretrained when the composite exposes it so the base/ + drafter/ subdir layout produced by save_pretrained can be reloaded for resume. - Dataset adapter: make_tulu3_magicoder_text_mix_dataset interleaves allenai/tulu-3-sft-mixture (80%) and ise-uiuc/Magicoder-OSS-Instruct-75K (20%) into a text-only VLM-shaped list consumed by default_collate_fn without producing pixel_values. - YAMLs: rename joint_drafter.yaml -> _medpix.yaml; add _tulu_magicoder_mix.yaml at drafter_loss_weight 0.001 (1/10 of the MedPix setting to compensate for ~10x larger summed CE on longer text sequences).
…mark - Move gemma4_4b_joint_drafter_medpix.yaml and gemma4_4b_joint_drafter_tulu_magicoder_mix.yaml into a dedicated examples/vlm_finetune/gemma4_joint_drafter/ subdir so the joint-drafter variants are easy to find next to each other. - Add benchmark_mtp_inference.py for measuring speculative-decoding throughput / acceptance with the trained base + drafter pair. Signed-off-by: Abhishree <abhishreetm@gmail.com>
Contributor
Author
|
/ok to test 753ee70 |
… Assistant entry
- Tag Gemma 4 E2B IT as kv-shared (matches E4B; both Es use the kv-shared
layer pattern).
- Add Gemma 4 E4B IT Assistant alongside the E4B IT entry in Available
Models so the assistant/drafter checkpoint is discoverable from the
same row.
- Tighten the Gemma4AssistantForCausalLM architecture description
("drafter / assistant", not "drafter / assistant head").
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Contributor
Author
|
/ok to test fddba86 |
jgerh
reviewed
May 26, 2026
Contributor
jgerh
left a comment
There was a problem hiding this comment.
Completed tech pubs review of docs/model-coverage/vlm/google/gemma4.md and left a few copyedits.
…AML + cord_v2 fixture
Three small fixes to make the L2 smoke runnable against the current PR
state. Validated locally with $TEST_DATA_DIR pointed at the tiny
hf_gemma4_e4b_{2l,assistant_2l} pair and $HF_CACHE at the staged
mini_cord_v2 fixture: 3 train steps + 3 val passes complete cleanly,
final ckpt saves both base/ and drafter/ subdirs.
- Config path: gemma4_4b_joint_drafter.yaml ->
gemma4_joint_drafter/gemma4_4b_joint_drafter_medpix.yaml (the file was
renamed and moved into a dedicated subdir during the joint-drafter
example reorg).
- Dataset: swap mini_medpix for mini_cord_v2 (the standard VLM mini
dataset already staged in CI under $HF_CACHE) and override both
``dataset._target_`` and ``validation_dataset._target_`` to
``make_cord_v2_dataset``. Also pin ``--dataset.split=train`` /
``--validation_dataset.split=validation`` because the YAML defaults
use HF slice expressions (``train[:1000]``) that only resolve against
the hub-hosted MedPix dataset, not the local parquet fixture.
- LR scheduler: add ``--lr_scheduler.lr_warmup_steps 0``. The YAML
default ``lr_warmup_steps: 25`` is larger than ``max_steps: 3``, which
trips ``OptimizerParamScheduler``'s
``assert lr_warmup_steps < lr_decay_steps`` at construction.
Signed-off-by: Abhishree <abhishreetm@gmail.com>
…xt-mix dataset, drafter helpers Three test additions covering the new joint-drafter / text-mix paths introduced by this PR. All run on CPU and require no external HF downloads, so they execute in L0_Unit_Tests on every PR. - tests/unit_tests/models/gemma4_drafter/test_composite_stubs.py (new, +718): stub-based tests for ``Gemma4WithDrafter`` that mimic the HF ``Gemma4ForConditionalGeneration`` and ``Gemma4AssistantForCausalLM`` surface via plain ``torch.nn.Module`` doubles. Covers ``__init__`` validation (K<1 rejection, share-embedding shape mismatch, base_activation_checkpointing wiring), side-effects (post_projection freezing for K=1 vs K>1, masked_embedding.centroids freezing, freeze_base_for_drafter, share_embedding_with_base copy + lm_head tie), and forward semantics for K=1 / K=2 (teacher-forced shifted ids, recurrent post_projection feedback, shared_kv_states reuse). Complements ``test_composite.py`` which requires the optional transformers TOT install. - tests/unit_tests/datasets/vlm/test_datasets.py (+686): broader coverage of the new ``make_tulu3_magicoder_text_mix_dataset`` adapter (80/20 interleave probabilities, max_turns filter, missing assistant-turn drop, magicoder problem/solution shape, limit_total cap, no image fields in output). - tests/unit_tests/recipes/test_vlm_drafter_helpers.py (+192): additional cases for the ``_maybe_add_drafter_loss`` path -- empty ``drafter_logits``, multi-K label-shift correctness, lambda gating, log-line gating off ``is_remote_logging_step``. Signed-off-by: Abhishree <abhishreetm@gmail.com>
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com>
Contributor
Author
|
/ok to test ebf0f46 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
This PR adds joint fine-tuning support for Gemma 4 base (also called target model) and drafter/assistant models that enable multi-token prediction (MTP). The drafter is co-trained with the Gemma 4 base end-to-end via a composite model (
Gemma4WithDrafter) that wires up shared K/V states, sqrt(H_b)-scaled embeddings, and a K-step recurrent forward matching the Gemma 4 drafter tech report. The PR provides two reference configs for joint fine-tuning of gemma-4-E4B-it and gemma-4-E4B-it-assistant, one with MedPix VQA dataset and the other with a text-only Tulu-3 + Magicoder mix. Also provides an inference benchmark script that validates speculative-decode throughput on the saved checkpoint.The feature has been verified only against the Gemma 4 4B (E4B) base + drafter pair. The composite is architecturally model-agnostic within the Gemma 4 family, but the example YAMLs, the parity tests, and all training/inference verification in this PR target the 4B pair.
Implementation
Gemma4WithDraftercomposite (nemo_automodel/components/models/gemma4_drafter/)Gemma4ForConditionalGenerationbase +Gemma4AssistantForCausalLMdrafter as a singlenn.Modulefor FSDP2 training.cat(base.embed(input_ids), base.h_final); round k≥1 feedscat(base.embed(input_ids_shifted_by_k), prev_drafter.last_hidden_state).shared_kv_statesis captured once from a single base forward and reused at every round (per the Gemma 4 drafter tech report).<ckpt>/base/and<ckpt>/drafter/HF-loadable subdirs for vLLM / HF inference handoff.post_projectiononly whendrafter_num_steps > 1masked_embedding.centroidsalways frozen (thetorch.topkinside it blocks gradient flow back to the centroids).Recipe (
recipes/vlm/finetune.py,recipes/base_recipe.py)_maybe_add_drafter_loss(out, base_loss, labels, …)sums per-step CE over the composite'sdrafter_logitslist with_shift_labels_left(labels, k)per round.base_recipe.load_checkpointdispatches tomodel.load_pretrainedwhen the composite exposes it, so thebase/+drafter/subdir layout reloads correctly on resume.Text-only dataset adapter (
components/datasets/vlm/datasets.py)make_tulu3_magicoder_text_mix_datasetinterleavesallenai/tulu-3-sft-mixture(80 %) andise-uiuc/Magicoder-OSS-Instruct-75K(20 %) into{"conversation": [...]}dicts with noimagefield, sodefault_collate_fnemits batches withoutpixel_values. The composite + base accept text-only inputs unchanged.Example configs and benchmark (
examples/vlm_finetune/gemma4_joint_drafter/)gemma4_4b_joint_drafter_medpix.yaml— Joint fine-tuning recipe ofgoogle/gemma-4-E4B-it+google/gemma-4-E4B-it-assistantwith MedPix VQAgemma4_4b_joint_drafter_tulu_magicoder_mix.yaml— Joint fine-tuning on tulu and magiccoder mix, same base + drafter pairbenchmark_mtp_inference.py— measures speculative-decode acceptance + throughput end-to-end against the trained 4B pairTesting and Verification
examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_medpix.yamlLoss curve:
examples/vlm_finetune/gemma4_joint_drafter/gemma4_4b_joint_drafter_tulu_magicoder_mix.yamlLoss curve:
benchmark_mtp_inference.py(uses transformers generate) against the savedbase/+drafter/pair. Results below:Verified save + standalone HF reload works correctly. Checkpoints are stored in separate
<ckpt>/base/and<ckpt>/drafter/subdirs as a prerequisite for generation. After a 5-step training run, each sub-checkpoint loads via plain HFfrom_pretrainedwithout any NeMo-specific code:Gemma4ForConditionalGeneration.from_pretrained("<ckpt>/base/model/consolidated")→ 7.94 B params ✓Gemma4AssistantForCausalLM.from_pretrained("<ckpt>/drafter/model/consolidated")→ 78.5 M params ✓google/gemma-4-E4B-itandgoogle/gemma-4-E4B-it-assistant) exactly (0 missing, 0 extra).masked_embedding.token_ordering(int64[262144]) buffer survives the DCP → safetensors path.Verified resume-from-checkpoint loss parity. Run A: 10 steps fresh with
ckpt_every_steps=5. Run B: same config but--checkpoint.restore_from <RunA>/epoch_0_step_4. Per-step loss is bit-identical between A and B at every overlapping step (5–9):3.1257, 3.2992, 3.4045, 2.5375, 3.2824. Confirms model weights + optimizer state + LR scheduler state + dataloader state + RNG all restore correctly.Unit and functional tests.
tests/unit_tests/models/gemma4_drafter/test_composite.py— composite pre-projection input layout, shared-KV plumbing, K-step recurrence, save/load round-trip.tests/unit_tests/models/gemma4_drafter/test_composite_fsdp2.py— FSDP2 wrap + expert/grad sync.tests/unit_tests/models/gemma4_drafter/test_drafter_wrapper.py— drafter sub-module integration.tests/unit_tests/recipes/test_vlm_drafter_helpers.py—_shift_labels_left+_maybe_add_drafter_loss.tests/functional_tests/L2_HF_Transformer_VLM_Gemma4_Joint_Drafter.sh— L2 single-node smoke.Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information