Skip to content

Commit 685e5b0

Browse files
authored
NhwcFusedConv: Add before Activation (#15837)
### Description Fp16 FusedConv and NhwcFusedConv. Fused Add operator should be performed BEFORE the activation operator. ### Motivation and Context Previous understanding of fused conv is incorrect.
1 parent 003c7d3 commit 685e5b0

File tree

4 files changed

+37
-34
lines changed

4 files changed

+37
-34
lines changed

onnxruntime/core/mlas/inc/mlas.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,7 @@ class MLAS_HALF_GEMM_POSTPROCESSOR {
14391439
/**
14401440
* @brief Half precision activation functions, with optional sum tensor.
14411441
* Supplied sum tensor must be the same layout as the GEMM output tensor.
1442-
* And the supplied sum tensor will be added to the final result.
1442+
* And the supplied sum tensor will be added to the tensor before activation.
14431443
*/
14441444
class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR
14451445
{

onnxruntime/core/mlas/lib/activate_fp16.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -689,35 +689,35 @@ MlasActivationKernel(
689689
size_t n = CountN;
690690

691691
while (n >= 8) {
692-
MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
693692
MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc);
693+
MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
694694
addsrc += 8;
695-
Vector = ActivationFunction.Activate(Vector);
696695
Vector = MlasAddFloat16x8(Vector, AVec);
696+
Vector = ActivationFunction.Activate(Vector);
697697
MlasStoreFloat16x8(buffer, Vector);
698698
buffer += 8;
699699
n -= 8;
700700
}
701701

702702
if (n >= 4) {
703-
MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
704703
MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc);
704+
MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
705705
addsrc += 4;
706-
Vector = ActivationFunction.Activate(Vector);
707706
Vector = MlasAddFloat16x4(Vector, AVec);
707+
Vector = ActivationFunction.Activate(Vector);
708708
MlasStoreFloat16x4(buffer, Vector);
709709
buffer += 4;
710710
n -= 4;
711711
}
712712

713713
if (n > 0) {
714-
MLAS_FLOAT16X4 buf;
715-
std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
716714
MLAS_FLOAT16X4 addbuf;
715+
MLAS_FLOAT16X4 buf;
717716
std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_));
718-
MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf);
719-
res = MlasAddFloat16x4(res, addbuf);
720-
MlasStorePartialFloat16x4(buffer, res, n);
717+
std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
718+
buf = MlasAddFloat16x4(buf, addbuf);
719+
buf = ActivationFunction.Activate(buf);
720+
MlasStorePartialFloat16x4(buffer, buf, n);
721721
}
722722

723723
CRow += ldc;
@@ -858,8 +858,6 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
858858
) const
859859
{
860860
std::vector<float> buffer(CountM*CountN);
861-
MLAS_HALF_GEMM_2FLOAT_PROCESSOR proc(this->Activation_, buffer.data(), CountN);
862-
proc.Process(C, StartM, StartN, CountM, CountN, ldc);
863861

864862
_mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C);
865863
auto* CRow = buffer.data();
@@ -876,6 +874,8 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
876874
}
877875
CAdd += ldc;
878876
}
877+
MlasActivation(&this->Activation_, CRow, nullptr, 1, CountN, CountN);
878+
879879
CvtFloat2Half(Output, CRow, CountN);
880880
CRow += CountN;
881881
Output += ldc;

onnxruntime/core/providers/cpu/fp16/fp16_conv.cc

+13-17
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using ConvPadVector = ConvAttributes::ConvPadVector;
3232
* 2. Activation
3333
* It takes an operator attribute 'activation', which supplies the activation info.
3434
*
35-
* Add is performed AFTER activation.
35+
* Add is performed BEFORE activation.
3636
*
3737
* The implementation supports both NCHW and NHWC. It runs faster with NHWC.
3838
*
@@ -281,12 +281,10 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
281281
if (Y->Shape().Size() == 0) {
282282
return Status::OK();
283283
}
284-
if (Sum) {
285-
if (Sum->Shape() != Y->Shape()) {
286-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.",
287-
" Z: ", Sum->Shape().ToString().c_str(),
288-
" Output: ", Y->Shape().ToString().c_str());
289-
}
284+
if (Sum && Sum->Shape() != Y->Shape()) {
285+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.",
286+
" Z: ", Sum->Shape().ToString().c_str(),
287+
" Output: ", Y->Shape().ToString().c_str());
290288
}
291289

292290
const int64_t input_image_size = input_shape.Size();
@@ -338,7 +336,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
338336
const auto* Xdata = X->Data<MLFloat16>();
339337
const auto* Bdata = B != nullptr ? B->Data<MLFloat16>() : nullptr;
340338
auto* Ydata = Y->MutableData<MLFloat16>();
341-
const auto* SumData = Sum != nullptr ? Sum->Data<MLFloat16>() : nullptr;
339+
const auto* sum_data = Sum != nullptr ? Sum->Data<MLFloat16>() : nullptr;
342340

343341
BufferUniquePtr transpose_input_buffer;
344342
BufferUniquePtr transpose_output_buffer;
@@ -409,7 +407,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
409407
for (int64_t image_id = 0; image_id < N; ++image_id) {
410408
const auto* input_data = Xdata;
411409
auto* output_data = Ydata;
412-
const auto* add_src = SumData;
410+
const auto* add_src = sum_data;
413411

414412
if (!channels_last_) {
415413
// Transpose the input from channels first (CHW) to channels last (HWC).
@@ -478,7 +476,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
478476
static_cast<size_t>(M),
479477
static_cast<size_t>(output_count),
480478
static_cast<size_t>(kernel_size),
481-
&act);
479+
(!channels_last_ && sum_data) ? nullptr : &act);
482480
} else {
483481
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
484482
// Prepare the im2col transformation or use the input buffer directly for
@@ -554,7 +552,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
554552
gemm_params.C = worker_output + group_id * group_output_channels;
555553
gemm_params.ldc = static_cast<size_t>(M);
556554
gemm_params.Bias = Bdata;
557-
gemm_params.OutputProcessor = &act; // process fused activation and add
555+
gemm_params.OutputProcessor = (!channels_last_ && sum_data) ? nullptr : &act; // process fused activation and add
558556

559557
MlasHalfGemmBatch(
560558
static_cast<size_t>(output_count),
@@ -574,10 +572,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
574572
Ydata,
575573
static_cast<size_t>(output_image_size),
576574
static_cast<size_t>(M));
577-
if (SumData != nullptr) {
578-
MLAS_ACTIVATION activation;
579-
activation.ActivationKind = MlasIdentityActivation;
580-
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation, SumData);
575+
if (sum_data != nullptr) {
576+
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation_, sum_data);
581577
proc.Process(Ydata, 0, 0, static_cast<size_t>(M),
582578
static_cast<size_t>(output_image_size),
583579
static_cast<size_t>(output_image_size));
@@ -586,8 +582,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
586582

587583
Xdata += X_offset;
588584
Ydata += Y_offset;
589-
if (SumData != nullptr) {
590-
SumData += Y_offset;
585+
if (sum_data != nullptr) {
586+
sum_data += Y_offset;
591587
}
592588
}
593589

onnxruntime/test/mlas/unittest/test_fp16_activation.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class MlasFp16ActivationTest : public MlasTestBase {
7070
auto addonData = AddonBuffer.GetBuffer(M * N, true);
7171
MatrixGuardBuffer<float> FloatBuffer;
7272
auto* fpBuffer = FloatBuffer.GetBuffer(M * N, true);
73+
MatrixGuardBuffer<float> FloatBuffer1;
74+
auto* fpAddBuffer = FloatBuffer1.GetBuffer(M * N, true);
7375

7476
size_t o = 3;
7577
for (size_t i = 0; i < M * N; i++) {
@@ -88,7 +90,6 @@ class MlasFp16ActivationTest : public MlasTestBase {
8890

8991
MLAS_ACTIVATION Activation;
9092
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(Activation, nullptr);
91-
MLAS_HALF_GEMM_2FLOAT_PROCESSOR converter(Activation, fpBuffer, N);
9293
MLAS_HALF_GEMM_ACTIVATION_PROCESSOR addon(Activation, reinterpret_cast<const MLAS_FP16*>(addonData));
9394
for (auto kind : acts) {
9495
Activation.ActivationKind = MLAS_ACTIVATION_KIND(kind);
@@ -111,17 +112,23 @@ class MlasFp16ActivationTest : public MlasTestBase {
111112
testData1[i] = TestData[i].f;
112113
testData2[i] = TestData[i].f;
113114
testData3[i] = TestData[i].f;
115+
fpBuffer[i] = TestData[i].f;
116+
fpAddBuffer[i] = TestData[i].f + addonData[i].ToFloat();
114117
}
115118
size_t offset = 7;
116119
for (size_t i = _countof(TestData); i < M * N; i++) {
117120
offset = (offset + 19) % 23;
118-
testData1[i] = (MinimumFillValue + offset) / 16.0f;
121+
float f = (MinimumFillValue + offset) / 16.0f;
122+
testData1[i] = f;
119123
testData2[i] = testData1[i];
120124
testData3[i] = testData1[i];
125+
fpBuffer[i] = f;
126+
fpAddBuffer[i] = f + addonData[i].ToFloat();
121127
}
122128

123129
proc.Process(reinterpret_cast<MLAS_FP16*>(testData1), 0, 0, M, N, N);
124-
converter.Process(reinterpret_cast<MLAS_FP16*>(testData2), 0, 0, M, N, N);
130+
MlasActivation(&Activation, fpBuffer, nullptr, M, N, N);
131+
MlasActivation(&Activation, fpAddBuffer, nullptr, M, N, N);
125132
addon.Process(reinterpret_cast<MLAS_FP16*>(testData3), 0, 0, M, N, N);
126133

127134
for (size_t i = 0; i < M * N; i++) {
@@ -131,8 +138,8 @@ class MlasFp16ActivationTest : public MlasTestBase {
131138
<< std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:"
132139
<< std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i];
133140

134-
float addonActual = testData3[i].ToFloat() - addonData[i].ToFloat();
135-
EXPECT_TRUE(check_equal(addonActual, fpBuffer[i]))
141+
float addonActual = testData3[i].ToFloat();
142+
EXPECT_TRUE(check_equal(addonActual, fpAddBuffer[i]))
136143
<< ", Vector + Activation Kind:" << (int)kind << ", i=" << i << ", value:"
137144
<< std::setw(8) << std::setfill('0') << std::hex << actual << ", expecting:"
138145
<< std::setw(8) << std::setfill('0') << std::hex << fpBuffer[i];

0 commit comments

Comments
 (0)