Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tools/tk/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ def fp8_gemm(a, b):
from ._C import mha_backward, mha_forward

__all__ = ["mha_backward", "mha_forward"]
""",
"bf16_b300_mha_causal/__init__.py": """from .._runtime import preload_torch_deps

preload_torch_deps()

from ._C import forward, forward_persistent

__all__ = ["forward", "forward_persistent"]
""",
"bf16_b300_mha_noncausal/__init__.py": """from .._runtime import preload_torch_deps

preload_torch_deps()

from ._C import forward

__all__ = ["forward"]
""",
}

Expand Down Expand Up @@ -127,6 +143,18 @@ def _prepare_package_layout():
target.write_text(content)


def _nvcc_supports_sm103a():
try:
result = subprocess.run(
["nvcc", "--list-gpu-arch"],
capture_output=True,
text=True,
)
return "sm_103a" in result.stdout
except Exception:
return False


def test_tk_attn_h100_fwd():
cmd = [
sys.executable,
Expand All @@ -150,4 +178,13 @@ def install_tk():
makefile=TK_TOOLS_PATH.joinpath("bf16_b200_gemm.Makefile"),
output_dir=TK_PACKAGE_PATH.joinpath("bf16_b200"),
)
if _nvcc_supports_sm103a():
_build_extension(
makefile=TK_TOOLS_PATH.joinpath("bf16_b300_mha_causal.Makefile"),
output_dir=TK_PACKAGE_PATH.joinpath("bf16_b300_mha_causal"),
)
_build_extension(
makefile=TK_TOOLS_PATH.joinpath("bf16_b300_mha_noncausal.Makefile"),
output_dir=TK_PACKAGE_PATH.joinpath("bf16_b300_mha_noncausal"),
)
test_tk_attn_h100_fwd()
45 changes: 44 additions & 1 deletion tritonbench/operators/blackwell_attentions/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

import logging

from tritonbench.utils.env_utils import IS_BLACKWELL, is_blackwell
from tritonbench.utils.env_utils import IS_BLACKWELL, is_blackwell, IS_GB300

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -155,6 +155,23 @@
attention as tlx_blackwell,
)

# [Optional] TK B300 MHA kernels
try:
from tritonbench.utils.path_utils import ensure_build_subdir_on_sys_path

with ensure_build_subdir_on_sys_path():
from thunderkittens.bf16_b300_mha_causal import (
forward as tk_b300_causal_fwd,
forward_persistent as tk_b300_causal_persistent_fwd,
)
from thunderkittens.bf16_b300_mha_noncausal import (
forward as tk_b300_noncausal_fwd,
)

HAS_TK_B300 = True
except (ImportError, IOError, AttributeError):
HAS_TK_B300 = False

with try_import("HAS_TILELANG"):
from .tilelang import tilelang_blackwell_attention

Expand Down Expand Up @@ -757,6 +774,32 @@ def fn(q, k, v):

return preproc_permute, fn

@register_benchmark(enabled=IS_GB300 and HAS_TK_B300, fwd_only=True)
@multi_input_wrapper
def tk_bf16_b300_mha(self, *args) -> Tuple[Callable, Callable]:
def preproc(q, k, v):
q, k, v = [t.contiguous() for t in permute_qkv(q, k, v, perm=(0, 2, 1, 3))]
B, S, H, D = q.shape
o = torch.zeros_like(v)
lse = torch.zeros(B, H, 1, S, dtype=torch.float32, device=q.device)
return [q, k, v, o, lse]

if self.causal:

def fn(q, k, v, o, lse):
S = q.shape[1]
fwd = tk_b300_causal_persistent_fwd if S <= 4096 else tk_b300_causal_fwd
fwd(q, k, v, o, lse)
return o

else:

def fn(q, k, v, o, lse):
tk_b300_noncausal_fwd(q, k, v, o, lse)
return o

return preproc, fn

@register_metric(x_only=True)
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down
11 changes: 11 additions & 0 deletions tritonbench/utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ def is_b200() -> bool:
return "B200" in gpu_model


def is_gb300() -> bool:
"""Check if running on an NVIDIA GB300 GPU."""
if not is_cuda_available():
return False
gpu_model = get_nvidia_gpu_model()
return "B300" in gpu_model


IS_GB300 = is_gb300()


def supports_tma():
if not is_cuda_available():
return False
Expand Down
Loading