-
Notifications
You must be signed in to change notification settings - Fork 4.9k
feat: Add ROCm and device-agnostic support #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This change adds support for ROCm and makes the codebase device-agnostic, allowing it to run on different hardware backends including ROCm, CUDA, and CPU. The key changes are: - Modified `pyproject.toml` to use ROCm-compatible PyTorch wheels and added the `pytorch-triton-rocm` dependency. - Refactored `nanochat/common.py` to dynamically detect the available hardware and set the device and distributed backend accordingly. - Updated all training, evaluation, and inference scripts to be device-agnostic, removing hardcoded CUDA references. - Adapted `speedrun.sh` for single-device execution by replacing `torchrun` with `python`. - Updated `nanochat/report.py` to provide more generic GPU information.
|
So far seeeems to be working. running it on a strix halo 128gb version. |
|
nvm, more work needs to be done. |
This commit addresses several runtime errors encountered during the execution of the `speedrun.sh` script and improves the overall configuration of the project. The key changes are: - Patched `nanochat/configurator.py` to be more robust by handling flag-like arguments and ignoring unknown arguments. This resolves the `AssertionError`. - Fixed the argument handling for `chat_eval.py` in `speedrun.sh` to prevent argument parsing errors. - Updated `pyproject.toml` to correctly define optional dependencies for development.
This commit fixes a `torch.AcceleratorError: HIP error: invalid device function` that occurred during weight initialization on ROCm devices. It also improves the device detection logic to correctly identify and prioritize the ROCm backend. The key changes are: - Patched `nanochat/gpt.py` to initialize weights on the CPU before moving them to the target device, which avoids the HIP kernel error. - Simplified and corrected the device detection logic in `nanochat/common.py` to ensure the ROCm backend is properly selected when available.
This commit adds the `HSA_OVERRIDE_GFX_VERSION` environment variable to the `speedrun.sh` script. This is a workaround to enable support for newer AMD GPU architectures (e.g., gfx1151) that are not yet officially supported in the pre-compiled PyTorch ROCm builds. This change also includes an update to the `README.md` to explain this workaround to users.
This change adds the `PYTORCH_CUDA_ALLOC_CONF` environment variable to the main `speedrun.sh` execution script. Setting `expandable_segments:True` is recommended by PyTorch to manage memory more efficiently and prevent fragmentation, addressing a `UserWarning` observed during execution.
…tion Set PYTORCH_CUDA_ALLOC_CONF to prevent memory fragmentation
This commit re-adds the `PYTORCH_CUDA_ALLOC_CONF` environment variable to the training scripts. This setting helps prevent memory fragmentation and is beneficial for both CUDA and ROCm environments. This change was inadvertently removed during a previous refactoring.
| # For newer AMD GPUs that are not yet officially supported by PyTorch ROCm builds, | ||
| # we can override the detected GPU architecture to a compatible one. | ||
| # For example, for a gfx1151 GPU, we can use gfx1100 (11.0.0). | ||
| export HSA_OVERRIDE_GFX_VERSION=11.0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch should be supporting gfx1151 already via. TheRock. This shouldn't be enabled by default IMO.
This could also cause issues when running on gfx9 which @jon-hotaisle was running on.
| [tool.uv.sources] | ||
| torch = [ | ||
| { index = "pytorch-cu128" }, | ||
| { index = "pytorch-rocm63" }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROCm 6.4 is the latest supported one, and there's also ROCm 7. I suggest switching to TheRock wheels for best compatibility https://github.com/ROCm/TheRock/blob/main/RELEASES.md#installing-pytorch-python-packages - these also install ROCm itself as pip wheels, so you don't need to worry about whether the system contains a specific version of ROCm.
EDIT: Currently TheRock wheels are device-specific, so maybe it's best to stick with 6.4/nightly 7.0 wheels from the official pytorch index until TheRock releases with multi-device wheels.
It seems like it's not supported properly for ROCm. But this shouldn't be a blocker for initial support I'd say. |
|
@LokiMetaSmith I've pushed a fix for the issues that @jon-hotaisle faced in jammm@e3e21e2
|
| # Detect hardware | ||
| if hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available(): | ||
| device_type = "cuda" # ROCm uses cuda naming in torch | ||
| backend = "rccl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will throw an error because backend is still nccl for ROCm. PyTorch automatically maps it to RCCL.
Just edited this to nccl on a mi300x node and it works fine.
|
it looks like some of the memory issue is that linux kernel doesn't automatically let you assign the full 128gb, I ran this, and it cleared the running out of memory issue, but there are other problems still. This is strix halo specific. https://www.reddit.com/r/LocalLLaMA/comments/1mib7l9/pytorch_on_rocm_v650rc_gfx1151_amd_strix_halo/ |
|
Also compare the work done here by @indianspeedster master...indianspeedster:nanochat:master It runs on MI300x, but not on MI355x. |
I helped build those wheels in the link. They’re quite old. Please use the ones from https://github.com/ROCm/TheRock/blob/main/RELEASES.md#torch-for-gfx1151 instead. |
|
The one you used from master...indianspeedster:nanochat:master has hard-coded export HSA_OVERRIDE_GFX_VERSION=9.4.2 which is why it’s not running on your mi355. My branch removes that line altogether from speedrun.sh Make sure to do a fresh clone and also delete the ~/.cargo folder. Also, can you confirm which rocm version you have installed ? If you’re on 7.0, there’s a good chance it won’t work because the changes here are using 6.3/6.4 (my branch uses 6.4) it’ll need a minor tweak to pyproject.toml to use the nightly wheels that supports rocm 7. |
|
@jammm Thanks, yea, your branch doesn't work for me. I'm on MI355x which pretty much requires ROCm7 for everything. There are a few things to fix:
I'm trying to run in a recent container: docker.io/rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 Since I'm in a container, I don't need to remove the ~/.cargo as it isn't kept between runs. This is how I start the container, keeping the ~/.cache is definitely helpful.
|
|
@jon-hotaisle I see. Can you try changing this line https://github.com/jammm/nanochat/blob/rocm-support/pyproject.toml#L37
to Delete the |
That's because of the changes in speedrun.sh, which replaced the You could undo those changes to get it working with all 8 GPUs. I did that on the mi300x. Ideally there should be a flag to toggle this. |
Awesome. Once that's done (takes a couple hours or so), you can |
|
@jon-hotaisle Your vram usage is too low ~84GB out of 290GB usage. You need to tweak values as much as possible to so it reaches max vram without oom. Otherwise, you can't get close to max mfu. Easier is to switch batch from 32 (nanochat default) to 64 or 128. If 64 works but 128 oom but vram still not full, you can then try to increase per batch token count. |
|
@Qubitium helpful feedback, thank you! first goal was just to get it running, now we can focus on tuning. should this be something that is added to the PR? Have it set sane defaults on different architectures too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @LokiMetaSmith! It looks like a lot of the edits from this PR have already been addressed by #88. There's quite a bit of merge conflicts as well. I suggest to close this one and open new PR(s) with any remaining issues you'd encounter with the latest version of master.
PS: discussions about different setups & recommendations etc are nice, and it's good to see how this i being run on different platforms & setups. Please use the forum to continue such conversations!
|
|
||
| # pretrain the d20 model | ||
| torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN | ||
| python -m scripts.base_train -- --depth=20 --run=$WANDB_RUN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we'll want this edit either way - the assumption being that this usually will be run on more than one GPU, and those that have only a single one can just edit this on their own fork.








This change adds support for ROCm and makes the codebase device-agnostic, allowing it to run on different hardware backends including ROCm, CUDA, and CPU.
The key changes are:
pyproject.tomlto use ROCm-compatible PyTorch wheels and added thepytorch-triton-rocmdependency.nanochat/common.pyto dynamically detect the available hardware and set the device and distributed backend accordingly.speedrun.shfor single-device execution by replacingtorchrunwithpython.nanochat/report.pyto provide more generic GPU information.