diff --git a/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp index b4e88111b3c58..dfd65d9d55fbb 100644 --- a/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp @@ -20,7 +20,6 @@ Module Name: #include "softmax.h" #include "softmax_kernel_neon.h" -// TODO(fajin): intra-loop parallelism namespace softmax_neon { template @@ -44,7 +43,7 @@ struct MlasExpConstants { T MaximumExponent; }; -const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = { +constexpr MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = { 0xcc55, // -25 * ln2 0x498c, // 16 * ln2 0xc95f, // -15.5 * ln2 @@ -64,59 +63,57 @@ const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = { 0x3C00, // 15 }; -const MlasExpConstants ExpConstantsFp16x8 = { - MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRange), - MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRange), - MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRangeSumExp), - MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRangeSumExp), - MlasBroadcastFloat16x8(ExpConstantsFp16.RoundingBias), - MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Reciprocal), - MlasBroadcastFloat16x8(ExpConstantsFp16.Log2High), - MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Mid), - MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Low), - MlasBroadcastFloat16x8(ExpConstantsFp16.poly_0), - MlasBroadcastFloat16x8(ExpConstantsFp16.poly_1), - MlasBroadcastFloat16x8(ExpConstantsFp16.poly_2), - MlasBroadcastFloat16x8(ExpConstantsFp16.poly_3), - MlasBroadcastFloat16x8(ExpConstantsFp16.poly_4), - MlasBroadcastFloat16x8(ExpConstantsFp16.poly_56), - MlasBroadcastFloat16x8(ExpConstantsFp16.MinimumExponent), - MlasBroadcastFloat16x8(ExpConstantsFp16.MaximumExponent), -}; - -const MlasExpConstants ExpConstantsFp16x4 = { - MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRange), - MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRange), - MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRangeSumExp), - MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRangeSumExp), - MlasBroadcastFloat16x4(ExpConstantsFp16.RoundingBias), - MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Reciprocal), - MlasBroadcastFloat16x4(ExpConstantsFp16.Log2High), - MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Mid), - MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Low), - MlasBroadcastFloat16x4(ExpConstantsFp16.poly_0), - MlasBroadcastFloat16x4(ExpConstantsFp16.poly_1), - MlasBroadcastFloat16x4(ExpConstantsFp16.poly_2), - MlasBroadcastFloat16x4(ExpConstantsFp16.poly_3), - MlasBroadcastFloat16x4(ExpConstantsFp16.poly_4), - MlasBroadcastFloat16x4(ExpConstantsFp16.poly_56), - MlasBroadcastFloat16x4(ExpConstantsFp16.MinimumExponent), - MlasBroadcastFloat16x4(ExpConstantsFp16.MaximumExponent), -}; - template MLAS_FORCEINLINE -MlasExpConstants Get_Exp_Constants(); +const MlasExpConstants& Get_Exp_Constants(); template <> MLAS_FORCEINLINE -MlasExpConstants Get_Exp_Constants() { +const MlasExpConstants& Get_Exp_Constants() { + const static MlasExpConstants ExpConstantsFp16x8 = { + MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRange), + MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRange), + MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRangeSumExp), + MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRangeSumExp), + MlasBroadcastFloat16x8(ExpConstantsFp16.RoundingBias), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Reciprocal), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2High), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Mid), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Low), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_0), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_1), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_2), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_3), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_4), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_56), + MlasBroadcastFloat16x8(ExpConstantsFp16.MinimumExponent), + MlasBroadcastFloat16x8(ExpConstantsFp16.MaximumExponent), + }; return ExpConstantsFp16x8; } template <> MLAS_FORCEINLINE -MlasExpConstants Get_Exp_Constants() { +const MlasExpConstants& Get_Exp_Constants() { + const static MlasExpConstants ExpConstantsFp16x4 = { + MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRange), + MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRange), + MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRangeSumExp), + MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRangeSumExp), + MlasBroadcastFloat16x4(ExpConstantsFp16.RoundingBias), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Reciprocal), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2High), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Mid), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Low), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_0), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_1), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_2), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_3), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_4), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_56), + MlasBroadcastFloat16x4(ExpConstantsFp16.MinimumExponent), + MlasBroadcastFloat16x4(ExpConstantsFp16.MaximumExponent), + }; return ExpConstantsFp16x4; } @@ -124,7 +121,7 @@ MlasExpConstants Get_Exp_Constants() { template MLAS_FORCEINLINE T Exp_Vector_Fp16(T x) { - const auto constants = Get_Exp_Constants(); + const auto& constants = Get_Exp_Constants(); auto clamped_x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); // integral @@ -242,7 +239,7 @@ void Exp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) { template MLAS_FORCEINLINE T SumExp_Vector_Fp16(T x, T negative_maximum) { - const auto constants = Get_Exp_Constants(); + const auto& constants = Get_Exp_Constants(); auto clamped_x = MlasMaximumFloat16(MlasAddFloat16(x, negative_maximum), constants.LowerRangeSumExp); // integral @@ -419,7 +416,7 @@ struct MlasTanhConstants { T beta_0; }; -const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = { +constexpr MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = { 0xc308, // -3.51562 0x4308, // 3.51562 0x0001, @@ -432,45 +429,43 @@ const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = { 0x1d03, }; -const MlasTanhConstants TanhConstantsFp16x8 = { - MlasBroadcastFloat16x8(TanhConstantsFp16.LowerRange), - MlasBroadcastFloat16x8(TanhConstantsFp16.UpperRange), - MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_7), - MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_5), - MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_3), - MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_1), - MlasBroadcastFloat16x8(TanhConstantsFp16.beta_6), - MlasBroadcastFloat16x8(TanhConstantsFp16.beta_4), - MlasBroadcastFloat16x8(TanhConstantsFp16.beta_2), - MlasBroadcastFloat16x8(TanhConstantsFp16.beta_0), -}; - -const MlasTanhConstants TanhConstantsFp16x4 = { - MlasBroadcastFloat16x4(TanhConstantsFp16.LowerRange), - MlasBroadcastFloat16x4(TanhConstantsFp16.UpperRange), - MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_7), - MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_5), - MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_3), - MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_1), - MlasBroadcastFloat16x4(TanhConstantsFp16.beta_6), - MlasBroadcastFloat16x4(TanhConstantsFp16.beta_4), - MlasBroadcastFloat16x4(TanhConstantsFp16.beta_2), - MlasBroadcastFloat16x4(TanhConstantsFp16.beta_0), -}; - template MLAS_FORCEINLINE -MlasTanhConstants Get_Tanh_Constants(); +const MlasTanhConstants& Get_Tanh_Constants(); template <> MLAS_FORCEINLINE -MlasTanhConstants Get_Tanh_Constants() { +const MlasTanhConstants& Get_Tanh_Constants() { + const static MlasTanhConstants TanhConstantsFp16x8 = { + MlasBroadcastFloat16x8(TanhConstantsFp16.LowerRange), + MlasBroadcastFloat16x8(TanhConstantsFp16.UpperRange), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_7), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_5), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_3), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_1), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_6), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_4), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_2), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_0), + }; return TanhConstantsFp16x8; } template <> MLAS_FORCEINLINE -MlasTanhConstants Get_Tanh_Constants() { +const MlasTanhConstants& Get_Tanh_Constants() { + const static MlasTanhConstants TanhConstantsFp16x4 = { + MlasBroadcastFloat16x4(TanhConstantsFp16.LowerRange), + MlasBroadcastFloat16x4(TanhConstantsFp16.UpperRange), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_7), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_5), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_3), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_1), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_6), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_4), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_2), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_0), + }; return TanhConstantsFp16x4; } @@ -478,7 +473,7 @@ MlasTanhConstants Get_Tanh_Constants() { template MLAS_FORCEINLINE T Tanh_Vector_Fp16(T x) { - const auto constants = Get_Tanh_Constants(); + const auto& constants = Get_Tanh_Constants(); x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); T x_2 = MlasMultiplyFloat16(x, x);