diff --git a/tests/assets/losses/sft_debugmodel_cuda.txt b/tests/assets/losses/sft_debugmodel_cuda.txt new file mode 100644 index 0000000000..857381502a --- /dev/null +++ b/tests/assets/losses/sft_debugmodel_cuda.txt @@ -0,0 +1,10 @@ +1 8.296810150146484 +2 7.725587844848633 +3 6.295645713806152 +4 4.756094932556152 +5 4.0870537757873535 +6 3.6305880546569824 +7 3.2472989559173584 +8 2.9624862670898438 +9 2.7819108963012695 +10 2.674215316772461 diff --git a/tests/assets/sft_test/data.json b/tests/assets/sft_test/data.json new file mode 100644 index 0000000000..0ef7c9b481 --- /dev/null +++ b/tests/assets/sft_test/data.json @@ -0,0 +1,42 @@ +[ + { + "question": "What is 2 + 3?", + "answer": "2 + 3 = 5. #### 5" + }, + { + "question": "If you have 10 apples and give away 4, how many do you have left?", + "answer": "10 - 4 = 6. #### 6" + }, + { + "question": "What is 7 * 8?", + "answer": "7 * 8 = 56. #### 56" + }, + { + "question": "A store has 25 books. If 12 are sold, how many remain?", + "answer": "25 - 12 = 13. #### 13" + }, + { + "question": "What is 100 / 5?", + "answer": "100 / 5 = 20. #### 20" + }, + { + "question": "Sam has 3 boxes with 6 toys each. How many toys in total?", + "answer": "3 * 6 = 18. #### 18" + }, + { + "question": "What is 15 + 27?", + "answer": "15 + 27 = 42. #### 42" + }, + { + "question": "A class has 30 students. If 5 are absent, how many are present?", + "answer": "30 - 5 = 25. #### 25" + }, + { + "question": "What is 9 * 9?", + "answer": "9 * 9 = 81. #### 81" + }, + { + "question": "If a pizza is cut into 8 slices and you eat 3, how many are left?", + "answer": "8 - 3 = 5. #### 5" + } +] diff --git a/tests/assets/tokenizer/tokenizer_config.json b/tests/assets/tokenizer/tokenizer_config.json index fecad7728a..306ffd18b3 100644 --- a/tests/assets/tokenizer/tokenizer_config.json +++ b/tests/assets/tokenizer/tokenizer_config.json @@ -60,7 +60,7 @@ "input_ids", "attention_mask" ], - "chat_template": "{% for msg in messages %}<|im_start|>{{ msg.role }}\n{{ msg.content }}<|im_end|>\n{% endfor %}", + "chat_template": "{{ bos_token }}{% for msg in messages %}{{ msg.role }}\n{{ msg.content }}{{ eos_token }}{% endfor %}{% if add_generation_prompt %}assistant\n{% endif %}", "model_max_length": 131072, "tokenizer_class": "PreTrainedTokenizerFast" } diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 0684ff0770..eb8cb2ae7b 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -597,6 +597,16 @@ def build_features_test_list() -> list[OverrideDefinitions]: ngpu=8, skip_rocm_test=True, ), + OverrideDefinitions( + [ + [ + "--module llama3 --config sft_debugmodel", + ], + ], + "SFT ChatDataset integration test", + "sft", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/tests/unit_tests/test_chat_dataset.py b/tests/unit_tests/test_chat_dataset.py new file mode 100644 index 0000000000..a315866635 --- /dev/null +++ b/tests/unit_tests/test_chat_dataset.py @@ -0,0 +1,370 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +from datasets import Dataset + +from torchtitan.components.loss import IGNORE_INDEX +from torchtitan.components.tokenizer import HuggingFaceTokenizer +from torchtitan.hf_datasets.text_datasets import ChatDataset + +# Path to the test tokenizer and fixture data +_ASSETS_DIR = os.path.join(os.path.dirname(__file__), "..", "assets") +_TOKENIZER_PATH = os.path.join(_ASSETS_DIR, "tokenizer") +_DATA_PATH = os.path.join(_ASSETS_DIR, "sft_test", "data.json") + + +def _process_sample(sample): + """Convert a test data sample into [user, assistant] messages.""" + return [ + {"role": "user", "content": sample["question"]}, + {"role": "assistant", "content": sample["answer"]}, + ] + + +def _load_tokenizer(): + return HuggingFaceTokenizer(tokenizer_path=_TOKENIZER_PATH) + + +def _load_dataset(): + return Dataset.from_json(_DATA_PATH) + + +class TestChatDatasetLabelMasking(unittest.TestCase): + """Prompt tokens should be masked (IGNORE_INDEX), assistant tokens should not.""" + + def test_prompt_masked_response_unmasked(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=2048, + infinite=False, + ) + + batch, labels = next(iter(chat_ds)) + input_ids = batch["input"] + label_ids = labels + + self.assertEqual(input_ids.shape, label_ids.shape) + self.assertEqual(input_ids.shape[0], 2048) + + # Some labels at the start should be IGNORE_INDEX (prompt masking) + masked = (label_ids == IGNORE_INDEX).nonzero(as_tuple=True)[0] + unmasked = (label_ids != IGNORE_INDEX).nonzero(as_tuple=True)[0] + self.assertGreater(len(masked), 0, "Expected some masked prompt labels") + self.assertGreater(len(unmasked), 0, "Expected some unmasked response labels") + + # All masked positions should precede all unmasked non-padding positions. + # The unmasked region is the response, then padding follows with IGNORE_INDEX. + # Find first unmasked position and last contiguous unmasked position. + first_unmasked = unmasked[0].item() + self.assertGreater(first_unmasked, 0, "First token label should be masked") + + +class TestChatDatasetShiftedTokens(unittest.TestCase): + """input_ids = tokens[:-1], label_ids = tokens[1:].""" + + def test_shifted_by_one(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=2048, + infinite=False, + ) + + batch, labels = next(iter(chat_ds)) + input_ids = batch["input"] + label_ids = labels + + # Tokenize the first sample directly to get ground truth tokens + sample = ds[0] + messages = _process_sample(sample) + full_text = tokenizer.apply_chat_template(messages) + # Chat templates already include end tokens, so no add_eos + full_tokens = tokenizer.encode(full_text, add_bos=True, add_eos=False) + + expected_input = full_tokens[:-1] + expected_label = full_tokens[1:] + + # The non-padded portion of input_ids should match expected_input + seq_len_actual = len(expected_input) + self.assertEqual( + input_ids[:seq_len_actual].tolist(), + expected_input, + ) + # The non-masked, non-padded portion of label_ids that corresponds to + # the response should come from full_tokens[1:] + # Just verify the response portion matches + prompt_text = tokenizer.apply_chat_template( + messages[:1], add_generation_prompt=True + ) + prompt_tokens = tokenizer.encode(prompt_text, add_bos=True, add_eos=False) + response_start = len(prompt_tokens) # labels [0, prompt_len) are masked + self.assertEqual( + label_ids[response_start:seq_len_actual].tolist(), + expected_label[response_start:], + ) + + +class TestChatDatasetGreedyPacking(unittest.TestCase): + """Multiple short samples packed into one sequence with small seq_len.""" + + def test_packing_multiple_samples(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + # seq_len=256 should fit multiple of the shortest samples (effective_len ~79) + seq_len = 256 + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=seq_len, + infinite=False, + ) + + batches = list(chat_ds) + # With 10 samples of lengths 79-123, they should pack into fewer than 10 batches + self.assertGreater(len(batches), 0) + self.assertLess(len(batches), 10) + + # Each batch should have seq_len tokens + for batch, labels in batches: + self.assertEqual(batch["input"].shape[0], seq_len) + self.assertEqual(labels.shape[0], seq_len) + self.assertIn("positions", batch) + self.assertEqual(batch["positions"].shape[0], seq_len) + + +class TestChatDatasetPerDocumentPositions(unittest.TestCase): + """Positions reset to 0 at each document boundary in packed mode.""" + + def test_positions_reset_at_boundaries(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + seq_len = 256 + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=seq_len, + infinite=False, + ) + + batch, _ = next(iter(chat_ds)) + positions = batch["positions"] + + # Positions should start at 0 + self.assertEqual(positions[0].item(), 0) + + # Find where positions reset to 0 (document boundaries) + resets = (positions[1:] == 0).nonzero(as_tuple=True)[0] + # With seq_len=256 and samples of ~79 tokens, at least one reset + self.assertGreater( + len(resets), 0, "Expected at least one position reset (document boundary)" + ) + + # Between resets, positions should be consecutive (0, 1, 2, ...) + pos_list = positions.tolist() + for i in range(1, len(pos_list)): + if pos_list[i] == 0: + # Document boundary: reset is fine + continue + self.assertEqual( + pos_list[i], + pos_list[i - 1] + 1, + f"Positions should be consecutive at index {i}, " + f"got {pos_list[i - 1]} -> {pos_list[i]}", + ) + + +class TestChatDatasetDropOnOverflow(unittest.TestCase): + """Samples exceeding seq_len are silently dropped.""" + + def test_all_dropped_with_tiny_seq_len(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=32, + infinite=False, + ) + + batches = list(chat_ds) + self.assertEqual(len(batches), 0, "All samples should be dropped at seq_len=32") + + +class TestChatDatasetMessageValidation(unittest.TestCase): + """Non-[user, assistant] messages raise ValueError.""" + + def test_wrong_first_role(self): + tokenizer = _load_tokenizer() + + def bad_processor(sample): + return [ + {"role": "system", "content": "You are helpful."}, + {"role": "assistant", "content": "OK"}, + ] + + ds = Dataset.from_list([{"question": "hi", "answer": "bye"}]) + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=bad_processor, + seq_len=2048, + infinite=False, + ) + + with self.assertRaises(ValueError, msg="system role should raise"): + next(iter(chat_ds)) + + def test_wrong_second_role(self): + tokenizer = _load_tokenizer() + + def bad_processor(sample): + return [ + {"role": "user", "content": "hi"}, + {"role": "user", "content": "hello again"}, + ] + + ds = Dataset.from_list([{"question": "hi", "answer": "bye"}]) + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=bad_processor, + seq_len=2048, + infinite=False, + ) + + with self.assertRaises(ValueError, msg="two user messages should raise"): + next(iter(chat_ds)) + + def test_three_messages(self): + tokenizer = _load_tokenizer() + + def bad_processor(sample): + return [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "bye"}, + ] + + ds = Dataset.from_list([{"question": "hi", "answer": "bye"}]) + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=bad_processor, + seq_len=2048, + infinite=False, + ) + + with self.assertRaises(ValueError, msg="3 messages should raise"): + next(iter(chat_ds)) + + +class TestChatDatasetCheckpointing(unittest.TestCase): + """state_dict / load_state_dict round-trips correctly.""" + + def test_state_dict_round_trip(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + seq_len = 128 + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=seq_len, + infinite=False, + ) + + # Consume one packed batch + it = iter(chat_ds) + next(it) + + state = chat_ds.state_dict() + + # Verify state has expected keys + self.assertIn("sample_idx", state) + self.assertIn("epoch", state) + self.assertIn("inputs_buffer", state) + self.assertIn("labels_buffer", state) + self.assertIn("positions_buffer", state) + self.assertGreater(state["sample_idx"], 0) + self.assertEqual(state["epoch"], 0) + + # Restore and verify the dataset can produce valid packed batches + chat_ds2 = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=seq_len, + infinite=False, + ) + chat_ds2.load_state_dict(state) + + self.assertEqual(chat_ds2._sample_idx, state["sample_idx"]) + self.assertEqual(chat_ds2._epoch, state["epoch"]) + + remaining = list(chat_ds2) + self.assertGreater(len(remaining), 0, "Restored dataset should produce batches") + for batch, labels in remaining: + self.assertEqual(batch["input"].shape[0], seq_len) + self.assertEqual(batch["positions"].shape[0], seq_len) + self.assertEqual(labels.shape[0], seq_len) + + +class TestChatDatasetInfiniteLooping(unittest.TestCase): + """Dataset re-shuffles and continues after exhausting data.""" + + def test_infinite_produces_more_than_dataset_size(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=2048, + infinite=True, + ) + + # The dataset has 10 samples. Consuming 15 requires at least one re-loop. + it = iter(chat_ds) + samples = [next(it) for _ in range(15)] + self.assertEqual(len(samples), 15) + + # After the first 10, the epoch counter should have incremented + self.assertGreaterEqual(chat_ds._epoch, 1) + + def test_infinite_packed(self): + tokenizer = _load_tokenizer() + ds = _load_dataset() + seq_len = 256 + chat_ds = ChatDataset( + dataset=ds, + tokenizer=tokenizer, + sample_processor=_process_sample, + seq_len=seq_len, + infinite=True, + ) + + # Consume enough packed batches to exceed the 10-sample dataset + it = iter(chat_ds) + batches = [next(it) for _ in range(20)] + self.assertEqual(len(batches), 20) + self.assertGreaterEqual(chat_ds._epoch, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/hf_datasets/__init__.py b/torchtitan/hf_datasets/__init__.py index b3886fcbae..5badd9d617 100644 --- a/torchtitan/hf_datasets/__init__.py +++ b/torchtitan/hf_datasets/__init__.py @@ -7,7 +7,6 @@ from collections.abc import Callable from dataclasses import dataclass - __all__ = ["DatasetConfig"] diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index aec4f15129..4fa727f012 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -5,17 +5,19 @@ # LICENSE file in the root directory of this source tree. from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial -from typing import Any +from typing import Annotated, Any, cast import torch +import tyro from datasets import Dataset, load_dataset from datasets.distributed import split_dataset_by_node from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.components.loss import IGNORE_INDEX from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.hf_datasets import DatasetConfig from torchtitan.tools.logging import logger @@ -95,8 +97,9 @@ def __init__( # Variables for checkpointing self._sample_idx = 0 - self._token_buffer: list[int] = [] - self._position_buffer: list[int] = [] + self._epoch: int = 0 + self._inputs_buffer: list[int] = [] + self._positions_buffer: list[int] = [] def _get_data_iter(self): # For map-style datasets, resume by skipping to the correct index @@ -119,24 +122,28 @@ def __iter__(self): sample_tokens = self._tokenizer.encode( sample_text, add_bos=True, add_eos=True ) - self._token_buffer.extend(sample_tokens) + self._inputs_buffer.extend(sample_tokens) # Per-document positions reset at document boundaries, # matching inference frameworks (e.g. vLLM) that start # positions at 0 per request. Positions wrap at seq_len # to stay within the RoPE cache, effectively chunking # long documents into seq_len-sized segments. # TODO: make overflow policy configurable (chunk / truncate / drop). - self._position_buffer.extend( + self._positions_buffer.extend( i % self.seq_len for i in range(len(sample_tokens)) ) self._sample_idx += 1 - while len(self._token_buffer) >= max_buffer_token_len: - x = torch.LongTensor(self._token_buffer[:max_buffer_token_len]) - pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len]) + while len(self._inputs_buffer) >= max_buffer_token_len: + x = torch.LongTensor(self._inputs_buffer[:max_buffer_token_len]) + pos = torch.LongTensor( + self._positions_buffer[:max_buffer_token_len] + ) # update buffers to the remaining tokens - self._token_buffer = self._token_buffer[max_buffer_token_len:] - self._position_buffer = self._position_buffer[max_buffer_token_len:] + self._inputs_buffer = self._inputs_buffer[max_buffer_token_len:] + self._positions_buffer = self._positions_buffer[ + max_buffer_token_len: + ] input = x[:-1] label = x[1:] positions = pos[:-1] @@ -148,6 +155,7 @@ def __iter__(self): else: # Reset offset for the next iteration self._sample_idx = 0 + self._epoch += 1 logger.warning(f"Dataset {self.dataset_name} is being re-looped") # Ensures re-looping a dataset loaded from a checkpoint works correctly if not isinstance(self._data, Dataset): @@ -157,16 +165,13 @@ def __iter__(self): self._data.set_epoch(self._data.epoch + 1) def load_state_dict(self, state_dict): - self._token_buffer = state_dict["token_buffer"] - if "position_buffer" not in state_dict: + self._inputs_buffer = state_dict["inputs_buffer"] + if "positions_buffer" not in state_dict: logger.warning( - "Checkpoint missing 'position_buffer' key in dataset state. " - "Falling back to empty position buffer. This is expected when " - "resuming from a checkpoint saved before position tracking was " - "added, but may cause incorrect RoPE positions with " - "block_causal attention (document packing)." + "Checkpoint missing 'positions_buffer'. Falling back to empty buffer. " + "RoPE positions may be incorrect with block_causal attention." ) - self._position_buffer = state_dict.get("position_buffer", []) + self._positions_buffer = state_dict.get("positions_buffer", []) if isinstance(self._data, Dataset): self._sample_idx = state_dict["sample_idx"] @@ -176,12 +181,13 @@ def load_state_dict(self, state_dict): def state_dict(self): _state_dict: dict[str, Any] = { - "token_buffer": self._token_buffer, - "position_buffer": self._position_buffer, + "inputs_buffer": self._inputs_buffer, + "positions_buffer": self._positions_buffer, } if isinstance(self._data, Dataset): _state_dict["sample_idx"] = self._sample_idx + _state_dict["epoch"] = self._epoch else: # Save the iterable dataset's state to later efficiently resume from it # https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration @@ -240,3 +246,289 @@ def __init__( dp_world_size=dp_world_size, **dataloader_kwargs, ) + + +class ChatDataset(IterableDataset, Stateful): + """Dataset for single-turn chat/instruction-tuning. + + Tokenizes [user, assistant] message pairs, masks prompt tokens with + IGNORE_INDEX in labels, and uses greedy sequence packing with + per-document positions. Implements Stateful for checkpointing. + """ + + def __init__( + self, + dataset: Dataset, + tokenizer: BaseTokenizer, + sample_processor: Callable, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + if tokenizer.eos_id is None: + raise ValueError( + "Tokenizer does not have an eos_id set. " + "ChatDataset requires a tokenizer with a valid EOS token." + ) + + self._data = split_dataset_by_node(dataset, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self._eos_id = tokenizer.eos_id + self.seq_len = seq_len + self.infinite = infinite + self._sample_processor = sample_processor + + self._dataset_id = f"{dataset.info.dataset_name}/{dataset.split}" + + # Variables for checkpointing + self._sample_idx = 0 + self._epoch: int = 0 + self._inputs_buffer: list[int] = [] + self._labels_buffer: list[int] = [] + self._positions_buffer: list[int] = [] + + self._logged_first_sample = False + + def _get_data_iter(self): + if isinstance(self._data, Dataset): + if self._sample_idx == len(self._data): + return iter([]) + return iter(self._data.skip(self._sample_idx)) + + return iter(self._data) + + @staticmethod + def _validate_messages(messages: list[dict[str, str]]) -> None: + """Validate that messages are a single-turn [user, assistant] pair.""" + # TODO: expand this to multi-turn + if len(messages) != 2: + raise ValueError( + f"Expected single-turn [user, assistant], got {len(messages)} messages" + ) + if messages[0]["role"] != "user": + raise ValueError( + f"First message must be 'user', got '{messages[0]['role']}'" + ) + if messages[1]["role"] != "assistant": + raise ValueError( + f"Second message must be 'assistant', got '{messages[1]['role']}'" + ) + + def _tokenize_sample( + self, sample: dict[str, Any] + ) -> tuple[list[int], list[int]] | None: + """Tokenize a single-turn sample and create input/label pairs. + + Returns (input_ids, label_ids) where input_ids = tokens[:-1] and + label_ids = tokens[1:] with prompt tokens masked as IGNORE_INDEX. + Returns None if the sample exceeds seq_len (dropped to avoid + training on truncated responses). + + Uses incremental prefix re-tokenization to find the prompt/response + token boundary, avoiding BPE merge errors. + """ + messages = self._sample_processor(sample) + self._validate_messages(messages) + + full_text = self._tokenizer.apply_chat_template(messages) + # Strip extra newline and ensure the sequence ends with EOS without duplicates + full_text = full_text.rstrip("\n") + full_tokens = self._tokenizer.encode(full_text, add_bos=True, add_eos=False) + if full_tokens[-1] != self._eos_id: + full_tokens.append(self._eos_id) + + if not self._logged_first_sample: + logger.info(f"[ChatDataset] First sample full:\n{full_text}") + self._logged_first_sample = True + + # Drop examples exceeding seq_len rather than truncating. + if len(full_tokens) - 1 > self.seq_len: + logger.debug( + f"Dropping sample {self._sample_idx}: " + f"tokens exceeds seq_len {self.seq_len}" + ) + return None + + input_ids = full_tokens[:-1] + label_ids = full_tokens[1:] + + # Find prompt/response boundary by tokenizing just the user message + # with add_generation_prompt=True. + prompt_text = self._tokenizer.apply_chat_template( + messages[:1], add_generation_prompt=True + ) + prompt_tokens = self._tokenizer.encode(prompt_text, add_bos=True, add_eos=False) + prompt_len = len(prompt_tokens) + + # Mask prompt tokens in labels [0, prompt_len). + mask_end = min(prompt_len, len(label_ids)) + label_ids[:mask_end] = [IGNORE_INDEX] * mask_end + + return input_ids, label_ids + + def __iter__(self): + yield from self._iter_greedy_packed() + + def _iter_greedy_packed(self): + """Greedy packing: pack examples sequentially until seq_len is full. + Document boundaries are marked by EOS tokens between packed examples. + The model's flex/varlen attention mask uses these EOS positions to + prevent cross-document attention. + """ + while True: + for sample in self._get_data_iter(): + # pyrefly: ignore [bad-argument-type] + result = self._tokenize_sample(sample) + self._sample_idx += 1 + if result is None: + continue + + input_ids, label_ids = result + remaining = self.seq_len - len(self._inputs_buffer) + + # If the example doesn't fit, pad and yield current buffer + if len(input_ids) > remaining and len(self._inputs_buffer) > 0: + pad_len = remaining + self._inputs_buffer.extend([self._eos_id] * pad_len) + self._labels_buffer.extend([IGNORE_INDEX] * pad_len) + self._positions_buffer.extend(range(pad_len)) + + yield self._flush_buffers() + + # Add example to buffer with positions resetting to 0 + self._inputs_buffer.extend(input_ids) + self._labels_buffer.extend(label_ids) + self._positions_buffer.extend(range(len(input_ids))) + + if len(self._inputs_buffer) == self.seq_len: + yield self._flush_buffers() + + # Flush remaining buffer at end of data + if len(self._inputs_buffer) > 0: + pad_len = self.seq_len - len(self._inputs_buffer) + if pad_len > 0: + self._inputs_buffer.extend([self._eos_id] * pad_len) + self._labels_buffer.extend([IGNORE_INDEX] * pad_len) + self._positions_buffer.extend(range(pad_len)) + + yield self._flush_buffers() + + if not self.infinite: + logger.warning(f"Chat dataset '{self._dataset_id}' has run out of data") + break + else: + self._sample_idx = 0 + self._epoch += 1 + if isinstance(self._data, Dataset): + self._data = cast( + Dataset, self._data.shuffle(seed=42 + self._epoch) + ) + elif hasattr(self._data, "set_epoch"): + self._data.set_epoch(self._epoch) + logger.warning( + f"Chat dataset '{self._dataset_id}' is being re-looped " + f"(epoch {self._epoch})" + ) + + def _flush_buffers(self): + """Convert buffers to tensors, clear them, and return the batch.""" + input_tensor = torch.tensor(self._inputs_buffer, dtype=torch.long) + label_tensor = torch.tensor(self._labels_buffer, dtype=torch.long) + positions_tensor = torch.tensor(self._positions_buffer, dtype=torch.long) + self._inputs_buffer = [] + self._labels_buffer = [] + self._positions_buffer = [] + return {"input": input_tensor, "positions": positions_tensor}, label_tensor + + def state_dict(self): + _state_dict: dict[str, Any] = { + "epoch": self._epoch, + "inputs_buffer": self._inputs_buffer, + "labels_buffer": self._labels_buffer, + "positions_buffer": self._positions_buffer, + } + + if isinstance(self._data, Dataset): + _state_dict["sample_idx"] = self._sample_idx + else: + _state_dict["data"] = self._data.state_dict() + + return _state_dict + + def load_state_dict(self, state_dict): + self._epoch = state_dict["epoch"] + self._inputs_buffer = state_dict["inputs_buffer"] + self._labels_buffer = state_dict["labels_buffer"] + self._positions_buffer = state_dict["positions_buffer"] + + if isinstance(self._data, Dataset): + self._sample_idx = state_dict["sample_idx"] + # Replay shuffles so _data matches the order at checkpoint time + if self._epoch > 0: + self._data = cast(Dataset, self._data.shuffle(seed=42 + self._epoch)) + else: + assert "data" in state_dict + self._data.load_state_dict(state_dict["data"]) + + +class ChatDataLoader(ParallelAwareDataloader): + """Chat dataloader for instruction/conversation datasets.""" + + @dataclass(kw_only=True, slots=True) + class Config(ParallelAwareDataloader.Config): + dataset_path: str | None = None + """HuggingFace dataset path (e.g., 'openai/gsm8k') or local path. Required.""" + + load_dataset_kwargs: dict[str, Any] = field(default_factory=dict) + """Extra kwargs passed to datasets.load_dataset().""" + + sample_processor: Annotated[Callable, tyro.conf.Suppress] + """Callable(sample_dict) -> list[message_dict]. Set in config functions.""" + + infinite: bool = True + """Whether to loop the dataset infinitely. Might hang on multi-GPU.""" + + def __init__( + self, + config: Config, + *, + dp_world_size: int, + dp_rank: int, + tokenizer: BaseTokenizer, + seq_len: int, + local_batch_size: int, + **kwargs, + ): + if not config.dataset_path: + raise ValueError( + "ChatDataLoader requires dataset_path to be set " + "(e.g., 'openai/gsm8k' or 'json')." + ) + + dataset = load_dataset(config.dataset_path, **config.load_dataset_kwargs) + + chat_ds = ChatDataset( + dataset=dataset, + tokenizer=tokenizer, + sample_processor=config.sample_processor, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=config.infinite, + ) + + dataloader_kwargs = { + "num_workers": config.num_workers, + "persistent_workers": config.persistent_workers, + "pin_memory": config.pin_memory, + "prefetch_factor": config.prefetch_factor, + "batch_size": local_batch_size, + } + + super().__init__( + chat_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + **dataloader_kwargs, + ) diff --git a/torchtitan/models/llama3/config_registry.py b/torchtitan/models/llama3/config_registry.py index 41c9d486d5..8bc8e953e7 100644 --- a/torchtitan/models/llama3/config_registry.py +++ b/torchtitan/models/llama3/config_registry.py @@ -19,7 +19,10 @@ ParallelismConfig, TrainingConfig, ) -from torchtitan.hf_datasets.text_datasets import HuggingFaceTextDataLoader +from torchtitan.hf_datasets.text_datasets import ( + ChatDataLoader, + HuggingFaceTextDataLoader, +) from torchtitan.protocols.model_converter import ModelConvertersContainer from torchtitan.tools.profiling import ProfilingConfig from torchtitan.trainer import Trainer @@ -212,3 +215,52 @@ def llama3_405b() -> Trainer.Config: steps=1200, ), ) + + +def sft_debugmodel() -> Trainer.Config: + """SFT debug config with Llama3 debugmodel and local test data.""" + + def process_sample(sample): + return [ + {"role": "user", "content": sample["question"]}, + {"role": "assistant", "content": sample["answer"]}, + ] + + model_spec = model_registry("debugmodel") + # pyrefly: ignore [missing-attribute] + model_spec.model.layer.attention.attn_backend = "flex" + # pyrefly: ignore [missing-attribute] + model_spec.model.layer.attention.attn_mask_type = "block_causal" + + return Trainer.Config( + hf_assets_path="./tests/assets/tokenizer", + model_spec=model_spec, + optimizer=OptimizersContainer.Config(lr=8e-4), + lr_scheduler=LRSchedulersContainer.Config( + warmup_steps=2, + decay_ratio=0.8, + decay_type="linear", + min_lr_factor=0.0, + ), + training=TrainingConfig( + local_batch_size=8, + seq_len=2048, + steps=10, + ), + dataloader=ChatDataLoader.Config( + dataset_path="json", + load_dataset_kwargs={ + "data_files": "tests/assets/sft_test/data.json", + "split": "train", + }, + sample_processor=process_sample, + ), + metrics=MetricsProcessor.Config(log_freq=1), + checkpoint=CheckpointManager.Config( + interval=10, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="selective", + ), + ) diff --git a/torchtitan/models/qwen3/config_registry.py b/torchtitan/models/qwen3/config_registry.py index 2ce9be3a34..80aed1f6bb 100644 --- a/torchtitan/models/qwen3/config_registry.py +++ b/torchtitan/models/qwen3/config_registry.py @@ -13,7 +13,10 @@ ParallelismConfig, TrainingConfig, ) -from torchtitan.hf_datasets.text_datasets import HuggingFaceTextDataLoader +from torchtitan.hf_datasets.text_datasets import ( + ChatDataLoader, + HuggingFaceTextDataLoader, +) from torchtitan.trainer import Trainer from . import model_registry @@ -216,3 +219,52 @@ def qwen3_moe_debug() -> Trainer.Config: mode="selective", ), ) + + +def sft_qwen3_8b_math() -> Trainer.Config: + """Qwen3-8B SFT on GSM8K math dataset.""" + + def process_sample(sample): + answer = sample["answer"] + reasoning, final_answer = answer.rsplit("####", 1) + return [ + {"role": "user", "content": sample["question"]}, + { + "role": "assistant", + "reasoning_content": reasoning.strip(), + "content": final_answer.strip(), + }, + ] + + model_spec = model_registry("8B", attn_backend_override="varlen") + return Trainer.Config( + hf_assets_path="./assets/hf/Qwen3-8B", + model_spec=model_spec, + optimizer=OptimizersContainer.Config(lr=2e-5), + lr_scheduler=LRSchedulersContainer.Config( + warmup_steps=15, + decay_ratio=0.9, + decay_type="cosine", + min_lr_factor=0.1, + ), + training=TrainingConfig( + local_batch_size=1, + seq_len=2048, + steps=180, + ), + dataloader=ChatDataLoader.Config( + dataset_path="openai/gsm8k", + load_dataset_kwargs={"name": "main", "split": "train"}, + sample_processor=process_sample, + ), + metrics=MetricsProcessor.Config( + enable_wandb=True, + ), + checkpoint=CheckpointManager.Config( + enable=True, + initial_load_in_hf=True, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="selective", + ), + )