Skip to content

Commit af0da1a

Browse files
committed
add requires_gpu decorator in tests, always test build on non-ampere
1 parent f2adca9 commit af0da1a

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ def maybe_swap(i, j):
7575
return (a, b, c)
7676

7777

78+
def is_ampere_or_newer():
79+
arch = tvm.contrib.nvcc.get_target_compute_version()
80+
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
81+
return major * 10 + minor >= 80
82+
83+
7884
def run_test(
7985
k_inner,
8086
in_dtype,
@@ -182,6 +188,10 @@ def tile_wmma_fragment(block_read, height, width):
182188
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
183189

184190
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
191+
192+
if not is_ampere_or_newer():
193+
return None
194+
185195
dev = tvm.device("cuda", 0)
186196

187197
if in_dtype == "float16":
@@ -221,16 +231,8 @@ def tile_wmma_fragment(block_read, height, width):
221231
return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
222232

223233

224-
def is_ampere_or_newer():
225-
arch = tvm.contrib.nvcc.get_target_compute_version()
226-
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
227-
return major * 10 + minor >= 80
228-
229-
234+
@tvm.testing.requires_cuda
230235
def test_f16f16f32_m16n16k16():
231-
if not is_ampere_or_newer():
232-
return
233-
234236
def index_map(i, j):
235237
return (
236238
i // 16,
@@ -261,7 +263,7 @@ def index_map(i, j):
261263
MMA_store_16x16_f32_global_INTRIN,
262264
)
263265

264-
if measure_perf:
266+
if measure_perf and timer:
265267
print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
266268

267269
timer = run_test(
@@ -282,14 +284,12 @@ def index_map(i, j):
282284
MMA_store_16x16_f32_global_INTRIN,
283285
)
284286

285-
if measure_perf:
287+
if measure_perf and timer:
286288
print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))
287289

288290

291+
@tvm.testing.requires_cuda
289292
def test_f16f16f16_m16n16k16():
290-
if not is_ampere_or_newer():
291-
return
292-
293293
def index_map(i, j):
294294
return (
295295
i // 16,
@@ -320,7 +320,7 @@ def index_map(i, j):
320320
MMA_store_16x16_f16_global_INTRIN,
321321
)
322322

323-
if measure_perf:
323+
if measure_perf and timer:
324324
print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
325325

326326
timer = run_test(
@@ -341,14 +341,12 @@ def index_map(i, j):
341341
MMA_store_16x16_f16_global_INTRIN,
342342
)
343343

344-
if measure_perf:
344+
if measure_perf and timer:
345345
print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))
346346

347347

348+
@tvm.testing.requires_cuda
348349
def test_i8i8i32_m16n16k32():
349-
if not is_ampere_or_newer():
350-
return
351-
352350
def index_map_A(i, j):
353351
return (
354352
i // 16,
@@ -393,7 +391,7 @@ def index_map_C(i, j):
393391
MMA_store_16x16_i32_global_INTRIN,
394392
)
395393

396-
if measure_perf:
394+
if measure_perf and timer:
397395
print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean)))
398396

399397
timer = run_test(
@@ -414,7 +412,7 @@ def index_map_C(i, j):
414412
MMA_store_16x16_i32_global_INTRIN,
415413
)
416414

417-
if measure_perf:
415+
if measure_perf and timer:
418416
print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean)))
419417

420418

0 commit comments

Comments
 (0)