@@ -20,7 +20,6 @@ Module Name:
20
20
#include " softmax.h"
21
21
#include " softmax_kernel_neon.h"
22
22
23
- // TODO(fajin): intra-loop parallelism
24
23
namespace softmax_neon {
25
24
26
25
template <typename T>
@@ -44,7 +43,7 @@ struct MlasExpConstants {
44
43
T MaximumExponent;
45
44
};
46
45
47
- const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
46
+ constexpr MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
48
47
0xcc55 , // -25 * ln2
49
48
0x498c , // 16 * ln2
50
49
0xc95f , // -15.5 * ln2
@@ -64,59 +63,57 @@ const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
64
63
0x3C00 , // 15
65
64
};
66
65
67
- const MlasExpConstants<float16x8_t > ExpConstantsFp16x8 = {
68
- MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRange ),
69
- MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRange ),
70
- MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRangeSumExp ),
71
- MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRangeSumExp ),
72
- MlasBroadcastFloat16x8 (ExpConstantsFp16.RoundingBias ),
73
- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Reciprocal ),
74
- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2High ),
75
- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Mid ),
76
- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Low ),
77
- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_0 ),
78
- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_1 ),
79
- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_2 ),
80
- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_3 ),
81
- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_4 ),
82
- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_56 ),
83
- MlasBroadcastFloat16x8 (ExpConstantsFp16.MinimumExponent ),
84
- MlasBroadcastFloat16x8 (ExpConstantsFp16.MaximumExponent ),
85
- };
86
-
87
- const MlasExpConstants<float16x4_t > ExpConstantsFp16x4 = {
88
- MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRange ),
89
- MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRange ),
90
- MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRangeSumExp ),
91
- MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRangeSumExp ),
92
- MlasBroadcastFloat16x4 (ExpConstantsFp16.RoundingBias ),
93
- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Reciprocal ),
94
- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2High ),
95
- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Mid ),
96
- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Low ),
97
- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_0 ),
98
- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_1 ),
99
- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_2 ),
100
- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_3 ),
101
- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_4 ),
102
- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_56 ),
103
- MlasBroadcastFloat16x4 (ExpConstantsFp16.MinimumExponent ),
104
- MlasBroadcastFloat16x4 (ExpConstantsFp16.MaximumExponent ),
105
- };
106
-
107
66
template <typename T>
108
67
MLAS_FORCEINLINE
109
68
MlasExpConstants<T> Get_Exp_Constants ();
110
69
111
70
template <>
112
71
MLAS_FORCEINLINE
113
72
MlasExpConstants<float16x8_t > Get_Exp_Constants<float16x8_t >() {
73
+ const static MlasExpConstants<float16x8_t > ExpConstantsFp16x8 = {
74
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRange ),
75
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRange ),
76
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRangeSumExp ),
77
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRangeSumExp ),
78
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.RoundingBias ),
79
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Reciprocal ),
80
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2High ),
81
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Mid ),
82
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Low ),
83
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_0 ),
84
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_1 ),
85
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_2 ),
86
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_3 ),
87
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_4 ),
88
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_56 ),
89
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.MinimumExponent ),
90
+ MlasBroadcastFloat16x8 (ExpConstantsFp16.MaximumExponent ),
91
+ };
114
92
return ExpConstantsFp16x8;
115
93
}
116
94
117
95
template <>
118
96
MLAS_FORCEINLINE
119
97
MlasExpConstants<float16x4_t > Get_Exp_Constants<float16x4_t >() {
98
+ const static MlasExpConstants<float16x4_t > ExpConstantsFp16x4 = {
99
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRange ),
100
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRange ),
101
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRangeSumExp ),
102
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRangeSumExp ),
103
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.RoundingBias ),
104
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Reciprocal ),
105
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2High ),
106
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Mid ),
107
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Low ),
108
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_0 ),
109
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_1 ),
110
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_2 ),
111
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_3 ),
112
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_4 ),
113
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_56 ),
114
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.MinimumExponent ),
115
+ MlasBroadcastFloat16x4 (ExpConstantsFp16.MaximumExponent ),
116
+ };
120
117
return ExpConstantsFp16x4;
121
118
}
122
119
@@ -419,7 +416,7 @@ struct MlasTanhConstants {
419
416
T beta_0;
420
417
};
421
418
422
- const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
419
+ constexpr MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
423
420
0xc308 , // -3.51562
424
421
0x4308 , // 3.51562
425
422
0x0001 ,
@@ -432,45 +429,43 @@ const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
432
429
0x1d03 ,
433
430
};
434
431
435
- const MlasTanhConstants<float16x8_t > TanhConstantsFp16x8 = {
436
- MlasBroadcastFloat16x8 (TanhConstantsFp16.LowerRange ),
437
- MlasBroadcastFloat16x8 (TanhConstantsFp16.UpperRange ),
438
- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_7 ),
439
- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_5 ),
440
- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_3 ),
441
- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_1 ),
442
- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_6 ),
443
- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_4 ),
444
- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_2 ),
445
- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_0 ),
446
- };
447
-
448
- const MlasTanhConstants<float16x4_t > TanhConstantsFp16x4 = {
449
- MlasBroadcastFloat16x4 (TanhConstantsFp16.LowerRange ),
450
- MlasBroadcastFloat16x4 (TanhConstantsFp16.UpperRange ),
451
- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_7 ),
452
- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_5 ),
453
- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_3 ),
454
- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_1 ),
455
- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_6 ),
456
- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_4 ),
457
- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_2 ),
458
- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_0 ),
459
- };
460
-
461
432
template <typename T>
462
433
MLAS_FORCEINLINE
463
434
MlasTanhConstants<T> Get_Tanh_Constants ();
464
435
465
436
template <>
466
437
MLAS_FORCEINLINE
467
438
MlasTanhConstants<float16x8_t > Get_Tanh_Constants<float16x8_t >() {
439
+ const static MlasTanhConstants<float16x8_t > TanhConstantsFp16x8 = {
440
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.LowerRange ),
441
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.UpperRange ),
442
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_7 ),
443
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_5 ),
444
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_3 ),
445
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_1 ),
446
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_6 ),
447
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_4 ),
448
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_2 ),
449
+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_0 ),
450
+ };
468
451
return TanhConstantsFp16x8;
469
452
}
470
453
471
454
template <>
472
455
MLAS_FORCEINLINE
473
456
MlasTanhConstants<float16x4_t > Get_Tanh_Constants<float16x4_t >() {
457
+ const static MlasTanhConstants<float16x4_t > TanhConstantsFp16x4 = {
458
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.LowerRange ),
459
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.UpperRange ),
460
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_7 ),
461
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_5 ),
462
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_3 ),
463
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_1 ),
464
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_6 ),
465
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_4 ),
466
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_2 ),
467
+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_0 ),
468
+ };
474
469
return TanhConstantsFp16x4;
475
470
}
476
471
0 commit comments