Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 309 additions & 0 deletions custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,315 @@ __global__ void append_decode_cache_T_quant_neox_rope_kernel(
#endif
}

template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool is_scale_channel_wise = false,
bool IsFP8 = true,
bool IsDynamic = true>
__global__ void append_decode_cache_T_int8_neox_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
T* __restrict__ cache_k_scale,
T* __restrict__ cache_v_scale,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d,
const float rms_norm_eps) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane_id = tid % 32;
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
const int* block_table_now = nullptr;

block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;

float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, VecSize>;

LoadT src_vec;
LoadT src_vec_right;
LoadBiasT out_vec;
LoadBiasT out_vec_right;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
#pragma unroll
for (uint32_t head_bias = lane_id * VecSize; head_bias < half_head_size;
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
Load<T, VecSize>(&qkv_now[bias_idx + half_head_size], &src_vec_right);
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
const uint32_t new_emb_idx =
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[i]);
float input_right = static_cast<float>(src_vec_right[i]);

const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[i] = static_cast<T>(tmp1);
out_vec_right[i] = static_cast<T>(tmp2);
}
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
Store<T, VecSize>(out_vec_right, &qkv_out_now[bias_idx + half_head_size]);
}
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
if (block_offset == 0) {
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
if (head_idx < num_heads + kv_num_heads) {
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim;
block_i < block_size;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
&key_cache[tgt_idx + block_i * HeadDim]);
}
} else {
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
const int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
}
}
__syncwarp();
}

constexpr int K_VEC_SIZE = 4;
constexpr int HALF_K_VEC_SIZE = 2;
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadEmbT = AlignedVector<float, HALF_K_VEC_SIZE>;
LoadKVResT cache_vec;
LoadT src_vec1, src_vec1_right, src_vec2, src_vec2_right;
LoadBiasT out_vec1, out_vec2;
LoadEmbT cos_emb_vec1, cos_emb_vec2;
LoadEmbT sin_emb_vec1, sin_emb_vec2;

const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
T scale = T(1.0f);
const int k_head_idx = head_idx - num_heads;
const int v_head_idx = head_idx - num_heads - kv_num_heads;
if (head_idx < num_heads + kv_num_heads) {
Load<T, HALF_K_VEC_SIZE>(
&qkv_now[head_idx * HeadDim + (head_bias + half_head_size) % HeadDim],
&src_vec1_right);
Load<T, HALF_K_VEC_SIZE>(
&qkv_now[head_idx * HeadDim +
(head_bias + 8 + half_head_size) % HeadDim],
&src_vec2_right);

const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
const uint32_t new_emb_idx =
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
}

if (head_idx < num_heads + kv_num_heads) {
float input_left = static_cast<float>(src_vec1[0]);
float input_right = static_cast<float>(src_vec1_right[0]);
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
float tmp1 = 0;
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
}
out_vec1[0] = static_cast<T>(tmp1);
input_left = static_cast<float>(src_vec1[1]);
input_right = static_cast<float>(src_vec1_right[1]);
cos_tmp = cos_emb_vec1[1];
sin_tmp = sin_emb_vec1[1];
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
}
out_vec1[1] = static_cast<T>(tmp1);
} else {
out_vec1[0] = src_vec1[0];
out_vec1[1] = src_vec1[1];
}

// rope
if (head_idx < num_heads + kv_num_heads) {
float input_left = static_cast<float>(src_vec2[0]);
float input_right = static_cast<float>(src_vec2_right[0]);
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
float tmp1 = 0;
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
}
out_vec2[0] = static_cast<T>(tmp1);
input_left = static_cast<float>(src_vec2[1]);
input_right = static_cast<float>(src_vec2_right[1]);
cos_tmp = cos_emb_vec2[1];
sin_tmp = sin_emb_vec2[1];
if (head_bias < half_head_size) {
tmp1 = input_left * cos_tmp - input_right * sin_tmp;
} else {
tmp1 = input_left * cos_tmp + input_right * sin_tmp;
}
out_vec2[1] = static_cast<T>(tmp1);
} else {
out_vec2[0] = src_vec2[0];
out_vec2[1] = src_vec2[1];
}
if constexpr (IsDynamic) {
// reduce max, 1 head per warp
T local_max = -INFINITY;
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(out_vec1[i]));
local_max = __hmax(local_max, __habs(out_vec2[i]));
}
#pragma unroll
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
local_max =
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}
scale = __hdiv(448, local_max);

int cache_offset;
if (head_idx < num_heads) {
cache_offset = 0;
} else if (head_idx < num_heads + 2 * kv_num_heads) {
cache_offset = block_idx * kv_num_heads * block_size +
(head_idx - num_heads) % kv_num_heads * block_size +
block_offset;
}
T* cache_k_scale_now = cache_k_scale + cache_offset;
T* cache_v_scale_now = cache_v_scale + cache_offset;
if (lane_id == 0) {
if (head_idx < num_heads + kv_num_heads) {
cache_k_scale_now[0] = __hdiv(1, scale);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
}
}
} else {
if (head_idx < num_heads + kv_num_heads) {
scale = __ldg(&cache_k_scale[kv_head_idx]);
} else {
scale = __ldg(&cache_v_scale[kv_head_idx]);
}
}

#pragma unroll
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
cache_vec[i] = QuantToC8<T, true, IsFP8, RoundType>(
scale, out_vec1[i], max_bound, min_bound);
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true, IsFP8, RoundType>(
scale, out_vec2[i], max_bound, min_bound);
}
if (head_idx < num_heads + kv_num_heads) {
const int start_block_16 =
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
const uint32_t tgt_cache_idx =
block_idx * kv_num_heads * block_size * HeadDim +
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
} else {
const uint32_t base_tgt_cache_idx =
block_idx * kv_num_heads * HeadDim * block_size +
kv_head_idx * HeadDim * block_size +
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
block_offset % 8 / 2 * 4 // per 4
+ block_offset % 16 / 8 * 2 // per 2
+ block_offset % 2; // per 1
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
value_cache[tgt_cache_idx1] = cache_vec[0];
value_cache[tgt_cache_idx2] = cache_vec[1];
value_cache[tgt_cache_idx3] = cache_vec[2];
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

template <typename T,
int VecSize = 4,
int RoundType = 0,
Expand Down
Loading
Loading