Skip to content

Commit

Permalink
[onert-micro] Fix building with DIS_QUANT flag (#12507)
Browse files Browse the repository at this point in the history
This commit fixes building with enabled DIS_QUANT flag.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>

Co-authored-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem and Artem Balyshev authored Jan 26, 2024
1 parent 2f5c40a commit 9ece1bb
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 9 deletions.
2 changes: 2 additions & 0 deletions onert-micro/luci-interpreter/src/kernels/Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,15 @@ void execute_kernel_CircleAdd(const circle::Operator *cur_op, BaseRuntimeGraph *
}
}
break;
#ifndef DIS_QUANT
case DataType::S8:
case DataType::S16:
{
evalQuantized(kernel.input1(), kernel.input2(), kernel.output(), options, runtime_graph,
type);
}
break;
#endif
default:
assert(false && "Unsupported type.");
}
Expand Down
12 changes: 8 additions & 4 deletions onert-micro/luci-interpreter/src/kernels/Dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

#ifndef DIS_QUANT

#include "Builders.h"
#include "kernels/Utils.h"
#include "SISOKernel.h"
Expand All @@ -28,18 +26,23 @@ namespace luci_interpreter
void configure_kernel_CircleDequantize(const circle::Operator *cur_op,
BaseRuntimeGraph *runtime_graph)
{
#ifndef DIS_QUANT
kernels::SISOKernel kernel(cur_op, runtime_graph);

LUCI_INTERPRETER_CHECK(Tensor::num_elements(kernel.input()) ==
Tensor::num_elements(kernel.output()));
LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input()) == Tensor::num_dims(kernel.output()));
LUCI_INTERPRETER_CHECK(!Tensor::scales(kernel.input()).empty());
LUCI_INTERPRETER_CHECK(!Tensor::zero_points(kernel.input()).empty());
#else
assert(false && "Remove DIS_QUANT flag");
#endif // DIS_QUANT
}

void execute_kernel_CircleDequantize(const circle::Operator *cur_op,
BaseRuntimeGraph *runtime_graph)
{
#ifndef DIS_QUANT
kernels::SISOKernel kernel(cur_op, runtime_graph);

const auto *input_data = runtime_graph->getDataByTensor(kernel.input());
Expand Down Expand Up @@ -84,8 +87,9 @@ void execute_kernel_CircleDequantize(const circle::Operator *cur_op,
default:
assert(false && "Unsupported type");
}
#else
assert(false && "Remove DIS_QUANT flag");
#endif // DIS_QUANT
}

} // namespace luci_interpreter

#endif // DIS_QUANT
9 changes: 8 additions & 1 deletion onert-micro/luci-interpreter/src/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,14 @@ void configure_kernel_CircleFullyConnected(const circle::Operator *cur_op,
assert(output != nullptr);

#ifndef DIS_FLOAT
if (Tensor::element_type(weights) == DataType::FLOAT32)
if (Tensor::element_type(weights) == DataType::S8 and
Tensor::element_type(input) == DataType::FLOAT32)
{
// hybrid mode
LUCI_INTERPRETER_CHECK(Tensor::element_type(output) == DataType::FLOAT32);
LUCI_INTERPRETER_CHECK(!bias || Tensor::element_type(bias) == DataType::FLOAT32)
}
else if (Tensor::element_type(weights) == DataType::FLOAT32)
{
LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::FLOAT32);
LUCI_INTERPRETER_CHECK(Tensor::element_type(output) == DataType::FLOAT32);
Expand Down
12 changes: 8 additions & 4 deletions onert-micro/luci-interpreter/src/kernels/Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

#ifndef DIS_QUANT

#include "Builders.h"
#include "kernels/Utils.h"
#include "SISOKernel.h"
Expand All @@ -28,17 +26,22 @@ namespace luci_interpreter
void configure_kernel_CircleQuantize(const circle::Operator *cur_op,
BaseRuntimeGraph *runtime_graph)
{
#ifndef DIS_QUANT
kernels::SISOKernel kernel(cur_op, runtime_graph);

LUCI_INTERPRETER_CHECK(Tensor::num_elements(kernel.input()) ==
Tensor::num_elements(kernel.output()));
LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input()) == Tensor::num_dims(kernel.output()));
LUCI_INTERPRETER_CHECK(!Tensor::scales(kernel.output()).empty());
LUCI_INTERPRETER_CHECK(!Tensor::zero_points(kernel.output()).empty());
#else
assert(false && "Remove DIS_QUANT flag");
#endif // DIS_QUANT
}

void execute_kernel_CircleQuantize(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
#ifndef DIS_QUANT
kernels::SISOKernel kernel(cur_op, runtime_graph);

const auto *input_data = runtime_graph->getDataByTensor(kernel.input());
Expand Down Expand Up @@ -83,8 +86,9 @@ void execute_kernel_CircleQuantize(const circle::Operator *cur_op, BaseRuntimeGr
default:
assert(false && "Unsupported type");
}
#else
assert(false && "Remove DIS_QUANT flag");
#endif // DIS_QUANT
}

} // namespace luci_interpreter

#endif // DIS_QUANT

0 comments on commit 9ece1bb

Please sign in to comment.