-
Notifications
You must be signed in to change notification settings - Fork 277
Expand file tree
/
Copy paththread_block_cluster.py
More file actions
65 lines (54 loc) · 1.87 KB
/
thread_block_cluster.py
File metadata and controls
65 lines (54 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0
import os
import sys
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions, launch
# prepare include
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME"))
if cuda_path is None:
print("this demo requires a valid CUDA_PATH environment variable set", file=sys.stderr)
sys.exit(0)
cuda_include_path = os.path.join(cuda_path, "include")
# print cluster info using a kernel
code = r"""
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
extern "C"
__global__ void check_cluster_info() {
auto g = cg::this_grid();
auto b = cg::this_thread_block();
if (g.cluster_rank() == 0 && g.block_rank() == 0 && g.thread_rank() == 0) {
printf("grid dim: (%u, %u, %u)\n", g.dim_blocks().x, g.dim_blocks().y, g.dim_blocks().z);
printf("cluster dim: (%u, %u, %u)\n", g.dim_clusters().x, g.dim_clusters().y, g.dim_clusters().z);
printf("block dim: (%u, %u, %u)\n", b.dim_threads().x, b.dim_threads().y, b.dim_threads().z);
}
}
"""
dev = Device()
arch = dev.compute_capability
if arch < (9, 0):
print(
"this demo requires compute capability >= 9.0 (since thread block cluster is a hardware feature)",
file=sys.stderr,
)
sys.exit(0)
arch = "".join(f"{i}" for i in arch)
# prepare program & compile kernel
dev.set_current()
prog = Program(
code,
code_type="c++",
options=ProgramOptions(arch=f"sm_{arch}", std="c++17", include_path=cuda_include_path),
)
mod = prog.compile(target_type="cubin")
ker = mod.get_kernel("check_cluster_info")
# prepare launch config
grid = 4
cluster = 2
block = 32
config = LaunchConfig(grid=grid, cluster=cluster, block=block)
# launch kernel on the default stream
launch(dev.default_stream, config, ker)
dev.sync()
print("done!")