@@ -379,18 +379,18 @@ enum FaCodePath {
379379};
380380
381381struct vk_fa_pipeline_state {
382- vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
383- : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
382+ vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
383+ : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
384384
385385 uint32_t HSK, HSV;
386- bool small_rows;
386+ bool small_rows, small_cache ;
387387 FaCodePath path;
388388 bool aligned;
389389 bool f32acc;
390390
391391 bool operator<(const vk_fa_pipeline_state &b) const {
392- return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
393- std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
392+ return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
393+ std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b. path, b.aligned, b.f32acc);
394394 }
395395};
396396
@@ -2582,10 +2582,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
25822582static constexpr uint32_t flash_attention_num_small_rows = 32;
25832583static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
25842584
2585- static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
2585+ static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache ) {
25862586 if (hsv >= 192) {
25872587 return 2;
2588- } else if ((hsv | hsk) & 8) {
2588+ } else if ((hsv | hsk) & 8 || small_cache ) {
25892589 return 4;
25902590 } else {
25912591 return 8;
@@ -2607,9 +2607,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
26072607 }
26082608}
26092609
2610- static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
2610+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache ) {
26112611 GGML_UNUSED(clamp);
2612- GGML_UNUSED(hsv);
26132612
26142613 if (path == FA_SCALAR) {
26152614 if (small_rows) {
@@ -2618,9 +2617,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
26182617 if ((hsv | hsk) & 8) {
26192618 // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
26202619 // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2621- return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
2620+ return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache ), 64};
26222621 } else {
2623- return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
2622+ return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache ), 32};
26242623 }
26252624 }
26262625 }
@@ -2649,8 +2648,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
26492648 return {64, 64};
26502649}
26512650
2652- static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
2653- return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
2651+ static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache ) {
2652+ return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache )[1];
26542653}
26552654
26562655static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -2992,11 +2991,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
29922991 align, disable_robustness, require_full_subgroups, required_subgroup_size);
29932992 };
29942993
2995- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2996- return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
2994+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache ) -> std::array<uint32_t, 3> {
2995+ return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache )[0], 1, 1};
29972996 };
29982997
2999- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2998+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache ) -> std::vector<uint32_t> {
30002999 // For large number of rows, 128 invocations seems to work best.
30013000 // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
30023001 // can't use 256 for D==80.
@@ -3006,7 +3005,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30063005 uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
30073006 ? scalar_flash_attention_workgroup_size
30083007 : ((small_rows && (D % 32) == 0) ? 256 : 128);
3009- auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
3008+ auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache );
30103009
30113010 // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
30123011 // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -3021,21 +3020,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
30213020 uint32_t HSK = fa.first.HSK; \
30223021 uint32_t HSV = fa.first.HSV; \
30233022 bool small_rows = fa.first.small_rows; \
3023+ bool small_cache = fa.first.small_cache; \
30243024 FaCodePath path = fa.first.path; \
30253025 bool aligned = fa.first.aligned; \
30263026 bool f32acc = fa.first.f32acc; \
30273027 if (path == FAPATH) { \
30283028 if (aligned) { \
30293029 if (f32acc) { \
3030- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3030+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache ), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache ), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache ), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30313031 } else { \
3032- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3032+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache ), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache ), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache ), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30333033 } \
30343034 } else { \
30353035 if (f32acc) { \
3036- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3036+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache ), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache ), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30373037 } else { \
3038- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3038+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache ), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache ), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
30393039 } \
30403040 } \
30413041 } \
@@ -8008,11 +8008,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
80088008 }
80098009}
80108010
8011- static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
8011+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache ) {
80128012 // Needs to be kept up to date on shader changes
80138013 GGML_UNUSED(hsv);
80148014 const uint32_t wg_size = scalar_flash_attention_workgroup_size;
8015- const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
8015+ const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache );
80168016 const uint32_t Bc = scalar_flash_attention_Bc;
80178017
80188018 const uint32_t tmpsh = wg_size * sizeof(float);
@@ -8136,14 +8136,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
81368136 uint32_t workgroups_y = (uint32_t)neq2;
81378137 uint32_t workgroups_z = (uint32_t)neq3;
81388138
8139+ const bool small_cache = nek1 < 1024;
8140+
81398141 // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
81408142 // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
81418143 uint32_t max_gqa;
81428144 switch (path) {
81438145 case FA_SCALAR:
81448146 case FA_COOPMAT1:
81458147 // We may switch from coopmat1 to scalar, so use the scalar limit for both
8146- max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
8148+ max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache );
81478149 break;
81488150 case FA_COOPMAT2:
81498151 max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@@ -8177,7 +8179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
81778179
81788180 // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
81798181 if (path == FA_SCALAR &&
8180- !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
8182+ !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache )) {
81818183 small_rows = true;
81828184 }
81838185
@@ -8193,7 +8195,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
81938195 v_stride /= 4;
81948196 }
81958197
8196- uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
8198+ uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache );
81978199 bool aligned = (KV % alignment) == 0 &&
81988200 // the "aligned" shader variant will forcibly align strides, for performance
81998201 (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
@@ -8205,7 +8207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
82058207
82068208 bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
82078209
8208- vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
8210+ vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
82098211
82108212 vk_pipeline pipeline = nullptr;
82118213
0 commit comments