Skip to content

Commit df25008

Browse files
Fixes L2Pool implementation to not average pooling region squares
See discussion here: webmachinelearning/webnn#278 PiperOrigin-RevId: 664908697
1 parent 8782d90 commit df25008

File tree

3 files changed

+9
-14
lines changed

3 files changed

+9
-14
lines changed

tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ TEST(NNAPIDelegate, L2PoolWithNoActivation) {
613613
3, 2, 10, 7, //
614614
});
615615
ASSERT_EQ(m.Invoke(), kTfLiteOk);
616-
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
616+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({7.0, 13.0}));
617617
}
618618

619619
class ConvolutionOpModel : public SingleOpModelWithNNAPI {

tensorflow/lite/kernels/internal/optimized/optimized_ops.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -3299,8 +3299,6 @@ inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
32993299
const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
33003300
auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
33013301
Eigen::VectorXf in_square(in_mat.rows());
3302-
Eigen::VectorXf out_count(out_mat.cols());
3303-
out_count.setZero();
33043302
// Prefill the output to 0.
33053303
out_mat.setZero();
33063304
for (int b = 0; b < batches; ++b) {
@@ -3329,16 +3327,13 @@ inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
33293327
for (int pw = w_start; pw < w_end; ++pw) {
33303328
const int out_offset = pw + output_width * (ph + output_height * b);
33313329
out_mat.col(out_offset) += in_square;
3332-
out_count(out_offset)++;
33333330
}
33343331
}
33353332
}
33363333
}
33373334
}
33383335

3339-
out_count = out_count.array().inverse();
3340-
out_mat =
3341-
(out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
3336+
out_mat = out_mat.cwiseSqrt();
33423337

33433338
const int flat_size = output_shape.FlatSize();
33443339
for (int i = 0; i < flat_size; ++i) {

tensorflow/lite/kernels/pooling_test.cc

+7-7
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ TEST(FloatPoolingOpTest, L2Pool) {
10621062
3, 2, 10, 7, //
10631063
});
10641064
ASSERT_EQ(m.Invoke(), kTfLiteOk);
1065-
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
1065+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({7.0, 13.0}));
10661066
}
10671067

10681068
TEST(FloatPoolingOpTest, L2PoolActivationRelu) {
@@ -1076,7 +1076,7 @@ TEST(FloatPoolingOpTest, L2PoolActivationRelu) {
10761076
-3, -2, 10, 7, //
10771077
});
10781078
ASSERT_EQ(m.Invoke(), kTfLiteOk);
1079-
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.53553, 6.5})));
1079+
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({7.07107, 13.0})));
10801080
}
10811081

10821082
TEST(FloatPoolingOpTest, L2PoolActivationRelu1) {
@@ -1090,7 +1090,7 @@ TEST(FloatPoolingOpTest, L2PoolActivationRelu1) {
10901090
-0.3, -0.2, 10, 7, //
10911091
});
10921092
ASSERT_EQ(m.Invoke(), kTfLiteOk);
1093-
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.353553, 1.0})));
1093+
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.707107, 1.0})));
10941094
}
10951095

10961096
TEST(FloatPoolingOpTest, L2PoolActivationRelu6) {
@@ -1104,7 +1104,7 @@ TEST(FloatPoolingOpTest, L2PoolActivationRelu6) {
11041104
-0.3, -0.2, 10, 7, //
11051105
});
11061106
ASSERT_EQ(m.Invoke(), kTfLiteOk);
1107-
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.353553, 6.0})));
1107+
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({0.707107, 6.0})));
11081108
}
11091109

11101110
TEST(FloatPoolingOpTest, L2PoolPaddingSame) {
@@ -1117,7 +1117,7 @@ TEST(FloatPoolingOpTest, L2PoolPaddingSame) {
11171117
3, 2, 10, 7, //
11181118
});
11191119
ASSERT_EQ(m.Invoke(), kTfLiteOk);
1120-
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
1120+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({7.0, 13.0}));
11211121
}
11221122

11231123
TEST(FloatPoolingOpTest, L2PoolPaddingSameSlide1) {
@@ -1133,7 +1133,7 @@ TEST(FloatPoolingOpTest, L2PoolPaddingSameSlide1) {
11331133
ASSERT_EQ(m.Invoke(), kTfLiteOk);
11341134
EXPECT_THAT(m.GetOutput(),
11351135
ElementsAreArray(ArrayFloatNear(
1136-
{3.5, 6.0, 6.5, 5.70088, 2.54951, 7.2111, 8.63134, 7.0},
1136+
{7.0, 12.0, 13.0, 8.06226, 3.60555, 10.19804, 12.20656, 7.0},
11371137
/*max_abs_error=*/1e-4)));
11381138
}
11391139

@@ -1148,7 +1148,7 @@ TEST(FloatPoolingOpTest, L2PoolPaddingValidSlide1) {
11481148
3, 2, 10, 7, //
11491149
});
11501150
ASSERT_EQ(m.Invoke(), kTfLiteOk);
1151-
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.0, 6.5}));
1151+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({7.0, 12.0, 13.0}));
11521152
}
11531153

11541154
#if GTEST_HAS_DEATH_TEST

0 commit comments

Comments
 (0)