Skip to content

Commit 95ab5ad

Browse files
authored
Support non-spatial mode in BatchNormalization (microsoft#2092)
* Initial commit * Update * Update * Fix build break * Update * More changes * Update type * Exclude Nuphar for non-spatial tests * Update * Resolve PR comments
1 parent 2536553 commit 95ab5ad

File tree

8 files changed

+181
-48
lines changed

8 files changed

+181
-48
lines changed

csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ private void TestMultiThreads()
319319
private static Dictionary<string, string> GetSkippedModels()
320320
{
321321
var skipModels = new Dictionary<string, string>() {
322-
{ "mxnet_arcface", "Model not supported by CPU execution provider" } ,
322+
{ "mxnet_arcface", "Model is an invalid ONNX model"},
323323
{ "tf_inception_v2", "TODO: Debug failing model, skipping for now" },
324324
{ "fp16_inception_v1", "16-bit float not supported type in C#." },
325325
{ "fp16_shufflenet", "16-bit float not supported type in C#." },

onnxruntime/core/providers/cpu/nn/batch_norm.cc

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Status BatchNorm<float>::Compute(OpKernelContext* p_op_kernel_context) const {
5252
const auto* mean = p_op_kernel_context->Input<Tensor>(3);
5353
const auto* var = p_op_kernel_context->Input<Tensor>(4);
5454

55-
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var));
55+
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, is_spatial_));
5656

5757
const TensorShape& x_shape = X->Shape();
5858
Tensor* Y = p_op_kernel_context->Output(0, x_shape);
@@ -61,33 +61,46 @@ Status BatchNorm<float>::Compute(OpKernelContext* p_op_kernel_context) const {
6161
const size_t N = dims_vec[0];
6262
const size_t C = dims_vec[1]; // assume NCHW as per the spec
6363

64-
// calculate sample_size
64+
// calculate sample_size (per individual channel)
6565
size_t sample_size = 1;
6666
for (size_t i = 2; i < dims_vec.size(); ++i) {
6767
sample_size *= dims_vec[i];
6868
}
6969

70-
ConstEigenVectorArrayMap<float> scale_arr(scale->template Data<float>(), C);
71-
ConstEigenVectorArrayMap<float> bias_arr(B->template Data<float>(), C);
70+
// calculate sample_size (including all channels)
71+
size_t sample_size_incl_all_channels = sample_size * C;
72+
73+
ConstEigenVectorArrayMap<float> scale_arr(scale->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
74+
ConstEigenVectorArrayMap<float> bias_arr(B->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
7275

7376
// Regardless of training or testing, we will apply the estimated mean
7477
// and standard deviation to the input. For testing, they are
7578
// specified directly by the input, and for training, they are computed
7679
// by the op.
77-
Eigen::Array<float, Eigen::Dynamic, 1> inv_std(C);
78-
ConstEigenVectorArrayMap<float> var_arr(var->template Data<float>(), C);
80+
Eigen::Array<float, Eigen::Dynamic, 1> inv_std(is_spatial_ ? C : sample_size_incl_all_channels);
81+
ConstEigenVectorArrayMap<float> var_arr(var->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
7982
inv_std = (var_arr + epsilon_).sqrt().inverse();
80-
ConstEigenVectorArrayMap<float> mean_arr(mean->template Data<float>(), C);
83+
ConstEigenVectorArrayMap<float> mean_arr(mean->template Data<float>(), is_spatial_ ? C : sample_size_incl_all_channels);
8184
// We can fuse the output computation as follows:
8285
// ((x - est_mean) * (inv_var) * scale + bias
8386
// to
8487
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
8588
Eigen::Array<float, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr;
8689
Eigen::Array<float, Eigen::Dynamic, 1> new_bias = bias_arr - mean_arr * new_scale;
87-
EigenArrayMap<float> Y_arr(Y->template MutableData<float>(), sample_size, N * C);
88-
ConstEigenArrayMap<float> X_arr(X->template Data<float>(), sample_size, N * C);
89-
for (size_t nc = 0; nc < N * C; ++nc) {
90-
Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
90+
EigenArrayMap<float> Y_arr(Y->template MutableData<float>(),
91+
is_spatial_ ? sample_size : sample_size_incl_all_channels,
92+
is_spatial_ ? N * C : N);
93+
ConstEigenArrayMap<float> X_arr(X->template Data<float>(),
94+
is_spatial_ ? sample_size : sample_size_incl_all_channels,
95+
is_spatial_ ? N * C : N);
96+
if (is_spatial_) { // spatial == 1
97+
for (size_t nc = 0; nc < N * C; ++nc) {
98+
Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
99+
}
100+
} else { // spatial == 0
101+
for (size_t n = 0; n < N; ++n) {
102+
Y_arr.col(n) = X_arr.col(n) * new_scale.col(0) + new_bias.col(0);
103+
}
91104
}
92105

93106
return Status::OK();

onnxruntime/core/providers/cpu/nn/batch_norm.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,19 @@ class BatchNorm : public OpKernel {
3232
explicit BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
3333
auto st = op_kernel_info.GetAttr<float>("epsilon", &epsilon_);
3434
ORT_ENFORCE(st.IsOK(), st.ErrorMessage());
35-
36-
// opset 6-8
37-
int64_t spatial;
38-
if (op_kernel_info.GetAttr<int64_t>("spatial", &spatial).IsOK()) {
39-
ORT_ENFORCE(spatial == 1, "BatchNormalization kernel for CPU provider does not support non-spatial cases");
40-
}
35+
36+
// For opset 6-8, if spatial attribute exists, pick up the value (by default spatial == 1)
37+
// From opset 9 onwards, by default, only the spatial case (spatial == 1) is defined per spec
38+
is_spatial_ = op_kernel_info.GetAttrOrDefault<int64_t>("spatial", 1) == 1 ? true : false;
4139

4240
//TODO: momentum
4341
}
4442

4543
Status Compute(OpKernelContext* p_op_kernel_context) const override;
4644

47-
protected:
48-
float epsilon_;
49-
//int64_t is_test_; ignored in this implementation since we're doing inferencing only.
45+
protected:
46+
float epsilon_;
47+
bool is_spatial_;
48+
//int64_t is_test_; ignored in this implementation since we're doing inferencing only.
5049
};
5150
} // namespace onnxruntime

onnxruntime/core/providers/cpu/nn/batch_norm_helper.h

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,48 +14,92 @@ class BatchNormHelper {
1414
const Tensor* scale,
1515
const Tensor* B,
1616
const Tensor* mean,
17-
const Tensor* var) {
17+
const Tensor* var,
18+
bool is_spatial = true) {
19+
const auto& x_dims = X->Shape().GetDims();
20+
if (x_dims.size() < 2) {
21+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
22+
"Invalid input X: The rank of input X must be atleast 2. Got rank: ", x_dims.size());
23+
}
24+
25+
int64_t num_channels = x_dims[1];
26+
int num_feature_dims = static_cast<int>(X->Shape().NumDimensions() - 2); // the first 2 are respectively - N and C
27+
1828
// defined as per spec and used for validation
19-
constexpr int kNumInputScaleDimensions = 1;
20-
constexpr int kNumInputBiasDimensions = 1;
21-
constexpr int kNumInputMeanDimensions = 1;
22-
constexpr int kNumInputVarianceDimensions = 1;
29+
int kNumInputScaleDimensions = (is_spatial ? 1 : num_feature_dims + 1);
30+
int kNumInputBiasDimensions = (is_spatial ? 1 : num_feature_dims + 1);
31+
int kNumInputMeanDimensions = (is_spatial ? 1 : num_feature_dims + 1);
32+
int kNumInputVarianceDimensions = (is_spatial ? 1 : num_feature_dims + 1);
2333
//constexpr int kMinCudaNumDims = 4;
2434
//constexpr int kMaxCudaNumDims = 5;
2535

26-
if (X->Shape().GetDims().empty()) {
27-
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid input X: Empty dimensions");
28-
}
29-
30-
int64_t num_channels = X->Shape().GetDims()[1];
31-
32-
if (scale->Shape().NumDimensions() != kNumInputScaleDimensions) {
36+
// validate 'scales' shape
37+
const auto& scale_dims = scale->Shape().GetDims();
38+
if (static_cast<int>(scale_dims.size()) != kNumInputScaleDimensions) {
3339
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
3440
}
35-
if (scale->Shape().GetDims()[0] != num_channels) {
41+
if (scale_dims[0] != num_channels) {
3642
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels);
3743
}
44+
// in non-spatial cases - the other dims of 'scale' must be validated
45+
if (!is_spatial) {
46+
for (int feature = 0; feature < num_feature_dims; ++feature) {
47+
if (scale_dims[1 + feature] != x_dims[2 + feature]) {
48+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
49+
}
50+
}
51+
}
3852

39-
if (B->Shape().NumDimensions() != kNumInputBiasDimensions) {
53+
// validate 'B' shape
54+
const auto& B_dims = B->Shape().GetDims();
55+
if (static_cast<int>(B_dims.size()) != kNumInputBiasDimensions) {
4056
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
4157
}
42-
if (B->Shape().GetDims()[0] != num_channels) {
58+
if (B_dims[0] != num_channels) {
4359
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels);
4460
}
61+
// in non-spatial cases - the other dims of 'B' must be validated
62+
if (!is_spatial) {
63+
for (int feature = 0; feature < num_feature_dims; ++feature) {
64+
if (B_dims[1 + feature] != x_dims[2 + feature]) {
65+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
66+
}
67+
}
68+
}
4569

46-
if (mean->Shape().NumDimensions() != kNumInputMeanDimensions) {
70+
// validate 'mean' shape
71+
const auto& mean_dims = mean->Shape().GetDims();
72+
if (static_cast<int>(mean_dims.size()) != kNumInputMeanDimensions) {
4773
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
4874
}
49-
if (mean->Shape().GetDims()[0] != num_channels) {
75+
if (mean_dims[0] != num_channels) {
5076
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels);
5177
}
78+
// in non-spatial cases - the other dims of 'mean' must be validated
79+
if (!is_spatial) {
80+
for (int feature = 0; feature < num_feature_dims; ++feature) {
81+
if (mean_dims[1 + feature] != x_dims[2 + feature]) {
82+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
83+
}
84+
}
85+
}
5286

53-
if (var->Shape().NumDimensions() != kNumInputVarianceDimensions) {
87+
// validate 'var' shape
88+
const auto& var_dims = var->Shape().GetDims();
89+
if (static_cast<int>(var_dims.size()) != kNumInputVarianceDimensions) {
5490
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
5591
}
56-
if (var->Shape().GetDims()[0] != num_channels) {
92+
if (var_dims[0] != num_channels) {
5793
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels);
5894
}
95+
// in non-spatial cases - the other dims of 'var' must be validated
96+
if (!is_spatial) {
97+
for (int feature = 0; feature < num_feature_dims; ++feature) {
98+
if (var_dims[1 + feature] != x_dims[2 + feature]) {
99+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
100+
}
101+
}
102+
}
59103

60104
return common::Status::OK();
61105
}

onnxruntime/core/providers/cuda/nn/batch_norm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Status BatchNorm<T>::ComputeInternal(OpKernelContext* p_op_kernel_context) const
4949
const Tensor* mean = p_op_kernel_context->Input<Tensor>(3);
5050
const Tensor* var = p_op_kernel_context->Input<Tensor>(4);
5151

52-
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var));
52+
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, scale, B, mean, var, spatial_ == 1));
5353

5454
const TensorShape& x_shape = X->Shape();
5555
Tensor* Y = p_op_kernel_context->Output(0, x_shape);

onnxruntime/core/providers/cuda/nn/batch_norm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class BatchNorm final : public CudaKernel {
2626
}
2727

2828
if (spatial_ == 0) {
29-
cudnn_batch_norm_mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; // TODO add test case for this when implemented in CPU as well.
29+
cudnn_batch_norm_mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
3030
}
3131
}
3232

onnxruntime/test/onnx/main.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
398398
{"shrink", "test case is wrong", {"onnx141"}},
399399
{"maxpool_with_argmax_2d_precomputed_strides", "ShapeInferenceError"},
400400
{"tf_inception_v2", "result mismatch"},
401-
{"mxnet_arcface", "result mismatch"},
401+
{"mxnet_arcface", "Model is an invalid ONNX model"},
402402
{"unique_not_sorted_without_axis", "Expected data for 'Y' is incorrect and in sorted order."},
403403
{"cumsum_1d_reverse_exclusive", "only failing linux GPU CI. Likely build error."},
404404
{"det_2d", "not implemented yet"},
@@ -508,7 +508,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
508508
#endif
509509

510510
#ifdef USE_CUDA
511-
broken_tests.insert({"mxnet_arcface", "result mismatch"});
512511
broken_tests.insert({"mask_rcnn_keras", "result mismatch"});
513512
broken_tests.insert({"mlperf_ssd_mobilenet_300", "unknown error"});
514513
broken_tests.insert({"mlperf_ssd_resnet34_1200", "unknown error"});

onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,27 @@ void TestBatchNorm(const InputDataMap& input_data_map,
2323
const vector<int64_t>& expected_output_shape,
2424
int64_t spatial_mode = 1,
2525
OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
26-
const std::string& err_str = "") {
27-
OpTester test("BatchNormalization");
26+
const std::string& err_str = "",
27+
int opset_version = 9) {
28+
OpTester test("BatchNormalization", opset_version);
2829
if (epsilon.has_value()) {
2930
test.AddAttribute("epsilon", epsilon.value());
3031
}
31-
test.AddAttribute("spatial", spatial_mode);
32+
if (opset_version < 9) { // spatial is only defined for opset-8 and below in the spec
33+
test.AddAttribute("spatial", spatial_mode);
34+
}
3235
test.AddInput<float>("X", input_shapes_map.at("X"), input_data_map.at("X"));
3336
test.AddInput<float>("scale", input_shapes_map.at("scale"), input_data_map.at("scale"));
3437
test.AddInput<float>("B", input_shapes_map.at("B"), input_data_map.at("B"));
3538
test.AddInput<float>("mean", input_shapes_map.at("mean"), input_data_map.at("mean"));
3639
test.AddInput<float>("var", input_shapes_map.at("var"), input_data_map.at("var"));
3740
test.AddOutput<float>("output", expected_output_shape, expected_output);
38-
test.Run(expect_result, err_str, {kTensorrtExecutionProvider});// Weight as input is not supported by TensorRT
41+
// Weight as input is not supported by TensorRT and spatial == 0 is not supported by Nuphar
42+
std::unordered_set<std::string> excluded_eps = {kTensorrtExecutionProvider};
43+
if (spatial_mode == 0) {
44+
excluded_eps.insert(kNGraphExecutionProvider);
45+
}
46+
test.Run(expect_result, err_str, excluded_eps);
3947
}
4048

4149
TEST(BatchNormTest, PositiveTestCase) {
@@ -513,6 +521,76 @@ TEST(BatchNormTest, InvalidVarDim) {
513521
"Invalid input var");
514522
}
515523

524+
TEST(BatchNormTest, NonSpatial_Simple) {
525+
vector<float> X{1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f};
526+
vector<float> scale{1.f, 1.f, 1.f, 1.f};
527+
vector<float> B{1.f, 0.f, 0.f, 1.f};
528+
vector<float> mean{0.f, 0.f, 0.f, 0.f};
529+
vector<float> var{1.f, 1.f, 1.f, 1.f};
530+
531+
InputDataMap input_data_map;
532+
input_data_map.insert({"X", X});
533+
input_data_map.insert({"scale", scale});
534+
input_data_map.insert({"B", B});
535+
input_data_map.insert({"mean", mean});
536+
input_data_map.insert({"var", var});
537+
538+
InputShapesMap input_shapes_map;
539+
input_shapes_map.insert({"X", {2, 2, 2}});
540+
input_shapes_map.insert({"scale", {2, 2}});
541+
input_shapes_map.insert({"B", {2, 2}});
542+
input_shapes_map.insert({"mean", {2, 2}});
543+
input_shapes_map.insert({"var", {2, 2}});
544+
545+
vector<int64_t> expected_output_shape{2, 2, 2};
546+
auto expected_output = {2.f, 2.f, 3.f, 5.f, 2.f, 2.f, 3.f, 5.f};
547+
float epsilon = 0.f;
548+
TestBatchNorm(input_data_map,
549+
input_shapes_map,
550+
epsilon,
551+
expected_output,
552+
expected_output_shape,
553+
0,
554+
OpTester::ExpectResult::kExpectSuccess,
555+
"",
556+
7); // opset-7
557+
}
558+
559+
TEST(BatchNormTest, NonSpatial_Complicated) {
560+
vector<float> X{0.2134f, 0.32434f, 0.5644f, 0.3234f, 0.4545f, 0.3445f};
561+
vector<float> scale{0.5f, 0.6f};
562+
vector<float> B{0.2f, 0.1f};
563+
vector<float> mean{0.034f, 0.342f};
564+
vector<float> var{1.f, 1.f};
565+
566+
InputDataMap input_data_map;
567+
input_data_map.insert({"X", X});
568+
input_data_map.insert({"scale", scale});
569+
input_data_map.insert({"B", B});
570+
input_data_map.insert({"mean", mean});
571+
input_data_map.insert({"var", var});
572+
573+
InputShapesMap input_shapes_map;
574+
input_shapes_map.insert({"X", {3, 1, 2}});
575+
input_shapes_map.insert({"scale", {1, 2}});
576+
input_shapes_map.insert({"B", {1, 2}});
577+
input_shapes_map.insert({"mean", {1, 2}});
578+
input_shapes_map.insert({"var", {1, 2}});
579+
580+
vector<int64_t> expected_output_shape{3, 1, 2};
581+
auto expected_output = {0.2897f, 0.089404f, 0.4652f, 0.08884f, 0.41025f, 0.1015f};
582+
float epsilon = 1e-05f;
583+
TestBatchNorm(input_data_map,
584+
input_shapes_map,
585+
epsilon,
586+
expected_output,
587+
expected_output_shape,
588+
0,
589+
OpTester::ExpectResult::kExpectSuccess,
590+
"",
591+
8); // opset-8
592+
}
593+
516594
// Only CUDA kernel has float 16 support
517595
#ifdef USE_CUDA
518596
TEST(BatchNormTest, BatchNorm2d_fp16) {

0 commit comments

Comments
 (0)