@@ -32,7 +32,7 @@ using ConvPadVector = ConvAttributes::ConvPadVector;
32
32
* 2. Activation
33
33
* It takes an operator attribute 'activation', which supplies the activation info.
34
34
*
35
- * Add is performed AFTER activation.
35
+ * Add is performed BEFORE activation.
36
36
*
37
37
* The implementation supports both NCHW and NHWC. It runs faster with NHWC.
38
38
*
@@ -281,12 +281,10 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
281
281
if (Y->Shape ().Size () == 0 ) {
282
282
return Status::OK ();
283
283
}
284
- if (Sum) {
285
- if (Sum->Shape () != Y->Shape ()) {
286
- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Z shape does not match output shape." ,
287
- " Z: " , Sum->Shape ().ToString ().c_str (),
288
- " Output: " , Y->Shape ().ToString ().c_str ());
289
- }
284
+ if (Sum && Sum->Shape () != Y->Shape ()) {
285
+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Z shape does not match output shape." ,
286
+ " Z: " , Sum->Shape ().ToString ().c_str (),
287
+ " Output: " , Y->Shape ().ToString ().c_str ());
290
288
}
291
289
292
290
const int64_t input_image_size = input_shape.Size ();
@@ -338,7 +336,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
338
336
const auto * Xdata = X->Data <MLFloat16>();
339
337
const auto * Bdata = B != nullptr ? B->Data <MLFloat16>() : nullptr ;
340
338
auto * Ydata = Y->MutableData <MLFloat16>();
341
- const auto * SumData = Sum != nullptr ? Sum->Data <MLFloat16>() : nullptr ;
339
+ const auto * sum_data = Sum != nullptr ? Sum->Data <MLFloat16>() : nullptr ;
342
340
343
341
BufferUniquePtr transpose_input_buffer;
344
342
BufferUniquePtr transpose_output_buffer;
@@ -409,7 +407,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
409
407
for (int64_t image_id = 0 ; image_id < N; ++image_id) {
410
408
const auto * input_data = Xdata;
411
409
auto * output_data = Ydata;
412
- const auto * add_src = SumData ;
410
+ const auto * add_src = sum_data ;
413
411
414
412
if (!channels_last_) {
415
413
// Transpose the input from channels first (CHW) to channels last (HWC).
@@ -478,7 +476,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
478
476
static_cast <size_t >(M),
479
477
static_cast <size_t >(output_count),
480
478
static_cast <size_t >(kernel_size),
481
- &act);
479
+ (!channels_last_ && sum_data) ? nullptr : &act);
482
480
} else {
483
481
for (int64_t group_id = 0 ; group_id < group_count; ++group_id) {
484
482
// Prepare the im2col transformation or use the input buffer directly for
@@ -554,7 +552,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
554
552
gemm_params.C = worker_output + group_id * group_output_channels;
555
553
gemm_params.ldc = static_cast <size_t >(M);
556
554
gemm_params.Bias = Bdata;
557
- gemm_params.OutputProcessor = &act; // process fused activation and add
555
+ gemm_params.OutputProcessor = (!channels_last_ && sum_data) ? nullptr : &act; // process fused activation and add
558
556
559
557
MlasHalfGemmBatch (
560
558
static_cast <size_t >(output_count),
@@ -574,10 +572,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
574
572
Ydata,
575
573
static_cast <size_t >(output_image_size),
576
574
static_cast <size_t >(M));
577
- if (SumData != nullptr ) {
578
- MLAS_ACTIVATION activation;
579
- activation.ActivationKind = MlasIdentityActivation;
580
- MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc (activation, SumData);
575
+ if (sum_data != nullptr ) {
576
+ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc (activation_, sum_data);
581
577
proc.Process (Ydata, 0 , 0 , static_cast <size_t >(M),
582
578
static_cast <size_t >(output_image_size),
583
579
static_cast <size_t >(output_image_size));
@@ -586,8 +582,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
586
582
587
583
Xdata += X_offset;
588
584
Ydata += Y_offset;
589
- if (SumData != nullptr ) {
590
- SumData += Y_offset;
585
+ if (sum_data != nullptr ) {
586
+ sum_data += Y_offset;
591
587
}
592
588
}
593
589
0 commit comments