@@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry {
3636 int64_t coeff{1 };
3737 int64_t base{0 };
3838
39+ Entry () = default ;
40+
41+ Entry (int64_t coeff, int64_t base) {
42+ CHECK_GE (coeff, 0 );
43+ this ->coeff = coeff;
44+ if (coeff != 0 ) {
45+ base = base % coeff;
46+ if (base < 0 ) base += coeff;
47+ }
48+ this ->base = base;
49+ }
50+
3951 bool is_const () const {
4052 return coeff == 0 ;
4153 }
@@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl :
5365 if (!override ) {
5466 CHECK (!var_map_.count (var));
5567 }
56- Entry e;
57- e.coeff = info->coeff ;
58- e.base = info->base ;
59- var_map_[var] = e;
68+ var_map_[var] = Entry (info->coeff , info->base );
6069 }
6170
6271 // Detect useful constraints and use them in the analysis scope.
@@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl :
6574 PVar<Integer> coeff, base;
6675 // pattern match interesting constraints
6776 if (((var % coeff) == base).Match (constraint)) {
68- Entry entry;
69- entry.coeff = coeff.Eval ()->value ;
70- entry.base = base.Eval ()->value ;
77+ Entry entry (coeff.Eval ()->value , base.Eval ()->value );
7178 return UpdateByIntersect (var.Eval (), entry);
7279 }
7380 return nullptr ;
@@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl :
8390 }
8491
8592 Entry VisitExpr_ (const IntImm* op) final {
86- Entry ret;
87- ret.base = op->value ;
88- ret.coeff = 0 ;
89- return ret;
93+ return Entry (0 , op->value );
9094 }
9195
9296 Entry VisitExpr_ (const UIntImm* op) final {
9397 if (op->value < std::numeric_limits<int64_t >::max ()) {
94- Entry ret;
95- ret.base = static_cast <int >(op->value );
96- ret.coeff = 0 ;
97- return ret;
98+ return Entry (0 , static_cast <int >(op->value ));
9899 } else {
99100 return Everything ();
100101 }
@@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl :
103104 Entry VisitExpr_ (const Add* op) final {
104105 Entry a = VisitExpr (op->a );
105106 Entry b = VisitExpr (op->b );
106- Entry ret;
107- ret.coeff = ZeroAwareGCD (a.coeff , b.coeff );
108- ret.base = BaseSimplify (a.base + b.base , ret.coeff );
109- return ret;
107+ int64_t coeff = ZeroAwareGCD (a.coeff , b.coeff );
108+ return Entry (coeff, a.base + b.base );
110109 }
111110
112111 Entry VisitExpr_ (const Sub* op) final {
113112 Entry a = VisitExpr (op->a );
114113 Entry b = VisitExpr (op->b );
115- Entry ret;
116- ret.coeff = ZeroAwareGCD (a.coeff , b.coeff );
117- ret.base = BaseSimplify (a.base - b.base , ret.coeff );
118- return ret;
114+ int64_t coeff = ZeroAwareGCD (a.coeff , b.coeff );
115+ return Entry (coeff, a.base - b.base );
119116 }
120117
121118 Entry VisitExpr_ (const Mul* op) final {
@@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl :
128125 int64_t pq = a.coeff * b.coeff ;
129126 int64_t pm = a.coeff * b.base ;
130127 int64_t qn = a.base * b.coeff ;
131- Entry ret;
132- ret.coeff = ZeroAwareGCD (pq, ZeroAwareGCD (pm, qn));
133- ret.base = BaseSimplify (a.base * b.base , ret.coeff );
134- return ret;
128+ int64_t coeff = ZeroAwareGCD (pq, ZeroAwareGCD (pm, qn));
129+ return Entry (coeff, a.base * b.base );
135130 }
136131
137132 Entry DivByConst (const Expr& lhs,
@@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl :
140135 Entry a = VisitExpr (lhs);
141136 CHECK_NE (val, 0 );
142137 if (a.coeff % val == 0 ) {
143- Entry ret;
144138 if (a.base == 0 ) {
145139 // a c x / c -> a x
146- ret.coeff = std::abs (a.coeff / val);
147- ret.base = 0 ;
148- return ret;
140+ return Entry (std::abs (a.coeff / val), 0 );
149141 }
150142 // positive division have a clear rounding mode.
151143 // Only handle case where we clearly know we need to round down.
152144 if (a.base > 0 && val > 0 &&
153145 (round_down || parent_->CanProveGreaterEqual (lhs, 0 ))) {
154- ret.coeff = a.coeff / val;
155- ret.base = a.base / val;
156- return ret;
146+ return Entry (a.coeff / val, a.base / val);
157147 }
158148 }
159149 return Everything ();
@@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl :
251241 }
252242 int64_t base0 = a.base % coeff;
253243 int64_t base1 = b.base % coeff;
254- Entry ret;
255244 if (base0 == base1) {
256- ret.coeff = coeff;
257- ret.base = base0;
258- return ret;
245+ return Entry (coeff, base0);
259246 } else {
260- ret.coeff = ZeroAwareGCD (ZeroAwareGCD (base0, base1), coeff);
261- ret.base = 0 ;
262- return ret;
247+ return Entry (ZeroAwareGCD (ZeroAwareGCD (base0, base1), coeff), base0);
263248 }
264249 }
250+ /* !
251+ * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
252+ * \param a The first coefficient.
253+ * \param b The second coefficient.
254+ * \param x The solution of x.
255+ * \param y The solution of y.
256+ * \return The GCD of a and b.
257+ */
258+ static int64_t ExtendedEuclidean (int64_t a, int64_t b, int64_t & x, int64_t & y) {
259+ // Extended Euclidean algorithm
260+ // if a < 0, the problem can be convert into
261+ // |a|* (-x) + b * y = gcd(|a|, b)
262+ //
263+ // initial condition:
264+ // a * 0 + b * 1 = b
265+ // a * 1 + b * 0 = a
266+ int64_t s = 0 , old_s = 1 ;
267+ int64_t r = b, old_r = a >= 0 ? a : -a;
268+ // Iteration (r2 < r1):
269+ // a * x1 + b * y1 = r1
270+ // a * x2 + b * y2 = r2
271+ // The above two eqs can derive the following eq (q = r2 / r1)
272+ // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r2 - r1 * q = r3
273+ // Because r3 < r2, the iteration can eventually terminate
274+ while (r != 0 ) {
275+ int64_t q = old_r / r;
276+ int64_t tmp = old_r;
277+ old_r = r;
278+ r = tmp - q * r;
279+ tmp = old_s;
280+ old_s = s;
281+ s = tmp - q * s;
282+ }
283+
284+ x = a >= 0 ? old_s : -old_s;
285+ if (b != 0 ) {
286+ y = (old_r - x * a) / b;
287+ } else {
288+ y = 1 ;
289+ }
290+
291+ return old_r;
292+ }
265293 /* !
266294 * \brief Create interect of two sets.
267295 * \param a The left operand.
268296 * \param b the right operand.
269297 */
270298 static Entry Intersect (Entry a, Entry b) {
271- // simple rule for now: pick higher constraints.
272- // TODO(team-team): Use extended euclidean algorithm.
273- if (a.coeff == 0 ) return a;
274- if (b.coeff == 0 ) return b;
275- if (a.coeff >= b.coeff ) return a;
276- return b;
277- }
278- /* !
279- * \brief Simplify base so that it is in [0, coeff) when coeff != 0.
280- * \param base The base value.
281- * \param coeff The coeff value.
282- * \return The simplified base.
283- */
284- static int64_t BaseSimplify (int64_t base, int64_t coeff) {
285- if (coeff == 0 ) return base;
286- base = base % coeff;
287- if (base < 0 ) base += coeff;
288- return base;
299+ int64_t x, y;
300+ int64_t c1 = a.coeff , b1 = a.base , c2 = b.coeff , b2 = b.base ;
301+ // z = c1 * p + b1
302+ // z = c2 * q + b2
303+ // c1 * x + c2 * y = gcd(c1, c2)
304+ // -> c1 * p - c2 * q = b2 - b1
305+ // -> p = (b2 - b1) / gcd * x
306+ // -> q = (b2 - b1) / gcd * (-y)
307+ // -> z = LCM(x, y) * k + (c1 * p + b1)
308+ int64_t gcd = ExtendedEuclidean (c1, c2, x, y);
309+ int64_t v = b2 - b1;
310+ if (v % gcd == 0 ) {
311+ x = v / gcd * x;
312+ y = v / gcd * (-y);
313+ int64_t coeff = c1 / gcd * c2;
314+ return Entry (coeff, x * c1 + b1);
315+ } else {
316+ return Nothing ();
317+ }
289318 }
290319 /* !
291320 * \brief Take GCD of a and b.
@@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl :
311340 * \return Bound that represent everything dtype can represent.
312341 */
313342 static Entry Everything () {
314- Entry ret;
315- ret.coeff = 1 ; ret.base = 0 ;
316- return ret;
343+ return Entry (1 , 0 );
344+ }
345+ /* !
346+ * \brief return an empty set
347+ * \return Bound that represent everything dtype can represent.
348+ */
349+ static Entry Nothing () {
350+ return Entry (0 , 1 );
317351 }
318352};
319353
0 commit comments