Skip to content

Commit 39209b4

Browse files
authored
Add 16bit support for unpack, transpose and comparisons operators (#3028)
Added 16bit operator support for below operators: 1. unpack 2. transpose 3. comparison BUG=fixes #3044
1 parent 6fc405d commit 39209b4

10 files changed

+418
-33
lines changed

tensorflow/lite/micro/kernels/comparisons.cc

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -286,6 +286,19 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
286286
tflite::micro::GetTensorData<int8_t>(input2), output_shape,
287287
output_data);
288288
break;
289+
case kTfLiteInt16:
290+
requires_broadcast
291+
? reference_ops::Broadcast4DSlowGreaterWithScaling(
292+
data->params, input1_shape,
293+
tflite::micro::GetTensorData<int16_t>(input1), input2_shape,
294+
tflite::micro::GetTensorData<int16_t>(input2), output_shape,
295+
output_data)
296+
: reference_ops::GreaterWithScaling(
297+
data->params, input1_shape,
298+
tflite::micro::GetTensorData<int16_t>(input1), input2_shape,
299+
tflite::micro::GetTensorData<int16_t>(input2), output_shape,
300+
output_data);
301+
break;
289302
default:
290303
MicroPrintf("Type %s (%d) not supported.",
291304
TfLiteTypeGetName(input1->type), input1->type);

tensorflow/lite/micro/kernels/comparisons_test.cc

+53-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -126,6 +126,29 @@ void TestComparisonQuantizedInt8(const TFLMRegistration& registration,
126126
TestComparison(registration, tensors, expected_output_data, output_data);
127127
}
128128

129+
void TestComparisonQuantizedInt16(const TFLMRegistration& registration,
130+
int* input1_dims_data, float* input1_data,
131+
int16_t* input1_quantized, float input1_scale,
132+
int input1_zero_point, int* input2_dims_data,
133+
float* input2_data, int16_t* input2_quantized,
134+
float input2_scale, int input2_zero_point,
135+
bool* expected_output_data,
136+
int* output_dims_data, bool* output_data) {
137+
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
138+
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
139+
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
140+
141+
TfLiteTensor tensors[tensors_size] = {
142+
CreateQuantizedTensor(input1_data, input1_quantized, input1_dims,
143+
input1_scale, input1_zero_point),
144+
CreateQuantizedTensor(input2_data, input2_quantized, input2_dims,
145+
input2_scale, input2_zero_point),
146+
CreateTensor(output_data, output_dims),
147+
};
148+
149+
TestComparison(registration, tensors, expected_output_data, output_data);
150+
}
151+
129152
} // namespace
130153
} // namespace testing
131154
} // namespace tflite
@@ -656,6 +679,35 @@ TF_LITE_MICRO_TEST(GreaterQuantizedInt8WithBroadcast) {
656679
}
657680
}
658681

682+
TF_LITE_MICRO_TEST(GreaterQuantizedInt16WithBroadcast) {
683+
const int num_shapes = 4;
684+
const int max_shape_size = 5;
685+
int test_shapes[num_shapes][max_shape_size] = {
686+
{1, 6}, {2, 2, 3}, {3, 2, 1, 3}, {4, 1, 3, 1, 2}};
687+
688+
for (int i = 0; i < num_shapes; ++i) {
689+
int* input1_dim = test_shapes[i];
690+
int input2_dim[] = {1, 1};
691+
float input1_data[] = {20, -2, -71, 8, 11, 20};
692+
float input2_data[] = {8};
693+
694+
bool expected_data[] = {true, false, false, false, true, true};
695+
int* expected_dim = input1_dim;
696+
697+
const float input1_scale = 0.5;
698+
const int input1_zero_point = -9;
699+
int16_t input1_quantized[6];
700+
int16_t input2_quantized[6];
701+
702+
bool output_data[6];
703+
tflite::testing::TestComparisonQuantizedInt16(
704+
tflite::Register_GREATER(), input1_dim, input1_data, input1_quantized,
705+
input1_scale, input1_zero_point, input2_dim, input2_data,
706+
input2_quantized, input1_scale, input1_zero_point, expected_data,
707+
expected_dim, output_data);
708+
}
709+
}
710+
659711
TF_LITE_MICRO_TEST(GreaterEqualQuantizedInt8WithBroadcast) {
660712
const int num_shapes = 4;
661713
const int max_shape_size = 5;

tensorflow/lite/micro/kernels/fully_connected.cc

+89-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -238,25 +238,97 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
238238
case kTfLiteInt16: {
239239
switch (filter->type) {
240240
case kTfLiteInt8: {
241-
tflite::reference_integer_ops::FullyConnected(
242-
FullyConnectedParamsQuantized(data),
243-
tflite::micro::GetTensorShape(input),
244-
tflite::micro::GetTensorData<int16_t>(input),
245-
tflite::micro::GetTensorShape(filter),
241+
if (bias == nullptr || bias->type == kTfLiteInt32) {
242+
data.is_per_channel
243+
? tflite::reference_integer_ops::FullyConnectedPerChannel(
244+
FullyConnectedParamsQuantized(data),
245+
data.per_channel_output_multiplier,
246+
reinterpret_cast<const int*>(
247+
data.per_channel_output_shift),
248+
tflite::micro::GetTensorShape(input),
249+
tflite::micro::GetTensorData<int16_t>(input),
250+
tflite::micro::GetTensorShape(filter),
246251
#ifdef USE_TFLM_COMPRESSION
247-
tflite::micro::GetTensorData<int8_t>(micro_context, filter,
248-
weights_comp_td,
249-
data.weights_scratch_index),
250-
tflite::micro::GetTensorShape(bias),
251-
tflite::micro::GetOptionalTensorData<int64_t>(
252-
micro_context, bias, bias_comp_td, data.bias_scratch_index),
252+
tflite::micro::GetTensorData<int8_t>(
253+
micro_context, filter, weights_comp_td,
254+
data.weights_scratch_index),
255+
tflite::micro::GetTensorShape(bias),
256+
tflite::micro::GetOptionalTensorData<int32_t>(
257+
micro_context, bias, bias_comp_td,
258+
data.bias_scratch_index),
253259
#else // USE_TFLM_COMPRESSION
254-
tflite::micro::GetTensorData<int8_t>(filter),
255-
tflite::micro::GetTensorShape(bias),
256-
tflite::micro::GetOptionalTensorData<int64_t>(bias),
260+
tflite::micro::GetTensorData<int8_t>(filter),
261+
tflite::micro::GetTensorShape(bias),
262+
tflite::micro::GetOptionalTensorData<int32_t>(bias),
257263
#endif // USE_TFLM_COMPRESSION
258-
tflite::micro::GetTensorShape(output),
259-
tflite::micro::GetTensorData<int16_t>(output));
264+
tflite::micro::GetTensorShape(output),
265+
tflite::micro::GetTensorData<int16_t>(output))
266+
: tflite::reference_integer_ops::FullyConnected(
267+
FullyConnectedParamsQuantized(data),
268+
tflite::micro::GetTensorShape(input),
269+
tflite::micro::GetTensorData<int16_t>(input),
270+
tflite::micro::GetTensorShape(filter),
271+
#ifdef USE_TFLM_COMPRESSION
272+
tflite::micro::GetTensorData<int8_t>(
273+
micro_context, filter, weights_comp_td,
274+
data.weights_scratch_index),
275+
tflite::micro::GetTensorShape(bias),
276+
tflite::micro::GetOptionalTensorData<int32_t>(
277+
micro_context, bias, bias_comp_td,
278+
data.bias_scratch_index),
279+
#else // USE_TFLM_COMPRESSION
280+
tflite::micro::GetTensorData<int8_t>(filter),
281+
tflite::micro::GetTensorShape(bias),
282+
tflite::micro::GetOptionalTensorData<int32_t>(bias),
283+
#endif // USE_TFLM_COMPRESSION
284+
tflite::micro::GetTensorShape(output),
285+
tflite::micro::GetTensorData<int16_t>(output));
286+
} else if (bias->type == kTfLiteInt64) {
287+
data.is_per_channel
288+
? tflite::reference_integer_ops::FullyConnectedPerChannel(
289+
FullyConnectedParamsQuantized(data),
290+
data.per_channel_output_multiplier,
291+
reinterpret_cast<const int*>(
292+
data.per_channel_output_shift),
293+
tflite::micro::GetTensorShape(input),
294+
tflite::micro::GetTensorData<int16_t>(input),
295+
tflite::micro::GetTensorShape(filter),
296+
#ifdef USE_TFLM_COMPRESSION
297+
tflite::micro::GetTensorData<int8_t>(
298+
micro_context, filter, weights_comp_td,
299+
data.weights_scratch_index),
300+
tflite::micro::GetTensorShape(bias),
301+
tflite::micro::GetOptionalTensorData<int64_t>(
302+
micro_context, bias, bias_comp_td,
303+
data.bias_scratch_index),
304+
#else // USE_TFLM_COMPRESSION
305+
tflite::micro::GetTensorData<int8_t>(filter),
306+
tflite::micro::GetTensorShape(bias),
307+
tflite::micro::GetOptionalTensorData<int64_t>(bias),
308+
#endif // USE_TFLM_COMPRESSION
309+
tflite::micro::GetTensorShape(output),
310+
tflite::micro::GetTensorData<int16_t>(output))
311+
: tflite::reference_integer_ops::FullyConnected(
312+
FullyConnectedParamsQuantized(data),
313+
tflite::micro::GetTensorShape(input),
314+
tflite::micro::GetTensorData<int16_t>(input),
315+
tflite::micro::GetTensorShape(filter),
316+
#ifdef USE_TFLM_COMPRESSION
317+
tflite::micro::GetTensorData<int8_t>(
318+
micro_context, filter, weights_comp_td,
319+
data.weights_scratch_index),
320+
tflite::micro::GetTensorShape(bias),
321+
tflite::micro::GetOptionalTensorData<int64_t>(
322+
micro_context, bias, bias_comp_td,
323+
data.bias_scratch_index),
324+
#else // USE_TFLM_COMPRESSION
325+
tflite::micro::GetTensorData<int8_t>(filter),
326+
tflite::micro::GetTensorShape(bias),
327+
tflite::micro::GetOptionalTensorData<int64_t>(bias),
328+
#endif // USE_TFLM_COMPRESSION
329+
tflite::micro::GetTensorShape(output),
330+
tflite::micro::GetTensorData<int16_t>(output));
331+
}
260332
break;
261333
}
262334
default: {

tensorflow/lite/micro/kernels/fully_connected_common.cc

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -95,9 +95,14 @@ TfLiteStatus CalculateOpDataFullyConnected(
9595
filter->quantization.params);
9696
const int per_channel_quantization_size = affine_quantization->scale->size;
9797

98-
// Currently only Int8 is supported for per channel quantization.
99-
TF_LITE_ENSURE(context,
100-
input->type == kTfLiteInt8 && filter->type != kTfLiteInt4);
98+
// Currently only Int8/Int16 are supported for per channel quantization.
99+
TF_LITE_ENSURE(
100+
context,
101+
(input->type == kTfLiteInt8 && filter->type != kTfLiteInt4) ||
102+
(input->type == kTfLiteInt16 && filter->type != kTfLiteInt4));
103+
104+
TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
105+
per_channel_quantization_size);
101106

102107
TF_LITE_ENSURE_EQ(
103108
context, per_channel_quantization_size,

0 commit comments

Comments
 (0)