From 4bbb255fcd13db35094ab0936011f013a1ec9ddf Mon Sep 17 00:00:00 2001 From: Balyshev Artem <43214667+BalyshevArtem@users.noreply.github.com> Date: Thu, 21 Sep 2023 11:28:24 +0300 Subject: [PATCH] [onert-micro] Add cmsis-nn Mul kernel (#11566) This commit adds cmsis-nn Mul kernel ONE-DCO-1.0-Signed-off-by: Artem Balyshev Co-authored-by: Artem Balyshev --- .../pal/cmsisnn/KernelsToBuild.lst | 1 + .../luci-interpreter/pal/cmsisnn/PALMul.h | 36 ++++--- .../luci-interpreter/src/kernels/Mul.cpp | 95 ++++++++++++------- 3 files changed, 85 insertions(+), 47 deletions(-) diff --git a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst index 29ec817899c..930fa0d5c72 100644 --- a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst +++ b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst @@ -27,6 +27,7 @@ REGISTER_KERNEL(LESS_EQUAL, LessEqual) REGISTER_KERNEL(LOGICAL_AND, LogicalAnd) REGISTER_KERNEL(LOGICAL_OR, LogicalOr) REGISTER_KERNEL(LEAKY_RELU, LeakyRelu) +REGISTER_KERNEL(MUL, Mul) REGISTER_KERNEL(CONCATENATION, Concatenation) REGISTER_KERNEL(SHAPE, Shape) REGISTER_KERNEL(NOT_EQUAL, NotEqual) diff --git a/onert-micro/luci-interpreter/pal/cmsisnn/PALMul.h b/onert-micro/luci-interpreter/pal/cmsisnn/PALMul.h index 347a97a831d..e640c35c567 100644 --- a/onert-micro/luci-interpreter/pal/cmsisnn/PALMul.h +++ b/onert-micro/luci-interpreter/pal/cmsisnn/PALMul.h @@ -17,29 +17,35 @@ #ifndef LUCI_INTERPRETER_PAL_MUL_H #define LUCI_INTERPRETER_PAL_MUL_H -#include +#include "PALMulCommon.h" +#include "arm_nnfunctions.h" namespace luci_interpreter_pal { -template -static inline void Mul(tflite::ArithmeticParams ¶ms, const tflite::RuntimeShape &input1_shape, - const T *input1_data, const tflite::RuntimeShape &input2_shape, - const T *input2_data, const tflite::RuntimeShape &output_shape, - T *output_data) + +template <> +inline void Mul(const ArithmeticParams ¶ms, const int flat_size, + const int8_t *input1_data, const int8_t *input2_data, int8_t *output_data) { - tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape, - input2_data, output_shape, output_data); + auto status = arm_elementwise_mul_s8( + input1_data, input2_data, params.input1_offset, params.input2_offset, output_data, + params.output_offset, params.output_multiplier, params.output_shift, + params.quantized_activation_min, params.quantized_activation_max, flat_size); + assert(status == ARM_CMSIS_NN_SUCCESS); } -template -static inline void -BroadcastMul4DSlow(tflite::ArithmeticParams ¶ms, const tflite::RuntimeShape &input1_shape, - const T *input1_data, const tflite::RuntimeShape &input2_shape, - const T *input2_data, const tflite::RuntimeShape &output_shape, T *output_data) +template <> +inline void Mul(const ArithmeticParams ¶ms, const int flat_size, + const int16_t *input1_data, const int16_t *input2_data, + int16_t *output_data) { - tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape, - input2_data, output_shape, output_data); + auto status = arm_elementwise_mul_s16( + input1_data, input2_data, params.input1_offset, params.input2_offset, output_data, + params.output_offset, params.output_multiplier, params.output_shift, + params.quantized_activation_min, params.quantized_activation_max, flat_size); + assert(status == ARM_CMSIS_NN_SUCCESS); } + } // namespace luci_interpreter_pal #endif // LUCI_INTERPRETER_PAL_MUL_H diff --git a/onert-micro/luci-interpreter/src/kernels/Mul.cpp b/onert-micro/luci-interpreter/src/kernels/Mul.cpp index 75f9b904773..86fb7ef635b 100644 --- a/onert-micro/luci-interpreter/src/kernels/Mul.cpp +++ b/onert-micro/luci-interpreter/src/kernels/Mul.cpp @@ -25,6 +25,63 @@ namespace luci_interpreter { +namespace +{ + +#ifndef DIS_QUANT +void evalQuantized(const circle::Tensor *input1, const circle::Tensor *input2, + const circle::Tensor *output, const circle::MulOptions *options, + BaseRuntimeGraph *runtime_graph, DataType type) +{ + assert(type == DataType::S16 or type == DataType::S8 && "Wrong Type"); + + luci_interpreter_pal::ArithmeticParams params{}; + luci_interpreter::RuntimeShape input_shape1 = + kernels::getTensorRuntimeShape(input1, runtime_graph); + luci_interpreter::RuntimeShape input_shape2 = + kernels::getTensorRuntimeShape(input2, runtime_graph); + + const bool need_broadcast = + luci_interpreter_pal::ProcessBroadcastShapes(input_shape1, input_shape2, ¶ms); + + assert(need_broadcast == false && "Broadcast for INT8 and INT16 not supported now"); + + params.input1_offset = -Tensor::zero_point(input1); + params.input2_offset = -Tensor::zero_point(input2); + params.output_offset = Tensor::zero_point(output); + + const auto input1_scale = static_cast(Tensor::scale(input1)); + const auto input2_scale = static_cast(Tensor::scale(input2)); + const auto output_scale = static_cast(Tensor::scale(output)); + + double real_multiplier = input1_scale * input2_scale / output_scale; + + kernels::quantizeMultiplier(real_multiplier, ¶ms.output_multiplier, ¶ms.output_shift); + + kernels::calculateActivationRangeQuantized(luci_actfunc(options->fused_activation_function()), + output, ¶ms.quantized_activation_min, + ¶ms.quantized_activation_max); + if (type == DataType::S8) + { + luci_interpreter_pal::Mul( + params, input_shape1.flatSize(), + kernels::getTensorData(runtime_graph->getDataByTensor(input1)), + kernels::getTensorData(runtime_graph->getDataByTensor(input2)), + kernels::getTensorData(runtime_graph->getDataByTensor(output))); + } + else + { + luci_interpreter_pal::Mul( + params, input_shape1.flatSize(), + kernels::getTensorData(runtime_graph->getDataByTensor(input1)), + kernels::getTensorData(runtime_graph->getDataByTensor(input2)), + kernels::getTensorData(runtime_graph->getDataByTensor(output))); + } +} +#endif // DIS_QUANT + +} // namespace + void configure_kernel_CircleMul(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) { kernels::TISOKernel kernel(cur_op, runtime_graph); @@ -57,8 +114,8 @@ void execute_kernel_CircleMul(const circle::Operator *cur_op, BaseRuntimeGraph * kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph); bool is_inplace = runtime_graph->is_inplace_op(cur_op); - - switch (Tensor::element_type(kernel.input1())) + const auto type = Tensor::element_type(kernel.input1()); + switch (type) { #ifndef DIS_FLOAT case DataType::FLOAT32: @@ -113,41 +170,15 @@ void execute_kernel_CircleMul(const circle::Operator *cur_op, BaseRuntimeGraph * } } break; -#if 0 #ifndef DIS_QUANT - // TODO: check quantize Mul - case DataType::U8: + case DataType::S8: + case DataType::S16: { - auto tiso_func = [](const luci_interpreter_pal::ArithmeticParams ¶ms, - const luci_interpreter::RuntimeShape &input1_shape, const uint8_t *input1_data, - const luci_interpreter::RuntimeShape &input2_shape, const uint8_t *input2_data, - const luci_interpreter::RuntimeShape &output_shape, uint8_t *output_data) { - luci_interpreter_pal::Mul(params, input1_shape, input1_data, input2_shape, input2_data, - output_shape, output_data); - }; - auto broadcast_tiso_func = - [](const luci_interpreter_pal::ArithmeticParams ¶ms, const luci_interpreter::RuntimeShape &input1_shape, - const uint8_t *input1_data, const luci_interpreter::RuntimeShape &input2_shape, - const uint8_t *input2_data, const luci_interpreter::RuntimeShape &output_shape, - uint8_t *output_data) { - luci_interpreter_pal::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape, - input2_data, output_shape, output_data); - }; - if (is_inplace) - { - kernels::evalTISOInplaceQuantizedKernel(tiso_func, broadcast_tiso_func, &kernel, - options); - } - else - { - kernels::TISOData kernel_data = kernel.readData(); - kernels::evalTISOQuantizedKernel(tiso_func, broadcast_tiso_func, &kernel, - &kernel_data, options); - } + evalQuantized(kernel.input1(), kernel.input2(), kernel.output(), options, runtime_graph, + type); } break; #endif // DIS_QUANT -#endif // 0 default: assert(false && "Unsupported type."); }