@@ -168,20 +168,34 @@ def check_type_casting(ctx, n, dtype):
168168
169169 c = tvm .nd .empty ((n ,), dtype , ctx )
170170 assembly = fun .imported_modules [0 ].get_source ()
171- false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))"
172- true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))"
173- lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
174- rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))"
175- cond = "({} && {})" .format (lcond , rcond )
176- select = "select({}, {}, {})" .format (false_branch , true_branch , cond )
177- count = assembly .count (select )
178- assert count == 1
179171
180- fun (c )
172+ if dtype == "float32" :
173+ false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))"
174+ true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))"
175+ lcond = "convert_int4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
176+ rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))"
177+ cond = "({} && {})" .format (lcond , rcond )
178+ select = "select({}, {}, {})" .format (false_branch , true_branch , cond )
179+ count = assembly .count (select )
180+ assert count == 1
181+ fun (c )
182+
183+ elif dtype == "float16" :
184+ false_branch = "((half4)((half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f))"
185+ true_branch = "((half4)((half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f))"
186+ lcond = "convert_short4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
187+ rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))))"
188+ cond = "({} && {})" .format (lcond , rcond )
189+ select = "select({}, {}, {})" .format (false_branch , true_branch , cond )
190+ count = assembly .count (select )
191+ assert count == 1
192+ fun (c )
181193
182194 dev = tvm .device (target , 0 )
183195
184196 check_type_casting (dev , 16 , "float32" )
197+ # fp16 is not yet supported in ci
198+ # check_type_casting(dev, 16, "float16")
185199
186200
187201if __name__ == "__main__" :
0 commit comments