Skip to content

Commit 456f84c

Browse files
committed
Fix intersect of modular set
1 parent 84cb712 commit 456f84c

File tree

2 files changed

+113
-62
lines changed

2 files changed

+113
-62
lines changed

src/arithmetic/modular_set.cc

Lines changed: 96 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry {
3636
int64_t coeff{1};
3737
int64_t base{0};
3838

39+
Entry() = default;
40+
41+
Entry(int64_t coeff, int64_t base) {
42+
CHECK_GE(coeff, 0);
43+
this->coeff = coeff;
44+
if (coeff != 0) {
45+
base = base % coeff;
46+
if (base < 0) base += coeff;
47+
}
48+
this->base = base;
49+
}
50+
3951
bool is_const() const {
4052
return coeff == 0;
4153
}
@@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl :
5365
if (!override) {
5466
CHECK(!var_map_.count(var));
5567
}
56-
Entry e;
57-
e.coeff = info->coeff;
58-
e.base = info->base;
59-
var_map_[var] = e;
68+
var_map_[var] = Entry(info->coeff, info->base);
6069
}
6170

6271
// Detect useful constraints and use them in the analysis scope.
@@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl :
6574
PVar<Integer> coeff, base;
6675
// pattern match interesting constraints
6776
if (((var % coeff) == base).Match(constraint)) {
68-
Entry entry;
69-
entry.coeff = coeff.Eval()->value;
70-
entry.base = base.Eval()->value;
77+
Entry entry(coeff.Eval()->value, base.Eval()->value);
7178
return UpdateByIntersect(var.Eval(), entry);
7279
}
7380
return nullptr;
@@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl :
8390
}
8491

8592
Entry VisitExpr_(const IntImm* op) final {
86-
Entry ret;
87-
ret.base = op->value;
88-
ret.coeff = 0;
89-
return ret;
93+
return Entry(0, op->value);
9094
}
9195

9296
Entry VisitExpr_(const UIntImm* op) final {
9397
if (op->value < std::numeric_limits<int64_t>::max()) {
94-
Entry ret;
95-
ret.base = static_cast<int>(op->value);
96-
ret.coeff = 0;
97-
return ret;
98+
return Entry(0, static_cast<int>(op->value));
9899
} else {
99100
return Everything();
100101
}
@@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl :
103104
Entry VisitExpr_(const Add* op) final {
104105
Entry a = VisitExpr(op->a);
105106
Entry b = VisitExpr(op->b);
106-
Entry ret;
107-
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
108-
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
109-
return ret;
107+
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
108+
return Entry(coeff, a.base + b.base);
110109
}
111110

112111
Entry VisitExpr_(const Sub* op) final {
113112
Entry a = VisitExpr(op->a);
114113
Entry b = VisitExpr(op->b);
115-
Entry ret;
116-
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
117-
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
118-
return ret;
114+
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
115+
return Entry(coeff, a.base - b.base);
119116
}
120117

121118
Entry VisitExpr_(const Mul* op) final {
@@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl :
128125
int64_t pq = a.coeff * b.coeff;
129126
int64_t pm = a.coeff * b.base;
130127
int64_t qn = a.base * b.coeff;
131-
Entry ret;
132-
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
133-
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
134-
return ret;
128+
int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
129+
return Entry(coeff, a.base * b.base);
135130
}
136131

137132
Entry DivByConst(const Expr& lhs,
@@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl :
140135
Entry a = VisitExpr(lhs);
141136
CHECK_NE(val, 0);
142137
if (a.coeff % val == 0) {
143-
Entry ret;
144138
if (a.base == 0) {
145139
// a c x / c -> a x
146-
ret.coeff = std::abs(a.coeff / val);
147-
ret.base = 0;
148-
return ret;
140+
return Entry(std::abs(a.coeff / val), 0);
149141
}
150142
// positive division have a clear rounding mode.
151143
// Only handle case where we clearly know we need to round down.
152144
if (a.base > 0 && val > 0 &&
153145
(round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
154-
ret.coeff = a.coeff / val;
155-
ret.base = a.base / val;
156-
return ret;
146+
return Entry(a.coeff / val, a.base / val);
157147
}
158148
}
159149
return Everything();
@@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl :
251241
}
252242
int64_t base0 = a.base % coeff;
253243
int64_t base1 = b.base % coeff;
254-
Entry ret;
255244
if (base0 == base1) {
256-
ret.coeff = coeff;
257-
ret.base = base0;
258-
return ret;
245+
return Entry(coeff, base0);
259246
} else {
260-
ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff);
261-
ret.base = 0;
262-
return ret;
247+
return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
263248
}
264249
}
250+
/*!
251+
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
252+
* \param a The first coefficient.
253+
* \param b The second coefficient.
254+
* \param x The solution of x.
255+
* \param y The solution of y.
256+
* \return The GCD of a and b.
257+
*/
258+
static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t& x, int64_t& y) {
259+
// Extended Euclidean algorithm
260+
// if a < 0, the problem can be convert into
261+
// |a|* (-x) + b * y = gcd(|a|, b)
262+
//
263+
// initial condition:
264+
// a * 0 + b * 1 = b
265+
// a * 1 + b * 0 = a
266+
int64_t s = 0, old_s = 1;
267+
int64_t r = b, old_r = a >= 0 ? a : -a;
268+
// Iteration (r2 < r1):
269+
// a * x1 + b * y1 = r1
270+
// a * x2 + b * y2 = r2
271+
// The above two eqs can derive the following eq (q = r2 / r1)
272+
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r2 - r1 * q = r3
273+
// Because r3 < r2, the iteration can eventually terminate
274+
while (r != 0) {
275+
int64_t q = old_r / r;
276+
int64_t tmp = old_r;
277+
old_r = r;
278+
r = tmp - q * r;
279+
tmp = old_s;
280+
old_s = s;
281+
s = tmp - q * s;
282+
}
283+
284+
x = a >= 0 ? old_s : -old_s;
285+
if (b != 0) {
286+
y = (old_r - x * a) / b;
287+
} else {
288+
y = 1;
289+
}
290+
291+
return old_r;
292+
}
265293
/*!
266294
* \brief Create interect of two sets.
267295
* \param a The left operand.
268296
* \param b the right operand.
269297
*/
270298
static Entry Intersect(Entry a, Entry b) {
271-
// simple rule for now: pick higher constraints.
272-
// TODO(team-team): Use extended euclidean algorithm.
273-
if (a.coeff == 0) return a;
274-
if (b.coeff == 0) return b;
275-
if (a.coeff >= b.coeff) return a;
276-
return b;
277-
}
278-
/*!
279-
* \brief Simplify base so that it is in [0, coeff) when coeff != 0.
280-
* \param base The base value.
281-
* \param coeff The coeff value.
282-
* \return The simplified base.
283-
*/
284-
static int64_t BaseSimplify(int64_t base, int64_t coeff) {
285-
if (coeff == 0) return base;
286-
base = base % coeff;
287-
if (base < 0) base += coeff;
288-
return base;
299+
int64_t x, y;
300+
int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base;
301+
// z = c1 * p + b1
302+
// z = c2 * q + b2
303+
// c1 * x + c2 * y = gcd(c1, c2)
304+
// -> c1 * p - c2 * q = b2 - b1
305+
// -> p = (b2 - b1) / gcd * x
306+
// -> q = (b2 - b1) / gcd * (-y)
307+
// -> z = LCM(x, y) * k + (c1 * p + b1)
308+
int64_t gcd = ExtendedEuclidean(c1, c2, x, y);
309+
int64_t v = b2 - b1;
310+
if (v % gcd == 0) {
311+
x = v / gcd * x;
312+
y = v / gcd * (-y);
313+
int64_t coeff = c1 / gcd * c2;
314+
return Entry(coeff, x * c1 + b1);
315+
} else {
316+
return Nothing();
317+
}
289318
}
290319
/*!
291320
* \brief Take GCD of a and b.
@@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl :
311340
* \return Bound that represent everything dtype can represent.
312341
*/
313342
static Entry Everything() {
314-
Entry ret;
315-
ret.coeff = 1; ret.base = 0;
316-
return ret;
343+
return Entry(1, 0);
344+
}
345+
/*!
346+
* \brief return an empty set
347+
* \return Bound that represent everything dtype can represent.
348+
*/
349+
static Entry Nothing() {
350+
return Entry(0, 1);
317351
}
318352
};
319353

tests/python/unittest/test_arith_modular_set.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,22 @@ def test_constraint_scope():
117117
assert m.coeff == 1
118118
assert m.base == 0
119119

120+
def test_intersect():
121+
a = tvm.var("a")
122+
analyzer = tvm.arith.Analyzer()
123+
with analyzer.constraint_scope(a % 4 == 1):
124+
with analyzer.constraint_scope(a % 3 == 1):
125+
m = analyzer.modular_set(a)
126+
assert m.coeff == 12
127+
assert m.base == 1
128+
129+
with analyzer.constraint_scope(a % 3 == 2):
130+
with analyzer.constraint_scope(a % 5 == 3):
131+
with analyzer.constraint_scope(a % 7 == 2):
132+
m = analyzer.modular_set(a)
133+
assert m.coeff == 105
134+
assert m.base == 23
135+
120136

121137
if __name__ == "__main__":
122138
test_cast()
@@ -126,3 +142,4 @@ def test_constraint_scope():
126142
test_min_max_select()
127143
test_mix_index()
128144
test_constraint_scope()
145+
test_intersect()

0 commit comments

Comments
 (0)