|
1 |
| -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. |
| 1 | +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. |
2 | 2 |
|
3 | 3 | Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | you may not use this file except in compliance with the License.
|
@@ -238,25 +238,97 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
|
238 | 238 | case kTfLiteInt16: {
|
239 | 239 | switch (filter->type) {
|
240 | 240 | case kTfLiteInt8: {
|
241 |
| - tflite::reference_integer_ops::FullyConnected( |
242 |
| - FullyConnectedParamsQuantized(data), |
243 |
| - tflite::micro::GetTensorShape(input), |
244 |
| - tflite::micro::GetTensorData<int16_t>(input), |
245 |
| - tflite::micro::GetTensorShape(filter), |
| 241 | + if (bias == nullptr || bias->type == kTfLiteInt32) { |
| 242 | + data.is_per_channel |
| 243 | + ? tflite::reference_integer_ops::FullyConnectedPerChannel( |
| 244 | + FullyConnectedParamsQuantized(data), |
| 245 | + data.per_channel_output_multiplier, |
| 246 | + reinterpret_cast<const int*>( |
| 247 | + data.per_channel_output_shift), |
| 248 | + tflite::micro::GetTensorShape(input), |
| 249 | + tflite::micro::GetTensorData<int16_t>(input), |
| 250 | + tflite::micro::GetTensorShape(filter), |
246 | 251 | #ifdef USE_TFLM_COMPRESSION
|
247 |
| - tflite::micro::GetTensorData<int8_t>(micro_context, filter, |
248 |
| - weights_comp_td, |
249 |
| - data.weights_scratch_index), |
250 |
| - tflite::micro::GetTensorShape(bias), |
251 |
| - tflite::micro::GetOptionalTensorData<int64_t>( |
252 |
| - micro_context, bias, bias_comp_td, data.bias_scratch_index), |
| 252 | + tflite::micro::GetTensorData<int8_t>( |
| 253 | + micro_context, filter, weights_comp_td, |
| 254 | + data.weights_scratch_index), |
| 255 | + tflite::micro::GetTensorShape(bias), |
| 256 | + tflite::micro::GetOptionalTensorData<int32_t>( |
| 257 | + micro_context, bias, bias_comp_td, |
| 258 | + data.bias_scratch_index), |
253 | 259 | #else // USE_TFLM_COMPRESSION
|
254 |
| - tflite::micro::GetTensorData<int8_t>(filter), |
255 |
| - tflite::micro::GetTensorShape(bias), |
256 |
| - tflite::micro::GetOptionalTensorData<int64_t>(bias), |
| 260 | + tflite::micro::GetTensorData<int8_t>(filter), |
| 261 | + tflite::micro::GetTensorShape(bias), |
| 262 | + tflite::micro::GetOptionalTensorData<int32_t>(bias), |
257 | 263 | #endif // USE_TFLM_COMPRESSION
|
258 |
| - tflite::micro::GetTensorShape(output), |
259 |
| - tflite::micro::GetTensorData<int16_t>(output)); |
| 264 | + tflite::micro::GetTensorShape(output), |
| 265 | + tflite::micro::GetTensorData<int16_t>(output)) |
| 266 | + : tflite::reference_integer_ops::FullyConnected( |
| 267 | + FullyConnectedParamsQuantized(data), |
| 268 | + tflite::micro::GetTensorShape(input), |
| 269 | + tflite::micro::GetTensorData<int16_t>(input), |
| 270 | + tflite::micro::GetTensorShape(filter), |
| 271 | +#ifdef USE_TFLM_COMPRESSION |
| 272 | + tflite::micro::GetTensorData<int8_t>( |
| 273 | + micro_context, filter, weights_comp_td, |
| 274 | + data.weights_scratch_index), |
| 275 | + tflite::micro::GetTensorShape(bias), |
| 276 | + tflite::micro::GetOptionalTensorData<int32_t>( |
| 277 | + micro_context, bias, bias_comp_td, |
| 278 | + data.bias_scratch_index), |
| 279 | +#else // USE_TFLM_COMPRESSION |
| 280 | + tflite::micro::GetTensorData<int8_t>(filter), |
| 281 | + tflite::micro::GetTensorShape(bias), |
| 282 | + tflite::micro::GetOptionalTensorData<int32_t>(bias), |
| 283 | +#endif // USE_TFLM_COMPRESSION |
| 284 | + tflite::micro::GetTensorShape(output), |
| 285 | + tflite::micro::GetTensorData<int16_t>(output)); |
| 286 | + } else if (bias->type == kTfLiteInt64) { |
| 287 | + data.is_per_channel |
| 288 | + ? tflite::reference_integer_ops::FullyConnectedPerChannel( |
| 289 | + FullyConnectedParamsQuantized(data), |
| 290 | + data.per_channel_output_multiplier, |
| 291 | + reinterpret_cast<const int*>( |
| 292 | + data.per_channel_output_shift), |
| 293 | + tflite::micro::GetTensorShape(input), |
| 294 | + tflite::micro::GetTensorData<int16_t>(input), |
| 295 | + tflite::micro::GetTensorShape(filter), |
| 296 | +#ifdef USE_TFLM_COMPRESSION |
| 297 | + tflite::micro::GetTensorData<int8_t>( |
| 298 | + micro_context, filter, weights_comp_td, |
| 299 | + data.weights_scratch_index), |
| 300 | + tflite::micro::GetTensorShape(bias), |
| 301 | + tflite::micro::GetOptionalTensorData<int64_t>( |
| 302 | + micro_context, bias, bias_comp_td, |
| 303 | + data.bias_scratch_index), |
| 304 | +#else // USE_TFLM_COMPRESSION |
| 305 | + tflite::micro::GetTensorData<int8_t>(filter), |
| 306 | + tflite::micro::GetTensorShape(bias), |
| 307 | + tflite::micro::GetOptionalTensorData<int64_t>(bias), |
| 308 | +#endif // USE_TFLM_COMPRESSION |
| 309 | + tflite::micro::GetTensorShape(output), |
| 310 | + tflite::micro::GetTensorData<int16_t>(output)) |
| 311 | + : tflite::reference_integer_ops::FullyConnected( |
| 312 | + FullyConnectedParamsQuantized(data), |
| 313 | + tflite::micro::GetTensorShape(input), |
| 314 | + tflite::micro::GetTensorData<int16_t>(input), |
| 315 | + tflite::micro::GetTensorShape(filter), |
| 316 | +#ifdef USE_TFLM_COMPRESSION |
| 317 | + tflite::micro::GetTensorData<int8_t>( |
| 318 | + micro_context, filter, weights_comp_td, |
| 319 | + data.weights_scratch_index), |
| 320 | + tflite::micro::GetTensorShape(bias), |
| 321 | + tflite::micro::GetOptionalTensorData<int64_t>( |
| 322 | + micro_context, bias, bias_comp_td, |
| 323 | + data.bias_scratch_index), |
| 324 | +#else // USE_TFLM_COMPRESSION |
| 325 | + tflite::micro::GetTensorData<int8_t>(filter), |
| 326 | + tflite::micro::GetTensorShape(bias), |
| 327 | + tflite::micro::GetOptionalTensorData<int64_t>(bias), |
| 328 | +#endif // USE_TFLM_COMPRESSION |
| 329 | + tflite::micro::GetTensorShape(output), |
| 330 | + tflite::micro::GetTensorData<int16_t>(output)); |
| 331 | + } |
260 | 332 | break;
|
261 | 333 | }
|
262 | 334 | default: {
|
|
0 commit comments