diff --git a/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.cpp b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.cpp new file mode 100644 index 00000000000..002f3c2cae0 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.cpp @@ -0,0 +1,531 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeOnnxDequantizeLinearPass.h" +#include "QuantizationUtils.h" + +#include +#include + +#include + +namespace +{ + +using namespace luci; + +// Return true if all values of node are within value_range +// value_range: [min, max] +template +bool value_range(const luci::CircleConst *node, const std::pair &value_range) +{ + const auto min = value_range.first; + const auto max = value_range.second; + + auto size = node->size
(); + for (uint32_t i = 0; i < size; i++) + { + const auto val = static_cast(node->at
(i)); + if (val < min or val > max) + return false; + } + + return true; +} + +std::vector get_scales(const luci::CircleConst *node) +{ + assert(node); // FIX_CALLER_UNLESS + + const auto num_scales = node->size(); + std::vector scales(num_scales); + for (uint32_t i = 0; i < num_scales; ++i) + { + scales[i] = node->at(i); + } + + return scales; +} + +template std::vector get_zerops(const luci::CircleConst *node) +{ + assert(node); // FIX_CALLER_UNLESS + + const auto num_zerops = node->size
(); + std::vector zerops(num_zerops); + for (uint32_t i = 0; i < num_zerops; ++i) + { + zerops[i] = node->at
(i); + } + + return zerops; +} + +int32_t get_axis(const luci::CircleCustom *node) +{ + assert(node); // FIX_CALLER_UNLESS + + const auto custom_options = node->custom_options(); + const auto map = flexbuffers::GetRoot(custom_options).AsMap(); + + return map["axis"].IsNull() ? 0 : map["axis"].AsInt32(); +} + +class OnnxDequantizeLinearPattern final +{ +public: + OnnxDequantizeLinearPattern(luci::CircleCustomOut *candidate) { custom_out = candidate; } + +public: + bool matched() + { + if (not custom_out) + return false; + + dequantize = loco::must_cast(custom_out->input()); + if (not is_onnx_dequantize_linear(dequantize)) + return false; + + input = dynamic_cast(dequantize->inputs(0)); + if (not input) + return false; + + scale = dynamic_cast(dequantize->inputs(1)); + if (not scale) + return false; + + zerop = dynamic_cast(dequantize->inputs(2)); + if (not zerop) + return false; + + const auto input_dtype = input->dtype(); + const auto scale_dtype = scale->dtype(); + const auto zerop_dtype = zerop->dtype(); + + if (scale_dtype != loco::DataType::FLOAT32) + return false; + + // Invariant from onnx DequantizeLinear operator + if (input_dtype != zerop_dtype) + return false; + + return true; + } + +public: + luci::CircleCustomOut *custom_out = nullptr; + luci::CircleCustom *dequantize = nullptr; + luci::CircleConst *input = nullptr; + luci::CircleConst *scale = nullptr; + luci::CircleConst *zerop = nullptr; +}; + +// Temporary class for our in-house model +// This is for per-tensor quantized LN const +// uint8 weight, int16 zerop, fp32 scale +// NOTE weight dtype != zerop dtype breaks invariant of +// onnx DequantizeLinear. That's why this class is a hack. +class OnnxDequantizeLinearPatternV2 final +{ +public: + OnnxDequantizeLinearPatternV2(luci::CircleCustomOut *candidate) { custom_out = candidate; } + +public: + bool matched() + { + if (not custom_out) + return false; + + dequantize = loco::must_cast(custom_out->input()); + if (not is_onnx_dequantize_linear(dequantize)) + return false; + + input = dynamic_cast(dequantize->inputs(0)); + if (not input) + return false; + + scale = dynamic_cast(dequantize->inputs(1)); + if (not scale) + return false; + + zerop = dynamic_cast(dequantize->inputs(2)); + if (not zerop) + return false; + + const auto input_dtype = input->dtype(); + const auto scale_dtype = scale->dtype(); + const auto zerop_dtype = zerop->dtype(); + + if (scale_dtype != loco::DataType::FLOAT32) + return false; + + if (input_dtype != loco::DataType::U8) + return false; + + if (zerop_dtype != loco::DataType::S16) + return false; + + return true; + } + +public: + luci::CircleCustomOut *custom_out = nullptr; + luci::CircleCustom *dequantize = nullptr; + luci::CircleConst *input = nullptr; + luci::CircleConst *scale = nullptr; + luci::CircleConst *zerop = nullptr; +}; + +class QuantizeOnnxDequantizeLinear final +{ +public: + QuantizeOnnxDequantizeLinear(const OnnxDequantizeLinearPattern &p) : _p(p) {} + +public: + void apply(void) + { + // The final const's dtype is the same with input_dtype by default + auto const_dtype = _p.input->dtype(); + if (const_dtype == loco::DataType::U8) + { + // Onnx does not support int4/uint4 as of writing. We assume uint8 + // tensor is quantized in int4/uint4 if values are within [0,15] + if (value_range(_p.input, {0, 15})) + { + if (value_range(_p.zerop, {8, 8})) + { + const_dtype = loco::DataType::S4; + } + else if (value_range(_p.zerop, {0, 15})) + { + const_dtype = loco::DataType::U4; + } + } + } + + luci::CircleConst *quant_const = nullptr; + switch (const_dtype) + { + case loco::DataType::S4: + quant_const = gen_s4_quant(); + break; + case loco::DataType::U4: + quant_const = gen_u4_quant(); + break; + case loco::DataType::U8: + quant_const = gen_u8_quant(); + break; + default: + throw std::runtime_error("Unsupported quantized dtype"); + } + + assert(quant_const); // FIX_ME_UNLESS + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p.dequantize), luci::get_origin(_p.input), luci::get_origin(_p.scale), + luci::get_origin(_p.zerop)}; + + luci::add_origin(quant_const, luci::composite_origin(origin_vec)); + + replace(_p.custom_out).with(quant_const); + } + +private: + luci::CircleConst *gen_s4_quant(void) + { + assert(_p.input->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::S4); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create S4 CircleConst + // NOTE S4 is saved as S8 in luci::CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const uint8_t u8_val = _p.input->at(i); + assert(u8_val <= 15); // FIX_CALLER_UNLESS + quantized_node->at(i) = static_cast(u8_val) - 8; + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + + luci::CircleConst *gen_u4_quant(void) + { + assert(_p.input->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::U4); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create U4 CircleConst + // NOTE U4 is saved as U8 in luci::CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const uint8_t u8_val = _p.input->at(i); + assert(u8_val <= 15); // FIX_CALLER_UNLESS + quantized_node->at(i) = u8_val; + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + + luci::CircleConst *gen_u8_quant(void) + { + assert(_p.input->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::U8); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create U8 CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const uint8_t u8_val = _p.input->at(i); + quantized_node->at(i) = u8_val; + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + +private: + const OnnxDequantizeLinearPattern &_p; +}; + +// Temporary class to handle our in-house model +class QuantizeOnnxDequantizeLinearV2 final +{ +public: + QuantizeOnnxDequantizeLinearV2(const OnnxDequantizeLinearPatternV2 &p) : _p(p) {} + +public: + void apply(void) + { + auto const_dtype = _p.zerop->dtype(); + + luci::CircleConst *quant_const = nullptr; + switch (const_dtype) + { + case loco::DataType::S16: + quant_const = gen_s16_quant(); + break; + default: + throw std::runtime_error("Unsupported quantized dtype"); + } + + assert(quant_const); // FIX_ME_UNLESS + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p.dequantize), luci::get_origin(_p.input), luci::get_origin(_p.scale), + luci::get_origin(_p.zerop)}; + + luci::add_origin(quant_const, luci::composite_origin(origin_vec)); + + replace(_p.custom_out).with(quant_const); + } + +private: + luci::CircleConst *gen_s16_quant(void) + { + assert(_p.input->dtype() == loco::DataType::U8); // FIX_CALLER_UNLESS + assert(_p.scale->dtype() == loco::DataType::FLOAT32); // FIX_CALLER_UNLESS + assert(_p.zerop->dtype() == loco::DataType::S16); // FIX_CALLER_UNLESS + + auto quantized_node = _p.dequantize->graph()->nodes()->create(); + quantized_node->dtype(loco::DataType::S16); + quantized_node->rank(_p.input->rank()); + for (uint32_t i = 0; i < _p.input->rank(); ++i) + { + quantized_node->dim(i) = _p.input->dim(i); + } + quantized_node->shape_status(luci::ShapeStatus::VALID); + + // Create S16 CircleConst + const auto num_elems = _p.input->size(); + quantized_node->size(num_elems); + for (uint32_t i = 0; i < num_elems; i++) + { + const uint8_t u8_val = _p.input->at(i); + quantized_node->at(i) = static_cast(u8_val); + } + + auto qparam = std::make_unique(); + { + const std::vector scale_vector = get_scales(_p.scale); + const std::vector zerop_vector = get_zerops(_p.zerop); + + if (scale_vector.size() != zerop_vector.size()) + throw std::runtime_error("Scale/Zerop size mismatches in " + _p.dequantize->name()); + + const int32_t qdim = get_axis(_p.dequantize); + + qparam->scale = scale_vector; + qparam->zerop = zerop_vector; + qparam->quantized_dimension = qdim; + } + + quantized_node->quantparam(std::move(qparam)); + + quantized_node->name(_p.input->name()); + + return quantized_node; + } + +private: + const OnnxDequantizeLinearPatternV2 &_p; +}; + +} // namespace + +namespace luci +{ + +/** + * + * Quantize pattern + * + * [Before] + * + * [CircleConst(quantized)] + * | + * [CircleCustom(OnnxDequantizeLinear)] + * | + * [CircleNode] + * + * [After] + * + * [CircleConst(quantized)] + * | + * [CircleNode] + */ +bool QuantizeOnnxDequantizeLinearPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto circle_custom_out = dynamic_cast(node)) + { + OnnxDequantizeLinearPattern p(circle_custom_out); + if (p.matched()) + { + QuantizeOnnxDequantizeLinear quantize(p); + quantize.apply(); + changed = true; + } + + // TODO Remove V2 classes + OnnxDequantizeLinearPatternV2 p2(circle_custom_out); + if (p2.matched()) + { + QuantizeOnnxDequantizeLinearV2 quantize(p2); + quantize.apply(); + changed = true; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.h b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.h new file mode 100644 index 00000000000..17436672b9f --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_QUANTIZE_ONNX_DEQUANTIZE_LINEAR_PASS_H__ +#define __LUCI_QUANTIZE_ONNX_DEQUANTIZE_LINEAR_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to quantize ONNXDequantizeLinear operator + * + */ +struct QuantizeOnnxDequantizeLinearPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::QuantizeOnnxDequantizeLinear"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_QUANTIZE_ONNX_DEQUANTIZE_LINEAR_PASS_H__ diff --git a/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.test.cpp b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.test.cpp new file mode 100644 index 00000000000..17627219a48 --- /dev/null +++ b/compiler/luci/pass/src/QuantizeOnnxDequantizeLinearPass.test.cpp @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "QuantizeOnnxDequantizeLinearPass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +template +class QuantizeOnnxDequantizeLinearTest : public luci::ConstantFoldingAddTestGraph, + public ::testing::Test +{ +public: + QuantizeOnnxDequantizeLinearTest() : luci::ConstantFoldingAddTestGraph({2, 2, 2}, DT) {} + + virtual void SetUp() { init(); } + + loco::Node *createFoldedPattern() override + { + _dequantize = _g.nodes()->template create(3, 1); + _dequantize_out = _g.nodes()->template create(); + _input = _g.nodes()->template create(); + _scale = _g.nodes()->template create(); + _zerop = _g.nodes()->template create(); + + _dequantize->dtype(loco::DataType::FLOAT32); + _dequantize_out->dtype(loco::DataType::FLOAT32); + _input->dtype(DT); + _scale->dtype(loco::DataType::FLOAT32); + _zerop->dtype(DT); + + _input->shape({2, 2, 2}); + _scale->shape({2}); + _zerop->shape({2}); + + _input->size
(8); + + _scale->size(2); + _scale->at(0) = 5.0; + _scale->at(1) = 10.0; + + _zerop->size
(2); + + // custom option + auto flex_buffers = std::make_unique(); + size_t map_start = flex_buffers->StartMap(); + flex_buffers->Int("axis", 1); + flex_buffers->EndMap(map_start); + flex_buffers->Finish(); + + _dequantize->inputs(0, _input); + _dequantize->inputs(1, _scale); + _dequantize->inputs(2, _zerop); + _dequantize->custom_code("ONNXDequantizeLinear"); + _dequantize->custom_options(flex_buffers->GetBuffer()); + + _dequantize_out->input(_dequantize); + _dequantize_out->index(0); + + _input->name("input"); + _dequantize->name("dequantize"); + _dequantize_out->name("dequantize_out"); + + return _dequantize_out; + } + + void createNotQuantizablePattern() { _input->dtype(loco::DataType::FLOAT32); } + +protected: + luci::CircleCustom *_dequantize = nullptr; + luci::CircleCustomOut *_dequantize_out = nullptr; + luci::CircleConst *_input = nullptr; + luci::CircleConst *_scale = nullptr; + luci::CircleConst *_zerop = nullptr; +}; + +class S4QuantizeOnnxDequantizeLinearTest + : public QuantizeOnnxDequantizeLinearTest +{ + virtual void SetUp() override + { + init(); + + // Input range [0, 15] + for (uint32_t i = 0; i < _input->size(); i++) + { + _input->at(i) = 1; + } + + // Zerop = 8 + for (uint32_t i = 0; i < _zerop->size(); i++) + { + _zerop->at(i) = 8; + } + } +}; + +class U4QuantizeOnnxDequantizeLinearTest + : public QuantizeOnnxDequantizeLinearTest +{ + virtual void SetUp() override + { + init(); + + // Input range [0, 15] + for (uint32_t i = 0; i < _input->size(); i++) + { + _input->at(i) = 1; + } + + // Zerop = [0, 15] + for (uint32_t i = 0; i < _zerop->size(); i++) + { + _zerop->at(i) = 1; + } + } +}; + +class U8QuantizeOnnxDequantizeLinearTest + : public QuantizeOnnxDequantizeLinearTest +{ + virtual void SetUp() override + { + init(); + + // Input range [0, 255] + for (uint32_t i = 0; i < _input->size(); i++) + { + _input->at(i) = 255; + } + + // Zerop = [0, 255] + for (uint32_t i = 0; i < _zerop->size(); i++) + { + _zerop->at(i) = 128; + } + } +}; + +} // namespace + +TEST_F(S4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic) +{ + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(loco::DataType::S4, folded_const->dtype()); +} + +TEST_F(S4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic_NEG) +{ + createNotQuantizablePattern(); + + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(U4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic) +{ + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(loco::DataType::U4, folded_const->dtype()); +} + +TEST_F(U4QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic_NEG) +{ + createNotQuantizablePattern(); + + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +} + +TEST_F(U8QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic) +{ + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_NE(nullptr, folded_const); + + EXPECT_EQ(loco::DataType::U8, folded_const->dtype()); +} + +TEST_F(U8QuantizeOnnxDequantizeLinearTest, quantize_onnx_dq_linear_basic_NEG) +{ + createNotQuantizablePattern(); + + luci::QuantizeOnnxDequantizeLinearPass pass; + while (pass.run(graph())) + ; + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(nullptr, folded_const); +}