-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup_pytorch_dev.py
More file actions
executable file
·346 lines (292 loc) · 14.1 KB
/
setup_pytorch_dev.py
File metadata and controls
executable file
·346 lines (292 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
#!/usr/bin/env -S uv run --python python3.11
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "click>=8.1.0",
# "rich>=13.0.0",
# ]
# ///
"""PyTorch Development Environment Setup Script"""
import click
import shutil
import subprocess
import os
from pathlib import Path
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.prompt import Confirm
from rich.progress import track
console = Console()
def run(cmd, capture=False, silent=False, cwd=None, env=None):
"""Run command with error handling."""
run_env = os.environ.copy()
if env:
run_env.update(env)
try:
if capture:
return subprocess.run(cmd, check=True, capture_output=True, text=True,
cwd=cwd, env=run_env).stdout.strip()
subprocess.run(cmd, check=True, cwd=cwd, env=run_env,
stdout=subprocess.DEVNULL if silent else None)
except subprocess.CalledProcessError as e:
console.print(f"[red]Command failed:[/red] {' '.join(map(str, cmd))}")
raise click.ClickException(str(e))
def check_prerequisites():
"""Check required tools."""
for tool, url in [('uv', 'https://docs.astral.sh/uv/'),
('direnv', 'https://direnv.net/'),
('git', 'https://git-scm.com/')]:
if not shutil.which(tool):
raise click.ClickException(f"{tool} not found. Install from: {url}")
console.print("✓ All prerequisites found", style="green")
def handle_directory(target_dir, force):
"""Handle target directory creation/cleanup."""
if target_dir.exists() and list(target_dir.iterdir()):
console.print(f"[yellow]Directory exists with {len(list(target_dir.iterdir()))} items[/yellow]")
if not force and not Confirm.ask("Delete contents?", default=False):
raise click.Abort()
shutil.rmtree(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
console.print(f"✓ Prepared {target_dir}", style="green")
def setup_python(version, no_gil):
"""Setup Python version."""
if no_gil and not version.startswith(('3.13', '3.14', '3.15')):
raise click.BadParameter("--no-gil requires Python 3.13+")
# Ensure uv has this Python version installed (not just system Python)
python_spec = f"{version}t" if no_gil else version
with console.status(f"Ensuring Python {python_spec} is available..."):
# Install the Python version via uv (this ensures a managed version, not system)
run(['uv', 'python', 'install', python_spec], silent=True)
# Verify it's installed and get the full version
python_list = run(['uv', 'python', 'list'], capture=True, silent=True)
# Find the uv-managed Python (not system ones)
managed_python = None
for line in python_list.split('\n'):
if python_spec in line and '.local/share/uv/python/' in line:
managed_python = line.split()[0] # Get the full cpython-X.Y.Z specification
break
if not managed_python:
# Fallback: just use the version spec and hope for the best
console.print(f"[yellow]Warning: Using version spec {python_spec}, may use system Python[/yellow]")
managed_python = python_spec
else:
console.print(f"✓ Using managed Python: {managed_python}", style="dim")
console.print(f"✓ Python {version} ready", style="green")
return managed_python
def create_venv(target_dir, python_spec, dry_run):
"""Create virtual environment."""
venv_path = target_dir / '.venv'
if dry_run:
console.print(f"[dim]Would create venv: {venv_path}[/dim]")
return
with console.status("Creating virtual environment..."):
run(['uv', 'venv', str(venv_path), '--python', python_spec])
console.print("✓ Virtual environment created", style="green")
def setup_direnv(target_dir, is_source, dry_run):
"""Setup direnv configuration."""
envrc = "# Auto-generated by setup_pytorch_dev.py\nsource .venv/bin/activate\n"
if is_source:
envrc += "export PYTHONPATH=$PWD/pytorch:$PYTHONPATH\n"
# Add touch test.py for compatibility with old scripts
envrc += "touch test.py\n"
if dry_run:
console.print(f"[dim]Would create .envrc[/dim]")
return
(target_dir / '.envrc').write_text(envrc)
run(['direnv', 'allow', str(target_dir)])
console.print("✓ Direnv configured", style="green")
def install_binary(target_dir, cuda, dry_run):
"""Install PyTorch binary."""
index = f'https://download.pytorch.org/whl/nightly/{cuda}' if cuda != 'cpu' else 'https://download.pytorch.org/whl/nightly/cpu'
venv = str(target_dir / '.venv')
packages = [
(['--pre', 'torch', '-f', index], 'torch'),
(['ipython', 'numpy', 'pytest'], 'dev tools'),
]
for args, name in track(packages, description="Installing packages..."):
if dry_run:
console.print(f"[dim]Would install {name}[/dim]")
else:
run(['uv', 'pip', 'install'] + args, env={'VIRTUAL_ENV': venv})
if not dry_run:
(target_dir / 'test.py').write_text(
'import torch\nprint(f"PyTorch {torch.__version__}")\n'
'print(f"CUDA: {torch.cuda.is_available()}")\n'
)
console.print("✓ PyTorch binary installed", style="green")
def install_source(target_dir, cuda, personal_remote, remote_name, dry_run):
"""Install PyTorch from source."""
pytorch_dir = target_dir / 'pytorch'
venv = str(target_dir / '.venv')
# Use a reference repository for faster cloning
reference_repo = Path.home() / 'local' / 'pytorch' / 'reference'
# Setup reference repository
if reference_repo.exists():
# Update reference repo to get latest objects
with console.status("Updating reference repository..."):
if not dry_run:
run(['git', 'pull'], cwd=reference_repo, silent=True)
run(['git', 'submodule', 'update', '--init', '--recursive'],
cwd=reference_repo, silent=True)
# Ensure git maintenance is enabled
run(['git', 'maintenance', 'start'], cwd=reference_repo, silent=True)
console.print(f"✓ Using reference repository: {reference_repo}", style="dim")
else:
# Create the reference repository on first clone
console.print("[yellow]First-time setup: creating reference repository...[/yellow]")
if not dry_run:
reference_repo.mkdir(parents=True, exist_ok=True)
with console.status("Creating reference repository (one-time operation)..."):
# Clone with all submodules
run(['git', 'clone', '--recurse-submodules',
'git@github.com:pytorch/pytorch.git', str(reference_repo)])
# Enable git maintenance
run(['git', 'maintenance', 'start'], cwd=reference_repo)
# Clone main repository with reference to share all objects (including submodules)
with console.status("Cloning PyTorch with submodules..."):
if not dry_run:
run(['git', 'clone', '--recurse-submodules', '--reference', str(reference_repo),
'git@github.com:pytorch/pytorch.git', str(pytorch_dir)])
if personal_remote and not dry_run:
run(['git', 'remote', 'add', remote_name, personal_remote], cwd=pytorch_dir)
# Enable git maintenance for faster operations
if not dry_run:
run(['git', 'maintenance', 'start'], cwd=pytorch_dir)
console.print("✓ Git maintenance enabled", style="green")
# Requirements
with console.status("Installing dependencies..."):
if not dry_run:
run(['uv', 'pip', 'install', '-r', str(pytorch_dir / 'requirements.txt')],
env={'VIRTUAL_ENV': venv})
run(['uv', 'pip', 'install', 'ipython', 'hypothesis', 'ninja', 'pytest'],
env={'VIRTUAL_ENV': venv})
# Build
with console.status("Building PyTorch (this will take a while)..."):
if not dry_run:
# Check if BUILD_CONFIG alias exists and expand it
build_env = os.environ.copy()
# Get the BUILD_CONFIG environment variables by executing the alias in a shell
# The alias sets environment variables, so we capture them
result = run(
['bash', '-i', '-c', 'shopt -s expand_aliases; alias BUILD_CONFIG &>/dev/null && echo "$BUILD_CONFIG" || exit 1'],
capture=True,
silent=True
)
if not result or result == '$BUILD_CONFIG':
# Alias doesn't exist or isn't expanded, try alternative method
# Execute a script that sources the alias and prints env vars
result = run(
['bash', '-i', '-c', '''
shopt -s expand_aliases
alias_def=$(alias BUILD_CONFIG 2>/dev/null)
if [ -z "$alias_def" ]; then
exit 1
fi
# Extract command from alias definition
eval "echo ${alias_def#*=}"
'''],
capture=True,
silent=True
)
if not result:
raise click.ClickException(
"BUILD_CONFIG alias not found. Please define it in your shell configuration."
)
# Parse environment variables from the command string
# The result is a string like: CFLAGS='...' USE_CUDA=0 BUILD_TEST=0 ...
import shlex
try:
parts = shlex.split(result)
for part in parts:
if '=' in part:
key, value = part.split('=', 1)
build_env[key] = value
except ValueError as e:
# If shlex fails, fall back to simple parsing
console.print(f"[yellow]Warning: Complex alias parsing, using simple mode[/yellow]")
# Simple split on spaces, handle KEY=VALUE
for part in result.split():
if '=' in part and not part.startswith("'") and not part.startswith('"'):
key, value = part.split('=', 1)
build_env[key] = value.strip("'\"")
# Print build environment
console.print("\n[cyan]Build environment:[/cyan]")
build_vars = {k: v for k, v in build_env.items() if k.startswith(('USE_', 'BUILD_', 'MAX_JOBS', 'CFLAGS'))}
for key in sorted(build_vars.keys()):
console.print(f" {key}={build_vars[key]}", style="dim")
console.print()
run([str(target_dir / '.venv' / 'bin' / 'python'), 'setup.py', 'develop'],
cwd=pytorch_dir, env=build_env)
console.print("✓ PyTorch built from source", style="green")
def validate(target_dir, dry_run):
"""Validate installation."""
if dry_run:
console.print("[dim]Would validate installation[/dim]")
return
python = str(target_dir / '.venv' / 'bin' / 'python')
py_ver = run([python, '-c', 'import sys; print(sys.version.split()[0])'], capture=True, silent=True)
torch_ver = run([python, '-c', 'import torch; print(torch.__version__)'], capture=True, silent=True)
table = Table(show_header=False, title="Installation Summary")
table.add_column(style="cyan")
table.add_column(style="green")
table.add_row("Location", str(target_dir))
table.add_row("Python", py_ver)
table.add_row("PyTorch", torch_ver)
console.print()
console.print(table)
console.print()
console.print(Panel(f"cd {target_dir}\n# Auto-activates via direnv",
title="Next Steps", border_style="green"))
@click.command()
@click.argument('target_dir', type=click.Path())
@click.option('-v', '--version', 'python_version', default='3.11', help='Python version')
@click.option('--mode', type=click.Choice(['release', 'debug']), default='release')
@click.option('--debug', is_flag=True, help='Debug mode (shortcut for --mode debug)')
@click.option('--binary', is_flag=True, help='Use binary install (default: source)')
@click.option('-c', '--cuda', default='cpu', help='CUDA version or "cpu"')
@click.option('--no-gil', is_flag=True, help='Disable GIL (Python 3.13+ only)')
@click.option('--personal-remote', default='git@github.com:albanD/pytorch.git', help='Personal git remote URL')
@click.option('--remote-name', default='alban', help='Personal remote name')
@click.option('--force', is_flag=True, help='Force cleanup without prompt')
@click.option('--dry-run', is_flag=True, help='Show what would be done')
def main(target_dir, python_version, mode, debug, binary, cuda, no_gil,
personal_remote, remote_name, force, dry_run):
"""Setup PyTorch development environment with uv and direnv."""
target = Path(target_dir).resolve()
# Handle debug shortcut
if debug:
mode = 'debug'
# Invert binary flag to get source
source = not binary
try:
console.print("\n[bold cyan]PyTorch Dev Environment Setup[/bold cyan]\n")
console.rule("[1/6] Prerequisites")
check_prerequisites()
console.rule("[2/6] Directory")
handle_directory(target, force)
console.rule("[3/6] Python")
python_spec = setup_python(python_version, no_gil)
console.rule("[4/6] Virtual Environment")
create_venv(target, python_spec, dry_run)
console.rule("[5/6] Direnv")
setup_direnv(target, source, dry_run)
console.rule("[6/6] PyTorch")
if source:
install_source(target, cuda, personal_remote, remote_name, dry_run)
else:
install_binary(target, cuda, dry_run)
console.rule("Validation")
validate(target, dry_run)
console.print("\n[bold green]✓ Setup complete![/bold green]\n")
except click.Abort:
console.print("[yellow]Cancelled by user[/yellow]")
raise
except Exception as e:
console.print(f"\n[red]Error: {e}[/red]")
if target.exists() and not force and not dry_run:
if Confirm.ask("Remove partial installation?", default=False):
shutil.rmtree(target)
raise
if __name__ == '__main__':
main()