Skip to content

Commit 1567dae

Browse files
Author Codrut-Grigore IrimieCodrut-Grigore Irimie
authored andcommitted
[CMSIS-NN] Support for Softmax Int16 operator
1 parent cd45513 commit 1567dae

6 files changed

Lines changed: 343 additions & 28 deletions

File tree

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,17 @@ def check_qnn_softmax(pattern):
8686
zero_point = pattern.args[2].data.numpy().item(0)
8787

8888
# check for dtypes of quantize and dequantize
89-
return (
90-
(scale == 1.0 / 256 and zero_point == -128)
89+
if ((scale == 1.0 / 256 and zero_point == -128)
9190
and pattern.attrs.out_dtype == "int8"
92-
and dequantize_call.args[0].checked_type.dtype == "int8"
93-
)
91+
and dequantize_call.args[0].checked_type.dtype == "int8"):
92+
return True
93+
94+
if ((scale == 1.0 / 32768 and zero_point == 0)
95+
and pattern.attrs.out_dtype == "int16"
96+
and dequantize_call.args[0].checked_type.dtype == "int16"):
97+
return True
98+
99+
return False
94100

95101
def qnn_conv2d_pattern(with_pad):
96102
"""Create pattern for qnn.conv2D with optional pad and/or optional fused relu."""
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "compute_luts.h"
21+
22+
#include <cmath>
23+
#include <algorithm>
24+
25+
namespace tvm {
26+
namespace relay {
27+
namespace contrib {
28+
namespace cmsisnn {
29+
30+
void CalculateLUTInt16(int key_zero_point, float key_scale, int value_zero_point, float value_scale,
31+
float (*func)(float), const int steps, int16_t* lut) {
32+
33+
const float value_min = static_cast<float>(std::numeric_limits<int16_t>::min());
34+
const float value_max = static_cast<float>(std::numeric_limits<int16_t>::max());
35+
const float key_min_deq = key_scale * (std::numeric_limits<int16_t>::min() - key_zero_point);
36+
const float key_max_deq = key_scale * (std::numeric_limits<int16_t>::max() - key_zero_point);
37+
const float value_min_deq = value_scale * (std::numeric_limits<int16_t>::min() - value_zero_point);
38+
const float value_max_deq = value_scale * (std::numeric_limits<int16_t>::max() - value_zero_point);
39+
40+
const float step_size_deq = (key_max_deq - key_min_deq) / (steps - 1);
41+
const float half_step_size_deq = step_size_deq / 2;
42+
43+
const float value_inv_quantizing = (std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min() + 1) /
44+
(value_max_deq - value_min_deq);
45+
46+
47+
48+
for (int i = 0; i < steps - 1; i++) {
49+
float value_deq = func(key_min_deq + i * step_size_deq);
50+
float mid_value_deq = func(key_min_deq + i * step_size_deq + half_step_size_deq);
51+
float next_value_deq = func(key_min_deq + (i + 1) * step_size_deq);
52+
53+
float value = std::round(value_deq * value_inv_quantizing);
54+
float mid_value = std::round(mid_value_deq * value_inv_quantizing);
55+
float next_value = std::round(next_value_deq * value_inv_quantizing);
56+
float mid_iterp_value = std::round((value + next_value) / 2);
57+
58+
float mid_err = mid_iterp_value - mid_value;
59+
float bias = std::round(mid_err / 2);
60+
61+
lut[i] = static_cast<int16_t>(std::max(std::min(value - bias, value_max), value_min));
62+
}
63+
64+
lut[steps - 1] = static_cast<int16_t>(std::max(std::min(func(value_max_deq) * value_inv_quantizing, value_max), value_min));
65+
}
66+
67+
} // namespace cmsisnn
68+
} // namespace contrib
69+
} // namespace relay
70+
} // namespace tvm
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/backend/contrib/cmsisnn/compute_luts.h
22+
* \brief CMSIS-NN LUTs calculation functions
23+
*/
24+
25+
#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_LUT_H_
26+
#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_LUT_H_
27+
28+
#include <cstdint>
29+
30+
namespace tvm {
31+
namespace relay {
32+
namespace contrib {
33+
namespace cmsisnn {
34+
/*!
35+
* \brief Populates an int16 LUT based on the quantization parameters of its keys, values and respective transformation function
36+
*
37+
* \param key_zero_point - zero point of table's keys
38+
* \param key_scale - scale of the table's keys
39+
* \param value_zero_point - zero point of table's values
40+
* \param value_scale - scale of the table's values
41+
* \param func - function pointer of the transformation performed by the LUT
42+
* \param steps - number of total values inside the table
43+
* \param lut - int16_t array storing the values of the LUT
44+
*/
45+
void CalculateLUTInt16(int key_zero_point, float key_scale, int value_zero_point, float value_scale,
46+
float (*func)(float), const int steps, int16_t* lut);
47+
48+
} //namespace cmsisnn
49+
} //namespace contrib
50+
} //namespace relay
51+
} //namespace tvm
52+
53+
#endif
54+

src/relay/backend/contrib/cmsisnn/relay_to_tir.cc

Lines changed: 109 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "buffer_size.h"
3232
#include "compiler_attrs.h"
3333
#include "convolutions.h"
34+
#include "compute_luts.h"
3435

3536
namespace tvm {
3637
namespace relay {
@@ -89,26 +90,42 @@ class RelayToTIRVisitor : public MixedModeMutator {
8990
private:
9091
inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); }
9192

93+
//struct used to allocated const NDArray
94+
struct user_const {
95+
tir::Var buffer_var;
96+
int num_bits;
97+
Array<PrimExpr> extents;
98+
tvm::runtime::NDArray ndarray;
99+
};
100+
92101
void CreatePrimFuncForExtern(const GlobalVar& global_var, Array<tir::Var> func_signature,
93102
const Map<tir::Var, tir::Buffer>& buffer_map,
94103
tvm::Array<PrimExpr> call_extern_args,
95104
PrimExpr context_buffer_var = PrimExpr(),
96-
int context_buffer_size = 0, int num_bits = 8) {
105+
int context_buffer_size = 0,
106+
int num_bits = 8,
107+
std::vector<user_const> context_const_buffer_vars = {}) {
97108
Map<String, ObjectRef> dict_attrs;
98109
dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
99110
dict_attrs.Set(tvm::attr::kTarget, target_);
100111
dict_attrs.Set("tir.noalias", Bool(true));
101112

102113
tir::Stmt body = tir::Evaluate(
103114
tvm::tir::Call(DataType::Int(num_bits), tir::builtin::call_extern(), call_extern_args));
104-
115+
105116
if (context_buffer_size) {
106117
body = tir::Allocate(Downcast<tir::Var>(context_buffer_var), DataType::Int(num_bits),
107118
{context_buffer_size}, tir::const_true(), body);
108119
}
109-
120+
121+
for (int i = 0; i < int(context_const_buffer_vars.size()); i++){
122+
body = tir::AllocateConst(Downcast<tir::Var>(context_const_buffer_vars[i].buffer_var), DataType::Int(context_const_buffer_vars[i].num_bits),
123+
context_const_buffer_vars[i].extents, context_const_buffer_vars[i].ndarray, body);
124+
}
125+
110126
tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map,
111127
DictAttrs(dict_attrs));
128+
112129
ir_module_->Add(global_var, replacement_func);
113130
}
114131

@@ -505,6 +522,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
505522
const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
506523
const CallNode* dequant_call = softmax_call->args[0].as<CallNode>();
507524
const float quant_scale = GetScalarFromConstant<float>(dequant_call->args[1]);
525+
const auto bit_width = quantize_call->type_as<TensorTypeNode>()->dtype.bits();
526+
LOG(INFO) << PrettyPrint(quantize_call->args[0]);
527+
LOG(INFO) << PrettyPrint(softmax_call->args[0]);
528+
LOG(INFO) << PrettyPrint(dequant_call->args[0]);
508529

509530
// assuming layout as NHWC
510531
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
@@ -517,36 +538,103 @@ class RelayToTIRVisitor : public MixedModeMutator {
517538

518539
// calculate multiplier and shift for CMSIS-NN softmax API
519540
// Note: TensorFlow Lite Micro assumptions
520-
// Output zero point and scale are fixed to -128 and 1 / 256
541+
// Output zero point and scale are fixed to -128 and 1 / 256 in the case of an int8 operator or to 0 and 1 / 32768
521542
// kScaledDiffIntegerBits, kInputBits, kBeta are described on the following github page
522-
// https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47
523-
double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits)));
524-
beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
525-
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
526-
int32_t mult = std::get<0>(mult_shift_pair);
527-
int32_t shift = std::get<1>(mult_shift_pair);
528-
int32_t diff_min = (1 << kScaledDiffIntegerBits) - 1;
529-
diff_min <<= (31 - kScaledDiffIntegerBits);
530-
diff_min >>= shift;
531-
diff_min *= -1;
532-
543+
// https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/exp_zero_pointmicro/kernels/softmax_common.cc#L47
544+
545+
int32_t mult;
546+
int32_t shift;
547+
int32_t diff_min = 0;
548+
549+
std::vector<user_const> softmax_params(2);
550+
Device dev{DLDeviceType::kDLCPU, 0};
551+
552+
if (bit_width == 8){
553+
double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits)));
554+
beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
555+
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
556+
mult = std::get<0>(mult_shift_pair);
557+
shift = std::get<1>(mult_shift_pair);
558+
diff_min = (1 << kScaledDiffIntegerBits) - 1;
559+
diff_min <<= (31 - kScaledDiffIntegerBits);
560+
diff_min >>= shift;
561+
diff_min *= -1;
562+
}
563+
else { //bit_width == 16
564+
double scale_beta_rescale = quant_scale * kBeta / (10.0 / 65535.0);
565+
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(scale_beta_rescale);
566+
mult = std::get<0>(mult_shift_pair);
567+
shift = std::get<1>(mult_shift_pair);
568+
569+
int lut_entries = 513;
570+
int16_t softmax_s16_exp_lut[lut_entries];
571+
int16_t softmax_s16_one_by_one_lut[lut_entries];
572+
573+
const int range_int16 = std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min();
574+
int exp_zero_point = std::numeric_limits<int16_t>::max();
575+
float exp_scale = 10.0f / range_int16;
576+
577+
int one_by_one_zero_point = std::numeric_limits<int16_t>::min();
578+
float one_by_one_scale = 1.0f / range_int16;
579+
580+
int lut_value_zero_point = 0;
581+
float lut_value_scale = 2.0f / range_int16;
582+
583+
CalculateLUTInt16(exp_zero_point, exp_scale, lut_value_zero_point, lut_value_scale,
584+
[](float key){ return std::exp(key); }, lut_entries, softmax_s16_exp_lut);
585+
CalculateLUTInt16(one_by_one_zero_point, one_by_one_scale, lut_value_zero_point, lut_value_scale,
586+
[](float key){ return 1.0f / (1.0f + key); }, lut_entries, softmax_s16_one_by_one_lut);
587+
588+
//first LUT
589+
softmax_params[0].buffer_var = tir::Var("exp_lut", PointerType(PrimType(DataType::Int(bit_width)), "global.workspace"));
590+
softmax_params[0].ndarray = runtime::NDArray::Empty({lut_entries}, DataType::Int(bit_width), dev);
591+
softmax_params[0].ndarray.CopyFromBytes(softmax_s16_exp_lut, sizeof(int16_t)*lut_entries);
592+
softmax_params[0].extents = {lut_entries};
593+
softmax_params[0].num_bits = 16;
594+
595+
//second LUT
596+
softmax_params[1].buffer_var = tir::Var("one_by_one_lut", PointerType(PrimType(DataType::Int(bit_width)), "global.workspace"));
597+
softmax_params[1].ndarray = runtime::NDArray::Empty({lut_entries}, DataType::Int(bit_width), dev);
598+
softmax_params[1].ndarray.CopyFromBytes(softmax_s16_one_by_one_lut, sizeof(int16_t)*lut_entries);
599+
softmax_params[1].extents = {lut_entries};
600+
softmax_params[1].num_bits = 16;
601+
}
602+
533603
BufferCreator buffer_creator;
534-
tir::Var in_var = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
535-
tir::Var out_var = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
604+
tir::Var in_var = buffer_creator.CreateBufferVar("input", DataType::Handle(bit_width));
605+
tir::Var out_var = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width));
536606

537-
tvm::Array<PrimExpr> args = {
538-
tir::StringImm("arm_softmax_s8"),
607+
if (bit_width == 8) {
608+
tvm::Array<PrimExpr> args = {
609+
tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
539610
in_var,
540611
ToArg(num_rows),
541612
ToArg(row_size),
542613
ToArg(mult),
543614
ToArg(shift),
544615
ToArg(diff_min),
545616
out_var,
546-
};
617+
};
547618

548-
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
619+
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
549620
buffer_creator.GetBufferMap(), args);
621+
} else { //bit_width == 16
622+
tvm::Array<PrimExpr> args = {
623+
tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
624+
in_var,
625+
ToArg(num_rows),
626+
ToArg(row_size),
627+
ToArg(mult),
628+
ToArg(shift),
629+
softmax_params[0].buffer_var,
630+
softmax_params[1].buffer_var,
631+
out_var,
632+
};
633+
634+
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
635+
buffer_creator.GetBufferMap(), args, PrimExpr(),
636+
0, 8, softmax_params);
637+
}
550638
}
551639

552640
struct BinaryElementwiseClipPattern {

0 commit comments

Comments
 (0)