Skip to content

Commit bb71c64

Browse files
committed
fix silly issue in dataloader, this version is much faster and more portable to mps too
1 parent c9ea7a9 commit bb71c64

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

nanochat/dataloader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
1616
bos_token = tokenizer.get_bos_token_id()
1717
# scratch buffer holds the tokens for one iteration
1818
token_buffer = deque() # we stream tokens on the right and pop from the left
19-
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
2019

2120
# infinite iterator over document batches
2221
def document_batches():
@@ -38,8 +37,8 @@ def document_batches():
3837
token_buffer.extend(tokens)
3938
batch_index += 1
4039
# Move tokens from the deque into the scratch buffer
41-
for i in range(needed_tokens):
42-
scratch[i] = token_buffer.popleft()
40+
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
41+
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
4342
# Create the inputs/targets as 1D tensors
4443
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
4544
targets_cpu = scratch[1:]

0 commit comments

Comments
 (0)