Skip to content

Commit fcd873d

Browse files
authored
feat: support multiple datasets for response dataset (#1691)
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 7b0b5a5 commit fcd873d

12 files changed

Lines changed: 266 additions & 60 deletions

File tree

docs/guides/grpo.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,31 @@ data:
6868
env_name: "math"
6969
```
7070
71+
We support using multiple datasets for train and validation. You can refer to `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example. Here's an example configuration:
72+
```yaml
73+
data:
74+
_override_: true # override the data config instead of merging with it
75+
# other data settings, see `examples/configs/sft.yaml` for more details
76+
...
77+
# dataset settings
78+
train:
79+
# train dataset 1
80+
- dataset_name: OpenMathInstruct-2
81+
split_validation_size: 0.05 # use 5% of the training data as validation data
82+
seed: 42 # seed for train/validation split when split_validation_size > 0
83+
# train dataset 2
84+
- dataset_name: DeepScaler
85+
validation:
86+
# validation dataset 1
87+
- dataset_name: AIME2024
88+
repeat: 16
89+
# validation dataset 2
90+
- dataset_name: DAPOMathAIME2024
91+
# default settings for all datasets
92+
default:
93+
...
94+
```
95+
7196
We support using a single dataset for both train and validation by using `split_validation_size` to set the validation ratio.
7297
[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature.
7398
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py).

docs/guides/sft.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,31 @@ data:
100100
processor: "sft_processor"
101101
```
102102
103+
We support using multiple datasets for train and validation. You can refer to `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example. Here's an example configuration:
104+
```yaml
105+
data:
106+
_override_: true # override the data config instead of merging with it
107+
# other data settings, see `examples/configs/sft.yaml` for more details
108+
...
109+
# dataset settings
110+
train:
111+
# train dataset 1
112+
- dataset_name: OpenMathInstruct-2
113+
split_validation_size: 0.05 # use 5% of the training data as validation data
114+
seed: 42 # seed for train/validation split when split_validation_size > 0
115+
# train dataset 2
116+
- dataset_name: DeepScaler
117+
validation:
118+
# validation dataset 1
119+
- dataset_name: AIME2024
120+
repeat: 16
121+
# validation dataset 2
122+
- dataset_name: DAPOMathAIME2024
123+
# default settings for all datasets
124+
default:
125+
...
126+
```
127+
103128
We support using a single dataset for both train and validation by using `split_validation_size` to set the ratio of validation.
104129
[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature.
105130
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py).

examples/configs/grpo_math_1B.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ data:
286286
system_prompt_file: null
287287
processor: "math_hf_data_processor"
288288
env_name: "math"
289+
290+
# You can also use multiple datasets by using a list of datasets.
291+
# See `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example.
292+
289293
# You can use custom response datasets for training and validation. For example:
290294
# train:
291295
# # this dataset will override input_key and use the default values for other vars
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# GRPO Algorithm Configuration
2+
defaults: "grpo_math_1B.yaml"
3+
4+
data:
5+
_override_: true # override the data config instead of merging with it
6+
7+
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
8+
shuffle: true
9+
num_workers: 1
10+
11+
# dataset
12+
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.
13+
train:
14+
- dataset_name: OpenMathInstruct-2
15+
split_validation_size: 0.05 # use 5% of the training data as validation data
16+
seed: ${grpo.seed} # seed for train/validation split when split_validation_size > 0
17+
- dataset_name: DeepScaler
18+
validation:
19+
- dataset_name: AIME2024
20+
repeat: 16
21+
- dataset_name: DAPOMathAIME2024
22+
23+
# default settings for all datasets
24+
default:
25+
prompt_file: "examples/prompts/cot.txt"
26+
system_prompt_file: null
27+
processor: "math_hf_data_processor"
28+
env_name: "math"

examples/configs/sft.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ data:
194194
prompt_file: null
195195
system_prompt_file: null
196196
processor: "sft_processor"
197+
198+
# You can also use multiple datasets by using a list of datasets.
199+
# See `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example.
200+
197201
# You can use custom response datasets for training and validation. For example:
198202
# train:
199203
# # this dataset will override input_key and use the default values for other vars
@@ -212,8 +216,7 @@ data:
212216
# processor: "sft_processor"
213217
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.
214218

215-
216-
## OpenAI format specific configs
219+
# OpenAI format specific configs
217220
# train_data_path: "/path/to/train.jsonl" # Path to training data
218221
# val_data_path: "/path/to/val.jsonl" # Path to validation data
219222
# chat_key: "messages" # Key for messages in the data

examples/run_sft.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,30 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
6464

6565
print("\n▶ Setting up data...")
6666
# setup train dataset
67-
if "default" in data_config:
68-
update_single_dataset_config(data_config["train"], data_config["default"])
69-
data = load_response_dataset(data_config["train"])
70-
data_processor = partial(
71-
data.processor,
72-
add_bos=data_config["add_bos"],
73-
add_eos=data_config["add_eos"],
74-
add_generation_prompt=data_config["add_generation_prompt"],
75-
)
76-
task_data_processors = {data.task_name: (data.task_spec, data_processor)}
67+
task_data_processors = {}
68+
data_list = []
69+
70+
if isinstance(data_config["train"], dict):
71+
data_config["train"] = [data_config["train"]]
72+
73+
for cfg in data_config["train"]:
74+
# load dataset
75+
if "default" in data_config and data_config["default"] is not None:
76+
update_single_dataset_config(cfg, data_config["default"])
77+
data = load_response_dataset(cfg)
78+
data_list.append(data)
79+
# bind task_name to task_data_processors
80+
data_processor = partial(
81+
data.processor,
82+
add_bos=data_config["add_bos"],
83+
add_eos=data_config["add_eos"],
84+
add_generation_prompt=data_config["add_generation_prompt"],
85+
)
86+
task_data_processors[data.task_name] = (data.task_spec, data_processor)
7787

88+
merged_data = concatenate_datasets([data.dataset for data in data_list])
7889
dataset = AllTaskProcessedDataset(
79-
data.dataset,
90+
merged_data,
8091
tokenizer,
8192
None,
8293
task_data_processors,
@@ -89,28 +100,35 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
89100
val_data_list = []
90101

91102
# validation dataset from train dataset (when train dataset's split_validation_size > 0)
92-
if hasattr(data, "val_dataset") and data.val_dataset is not None:
93-
val_data_list.append(data.val_dataset)
94-
val_task_data_processors = task_data_processors.copy()
103+
for data in data_list:
104+
if hasattr(data, "val_dataset") and data.val_dataset is not None:
105+
val_data_list.append(data.val_dataset)
106+
# bind task_name to task_data_processors
107+
task_name = data.task_name
108+
val_task_data_processors[task_name] = task_data_processors[task_name]
95109

96110
# validation dataset from config
97111
if "validation" in data_config and data_config["validation"] is not None:
98-
if "default" in data_config:
99-
update_single_dataset_config(
100-
data_config["validation"], data_config["default"]
112+
if isinstance(data_config["validation"], dict):
113+
data_config["validation"] = [data_config["validation"]]
114+
115+
for cfg in data_config["validation"]:
116+
# load dataset
117+
if "default" in data_config and data_config["default"] is not None:
118+
update_single_dataset_config(cfg, data_config["default"])
119+
val_data = load_response_dataset(cfg)
120+
val_data_list.append(val_data.dataset)
121+
# bind task_name to task_data_processors
122+
val_data_processor = partial(
123+
val_data.processor,
124+
add_bos=data_config["add_bos"],
125+
add_eos=data_config["add_eos"],
126+
add_generation_prompt=data_config["add_generation_prompt"],
127+
)
128+
val_task_data_processors[val_data.task_name] = (
129+
val_data.task_spec,
130+
val_data_processor,
101131
)
102-
val_data = load_response_dataset(data_config["validation"])
103-
val_data_list.append(val_data.dataset)
104-
val_data_processor = partial(
105-
val_data.processor,
106-
add_bos=data_config["add_bos"],
107-
add_eos=data_config["add_eos"],
108-
add_generation_prompt=data_config["add_generation_prompt"],
109-
)
110-
val_task_data_processors[val_data.task_name] = (
111-
val_data.task_spec,
112-
val_data_processor,
113-
)
114132

115133
val_dataset = None
116134
if len(val_data_list) > 0:

nemo_rl/data/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ class DataConfig(TypedDict):
4949
num_workers: NotRequired[int]
5050
# dataset configs
5151
# TODO: remove NotRequired once preference dataset is refactored
52-
train: NotRequired[ResponseDatasetConfig]
53-
validation: NotRequired[ResponseDatasetConfig | None]
52+
train: NotRequired[ResponseDatasetConfig | list[ResponseDatasetConfig]]
53+
validation: NotRequired[ResponseDatasetConfig | list[ResponseDatasetConfig] | None]
5454
default: NotRequired[ResponseDatasetConfig | None]
5555
# TODO: remove once preference dataset is refactored
5656
dataset_name: NotRequired[str]

nemo_rl/data/utils.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,27 @@ def setup_data_with_envs(
6868

6969
print("\n▶ Setting up data...")
7070
# setup train dataset
71-
if "default" in data_config:
72-
update_single_dataset_config(data_config["train"], data_config["default"])
73-
data = load_response_dataset(data_config["train"])
74-
task_data_processors = {data.task_name: (data.task_spec, data.processor)}
75-
task_to_env = {data.task_name: envs[data_config["train"]["env_name"]]}
76-
71+
task_data_processors = {}
72+
task_to_env = {}
73+
data_list = []
74+
75+
if isinstance(data_config["train"], dict):
76+
data_config["train"] = [data_config["train"]]
77+
78+
for cfg in data_config["train"]:
79+
# load dataset
80+
if "default" in data_config and data_config["default"] is not None:
81+
update_single_dataset_config(cfg, data_config["default"])
82+
data = load_response_dataset(cfg)
83+
data_list.append(data)
84+
# bind task_name to task_data_processors and task_to_env
85+
task_name = data.task_name
86+
task_data_processors[task_name] = (data.task_spec, data.processor)
87+
task_to_env[task_name] = envs[cfg["env_name"]]
88+
89+
merged_data = concatenate_datasets([data.dataset for data in data_list])
7790
dataset = AllTaskProcessedDataset(
78-
data.dataset,
91+
merged_data,
7992
tokenizer,
8093
None,
8194
task_data_processors,
@@ -89,26 +102,32 @@ def setup_data_with_envs(
89102
val_data_list = []
90103

91104
# validation dataset from train dataset (when train dataset's split_validation_size > 0)
92-
if hasattr(data, "val_dataset") and data.val_dataset is not None:
93-
val_data_list.append(data.val_dataset)
94-
val_task_data_processors = task_data_processors.copy()
95-
val_task_to_env = task_to_env.copy()
105+
for data in data_list:
106+
if hasattr(data, "val_dataset") and data.val_dataset is not None:
107+
val_data_list.append(data.val_dataset)
108+
# bind task_name to task_data_processors and task_to_env
109+
task_name = data.task_name
110+
val_task_data_processors[task_name] = task_data_processors[task_name]
111+
val_task_to_env[task_name] = task_to_env[task_name]
96112

97113
# validation dataset from config
98114
if "validation" in data_config and data_config["validation"] is not None:
99-
if "default" in data_config:
100-
update_single_dataset_config(
101-
data_config["validation"], data_config["default"]
115+
if isinstance(data_config["validation"], dict):
116+
data_config["validation"] = [data_config["validation"]]
117+
118+
for cfg in data_config["validation"]:
119+
# load dataset
120+
if "default" in data_config and data_config["default"] is not None:
121+
update_single_dataset_config(cfg, data_config["default"])
122+
val_data = load_response_dataset(cfg)
123+
val_data_list.append(val_data.dataset)
124+
# bind task_name to task_data_processors and task_to_env
125+
task_name = val_data.task_name
126+
val_task_data_processors[task_name] = (
127+
val_data.task_spec,
128+
val_data.processor,
102129
)
103-
val_data = load_response_dataset(data_config["validation"])
104-
val_data_list.append(val_data.dataset)
105-
val_task_data_processors[val_data.task_name] = (
106-
val_data.task_spec,
107-
val_data.processor,
108-
)
109-
val_task_to_env[val_data.task_name] = envs[
110-
data_config["validation"]["env_name"]
111-
]
130+
val_task_to_env[task_name] = envs[cfg["env_name"]]
112131

113132
val_dataset = None
114133
if len(val_data_list) > 0:

nemo_rl/utils/config.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@ def resolve_path(base_path: Path, path: str) -> Path:
2727
return base_path / path
2828

2929

30+
def merge_with_override(
31+
base_config: DictConfig, override_config: DictConfig
32+
) -> DictConfig:
33+
"""Merge configs with support for _override_ marker to completely override sections."""
34+
for key in list(override_config.keys()):
35+
if isinstance(override_config[key], DictConfig):
36+
if override_config[key].get("_override_", False):
37+
# remove the _override_ marker
38+
override_config[key].pop("_override_")
39+
# remove the key from base_config so it won't be merged
40+
if key in base_config:
41+
base_config.pop(key)
42+
43+
merged_config = cast(DictConfig, OmegaConf.merge(base_config, override_config))
44+
return merged_config
45+
46+
3047
def load_config_with_inheritance(
3148
config_path: Union[str, Path],
3249
base_dir: Optional[Union[str, Path]] = None,
@@ -63,10 +80,12 @@ def load_config_with_inheritance(
6380
for default in defaults:
6481
parent_path = resolve_path(base_dir, str(default))
6582
parent_config = load_config_with_inheritance(parent_path, base_dir)
66-
base_config = cast(DictConfig, OmegaConf.merge(base_config, parent_config))
83+
base_config = cast(
84+
DictConfig, merge_with_override(base_config, parent_config)
85+
)
6786

6887
# Merge with current config
69-
config = cast(DictConfig, OmegaConf.merge(base_config, config))
88+
config = cast(DictConfig, merge_with_override(base_config, config))
7089

7190
return config
7291

tests/functional/L1_Functional_Tests_GPU.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
3535
time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
3636
time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh
3737
time uv run --no-sync bash ./tests/functional/grpo_sglang.sh
38+
time uv run --no-sync bash ./tests/functional/grpo_multiple_datasets.sh
3839
time uv run --no-sync bash ./tests/functional/dpo.sh
3940
time uv run --no-sync bash ./tests/functional/rm.sh
4041
time uv run --no-sync bash ./tests/functional/eval.sh

0 commit comments

Comments
 (0)