Skip to content

Commit

Permalink
[luci-interpreter] GRU kernel
Browse files Browse the repository at this point in the history
- Enable CircleGRU operation in luci-interpreter
- Overall implemenatation is borrowed from onert-micro 2.0

ONE-DCO-1.0-Signed-off-by: Chunseok Lee <[email protected]>
  • Loading branch information
chunseoklee committed Oct 25, 2024
1 parent 153edec commit 2b9d460
Show file tree
Hide file tree
Showing 7 changed files with 543 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci-interpreter/pal/linux/KernelsToBuild.lst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ REGISTER_KERNEL(FullyConnected)
REGISTER_KERNEL(Gather)
REGISTER_KERNEL(Gelu)
REGISTER_KERNEL(Greater)
REGISTER_KERNEL(GRU)
REGISTER_KERNEL(GreaterEqual)
REGISTER_KERNEL(HardSwish)
REGISTER_KERNEL(If)
Expand Down
176 changes: 176 additions & 0 deletions compiler/luci-interpreter/pal/linux/PALGRU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_PAL_GRU_H
#define LUCI_INTERPRETER_PAL_GRU_H

#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
#include "PALreference_ops.h"
namespace luci_interpreter_pal
{

// tflite's Logistic does not provide inplace Logistic kernel
void Logistic(const int flat_size, const float *input_data, float *output_data)
{
const float cutoff_upper = 16.619047164916992188f;
const float cutoff_lower = -9.f;

// Rational for using approximation in reference kernel.
// 0. This approximation gives enough precision for float.
// 1. This works around an issue on an embedded chipset where exp() does not
// return correctly as expected - exp(x) should return inf when overflown
// not 1.701417 IEEE 754 defines representation for inf.
// 2. This will speed up calculation and is matching the behavior in the
// optimized kernels. (check the definition of scalar_logistic_op<float>)

for (int i = 0; i < flat_size; i++)
{
float val = input_data[i];
float result;
if (val > cutoff_upper)
{
result = 1.0f;
}
else if (val < cutoff_lower)
{
result = std::exp(val);
}
else
{
result = 1.f / (1.f + std::exp(-val));
}
output_data[i] = result;
}
}

void calculateGRU(const float *input_data, const float *weight_input_data,
const float *weight_hidden_data, const float *bias_input_data,
const float *bias_hidden_data, float *output_data,
const tflite::RuntimeShape &input_shape, const tflite::RuntimeShape &output_shape,
const tflite::RuntimeShape &weight_input_shape,
const tflite::RuntimeShape &weight_hidden_shape, float *output_input_data,
float *output_hidden_data, const tflite::RuntimeShape &output_shape_fc,
float *intermediate_buffer)
{
tflite::FullyConnectedParams op_params{};
// As FC nodes doesn't have any activations inside GRU, let' use just numeric limits
op_params.float_activation_min = std::numeric_limits<float>::lowest();
op_params.float_activation_max = std::numeric_limits<float>::max();

// FC Input
tflite::RuntimeShape bias_input_shape{weight_input_shape.Dims(0)};
tflite::reference_ops::FullyConnected(op_params, output_shape, output_data, weight_input_shape,
weight_input_data, bias_input_shape, bias_input_data,
output_shape_fc, output_input_data);

// FC Hidden
tflite::RuntimeShape bias_hidden_shape{weight_hidden_shape.Dims(0)};
// Note: input for this FC node will be saved without intermediate buffer
tflite::reference_ops::FullyConnected(op_params, input_shape, input_data, weight_hidden_shape,
weight_hidden_data, bias_hidden_shape, bias_hidden_data,
output_shape_fc, output_hidden_data);

int num_elements = output_shape_fc.Dims(1) / 3;

float *second_hidden_part = output_hidden_data + num_elements;
float *second_input_part = output_input_data + num_elements;

float *third_hidden_part = second_hidden_part + num_elements;
float *third_input_part = second_input_part + num_elements;

// Calculate Left part
for (int i = 0; i < num_elements; ++i)
{
output_input_data[i] += output_hidden_data[i];
}

Logistic(num_elements, output_input_data, output_input_data);

// Calculate most left mul
float *most_left_part_final = output_input_data;
float *first_part = output_input_data;
for (int i = 0; i < num_elements; ++i)
{
output_data[i] *= most_left_part_final[i];
first_part[i] = 1.0f - first_part[i];
}

// Calc second part
for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] += second_input_part[i];
}

Logistic(num_elements, second_hidden_part, second_hidden_part);

for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] *= third_input_part[i];
second_hidden_part[i] += third_hidden_part[i];
}

for (int i = 0; i < num_elements; ++i)
{
if (second_hidden_part[i] > 19)
{
second_hidden_part[i] = 1;
}
else if (second_hidden_part[i] < -19)
{
second_hidden_part[i] = -1;
}
else
{
second_hidden_part[i] = std::tanh(second_hidden_part[i]);
}
}

for (int i = 0; i < num_elements; ++i)
{
second_hidden_part[i] *= first_part[i];
output_data[i] += second_hidden_part[i];
}
}

void GRU(const float *input_data, const float *weight_input_data, const float *weight_hidden_data,
const float *bias_input_data, const float *bias_hidden_data,
const float *hidden_state_data, float *output_data, float *output_input_data,
float *output_hidden_data, const tflite::RuntimeShape &input_shape,
const tflite::RuntimeShape &output_shape, const tflite::RuntimeShape &weight_input_shape,
const tflite::RuntimeShape &weight_hidden_shape, const size_t intermediate_buffer_size,
float *intermediate_buffer)
{
const int32_t time = input_shape.Dims(0);

tflite::RuntimeShape output_shape_fc(2);
output_shape_fc.SetDim(0, 1);
output_shape_fc.SetDim(1, weight_hidden_shape.Dims(0));

std::memcpy(output_data, hidden_state_data, output_shape.FlatSize() * sizeof(float));

for (int i = 0; i < time; ++i)
{
calculateGRU(input_data, weight_input_data, weight_hidden_data, bias_input_data,
bias_hidden_data, output_data, input_shape, output_shape, weight_input_shape,
weight_hidden_shape, output_input_data, output_hidden_data, output_shape_fc,
intermediate_buffer);
input_data += input_shape.Dims(2);
}
}

} // namespace luci_interpreter_pal

#endif // ONERT_MICRO_EXECUTE_PAL_GRU_COMMON_H
7 changes: 7 additions & 0 deletions compiler/luci-interpreter/src/core/KernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ struct GeluParams
bool approximate;
};

struct GRUParams
{
Activation fused_act_function = Activation::NONE;
bool return_sequences = false;
bool time_major = false;
};

struct InstanceNormParams
{
float epsilon;
Expand Down
88 changes: 88 additions & 0 deletions compiler/luci-interpreter/src/kernels/GRU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernels/GRU.h"

#include "kernels/Utils.h"

#include "PALFullyConnected.h"
#include "PALGRU.h"

namespace luci_interpreter
{
namespace kernels
{
GRU::GRU(const Tensor *input, const Tensor *hidden_hidden, const Tensor *hidden_hidden_bias,
const Tensor *hidden_input, const Tensor *hidden_input_bias, const Tensor *state,
Tensor *output, const GRUParams &params)
: KernelWithParams<GRUParams>(
{input, hidden_hidden, hidden_hidden_bias, hidden_input, hidden_input_bias, state}, {output},
params)
{
}

void GRU::configure()
{
auto hidden_hidden_shape = getTensorShape(hidden_hidden());
auto hidden_input_shape = getTensorShape(hidden_input());
LUCI_INTERPRETER_CHECK(hidden_hidden_shape.Dims(0) == hidden_input_shape.Dims(0));

const int32_t div_factor = 3;

auto output_shape = getTensorShape(output());
auto state_shape = getTensorShape(state());

output()->resize(state()->shape());

LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
}

void GRU::execute() const
{
switch (input()->element_type())
{
case DataType::FLOAT32:
evalFloat();
break;
default:
throw std::runtime_error("luci-GRU Unsupported data type.");
}
}

void GRU::evalFloat() const
{
uint8_t *output_hidden_data;
uint8_t *output_input_data;

// allocate output datas above
output_hidden_data = new uint8_t[getTensorShape(hidden_hidden()).FlatSize() * sizeof(float)];
output_input_data = new uint8_t[getTensorShape(hidden_input()).FlatSize() * sizeof(float)];

luci_interpreter_pal::GRU(
getTensorData<float>(input()), getTensorData<float>(hidden_input()),
getTensorData<float>(hidden_hidden()), getTensorData<float>(hidden_input_bias()),
getTensorData<float>(hidden_hidden_bias()), getTensorData<float>(state()),
getTensorData<float>(output()), reinterpret_cast<float *>(output_input_data),
reinterpret_cast<float *>(output_hidden_data), getTensorShape(input()),
getTensorShape(output()), getTensorShape(hidden_input()), getTensorShape(hidden_hidden()), 0,
nullptr);

delete output_hidden_data;
delete output_input_data;
}

} // namespace kernels
} // namespace luci_interpreter
53 changes: 53 additions & 0 deletions compiler/luci-interpreter/src/kernels/GRU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_KERNELS_GRU_H
#define LUCI_INTERPRETER_KERNELS_GRU_H

#include "core/Kernel.h"
#include "core/KernelParams.h"

namespace luci_interpreter
{
namespace kernels
{

class GRU : public KernelWithParams<GRUParams>
{
public:
GRU(const Tensor *input, const Tensor *hidden_hidden, const Tensor *hidden_hidden_bias,
const Tensor *hidden_input, const Tensor *hidden_input_bias, const Tensor *state,
Tensor *output, const GRUParams &params);

const Tensor *input() const { return _inputs[0]; }
const Tensor *hidden_hidden() const { return _inputs[1]; }
const Tensor *hidden_hidden_bias() const { return _inputs[2]; }
const Tensor *hidden_input() const { return _inputs[3]; }
const Tensor *hidden_input_bias() const { return _inputs[4]; }
const Tensor *state() const { return _inputs[5]; }
Tensor *output() const { return _outputs[0]; }

void configure() override;
void execute() const override;

private:
void evalFloat() const;
};

} // namespace kernels
} // namespace luci_interpreter

#endif // LUCI_INTERPRETER_KERNELS_ROPE_H
Loading

0 comments on commit 2b9d460

Please sign in to comment.