1818"""x86 batch_matmul operators"""
1919from __future__ import absolute_import as _abs
2020import tvm
21+ from tvm import autotvm
22+ from tvm .autotvm .task .space import SplitEntity
2123from tvm .contrib import cblas
22- from topi .nn import batch_matmul , batch_matmul_default
23- from .. import generic
24+ from .. import generic , nn
2425from ..util import traverse_inline , get_const_tuple , get_max_power2_factor
2526
26- @batch_matmul .register (["cpu" ])
27- def batch_matmul_x86 (x , y ):
27+
28+ @autotvm .register_topi_compute (nn .batch_matmul , "cpu" , "direct" )
29+ def _declaration_batch_matmul_nopack (cfg , x , y ):
2830 """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
2931 data in batch.
3032
3133 Parameters
3234 ----------
35+ cfg : ConfigSpace
36+ Autotvm tuning space config file
3337 x : tvm.Tensor
3438 3-D with shape [batch, M, K]
35-
3639 y : tvm.Tensor
3740 3-D with shape [batch, N, K]
38-
3941 Returns
4042 -------
4143 output : tvm.Tensor
@@ -44,17 +46,37 @@ def batch_matmul_x86(x, y):
4446 target = tvm .target .current_target ()
4547 if "cblas" in target .libs :
4648 return cblas .batch_matmul (x , y , False , True )
47- return batch_matmul_default (x , y )
4849
49- @generic .schedule_batch_matmul .register (["cpu" ])
50- def schedule_batch_matmul (outs ):
50+ assert len (x .shape ) == 3 and len (
51+ y .shape ) == 3 , "only support 3-dim batch_matmul"
52+ XB , M , XK = get_const_tuple (x .shape )
53+ YB , N , YK = get_const_tuple (y .shape )
54+ assert XB == YB , "batch dimension doesn't match"
55+ assert XK == YK , "shapes of x and y is inconsistant"
56+ B = XB
57+ K = XK
58+ if cfg .is_fallback :
59+ _default_batch_matmul_nopack_config (cfg , M , N , K )
60+
61+ k = tvm .reduce_axis ((0 , K ), name = 'k' )
62+ C = tvm .compute (
63+ (B , M , N ),
64+ lambda b , i , j : tvm .sum (x [b , i , k ] * y [b , j , k ], axis = k ),
65+ tag = 'batch_matmul' )
66+ return C
67+
68+
69+ @autotvm .register_topi_schedule (generic .schedule_batch_matmul , "cpu" , "direct" )
70+ def schedule_batch_matmul (cfg , outs ):
5171 """Schedule for batch_matmul
5272
5373 Parameters
5474 ----------
55- outs: Array of Tensor
56- The computation graph description of batch_matmul
57- in the format of an array of tensors.
75+ cfg : ConfigSpace
76+ AutoTVM tuning space config file.
77+ outs : Array of Tensor
78+ The computation graph description of batch_matmul
79+ in the format of an array of tensors.
5880
5981 Returns
6082 -------
@@ -71,16 +93,22 @@ def _callback(op):
7193 if "batch_matmul" in op .tag :
7294 C = op .output (0 )
7395 A , B = s [C ].op .input_tensors
74- _ , M , N = get_const_tuple (C .shape )
96+ _ , M , K = get_const_tuple (A .shape )
97+ _ , _ , N = get_const_tuple (C .shape )
98+
99+ # create tuning space
100+ cfg .define_split ("tile_y" , M , num_outputs = 2 )
101+ cfg .define_split ("tile_x" , N , num_outputs = 2 )
102+ cfg .define_split ("tile_k" , K , num_outputs = 2 )
103+
75104 k , = s [C ].op .reduce_axis
76- ko , ki = s [C ].split (k , 16 )
105+
106+ ko , ki = cfg ["tile_k" ].apply (s , C , k )
77107 CC = s .rfactor (C , ki )
78108
79109 b , y , x = s [C ].op .axis
80- y_bn = get_max_power2_factor (M , 8 )
81- x_bn = get_max_power2_factor (N , 8 )
82- yo , yi = s [C ].split (y , y_bn )
83- xo , xi = s [C ].split (x , x_bn )
110+ yo , yi = cfg ["tile_y" ].apply (s , C , y )
111+ xo , xi = cfg ["tile_x" ].apply (s , C , x )
84112 s [C ].reorder (b , yo , xo , yi , xi )
85113 bxyo = s [C ].fuse (b , yo , xo )
86114 s [C ].parallel (bxyo )
@@ -94,3 +122,11 @@ def _callback(op):
94122
95123 traverse_inline (s , outs [0 ].op , _callback )
96124 return s
125+
126+
127+ def _default_batch_matmul_nopack_config (cfg , M , N , K ):
128+ cfg ["tile_k" ] = SplitEntity ([K // 16 , 16 ])
129+ x_bn = get_max_power2_factor (N , 8 )
130+ cfg ["tile_x" ] = SplitEntity ([N // x_bn , x_bn ])
131+ y_bn = get_max_power2_factor (M , 8 )
132+ cfg ["tile_y" ] = SplitEntity ([M // y_bn , y_bn ])
0 commit comments