Skip to content

Commit 3b19264

Browse files
committed
fix const init
1 parent 333fbdb commit 3b19264

File tree

1 file changed

+64
-69
lines changed

1 file changed

+64
-69
lines changed

onnxruntime/core/mlas/lib/softmax_kernel_neon_fp16.cpp

+64-69
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ Module Name:
2020
#include "softmax.h"
2121
#include "softmax_kernel_neon.h"
2222

23-
// TODO(fajin): intra-loop parallelism
2423
namespace softmax_neon {
2524

2625
template <typename T>
@@ -44,7 +43,7 @@ struct MlasExpConstants {
4443
T MaximumExponent;
4544
};
4645

47-
const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
46+
constexpr MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
4847
0xcc55, // -25 * ln2
4948
0x498c, // 16 * ln2
5049
0xc95f, // -15.5 * ln2
@@ -64,59 +63,57 @@ const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
6463
0x3C00, // 15
6564
};
6665

67-
const MlasExpConstants<float16x8_t> ExpConstantsFp16x8 = {
68-
MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRange),
69-
MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRange),
70-
MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRangeSumExp),
71-
MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRangeSumExp),
72-
MlasBroadcastFloat16x8(ExpConstantsFp16.RoundingBias),
73-
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Reciprocal),
74-
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2High),
75-
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Mid),
76-
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Low),
77-
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_0),
78-
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_1),
79-
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_2),
80-
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_3),
81-
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_4),
82-
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_56),
83-
MlasBroadcastFloat16x8(ExpConstantsFp16.MinimumExponent),
84-
MlasBroadcastFloat16x8(ExpConstantsFp16.MaximumExponent),
85-
};
86-
87-
const MlasExpConstants<float16x4_t> ExpConstantsFp16x4 = {
88-
MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRange),
89-
MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRange),
90-
MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRangeSumExp),
91-
MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRangeSumExp),
92-
MlasBroadcastFloat16x4(ExpConstantsFp16.RoundingBias),
93-
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Reciprocal),
94-
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2High),
95-
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Mid),
96-
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Low),
97-
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_0),
98-
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_1),
99-
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_2),
100-
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_3),
101-
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_4),
102-
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_56),
103-
MlasBroadcastFloat16x4(ExpConstantsFp16.MinimumExponent),
104-
MlasBroadcastFloat16x4(ExpConstantsFp16.MaximumExponent),
105-
};
106-
10766
template <typename T>
10867
MLAS_FORCEINLINE
10968
MlasExpConstants<T> Get_Exp_Constants();
11069

11170
template <>
11271
MLAS_FORCEINLINE
11372
MlasExpConstants<float16x8_t> Get_Exp_Constants<float16x8_t>() {
73+
const static MlasExpConstants<float16x8_t> ExpConstantsFp16x8 = {
74+
MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRange),
75+
MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRange),
76+
MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRangeSumExp),
77+
MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRangeSumExp),
78+
MlasBroadcastFloat16x8(ExpConstantsFp16.RoundingBias),
79+
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Reciprocal),
80+
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2High),
81+
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Mid),
82+
MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Low),
83+
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_0),
84+
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_1),
85+
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_2),
86+
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_3),
87+
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_4),
88+
MlasBroadcastFloat16x8(ExpConstantsFp16.poly_56),
89+
MlasBroadcastFloat16x8(ExpConstantsFp16.MinimumExponent),
90+
MlasBroadcastFloat16x8(ExpConstantsFp16.MaximumExponent),
91+
};
11492
return ExpConstantsFp16x8;
11593
}
11694

11795
template <>
11896
MLAS_FORCEINLINE
11997
MlasExpConstants<float16x4_t> Get_Exp_Constants<float16x4_t>() {
98+
const static MlasExpConstants<float16x4_t> ExpConstantsFp16x4 = {
99+
MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRange),
100+
MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRange),
101+
MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRangeSumExp),
102+
MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRangeSumExp),
103+
MlasBroadcastFloat16x4(ExpConstantsFp16.RoundingBias),
104+
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Reciprocal),
105+
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2High),
106+
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Mid),
107+
MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Low),
108+
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_0),
109+
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_1),
110+
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_2),
111+
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_3),
112+
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_4),
113+
MlasBroadcastFloat16x4(ExpConstantsFp16.poly_56),
114+
MlasBroadcastFloat16x4(ExpConstantsFp16.MinimumExponent),
115+
MlasBroadcastFloat16x4(ExpConstantsFp16.MaximumExponent),
116+
};
120117
return ExpConstantsFp16x4;
121118
}
122119

@@ -419,7 +416,7 @@ struct MlasTanhConstants {
419416
T beta_0;
420417
};
421418

422-
const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
419+
constexpr MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
423420
0xc308, // -3.51562
424421
0x4308, // 3.51562
425422
0x0001,
@@ -432,45 +429,43 @@ const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
432429
0x1d03,
433430
};
434431

435-
const MlasTanhConstants<float16x8_t> TanhConstantsFp16x8 = {
436-
MlasBroadcastFloat16x8(TanhConstantsFp16.LowerRange),
437-
MlasBroadcastFloat16x8(TanhConstantsFp16.UpperRange),
438-
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_7),
439-
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_5),
440-
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_3),
441-
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_1),
442-
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_6),
443-
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_4),
444-
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_2),
445-
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_0),
446-
};
447-
448-
const MlasTanhConstants<float16x4_t> TanhConstantsFp16x4 = {
449-
MlasBroadcastFloat16x4(TanhConstantsFp16.LowerRange),
450-
MlasBroadcastFloat16x4(TanhConstantsFp16.UpperRange),
451-
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_7),
452-
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_5),
453-
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_3),
454-
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_1),
455-
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_6),
456-
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_4),
457-
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_2),
458-
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_0),
459-
};
460-
461432
template <typename T>
462433
MLAS_FORCEINLINE
463434
MlasTanhConstants<T> Get_Tanh_Constants();
464435

465436
template <>
466437
MLAS_FORCEINLINE
467438
MlasTanhConstants<float16x8_t> Get_Tanh_Constants<float16x8_t>() {
439+
const static MlasTanhConstants<float16x8_t> TanhConstantsFp16x8 = {
440+
MlasBroadcastFloat16x8(TanhConstantsFp16.LowerRange),
441+
MlasBroadcastFloat16x8(TanhConstantsFp16.UpperRange),
442+
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_7),
443+
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_5),
444+
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_3),
445+
MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_1),
446+
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_6),
447+
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_4),
448+
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_2),
449+
MlasBroadcastFloat16x8(TanhConstantsFp16.beta_0),
450+
};
468451
return TanhConstantsFp16x8;
469452
}
470453

471454
template <>
472455
MLAS_FORCEINLINE
473456
MlasTanhConstants<float16x4_t> Get_Tanh_Constants<float16x4_t>() {
457+
const static MlasTanhConstants<float16x4_t> TanhConstantsFp16x4 = {
458+
MlasBroadcastFloat16x4(TanhConstantsFp16.LowerRange),
459+
MlasBroadcastFloat16x4(TanhConstantsFp16.UpperRange),
460+
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_7),
461+
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_5),
462+
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_3),
463+
MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_1),
464+
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_6),
465+
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_4),
466+
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_2),
467+
MlasBroadcastFloat16x4(TanhConstantsFp16.beta_0),
468+
};
474469
return TanhConstantsFp16x4;
475470
}
476471

0 commit comments

Comments
 (0)