Skip to content

Commit 9999114

Browse files
[Codegen][OpenCL] fix amibiguous selection operator call (#14833)
* fix * fix * Update test_target_codegen_opencl.py * Update test_target_codegen_opencl.py * Update test_target_codegen_opencl.py * Update codegen_opencl.cc --------- Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
1 parent 28c85f0 commit 9999114

2 files changed

Lines changed: 24 additions & 14 deletions

File tree

src/target/source/codegen_opencl.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,11 +567,7 @@ void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) {
567567
os << ", ";
568568
PrintExpr(op->condition, oss);
569569
if (op->dtype.is_float()) {
570-
if (op->condition.dtype().is_uint() || op->condition.dtype().is_int()) {
571-
os << oss.str();
572-
} else {
573-
os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes()));
574-
}
570+
os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes()));
575571
} else {
576572
os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype);
577573
}

tests/python/unittest/test_target_codegen_opencl.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,34 @@ def check_type_casting(ctx, n, dtype):
168168

169169
c = tvm.nd.empty((n,), dtype, ctx)
170170
assembly = fun.imported_modules[0].get_source()
171-
false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))"
172-
true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))"
173-
lcond = "(convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
174-
rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))"
175-
cond = "({} && {})".format(lcond, rcond)
176-
select = "select({}, {}, {})".format(false_branch, true_branch, cond)
177-
count = assembly.count(select)
178-
assert count == 1
179171

180-
fun(c)
172+
if dtype == "float32":
173+
false_branch = "((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f))"
174+
true_branch = "((float4)(1.000000e+00f, 1.000000e+00f, 1.000000e+00f, 1.000000e+00f))"
175+
lcond = "convert_int4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
176+
rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))"
177+
cond = "({} && {})".format(lcond, rcond)
178+
select = "select({}, {}, {})".format(false_branch, true_branch, cond)
179+
count = assembly.count(select)
180+
assert count == 1
181+
fun(c)
182+
183+
elif dtype == "float16":
184+
false_branch = "((half4)((half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f))"
185+
true_branch = "((half4)((half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f, (half)1.000000e+00f))"
186+
lcond = "convert_short4(((convert_uint4(((uint4)((((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3), (((int)get_local_id(0)) == 3)))))"
187+
rcond = "(convert_uint4((((int4)((0)+(1*0), (0)+(1*1), (0)+(1*2), (0)+(1*3))) == ((int4)(3, 3, 3, 3)))))))"
188+
cond = "({} && {})".format(lcond, rcond)
189+
select = "select({}, {}, {})".format(false_branch, true_branch, cond)
190+
count = assembly.count(select)
191+
assert count == 1
192+
fun(c)
181193

182194
dev = tvm.device(target, 0)
183195

184196
check_type_casting(dev, 16, "float32")
197+
# fp16 is not yet supported in ci
198+
# check_type_casting(dev, 16, "float16")
185199

186200

187201
if __name__ == "__main__":

0 commit comments

Comments
 (0)