@@ -75,6 +75,12 @@ def maybe_swap(i, j):
7575 return (a , b , c )
7676
7777
78+ def is_ampere_or_newer ():
79+ arch = tvm .contrib .nvcc .get_target_compute_version ()
80+ major , minor = tvm .contrib .nvcc .parse_compute_version (arch )
81+ return major * 10 + minor >= 80
82+
83+
7884def run_test (
7985 k_inner ,
8086 in_dtype ,
@@ -182,6 +188,10 @@ def tile_wmma_fragment(block_read, height, width):
182188 sch .tensorize (sch .get_loops (C_warp )[- 2 ], mma_store_intrin )
183189
184190 f = tvm .build (sch .mod ["main" ], target = "cuda" , name = "dense" )
191+
192+ if not is_ampere_or_newer ():
193+ return None
194+
185195 dev = tvm .device ("cuda" , 0 )
186196
187197 if in_dtype == "float16" :
@@ -221,16 +231,8 @@ def tile_wmma_fragment(block_read, height, width):
221231 return lambda : f .time_evaluator (f .entry_name , dev , number = 500 )(a , b , c )
222232
223233
224- def is_ampere_or_newer ():
225- arch = tvm .contrib .nvcc .get_target_compute_version ()
226- major , minor = tvm .contrib .nvcc .parse_compute_version (arch )
227- return major * 10 + minor >= 80
228-
229-
234+ @tvm .testing .requires_cuda
230235def test_f16f16f32_m16n16k16 ():
231- if not is_ampere_or_newer ():
232- return
233-
234236 def index_map (i , j ):
235237 return (
236238 i // 16 ,
@@ -261,7 +263,7 @@ def index_map(i, j):
261263 MMA_store_16x16_f32_global_INTRIN ,
262264 )
263265
264- if measure_perf :
266+ if measure_perf and timer :
265267 print ("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer ().mean )))
266268
267269 timer = run_test (
@@ -282,14 +284,12 @@ def index_map(i, j):
282284 MMA_store_16x16_f32_global_INTRIN ,
283285 )
284286
285- if measure_perf :
287+ if measure_perf and timer :
286288 print ("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer ().mean )))
287289
288290
291+ @tvm .testing .requires_cuda
289292def test_f16f16f16_m16n16k16 ():
290- if not is_ampere_or_newer ():
291- return
292-
293293 def index_map (i , j ):
294294 return (
295295 i // 16 ,
@@ -320,7 +320,7 @@ def index_map(i, j):
320320 MMA_store_16x16_f16_global_INTRIN ,
321321 )
322322
323- if measure_perf :
323+ if measure_perf and timer :
324324 print ("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer ().mean )))
325325
326326 timer = run_test (
@@ -341,14 +341,12 @@ def index_map(i, j):
341341 MMA_store_16x16_f16_global_INTRIN ,
342342 )
343343
344- if measure_perf :
344+ if measure_perf and timer :
345345 print ("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer ().mean )))
346346
347347
348+ @tvm .testing .requires_cuda
348349def test_i8i8i32_m16n16k32 ():
349- if not is_ampere_or_newer ():
350- return
351-
352350 def index_map_A (i , j ):
353351 return (
354352 i // 16 ,
@@ -393,7 +391,7 @@ def index_map_C(i, j):
393391 MMA_store_16x16_i32_global_INTRIN ,
394392 )
395393
396- if measure_perf :
394+ if measure_perf and timer :
397395 print ("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer ().mean )))
398396
399397 timer = run_test (
@@ -414,7 +412,7 @@ def index_map_C(i, j):
414412 MMA_store_16x16_i32_global_INTRIN ,
415413 )
416414
417- if measure_perf :
415+ if measure_perf and timer :
418416 print ("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer ().mean )))
419417
420418
0 commit comments