|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=unused-argument |
| 18 | +"""Relay operator for unary elementwise operations for Arm(R) Ethos(TM)-U NPU""" |
| 19 | +from typing import Optional |
| 20 | +import tvm |
| 21 | +from tvm.relay.op import _make |
| 22 | +from tvm.topi.generic import schedule_injective |
| 23 | +from tvm.relay.op.op import OpStrategy |
| 24 | +from tvm.relay.op import strategy as _strategy |
| 25 | + |
| 26 | +from ..te import unary_elementwise_compute |
| 27 | + |
| 28 | + |
| 29 | +def _extract_ethosu_unary_elementwise_params(attrs, args): |
| 30 | + """Get the parameters necessary to construct a ethosu_unary_elementwise compute TE |
| 31 | + from a ethosu_unary_elementwise Relay call.""" |
| 32 | + ifm = args[0] |
| 33 | + lut = args[1] |
| 34 | + operator_type = attrs.operator_type |
| 35 | + ifm_scale = attrs.ifm_scale |
| 36 | + ifm_zero_point = attrs.ifm_zero_point |
| 37 | + ofm_scale = attrs.ofm_scale |
| 38 | + ofm_zero_point = attrs.ofm_zero_point |
| 39 | + ofm_channels = attrs.ofm_channels |
| 40 | + activation = attrs.activation |
| 41 | + clip_min = attrs.clip_min |
| 42 | + clip_max = attrs.clip_max |
| 43 | + rounding_mode = attrs.rounding_mode |
| 44 | + ifm_layout = attrs.ifm_layout |
| 45 | + ofm_layout = attrs.ofm_layout |
| 46 | + |
| 47 | + return ( |
| 48 | + ifm, |
| 49 | + lut, |
| 50 | + operator_type, |
| 51 | + ifm_scale, |
| 52 | + ifm_zero_point, |
| 53 | + ofm_scale, |
| 54 | + ofm_zero_point, |
| 55 | + ofm_channels, |
| 56 | + activation, |
| 57 | + clip_min, |
| 58 | + clip_max, |
| 59 | + rounding_mode, |
| 60 | + ifm_layout, |
| 61 | + ofm_layout, |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMCompute") |
| 66 | +def create_ethosu_unary_elementwise_compute(attrs, args, out_type): |
| 67 | + """Create an ethosu_unary_elementwise compute op.""" |
| 68 | + params = _extract_ethosu_unary_elementwise_params(attrs, args) |
| 69 | + op = unary_elementwise_compute(*params) |
| 70 | + return [op] |
| 71 | + |
| 72 | + |
| 73 | +@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMStrategy") |
| 74 | +def unary_elementwise_strategy_ethosu(attrs, inputs, out_type, target): |
| 75 | + strategy = OpStrategy() |
| 76 | + strategy.add_implementation( |
| 77 | + create_ethosu_unary_elementwise_compute, |
| 78 | + _strategy.wrap_topi_schedule(schedule_injective), |
| 79 | + name="ethosu_unary_elementwise", |
| 80 | + ) |
| 81 | + return strategy |
| 82 | + |
| 83 | + |
| 84 | +def ethosu_unary_elementwise( |
| 85 | + ifm: tvm.relay.Expr, |
| 86 | + lut: tvm.relay.Expr, |
| 87 | + operator_type: str, |
| 88 | + ifm_scale: float, |
| 89 | + ifm_zero_point: int, |
| 90 | + ofm_scale: float, |
| 91 | + ofm_zero_point: int, |
| 92 | + ofm_channels: int, |
| 93 | + activation: Optional[str] = "NONE", |
| 94 | + clip_min: Optional[int] = 0, |
| 95 | + clip_max: Optional[int] = 0, |
| 96 | + rounding_mode: Optional[str] = "TFL", |
| 97 | + ifm_layout: Optional[str] = "NHWC", |
| 98 | + ofm_layout: Optional[str] = "NHWC", |
| 99 | +) -> tvm.relay.Call: |
| 100 | + """This is a quantized unary elementwise operation as supported by the |
| 101 | + NPU. It accepts either NHWC or NHCWB16 format for the input data. |
| 102 | +
|
| 103 | + Parameters |
| 104 | + ---------- |
| 105 | + ifm : tvm.relay.Expr |
| 106 | + The Input Feature Map tensor (IFM). |
| 107 | + lut : tvm.relay.Expr |
| 108 | + The look-up table values to use if activation = "LUT". |
| 109 | + operator_type: str |
| 110 | + The type of the unary elementwise operator. |
| 111 | + "ABS" |
| 112 | + ifm_scale : float |
| 113 | + The quantization scale for the Input Feature Map tensor. |
| 114 | + ifm_zero_point : int |
| 115 | + The quantization zero point for the Input Feature Map tensor. |
| 116 | + ofm_scale : float |
| 117 | + The quantization scale for the Output Feature Map tensor. |
| 118 | + ofm_zero_point : int |
| 119 | + The quantization zero point for the Output Feature Map tensor. |
| 120 | + ofm_channels : int |
| 121 | + The number of OFM channels. |
| 122 | + activation : str, optional |
| 123 | + The activation function to use. |
| 124 | + "NONE" - no activation function. |
| 125 | + "CLIP" - clip the output between clip_min and clip_max. |
| 126 | + "TANH" - tanh activation function. |
| 127 | + "SIGMOID" - sigmoid activation function. |
| 128 | + "LUT" - use a look-up table to perform the activation function. |
| 129 | + clip_min : int, optional |
| 130 | + The minimum clipping value if activation = "CLIP". |
| 131 | + clip_max : int, optional |
| 132 | + The maximum clipping value if activation = "CLIP". |
| 133 | + rounding_mode : str, optional |
| 134 | + The rounding mode to apply to the Output Feature Map tensor. |
| 135 | + "TFL" - Tensorflow Lite rounding scheme. |
| 136 | + "TRUNCATE" - Truncate towards zero. |
| 137 | + "NATURAL" - Round to nearest value, with x.5 rounded up towards +infinity. |
| 138 | + ifm_layout : str, optional |
| 139 | + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". |
| 140 | + ofm_layout : str, optional |
| 141 | + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". |
| 142 | +
|
| 143 | + Returns |
| 144 | + ------- |
| 145 | + out : tvm.relay.Call |
| 146 | + A call to the ethosu_binary_elementwise op. |
| 147 | + """ |
| 148 | + return _make.ethosu_unary_elementwise( |
| 149 | + ifm, |
| 150 | + lut, |
| 151 | + operator_type, |
| 152 | + ifm_scale, |
| 153 | + ifm_zero_point, |
| 154 | + ofm_scale, |
| 155 | + ofm_zero_point, |
| 156 | + ofm_channels, |
| 157 | + activation, |
| 158 | + clip_min, |
| 159 | + clip_max, |
| 160 | + rounding_mode, |
| 161 | + ifm_layout, |
| 162 | + ofm_layout, |
| 163 | + ) |
0 commit comments