This project allows any researcher to start pre-training a fully aligned RWKV v7 model within 15 minutes. Of course, this does not include the time to download the data :)
All code is sourced from the original RWKV-LM project: https://github.com/BlinkDL/RWKV-LM
This repository is suitable for quickly reproducing small-scale RWKV v7 series models (e.g., 191M to 3B) on NVIDIA GPUs using either sample data or private data. We will focus on the following improvements next:
- Provide template code for RWKV series models for tasks such as multimodal applications.
- Provide cross-platform kernel implementations.
- Provide a configurable RWKV Layer class.
- Provide a high-performance PyTorch inference implementation.
- Provide a cluster training framework and scripts suitable for models from 3B to 70B.
We love and give back to the open-source community and appreciate any implementations from it. If you find any issues in our code repository, including but not limited to code quality, code style, code interpretability, or numerical precision errors, you are welcome to submit an issue.
Warning
Note: This is WIP (very likely correct, and more efficient). On the other hand, you can still use RWKV-LM as reference implementation.
To prepare the environment, please use a conda-compatible package manager like miniforge to create a new environment.
conda create -n rwkv-lm-v7 python=3.12
conda activate rwkv-lm-v7
Next, install the following dependencies. Please note that pytorch-lightning is fixed at version 1.9.5. This is a specific requirement for this repository; do not upgrade this package.
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip3 install -r requirements.txt
wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx
wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin
- Initialize an empty RWKV-7 model
sh ./demo-training-prepare.sh
-
Log in to your WandB account
-
Start training
sh ./demo-training-run.sh
This section contains explanations of model initialization, learning rates, and other details.
RWKV-7 uses initializations that are both theoretically designed with mathematical proof and empirically derived from training results to accelerate model convergence and improve performance.
| params | 0.1B | 0.4B | 1.5B | 2.9B | 7.2B | 13.3B | |
|---|---|---|---|---|---|---|---|
| D_DECAY_LORA | w | 64 | 64 | 96 | 96 | 128 | 192 |
| D_AAA_LORA | a | 64 | 64 | 96 | 96 | 128 | 192 |
| D_MV_LORA | v | 32 | 32 | 64 | 64 | 96 | 128 |
| D_GATE_LORA | g | 128 | 128 | 256 | 320 | 480 | 384 |
This type of penalty prevents the model from becoming overconfident, thereby mitigating precision loss in BF16.
Please pay close attention to the learning rate and related settings in the context.
self.k_k = nn.Parameter(torch.zeros(1, 1, C)+0.71 - linear*0.1)
self.k_a = nn.Parameter(torch.zeros(1, 1, C)+1.02)RWKV-7 weight example for 1.5B (L24-D2048, vocab 65536):
| name | shape | comment | initialization |
|---|---|---|---|
| emb.weight | [65536, 2048] | wdecay | see code |
| blocks.0.ln0.weight | [2048] | for layer 0 | 1 |
| blocks.0.ln0.bias | [2048] | for layer 0 | 0 |
| blocks.*.ln1.weight | [2048] | 1 | |
| blocks.*.ln1.bias | [2048] | 0 | |
| blocks.*.att.x_r | [1, 1, 2048] | see code | |
| blocks.*.att.x_w | [1, 1, 2048] | see code | |
| blocks.*.att.x_k | [1, 1, 2048] | see code | |
| blocks.*.att.x_v | [1, 1, 2048] | see code | |
| blocks.*.att.x_a | [1, 1, 2048] | see code | |
| blocks.*.att.x_g | [1, 1, 2048] | see code | |
| blocks.*.att.w0 | [1, 1, 2048] | lr 2x | see code |
| blocks.*.att.w1 | [2048, 96] | 0 | |
| blocks.*.att.w2 | [96, 2048] | see code | |
| blocks.*.att.a0 | [1, 1, 2048] | 0 | |
| blocks.*.att.a1 | [2048, 96] | 0 | |
| blocks.*.att.a2 | [96, 2048] | see code | |
| blocks.*.att.v0 | [1, 1, 2048] | for layer 1+ | 1 |
| blocks.*.att.v1 | [2048, 64] | for layer 1+ | 0 |
| blocks.*.att.v2 | [64, 2048] | for layer 1+ | see code |
| blocks.*.att.g1 | [2048, 256] | 0 | |
| blocks.*.att.g2 | [256, 2048] | see code | |
| blocks.*.att.k_k | [1, 1, 2048] | 1 | |
| blocks.*.att.k_a | [1, 1, 2048] | 1 | |
| blocks.*.att.r_k | [32, 64] | 0 | |
| blocks.*.att.receptance.weight | [2048, 2048] | wdecay | see code |
| blocks.*.att.key.weight | [2048, 2048] | wdecay | see code |
| blocks.*.att.value.weight | [2048, 2048] | wdecay | see code |
| blocks.*.att.output.weight | [2048, 2048] | wdecay | 0 |
| blocks.*.att.ln_x.weight | [2048] | see code | |
| blocks.*.att.ln_x.bias | [2048] | 0 | |
| blocks.*.ln2.weight | [2048] | 1 | |
| blocks.*.ln2.bias | [2048] | 0 | |
| blocks.*.ffn.x_k | [1, 1, 2048] | see code | |
| blocks.*.ffn.key.weight | [8192, 2048] | wdecay | see code |
| blocks.*.ffn.value.weight | [2048, 8192] | wdecay | 0 |
| ln_out.weight | [2048] | 1 | |
| ln_out.bias | [2048] | 0 | |
| head.weight | [65536, 2048] | wdecay | see code |
your out/....../train_log.txt should have losses similar to:
0 4.875856 131.0863 0.00059975 2025-04-24 02:23:42.481256 0
1 4.028621 56.1834 0.00059899 2025-04-24 02:28:16.674463 1
2 3.801625 44.7739 0.00059773 2025-04-24 02:32:51.059568 2
3 3.663070 38.9808 0.00059597 2025-04-24 02:37:25.409892 3
4 3.578974 35.8368 0.00059371 2025-04-24 02:41:59.711315 4
5 3.510906 33.4786 0.00059096 2025-04-24 02:46:33.990839 5
6 3.462345 31.8917 0.00058771 2025-04-24 02:51:08.378331 6
7 3.412196 30.3318 0.00058399 2025-04-24 02:55:42.927474 7
8 3.376724 29.2747 0.00057978 2025-04-24 03:00:17.504665 8
9 3.336911 28.1321 0.00057511 2025-04-24 03:04:52.006063 9
10 3.313411 27.4787 0.00056999 2025-04-24 03:09:27.563336 10
11 3.295895 27.0016 0.00056441 2025-04-24 03:14:01.786079 11
Use data/make_data.py script to convert your training data from .jsonl format to binidx format.
python data/make_data.py [input_file] [n_epoch] [ctx_len]
# example:
cd data/
python make_data.py demo.jsonl 3 4096
This command will:
- shuffle & duplicate demo.jsonl (for 3 epochs)
- load jsonl and tokenize
- save as demo.bin & demo.idx
- compute "magic_prime" for ctxlen 4096
Assume your source jsonl is:
- {"text":"aa"}
- {"text":"bb"}
- {"text":"cc"}
- {"text":"dd"}
The final binidx will be like (here "/" means end_of_doc, which is actually token [0]): bb/aa/dd/cc/dd/aa/bb/cc/dd/bb/cc/aa/
Warning
make_data.py will be very slow for large jsonl,check json2binidx_tool if you need to process large jsonl.
The data/compute_magic_prime.py script computes the correct values of --my_exit_tokens and --magic_prime for a specified binidx dataset and context length (ctx_len).
- change the
DATA_NAMEandCTX_LENin thedata/compute_magic_prime.pyfor your training dataset and context length - run the script to get the correct values of
--my_exit_tokensand--magic_prime
cd data/
python compute_magic_prime.py
output will be like:
### Loading /home/rwkv/RWKV-LM-V7/data/demo
### /home/rwkv/RWKV-LM-V7/data/demo.bin/idx has 200499 tokens, 546 items. Dtype <class 'numpy.uint16'>
### magic_prime = 47 (for ctxlen 4096)
--my_exit_tokens 200499 --magic_prime 47 --ctx_len 4096