Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp)
file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
Expand Down
9 changes: 6 additions & 3 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets);
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map,
Expand All @@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
Expand Down
1 change: 0 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_

#include <arithmetic/Simplify.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down
1 change: 1 addition & 0 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) {
Expr Analyzer::Simplify(const Expr& expr) {
if (is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr);
if (is_const(res)) return res;
res = this->canonical_simplify(res);
return res;
}
Expand Down
145 changes: 84 additions & 61 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
void Deduce();

void Visit(const NodeRef& e) final {
if (!success) return;
if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success = false;
success_ = false;
return;
}
}
Expand All @@ -111,62 +111,84 @@ class BoundDeducer: public IRVisitor {

void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
result_ += op->b;
} else {
result -= op->a;
result = - result;
is_greater = !is_greater;
result_ -= op->a;
result_ = - result_;
is_greater_ = !is_greater_;
}
Visit(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;

SignType sign;
SignType sign_operand;
if (operand.type().is_uint()) {
sign = kPositive;
sign_operand = kPositive;
} else {
sign = expr_map_[operand].sign_type();
sign_operand = expr_map_[operand].sign_type();
}

if (sign == SignType::kNegative) {
is_greater = !is_greater;
} else if (sign == SignType::kUnknown) {
if (sign_operand == SignType::kNegative) {
is_greater_ = !is_greater_;
} else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
success = false;
success_ = false;
return;
}

// always use relax bound
bool divided = can_prove(result % operand == 0);
result = result / operand;
// since system will round down when not divided
// eg. 2/4 -> 0; -2/4 -> -1
// no need fix for !is_greater:
// eg. a <= 2/4 -> a <= 0
// eg. a <= 0/4 -> a <= 0
// so just fix for not divided and is_greater
// eg. a >= 2/4 -> a >= 0 + 1
// eg. a >= 0/4 -> a >= 0
if (is_greater && !divided) {
result += 1;
bool divided = analyzer_.CanProve(result_ % operand == 0);

result_ = result_ / operand;

if (!divided) {
// Handle non-divisible case
// NOTE: this accounts for truc div behavior.
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();

if (is_greater_) {
result_ += 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may be wrong: if we consider x*2 >= -1, then it will be transformed into x >= 0 + 1 which excludes x = 0.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bound deduction may allow relaxation and is not guaranteed to be tight

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I agree we could start to think about generating tighter bound if possible(by looking into signs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is wrong, because in this case it is not a relaxation, but rather an overtightening, which is unsound.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant overtightening, as documented by the API note, we want to find a set in which the condition is always satisfied, so we can eliminate the condition by constraining the loop within that range.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ok, that makes sense. However I think the DeduceBound function may return an approximation which is not a subset of the true set, because it also performs relaxation. I'm not sure, but I think it might break loop partitioning in some cases.

} else {
// NOTE: this is a bit sutble hack.
//
// condition:
// - x * operand <= result
// - operand > 0
// - x >= 0
//
// Then it is fine to deduce that x <= result / operand.
// - if result > 0, this division round down
// - if result < 0, (result / operand) rounds up and may violate the constraint
// however, given that x is always non-negative,
// it is fine to have this relaxed bound, given that the user of deduce bound
// will respect the bound of x
//
// TODO(tvm-team): think about a better API to incorporate constraint of x.
// e.g. specify an interval of x and return a bound
// that is in the interval and satisfies the condition.
if (target_is_non_neg && sign_operand == kPositive) {
// do nothing
} else {
result_ -= 1;
}
}
}

Visit(left ? op->a : op->b);
}

Expr result;
bool is_greater{true};
bool success{true};
Expr result_;
bool is_greater_{true};
bool success_{true};

private:
void Init();
Expand All @@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor {
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
// internal analzyer
Analyzer analyzer_;
};

class BoundDeduceInputChecker: public IRVisitor {
Expand All @@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor {

void BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false;
if (!checker.Check(this)) success_ = false;
Transform();
}

Expand All @@ -211,93 +235,92 @@ void BoundDeducer::Transform() {
if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
is_greater = true;
is_greater_ = true;
expr_ = op->b;
result = op->a + 1;
result_ = op->a + 1;
} else {
// a < b -> a <= b - 1
is_greater = false;
is_greater_ = false;
expr_ = op->a;
result = op->b - 1;
result_ = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
is_greater = true;
is_greater_ = true;
expr_ = op->b;
result = op->a;
result_ = op->a;
} else {
is_greater = false;
is_greater_ = false;
expr_ = op->a;
result = op->b;
result_ = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
is_greater = false;
is_greater_ = false;
expr_ = op->b;
result = op->a - 1;
result_ = op->a - 1;
} else {
// a > b -> a >= b + 1
is_greater = true;
is_greater_ = true;
expr_ = op->a;
result = op->b + 1;
result_ = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
is_greater = false;
is_greater_ = false;
expr_ = op->b;
result = op->a;
result_ = op->a;
} else {
is_greater = true;
is_greater_ = true;
expr_ = op->a;
result = op->b;
result_ = op->b;
}
} else {
success = false;
success_ = false;
}
}

void BoundDeducer::Deduce() {
Init();
if (!success) return;
if (!success_) return;
Relax();
if (!success) return;
if (!success_) return;
// get the path
path_ = GetPath(target_, expr_);
if (!path_.size()) {
success = false;
success_ = false;
return;
}

expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

Visit(expr_);
}

void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result, relax_map_);
IntSet b = EvalSet(result_, relax_map_);
if (a.is_everything() || b.is_everything()) {
success = false;
success_ = false;
return;
}
expr_ = is_greater ? a.min() : a.max();
result = is_greater ? b.max() : b.min();
expr_ = is_greater_ ? a.min() : a.max();
result_ = is_greater_ ? b.max() : b.min();
}

IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) {
min = d.result;
if (d.is_greater_) {
min = d.result_;
} else {
max = d.result;
max = d.result_;
}
return IntSet::interval(min, max);
}
Expand Down
7 changes: 4 additions & 3 deletions src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
if (pa && pb) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
Expand Down
1 change: 0 additions & 1 deletion src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));


TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
Expand Down
Loading