Skip to content

Commit b56aad1

Browse files
committed
add more detail to a receiving a bad type
1 parent db12ffc commit b56aad1

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/ExtractTileOperations.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,12 @@ Matmul convert_to_matmul(const Store *op, const string &new_name, AMXOpType op_t
384384
}
385385

386386
if (rhs_cast) {
387-
if (op_type == AMXOpType::Int8) {
388-
if (!(rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8))) {
389-
user_error << "Expected rhs cast of i8/u8, got " << rhs_cast->value.type();
390-
}
391-
} else { // AMXOpType::Bfloat16
392-
user_assert(rhs_cast->value.type().element_of() == BFloat(16)) << "Expected rhs cast of bf16";
387+
bool is_i8_u8 = rhs_cast->value.type().element_of() == Int(8) || rhs_cast->value.type().element_of() == UInt(8);
388+
bool is_bf16 = rhs_cast->value.type().element_of() == BFloat(16);
389+
390+
if ((op_type == AMXOpType::Int8 && !is_i8_u8) || (op_type == AMXOpType::Bfloat16 && !is_bf16)) {
391+
user_error << "Expected rhs type of " << (op_type == AMXOpType::Int8 ? "i8/u8" : "bf16")
392+
<< ", got " << rhs_cast->value.type() << " instead.\nIn Expression: " << Expr(rhs_cast);
393393
}
394394
} else {
395395
return {};

0 commit comments

Comments
 (0)