Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/assets/losses/sft_debugmodel_cuda.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
1 8.296810150146484
2 7.725587844848633
3 6.295645713806152
4 4.756094932556152
5 4.0870537757873535
6 3.6305880546569824
7 3.2472989559173584
8 2.9624862670898438
9 2.7819108963012695
10 2.674215316772461
42 changes: 42 additions & 0 deletions tests/assets/sft_test/data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[
{
"question": "What is 2 + 3?",
"answer": "2 + 3 = 5. #### 5"
},
{
"question": "If you have 10 apples and give away 4, how many do you have left?",
"answer": "10 - 4 = 6. #### 6"
},
{
"question": "What is 7 * 8?",
"answer": "7 * 8 = 56. #### 56"
},
{
"question": "A store has 25 books. If 12 are sold, how many remain?",
"answer": "25 - 12 = 13. #### 13"
},
{
"question": "What is 100 / 5?",
"answer": "100 / 5 = 20. #### 20"
},
{
"question": "Sam has 3 boxes with 6 toys each. How many toys in total?",
"answer": "3 * 6 = 18. #### 18"
},
{
"question": "What is 15 + 27?",
"answer": "15 + 27 = 42. #### 42"
},
{
"question": "A class has 30 students. If 5 are absent, how many are present?",
"answer": "30 - 5 = 25. #### 25"
},
{
"question": "What is 9 * 9?",
"answer": "9 * 9 = 81. #### 81"
},
{
"question": "If a pizza is cut into 8 slices and you eat 3, how many are left?",
"answer": "8 - 3 = 5. #### 5"
}
]
2 changes: 1 addition & 1 deletion tests/assets/tokenizer/tokenizer_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"input_ids",
"attention_mask"
],
"chat_template": "{% for msg in messages %}<|im_start|>{{ msg.role }}\n{{ msg.content }}<|im_end|>\n{% endfor %}",
"chat_template": "{{ bos_token }}{% for msg in messages %}{{ msg.role }}\n{{ msg.content }}{{ eos_token }}{% endfor %}{% if add_generation_prompt %}assistant\n{% endif %}",
"model_max_length": 131072,
"tokenizer_class": "PreTrainedTokenizerFast"
}
10 changes: 10 additions & 0 deletions tests/integration_tests/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,16 @@ def build_features_test_list() -> list[OverrideDefinitions]:
ngpu=8,
skip_rocm_test=True,
),
OverrideDefinitions(
[
[
"--module llama3 --config sft_debugmodel",
],
],
"SFT ChatDataset integration test",
"sft",
ngpu=2,
),
]

return integration_tests_flavors
Loading
Loading