Skip to content

Commit 1d475cb

Browse files
ekaldaRishabh Jain
authored andcommitted
[microNPU] Add unary elementwise operator infrastructure with ABS (apache#9530)
* [microNPU] Add unary elementwise operator infrastructure with ABS * Added unary elementwise ABS legalization support and tests * Added unary_elementwise Relay to TIR lowering and tests * Added TIR to Vela translation and tests * Added codegen tests Co-authored-by: Rishabh Jain <rishabh.jain2@arm.com>
1 parent 37e4d75 commit 1d475cb

20 files changed

Lines changed: 1236 additions & 27 deletions

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,96 @@ def __call__(self, *args, **kwargs):
741741
pass
742742

743743

744+
class UnaryElementwiseRewriter(DFPatternCallback):
745+
"""
746+
Convert ethosu unary elementwise composite function to
747+
ethosu_unary_elementwise operators
748+
"""
749+
750+
def __init__(self, params_class: Type, pattern: CallPattern):
751+
super().__init__(require_type=True)
752+
self.params_class = params_class
753+
self.pattern = pattern
754+
755+
def callback(
756+
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
757+
) -> tvm.relay.Expr:
758+
params = self.params_class(post.op.body)
759+
params.ifm.tensor = post.args[0]
760+
761+
if str(params.ofm.layout) != "NHWC":
762+
raise UnsupportedLayout(str(params.ofm.layout))
763+
764+
activation_map = {"clip": "CLIP"}
765+
if params.activation:
766+
activation = activation_map[params.activation.op.name]
767+
clip_min = int(params.activation.attrs.a_min)
768+
clip_max = int(params.activation.attrs.a_max)
769+
else:
770+
activation = "NONE"
771+
clip_min = 0
772+
clip_max = 0
773+
774+
# We don't yet support activation functions that use LUT.
775+
lut = relay.const([], dtype="int8")
776+
777+
unary_input_shape = params.ifm.shape
778+
# If the input tensor is not 4D, enter reshapes before and after the unary operator
779+
if len(params.ifm.shape) == 4:
780+
unary_input = params.ifm.tensor
781+
else:
782+
pad_size = 4 - len(unary_input_shape)
783+
unary_input_shape = ([1] * pad_size) + unary_input_shape
784+
unary_input = relay.op.reshape(params.ifm.tensor, newshape=unary_input_shape)
785+
786+
ethosu_unary_elementwise = ethosu_ops.ethosu_unary_elementwise(
787+
ifm=unary_input,
788+
lut=lut,
789+
operator_type=params.operator_type,
790+
ifm_scale=float(params.ifm.q_params.scale_f32),
791+
ifm_zero_point=int(params.ifm.q_params.zero_point),
792+
ofm_scale=float(params.ofm.q_params.scale_f32),
793+
ofm_zero_point=int(params.ofm.q_params.zero_point),
794+
ofm_channels=unary_input_shape[3],
795+
activation=activation,
796+
clip_min=clip_min,
797+
clip_max=clip_max,
798+
ifm_layout=str(params.ifm.layout),
799+
ofm_layout=str(params.ofm.layout),
800+
)
801+
if len(params.ifm.shape) == 4:
802+
op = ethosu_unary_elementwise
803+
else:
804+
op = relay.op.reshape(ethosu_unary_elementwise, newshape=params.ifm.shape)
805+
return op
806+
807+
808+
class AbsRewriter(UnaryElementwiseRewriter):
809+
def __init__(self):
810+
super().__init__(
811+
params_class=ethosu_patterns.AbsParams,
812+
pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AbsParams.composite_name}))(
813+
wildcard()
814+
),
815+
)
816+
817+
818+
@ir.transform.module_pass(opt_level=1)
819+
class LegalizeAbs:
820+
"""This is the pass that wraps the AbsRewriter"""
821+
822+
def transform_module(
823+
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
824+
) -> tvm.ir.IRModule:
825+
for global_var, func in mod.functions.items():
826+
func = rewrite(AbsRewriter(), func)
827+
mod.update_func(global_var, func)
828+
return mod
829+
830+
def __call__(self, *args, **kwargs):
831+
pass
832+
833+
744834
@ir.transform.module_pass(opt_level=1)
745835
class LegalizeEthosU:
746836
"""This is the pass to call graph-rewrites to perform graph transformation
@@ -765,6 +855,7 @@ def transform_module(
765855
mod = LegalizeMin()(mod)
766856
mod = LegalizeMax()(mod)
767857
mod = LegalizeShl()(mod)
858+
mod = LegalizeAbs()(mod)
768859
mod = LegalizeReshape()(mod)
769860
mod = LegalizeStridedSlice()(mod)
770861
mod = LegalizeNoOps()(mod)

python/tvm/relay/backend/contrib/ethosu/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
from .pooling import ethosu_pooling
2222
from .binary_elementwise import ethosu_binary_elementwise
2323
from .identity import ethosu_identity
24+
from .unary_elementwise import ethosu_unary_elementwise
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-argument
18+
"""Relay operator for unary elementwise operations for Arm(R) Ethos(TM)-U NPU"""
19+
from typing import Optional
20+
import tvm
21+
from tvm.relay.op import _make
22+
from tvm.topi.generic import schedule_injective
23+
from tvm.relay.op.op import OpStrategy
24+
from tvm.relay.op import strategy as _strategy
25+
26+
from ..te import unary_elementwise_compute
27+
28+
29+
def _extract_ethosu_unary_elementwise_params(attrs, args):
30+
"""Get the parameters necessary to construct a ethosu_unary_elementwise compute TE
31+
from a ethosu_unary_elementwise Relay call."""
32+
ifm = args[0]
33+
lut = args[1]
34+
operator_type = attrs.operator_type
35+
ifm_scale = attrs.ifm_scale
36+
ifm_zero_point = attrs.ifm_zero_point
37+
ofm_scale = attrs.ofm_scale
38+
ofm_zero_point = attrs.ofm_zero_point
39+
ofm_channels = attrs.ofm_channels
40+
activation = attrs.activation
41+
clip_min = attrs.clip_min
42+
clip_max = attrs.clip_max
43+
rounding_mode = attrs.rounding_mode
44+
ifm_layout = attrs.ifm_layout
45+
ofm_layout = attrs.ofm_layout
46+
47+
return (
48+
ifm,
49+
lut,
50+
operator_type,
51+
ifm_scale,
52+
ifm_zero_point,
53+
ofm_scale,
54+
ofm_zero_point,
55+
ofm_channels,
56+
activation,
57+
clip_min,
58+
clip_max,
59+
rounding_mode,
60+
ifm_layout,
61+
ofm_layout,
62+
)
63+
64+
65+
@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMCompute")
66+
def create_ethosu_unary_elementwise_compute(attrs, args, out_type):
67+
"""Create an ethosu_unary_elementwise compute op."""
68+
params = _extract_ethosu_unary_elementwise_params(attrs, args)
69+
op = unary_elementwise_compute(*params)
70+
return [op]
71+
72+
73+
@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMStrategy")
74+
def unary_elementwise_strategy_ethosu(attrs, inputs, out_type, target):
75+
strategy = OpStrategy()
76+
strategy.add_implementation(
77+
create_ethosu_unary_elementwise_compute,
78+
_strategy.wrap_topi_schedule(schedule_injective),
79+
name="ethosu_unary_elementwise",
80+
)
81+
return strategy
82+
83+
84+
def ethosu_unary_elementwise(
85+
ifm: tvm.relay.Expr,
86+
lut: tvm.relay.Expr,
87+
operator_type: str,
88+
ifm_scale: float,
89+
ifm_zero_point: int,
90+
ofm_scale: float,
91+
ofm_zero_point: int,
92+
ofm_channels: int,
93+
activation: Optional[str] = "NONE",
94+
clip_min: Optional[int] = 0,
95+
clip_max: Optional[int] = 0,
96+
rounding_mode: Optional[str] = "TFL",
97+
ifm_layout: Optional[str] = "NHWC",
98+
ofm_layout: Optional[str] = "NHWC",
99+
) -> tvm.relay.Call:
100+
"""This is a quantized unary elementwise operation as supported by the
101+
NPU. It accepts either NHWC or NHCWB16 format for the input data.
102+
103+
Parameters
104+
----------
105+
ifm : tvm.relay.Expr
106+
The Input Feature Map tensor (IFM).
107+
lut : tvm.relay.Expr
108+
The look-up table values to use if activation = "LUT".
109+
operator_type: str
110+
The type of the unary elementwise operator.
111+
"ABS"
112+
ifm_scale : float
113+
The quantization scale for the Input Feature Map tensor.
114+
ifm_zero_point : int
115+
The quantization zero point for the Input Feature Map tensor.
116+
ofm_scale : float
117+
The quantization scale for the Output Feature Map tensor.
118+
ofm_zero_point : int
119+
The quantization zero point for the Output Feature Map tensor.
120+
ofm_channels : int
121+
The number of OFM channels.
122+
activation : str, optional
123+
The activation function to use.
124+
"NONE" - no activation function.
125+
"CLIP" - clip the output between clip_min and clip_max.
126+
"TANH" - tanh activation function.
127+
"SIGMOID" - sigmoid activation function.
128+
"LUT" - use a look-up table to perform the activation function.
129+
clip_min : int, optional
130+
The minimum clipping value if activation = "CLIP".
131+
clip_max : int, optional
132+
The maximum clipping value if activation = "CLIP".
133+
rounding_mode : str, optional
134+
The rounding mode to apply to the Output Feature Map tensor.
135+
"TFL" - Tensorflow Lite rounding scheme.
136+
"TRUNCATE" - Truncate towards zero.
137+
"NATURAL" - Round to nearest value, with x.5 rounded up towards +infinity.
138+
ifm_layout : str, optional
139+
The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
140+
ofm_layout : str, optional
141+
The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
142+
143+
Returns
144+
-------
145+
out : tvm.relay.Call
146+
A call to the ethosu_binary_elementwise op.
147+
"""
148+
return _make.ethosu_unary_elementwise(
149+
ifm,
150+
lut,
151+
operator_type,
152+
ifm_scale,
153+
ifm_zero_point,
154+
ofm_scale,
155+
ofm_zero_point,
156+
ofm_channels,
157+
activation,
158+
clip_min,
159+
clip_max,
160+
rounding_mode,
161+
ifm_layout,
162+
ofm_layout,
163+
)

python/tvm/relay/backend/contrib/ethosu/te/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
from .pooling import *
2222
from .binary_elementwise import *
2323
from .identity import *
24+
from .unary_elementwise import *

0 commit comments

Comments
 (0)