Skip to content
15 changes: 15 additions & 0 deletions docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ data:
...
```

To use a specific validation dataset's accuracy for checkpointing, set `task_name` in the validation dataset config to override the default, and set `checkpointing.metric_name` to `val:accuracy_<TaskName>`. Checkpoints will then be saved based on that dataset's accuracy.
```yaml
data:
...
validation:
- dataset_name: ResponseDataset
task_name: "dataset1"
- dataset_name: ResponseDataset
...
checkpointing:
metric_name: "val:accuracy_dataset1" # this will save the checkpoint according to the accuracy of the first dataset
# metric_name: "val:accuracy" # this will save the checkpoint according to the accuracy of all the datasets
...
```

We support using a single dataset for both train and validation by using `split_validation_size` to set the validation ratio.
[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature.
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py).
Expand Down
15 changes: 15 additions & 0 deletions docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ data:
...
```

To use a specific validation dataset's val_loss for checkpointing, set `task_name` in the validation dataset config to override the default, and set `checkpointing.metric_name` to `val:val_loss_<TaskName>`. Checkpoints will then be saved based on that dataset's val_loss.
```yaml
data:
...
validation:
- dataset_name: ResponseDataset
task_name: "dataset1"
- dataset_name: ResponseDataset
...
checkpointing:
metric_name: "val:val_loss_dataset1" # this will save the checkpoint according to the val_loss of the first dataset
# metric_name: "val:val_loss" # this will save the checkpoint according to the val_loss of all the datasets
...
```

We support using a single dataset for both train and validation by using `split_validation_size` to set the ratio of validation.
[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature.
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py).
Expand Down
4 changes: 3 additions & 1 deletion examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ loss_fn:
checkpointing:
enabled: true
checkpoint_dir: "checkpoints/distillation-${policy.model_name}"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
# one of "val:" or "train:" followed by the metric name
# if you want to use the accuracy from a specific dataset, you can use the format "val:accuracy_<TaskName>", task_name can be set at dataset config
metric_name: "val:accuracy"
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down
17 changes: 16 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ loss_fn:
checkpointing:
enabled: true
checkpoint_dir: "results/grpo"
metric_name: "val:accuracy" # one of "val:" or "train:" followed by the metric name
# one of "val:" or "train:" followed by the metric name
# if you want to use the accuracy from a specific dataset, you can use the format "val:accuracy_<TaskName>", task_name can be set at dataset config
metric_name: "val:accuracy"
higher_is_better: true
keep_top_k: 3
save_period: 10
Expand Down Expand Up @@ -327,6 +329,19 @@ data:
# env_name: math
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#datasets for more details.

# If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example:
# data:
# ...
# validation:
# - dataset_name: ResponseDataset
# task_name: "dataset1"
# - dataset_name: ResponseDataset
# ...
# checkpointing:
# metric_name: "val:accuracy_dataset1" # this will save the checkpoint according to the accuracy of the first dataset
# # metric_name: "val:accuracy" # this will save the checkpoint according to the accuracy of all the datasets
# ...

env:
math:
num_workers: 8
Expand Down
17 changes: 16 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ sft:
checkpointing:
enabled: true
checkpoint_dir: "results/sft"
metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name
# one of "val:" or "train:" followed by the metric name
# if you want to use the val_loss from a specific dataset, you can use the format "val:val_loss_<TaskName>", task_name can be set at dataset config
metric_name: "val:val_loss"
higher_is_better: false
keep_top_k: 3
save_period: 10
Expand Down Expand Up @@ -232,6 +234,19 @@ data:
# tool_key: "tools" # Key for tools in the data
# use_preserving_dataset: false # If true, uses PreservingDataset to preserve heterogeneous schemas (e.g., tool calls with varying argument structures)

# If you are doing checkpointing, `metric_name` should reflect the metric and validation set to be tracked. For example:
# data:
# ...
# validation:
# - dataset_name: ResponseDataset
# task_name: "dataset1"
# - dataset_name: ResponseDataset
# ...
# checkpointing:
# metric_name: "val:val_loss_dataset1" # this will save the checkpoint according to the val_loss of the first dataset
# # metric_name: "val:val_loss" # this will save the checkpoint according to the val_loss of all the datasets
# ...

logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running
Expand Down
9 changes: 8 additions & 1 deletion examples/run_grpo_sliding_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SlidingPuzzleGameLogic,
SlidingPuzzleMetadata,
)
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.utils.config import (
load_config,
Expand Down Expand Up @@ -157,7 +158,12 @@ def setup_puzzle_data(
length: int,
val_length: int,
add_system_prompt: bool,
) -> tuple[IterableDataset, IterableDataset | None, dict, dict]:
) -> tuple[
IterableDataset,
dict[str, IterableDataset],
dict[str, EnvironmentInterface],
dict[str, EnvironmentInterface],
]:
"""Sets up the iterable data generator and env map for the sliding puzzle task."""
print("Setting up Sliding Puzzle iterable data and environment...")
env_config = env_cfg[task_name]
Expand Down Expand Up @@ -186,6 +192,7 @@ def setup_puzzle_data(
add_system_prompt=add_system_prompt,
length=val_length,
)
validation_dataset = {"sliding_puzzle_game": validation_dataset}
val_task_to_env = task_to_env

return training_dataset, validation_dataset, task_to_env, val_task_to_env
Expand Down
31 changes: 17 additions & 14 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):

# setup validation dataset
val_task_data_processors = {}
val_data_list = []
val_data_dict = {}

# validation dataset from train dataset (when train dataset's split_validation_size > 0)
for data in data_list:
if hasattr(data, "val_dataset") and data.val_dataset is not None:
val_data_list.append(data.val_dataset)
val_data_dict[data.task_name] = data.val_dataset
# bind task_name to task_data_processors
task_name = data.task_name
val_task_data_processors[task_name] = task_data_processors[task_name]
Expand All @@ -118,7 +118,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
if "default" in data_config and data_config["default"] is not None:
update_single_dataset_config(cfg, data_config["default"])
val_data = load_response_dataset(cfg)
val_data_list.append(val_data.dataset)
val_data_dict[val_data.task_name] = val_data.dataset
# bind task_name to task_data_processors
val_data_processor = partial(
val_data.processor,
Expand All @@ -131,17 +131,20 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
val_data_processor,
)

val_dataset = None
if len(val_data_list) > 0:
merged_val_data = concatenate_datasets(val_data_list)
val_dataset = AllTaskProcessedDataset(
merged_val_data,
tokenizer,
None,
val_task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.")
val_dataset = {}
if len(val_data_dict) > 0:
val_dataset = {
task_name: AllTaskProcessedDataset(
val_data,
tokenizer,
None,
val_task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
for task_name, val_data in val_data_dict.items()
}
val_sample_count = sum(len(val_data) for val_data in val_data_dict.values())
print(f" ✓ Validation dataset loaded with {val_sample_count} samples.")

return dataset, val_dataset

Expand Down
Loading