-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Enable CircleGRU operation in luci-interpreter - Overall implemenatation is borrowed from onert-micro 2.0 ONE-DCO-1.0-Signed-off-by: Chunseok Lee <[email protected]>
- Loading branch information
1 parent
153edec
commit 2b9d460
Showing
7 changed files
with
543 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef LUCI_INTERPRETER_PAL_GRU_H | ||
#define LUCI_INTERPRETER_PAL_GRU_H | ||
|
||
#include <tensorflow/lite/kernels/internal/reference/fully_connected.h> | ||
#include "PALreference_ops.h" | ||
namespace luci_interpreter_pal | ||
{ | ||
|
||
// tflite's Logistic does not provide inplace Logistic kernel | ||
void Logistic(const int flat_size, const float *input_data, float *output_data) | ||
{ | ||
const float cutoff_upper = 16.619047164916992188f; | ||
const float cutoff_lower = -9.f; | ||
|
||
// Rational for using approximation in reference kernel. | ||
// 0. This approximation gives enough precision for float. | ||
// 1. This works around an issue on an embedded chipset where exp() does not | ||
// return correctly as expected - exp(x) should return inf when overflown | ||
// not 1.701417 IEEE 754 defines representation for inf. | ||
// 2. This will speed up calculation and is matching the behavior in the | ||
// optimized kernels. (check the definition of scalar_logistic_op<float>) | ||
|
||
for (int i = 0; i < flat_size; i++) | ||
{ | ||
float val = input_data[i]; | ||
float result; | ||
if (val > cutoff_upper) | ||
{ | ||
result = 1.0f; | ||
} | ||
else if (val < cutoff_lower) | ||
{ | ||
result = std::exp(val); | ||
} | ||
else | ||
{ | ||
result = 1.f / (1.f + std::exp(-val)); | ||
} | ||
output_data[i] = result; | ||
} | ||
} | ||
|
||
void calculateGRU(const float *input_data, const float *weight_input_data, | ||
const float *weight_hidden_data, const float *bias_input_data, | ||
const float *bias_hidden_data, float *output_data, | ||
const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape, | ||
const tflite::RuntimeShape &weight_input_shape, | ||
const tflite::RuntimeShape &weight_hidden_shape, float *output_input_data, | ||
float *output_hidden_data, const tflite::RuntimeShape &output_shape_fc, | ||
float *intermediate_buffer) | ||
{ | ||
tflite::FullyConnectedParams op_params{}; | ||
// As FC nodes doesn't have any activations inside GRU, let' use just numeric limits | ||
op_params.float_activation_min = std::numeric_limits<float>::lowest(); | ||
op_params.float_activation_max = std::numeric_limits<float>::max(); | ||
|
||
// FC Input | ||
tflite::RuntimeShape bias_input_shape{weight_input_shape.Dims(0)}; | ||
tflite::reference_ops::FullyConnected(op_params, output_shape, output_data, weight_input_shape, | ||
weight_input_data, bias_input_shape, bias_input_data, | ||
output_shape_fc, output_input_data); | ||
|
||
// FC Hidden | ||
tflite::RuntimeShape bias_hidden_shape{weight_hidden_shape.Dims(0)}; | ||
// Note: input for this FC node will be saved without intermediate buffer | ||
tflite::reference_ops::FullyConnected(op_params, input_shape, input_data, weight_hidden_shape, | ||
weight_hidden_data, bias_hidden_shape, bias_hidden_data, | ||
output_shape_fc, output_hidden_data); | ||
|
||
int num_elements = output_shape_fc.Dims(1) / 3; | ||
|
||
float *second_hidden_part = output_hidden_data + num_elements; | ||
float *second_input_part = output_input_data + num_elements; | ||
|
||
float *third_hidden_part = second_hidden_part + num_elements; | ||
float *third_input_part = second_input_part + num_elements; | ||
|
||
// Calculate Left part | ||
for (int i = 0; i < num_elements; ++i) | ||
{ | ||
output_input_data[i] += output_hidden_data[i]; | ||
} | ||
|
||
Logistic(num_elements, output_input_data, output_input_data); | ||
|
||
// Calculate most left mul | ||
float *most_left_part_final = output_input_data; | ||
float *first_part = output_input_data; | ||
for (int i = 0; i < num_elements; ++i) | ||
{ | ||
output_data[i] *= most_left_part_final[i]; | ||
first_part[i] = 1.0f - first_part[i]; | ||
} | ||
|
||
// Calc second part | ||
for (int i = 0; i < num_elements; ++i) | ||
{ | ||
second_hidden_part[i] += second_input_part[i]; | ||
} | ||
|
||
Logistic(num_elements, second_hidden_part, second_hidden_part); | ||
|
||
for (int i = 0; i < num_elements; ++i) | ||
{ | ||
second_hidden_part[i] *= third_input_part[i]; | ||
second_hidden_part[i] += third_hidden_part[i]; | ||
} | ||
|
||
for (int i = 0; i < num_elements; ++i) | ||
{ | ||
if (second_hidden_part[i] > 19) | ||
{ | ||
second_hidden_part[i] = 1; | ||
} | ||
else if (second_hidden_part[i] < -19) | ||
{ | ||
second_hidden_part[i] = -1; | ||
} | ||
else | ||
{ | ||
second_hidden_part[i] = std::tanh(second_hidden_part[i]); | ||
} | ||
} | ||
|
||
for (int i = 0; i < num_elements; ++i) | ||
{ | ||
second_hidden_part[i] *= first_part[i]; | ||
output_data[i] += second_hidden_part[i]; | ||
} | ||
} | ||
|
||
void GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data, | ||
const float *bias_input_data, const float *bias_hidden_data, | ||
const float *hidden_state_data, float *output_data, float *output_input_data, | ||
float *output_hidden_data, const tflite::RuntimeShape &input_shape, | ||
const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape, | ||
const tflite::RuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size, | ||
float *intermediate_buffer) | ||
{ | ||
const int32_t time = input_shape.Dims(0); | ||
|
||
tflite::RuntimeShape output_shape_fc(2); | ||
output_shape_fc.SetDim(0, 1); | ||
output_shape_fc.SetDim(1, weight_hidden_shape.Dims(0)); | ||
|
||
std::memcpy(output_data, hidden_state_data, output_shape.FlatSize() * sizeof(float)); | ||
|
||
for (int i = 0; i < time; ++i) | ||
{ | ||
calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data, | ||
bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape, | ||
weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc, | ||
intermediate_buffer); | ||
input_data += input_shape.Dims(2); | ||
} | ||
} | ||
|
||
} // namespace luci_interpreter_pal | ||
|
||
#endif // ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "kernels/GRU.h" | ||
|
||
#include "kernels/Utils.h" | ||
|
||
#include "PALFullyConnected.h" | ||
#include "PALGRU.h" | ||
|
||
namespace luci_interpreter | ||
{ | ||
namespace kernels | ||
{ | ||
GRU::GRU(const Tensor *input, const Tensor *hidden_hidden, const Tensor *hidden_hidden_bias, | ||
const Tensor *hidden_input, const Tensor *hidden_input_bias, const Tensor *state, | ||
Tensor *output, const GRUParams ¶ms) | ||
: KernelWithParams<GRUParams>( | ||
{input, hidden_hidden, hidden_hidden_bias, hidden_input, hidden_input_bias, state}, {output}, | ||
params) | ||
{ | ||
} | ||
|
||
void GRU::configure() | ||
{ | ||
auto hidden_hidden_shape = getTensorShape(hidden_hidden()); | ||
auto hidden_input_shape = getTensorShape(hidden_input()); | ||
LUCI_INTERPRETER_CHECK(hidden_hidden_shape.Dims(0) == hidden_input_shape.Dims(0)); | ||
|
||
const int32_t div_factor = 3; | ||
|
||
auto output_shape = getTensorShape(output()); | ||
auto state_shape = getTensorShape(state()); | ||
|
||
output()->resize(state()->shape()); | ||
|
||
LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type()); | ||
} | ||
|
||
void GRU::execute() const | ||
{ | ||
switch (input()->element_type()) | ||
{ | ||
case DataType::FLOAT32: | ||
evalFloat(); | ||
break; | ||
default: | ||
throw std::runtime_error("luci-GRU Unsupported data type."); | ||
} | ||
} | ||
|
||
void GRU::evalFloat() const | ||
{ | ||
uint8_t *output_hidden_data; | ||
uint8_t *output_input_data; | ||
|
||
// allocate output datas above | ||
output_hidden_data = new uint8_t[getTensorShape(hidden_hidden()).FlatSize() * sizeof(float)]; | ||
output_input_data = new uint8_t[getTensorShape(hidden_input()).FlatSize() * sizeof(float)]; | ||
|
||
luci_interpreter_pal::GRU( | ||
getTensorData<float>(input()), getTensorData<float>(hidden_input()), | ||
getTensorData<float>(hidden_hidden()), getTensorData<float>(hidden_input_bias()), | ||
getTensorData<float>(hidden_hidden_bias()), getTensorData<float>(state()), | ||
getTensorData<float>(output()), reinterpret_cast<float *>(output_input_data), | ||
reinterpret_cast<float *>(output_hidden_data), getTensorShape(input()), | ||
getTensorShape(output()), getTensorShape(hidden_input()), getTensorShape(hidden_hidden()), 0, | ||
nullptr); | ||
|
||
delete output_hidden_data; | ||
delete output_input_data; | ||
} | ||
|
||
} // namespace kernels | ||
} // namespace luci_interpreter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef LUCI_INTERPRETER_KERNELS_GRU_H | ||
#define LUCI_INTERPRETER_KERNELS_GRU_H | ||
|
||
#include "core/Kernel.h" | ||
#include "core/KernelParams.h" | ||
|
||
namespace luci_interpreter | ||
{ | ||
namespace kernels | ||
{ | ||
|
||
class GRU : public KernelWithParams<GRUParams> | ||
{ | ||
public: | ||
GRU(const Tensor *input, const Tensor *hidden_hidden, const Tensor *hidden_hidden_bias, | ||
const Tensor *hidden_input, const Tensor *hidden_input_bias, const Tensor *state, | ||
Tensor *output, const GRUParams ¶ms); | ||
|
||
const Tensor *input() const { return _inputs[0]; } | ||
const Tensor *hidden_hidden() const { return _inputs[1]; } | ||
const Tensor *hidden_hidden_bias() const { return _inputs[2]; } | ||
const Tensor *hidden_input() const { return _inputs[3]; } | ||
const Tensor *hidden_input_bias() const { return _inputs[4]; } | ||
const Tensor *state() const { return _inputs[5]; } | ||
Tensor *output() const { return _outputs[0]; } | ||
|
||
void configure() override; | ||
void execute() const override; | ||
|
||
private: | ||
void evalFloat() const; | ||
}; | ||
|
||
} // namespace kernels | ||
} // namespace luci_interpreter | ||
|
||
#endif // LUCI_INTERPRETER_KERNELS_ROPE_H |
Oops, something went wrong.