3131#include " buffer_size.h"
3232#include " compiler_attrs.h"
3333#include " convolutions.h"
34+ #include " compute_luts.h"
3435
3536namespace tvm {
3637namespace relay {
@@ -89,26 +90,42 @@ class RelayToTIRVisitor : public MixedModeMutator {
8990 private:
9091 inline IntImm ToArg (int32_t value) { return IntImm (DataType::Int (32 ), value); }
9192
93+ // struct used to allocated const NDArray
94+ struct user_const {
95+ tir::Var buffer_var;
96+ int num_bits;
97+ Array<PrimExpr> extents;
98+ tvm::runtime::NDArray ndarray;
99+ };
100+
92101 void CreatePrimFuncForExtern (const GlobalVar& global_var, Array<tir::Var> func_signature,
93102 const Map<tir::Var, tir::Buffer>& buffer_map,
94103 tvm::Array<PrimExpr> call_extern_args,
95104 PrimExpr context_buffer_var = PrimExpr(),
96- int context_buffer_size = 0, int num_bits = 8) {
105+ int context_buffer_size = 0,
106+ int num_bits = 8,
107+ std::vector<user_const> context_const_buffer_vars = {}) {
97108 Map<String, ObjectRef> dict_attrs;
98109 dict_attrs.Set (tvm::attr::kGlobalSymbol , global_var->name_hint );
99110 dict_attrs.Set (tvm::attr::kTarget , target_);
100111 dict_attrs.Set (" tir.noalias" , Bool (true ));
101112
102113 tir::Stmt body = tir::Evaluate (
103114 tvm::tir::Call (DataType::Int (num_bits), tir::builtin::call_extern (), call_extern_args));
104-
115+
105116 if (context_buffer_size) {
106117 body = tir::Allocate (Downcast<tir::Var>(context_buffer_var), DataType::Int (num_bits),
107118 {context_buffer_size}, tir::const_true (), body);
108119 }
109-
120+
121+ for (int i = 0 ; i < int (context_const_buffer_vars.size ()); i++){
122+ body = tir::AllocateConst (Downcast<tir::Var>(context_const_buffer_vars[i].buffer_var ), DataType::Int (context_const_buffer_vars[i].num_bits ),
123+ context_const_buffer_vars[i].extents , context_const_buffer_vars[i].ndarray , body);
124+ }
125+
110126 tir::PrimFunc replacement_func (func_signature, body, VoidType (), buffer_map,
111127 DictAttrs (dict_attrs));
128+
112129 ir_module_->Add (global_var, replacement_func);
113130 }
114131
@@ -505,6 +522,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
505522 const CallNode* softmax_call = quantize_call->args [0 ].as <CallNode>();
506523 const CallNode* dequant_call = softmax_call->args [0 ].as <CallNode>();
507524 const float quant_scale = GetScalarFromConstant<float >(dequant_call->args [1 ]);
525+ const auto bit_width = quantize_call->type_as <TensorTypeNode>()->dtype .bits ();
526+ LOG (INFO) << PrettyPrint (quantize_call->args [0 ]);
527+ LOG (INFO) << PrettyPrint (softmax_call->args [0 ]);
528+ LOG (INFO) << PrettyPrint (dequant_call->args [0 ]);
508529
509530 // assuming layout as NHWC
510531 auto shape = quantize_call->type_as <TensorTypeNode>()->shape ;
@@ -517,36 +538,103 @@ class RelayToTIRVisitor : public MixedModeMutator {
517538
518539 // calculate multiplier and shift for CMSIS-NN softmax API
519540 // Note: TensorFlow Lite Micro assumptions
520- // Output zero point and scale are fixed to -128 and 1 / 256
541+ // Output zero point and scale are fixed to -128 and 1 / 256 in the case of an int8 operator or to 0 and 1 / 32768
521542 // kScaledDiffIntegerBits, kInputBits, kBeta are described on the following github page
522- // https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47
523- double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits )));
524- beta_multiplier = std::min<double >(beta_multiplier, (1ll << 31 ) - 1.0 );
525- auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift (beta_multiplier);
526- int32_t mult = std::get<0 >(mult_shift_pair);
527- int32_t shift = std::get<1 >(mult_shift_pair);
528- int32_t diff_min = (1 << kScaledDiffIntegerBits ) - 1 ;
529- diff_min <<= (31 - kScaledDiffIntegerBits );
530- diff_min >>= shift;
531- diff_min *= -1 ;
532-
543+ // https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/exp_zero_pointmicro/kernels/softmax_common.cc#L47
544+
545+ int32_t mult;
546+ int32_t shift;
547+ int32_t diff_min = 0 ;
548+
549+ std::vector<user_const> softmax_params (2 );
550+ Device dev{DLDeviceType::kDLCPU , 0 };
551+
552+ if (bit_width == 8 ){
553+ double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits )));
554+ beta_multiplier = std::min<double >(beta_multiplier, (1ll << 31 ) - 1.0 );
555+ auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift (beta_multiplier);
556+ mult = std::get<0 >(mult_shift_pair);
557+ shift = std::get<1 >(mult_shift_pair);
558+ diff_min = (1 << kScaledDiffIntegerBits ) - 1 ;
559+ diff_min <<= (31 - kScaledDiffIntegerBits );
560+ diff_min >>= shift;
561+ diff_min *= -1 ;
562+ }
563+ else { // bit_width == 16
564+ double scale_beta_rescale = quant_scale * kBeta / (10.0 / 65535.0 );
565+ auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift (scale_beta_rescale);
566+ mult = std::get<0 >(mult_shift_pair);
567+ shift = std::get<1 >(mult_shift_pair);
568+
569+ int lut_entries = 513 ;
570+ int16_t softmax_s16_exp_lut[lut_entries];
571+ int16_t softmax_s16_one_by_one_lut[lut_entries];
572+
573+ const int range_int16 = std::numeric_limits<int16_t >::max () - std::numeric_limits<int16_t >::min ();
574+ int exp_zero_point = std::numeric_limits<int16_t >::max ();
575+ float exp_scale = 10 .0f / range_int16;
576+
577+ int one_by_one_zero_point = std::numeric_limits<int16_t >::min ();
578+ float one_by_one_scale = 1 .0f / range_int16;
579+
580+ int lut_value_zero_point = 0 ;
581+ float lut_value_scale = 2 .0f / range_int16;
582+
583+ CalculateLUTInt16 (exp_zero_point, exp_scale, lut_value_zero_point, lut_value_scale,
584+ [](float key){ return std::exp (key); }, lut_entries, softmax_s16_exp_lut);
585+ CalculateLUTInt16 (one_by_one_zero_point, one_by_one_scale, lut_value_zero_point, lut_value_scale,
586+ [](float key){ return 1 .0f / (1 .0f + key); }, lut_entries, softmax_s16_one_by_one_lut);
587+
588+ // first LUT
589+ softmax_params[0 ].buffer_var = tir::Var (" exp_lut" , PointerType (PrimType (DataType::Int (bit_width)), " global.workspace" ));
590+ softmax_params[0 ].ndarray = runtime::NDArray::Empty ({lut_entries}, DataType::Int (bit_width), dev);
591+ softmax_params[0 ].ndarray .CopyFromBytes (softmax_s16_exp_lut, sizeof (int16_t )*lut_entries);
592+ softmax_params[0 ].extents = {lut_entries};
593+ softmax_params[0 ].num_bits = 16 ;
594+
595+ // second LUT
596+ softmax_params[1 ].buffer_var = tir::Var (" one_by_one_lut" , PointerType (PrimType (DataType::Int (bit_width)), " global.workspace" ));
597+ softmax_params[1 ].ndarray = runtime::NDArray::Empty ({lut_entries}, DataType::Int (bit_width), dev);
598+ softmax_params[1 ].ndarray .CopyFromBytes (softmax_s16_one_by_one_lut, sizeof (int16_t )*lut_entries);
599+ softmax_params[1 ].extents = {lut_entries};
600+ softmax_params[1 ].num_bits = 16 ;
601+ }
602+
533603 BufferCreator buffer_creator;
534- tir::Var in_var = buffer_creator.CreateBufferVar (" input" , DataType::Handle (8 ));
535- tir::Var out_var = buffer_creator.CreateBufferVar (" output" , DataType::Handle (8 ));
604+ tir::Var in_var = buffer_creator.CreateBufferVar (" input" , DataType::Handle (bit_width ));
605+ tir::Var out_var = buffer_creator.CreateBufferVar (" output" , DataType::Handle (bit_width ));
536606
537- tvm::Array<PrimExpr> args = {
538- tir::StringImm (" arm_softmax_s8" ),
607+ if (bit_width == 8 ) {
608+ tvm::Array<PrimExpr> args = {
609+ tir::StringImm (" arm_softmax_s" + std::to_string (bit_width)),
539610 in_var,
540611 ToArg (num_rows),
541612 ToArg (row_size),
542613 ToArg (mult),
543614 ToArg (shift),
544615 ToArg (diff_min),
545616 out_var,
546- };
617+ };
547618
548- CreatePrimFuncForExtern (global_var, buffer_creator.GetPrimFuncParams (),
619+ CreatePrimFuncForExtern (global_var, buffer_creator.GetPrimFuncParams (),
549620 buffer_creator.GetBufferMap (), args);
621+ } else { // bit_width == 16
622+ tvm::Array<PrimExpr> args = {
623+ tir::StringImm (" arm_softmax_s" + std::to_string (bit_width)),
624+ in_var,
625+ ToArg (num_rows),
626+ ToArg (row_size),
627+ ToArg (mult),
628+ ToArg (shift),
629+ softmax_params[0 ].buffer_var ,
630+ softmax_params[1 ].buffer_var ,
631+ out_var,
632+ };
633+
634+ CreatePrimFuncForExtern (global_var, buffer_creator.GetPrimFuncParams (),
635+ buffer_creator.GetBufferMap (), args, PrimExpr (),
636+ 0 , 8 , softmax_params);
637+ }
550638 }
551639
552640 struct BinaryElementwiseClipPattern {
0 commit comments