@@ -75,13 +75,13 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
7575 "fp16" ,
7676 "fp16" ,
7777 "fp16" ,
78- multi_a ,
78+ multi_a . data ,
7979 0 ,
80- multi_b ,
80+ multi_b . data ,
8181 0 ,
82- accum ,
82+ accum . data ,
8383 0 ,
84- meta_local ,
84+ meta_local . data ,
8585 0 ,
8686 0 ,
8787 False ,
@@ -90,7 +90,7 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
9090 )
9191
9292 for i in range (4 ):
93- C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = T . load ( "float16" , accum , i )
93+ C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = accum [ i ]
9494
9595
9696@T .prim_func
@@ -129,13 +129,13 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
129129 "fp16" ,
130130 "fp16" ,
131131 "fp32" ,
132- multi_a ,
132+ multi_a . data ,
133133 0 ,
134- multi_b ,
134+ multi_b . data ,
135135 0 ,
136- accum ,
136+ accum . data ,
137137 0 ,
138- meta_local ,
138+ meta_local . data ,
139139 0 ,
140140 0 ,
141141 False ,
@@ -144,7 +144,7 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
144144 )
145145
146146 for i in range (4 ):
147- C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = T . load ( "float32" , accum , i )
147+ C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = accum [ i ]
148148
149149
150150@T .prim_func
@@ -183,13 +183,13 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
183183 "fp16" ,
184184 "fp16" ,
185185 "fp16" ,
186- multi_a ,
186+ multi_a . data ,
187187 0 ,
188- multi_b ,
188+ multi_b . data ,
189189 0 ,
190- accum ,
190+ accum . data ,
191191 0 ,
192- meta_local ,
192+ meta_local . data ,
193193 0 ,
194194 0 ,
195195 False ,
@@ -198,7 +198,7 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
198198 )
199199
200200 for i in range (4 ):
201- C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = T . load ( "float16" , accum , i )
201+ C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = accum [ i ]
202202
203203
204204@T .prim_func
@@ -237,13 +237,13 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
237237 "fp16" ,
238238 "fp16" ,
239239 "fp32" ,
240- multi_a ,
240+ multi_a . data ,
241241 0 ,
242- multi_b ,
242+ multi_b . data ,
243243 0 ,
244- accum ,
244+ accum . data ,
245245 0 ,
246- meta_local ,
246+ meta_local . data ,
247247 0 ,
248248 0 ,
249249 False ,
@@ -252,7 +252,7 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
252252 )
253253
254254 for i in range (4 ):
255- C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = T . load ( "float32" , accum , i )
255+ C [i // 2 * 8 + tx // 4 , tx % 4 * 2 + i % 2 ] = accum [ i ]
256256
257257
258258@tvm .testing .requires_cuda
0 commit comments