feat: split validation statistics by task name#2019
Conversation
📝 WalkthroughWalkthroughThis PR extends validation support from single merged datasets to per-task validation dictionaries across training algorithms. It updates data setup, algorithm training loops, dataset initialization patterns, configuration examples, and checkpointing logic to handle per-dataset validation metrics and enable dataset-specific checkpoint triggers. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant DataLoader as Per-Task<br/>DataLoaders
participant Validator
participant Metrics as Per-Task<br/>Metrics
participant Checkpoint as Checkpoint<br/>Manager
rect rgba(0, 100, 200, 0.5)
Note over Trainer,Checkpoint: Old Flow: Single Validation Dataset
Trainer->>DataLoader: Load val_dataloader (merged)
DataLoader->>Validator: Single batch stream
Validator->>Metrics: Aggregate results
Metrics->>Checkpoint: global_accuracy/loss
Checkpoint->>Checkpoint: Save if metric improves
end
rect rgba(0, 150, 100, 0.5)
Note over Trainer,Checkpoint: New Flow: Per-Task Validation Datasets
Trainer->>DataLoader: Load val_dataloaders: dict[task_name]
loop For each task in dict
DataLoader->>Validator: Per-task batch stream
Validator->>Metrics: Per-task results
end
Metrics->>Metrics: Accumulate per-task metrics
Metrics->>Metrics: Compute total_accuracy/loss
Checkpoint->>Checkpoint: Select metric by task_name<br/>(val:accuracy_Task1)
Checkpoint->>Checkpoint: Save if metric improves
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
examples/run_sft.py (1)
104-121:⚠️ Potential issue | 🟡 MinorSilent overwrite when the same
task_nameappears in both validation sources.If
data.task_nameat line 106 equalsval_data.task_nameat line 121 (a task that appears in both the train-split validation and the explicitvalidation:config), the train-split dataset is silently replaced. Add a warning:
⚠️ Proposed guard+ if val_data.task_name in val_data_dict: + warnings.warn( + f"task_name '{val_data.task_name}' already exists in val_data_dict " + "(from train split). Overwriting with config-defined validation dataset." + ) val_data_dict[val_data.task_name] = val_data.dataset🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/run_sft.py` around lines 104 - 121, The current logic populates val_data_dict from two sources (the data_list loop using data.task_name and the data_config["validation"] loop using val_data.task_name) and silently overwrites entries when task names collide; update the code so before assigning into val_data_dict in the second validation-source loop (after val_data = load_response_dataset(cfg)) you check if val_data.task_name already exists in val_data_dict and, if so, emit a warning (e.g., using logger.warning or print) that the validation dataset for that task_name will be overwritten, including which source is being overridden; ensure the check references val_data_dict, val_data.task_name, and the loading path around load_response_dataset so the warning is clear and only emitted on duplicates.nemo_rl/algorithms/dpo.py (1)
663-682:⚠️ Potential issue | 🟡 MinorSame missing-colon validation as
rm.py— crypticValueErrorfor misconfiguredmetric_name.Line 665 (
prefix, metric_name = full_metric_name.split(":", 1)) is the same unguarded split as inrm.py. Apply the same fix:🛡️ Proposed fix
if full_metric_name is not None: + if ":" not in full_metric_name: + raise ValueError( + f"checkpointing.metric_name must be in '<prefix>:<metric>' format, got '{full_metric_name}'" + ) prefix, metric_name = full_metric_name.split(":", 1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/dpo.py` around lines 663 - 682, The code does an unguarded split of full_metric_name into prefix and metric_name (prefix, metric_name = full_metric_name.split(":", 1)) which raises a cryptic ValueError for misconfigured strings; update the validation so you first check that full_metric_name is a non-empty string containing a ":" (or use str.partition and verify the separator was present) and raise a clear ValueError like "full_metric_name must be '<prefix>:<metric>'" when missing; then proceed to set prefix, metric_name and the rest of the logic that chooses metrics_source and updates dpo_save_state (keep the existing warnings, deletion from dpo_save_state, and the metric-not-found error path intact).nemo_rl/algorithms/rm.py (1)
590-610:⚠️ Potential issue | 🟡 MinorRemoved format assertion leaves a cryptic
ValueErrorwhenmetric_namelacks a colon.With the assertion gone, a misconfigured
metric_name: accuracy(noprefix:part) reaches line 592:prefix, metric_name = full_metric_name.split(":", 1) # → ValueError: not enough values to unpack (expected 2, got 1)Add an explicit guard before the split:
🛡️ Proposed fix
if full_metric_name is not None: + if ":" not in full_metric_name: + raise ValueError( + f"checkpointing.metric_name must be in '<prefix>:<metric>' format, got '{full_metric_name}'" + ) prefix, metric_name = full_metric_name.split(":", 1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/rm.py` around lines 590 - 610, The code currently does an unchecked split of full_metric_name into prefix and metric_name causing a cryptic ValueError when the colon is missing; before calling full_metric_name.split(":", 1) in the checkpointing block that updates rm_save_state, add an explicit guard that checks that full_metric_name contains exactly one ':' (or at least contains ':') and if not raise a clear ValueError (or warnings.warn then skip) that explains the expected format like "checkpointing.metric_name must be 'prefix:metric' (e.g. 'train:accuracy')" so callers see a helpful error instead of the unpacking exception.nemo_rl/algorithms/distillation.py (1)
945-1078:⚠️ Potential issue | 🟡 MinorRemove unused
rewardsvariable assignment.Line 1011 assigns
val_batch["total_reward"]torewards, but the variable is never used; the next line accessesval_batch["total_reward"].tolist()directly. Delete the unused assignment.Diff
- rewards = val_batch["total_reward"] task_rewards.extend(val_batch["total_reward"].tolist())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/distillation.py` around lines 945 - 1078, In function validate (in nemo_rl/algorithms/distillation.py) remove the unused local assignment "rewards = val_batch['total_reward']" (the variable rewards is never referenced afterward); simply delete that line so the code uses val_batch['total_reward'].tolist() directly and avoids an unused variable in the loop that processes val_batch within validate.nemo_rl/data/utils.py (1)
116-140:⚠️ Potential issue | 🟠 MajorGuard against duplicate task_name overwrites in validation data loading.
Direct dict assignment silently drops earlier validation datasets if the same task_name appears in multiple configs. The training path handles this correctly with
concatenate_datasets()(line 105), but validation does not. When duplicate task_names occur, later assignments overwrite earlier ones inval_data_dict,val_task_data_processors, andval_task_to_env, causing silent data loss.Concatenate validation datasets with matching task_names using the same pattern as training data, or assert that task_names are unique across validation configs.
💡 Suggested fix (merge duplicates like training data)
- val_data_dict[data.task_name] = data.val_dataset + if data.task_name in val_data_dict: + val_data_dict[data.task_name] = concatenate_datasets( + [val_data_dict[data.task_name], data.val_dataset] + ) + else: + val_data_dict[data.task_name] = data.val_dataset ... - val_data_dict[val_data.task_name] = val_data.dataset + if val_data.task_name in val_data_dict: + val_data_dict[val_data.task_name] = concatenate_datasets( + [val_data_dict[val_data.task_name], val_data.dataset] + ) + else: + val_data_dict[val_data.task_name] = val_data.dataset🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/utils.py` around lines 116 - 140, The validation loader currently overwrites earlier entries when multiple validation configs share the same task_name; update the block that assigns to val_data_dict, val_task_data_processors, and val_task_to_env to merge duplicates like the training path: after loading val_data = load_response_dataset(cfg), check if val_data.task_name already exists in val_data_dict and if so call concatenate_datasets(existing_dataset, val_data.dataset) (same helper used on line ~105) and replace the entry, otherwise set it; ensure val_task_data_processors[val_data.task_name] and val_task_to_env[val_data.task_name] are only set once (or validated to be consistent) to avoid inconsistent processor/env mappings.
♻️ Duplicate comments (7)
nemo_rl/data/datasets/preference_datasets/tulu3.py (1)
27-28: Samecommon_initargument concern as refcoco.Please ensure
skip_set_processoris passed or has a default incommon_init(see earlier note).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/preference_datasets/tulu3.py` around lines 27 - 28, The call to self.common_init(...) in Tulu3Preference must either pass skip_set_processor=True or rely on common_init having a default for skip_set_processor; update the call site in tulu3.py to include skip_set_processor=<appropriate boolean> (e.g., True) or modify the common_init signature to provide a default value for skip_set_processor so the processor behavior matches other datasets like refcoco; reference the common_init function and the Tulu3Preference initializer to make the change consistently.nemo_rl/data/datasets/response_datasets/dapo_math.py (2)
55-60: Samecommon_initargument concern as refcoco.Please ensure
skip_set_processoris passed or has a default incommon_init(see earlier note).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/response_datasets/dapo_math.py` around lines 55 - 60, The constructor calls self.common_init(default_task_name="DAPOMathAIME2024", **kwargs) but does not pass skip_set_processor (and common_init may not provide a default), so ensure skip_set_processor is explicitly handled: either pass skip_set_processor from kwargs into common_init (e.g., include skip_set_processor=kwargs.get("skip_set_processor", <desired default>)) or update common_init to define a safe default for skip_set_processor; locate the call site in the DAPOMathAIME2024 dataset class __init__ and adjust the common_init invocation or add the default in the common_init implementation so skip_set_processor is always defined.
26-27: Samecommon_initargument concern as refcoco.Please ensure
skip_set_processoris passed or has a default incommon_init(see earlier note).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/response_datasets/dapo_math.py` around lines 26 - 27, common_init is being called without the skip_set_processor argument here in the DAPOMath17K dataset initializer; update the call in the constructor to pass skip_set_processor (same approach used in refcoco) or ensure common_init defines a default for skip_set_processor. Specifically, modify the call to self.common_init(default_task_name="DAPOMath17K", skip_set_processor=..., **kwargs) or add skip_set_processor=False/True as a default parameter in the common_init signature so the processor behavior is explicit (refer to the common_init function and the class initializer in dapo_math.py).nemo_rl/data/datasets/response_datasets/nemogym_dataset.py (1)
29-35: Samecommon_initargument concern as refcoco.Please ensure
skip_set_processoris passed or has a default incommon_init(see earlier note).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/response_datasets/nemogym_dataset.py` around lines 29 - 35, The call to self.common_init(...) in nemogym_dataset.py uses common_init without supplying skip_set_processor, which can cause the same bug noted in refcoco; update the invocation in the constructor where default_task_name is computed so it explicitly passes skip_set_processor (e.g., skip_set_processor=kwargs.get("skip_set_processor", False)) or ensure common_init has a default for skip_set_processor, referencing the common_init method and the place where default_task_name is computed and passed to self.common_init.nemo_rl/data/datasets/response_datasets/squad.py (1)
30-31: Samecommon_initargument concern as refcoco.Please ensure
skip_set_processoris passed or has a default incommon_init(see earlier note).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/response_datasets/squad.py` around lines 30 - 31, The call to self.common_init in the SQuAD dataset should explicitly handle the skip_set_processor flag: either pass skip_set_processor from the SQuAD constructor into self.common_init (e.g., self.common_init(default_task_name="squad", skip_set_processor=skip_set_processor, **kwargs)) or add a default for skip_set_processor inside the common_init signature so callers that omit it (like squad) behave correctly; update the SQuAD constructor to pass through the parameter or update common_init to set a sensible default for skip_set_processor to avoid the missing-argument issue.nemo_rl/data/datasets/response_datasets/openmathinstruct2.py (1)
48-50: Samecommon_initargument concern as refcoco.Please ensure
skip_set_processoris passed or has a default incommon_init(see earlier note).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/response_datasets/openmathinstruct2.py` around lines 48 - 50, The call to self.common_init in OpenMathInstruct2 doesn't pass skip_set_processor and relies on common_init to provide a safe default; either update the call in the OpenMathInstruct2 constructor to explicitly pass skip_set_processor (e.g., skip_set_processor=True or False as appropriate for this dataset) or change the common_init signature to include a default value for skip_set_processor so callers like OpenMathInstruct2 can omit it; locate the common_init definition and add a default (or update the OpenMathInstruct2 call) to ensure skip_set_processor is always defined.nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py (1)
57-63: Duplicate task-name derivation logic — same issue asoai_format_dataset.py.The
default_task_namederivation (lines 58–60) is identical tooai_format_dataset.pylines 138–140. Please extract to the shared helper described in theoai_format_dataset.pycomment — this is the second occurrence that confirms the need for the refactor.The same empty-string edge case and
skip_set_processorconcern flagged inoai_format_dataset.pyapply here as well.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py` around lines 57 - 63, The default_task_name derivation code (building default_task_name from data_path and trimming a leading '-') is duplicated here and should be extracted into a shared helper (e.g., a new function get_default_task_name(data_path) used by both BinaryPreferenceDataset and OaiFormatDataset); update nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py to call that helper instead of repeating the logic around default_task_name, ensure the helper returns a non-empty string (handle the empty-string edge case by falling back to a safe name or raising a clear error), and preserve the existing call to self.common_init(default_task_name=..., **kwargs) while respecting the existing skip_set_processor behavior (i.e., do not change how skip_set_processor is passed through to common_init).
🧹 Nitpick comments (4)
examples/configs/sft.yaml (1)
237-249: Consider moving the checkpointing guidance comment closer to thecheckpointing:block.The guidance block at lines 237–248 explains how
task_namein thedata:section relates tometric_nameincheckpointing:, but it sits at the end of thedata:section, far from thecheckpointing:block (lines 16–25). Readers are more likely to look for this guidance when editingcheckpointing.metric_name. A cross-reference comment near line 20 (themetric_nameentry) would be more discoverable.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/configs/sft.yaml` around lines 237 - 249, Move the explanatory comment about how data.validation.task_name maps to checkpointing.metric_name from the end of the data: section into/near the checkpointing: block—specifically adjacent to the checkpointing.metric_name entry—so users editing metric_name can immediately see the guidance; reference data.validation.task_name in the moved comment and keep the example lines (e.g., metric_name: "val:val_loss_dataset1" and metric_name: "val:val_loss") so the relationship and examples remain intact.nemo_rl/utils/checkpoint.py (1)
109-118: Prefer aValueErroroverassertfor config validation, and fix the message typo.
assertcan be stripped in optimized runs, and the example string currently reads awkwardly ('val_reward --> 'val:reward'). Consider a regular exception with a clearer message.♻️ Proposed refactor
- if self.metric_name is not None: - assert self.metric_name.startswith("train:") or self.metric_name.startswith( - "val:" - ), ( - f"metric_name={self.metric_name} must start with 'val:' or 'train:',\n" - f'followed by the corresponding name in the "val" or "train" metrics dictionary.' - f" If you are using an old config, please updated checkpointing.metric_name to the new format, " - f" e.g. 'val_reward --> 'val:reward'" - ) + if self.metric_name is not None: + if not ( + self.metric_name.startswith("train:") + or self.metric_name.startswith("val:") + ): + raise ValueError( + f"metric_name={self.metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary. ' + "If you are using an old config, please update checkpointing.metric_name to the new format, " + "e.g. 'val_reward' -> 'val:reward'." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/utils/checkpoint.py` around lines 109 - 118, Replace the assert-based config validation in the checkpoint logic with an explicit ValueError: locate the block referencing self.metric_name (in nemo_rl/utils/checkpoint.py, around the check in the checkpoint class where metric_name is validated) and change the assert to an if-check that raises ValueError when metric_name doesn't start with "train:" or "val:"; update the exception message to be clear and fix the example typo (e.g. use "val_reward -> 'val:reward'" or "e.g. 'val:reward' for old 'val_reward'") and include guidance about updating checkpointing.metric_name to the new format.nemo_rl/data/datasets/preference_datasets/preference_dataset.py (1)
54-60: Same duplicated task-name derivation asresponse_dataset.pylines 56–59.The body of lines 54–57 is verbatim copy of
response_dataset.pylines 56–59. See the suggestedtask_name_from_pathutility in theresponse_dataset.pycomment above.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/preference_datasets/preference_dataset.py` around lines 54 - 60, The task-name derivation in preference_dataset.py duplicates logic from response_dataset.py: replace the inline derivation (the default_task_name computation just before the call to self.common_init) with a call to a shared utility (e.g., task_name_from_path) or move that logic into a new helper function and use it from both modules; update preference_dataset.py to compute default_task_name by calling the shared helper and then call self.common_init(default_task_name=default_task_name, **kwargs) so both files use the same centralized function instead of duplicating the code.nemo_rl/data/datasets/response_datasets/response_dataset.py (1)
56-62: Extract duplicated task-name derivation to a shared utility.The exact same 4-line block (
"-".join(...).split(".")[0]+ leading-dash strip) is copy-pasted verbatim intopreference_dataset.py(lines 54–57). Extraction to a helper keeps a single definition.♻️ Suggested utility in
nemo_rl/data/datasets/utils.py+ def task_name_from_path(data_path: str) -> str: + """Derive a default task name from a dataset file path.""" + name = "-".join(data_path.split("/")[-2:]).split(".")[0] + return name.lstrip("-")Then in both
response_dataset.pyandpreference_dataset.py:- default_task_name = "-".join(data_path.split("/")[-2:]).split(".")[0] - if default_task_name[0] == "-": - default_task_name = default_task_name[1:] + default_task_name = task_name_from_path(data_path)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/data/datasets/response_datasets/response_dataset.py` around lines 56 - 62, Extract the duplicated task-name derivation into a shared utility function (e.g., task_name_from_path) and replace the 4-line block in response_dataset.py and preference_dataset.py with a call to that helper; specifically, add a function task_name_from_path(data_path: str) in nemo_rl/data/datasets/utils.py that returns the derived name with leading dashes stripped, then in the constructors where you currently compute default_task_name (the block before calling self.common_init), call default_task_name = task_name_from_path(data_path) and keep the subsequent self.common_init(default_task_name=default_task_name, **kwargs).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@nemo_rl/algorithms/sft.py`:
- Around line 324-362: Two warning calls in the validation loop (the one inside
the per-batch check that warns "No validation metrics were collected for this
batch." and the final one that warns "No validation metrics were collected.")
need a stacklevel so callers see the true call site; update both
warnings.warn(...) invocations in the SFT validation code path (where
val_results, val_data, val_loss, total_val_loss, total_num_valid_tokens are
used) to pass stacklevel=2.
- Around line 272-342: The task-level loss calculation may divide by zero when
task_num_valid_tokens is 0; update the block that sets val_loss[task_name] so it
checks task_num_valid_tokens before dividing (e.g., if task_num_valid_tokens ==
0 assign a safe default like float("nan") or 0.0 and optionally emit a warning)
instead of unconditionally doing task_val_loss / task_num_valid_tokens; change
the assignment around the symbols task_val_loss, task_num_valid_tokens, and
val_loss[task_name] to perform this guard.
In `@nemo_rl/data/datasets/raw_dataset.py`:
- Around line 51-58: Remove the implicit fallback processor="default" in the raw
dataset binding code: if processor is None, do not set a code-level default;
instead raise a clear error (e.g., ValueError or AssertionError) instructing the
caller to supply a processor in configuration/YAML. Keep the existing validation
against PROCESSOR_REGISTRY and assignment to self.processor (use
PROCESSOR_REGISTRY[processor]) so that when a non-None processor is provided
it's still validated and bound; reference symbols: skip_set_processor,
processor, PROCESSOR_REGISTRY, self.processor.
- Around line 21-40: Add a Google-style docstring to common_init describing
parameters and behavior; silence the unused kwargs by renaming **kwargs to
**_kwargs (or explicitly using it) to avoid ARG002; and remove the hardcoded
assignment that forces processor = "default" inside common_init so the processor
remains None unless explicitly provided (let configuration/YAML be the source of
the default). Update references to processor and skip_set_processor inside
common_init to respect a None processor value and only set the processor when
skip_set_processor is False and processor is explicitly provided.
In `@nemo_rl/data/datasets/response_datasets/__init__.py`:
- Around line 66-69: The failure comes from common_init requiring
skip_set_processor with no default, causing direct instantiation of
ResponseDataset-derived classes to raise a TypeError; update the common_init
function signature in raw_dataset.py to give skip_set_processor a default (e.g.,
change common_init(..., skip_set_processor: bool = False, ...)) so constructors
like ResponseDataset, Tulu3SftMixtureDataset, OpenMathInstruct2Dataset no longer
must receive that kwarg, and ensure any internal uses still respect the default;
run the unit tests to confirm the fix.
In `@nemo_rl/data/datasets/response_datasets/geometry3k.py`:
- Around line 69-70: The call to common_init in geometry3k currently omits the
skip_set_processor flag causing a TypeError; update the initializer call in
geometry3k (where self.common_init is invoked) to explicitly pass
skip_set_processor (either forward it from kwargs or supply a default like
False) so common_init receives that parameter (e.g., use
kwargs.get('skip_set_processor', False) or pass kwargs['skip_set_processor']
when present) and avoid the runtime error.
In `@nemo_rl/data/datasets/response_datasets/helpsteer3.py`:
- Around line 30-32: The __init__ currently calls self.common_init(...) without
passing skip_set_processor which common_init expects; update the __init__
signature to accept skip_set_processor: bool = False (preserving prior default
behavior) and forward it into the common_init call
(self.common_init(default_task_name="HelpSteer3",
skip_set_processor=skip_set_processor, **kwargs)) so callers can opt out of
processor setup and avoid the runtime TypeError in HelpSteer3.__init__.
In `@nemo_rl/data/datasets/response_datasets/oai_format_dataset.py`:
- Around line 137-141: Extract the duplicate default task-name derivation into a
single helper named _task_name_from_path(data_path: str) -> str (place it in
nemo_rl/data/datasets/raw_dataset.py or a shared utils module), replace the
duplicated logic in oai_format_dataset.py and binary_preference_dataset.py to
call this helper, and add a guard in the helper that raises a ValueError when
the derived name is empty (e.g., when data_path == "/") so callers (e.g., the
code that previously used default_task_name) must pass an explicit task_name
instead of producing empty metric keys.
- Around line 142-146: The common_init call is missing the required
skip_set_processor argument (declared as a positional parameter in
raw_dataset.py), causing TypeError when these classes are instantiated without
that kwarg; fix by extracting skip_set_processor =
kwargs.pop("skip_set_processor") (or kwargs.get and validate presence) and pass
it explicitly to self.common_init (e.g.,
self.common_init(default_task_name=default_task_name,
skip_set_processor=skip_set_processor, **kwargs)) in OpenAIFormatDataset (and
likewise in OasstDataset, HelpSteer3Dataset, DeepScalerDataset, AIME2024Dataset,
BinaryPreferenceDataset) so the required parameter is supplied.
In `@nemo_rl/data/datasets/response_datasets/refcoco.py`:
- Around line 192-193: The call to common_init(...) in the class (in refcoco.py)
is missing the required skip_set_processor parameter and will raise a TypeError;
update the call to self.common_init(default_task_name="refcoco",
skip_set_processor=kwargs.get("skip_set_processor"), **kwargs) (or explicitly
pass the appropriate boolean value) so the required skip_set_processor argument
is provided when invoking common_init.
---
Outside diff comments:
In `@examples/run_sft.py`:
- Around line 104-121: The current logic populates val_data_dict from two
sources (the data_list loop using data.task_name and the
data_config["validation"] loop using val_data.task_name) and silently overwrites
entries when task names collide; update the code so before assigning into
val_data_dict in the second validation-source loop (after val_data =
load_response_dataset(cfg)) you check if val_data.task_name already exists in
val_data_dict and, if so, emit a warning (e.g., using logger.warning or print)
that the validation dataset for that task_name will be overwritten, including
which source is being overridden; ensure the check references val_data_dict,
val_data.task_name, and the loading path around load_response_dataset so the
warning is clear and only emitted on duplicates.
In `@nemo_rl/algorithms/distillation.py`:
- Around line 945-1078: In function validate (in
nemo_rl/algorithms/distillation.py) remove the unused local assignment "rewards
= val_batch['total_reward']" (the variable rewards is never referenced
afterward); simply delete that line so the code uses
val_batch['total_reward'].tolist() directly and avoids an unused variable in the
loop that processes val_batch within validate.
In `@nemo_rl/algorithms/dpo.py`:
- Around line 663-682: The code does an unguarded split of full_metric_name into
prefix and metric_name (prefix, metric_name = full_metric_name.split(":", 1))
which raises a cryptic ValueError for misconfigured strings; update the
validation so you first check that full_metric_name is a non-empty string
containing a ":" (or use str.partition and verify the separator was present) and
raise a clear ValueError like "full_metric_name must be '<prefix>:<metric>'"
when missing; then proceed to set prefix, metric_name and the rest of the logic
that chooses metrics_source and updates dpo_save_state (keep the existing
warnings, deletion from dpo_save_state, and the metric-not-found error path
intact).
In `@nemo_rl/algorithms/rm.py`:
- Around line 590-610: The code currently does an unchecked split of
full_metric_name into prefix and metric_name causing a cryptic ValueError when
the colon is missing; before calling full_metric_name.split(":", 1) in the
checkpointing block that updates rm_save_state, add an explicit guard that
checks that full_metric_name contains exactly one ':' (or at least contains ':')
and if not raise a clear ValueError (or warnings.warn then skip) that explains
the expected format like "checkpointing.metric_name must be 'prefix:metric'
(e.g. 'train:accuracy')" so callers see a helpful error instead of the unpacking
exception.
In `@nemo_rl/data/utils.py`:
- Around line 116-140: The validation loader currently overwrites earlier
entries when multiple validation configs share the same task_name; update the
block that assigns to val_data_dict, val_task_data_processors, and
val_task_to_env to merge duplicates like the training path: after loading
val_data = load_response_dataset(cfg), check if val_data.task_name already
exists in val_data_dict and if so call concatenate_datasets(existing_dataset,
val_data.dataset) (same helper used on line ~105) and replace the entry,
otherwise set it; ensure val_task_data_processors[val_data.task_name] and
val_task_to_env[val_data.task_name] are only set once (or validated to be
consistent) to avoid inconsistent processor/env mappings.
---
Duplicate comments:
In `@nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py`:
- Around line 57-63: The default_task_name derivation code (building
default_task_name from data_path and trimming a leading '-') is duplicated here
and should be extracted into a shared helper (e.g., a new function
get_default_task_name(data_path) used by both BinaryPreferenceDataset and
OaiFormatDataset); update
nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py to call
that helper instead of repeating the logic around default_task_name, ensure the
helper returns a non-empty string (handle the empty-string edge case by falling
back to a safe name or raising a clear error), and preserve the existing call to
self.common_init(default_task_name=..., **kwargs) while respecting the existing
skip_set_processor behavior (i.e., do not change how skip_set_processor is
passed through to common_init).
In `@nemo_rl/data/datasets/preference_datasets/tulu3.py`:
- Around line 27-28: The call to self.common_init(...) in Tulu3Preference must
either pass skip_set_processor=True or rely on common_init having a default for
skip_set_processor; update the call site in tulu3.py to include
skip_set_processor=<appropriate boolean> (e.g., True) or modify the common_init
signature to provide a default value for skip_set_processor so the processor
behavior matches other datasets like refcoco; reference the common_init function
and the Tulu3Preference initializer to make the change consistently.
In `@nemo_rl/data/datasets/response_datasets/dapo_math.py`:
- Around line 55-60: The constructor calls
self.common_init(default_task_name="DAPOMathAIME2024", **kwargs) but does not
pass skip_set_processor (and common_init may not provide a default), so ensure
skip_set_processor is explicitly handled: either pass skip_set_processor from
kwargs into common_init (e.g., include
skip_set_processor=kwargs.get("skip_set_processor", <desired default>)) or
update common_init to define a safe default for skip_set_processor; locate the
call site in the DAPOMathAIME2024 dataset class __init__ and adjust the
common_init invocation or add the default in the common_init implementation so
skip_set_processor is always defined.
- Around line 26-27: common_init is being called without the skip_set_processor
argument here in the DAPOMath17K dataset initializer; update the call in the
constructor to pass skip_set_processor (same approach used in refcoco) or ensure
common_init defines a default for skip_set_processor. Specifically, modify the
call to self.common_init(default_task_name="DAPOMath17K",
skip_set_processor=..., **kwargs) or add skip_set_processor=False/True as a
default parameter in the common_init signature so the processor behavior is
explicit (refer to the common_init function and the class initializer in
dapo_math.py).
In `@nemo_rl/data/datasets/response_datasets/nemogym_dataset.py`:
- Around line 29-35: The call to self.common_init(...) in nemogym_dataset.py
uses common_init without supplying skip_set_processor, which can cause the same
bug noted in refcoco; update the invocation in the constructor where
default_task_name is computed so it explicitly passes skip_set_processor (e.g.,
skip_set_processor=kwargs.get("skip_set_processor", False)) or ensure
common_init has a default for skip_set_processor, referencing the common_init
method and the place where default_task_name is computed and passed to
self.common_init.
In `@nemo_rl/data/datasets/response_datasets/openmathinstruct2.py`:
- Around line 48-50: The call to self.common_init in OpenMathInstruct2 doesn't
pass skip_set_processor and relies on common_init to provide a safe default;
either update the call in the OpenMathInstruct2 constructor to explicitly pass
skip_set_processor (e.g., skip_set_processor=True or False as appropriate for
this dataset) or change the common_init signature to include a default value for
skip_set_processor so callers like OpenMathInstruct2 can omit it; locate the
common_init definition and add a default (or update the OpenMathInstruct2 call)
to ensure skip_set_processor is always defined.
In `@nemo_rl/data/datasets/response_datasets/squad.py`:
- Around line 30-31: The call to self.common_init in the SQuAD dataset should
explicitly handle the skip_set_processor flag: either pass skip_set_processor
from the SQuAD constructor into self.common_init (e.g.,
self.common_init(default_task_name="squad",
skip_set_processor=skip_set_processor, **kwargs)) or add a default for
skip_set_processor inside the common_init signature so callers that omit it
(like squad) behave correctly; update the SQuAD constructor to pass through the
parameter or update common_init to set a sensible default for skip_set_processor
to avoid the missing-argument issue.
---
Nitpick comments:
In `@examples/configs/sft.yaml`:
- Around line 237-249: Move the explanatory comment about how
data.validation.task_name maps to checkpointing.metric_name from the end of the
data: section into/near the checkpointing: block—specifically adjacent to the
checkpointing.metric_name entry—so users editing metric_name can immediately see
the guidance; reference data.validation.task_name in the moved comment and keep
the example lines (e.g., metric_name: "val:val_loss_dataset1" and metric_name:
"val:val_loss") so the relationship and examples remain intact.
In `@nemo_rl/data/datasets/preference_datasets/preference_dataset.py`:
- Around line 54-60: The task-name derivation in preference_dataset.py
duplicates logic from response_dataset.py: replace the inline derivation (the
default_task_name computation just before the call to self.common_init) with a
call to a shared utility (e.g., task_name_from_path) or move that logic into a
new helper function and use it from both modules; update preference_dataset.py
to compute default_task_name by calling the shared helper and then call
self.common_init(default_task_name=default_task_name, **kwargs) so both files
use the same centralized function instead of duplicating the code.
In `@nemo_rl/data/datasets/response_datasets/response_dataset.py`:
- Around line 56-62: Extract the duplicated task-name derivation into a shared
utility function (e.g., task_name_from_path) and replace the 4-line block in
response_dataset.py and preference_dataset.py with a call to that helper;
specifically, add a function task_name_from_path(data_path: str) in
nemo_rl/data/datasets/utils.py that returns the derived name with leading dashes
stripped, then in the constructors where you currently compute default_task_name
(the block before calling self.common_init), call default_task_name =
task_name_from_path(data_path) and keep the subsequent
self.common_init(default_task_name=default_task_name, **kwargs).
In `@nemo_rl/utils/checkpoint.py`:
- Around line 109-118: Replace the assert-based config validation in the
checkpoint logic with an explicit ValueError: locate the block referencing
self.metric_name (in nemo_rl/utils/checkpoint.py, around the check in the
checkpoint class where metric_name is validated) and change the assert to an
if-check that raises ValueError when metric_name doesn't start with "train:" or
"val:"; update the exception message to be clear and fix the example typo (e.g.
use "val_reward -> 'val:reward'" or "e.g. 'val:reward' for old 'val_reward'")
and include guidance about updating checkpointing.metric_name to the new format.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (37)
docs/guides/grpo.mddocs/guides/sft.mdexamples/configs/distillation_math.yamlexamples/configs/grpo_math_1B.yamlexamples/configs/sft.yamlexamples/run_sft.pynemo_rl/algorithms/distillation.pynemo_rl/algorithms/dpo.pynemo_rl/algorithms/grpo.pynemo_rl/algorithms/rm.pynemo_rl/algorithms/sft.pynemo_rl/data/datasets/preference_datasets/__init__.pynemo_rl/data/datasets/preference_datasets/binary_preference_dataset.pynemo_rl/data/datasets/preference_datasets/helpsteer3.pynemo_rl/data/datasets/preference_datasets/preference_dataset.pynemo_rl/data/datasets/preference_datasets/tulu3.pynemo_rl/data/datasets/raw_dataset.pynemo_rl/data/datasets/response_datasets/__init__.pynemo_rl/data/datasets/response_datasets/aime24.pynemo_rl/data/datasets/response_datasets/clevr.pynemo_rl/data/datasets/response_datasets/dapo_math.pynemo_rl/data/datasets/response_datasets/deepscaler.pynemo_rl/data/datasets/response_datasets/geometry3k.pynemo_rl/data/datasets/response_datasets/helpsteer3.pynemo_rl/data/datasets/response_datasets/nemogym_dataset.pynemo_rl/data/datasets/response_datasets/oai_format_dataset.pynemo_rl/data/datasets/response_datasets/oasst.pynemo_rl/data/datasets/response_datasets/openmathinstruct2.pynemo_rl/data/datasets/response_datasets/refcoco.pynemo_rl/data/datasets/response_datasets/response_dataset.pynemo_rl/data/datasets/response_datasets/squad.pynemo_rl/data/datasets/response_datasets/tulu3.pynemo_rl/data/utils.pynemo_rl/utils/checkpoint.pytests/functional/distillation.shtests/functional/grpo_multiple_datasets.shtests/functional/sft.sh
terrykong
left a comment
There was a problem hiding this comment.
mostly lgtm. small comments
| if sum_num_valid_tokens > 0: | ||
| val_metrics["val_loss"] /= sum_num_valid_tokens | ||
| # Calculate validation metrics | ||
| if total_num_valid_tokens > 0: |
There was a problem hiding this comment.
nit:
| if total_num_valid_tokens > 0: | |
| if total_num_valid_tokens > 0: | |
| assert "total" not in val_loss, f"total is a reserved task_name since it is used in the metrics as the aggregate label" |
|
|
||
| full_metric_name = master_config["checkpointing"]["metric_name"] | ||
| if full_metric_name is not None: | ||
| assert full_metric_name.startswith( |
There was a problem hiding this comment.
not removed, just moved to nemo_rl/utils/checkpoint.py. :)
so that we could assert this at the beginning and reduce duplicated things in different algos.
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
74ca707 to
fb99598
Compare
Signed-off-by: Yuki Huang <yukih@nvidia.com>
fb99598 to
f2d4720
Compare
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Summary by CodeRabbit
New Features
val:accuracy_<TaskName>)Documentation
Tests