Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class DataType {
bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
/*! \return Whether the type is a scalable vector. */
bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
/*! \return whether type is a vector type. */
bool is_vector() const { return lanes() > 1; }
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
if (expr.dtype().lanes() == type.lanes()) {
return expr;
} else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
} else if (expr.dtype().lanes() == 1 && type.is_vector()) {
return tvm::tir::Broadcast(expr, type.lanes());
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper(
if (const RampNode* ramp = last_index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
} else if (last_index.dtype().is_vector()) {
if (i == 0) {
cached_vector_index = MakeValue(last_index);
}
Expand Down
8 changes: 4 additions & 4 deletions src/target/llvm/intrin_rule_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) {

// Enable QHL library for FP16 data type
const PrimExpr& x = call->args[0];
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
return TVMExternCall(call, tvm_wrapper);
}
#endif
Expand Down Expand Up @@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh")
}

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf");
return TVMExternCall(call, tvm_wrapper);
}
Expand Down Expand Up @@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
}

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf");
return TVMExternCall(call, tvm_wrapper);
}
Expand Down Expand Up @@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid")
const tir::Call new_call = tir::Call(call->dtype, call->op, new_args);

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf");
return TVMExternCall(new_call.get(), tvm_wrapper);
}
Expand Down
8 changes: 4 additions & 4 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand Down Expand Up @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitExpr_(const CastNode* op) {
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand All @@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) {
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand All @@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitStmt_(const BufferStoreNode* op) {
if (op->value->dtype.lanes() > 1) {
if (op->value->dtype.is_vector()) {
if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
Expand Down