diff --git a/csrc/composable_kernel b/csrc/composable_kernel index e8709c24f40..13f6d635653 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit e8709c24f403173ad21a2da907d1347957e324fb +Subproject commit 13f6d635653bd5ffbfcac8577f1ef09590c23d78 diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index bb879453680..083494f5b0c 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -133,9 +133,12 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc - nullptr, // seqstart_q - nullptr, // seqstart_k + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 4d7d5bd655e..0229e777cd5 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -24,7 +24,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, - false}; // do_fp8_static_quant + quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, @@ -95,12 +95,18 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - nullptr, // seqstart_q - nullptr, // seqstart_k - nullptr, + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, @@ -110,8 +116,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, h, // nhead h_k, // nhead_k softmax_scale, // scale_s - 1, // scale_p - 1, // scale_o 0.0f, // logits_soft_cap stride_q, stride_k, diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index bfeb3b770d0..3cd01c32d48 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -139,9 +139,12 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k + seqlens_q.data_ptr(), // seqstart_q_ptr + seqlens_k.data_ptr(), // seqstart_k_ptr + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr total_q, total_k, b, diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 07cfa9a8f90..00b0fcd5738 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -24,7 +24,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, - false}; // do_fp8_static_quant + quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask, @@ -116,12 +116,18 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k - nullptr, // seqlen_kpads + seqlens_q.data_ptr(), // seqstart_q_ptr + seqlens_k.data_ptr(), // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_kv_ptr total_q, total_k, b, @@ -131,8 +137,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, h, // nhead h_k, // nhead_k softmax_scale, // scale_s - 1, // scale_p - 1, // scale_o 0.0f, // logits_soft_cap stride_q, stride_k, diff --git a/setup.py b/setup.py index f0b476255ba..730a190a876 100644 --- a/setup.py +++ b/setup.py @@ -145,7 +145,7 @@ def add_cuda_gencodes(cc_flag, archs, bare_metal_version): cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] return cc_flag - + def get_hip_version(): return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) @@ -436,7 +436,7 @@ def validate_and_update_archs(archs): "csrc/flash_attn_ck/mha_varlen_bwd.cu", "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") - cc_flag += ["-O3","-std=c++17", + cc_flag += ["-O3","-std=c++20", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", "-DCK_ENABLE_BF16", @@ -468,7 +468,7 @@ def validate_and_update_archs(archs): cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] extra_compile_args = { - "cxx": ["-O3", "-std=c++17"] + generator_flag + maybe_hipify_v2_flag, + "cxx": ["-O3", "-std=c++20"] + generator_flag + maybe_hipify_v2_flag, "nvcc": cc_flag + generator_flag + maybe_hipify_v2_flag, }