Skip to content

RWKV-Vibe/RWKV-LM-V7

Repository files navigation

RWKV-LM-V7

English 中文

Project Introduction

This project allows any researcher to start pre-training a fully aligned RWKV v7 model within 15 minutes. Of course, this does not include the time to download the data :)

All code is sourced from the original RWKV-LM project: https://github.com/BlinkDL/RWKV-LM

This repository is suitable for quickly reproducing small-scale RWKV v7 series models (e.g., 191M to 3B) on NVIDIA GPUs using either sample data or private data. We will focus on the following improvements next:

  • Provide template code for RWKV series models for tasks such as multimodal applications.
  • Provide cross-platform kernel implementations.
  • Provide a configurable RWKV Layer class.
  • Provide a high-performance PyTorch inference implementation.
  • Provide a cluster training framework and scripts suitable for models from 3B to 70B.

We love and give back to the open-source community and appreciate any implementations from it. If you find any issues in our code repository, including but not limited to code quality, code style, code interpretability, or numerical precision errors, you are welcome to submit an issue.

Warning

Note: This is WIP (very likely correct, and more efficient). On the other hand, you can still use RWKV-LM as reference implementation.

How to start?

Prepare Environment

To prepare the environment, please use a conda-compatible package manager like miniforge to create a new environment.

conda create -n rwkv-lm-v7 python=3.12
conda activate rwkv-lm-v7

Next, install the following dependencies. Please note that pytorch-lightning is fixed at version 1.9.5. This is a specific requirement for this repository; do not upgrade this package.

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
pip3 install -r requirements.txt

Download Data

wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx
wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin

Start Training

  1. Initialize an empty RWKV-7 model
sh ./demo-training-prepare.sh
  1. Log in to your WandB account

  2. Start training

sh ./demo-training-run.sh

Detailed Explanation

This section contains explanations of model initialization, learning rates, and other details.

RWKV-7 uses initializations that are both theoretically designed with mathematical proof and empirically derived from training results to accelerate model convergence and improve performance.

RWKV G1 model Lora Dim

params 0.1B 0.4B 1.5B 2.9B 7.2B 13.3B
D_DECAY_LORA w 64 64 96 96 128 192
D_AAA_LORA a 64 64 96 96 128 192
D_MV_LORA v 32 32 64 64 96 128
D_GATE_LORA g 128 128 256 320 480 384

L2Warp

This type of penalty prevents the model from becoming overconfident, thereby mitigating precision loss in BF16.

Weights and Initialization Example

Please pay close attention to the learning rate and related settings in the context.

self.k_k = nn.Parameter(torch.zeros(1, 1, C)+0.71 - linear*0.1)
self.k_a = nn.Parameter(torch.zeros(1, 1, C)+1.02)

RWKV-7 weight example for 1.5B (L24-D2048, vocab 65536):

name shape comment initialization
emb.weight [65536, 2048] wdecay see code
blocks.0.ln0.weight [2048] for layer 0 1
blocks.0.ln0.bias [2048] for layer 0 0
blocks.*.ln1.weight [2048] 1
blocks.*.ln1.bias [2048] 0
blocks.*.att.x_r [1, 1, 2048] see code
blocks.*.att.x_w [1, 1, 2048] see code
blocks.*.att.x_k [1, 1, 2048] see code
blocks.*.att.x_v [1, 1, 2048] see code
blocks.*.att.x_a [1, 1, 2048] see code
blocks.*.att.x_g [1, 1, 2048] see code
blocks.*.att.w0 [1, 1, 2048] lr 2x see code
blocks.*.att.w1 [2048, 96] 0
blocks.*.att.w2 [96, 2048] see code
blocks.*.att.a0 [1, 1, 2048] 0
blocks.*.att.a1 [2048, 96] 0
blocks.*.att.a2 [96, 2048] see code
blocks.*.att.v0 [1, 1, 2048] for layer 1+ 1
blocks.*.att.v1 [2048, 64] for layer 1+ 0
blocks.*.att.v2 [64, 2048] for layer 1+ see code
blocks.*.att.g1 [2048, 256] 0
blocks.*.att.g2 [256, 2048] see code
blocks.*.att.k_k [1, 1, 2048] 1
blocks.*.att.k_a [1, 1, 2048] 1
blocks.*.att.r_k [32, 64] 0
blocks.*.att.receptance.weight [2048, 2048] wdecay see code
blocks.*.att.key.weight [2048, 2048] wdecay see code
blocks.*.att.value.weight [2048, 2048] wdecay see code
blocks.*.att.output.weight [2048, 2048] wdecay 0
blocks.*.att.ln_x.weight [2048] see code
blocks.*.att.ln_x.bias [2048] 0
blocks.*.ln2.weight [2048] 1
blocks.*.ln2.bias [2048] 0
blocks.*.ffn.x_k [1, 1, 2048] see code
blocks.*.ffn.key.weight [8192, 2048] wdecay see code
blocks.*.ffn.value.weight [2048, 8192] wdecay 0
ln_out.weight [2048] 1
ln_out.bias [2048] 0
head.weight [65536, 2048] wdecay see code

Check Result

your out/....../train_log.txt should have losses similar to:

0 4.875856 131.0863 0.00059975 2025-04-24 02:23:42.481256 0
1 4.028621 56.1834 0.00059899 2025-04-24 02:28:16.674463 1
2 3.801625 44.7739 0.00059773 2025-04-24 02:32:51.059568 2
3 3.663070 38.9808 0.00059597 2025-04-24 02:37:25.409892 3
4 3.578974 35.8368 0.00059371 2025-04-24 02:41:59.711315 4
5 3.510906 33.4786 0.00059096 2025-04-24 02:46:33.990839 5
6 3.462345 31.8917 0.00058771 2025-04-24 02:51:08.378331 6
7 3.412196 30.3318 0.00058399 2025-04-24 02:55:42.927474 7
8 3.376724 29.2747 0.00057978 2025-04-24 03:00:17.504665 8
9 3.336911 28.1321 0.00057511 2025-04-24 03:04:52.006063 9
10 3.313411 27.4787 0.00056999 2025-04-24 03:09:27.563336 10
11 3.295895 27.0016 0.00056441 2025-04-24 03:14:01.786079 11

Processing training data

Convert jsonl to binidx format

Use data/make_data.py script to convert your training data from .jsonl format to binidx format.

python data/make_data.py [input_file] [n_epoch] [ctx_len]
# example:
cd data/
python make_data.py demo.jsonl 3 4096

This command will:

  • shuffle & duplicate demo.jsonl (for 3 epochs)
  • load jsonl and tokenize
  • save as demo.bin & demo.idx
  • compute "magic_prime" for ctxlen 4096

Assume your source jsonl is:

  • {"text":"aa"}
  • {"text":"bb"}
  • {"text":"cc"}
  • {"text":"dd"}

The final binidx will be like (here "/" means end_of_doc, which is actually token [0]): bb/aa/dd/cc/dd/aa/bb/cc/dd/bb/cc/aa/

Warning

make_data.py will be very slow for large jsonl,check json2binidx_tool if you need to process large jsonl.

Compute magic_prime for specified binidx dataset

The data/compute_magic_prime.py script computes the correct values of --my_exit_tokens and --magic_prime for a specified binidx dataset and context length (ctx_len).

  • change the DATA_NAME and CTX_LEN in the data/compute_magic_prime.py for your training dataset and context length
  • run the script to get the correct values of --my_exit_tokens and --magic_prime
cd data/
python compute_magic_prime.py

output will be like:

### Loading /home/rwkv/RWKV-LM-V7/data/demo

### /home/rwkv/RWKV-LM-V7/data/demo.bin/idx has 200499 tokens, 546 items. Dtype <class 'numpy.uint16'>

### magic_prime = 47 (for ctxlen 4096)

--my_exit_tokens 200499 --magic_prime 47 --ctx_len 4096

About

RWKV-LM-V7(https://github.com/BlinkDL/RWKV-LM) Under Lightning Framework

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors