Skip to content

feat: split validation statistics by task name#2019

Open
yuki-97 wants to merge 14 commits intomainfrom
yukih/validation-task-name
Open

feat: split validation statistics by task name#2019
yuki-97 wants to merge 14 commits intomainfrom
yukih/validation-task-name

Conversation

@yuki-97
Copy link
Contributor

@yuki-97 yuki-97 commented Feb 24, 2026

  1. Support split validation statistics by task name in SFT/GRPO/Distillation by using multiple val dataloaders like RM/DPO.
  2. Support set custom task name, it's useful for saving checkpoint according to a specific dataset's val statistic (e.g. accuracy, val_loss).
image

Summary by CodeRabbit

  • New Features

    • Multi-task validation support with per-dataset accuracy and loss tracking for checkpointing
    • Per-dataset task naming for validation metrics (e.g., val:accuracy_<TaskName>)
  • Documentation

    • Added guidance for configuring dataset-specific validation metrics in training configs
    • Included YAML examples for multi-dataset validation setups
  • Tests

    • Updated validation metric checks to include per-task accuracy metrics

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 24, 2026
@yuki-97 yuki-97 marked this pull request as ready for review February 24, 2026 15:31
@yuki-97 yuki-97 requested review from a team as code owners February 24, 2026 15:31
@yuki-97 yuki-97 added the CI:L1 Run doctests, unit tests, and functional tests label Feb 24, 2026
@yuki-97 yuki-97 requested a review from terrykong February 24, 2026 15:32
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 24, 2026

📝 Walkthrough

Walkthrough

This PR extends validation support from single merged datasets to per-task validation dictionaries across training algorithms. It updates data setup, algorithm training loops, dataset initialization patterns, configuration examples, and checkpointing logic to handle per-dataset validation metrics and enable dataset-specific checkpoint triggers.

Changes

Cohort / File(s) Summary
Documentation & Guidance
docs/guides/grpo.md, docs/guides/sft.md
Added guidance on using dataset-specific validation metrics for checkpointing via task_name config and metric_name set to val:accuracy_<TaskName> or val:val_loss_<TaskName>.
Configuration Examples
examples/configs/distillation_math.yaml, grpo_math_1B.yaml, sft.yaml
Updated config examples with comments explaining per-dataset metric configuration and added new checkpointing guidance blocks for multi-dataset setups.
Example Training Script
examples/run_sft.py
Refactored validation dataset collection from list-based to dict-based (keyed by task_name); changed setup_data return type from single val_dataset to dict[str, AllTaskProcessedDataset].
Algorithm: Data Setup & Validation
nemo_rl/algorithms/sft.py, grpo.py, distillation.py
Updated setup signatures to accept val_dataset: dict[str, AllTaskProcessedDataset] and return val_dataloader: dict[str, StatefulDataLoader]; refactored validation loops to iterate per-task, accumulating per-task metrics and computing totals.
Algorithm: Checkpoint & Metric Handling
nemo_rl/algorithms/dpo.py, rm.py
Removed strict runtime assertion on metric_name format (train:/val: prefix); adjusted checkpointing logic to split metric_name without prior validation, with fallback handling for missing metrics.
Data Utilities & Dataset Base
nemo_rl/data/utils.py, nemo_rl/data/datasets/raw_dataset.py
Updated setup_response_data return type to dict[str, AllTaskProcessedDataset] for validation; introduced common_init method on RawDataset for centralized task initialization, removing legacy set_processor and set_task_spec methods.
Dataset Registry & Loading
nemo_rl/data/datasets/preference_datasets/__init__.py, response_datasets/__init__.py
Added skip_set_processor=True/False parameter when instantiating datasets; removed explicit set_task_spec and set_processor calls, delegating to common_init.
Response & Preference Dataset Classes
nemo_rl/data/datasets/response_datasets/*.py (13 files), nemo_rl/data/datasets/preference_datasets/*.py (5 files)
Replaced direct self.task_name assignment with self.common_init(default_task_name=..., **kwargs) calls; some datasets now add task_name column post-load.
Checkpoint Validation
nemo_rl/utils/checkpoint.py
Added runtime validation in CheckpointManager.__init__ to enforce metric_name format (must start with "train:" or "val:" if provided).
Test Metric Checks
tests/functional/distillation.sh, grpo_multiple_datasets.sh, sft.sh
Extended metric validation checks to include per-dataset accuracy/loss metrics (e.g., validation/accuracy_<DatasetName>) in addition to aggregate metrics.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant DataLoader as Per-Task<br/>DataLoaders
    participant Validator
    participant Metrics as Per-Task<br/>Metrics
    participant Checkpoint as Checkpoint<br/>Manager

    rect rgba(0, 100, 200, 0.5)
    Note over Trainer,Checkpoint: Old Flow: Single Validation Dataset
    Trainer->>DataLoader: Load val_dataloader (merged)
    DataLoader->>Validator: Single batch stream
    Validator->>Metrics: Aggregate results
    Metrics->>Checkpoint: global_accuracy/loss
    Checkpoint->>Checkpoint: Save if metric improves
    end

    rect rgba(0, 150, 100, 0.5)
    Note over Trainer,Checkpoint: New Flow: Per-Task Validation Datasets
    Trainer->>DataLoader: Load val_dataloaders: dict[task_name]
    loop For each task in dict
        DataLoader->>Validator: Per-task batch stream
        Validator->>Metrics: Per-task results
    end
    Metrics->>Metrics: Accumulate per-task metrics
    Metrics->>Metrics: Compute total_accuracy/loss
    Checkpoint->>Checkpoint: Select metric by task_name<br/>(val:accuracy_Task1)
    Checkpoint->>Checkpoint: Save if metric improves
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • PR #1763: Continues the dataset split/refactor; both modify validation from single merged dataset to per-task dicts across setup_response_data, examples/run_sft.py, and algorithm signatures.
  • PR #1649: Both introduce per-dataset task names and per-task validation datasets in setup functions, config structures, and checkpointing metric naming.
  • PR #1291: Both modify checkpointing metric_name parsing and handling of train:/val: prefixes in checkpoint logic.

Suggested labels

CI:L1

Suggested reviewers

  • terrykong
  • ashors1
  • yfw
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.90% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'feat: split validation statistics by task name' directly and clearly summarizes the main objective of the changeset, which is to enable per-task validation statistics tracking across multiple algorithms.
Test Results For Major Changes ✅ Passed PR includes validation accuracy metrics comparison graphs and updated functional tests demonstrating per-task validation statistics work correctly without regression.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yukih/validation-task-name

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
examples/run_sft.py (1)

104-121: ⚠️ Potential issue | 🟡 Minor

Silent overwrite when the same task_name appears in both validation sources.

If data.task_name at line 106 equals val_data.task_name at line 121 (a task that appears in both the train-split validation and the explicit validation: config), the train-split dataset is silently replaced. Add a warning:

⚠️ Proposed guard
+        if val_data.task_name in val_data_dict:
+            warnings.warn(
+                f"task_name '{val_data.task_name}' already exists in val_data_dict "
+                "(from train split). Overwriting with config-defined validation dataset."
+            )
         val_data_dict[val_data.task_name] = val_data.dataset
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/run_sft.py` around lines 104 - 121, The current logic populates
val_data_dict from two sources (the data_list loop using data.task_name and the
data_config["validation"] loop using val_data.task_name) and silently overwrites
entries when task names collide; update the code so before assigning into
val_data_dict in the second validation-source loop (after val_data =
load_response_dataset(cfg)) you check if val_data.task_name already exists in
val_data_dict and, if so, emit a warning (e.g., using logger.warning or print)
that the validation dataset for that task_name will be overwritten, including
which source is being overridden; ensure the check references val_data_dict,
val_data.task_name, and the loading path around load_response_dataset so the
warning is clear and only emitted on duplicates.
nemo_rl/algorithms/dpo.py (1)

663-682: ⚠️ Potential issue | 🟡 Minor

Same missing-colon validation as rm.py — cryptic ValueError for misconfigured metric_name.

Line 665 (prefix, metric_name = full_metric_name.split(":", 1)) is the same unguarded split as in rm.py. Apply the same fix:

🛡️ Proposed fix
     if full_metric_name is not None:
+        if ":" not in full_metric_name:
+            raise ValueError(
+                f"checkpointing.metric_name must be in '<prefix>:<metric>' format, got '{full_metric_name}'"
+            )
         prefix, metric_name = full_metric_name.split(":", 1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/dpo.py` around lines 663 - 682, The code does an unguarded
split of full_metric_name into prefix and metric_name (prefix, metric_name =
full_metric_name.split(":", 1)) which raises a cryptic ValueError for
misconfigured strings; update the validation so you first check that
full_metric_name is a non-empty string containing a ":" (or use str.partition
and verify the separator was present) and raise a clear ValueError like
"full_metric_name must be '<prefix>:<metric>'" when missing; then proceed to set
prefix, metric_name and the rest of the logic that chooses metrics_source and
updates dpo_save_state (keep the existing warnings, deletion from
dpo_save_state, and the metric-not-found error path intact).
nemo_rl/algorithms/rm.py (1)

590-610: ⚠️ Potential issue | 🟡 Minor

Removed format assertion leaves a cryptic ValueError when metric_name lacks a colon.

With the assertion gone, a misconfigured metric_name: accuracy (no prefix: part) reaches line 592:

prefix, metric_name = full_metric_name.split(":", 1)
# → ValueError: not enough values to unpack (expected 2, got 1)

Add an explicit guard before the split:

🛡️ Proposed fix
 if full_metric_name is not None:
+    if ":" not in full_metric_name:
+        raise ValueError(
+            f"checkpointing.metric_name must be in '<prefix>:<metric>' format, got '{full_metric_name}'"
+        )
     prefix, metric_name = full_metric_name.split(":", 1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/rm.py` around lines 590 - 610, The code currently does an
unchecked split of full_metric_name into prefix and metric_name causing a
cryptic ValueError when the colon is missing; before calling
full_metric_name.split(":", 1) in the checkpointing block that updates
rm_save_state, add an explicit guard that checks that full_metric_name contains
exactly one ':' (or at least contains ':') and if not raise a clear ValueError
(or warnings.warn then skip) that explains the expected format like
"checkpointing.metric_name must be 'prefix:metric' (e.g. 'train:accuracy')" so
callers see a helpful error instead of the unpacking exception.
nemo_rl/algorithms/distillation.py (1)

945-1078: ⚠️ Potential issue | 🟡 Minor

Remove unused rewards variable assignment.

Line 1011 assigns val_batch["total_reward"] to rewards, but the variable is never used; the next line accesses val_batch["total_reward"].tolist() directly. Delete the unused assignment.

Diff
-                rewards = val_batch["total_reward"]

                 task_rewards.extend(val_batch["total_reward"].tolist())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/distillation.py` around lines 945 - 1078, In function
validate (in nemo_rl/algorithms/distillation.py) remove the unused local
assignment "rewards = val_batch['total_reward']" (the variable rewards is never
referenced afterward); simply delete that line so the code uses
val_batch['total_reward'].tolist() directly and avoids an unused variable in the
loop that processes val_batch within validate.
nemo_rl/data/utils.py (1)

116-140: ⚠️ Potential issue | 🟠 Major

Guard against duplicate task_name overwrites in validation data loading.

Direct dict assignment silently drops earlier validation datasets if the same task_name appears in multiple configs. The training path handles this correctly with concatenate_datasets() (line 105), but validation does not. When duplicate task_names occur, later assignments overwrite earlier ones in val_data_dict, val_task_data_processors, and val_task_to_env, causing silent data loss.

Concatenate validation datasets with matching task_names using the same pattern as training data, or assert that task_names are unique across validation configs.

💡 Suggested fix (merge duplicates like training data)
-            val_data_dict[data.task_name] = data.val_dataset
+            if data.task_name in val_data_dict:
+                val_data_dict[data.task_name] = concatenate_datasets(
+                    [val_data_dict[data.task_name], data.val_dataset]
+                )
+            else:
+                val_data_dict[data.task_name] = data.val_dataset
...
-            val_data_dict[val_data.task_name] = val_data.dataset
+            if val_data.task_name in val_data_dict:
+                val_data_dict[val_data.task_name] = concatenate_datasets(
+                    [val_data_dict[val_data.task_name], val_data.dataset]
+                )
+            else:
+                val_data_dict[val_data.task_name] = val_data.dataset
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/utils.py` around lines 116 - 140, The validation loader
currently overwrites earlier entries when multiple validation configs share the
same task_name; update the block that assigns to val_data_dict,
val_task_data_processors, and val_task_to_env to merge duplicates like the
training path: after loading val_data = load_response_dataset(cfg), check if
val_data.task_name already exists in val_data_dict and if so call
concatenate_datasets(existing_dataset, val_data.dataset) (same helper used on
line ~105) and replace the entry, otherwise set it; ensure
val_task_data_processors[val_data.task_name] and
val_task_to_env[val_data.task_name] are only set once (or validated to be
consistent) to avoid inconsistent processor/env mappings.
♻️ Duplicate comments (7)
nemo_rl/data/datasets/preference_datasets/tulu3.py (1)

27-28: Same common_init argument concern as refcoco.

Please ensure skip_set_processor is passed or has a default in common_init (see earlier note).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/preference_datasets/tulu3.py` around lines 27 - 28, The
call to self.common_init(...) in Tulu3Preference must either pass
skip_set_processor=True or rely on common_init having a default for
skip_set_processor; update the call site in tulu3.py to include
skip_set_processor=<appropriate boolean> (e.g., True) or modify the common_init
signature to provide a default value for skip_set_processor so the processor
behavior matches other datasets like refcoco; reference the common_init function
and the Tulu3Preference initializer to make the change consistently.
nemo_rl/data/datasets/response_datasets/dapo_math.py (2)

55-60: Same common_init argument concern as refcoco.

Please ensure skip_set_processor is passed or has a default in common_init (see earlier note).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/response_datasets/dapo_math.py` around lines 55 - 60,
The constructor calls self.common_init(default_task_name="DAPOMathAIME2024",
**kwargs) but does not pass skip_set_processor (and common_init may not provide
a default), so ensure skip_set_processor is explicitly handled: either pass
skip_set_processor from kwargs into common_init (e.g., include
skip_set_processor=kwargs.get("skip_set_processor", <desired default>)) or
update common_init to define a safe default for skip_set_processor; locate the
call site in the DAPOMathAIME2024 dataset class __init__ and adjust the
common_init invocation or add the default in the common_init implementation so
skip_set_processor is always defined.

26-27: Same common_init argument concern as refcoco.

Please ensure skip_set_processor is passed or has a default in common_init (see earlier note).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/response_datasets/dapo_math.py` around lines 26 - 27,
common_init is being called without the skip_set_processor argument here in the
DAPOMath17K dataset initializer; update the call in the constructor to pass
skip_set_processor (same approach used in refcoco) or ensure common_init defines
a default for skip_set_processor. Specifically, modify the call to
self.common_init(default_task_name="DAPOMath17K", skip_set_processor=...,
**kwargs) or add skip_set_processor=False/True as a default parameter in the
common_init signature so the processor behavior is explicit (refer to the
common_init function and the class initializer in dapo_math.py).
nemo_rl/data/datasets/response_datasets/nemogym_dataset.py (1)

29-35: Same common_init argument concern as refcoco.

Please ensure skip_set_processor is passed or has a default in common_init (see earlier note).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/response_datasets/nemogym_dataset.py` around lines 29 -
35, The call to self.common_init(...) in nemogym_dataset.py uses common_init
without supplying skip_set_processor, which can cause the same bug noted in
refcoco; update the invocation in the constructor where default_task_name is
computed so it explicitly passes skip_set_processor (e.g.,
skip_set_processor=kwargs.get("skip_set_processor", False)) or ensure
common_init has a default for skip_set_processor, referencing the common_init
method and the place where default_task_name is computed and passed to
self.common_init.
nemo_rl/data/datasets/response_datasets/squad.py (1)

30-31: Same common_init argument concern as refcoco.

Please ensure skip_set_processor is passed or has a default in common_init (see earlier note).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/response_datasets/squad.py` around lines 30 - 31, The
call to self.common_init in the SQuAD dataset should explicitly handle the
skip_set_processor flag: either pass skip_set_processor from the SQuAD
constructor into self.common_init (e.g.,
self.common_init(default_task_name="squad",
skip_set_processor=skip_set_processor, **kwargs)) or add a default for
skip_set_processor inside the common_init signature so callers that omit it
(like squad) behave correctly; update the SQuAD constructor to pass through the
parameter or update common_init to set a sensible default for skip_set_processor
to avoid the missing-argument issue.
nemo_rl/data/datasets/response_datasets/openmathinstruct2.py (1)

48-50: Same common_init argument concern as refcoco.

Please ensure skip_set_processor is passed or has a default in common_init (see earlier note).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/response_datasets/openmathinstruct2.py` around lines 48
- 50, The call to self.common_init in OpenMathInstruct2 doesn't pass
skip_set_processor and relies on common_init to provide a safe default; either
update the call in the OpenMathInstruct2 constructor to explicitly pass
skip_set_processor (e.g., skip_set_processor=True or False as appropriate for
this dataset) or change the common_init signature to include a default value for
skip_set_processor so callers like OpenMathInstruct2 can omit it; locate the
common_init definition and add a default (or update the OpenMathInstruct2 call)
to ensure skip_set_processor is always defined.
nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py (1)

57-63: Duplicate task-name derivation logic — same issue as oai_format_dataset.py.

The default_task_name derivation (lines 58–60) is identical to oai_format_dataset.py lines 138–140. Please extract to the shared helper described in the oai_format_dataset.py comment — this is the second occurrence that confirms the need for the refactor.

The same empty-string edge case and skip_set_processor concern flagged in oai_format_dataset.py apply here as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py`
around lines 57 - 63, The default_task_name derivation code (building
default_task_name from data_path and trimming a leading '-') is duplicated here
and should be extracted into a shared helper (e.g., a new function
get_default_task_name(data_path) used by both BinaryPreferenceDataset and
OaiFormatDataset); update
nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py to call
that helper instead of repeating the logic around default_task_name, ensure the
helper returns a non-empty string (handle the empty-string edge case by falling
back to a safe name or raising a clear error), and preserve the existing call to
self.common_init(default_task_name=..., **kwargs) while respecting the existing
skip_set_processor behavior (i.e., do not change how skip_set_processor is
passed through to common_init).
🧹 Nitpick comments (4)
examples/configs/sft.yaml (1)

237-249: Consider moving the checkpointing guidance comment closer to the checkpointing: block.

The guidance block at lines 237–248 explains how task_name in the data: section relates to metric_name in checkpointing:, but it sits at the end of the data: section, far from the checkpointing: block (lines 16–25). Readers are more likely to look for this guidance when editing checkpointing.metric_name. A cross-reference comment near line 20 (the metric_name entry) would be more discoverable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/configs/sft.yaml` around lines 237 - 249, Move the explanatory
comment about how data.validation.task_name maps to checkpointing.metric_name
from the end of the data: section into/near the checkpointing:
block—specifically adjacent to the checkpointing.metric_name entry—so users
editing metric_name can immediately see the guidance; reference
data.validation.task_name in the moved comment and keep the example lines (e.g.,
metric_name: "val:val_loss_dataset1" and metric_name: "val:val_loss") so the
relationship and examples remain intact.
nemo_rl/utils/checkpoint.py (1)

109-118: Prefer a ValueError over assert for config validation, and fix the message typo.
assert can be stripped in optimized runs, and the example string currently reads awkwardly ('val_reward --> 'val:reward'). Consider a regular exception with a clearer message.

♻️ Proposed refactor
-        if self.metric_name is not None:
-            assert self.metric_name.startswith("train:") or self.metric_name.startswith(
-                "val:"
-            ), (
-                f"metric_name={self.metric_name} must start with 'val:' or 'train:',\n"
-                f'followed by the corresponding name in the "val" or "train" metrics dictionary.'
-                f"  If you are using an old config, please updated checkpointing.metric_name to the new format, "
-                f" e.g. 'val_reward --> 'val:reward'"
-            )
+        if self.metric_name is not None:
+            if not (
+                self.metric_name.startswith("train:")
+                or self.metric_name.startswith("val:")
+            ):
+                raise ValueError(
+                    f"metric_name={self.metric_name} must start with 'val:' or 'train:',\n"
+                    f'followed by the corresponding name in the "val" or "train" metrics dictionary. '
+                    "If you are using an old config, please update checkpointing.metric_name to the new format, "
+                    "e.g. 'val_reward' -> 'val:reward'."
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/utils/checkpoint.py` around lines 109 - 118, Replace the assert-based
config validation in the checkpoint logic with an explicit ValueError: locate
the block referencing self.metric_name (in nemo_rl/utils/checkpoint.py, around
the check in the checkpoint class where metric_name is validated) and change the
assert to an if-check that raises ValueError when metric_name doesn't start with
"train:" or "val:"; update the exception message to be clear and fix the example
typo (e.g. use "val_reward -> 'val:reward'" or "e.g. 'val:reward' for old
'val_reward'") and include guidance about updating checkpointing.metric_name to
the new format.
nemo_rl/data/datasets/preference_datasets/preference_dataset.py (1)

54-60: Same duplicated task-name derivation as response_dataset.py lines 56–59.

The body of lines 54–57 is verbatim copy of response_dataset.py lines 56–59. See the suggested task_name_from_path utility in the response_dataset.py comment above.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/preference_datasets/preference_dataset.py` around lines
54 - 60, The task-name derivation in preference_dataset.py duplicates logic from
response_dataset.py: replace the inline derivation (the default_task_name
computation just before the call to self.common_init) with a call to a shared
utility (e.g., task_name_from_path) or move that logic into a new helper
function and use it from both modules; update preference_dataset.py to compute
default_task_name by calling the shared helper and then call
self.common_init(default_task_name=default_task_name, **kwargs) so both files
use the same centralized function instead of duplicating the code.
nemo_rl/data/datasets/response_datasets/response_dataset.py (1)

56-62: Extract duplicated task-name derivation to a shared utility.

The exact same 4-line block ("-".join(...).split(".")[0] + leading-dash strip) is copy-pasted verbatim into preference_dataset.py (lines 54–57). Extraction to a helper keeps a single definition.

♻️ Suggested utility in nemo_rl/data/datasets/utils.py
+ def task_name_from_path(data_path: str) -> str:
+     """Derive a default task name from a dataset file path."""
+     name = "-".join(data_path.split("/")[-2:]).split(".")[0]
+     return name.lstrip("-")

Then in both response_dataset.py and preference_dataset.py:

- default_task_name = "-".join(data_path.split("/")[-2:]).split(".")[0]
- if default_task_name[0] == "-":
-     default_task_name = default_task_name[1:]
+ default_task_name = task_name_from_path(data_path)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/data/datasets/response_datasets/response_dataset.py` around lines 56
- 62, Extract the duplicated task-name derivation into a shared utility function
(e.g., task_name_from_path) and replace the 4-line block in response_dataset.py
and preference_dataset.py with a call to that helper; specifically, add a
function task_name_from_path(data_path: str) in nemo_rl/data/datasets/utils.py
that returns the derived name with leading dashes stripped, then in the
constructors where you currently compute default_task_name (the block before
calling self.common_init), call default_task_name =
task_name_from_path(data_path) and keep the subsequent
self.common_init(default_task_name=default_task_name, **kwargs).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@nemo_rl/algorithms/sft.py`:
- Around line 324-362: Two warning calls in the validation loop (the one inside
the per-batch check that warns "No validation metrics were collected for this
batch." and the final one that warns "No validation metrics were collected.")
need a stacklevel so callers see the true call site; update both
warnings.warn(...) invocations in the SFT validation code path (where
val_results, val_data, val_loss, total_val_loss, total_num_valid_tokens are
used) to pass stacklevel=2.
- Around line 272-342: The task-level loss calculation may divide by zero when
task_num_valid_tokens is 0; update the block that sets val_loss[task_name] so it
checks task_num_valid_tokens before dividing (e.g., if task_num_valid_tokens ==
0 assign a safe default like float("nan") or 0.0 and optionally emit a warning)
instead of unconditionally doing task_val_loss / task_num_valid_tokens; change
the assignment around the symbols task_val_loss, task_num_valid_tokens, and
val_loss[task_name] to perform this guard.

In `@nemo_rl/data/datasets/raw_dataset.py`:
- Around line 51-58: Remove the implicit fallback processor="default" in the raw
dataset binding code: if processor is None, do not set a code-level default;
instead raise a clear error (e.g., ValueError or AssertionError) instructing the
caller to supply a processor in configuration/YAML. Keep the existing validation
against PROCESSOR_REGISTRY and assignment to self.processor (use
PROCESSOR_REGISTRY[processor]) so that when a non-None processor is provided
it's still validated and bound; reference symbols: skip_set_processor,
processor, PROCESSOR_REGISTRY, self.processor.
- Around line 21-40: Add a Google-style docstring to common_init describing
parameters and behavior; silence the unused kwargs by renaming **kwargs to
**_kwargs (or explicitly using it) to avoid ARG002; and remove the hardcoded
assignment that forces processor = "default" inside common_init so the processor
remains None unless explicitly provided (let configuration/YAML be the source of
the default). Update references to processor and skip_set_processor inside
common_init to respect a None processor value and only set the processor when
skip_set_processor is False and processor is explicitly provided.

In `@nemo_rl/data/datasets/response_datasets/__init__.py`:
- Around line 66-69: The failure comes from common_init requiring
skip_set_processor with no default, causing direct instantiation of
ResponseDataset-derived classes to raise a TypeError; update the common_init
function signature in raw_dataset.py to give skip_set_processor a default (e.g.,
change common_init(..., skip_set_processor: bool = False, ...)) so constructors
like ResponseDataset, Tulu3SftMixtureDataset, OpenMathInstruct2Dataset no longer
must receive that kwarg, and ensure any internal uses still respect the default;
run the unit tests to confirm the fix.

In `@nemo_rl/data/datasets/response_datasets/geometry3k.py`:
- Around line 69-70: The call to common_init in geometry3k currently omits the
skip_set_processor flag causing a TypeError; update the initializer call in
geometry3k (where self.common_init is invoked) to explicitly pass
skip_set_processor (either forward it from kwargs or supply a default like
False) so common_init receives that parameter (e.g., use
kwargs.get('skip_set_processor', False) or pass kwargs['skip_set_processor']
when present) and avoid the runtime error.

In `@nemo_rl/data/datasets/response_datasets/helpsteer3.py`:
- Around line 30-32: The __init__ currently calls self.common_init(...) without
passing skip_set_processor which common_init expects; update the __init__
signature to accept skip_set_processor: bool = False (preserving prior default
behavior) and forward it into the common_init call
(self.common_init(default_task_name="HelpSteer3",
skip_set_processor=skip_set_processor, **kwargs)) so callers can opt out of
processor setup and avoid the runtime TypeError in HelpSteer3.__init__.

In `@nemo_rl/data/datasets/response_datasets/oai_format_dataset.py`:
- Around line 137-141: Extract the duplicate default task-name derivation into a
single helper named _task_name_from_path(data_path: str) -> str (place it in
nemo_rl/data/datasets/raw_dataset.py or a shared utils module), replace the
duplicated logic in oai_format_dataset.py and binary_preference_dataset.py to
call this helper, and add a guard in the helper that raises a ValueError when
the derived name is empty (e.g., when data_path == "/") so callers (e.g., the
code that previously used default_task_name) must pass an explicit task_name
instead of producing empty metric keys.
- Around line 142-146: The common_init call is missing the required
skip_set_processor argument (declared as a positional parameter in
raw_dataset.py), causing TypeError when these classes are instantiated without
that kwarg; fix by extracting skip_set_processor =
kwargs.pop("skip_set_processor") (or kwargs.get and validate presence) and pass
it explicitly to self.common_init (e.g.,
self.common_init(default_task_name=default_task_name,
skip_set_processor=skip_set_processor, **kwargs)) in OpenAIFormatDataset (and
likewise in OasstDataset, HelpSteer3Dataset, DeepScalerDataset, AIME2024Dataset,
BinaryPreferenceDataset) so the required parameter is supplied.

In `@nemo_rl/data/datasets/response_datasets/refcoco.py`:
- Around line 192-193: The call to common_init(...) in the class (in refcoco.py)
is missing the required skip_set_processor parameter and will raise a TypeError;
update the call to self.common_init(default_task_name="refcoco",
skip_set_processor=kwargs.get("skip_set_processor"), **kwargs) (or explicitly
pass the appropriate boolean value) so the required skip_set_processor argument
is provided when invoking common_init.

---

Outside diff comments:
In `@examples/run_sft.py`:
- Around line 104-121: The current logic populates val_data_dict from two
sources (the data_list loop using data.task_name and the
data_config["validation"] loop using val_data.task_name) and silently overwrites
entries when task names collide; update the code so before assigning into
val_data_dict in the second validation-source loop (after val_data =
load_response_dataset(cfg)) you check if val_data.task_name already exists in
val_data_dict and, if so, emit a warning (e.g., using logger.warning or print)
that the validation dataset for that task_name will be overwritten, including
which source is being overridden; ensure the check references val_data_dict,
val_data.task_name, and the loading path around load_response_dataset so the
warning is clear and only emitted on duplicates.

In `@nemo_rl/algorithms/distillation.py`:
- Around line 945-1078: In function validate (in
nemo_rl/algorithms/distillation.py) remove the unused local assignment "rewards
= val_batch['total_reward']" (the variable rewards is never referenced
afterward); simply delete that line so the code uses
val_batch['total_reward'].tolist() directly and avoids an unused variable in the
loop that processes val_batch within validate.

In `@nemo_rl/algorithms/dpo.py`:
- Around line 663-682: The code does an unguarded split of full_metric_name into
prefix and metric_name (prefix, metric_name = full_metric_name.split(":", 1))
which raises a cryptic ValueError for misconfigured strings; update the
validation so you first check that full_metric_name is a non-empty string
containing a ":" (or use str.partition and verify the separator was present) and
raise a clear ValueError like "full_metric_name must be '<prefix>:<metric>'"
when missing; then proceed to set prefix, metric_name and the rest of the logic
that chooses metrics_source and updates dpo_save_state (keep the existing
warnings, deletion from dpo_save_state, and the metric-not-found error path
intact).

In `@nemo_rl/algorithms/rm.py`:
- Around line 590-610: The code currently does an unchecked split of
full_metric_name into prefix and metric_name causing a cryptic ValueError when
the colon is missing; before calling full_metric_name.split(":", 1) in the
checkpointing block that updates rm_save_state, add an explicit guard that
checks that full_metric_name contains exactly one ':' (or at least contains ':')
and if not raise a clear ValueError (or warnings.warn then skip) that explains
the expected format like "checkpointing.metric_name must be 'prefix:metric'
(e.g. 'train:accuracy')" so callers see a helpful error instead of the unpacking
exception.

In `@nemo_rl/data/utils.py`:
- Around line 116-140: The validation loader currently overwrites earlier
entries when multiple validation configs share the same task_name; update the
block that assigns to val_data_dict, val_task_data_processors, and
val_task_to_env to merge duplicates like the training path: after loading
val_data = load_response_dataset(cfg), check if val_data.task_name already
exists in val_data_dict and if so call concatenate_datasets(existing_dataset,
val_data.dataset) (same helper used on line ~105) and replace the entry,
otherwise set it; ensure val_task_data_processors[val_data.task_name] and
val_task_to_env[val_data.task_name] are only set once (or validated to be
consistent) to avoid inconsistent processor/env mappings.

---

Duplicate comments:
In `@nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py`:
- Around line 57-63: The default_task_name derivation code (building
default_task_name from data_path and trimming a leading '-') is duplicated here
and should be extracted into a shared helper (e.g., a new function
get_default_task_name(data_path) used by both BinaryPreferenceDataset and
OaiFormatDataset); update
nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py to call
that helper instead of repeating the logic around default_task_name, ensure the
helper returns a non-empty string (handle the empty-string edge case by falling
back to a safe name or raising a clear error), and preserve the existing call to
self.common_init(default_task_name=..., **kwargs) while respecting the existing
skip_set_processor behavior (i.e., do not change how skip_set_processor is
passed through to common_init).

In `@nemo_rl/data/datasets/preference_datasets/tulu3.py`:
- Around line 27-28: The call to self.common_init(...) in Tulu3Preference must
either pass skip_set_processor=True or rely on common_init having a default for
skip_set_processor; update the call site in tulu3.py to include
skip_set_processor=<appropriate boolean> (e.g., True) or modify the common_init
signature to provide a default value for skip_set_processor so the processor
behavior matches other datasets like refcoco; reference the common_init function
and the Tulu3Preference initializer to make the change consistently.

In `@nemo_rl/data/datasets/response_datasets/dapo_math.py`:
- Around line 55-60: The constructor calls
self.common_init(default_task_name="DAPOMathAIME2024", **kwargs) but does not
pass skip_set_processor (and common_init may not provide a default), so ensure
skip_set_processor is explicitly handled: either pass skip_set_processor from
kwargs into common_init (e.g., include
skip_set_processor=kwargs.get("skip_set_processor", <desired default>)) or
update common_init to define a safe default for skip_set_processor; locate the
call site in the DAPOMathAIME2024 dataset class __init__ and adjust the
common_init invocation or add the default in the common_init implementation so
skip_set_processor is always defined.
- Around line 26-27: common_init is being called without the skip_set_processor
argument here in the DAPOMath17K dataset initializer; update the call in the
constructor to pass skip_set_processor (same approach used in refcoco) or ensure
common_init defines a default for skip_set_processor. Specifically, modify the
call to self.common_init(default_task_name="DAPOMath17K",
skip_set_processor=..., **kwargs) or add skip_set_processor=False/True as a
default parameter in the common_init signature so the processor behavior is
explicit (refer to the common_init function and the class initializer in
dapo_math.py).

In `@nemo_rl/data/datasets/response_datasets/nemogym_dataset.py`:
- Around line 29-35: The call to self.common_init(...) in nemogym_dataset.py
uses common_init without supplying skip_set_processor, which can cause the same
bug noted in refcoco; update the invocation in the constructor where
default_task_name is computed so it explicitly passes skip_set_processor (e.g.,
skip_set_processor=kwargs.get("skip_set_processor", False)) or ensure
common_init has a default for skip_set_processor, referencing the common_init
method and the place where default_task_name is computed and passed to
self.common_init.

In `@nemo_rl/data/datasets/response_datasets/openmathinstruct2.py`:
- Around line 48-50: The call to self.common_init in OpenMathInstruct2 doesn't
pass skip_set_processor and relies on common_init to provide a safe default;
either update the call in the OpenMathInstruct2 constructor to explicitly pass
skip_set_processor (e.g., skip_set_processor=True or False as appropriate for
this dataset) or change the common_init signature to include a default value for
skip_set_processor so callers like OpenMathInstruct2 can omit it; locate the
common_init definition and add a default (or update the OpenMathInstruct2 call)
to ensure skip_set_processor is always defined.

In `@nemo_rl/data/datasets/response_datasets/squad.py`:
- Around line 30-31: The call to self.common_init in the SQuAD dataset should
explicitly handle the skip_set_processor flag: either pass skip_set_processor
from the SQuAD constructor into self.common_init (e.g.,
self.common_init(default_task_name="squad",
skip_set_processor=skip_set_processor, **kwargs)) or add a default for
skip_set_processor inside the common_init signature so callers that omit it
(like squad) behave correctly; update the SQuAD constructor to pass through the
parameter or update common_init to set a sensible default for skip_set_processor
to avoid the missing-argument issue.

---

Nitpick comments:
In `@examples/configs/sft.yaml`:
- Around line 237-249: Move the explanatory comment about how
data.validation.task_name maps to checkpointing.metric_name from the end of the
data: section into/near the checkpointing: block—specifically adjacent to the
checkpointing.metric_name entry—so users editing metric_name can immediately see
the guidance; reference data.validation.task_name in the moved comment and keep
the example lines (e.g., metric_name: "val:val_loss_dataset1" and metric_name:
"val:val_loss") so the relationship and examples remain intact.

In `@nemo_rl/data/datasets/preference_datasets/preference_dataset.py`:
- Around line 54-60: The task-name derivation in preference_dataset.py
duplicates logic from response_dataset.py: replace the inline derivation (the
default_task_name computation just before the call to self.common_init) with a
call to a shared utility (e.g., task_name_from_path) or move that logic into a
new helper function and use it from both modules; update preference_dataset.py
to compute default_task_name by calling the shared helper and then call
self.common_init(default_task_name=default_task_name, **kwargs) so both files
use the same centralized function instead of duplicating the code.

In `@nemo_rl/data/datasets/response_datasets/response_dataset.py`:
- Around line 56-62: Extract the duplicated task-name derivation into a shared
utility function (e.g., task_name_from_path) and replace the 4-line block in
response_dataset.py and preference_dataset.py with a call to that helper;
specifically, add a function task_name_from_path(data_path: str) in
nemo_rl/data/datasets/utils.py that returns the derived name with leading dashes
stripped, then in the constructors where you currently compute default_task_name
(the block before calling self.common_init), call default_task_name =
task_name_from_path(data_path) and keep the subsequent
self.common_init(default_task_name=default_task_name, **kwargs).

In `@nemo_rl/utils/checkpoint.py`:
- Around line 109-118: Replace the assert-based config validation in the
checkpoint logic with an explicit ValueError: locate the block referencing
self.metric_name (in nemo_rl/utils/checkpoint.py, around the check in the
checkpoint class where metric_name is validated) and change the assert to an
if-check that raises ValueError when metric_name doesn't start with "train:" or
"val:"; update the exception message to be clear and fix the example typo (e.g.
use "val_reward -> 'val:reward'" or "e.g. 'val:reward' for old 'val_reward'")
and include guidance about updating checkpointing.metric_name to the new format.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9148186 and 581cbd9.

📒 Files selected for processing (37)
  • docs/guides/grpo.md
  • docs/guides/sft.md
  • examples/configs/distillation_math.yaml
  • examples/configs/grpo_math_1B.yaml
  • examples/configs/sft.yaml
  • examples/run_sft.py
  • nemo_rl/algorithms/distillation.py
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/algorithms/grpo.py
  • nemo_rl/algorithms/rm.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/data/datasets/preference_datasets/__init__.py
  • nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py
  • nemo_rl/data/datasets/preference_datasets/helpsteer3.py
  • nemo_rl/data/datasets/preference_datasets/preference_dataset.py
  • nemo_rl/data/datasets/preference_datasets/tulu3.py
  • nemo_rl/data/datasets/raw_dataset.py
  • nemo_rl/data/datasets/response_datasets/__init__.py
  • nemo_rl/data/datasets/response_datasets/aime24.py
  • nemo_rl/data/datasets/response_datasets/clevr.py
  • nemo_rl/data/datasets/response_datasets/dapo_math.py
  • nemo_rl/data/datasets/response_datasets/deepscaler.py
  • nemo_rl/data/datasets/response_datasets/geometry3k.py
  • nemo_rl/data/datasets/response_datasets/helpsteer3.py
  • nemo_rl/data/datasets/response_datasets/nemogym_dataset.py
  • nemo_rl/data/datasets/response_datasets/oai_format_dataset.py
  • nemo_rl/data/datasets/response_datasets/oasst.py
  • nemo_rl/data/datasets/response_datasets/openmathinstruct2.py
  • nemo_rl/data/datasets/response_datasets/refcoco.py
  • nemo_rl/data/datasets/response_datasets/response_dataset.py
  • nemo_rl/data/datasets/response_datasets/squad.py
  • nemo_rl/data/datasets/response_datasets/tulu3.py
  • nemo_rl/data/utils.py
  • nemo_rl/utils/checkpoint.py
  • tests/functional/distillation.sh
  • tests/functional/grpo_multiple_datasets.sh
  • tests/functional/sft.sh

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm. small comments

if sum_num_valid_tokens > 0:
val_metrics["val_loss"] /= sum_num_valid_tokens
# Calculate validation metrics
if total_num_valid_tokens > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if total_num_valid_tokens > 0:
if total_num_valid_tokens > 0:
assert "total" not in val_loss, f"total is a reserved task_name since it is used in the metrics as the aggregate label"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


full_metric_name = master_config["checkpointing"]["metric_name"]
if full_metric_name is not None:
assert full_metric_name.startswith(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not removed, just moved to nemo_rl/utils/checkpoint.py. :)
so that we could assert this at the beginning and reduce duplicated things in different algos.

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 25, 2026
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 26, 2026
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 force-pushed the yukih/validation-task-name branch from 74ca707 to fb99598 Compare February 26, 2026 07:48
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 26, 2026
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 force-pushed the yukih/validation-task-name branch from fb99598 to f2d4720 Compare February 26, 2026 10:22
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 26, 2026
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 26, 2026
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants