Skip to content
Merged
26 changes: 22 additions & 4 deletions cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Union

from cuda import cuda
from cuda.core.experimental._device import Device
from cuda.core.experimental._kernel_arg_handler import ParamHolder
from cuda.core.experimental._module import Kernel
from cuda.core.experimental._stream import Stream
Expand Down Expand Up @@ -38,10 +39,14 @@ class LaunchConfig:
----------
grid : Union[tuple, int]
Collection of threads that will execute a kernel function.
cluster : Union[tuple, int]
Group of blocks (Thread Block Cluster) that will execute on the same
GPU Processing Cluster (GPC). Blocks within a cluster have access to
distributed shared memory and can be explicitly synchronized.
block : Union[tuple, int]
Group of threads (Thread Block) that will execute on the same
multiprocessor. Threads within a thread blocks have access to
shared memory and can be explicitly synchronized.
streaming multiprocessor (SM). Threads within a thread blocks have
access to shared memory and can be explicitly synchronized.
stream : :obj:`Stream`
The stream establishing the stream ordering semantic of a
launch.
Expand All @@ -53,13 +58,22 @@ class LaunchConfig:

# TODO: expand LaunchConfig to include other attributes
grid: Union[tuple, int] = None
cluster: Union[tuple, int] = None
block: Union[tuple, int] = None
stream: Stream = None
shmem_size: Optional[int] = None

def __post_init__(self):
_lazy_init()
self.grid = self._cast_to_3_tuple(self.grid)
self.block = self._cast_to_3_tuple(self.block)
# thread block clusters are supported starting H100
if self.cluster is not None:
if not _use_ex:
raise CUDAError("thread block clusters require cuda.bindings & driver 11.8+")
if Device().compute_capability < (9, 0):
raise CUDAError("thread block clusters are not supported below Hopper")
Comment thread
leofang marked this conversation as resolved.
Outdated
self.cluster = self._cast_to_3_tuple(self.cluster)
# we handle "stream=None" in the launch API
if self.stream is not None and not isinstance(self.stream, Stream):
try:
Expand All @@ -69,8 +83,6 @@ def __post_init__(self):
if self.shmem_size is None:
self.shmem_size = 0

_lazy_init()

def _cast_to_3_tuple(self, cfg):
if isinstance(cfg, int):
if cfg < 1:
Expand Down Expand Up @@ -134,6 +146,12 @@ def launch(kernel, config, *kernel_args):
drv_cfg.hStream = config.stream.handle
drv_cfg.sharedMemBytes = config.shmem_size
drv_cfg.numAttrs = 0 # TODO
if config.cluster:
drv_cfg.numAttrs += 1
attr = cuda.CUlaunchAttribute()
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = config.cluster
drv_cfg.attrs = [attr] # TODO: WHAT!!
Comment thread
leofang marked this conversation as resolved.
Outdated
handle_return(cuda.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
else:
# TODO: check if config has any unsupported attrs
Expand Down
61 changes: 61 additions & 0 deletions cuda_core/examples/thread_block_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import os
import sys

from cuda.core.experimental import Device, LaunchConfig, Program, launch


# prepare include
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
if cuda_path is None:
print("this demo requires a valid CUDA_PATH environment variable set")
sys.exit(0)
cuda_include_path = os.path.join(cuda_path, "include")

# print cluster info
code = r"""
#include <cooperative_groups.h>

namespace cg = cooperative_groups;

extern "C"
__global__ void check_cluster_info() {
auto g = cg::this_grid();
auto b = cg::this_thread_block();
if (g.cluster_rank() == 0 && g.block_rank() == 0 && g.thread_rank() == 0) {
printf("grid dim: (%u, %u, %u)\n", g.dim_blocks().x, g.dim_blocks().y, g.dim_blocks().z);
printf("cluster dim: (%u, %u, %u)\n", g.dim_clusters().x, g.dim_clusters().y, g.dim_clusters().z);
printf("block dim: (%u, %u, %u)\n", b.dim_threads().x, b.dim_threads().y, b.dim_threads().z);
}
}
"""

dev = Device()
dev.set_current()
arch = "".join(f"{i}" for i in dev.compute_capability)

# prepare program
prog = Program(code, code_type="c++")
mod = prog.compile(
target_type="cubin",
# TODO: update this after NVIDIA/cuda-python#237 is merged
options=(f"-arch=sm_{arch}", "-std=c++17", f"-I{cuda_include_path}"),
)

# run in single precision
ker = mod.get_kernel("check_cluster_info")

# prepare launch
grid = 4
cluster = 2
block = 32
config = LaunchConfig(grid=grid, cluster=cluster, block=block, stream=dev.default_stream)

# launch kernel on the default stream
launch(ker, config)
dev.sync()

print("done!")