Skip to content

feat: Megatron LoRA GRPO w/ Weight Merging#1889

Open
vadam5 wants to merge 21 commits intomainfrom
vadams/sync-lora-grpo-megatron
Open

feat: Megatron LoRA GRPO w/ Weight Merging#1889
vadam5 wants to merge 21 commits intomainfrom
vadams/sync-lora-grpo-megatron

Conversation

@vadam5
Copy link

@vadam5 vadam5 commented Feb 5, 2026

What does this PR do ?

Supports sync, async, and non-colocated LoRA GRPO via the megatron path with weight merging for rollouts. This PR merges lora adapter weights into model weights before exporting to VLLM for rollouts.

Issues

closes: #1372

Usage

uv run examples/run_grpo.py \
    --config examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-megatron-lora.yaml \
    grpo.max_num_steps=20 \
    grpo.num_prompts_per_step=8 \
    policy.train_global_batch_size=128 \
    policy.generation.colocated.enabled=false \
    policy.generation.colocated.resources.gpus_per_node=4 \
    policy.generation.colocated.resources.num_nodes=1 \
    policy.generation.vllm_cfg.tensor_parallel_size=4 \
    cluster.gpus_per_node=8 \
    policy.megatron_cfg.tensor_model_parallel_size=4 \
    policy.generation.vllm_cfg.async_engine=true \
    grpo.async_grpo.enabled=true \
    loss_fn.use_importance_sampling_correction=true \
    logger.log_dir=results/grpo-async-qwen3-8b-base-1n8g-megatron-lora/logs \
    logger.wandb_enabled=True \
    logger.wandb.project=lora-rl \
    logger.wandb.name=grpo-async-qwen3-8b-base-1n8g-megatron-lora \
    logger.monitor_gpus=True \
    logger.tensorboard_enabled=False \
    checkpointing.enabled=True \
    checkpointing.checkpoint_dir=results/grpo-async-qwen3-8b-base-1n8g-megatron-lora

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

https://wandb.ai/nvidia/nemo-rl?nw=s1m0n39d4le

image image image image Screenshot 2026-02-17 at 11 14 17 PM image image image

Summary by CodeRabbit

  • New Features

    • Added GRPO training support with LoRA fine-tuning for Megatron models, including single-node and distributed configurations.
    • Introduced example configurations for Qwen3 8B and other models using the new PEFT/LoRA framework.
  • Tests

    • Added functional test coverage for GRPO with LoRA in synchronous, asynchronous, and non-colocated deployment scenarios.
  • Chores

    • Updated dependencies including accelerate, transformers, transformer-engine, and new packages (fastapi, flash-linear-attention).
    • Updated Megatron-LM and Megatron-Bridge submodule references.

@vadam5 vadam5 changed the title feat: Megatron Sync LoRA GRPO w/ Weight Merging feat: Megatron LoRA GRPO w/ Weight Merging Feb 10, 2026
@vadam5 vadam5 marked this pull request as ready for review February 10, 2026 02:51
@vadam5 vadam5 requested review from a team as code owners February 10, 2026 02:51
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 10, 2026

📝 Walkthrough

Walkthrough

This PR updates Megatron-LM and Megatron-Bridge submodule references to new repositories and commits, updates dependencies in setup.py files to support newer versions, introduces PEFT/LoRA configuration blocks for GRPO, adds a state dict remapping utility method to MegatronPolicyWorker, and introduces new functional test scripts for Megatron-LORA GRPO training experiments.

Changes

Cohort / File(s) Summary
Submodule References
.gitmodules, 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge, 3rdparty/Megatron-LM-workspace/Megatron-LM
Updated Megatron-LM repository URL from terrykong to yaoyu-33 and branch from yuya/nemo-rl-use-dev to main; updated Megatron-Bridge commit pointer to a3fc5d5; updated Megatron-LM commit pointer to 11dcbaca.
Dependency Updates
3rdparty/Megatron-Bridge-workspace/setup.py, 3rdparty/Megatron-LM-workspace/setup.py
Added "accelerate" dependency; pinned transformers to "==4.57.1"; updated transformer-engine bounds from ">=2.9.0a0,<2.10.0" to ">=2.10.0a0,<2.12.0"; added "datasets"; relaxed numpy and nvidia-modelopt version constraints; added "fastapi~=0.50" and "flash-linear-attention~=0.3.2"; updated flashinfer-python to "~=0.5.0".
PEFT/LoRA Configuration
examples/configs/grpo_math_1B_megatron.yaml, examples/configs/grpo_math_1B_megatron_lora.yaml, examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-megatron-lora.yaml
Introduced PEFT configuration block under megatron_cfg with enabled flag, target/exclude modules, LoRA dimensions, alpha, dropout, initialization methods, and experimental flags; created comprehensive GRPO Megatron-LoRA config for 1B math model and Qwen3 8B recipe with detailed hyperparameters for loss, checkpointing, optimization, and distributed training.
Policy Worker Implementation
nemo_rl/models/policy/workers/megatron_policy_worker.py
Added _remap_reference_state_dict() utility method for mapping LoRA-unwrapped state dict names; propagated ProcessGroupCollection usage in reference model loading paths.
Functional Test Scripts
tests/functional/grpo_megatron_lora.sh, tests/functional/grpo_megatron_lora_async.sh, tests/functional/grpo_megatron_lora_non_colocated.sh
Added three new GPU functional test scripts orchestrating Megatron-based GRPO experiments with LoRA, covering synchronous, asynchronous, and non-colocated distributed training configurations; each script sets up directories, executes training via uv run, extracts metrics, and validates reward thresholds.
Test Infrastructure Updates
tests/L1_Functional_Tests_GPU.sh, tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-megatron-lora.sh, tests/test_suites/nightly.txt
Added three new functional test invocations to L1 GPU tests; created new Megatron LoRA GRPO test suite script; registered two nightly test entries under GRPO and SFT sections.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

CI:L1, CI, documentation

Suggested reviewers

  • yaoyu-33
  • terrykong
  • yuki-97
🚥 Pre-merge checks | ✅ 1 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major GRPO+LoRA+Megatron features with new test scripts but lacks documented test results, convergence metrics, or performance benchmarks in the PR description. Include actual test results showing convergence metrics, training reward values, and weight merging validation across sync, async, and non-colocated scenarios.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch vadams/sync-lora-grpo-megatron

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/models/policy/workers/megatron_policy_worker.py (1)

1565-1588: ⚠️ Potential issue | 🟠 Major

Manual state_dict swap will KeyError on LoRA-only keys.

When use_peft=True, self.model.state_dict() contains LoRA adapter keys (e.g., lora_A, lora_B) that don't exist in self.reference_state_dict. Line 1569 (self.reference_state_dict[k]) and line 1588 (model_state_dict[k]) will raise KeyError for these keys.

The commented-out load_state_dict with strict=True (lines 1565, 1585) had the same problem, which is presumably why it was replaced. But the manual approach needs to handle missing keys too.

Proposed fix
                 # Swap reference model state_dict to self.model
                 for k, v in self.model.state_dict().items():
                     if isinstance(v, torch.Tensor):
-                        v.copy_(self.reference_state_dict[k])
+                        if k in self.reference_state_dict:
+                            v.copy_(self.reference_state_dict[k])
-                for k, v in self.model.state_dict().items():
-                    if isinstance(v, torch.Tensor):
-                        v.copy_(model_state_dict[k])
+                for k, v in self.model.state_dict().items():
+                    if isinstance(v, torch.Tensor) and k in model_state_dict:
+                        v.copy_(model_state_dict[k])
🤖 Fix all issues with AI agents
In @.gitmodules:
- Around line 3-4: Update the submodule declaration that currently sets "url =
https://github.com/yaoyu-33/Megatron-LM.git" (with "branch = main") so it points
to the official upstream "https://github.com/NVIDIA/Megatron-LM.git"; if the
fork is intentionally required instead, replace the URL only after adding a
short justification in the repo docs (e.g., SECURITY.md or README) explaining
why the fork is needed and noting any maintained diffs, and include a maintainer
sign-off in the justification so reviewers can accept the deviation.

In `@3rdparty/Megatron-LM-workspace/Megatron-LM`:
- Line 1: The submodule 3rdparty/Megatron-LM-workspace/Megatron-LM points at a
non-existent commit (11dcbaca317133cc5c77c8bc4f54ed71d3b5d656); update the
submodule to a valid commit/branch on the upstream Megatron-LM remote by
entering the submodule (cd 3rdparty/Megatron-LM-workspace/Megatron-LM), running
git fetch origin, checking out a known-good commit or branch (e.g., origin/main
or a specific existing SHA), then git add the submodule change in the
superproject, commit the update, and push the branch so the PR references a
valid submodule commit.

In `@examples/configs/grpo_math_1B_megatron_lora.yaml`:
- Line 114: Replace the YAML value that sets lora_dtype so it yields a true null
rather than the string "None": change the mapping key/value where lora_dtype is
defined (currently `lora_dtype: None`) to use YAML null (`lora_dtype: null` or
`lora_dtype: ~`) so that downstream code constructing LoRA (e.g.,
LoRA(lora_dtype=...)) receives a null/None value instead of the string "None".

In `@examples/configs/grpo_math_1B_megatron.yaml`:
- Around line 100-111: The base Megatron config currently enables LoRA by
default and sets lora_dtype to the literal string "None"; change peft.enabled to
false so downstream non‑LoRA runs (e.g., grpo_megatron.sh) don't inadvertently
enable LoRA, and have LoRA-specific configs or the grpo_megatron_lora.sh
override set peft.enabled=true when needed; also replace lora_dtype: None with a
YAML null (e.g., lora_dtype: null or lora_dtype: ~) so it parses as null rather
than the string "None".

In `@nemo_rl/models/policy/workers/megatron_policy_worker.py`:
- Around line 925-933: The current check uses "if ref_megatron_cfg is not None"
which is always true because ref_megatron_cfg is always created; change the
guard to verify PEFT is enabled (e.g. if self.use_peft and ref_megatron_cfg is
not None) before creating and registering the PEFT pre-wrap hook via
_create_peft_pre_wrap_hook(ref_megatron_cfg, ref_state), calling
ref_megatron_cfg.model.register_pre_wrap_hook(pre_peft_hook), composing
composed_peft_hook, and extending ref_pre_wrap_hooks so LoRA wrapping only
applies when self.use_peft is true.
- Around line 946-960: When self.use_peft is true the current
should_load_checkpoint only checks ref_megatron_cfg.checkpoint.load and ignores
ref_megatron_cfg.checkpoint.pretrained_checkpoint; update the PEFT branch in
megatron_policy_worker.py so should_load_checkpoint mirrors the non-PEFT logic
by checking both ref_megatron_cfg.checkpoint.load and
ref_megatron_cfg.checkpoint.pretrained_checkpoint with checkpoint_exists, and
preserve the existing ref_megatron_cfg.checkpoint.finetune toggling behavior
(still set finetune=False when loading a checkpoint) so the reference model
loads pretrained weights in PEFT scenarios.
🧹 Nitpick comments (2)
nemo_rl/models/policy/workers/megatron_policy_worker.py (2)

1571-1573: Commented-out code without explanation.

Per coding guidelines, commented-out code should include a comment describing why it is retained, or be removed before merging. Lines 1571-1573 and 1565 have commented-out load_state_dict calls with no rationale.


904-943: Duplicate LoRA construction — extract a helper.

The LoRA instantiation block (lines 906-919) is nearly identical to lines 308-320 in setup_megatron_model. Consider extracting a shared helper to avoid copy-paste divergence.

@vadam5 vadam5 force-pushed the vadams/sync-lora-grpo-megatron branch from 7c9e021 to 4c0e216 Compare February 11, 2026 00:50
@github-actions
Copy link

❌ Submodule Fast-Forward Check Failed

Check based on commit: 4c0e216 (PR #1889 from vadams/sync-lora-grpo-megatron)

✅ Submodules that are properly updated:

Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward)

❌ Submodules that need attention:

Megatron-LM: ❌ Commits have DIVERGED from a common ancestor
TARGET (main branch): https://github.com/yaoyu-33/Megatron-LM/commits/b73ae5cdab9d409fcface2b2f3c375710abe6911/
CURRENT (PR #1889 from vadams/sync-lora-grpo-megatron): https://github.com/yaoyu-33/Megatron-LM/commits/11dcbaca317133cc5c77c8bc4f54ed71d3b5d656/

Please ensure all submodule commits are fast-forwards of the main branch before merging.

@github-actions
Copy link

❌ Submodule Fast-Forward Check Failed

Check based on commit: 2b335f7 (PR #1889 from vadams/sync-lora-grpo-megatron)

✅ Submodules that are properly updated:

Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward)

❌ Submodules that need attention:

Megatron-LM: ❌ Commits have DIVERGED from a common ancestor
TARGET (main branch): https://github.com/yaoyu-33/Megatron-LM/commits/b73ae5cdab9d409fcface2b2f3c375710abe6911/
CURRENT (PR #1889 from vadams/sync-lora-grpo-megatron): https://github.com/yaoyu-33/Megatron-LM/commits/11dcbaca317133cc5c77c8bc4f54ed71d3b5d656/

Please ensure all submodule commits are fast-forwards of the main branch before merging.

@vadam5 vadam5 force-pushed the vadams/sync-lora-grpo-megatron branch from 2a926f7 to 8168855 Compare February 17, 2026 20:13
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work @vadam5 !

could you also include lp error and gen kl error in the plots?

@cuichenx @yaoyu-33 to review the megatron worker part

if policy_cfg["megatron_cfg"]["freeze_moe_router"]:
if use_peft:
raise ValueError(
"Freezing the MOE router is not currently supported when using PEFT"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, what change allowed this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we ever really needed to raise an error here. In the SFT LoRA PR, I raised an error because I thought adding the two different pre-wrap hooks (one for PEFT and one for freezing the MOE router) would have conflicts with each other in mbridge.

In this PR I did several runs where "freeze_moe_router" was set to true and PEFT was enabled without any errors or crashing. I believe in the past @yfw did similar runs with SFT LoRA and saw no adverse effects as well. So I removed that check here.

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another round of review

else:
peft = None

ref_megatron_cfg.peft = peft
Copy link
Contributor

@terrykong terrykong Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason to do peft on the reference model? usually the reference model is just for FW pass for logprobs. if we add a peft adapter, won't the lora weights be initialized and the reference model will be changed despite us expecting it to be frozen?

assuming it is for compatibility with the policy so you can swap the state dicts, does the ref lora weights need to be initialized to all 0 to avoid changing the fwd pass?

Copy link
Author

@vadam5 vadam5 Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I added it to the reference model for compatibility with the policy so the state dicts can be swapped easily. Since the lora B weights are always init to zero, the lora weights shouldn't effect the fwd pass. I can explicitly force the lora A weights to zero too to make this more clear, but lora B being zero should cover it already.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the code a bit to force these both to zero regardless of what's set in the config.

overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step,
pre_wrap_hook=megatron_cfg.rng.data_parallel_random_init,
data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init,
pre_wrap_hook=ref_pre_wrap_hooks,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this a bug? i see the setup for ref model and setup for policy having diff ways of dealing with data_parallel_random_init in the get_model call

@cuichenx thoughts on this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was a bug. I made the reference model setup match the policy model setup here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref_megatron_cfg.checkpoint.finetune = False
else:
should_load_checkpoint = (
ref_checkpoint_config.pretrained_checkpoint is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this condition slightly diff than the one above?

should_load_checkpoint = ref_checkpoint_config.load is not None and checkpoint_exists(
            ref_checkpoint_config.load
        )

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh, this logic was copied over from within mbridge and I wasn't sure why the logic difference was there. Let me check on this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was trying to mirror the logic from megatron bridge here: https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/b8475660c0ad2dce2d43e0a2f96bd0d719ad85bf/src/megatron/bridge/training/setup.py#L246

Not sure why its set like this in mbridge

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaoyu-33 @ananthsub can you comment?

)

if use_peft:
should_load_checkpoint = ref_checkpoint_config.load is not None and checkpoint_exists(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on line 842, looks like

ref_checkpoint_config = CheckpointConfig(
        pretrained_checkpoint=pretrained_path,
        save=None,
        load=None,
        fully_parallel_load=True,
        load_rng=False,
    )

so isn't ref_checkpoint_config.load == None? so does that mean if use_peft==true, we'll never load the ckpt?

Copy link
Author

@vadam5 vadam5 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this is true. I just reverted the logic back to loading a checkpoint if pretrained_checkpoint exsists regradless of if PEFT is used or not, but kept this check.

if should_load_checkpoint and use_peft:
        # The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
        # This is switched off here in order to load these states from the checkpoint
        ref_megatron_cfg.checkpoint.finetune = False

I was trying to mirror the logic used for setting up the policy model and setup in mbridge, but I think something different is required here.

@terrykong
Copy link
Contributor

also, DCO and lint need to be resolved before final merge

@vadam5 vadam5 force-pushed the vadams/sync-lora-grpo-megatron branch from 3bb86a6 to c799cdf Compare February 20, 2026 00:34
@vadam5 vadam5 requested a review from a team as a code owner February 20, 2026 00:52
ashors1 and others added 7 commits February 19, 2026 16:54
Signed-off-by: Anna Shors <ashors@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
vadam5 and others added 13 commits February 19, 2026 16:54
Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
@vadam5 vadam5 force-pushed the vadams/sync-lora-grpo-megatron branch from 9c7e4e2 to d5fa658 Compare February 20, 2026 00:55
Signed-off-by: Virginia Wu <vadams@nvidia.com>
@vadam5
Copy link
Author

vadam5 commented Feb 20, 2026

Fixed DCO and ran the linter. Some files I didn't touch for this PR also had small lint fixes.

@vadam5 vadam5 requested review from cuichenx and terrykong February 20, 2026 01:07
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaoyu-33 @ananthsub can you take a pass? some of the megatron changes could use your expertise

peft = LoRA(
target_modules=peft_cfg["target_modules"],
exclude_modules=peft_cfg["exclude_modules"],
target_modules=peft_cfg.get("target_modules", []),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using default value fallback here seems to be not the heuristics that nemo-rl follows? @terrykong to check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vadam5 can you avoid the fallback and just rely on the one in the config (before the change)?

@terrykong
Copy link
Contributor

hey @vadam5 . @yaoyu-33 raised some good points offline. give me some time to review some of the api changes, to confirm with these changes. i'll circle back today or tomorrow on this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LoRA Mcore GRPO

5 participants