Commit 912993f
authored
[ARM] Fix int8 NCHWc compute and alter layout (#10839)
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in #10310. The compute itself, not the schedule, is broken for the following reasons:
* We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375
* In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478
* The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108
* Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension.
Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect.
Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction.
@tkonolige I suggest doing perf benchmark again, since the numbers in #10310 are invalid.
cc @mbrookhart @Mousius @junrushao1994 @vinx131 parent 63bb3b9 commit 912993f
6 files changed
Lines changed: 26 additions & 23 deletions
File tree
- python/tvm/topi
- arm_cpu
- nn
- x86
- tests/python/topi/python
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
347 | 347 | | |
348 | 348 | | |
349 | 349 | | |
350 | | - | |
| 350 | + | |
351 | 351 | | |
352 | 352 | | |
353 | 353 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
60 | | - | |
| 60 | + | |
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| |||
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
| 106 | + | |
| 107 | + | |
106 | 108 | | |
107 | | - | |
| 109 | + | |
108 | 110 | | |
109 | 111 | | |
110 | 112 | | |
| |||
149 | 151 | | |
150 | 152 | | |
151 | 153 | | |
152 | | - | |
| 154 | + | |
| 155 | + | |
153 | 156 | | |
154 | 157 | | |
155 | 158 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
614 | 614 | | |
615 | 615 | | |
616 | 616 | | |
617 | | - | |
618 | | - | |
619 | | - | |
620 | | - | |
621 | | - | |
622 | | - | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
623 | 622 | | |
624 | | - | |
| 623 | + | |
625 | 624 | | |
| 625 | + | |
| 626 | + | |
626 | 627 | | |
627 | 628 | | |
628 | 629 | | |
629 | 630 | | |
630 | 631 | | |
631 | | - | |
| 632 | + | |
632 | 633 | | |
633 | 634 | | |
634 | 635 | | |
| |||
638 | 639 | | |
639 | 640 | | |
640 | 641 | | |
641 | | - | |
642 | | - | |
| 642 | + | |
| 643 | + | |
643 | 644 | | |
644 | 645 | | |
645 | 646 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
486 | 486 | | |
487 | 487 | | |
488 | 488 | | |
489 | | - | |
490 | 489 | | |
491 | 490 | | |
492 | 491 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
120 | 120 | | |
121 | 121 | | |
122 | 122 | | |
123 | | - | |
| 123 | + | |
124 | 124 | | |
125 | 125 | | |
126 | 126 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | | - | |
25 | 24 | | |
26 | 25 | | |
27 | 26 | | |
| |||
34 | 33 | | |
35 | 34 | | |
36 | 35 | | |
| 36 | + | |
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
| |||
299 | 299 | | |
300 | 300 | | |
301 | 301 | | |
302 | | - | |
303 | 302 | | |
304 | 303 | | |
305 | 304 | | |
| |||
311 | 310 | | |
312 | 311 | | |
313 | 312 | | |
314 | | - | |
315 | | - | |
316 | 313 | | |
317 | 314 | | |
318 | 315 | | |
| |||
342 | 339 | | |
343 | 340 | | |
344 | 341 | | |
| 342 | + | |
| 343 | + | |
345 | 344 | | |
346 | 345 | | |
347 | 346 | | |
| |||
364 | 363 | | |
365 | 364 | | |
366 | 365 | | |
367 | | - | |
| 366 | + | |
| 367 | + | |
368 | 368 | | |
369 | 369 | | |
370 | 370 | | |
371 | 371 | | |
372 | 372 | | |
373 | 373 | | |
374 | | - | |
| 374 | + | |
375 | 375 | | |
376 | 376 | | |
377 | 377 | | |
| |||
382 | 382 | | |
383 | 383 | | |
384 | 384 | | |
385 | | - | |
| 385 | + | |
386 | 386 | | |
387 | 387 | | |
388 | 388 | | |
| |||
0 commit comments