Conversation
📝 WalkthroughWalkthroughThis PR updates Megatron-LM and Megatron-Bridge submodule references to new repositories and commits, updates dependencies in setup.py files to support newer versions, introduces PEFT/LoRA configuration blocks for GRPO, adds a state dict remapping utility method to MegatronPolicyWorker, and introduces new functional test scripts for Megatron-LORA GRPO training experiments. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/policy/workers/megatron_policy_worker.py (1)
1565-1588:⚠️ Potential issue | 🟠 MajorManual state_dict swap will KeyError on LoRA-only keys.
When
use_peft=True,self.model.state_dict()contains LoRA adapter keys (e.g.,lora_A,lora_B) that don't exist inself.reference_state_dict. Line 1569 (self.reference_state_dict[k]) and line 1588 (model_state_dict[k]) will raiseKeyErrorfor these keys.The commented-out
load_state_dictwithstrict=True(lines 1565, 1585) had the same problem, which is presumably why it was replaced. But the manual approach needs to handle missing keys too.Proposed fix
# Swap reference model state_dict to self.model for k, v in self.model.state_dict().items(): if isinstance(v, torch.Tensor): - v.copy_(self.reference_state_dict[k]) + if k in self.reference_state_dict: + v.copy_(self.reference_state_dict[k])- for k, v in self.model.state_dict().items(): - if isinstance(v, torch.Tensor): - v.copy_(model_state_dict[k]) + for k, v in self.model.state_dict().items(): + if isinstance(v, torch.Tensor) and k in model_state_dict: + v.copy_(model_state_dict[k])
🤖 Fix all issues with AI agents
In @.gitmodules:
- Around line 3-4: Update the submodule declaration that currently sets "url =
https://github.com/yaoyu-33/Megatron-LM.git" (with "branch = main") so it points
to the official upstream "https://github.com/NVIDIA/Megatron-LM.git"; if the
fork is intentionally required instead, replace the URL only after adding a
short justification in the repo docs (e.g., SECURITY.md or README) explaining
why the fork is needed and noting any maintained diffs, and include a maintainer
sign-off in the justification so reviewers can accept the deviation.
In `@3rdparty/Megatron-LM-workspace/Megatron-LM`:
- Line 1: The submodule 3rdparty/Megatron-LM-workspace/Megatron-LM points at a
non-existent commit (11dcbaca317133cc5c77c8bc4f54ed71d3b5d656); update the
submodule to a valid commit/branch on the upstream Megatron-LM remote by
entering the submodule (cd 3rdparty/Megatron-LM-workspace/Megatron-LM), running
git fetch origin, checking out a known-good commit or branch (e.g., origin/main
or a specific existing SHA), then git add the submodule change in the
superproject, commit the update, and push the branch so the PR references a
valid submodule commit.
In `@examples/configs/grpo_math_1B_megatron_lora.yaml`:
- Line 114: Replace the YAML value that sets lora_dtype so it yields a true null
rather than the string "None": change the mapping key/value where lora_dtype is
defined (currently `lora_dtype: None`) to use YAML null (`lora_dtype: null` or
`lora_dtype: ~`) so that downstream code constructing LoRA (e.g.,
LoRA(lora_dtype=...)) receives a null/None value instead of the string "None".
In `@examples/configs/grpo_math_1B_megatron.yaml`:
- Around line 100-111: The base Megatron config currently enables LoRA by
default and sets lora_dtype to the literal string "None"; change peft.enabled to
false so downstream non‑LoRA runs (e.g., grpo_megatron.sh) don't inadvertently
enable LoRA, and have LoRA-specific configs or the grpo_megatron_lora.sh
override set peft.enabled=true when needed; also replace lora_dtype: None with a
YAML null (e.g., lora_dtype: null or lora_dtype: ~) so it parses as null rather
than the string "None".
In `@nemo_rl/models/policy/workers/megatron_policy_worker.py`:
- Around line 925-933: The current check uses "if ref_megatron_cfg is not None"
which is always true because ref_megatron_cfg is always created; change the
guard to verify PEFT is enabled (e.g. if self.use_peft and ref_megatron_cfg is
not None) before creating and registering the PEFT pre-wrap hook via
_create_peft_pre_wrap_hook(ref_megatron_cfg, ref_state), calling
ref_megatron_cfg.model.register_pre_wrap_hook(pre_peft_hook), composing
composed_peft_hook, and extending ref_pre_wrap_hooks so LoRA wrapping only
applies when self.use_peft is true.
- Around line 946-960: When self.use_peft is true the current
should_load_checkpoint only checks ref_megatron_cfg.checkpoint.load and ignores
ref_megatron_cfg.checkpoint.pretrained_checkpoint; update the PEFT branch in
megatron_policy_worker.py so should_load_checkpoint mirrors the non-PEFT logic
by checking both ref_megatron_cfg.checkpoint.load and
ref_megatron_cfg.checkpoint.pretrained_checkpoint with checkpoint_exists, and
preserve the existing ref_megatron_cfg.checkpoint.finetune toggling behavior
(still set finetune=False when loading a checkpoint) so the reference model
loads pretrained weights in PEFT scenarios.
🧹 Nitpick comments (2)
nemo_rl/models/policy/workers/megatron_policy_worker.py (2)
1571-1573: Commented-out code without explanation.Per coding guidelines, commented-out code should include a comment describing why it is retained, or be removed before merging. Lines 1571-1573 and 1565 have commented-out
load_state_dictcalls with no rationale.
904-943: Duplicate LoRA construction — extract a helper.The LoRA instantiation block (lines 906-919) is nearly identical to lines 308-320 in
setup_megatron_model. Consider extracting a shared helper to avoid copy-paste divergence.
7c9e021 to
4c0e216
Compare
❌ Submodule Fast-Forward Check FailedCheck based on commit: 4c0e216 (PR #1889 from ✅ Submodules that are properly updated:Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-LM: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
❌ Submodule Fast-Forward Check FailedCheck based on commit: 2b335f7 (PR #1889 from ✅ Submodules that are properly updated:Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward) ❌ Submodules that need attention:Megatron-LM: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
2a926f7 to
8168855
Compare
examples/configs/recipes/llm/grpo-nanov3-30BA3B-1n8g-megatron-lora.yaml
Outdated
Show resolved
Hide resolved
| if policy_cfg["megatron_cfg"]["freeze_moe_router"]: | ||
| if use_peft: | ||
| raise ValueError( | ||
| "Freezing the MOE router is not currently supported when using PEFT" |
There was a problem hiding this comment.
OOC, what change allowed this?
There was a problem hiding this comment.
I don't think we ever really needed to raise an error here. In the SFT LoRA PR, I raised an error because I thought adding the two different pre-wrap hooks (one for PEFT and one for freezing the MOE router) would have conflicts with each other in mbridge.
In this PR I did several runs where "freeze_moe_router" was set to true and PEFT was enabled without any errors or crashing. I believe in the past @yfw did similar runs with SFT LoRA and saw no adverse effects as well. So I removed that check here.
terrykong
left a comment
There was a problem hiding this comment.
another round of review
| else: | ||
| peft = None | ||
|
|
||
| ref_megatron_cfg.peft = peft |
There was a problem hiding this comment.
what is the reason to do peft on the reference model? usually the reference model is just for FW pass for logprobs. if we add a peft adapter, won't the lora weights be initialized and the reference model will be changed despite us expecting it to be frozen?
assuming it is for compatibility with the policy so you can swap the state dicts, does the ref lora weights need to be initialized to all 0 to avoid changing the fwd pass?
There was a problem hiding this comment.
yes, I added it to the reference model for compatibility with the policy so the state dicts can be swapped easily. Since the lora B weights are always init to zero, the lora weights shouldn't effect the fwd pass. I can explicitly force the lora A weights to zero too to make this more clear, but lora B being zero should cover it already.
There was a problem hiding this comment.
updated the code a bit to force these both to zero regardless of what's set in the config.
| overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, | ||
| pre_wrap_hook=megatron_cfg.rng.data_parallel_random_init, | ||
| data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init, | ||
| pre_wrap_hook=ref_pre_wrap_hooks, |
There was a problem hiding this comment.
was this a bug? i see the setup for ref model and setup for policy having diff ways of dealing with data_parallel_random_init in the get_model call
@cuichenx thoughts on this?
There was a problem hiding this comment.
I thought it was a bug. I made the reference model setup match the policy model setup here.
nemo_rl/models/megatron/setup.py
Outdated
| ref_megatron_cfg.checkpoint.finetune = False | ||
| else: | ||
| should_load_checkpoint = ( | ||
| ref_checkpoint_config.pretrained_checkpoint is not None |
There was a problem hiding this comment.
why is this condition slightly diff than the one above?
should_load_checkpoint = ref_checkpoint_config.load is not None and checkpoint_exists(
ref_checkpoint_config.load
)
There was a problem hiding this comment.
tbh, this logic was copied over from within mbridge and I wasn't sure why the logic difference was there. Let me check on this.
There was a problem hiding this comment.
was trying to mirror the logic from megatron bridge here: https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/b8475660c0ad2dce2d43e0a2f96bd0d719ad85bf/src/megatron/bridge/training/setup.py#L246
Not sure why its set like this in mbridge
nemo_rl/models/megatron/setup.py
Outdated
| ) | ||
|
|
||
| if use_peft: | ||
| should_load_checkpoint = ref_checkpoint_config.load is not None and checkpoint_exists( |
There was a problem hiding this comment.
on line 842, looks like
ref_checkpoint_config = CheckpointConfig(
pretrained_checkpoint=pretrained_path,
save=None,
load=None,
fully_parallel_load=True,
load_rng=False,
)
so isn't ref_checkpoint_config.load == None? so does that mean if use_peft==true, we'll never load the ckpt?
There was a problem hiding this comment.
hmm, this is true. I just reverted the logic back to loading a checkpoint if pretrained_checkpoint exsists regradless of if PEFT is used or not, but kept this check.
if should_load_checkpoint and use_peft:
# The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
# This is switched off here in order to load these states from the checkpoint
ref_megatron_cfg.checkpoint.finetune = FalseI was trying to mirror the logic used for setting up the policy model and setup in mbridge, but I think something different is required here.
|
also, DCO and lint need to be resolved before final merge |
3bb86a6 to
c799cdf
Compare
Signed-off-by: Anna Shors <ashors@nvidia.com> Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com> Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com> Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com> Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
Signed-off-by: Virginia Wu <vadams@nvidia.com>
9c7e4e2 to
d5fa658
Compare
Signed-off-by: Virginia Wu <vadams@nvidia.com>
|
Fixed DCO and ran the linter. Some files I didn't touch for this PR also had small lint fixes. |
terrykong
left a comment
There was a problem hiding this comment.
@yaoyu-33 @ananthsub can you take a pass? some of the megatron changes could use your expertise
| peft = LoRA( | ||
| target_modules=peft_cfg["target_modules"], | ||
| exclude_modules=peft_cfg["exclude_modules"], | ||
| target_modules=peft_cfg.get("target_modules", []), |
There was a problem hiding this comment.
using default value fallback here seems to be not the heuristics that nemo-rl follows? @terrykong to check
There was a problem hiding this comment.
@vadam5 can you avoid the fallback and just rely on the one in the config (before the change)?
What does this PR do ?
Supports sync, async, and non-colocated LoRA GRPO via the megatron path with weight merging for rollouts. This PR merges lora adapter weights into model weights before exporting to VLLM for rollouts.
Issues
closes: #1372
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
https://wandb.ai/nvidia/nemo-rl?nw=s1m0n39d4le
Summary by CodeRabbit
New Features
Tests
Chores