Skip to content

Commit 33e8a27

Browse files
authored
Merge karpathy/cpu-mps-dev , adding the ability to run on CPU, on MPS, or on CUDA, with autodetect. Gnarly PR, nonzero chance I broke something.
add cpu|mps support
2 parents bb71c64 + 50bea28 commit 33e8a27

19 files changed

+266
-93
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ And a bit more about computing environments that will run nanochat:
9595

9696
## Running on CPU / MPS
9797

98-
If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience.
98+
nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
9999

100100
## Customization
101101

dev/runcpu.sh

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#!/bin/bash
2+
3+
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
4+
# Run as:
5+
# bash dev/cpu_demo_run.sh
6+
7+
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
8+
# Think of this run as educational/fun demo, not something you should expect to work well.
9+
# This is also why I hide this script away in dev/
10+
11+
# all the setup stuff
12+
export OMP_NUM_THREADS=1
13+
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
14+
mkdir -p $NANOCHAT_BASE_DIR
15+
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
16+
[ -d ".venv" ] || uv venv
17+
uv sync
18+
source .venv/bin/activate
19+
if [ -z "$WANDB_RUN" ]; then
20+
WANDB_RUN=dummy
21+
fi
22+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
23+
source "$HOME/.cargo/env"
24+
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
25+
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
26+
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
27+
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
28+
unzip -q eval_bundle.zip
29+
rm eval_bundle.zip
30+
mv eval_bundle $NANOCHAT_BASE_DIR
31+
fi
32+
33+
# wipe the report
34+
python -m nanochat.report reset
35+
36+
# train tokenizer on ~1B characters
37+
python -m nanochat.dataset -n 4
38+
python -m scripts.tok_train --max_chars=1000000000
39+
python -m scripts.tok_eval
40+
41+
# train a very small 4 layer model on the CPU
42+
# each optimization step processes a single sequence of 1024 tokens
43+
# we only run 50 steps of optimization (bump this to get better results)
44+
python -m scripts.base_train \
45+
--depth=4 \
46+
--max_seq_len=1024 \
47+
--device_batch_size=1 \
48+
--total_batch_size=1024 \
49+
--eval_every=50 \
50+
--eval_tokens=4096 \
51+
--core_metric_every=50 \
52+
--core_metric_max_per_task=12 \
53+
--sample_every=50 \
54+
--num_iterations=50
55+
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
56+
python -m scripts.base_eval --max-per-task=5
57+
58+
# midtraining
59+
python -m scripts.mid_train \
60+
--max_seq_len=1024 \
61+
--device_batch_size=1 \
62+
--eval_every=50 \
63+
--eval_tokens=4096 \
64+
--total_batch_size=1024 \
65+
--num_iterations=100
66+
# eval results will be terrible, this is just to execute the code paths.
67+
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
68+
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
69+
70+
# SFT
71+
python -m scripts.chat_sft \
72+
--device_batch_size=1 \
73+
--target_examples_per_step=4 \
74+
--num_iterations=100 \
75+
--eval_steps=4 \
76+
--eval_metrics_max_problems=16
77+
78+
# Chat CLI
79+
# python -m scripts.chat_cli -p "Why is the sky blue?"
80+
81+
# Chat Web
82+
# python -m scripts.chat_web
83+
84+
python -m nanochat.report generate

nanochat/common.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,46 @@ def get_dist_info():
8989
else:
9090
return False, 0, 0, 1
9191

92-
def compute_init():
92+
def autodetect_device_type():
93+
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
94+
if torch.cuda.is_available():
95+
device_type = "cuda"
96+
elif torch.backends.mps.is_available():
97+
device_type = "mps"
98+
else:
99+
device_type = "cpu"
100+
print0(f"Autodetected device type: {device_type}")
101+
return device_type
102+
103+
def compute_init(device_type="cuda"): # cuda|cpu|mps
93104
"""Basic initialization that we keep doing over and over, so make common."""
94105

95-
# CUDA is currently required
96-
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
106+
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
107+
if device_type == "cuda":
108+
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
109+
if device_type == "mps":
110+
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
97111

98112
# Reproducibility
99113
torch.manual_seed(42)
100-
torch.cuda.manual_seed(42)
114+
if device_type == "cuda":
115+
torch.cuda.manual_seed(42)
101116
# skipping full reproducibility for now, possibly investigate slowdown later
102117
# torch.use_deterministic_algorithms(True)
103-
# torch.backends.cudnn.deterministic = True
104-
# torch.backends.cudnn.benchmark = False
105118

106119
# Precision
107-
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
120+
if device_type == "cuda":
121+
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
108122

109-
# Distributed setup: Distributed Data Parallel (DDP), optional
123+
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
110124
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
111-
if ddp:
125+
if ddp and device_type == "cuda":
112126
device = torch.device("cuda", ddp_local_rank)
113127
torch.cuda.set_device(device) # make "cuda" default to this device
114128
dist.init_process_group(backend="nccl", device_id=device)
115129
dist.barrier()
116130
else:
117-
device = torch.device("cuda")
131+
device = torch.device(device_type) # mps|cpu
118132

119133
if ddp_rank == 0:
120134
logger.info(f"Distributed world size: {ddp_world_size}")

nanochat/dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from nanochat.dataset import parquets_iter_batched
77
from nanochat.tokenizer import get_tokenizer
88

9-
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
9+
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
1010
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
1111
assert split in ["train", "val"], "split must be 'train' or 'val'"
1212
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@@ -43,6 +43,6 @@ def document_batches():
4343
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
4444
targets_cpu = scratch[1:]
4545
# Reshape to 2D and move to GPU async
46-
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
47-
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
46+
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
47+
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
4848
yield inputs, targets

nanochat/execution.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
146146
with caution.
147147
"""
148148

149-
if maximum_memory_bytes is not None:
149+
if platform.uname().system != "Darwin":
150+
# These resource limit calls seem to fail on macOS (Darwin), skip?
150151
import resource
151-
152152
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
153153
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
154-
if not platform.uname().system == "Darwin":
155-
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
154+
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
156155

157156
faulthandler.disable()
158157

@@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
225224
rmtree = shutil.rmtree
226225
rmdir = os.rmdir
227226
chdir = os.chdir
227+
unlink = os.unlink
228228

229229
# Disable functionalities that can make destructive changes to the test.
230230
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
@@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
282282
shutil.rmtree = rmtree
283283
os.rmdir = rmdir
284284
os.chdir = chdir
285+
os.unlink = unlink
285286

286287

287288
def execute_code(

nanochat/gpt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,6 @@ def __init__(self, config):
169169
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
170170
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
171171
self.register_buffer("sin", sin, persistent=False)
172-
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
173-
self.transformer.wte.to(dtype=torch.bfloat16)
174172

175173
def init_weights(self):
176174
self.apply(self._init_weights)
@@ -184,6 +182,9 @@ def init_weights(self):
184182
head_dim = self.config.n_embd // self.config.n_head
185183
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
186184
self.cos, self.sin = cos, sin
185+
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
186+
if self.transformer.wte.weight.device.type == "cuda":
187+
self.transformer.wte.to(dtype=torch.bfloat16)
187188

188189
def _init_weights(self, module):
189190
if isinstance(module, nn.Linear):

nanochat/loss_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
3333
loss2d = model(x, y, loss_reduction='none') # (B, T)
3434
loss2d = loss2d.view(-1) # flatten
3535
y = y.view(-1) # flatten
36-
if (y < 0).any():
36+
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
3737
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
3838
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
3939
valid = y >= 0

nanochat/report.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ def generate(self):
283283
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
284284
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
285285
bloat_data = bloat_data.group(1) if bloat_data else ""
286+
else:
287+
start_time = None # will cause us to not write the total wall clock time
288+
bloat_data = "[bloat data missing]"
289+
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
286290
# process all the individual sections
287291
for file_name in EXPECTED_FILES:
288292
section_file = os.path.join(report_dir, file_name)

pyproject.toml

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"numpy==1.26.4",
1212
"psutil>=7.1.0",
1313
"regex>=2025.9.1",
14+
"setuptools>=80.9.0",
1415
"tiktoken>=0.11.0",
1516
"tokenizers>=0.22.0",
1617
"torch>=2.8.0",
@@ -22,17 +23,6 @@ dependencies = [
2223
requires = ["maturin>=1.7,<2.0"]
2324
build-backend = "maturin"
2425

25-
# target torch to cuda 12.8
26-
[tool.uv.sources]
27-
torch = [
28-
{ index = "pytorch-cu128" },
29-
]
30-
31-
[[tool.uv.index]]
32-
name = "pytorch-cu128"
33-
url = "https://download.pytorch.org/whl/cu128"
34-
explicit = true
35-
3626
[tool.maturin]
3727
module-name = "rustbpe"
3828
bindings = "pyo3"
@@ -53,3 +43,20 @@ testpaths = ["tests"]
5343
python_files = ["test_*.py"]
5444
python_classes = ["Test*"]
5545
python_functions = ["test_*"]
46+
47+
# target torch to cuda 12.8
48+
[tool.uv.sources]
49+
torch = [
50+
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
51+
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
52+
]
53+
54+
[[tool.uv.index]]
55+
name = "pytorch-cpu"
56+
url = "https://download.pytorch.org/whl/cpu"
57+
explicit = true
58+
59+
[[tool.uv.index]]
60+
name = "pytorch-cu128"
61+
url = "https://download.pytorch.org/whl/cu128"
62+
explicit = true

run1000.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/bin/bash
2+
13
# The $1000 tier of nanochat
24
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
35
# A bit sparser on comments, see speedrun.sh for more detail

0 commit comments

Comments
 (0)