@@ -50,20 +50,23 @@ class TensorIntrin(NodeBase):
5050 decl_tensor_intrin: Construct a TensorIntrin
5151 """
5252 def __call__ (self , * args , ** kwargs ):
53- tensors = [x .tensor for x in args ]
54- regions = [_get_region (x ) for x in args ]
53+ tensors = [x .tensor for x in args if isinstance (x , _tensor .TensorSlice )]
54+ scalar_inputs = [x for x in args if not isinstance (x , _tensor .TensorSlice )]
55+ regions = [_get_region (x ) for x in args if isinstance (x , _tensor .TensorSlice )]
5556 reduce_axis = []
5657 if "reduce_axis" in kwargs :
5758 reduce_axis = kwargs ["reduce_axis" ]
5859 if not isinstance (reduce_axis , (list , tuple )):
5960 reduce_axis = [reduce_axis ]
6061 reduce_axis = _api .convert (reduce_axis )
61- return _api_internal ._TensorIntrinCall (self , tensors , regions , reduce_axis )
62+ if scalar_inputs :
63+ scalar_inputs = _api .convert (scalar_inputs )
64+ return _api_internal ._TensorIntrinCall (self , tensors , regions , reduce_axis , scalar_inputs )
6265
6366def decl_tensor_intrin (op ,
6467 fcompute ,
6568 name = "tensor_intrin" ,
66- binds = None ):
69+ binds = None , scalar_params = None ):
6770 """Declare a tensor intrinsic function.
6871
6972 Parameters
@@ -96,6 +99,9 @@ def decl_tensor_intrin(op,
9699 requirement of the function. By default, a new compact buffer is created
97100 for each tensor in the argument.
98101
102+ scalar_params: a list of variables used by op, whose values will be passed
103+ as scalar_inputs when the tensor intrinsic is called.
104+
99105 Returns
100106 -------
101107 intrin: TensorIntrin
@@ -122,11 +128,15 @@ def decl_tensor_intrin(op,
122128 offset_factor = cfg .offset_factor ))
123129 binds_list .append (buf )
124130
125- body = fcompute (binds_list [:len (inputs )], binds_list [len (inputs ):])
131+ if scalar_params :
132+ body = fcompute (binds_list [:len (inputs )], binds_list [len (inputs ):], scalar_params )
133+ else :
134+ body = fcompute (binds_list [:len (inputs )], binds_list [len (inputs ):])
135+ scalar_params = []
126136 if isinstance (body , (_expr .Expr , _stmt .Stmt )):
127137 body = [body ]
128138 body = [_make .Evaluate (x ) if isinstance (x , _expr .Expr ) else x for x in body ]
129139 if len (body ) < 3 :
130140 body += [None ] * (3 - len (body ))
131141 return _api_internal ._TensorIntrin (
132- name , op , inputs , binds_list , * body )
142+ name , op , inputs , binds_list , scalar_params , * body )
0 commit comments