Skip to content

bwd v3 remove ck dependency#2250

Open
JaxChen29 wants to merge 2 commits intoROCm:mainfrom
JaxChen29:remove_ck_dependency
Open

bwd v3 remove ck dependency#2250
JaxChen29 wants to merge 2 commits intoROCm:mainfrom
JaxChen29:remove_ck_dependency

Conversation

@JaxChen29
Copy link
Contributor

Motivation

Remove the Composable Kernel (CK) dependency from the FMHA backward pass library when building with ONLY_FAV3=1. The V3 ASM kernels are self-contained .co binaries that don't need CK at compile time, but the build previously required CK headers (fmha_bwd.hpp, ck_tile/core.hpp, mask.hpp, etc.) even when only the ASM path was used. This decouples the V3 backward build from CK, enabling faster builds and removing the hard dependency on the CK submodule for ASM-only configurations.

Technical Details

New file: csrc/include/ck_tile_shim.h Minimal shim providing ck_tile:: types (stream_config, index_t, long_index_t, log2e_v, get_warp_size, launch_kernel) for compilation without CK.

csrc/include/aiter_hip_common.h Conditional include: #if ONLY_FAV3 includes ck_tile_shim.h, otherwise includes ck_tile/core.hpp. Safe for all other modules since #if ONLY_FAV3 evaluates to 0 when the macro is undefined.

csrc/include/mha_bwd.h Guarded #include "fmha_bwd.hpp" with #if !ONLY_FAV3 to skip the heavy CK FMHA headers when building V3-only.

csrc/cpp_itfs/mha_bwd.cu

Defined local mask_enum and compute_mask_coordinates() under #if ONLY_FAV3 to replace CK's mask_enum (from mask.hpp) and ck_tile::make_generic_attention_mask_coordinates_from_lr_window().
The SWA mask coordinate computation block is dual-pathed: #if ONLY_FAV3 uses the local function, #else uses the original CK function.
The CK fallback path in mha_bwd() (fmha_bwd(traits, ck_args, s)) remains guarded by the existing #if ONLY_FAV3 / #else block -- no change needed there.
aiter/jit/optCompilerConfig.json For module_fmha_v3_bwd and module_fmha_v3_varlen_bwd:

Removed extra_include: [CK_DIR/example/ck_tile/01_fmha]
Removed flags_extra_hip: ['-DCK_TILE_FMHA_FWD_FAST_EXP2=1']
Kept flags_extra_cc: ['-DONLY_FAV3=1'] and blob_gen_cmd (ASM codegen, not CK-related).
op_tests/cpp/mha/build_mha.sh Updated bwd_v3 benchmark linking: the benchmark binary (bwd.exe) still uses CK include paths for its reference computation (CK is only removed from the library, not the test harness).

Test Plan

Test Result

Submission Checklist

@JaxChen29 JaxChen29 requested a review from a team March 11, 2026 08:28
@JaxChen29 JaxChen29 force-pushed the remove_ck_dependency branch from c296bca to b5e099a Compare March 11, 2026 08:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant