@@ -160,19 +160,19 @@ def test_full_like():
160160 @tvm .script .ir_module
161161 class FullLike :
162162 @R .function
163- def main (x : R .Tensor ((2 , 3 ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor ((2 , 3 ), "float32 " ):
164- gv : R .Tensor ((2 , 3 ), "float32 " ) = R .full_like (x , v )
163+ def main (x : R .Tensor ((2 , 3 ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor ((2 , 3 ), "int32 " ):
164+ gv : R .Tensor ((2 , 3 ), "int32 " ) = R .full_like (x , v )
165165 return gv
166166
167167 @tvm .script .ir_module
168168 class Expected :
169169 @R .function
170- def main (x : R .Tensor ((2 , 3 ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor ((2 , 3 ), "float32 " ):
171- gv = R .call_tir (Expected .full , (v ,), R .Tensor ((2 , 3 ), dtype = "float32 " ))
170+ def main (x : R .Tensor ((2 , 3 ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor ((2 , 3 ), "int32 " ):
171+ gv = R .call_tir (Expected .full , (v ,), R .Tensor ((2 , 3 ), dtype = "int32 " ))
172172 return gv
173173
174174 @T .prim_func (private = True )
175- def full (rxplaceholder : T .Buffer ((), "float32" ), T_full : T .Buffer ((T .int64 (2 ), T .int64 (3 )), "float32 " )):
175+ def full (rxplaceholder : T .Buffer ((), "float32" ), T_full : T .Buffer ((T .int64 (2 ), T .int64 (3 )), "int32 " )):
176176 T .func_attr ({"tir.noalias" : True })
177177 for i0 , i1 in T .grid (T .int64 (2 ), T .int64 (3 )):
178178 with T .block ("T_full" ):
@@ -191,26 +191,26 @@ def test_full_like_constant_scalar_fill_value():
191191 @tvm .script .ir_module
192192 class FullLike :
193193 @R .function
194- def main (x : R .Tensor ((2 , 3 ), "int32" )) -> R .Tensor ((2 , 3 ), "float32 " ):
195- gv : R .Tensor ((2 , 3 ), "float32 " ) = R .full_like (x , R .const (- 5 , "float32" ))
194+ def main (x : R .Tensor ((2 , 3 ), "int32" )) -> R .Tensor ((2 , 3 ), "int32 " ):
195+ gv : R .Tensor ((2 , 3 ), "int32 " ) = R .full_like (x , R .const (- 5 , "float32" ))
196196 return gv
197197
198198 @tvm .script .ir_module
199199 class Expected :
200200 @R .function
201- def main (x : R .Tensor ((2 , 3 ), "int32" )) -> R .Tensor ((2 , 3 ), "float32 " ):
202- gv = R .call_tir (Expected .full , R .tuple (), R .Tensor ((2 , 3 ), dtype = "float32 " ))
201+ def main (x : R .Tensor ((2 , 3 ), "int32" )) -> R .Tensor ((2 , 3 ), "int32 " ):
202+ gv = R .call_tir (Expected .full , R .tuple (), R .Tensor ((2 , 3 ), dtype = "int32 " ))
203203 return gv
204204
205205 @T .prim_func (private = True )
206- def full (T_full : T .Buffer ((T .int64 (2 ), T .int64 (3 )), "float32 " )):
206+ def full (T_full : T .Buffer ((T .int64 (2 ), T .int64 (3 )), "int32 " )):
207207 T .func_attr ({"tir.noalias" : True })
208208 for i0 , i1 in T .grid (T .int64 (2 ), T .int64 (3 )):
209209 with T .block ("T_full" ):
210210 ax0 , ax1 = T .axis .remap ("SS" , [i0 , i1 ])
211211 T .reads ()
212212 T .writes (T_full [ax0 , ax1 ])
213- T_full [ax0 , ax1 ] = T .float32 (- 5 )
213+ T_full [ax0 , ax1 ] = T .int32 (- 5 )
214214 # fmt: on
215215
216216 mod = LegalizeOps ()(FullLike )
@@ -253,33 +253,33 @@ def test_full_like_symbolic():
253253 @tvm .script .ir_module
254254 class FullLike :
255255 @R .function
256- def main (x : R .Tensor (("m" , "n" ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor (("m" , "n" ), "float32 " ):
256+ def main (x : R .Tensor (("m" , "n" ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor (("m" , "n" ), "int32 " ):
257257 m = T .int64 ()
258258 n = T .int64 ()
259- gv : R .Tensor ((m , n ), "float32 " ) = R .full_like (x , v )
259+ gv : R .Tensor ((m , n ), "int32 " ) = R .full_like (x , v )
260260 return gv
261261
262262 @tvm .script .ir_module
263263 class Expected :
264264 @R .function
265- def main (x : R .Tensor (("m" , "n" ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor (("m" , "n" ), "float32 " ):
265+ def main (x : R .Tensor (("m" , "n" ), "int32" ), v : R .Tensor ((), "float32" )) -> R .Tensor (("m" , "n" ), "int32 " ):
266266 m = T .int64 ()
267267 n = T .int64 ()
268- gv = R .call_tir (Expected .full , (v ,), R .Tensor ((m , n ), dtype = "float32 " ))
268+ gv = R .call_tir (Expected .full , (v ,), R .Tensor ((m , n ), dtype = "int32 " ))
269269 return gv
270270
271271 @T .prim_func (private = True )
272272 def full (rxplaceholder : T .Buffer ((), "float32" ), var_T_full : T .handle ):
273273 T .func_attr ({"tir.noalias" : True })
274274 m = T .int64 ()
275275 n = T .int64 ()
276- T_full = T .match_buffer (var_T_full , [m , n ], dtype = "float32 " )
276+ T_full = T .match_buffer (var_T_full , [m , n ], dtype = "int32 " )
277277 for i0 , i1 in T .grid (m , n ):
278278 with T .block ("T_full" ):
279279 ax0 , ax1 = T .axis .remap ("SS" , [i0 , i1 ])
280280 T .reads (rxplaceholder [()])
281281 T .writes (T_full [ax0 , ax1 ])
282- T_full [ax0 , ax1 ] = rxplaceholder [()]
282+ T_full [ax0 , ax1 ] = T . int32 ( rxplaceholder [()])
283283 # fmt: on
284284
285285 mod = LegalizeOps ()(FullLike )
0 commit comments