Skip to content

Commit b3f5751

Browse files
committed
compatible #9727
1 parent 7567b48 commit b3f5751

1 file changed

Lines changed: 20 additions & 20 deletions

File tree

tests/python/unittest/test_tir_ptx_mma_sp.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
7575
"fp16",
7676
"fp16",
7777
"fp16",
78-
multi_a,
78+
multi_a.data,
7979
0,
80-
multi_b,
80+
multi_b.data,
8181
0,
82-
accum,
82+
accum.data,
8383
0,
84-
meta_local,
84+
meta_local.data,
8585
0,
8686
0,
8787
False,
@@ -90,7 +90,7 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
9090
)
9191

9292
for i in range(4):
93-
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i)
93+
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
9494

9595

9696
@T.prim_func
@@ -129,13 +129,13 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
129129
"fp16",
130130
"fp16",
131131
"fp32",
132-
multi_a,
132+
multi_a.data,
133133
0,
134-
multi_b,
134+
multi_b.data,
135135
0,
136-
accum,
136+
accum.data,
137137
0,
138-
meta_local,
138+
meta_local.data,
139139
0,
140140
0,
141141
False,
@@ -144,7 +144,7 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
144144
)
145145

146146
for i in range(4):
147-
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float32", accum, i)
147+
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
148148

149149

150150
@T.prim_func
@@ -183,13 +183,13 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
183183
"fp16",
184184
"fp16",
185185
"fp16",
186-
multi_a,
186+
multi_a.data,
187187
0,
188-
multi_b,
188+
multi_b.data,
189189
0,
190-
accum,
190+
accum.data,
191191
0,
192-
meta_local,
192+
meta_local.data,
193193
0,
194194
0,
195195
False,
@@ -198,7 +198,7 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata:
198198
)
199199

200200
for i in range(4):
201-
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i)
201+
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
202202

203203

204204
@T.prim_func
@@ -237,13 +237,13 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
237237
"fp16",
238238
"fp16",
239239
"fp32",
240-
multi_a,
240+
multi_a.data,
241241
0,
242-
multi_b,
242+
multi_b.data,
243243
0,
244-
accum,
244+
accum.data,
245245
0,
246-
meta_local,
246+
meta_local.data,
247247
0,
248248
0,
249249
False,
@@ -252,7 +252,7 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata:
252252
)
253253

254254
for i in range(4):
255-
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float32", accum, i)
255+
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
256256

257257

258258
@tvm.testing.requires_cuda

0 commit comments

Comments
 (0)