diff --git a/runtime/compute/cker/include/cker/Shape.h b/runtime/compute/cker/include/cker/Shape.h index 8e604a945d7..248e3a618eb 100644 --- a/runtime/compute/cker/include/cker/Shape.h +++ b/runtime/compute/cker/include/cker/Shape.h @@ -267,27 +267,28 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim) // arrays. template inline bool checkMatching(const Shape &shape, Ts... check_shapes) { - const Shape check_shapes_array[sizeof...(Ts)] = {std::forward(check_shapes)...}; - for (const auto &check_shape : check_shapes_array) - { - // Check matching of shapes except the case of that two shapes can be scalar - if (shape.DimensionsCount() > 1 || check_shape.DimensionsCount() > 1 || shape.FlatSize() != 1 || - check_shape.FlatSize() != 1) + auto match = [&shape](const Shape &s) -> bool { + // Check matching of shapes except the case that both shapes are scalars. + if (shape.DimensionsCount() > 1 || s.DimensionsCount() > 1 || shape.FlatSize() != 1 || + s.FlatSize() != 1) { - if (shape.DimensionsCount() != check_shape.DimensionsCount()) + if (shape.DimensionsCount() != s.DimensionsCount()) { return false; } for (int i = 0; i < shape.DimensionsCount(); ++i) { - if (shape.Dims(i) != check_shape.Dims(i)) + if (shape.Dims(i) != s.Dims(i)) { return false; } } } - } - return true; + return true; + }; + + // Apply the lambda to each check shape and combine with && + return (match(check_shapes) && ...); } struct UNUSED_ALL diff --git a/runtime/compute/ruy/include/ruy/Shape.h b/runtime/compute/ruy/include/ruy/Shape.h index c8b3dc69966..5ddd53ec39e 100644 --- a/runtime/compute/ruy/include/ruy/Shape.h +++ b/runtime/compute/ruy/include/ruy/Shape.h @@ -268,27 +268,28 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim) // arrays. template inline bool checkMatching(const Shape &shape, Ts... check_shapes) { - const Shape check_shapes_array[sizeof...(Ts)] = {std::forward(check_shapes)...}; - for (const auto &check_shape : check_shapes_array) - { - // Check matching of shapes except the case of that two shapes can be scalar - if (shape.DimensionsCount() > 1 || check_shape.DimensionsCount() > 1 || shape.FlatSize() != 1 || - check_shape.FlatSize() != 1) + auto match = [&shape](const Shape &s) -> bool { + // Check matching of shapes except the case that both shapes are scalars. + if (shape.DimensionsCount() > 1 || s.DimensionsCount() > 1 || shape.FlatSize() != 1 || + s.FlatSize() != 1) { - if (shape.DimensionsCount() != check_shape.DimensionsCount()) + if (shape.DimensionsCount() != s.DimensionsCount()) { return false; } for (int i = 0; i < shape.DimensionsCount(); ++i) { - if (shape.Dims(i) != check_shape.Dims(i)) + if (shape.Dims(i) != s.Dims(i)) { return false; } } } - } - return true; + return true; + }; + + // Apply the lambda to each check shape and combine with && + return (match(check_shapes) && ...); } struct UNUSED_ALL