-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathprompt_dataset.py
More file actions
96 lines (87 loc) · 3.09 KB
/
prompt_dataset.py
File metadata and controls
96 lines (87 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""prompt dataset"""
from typing import List, Dict
from torch.utils.data import Dataset
from transformers import AutoTokenizer
class PromptPipeline(Dataset):
"""
process this format
{
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question,
}],
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": solution
},
"extra_info": {
'split': split,
'index': idx,
'answer': answer_raw,
"question": question_raw,
}
}
self.data format
{"input_ids": prompt_ids, "prompt": prompt}
"""
def __init__(
self,
data_list: List[Dict],
max_prompt_tokens_length: int,
tokenizer: AutoTokenizer = None,
enable_thinking=False,
raw_chat=True,
): # pylint: disable=super-init-not-called
super().__init__()
self.tokenizer = tokenizer
self.data = []
self.max_prompt = 0
for data_item in data_list:
prompt = data_item["prompt"]
data_source = data_item.get("data_source", "")
ground_truth = data_item["reward_model"]["ground_truth"]
agent_name = data_item.get("agent_name", None)
agent_cfg_path = data_item.get("agent_cfg_path", None)
processed_data = {
"data_source": data_source,
"ground_truth": ground_truth,
"agent_name": agent_name,
"agent_cfg_path": agent_cfg_path
}
if not raw_chat:
if isinstance(prompt, list):
prompt = self.tokenizer.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
input_ids = self.tokenizer.encode(prompt)
# When partial rollout enabled:
# input_ids may change (contain response tokens from previous rollouts)
# prompt_token_ids will always be origial prompt tokens
processed_data.update({
"input_ids": input_ids,
"prompt": prompt,
"prompt_token_length": len(input_ids),
"prompt_token_ids": input_ids
})
# Filter out data with long input_ids
if len(input_ids) > self.max_prompt:
self.max_prompt = len(input_ids)
if max_prompt_tokens_length > len(input_ids):
self.data.append(processed_data)
else:
processed_data.update({
"messages": prompt
})
self.data.append(processed_data)
self.valid_ratio = len(self.data) / len(data_list)
def __getitem__(self, ix: int):
return self.data[ix]
def __len__(self) -> int:
return len(self.data)
def collate_fn(self, samples):
return samples