@@ -86,84 +86,84 @@ using acc_type = typename AccumulateType<T, is_cuda>::type;
8686#define CUDA_ACC_TYPE (t, acc_t ) ACC_TYPE(t, acc_t , c10::DeviceType::CUDA)
8787#define CPU_ACC_TYPE (t, acc_t ) ACC_TYPE(t, acc_t , c10::DeviceType::CPU)
8888
89- MPS_ACC_TYPE (BFloat16, float );
90- MPS_ACC_TYPE (Half, float );
91- MPS_ACC_TYPE (Float8_e5m2, float );
92- MPS_ACC_TYPE (Float8_e4m3fn, float );
93- MPS_ACC_TYPE (Float8_e5m2fnuz, float );
94- MPS_ACC_TYPE (Float8_e4m3fnuz, float );
95- MPS_ACC_TYPE (float , float );
96- MPS_ACC_TYPE (double , float );
97- MPS_ACC_TYPE (int8_t , int64_t );
98- MPS_ACC_TYPE (uint8_t , int64_t );
99- MPS_ACC_TYPE (char , int64_t );
100- MPS_ACC_TYPE (int16_t , int64_t );
101- MPS_ACC_TYPE (int32_t , int64_t );
102- MPS_ACC_TYPE (int64_t , int64_t );
103- MPS_ACC_TYPE (bool , bool );
104- MPS_ACC_TYPE (c10::complex <Half>, c10::complex <float >);
105- MPS_ACC_TYPE (c10::complex <float >, c10::complex <float >);
106- MPS_ACC_TYPE (c10::complex <double >, c10::complex <float >);
107-
108- XPU_ACC_TYPE (BFloat16, float );
109- XPU_ACC_TYPE (Half, float );
110- XPU_ACC_TYPE (Float8_e5m2, float );
111- XPU_ACC_TYPE (Float8_e4m3fn, float );
112- XPU_ACC_TYPE (Float8_e5m2fnuz, float );
113- XPU_ACC_TYPE (Float8_e4m3fnuz, float );
114- XPU_ACC_TYPE (float , float );
115- XPU_ACC_TYPE (double , double );
116- XPU_ACC_TYPE (int8_t , int64_t );
117- XPU_ACC_TYPE (uint8_t , int64_t );
118- XPU_ACC_TYPE (char , int64_t );
119- XPU_ACC_TYPE (int16_t , int64_t );
120- XPU_ACC_TYPE (int32_t , int64_t );
121- XPU_ACC_TYPE (int64_t , int64_t );
122- XPU_ACC_TYPE (bool , bool );
123- XPU_ACC_TYPE (c10::complex <Half>, c10::complex <float >);
124- XPU_ACC_TYPE (c10::complex <float >, c10::complex <float >);
125- XPU_ACC_TYPE (c10::complex <double >, c10::complex <double >);
89+ MPS_ACC_TYPE (BFloat16, float )
90+ MPS_ACC_TYPE (Half, float )
91+ MPS_ACC_TYPE (Float8_e5m2, float )
92+ MPS_ACC_TYPE (Float8_e4m3fn, float )
93+ MPS_ACC_TYPE (Float8_e5m2fnuz, float )
94+ MPS_ACC_TYPE (Float8_e4m3fnuz, float )
95+ MPS_ACC_TYPE (float , float )
96+ MPS_ACC_TYPE (double , float )
97+ MPS_ACC_TYPE (int8_t , int64_t )
98+ MPS_ACC_TYPE (uint8_t , int64_t )
99+ MPS_ACC_TYPE (char , int64_t )
100+ MPS_ACC_TYPE (int16_t , int64_t )
101+ MPS_ACC_TYPE (int32_t , int64_t )
102+ MPS_ACC_TYPE (int64_t , int64_t )
103+ MPS_ACC_TYPE (bool , bool )
104+ MPS_ACC_TYPE (c10::complex <Half>, c10::complex <float >)
105+ MPS_ACC_TYPE (c10::complex <float >, c10::complex <float >)
106+ MPS_ACC_TYPE (c10::complex <double >, c10::complex <float >)
107+
108+ XPU_ACC_TYPE (BFloat16, float )
109+ XPU_ACC_TYPE (Half, float )
110+ XPU_ACC_TYPE (Float8_e5m2, float )
111+ XPU_ACC_TYPE (Float8_e4m3fn, float )
112+ XPU_ACC_TYPE (Float8_e5m2fnuz, float )
113+ XPU_ACC_TYPE (Float8_e4m3fnuz, float )
114+ XPU_ACC_TYPE (float , float )
115+ XPU_ACC_TYPE (double , double )
116+ XPU_ACC_TYPE (int8_t , int64_t )
117+ XPU_ACC_TYPE (uint8_t , int64_t )
118+ XPU_ACC_TYPE (char , int64_t )
119+ XPU_ACC_TYPE (int16_t , int64_t )
120+ XPU_ACC_TYPE (int32_t , int64_t )
121+ XPU_ACC_TYPE (int64_t , int64_t )
122+ XPU_ACC_TYPE (bool , bool )
123+ XPU_ACC_TYPE (c10::complex <Half>, c10::complex <float >)
124+ XPU_ACC_TYPE (c10::complex <float >, c10::complex <float >)
125+ XPU_ACC_TYPE (c10::complex <double >, c10::complex <double >)
126126
127127#if defined(__CUDACC__) || defined(__HIPCC__)
128- CUDA_ACC_TYPE (half, float );
128+ CUDA_ACC_TYPE (half, float )
129129#endif
130- CUDA_ACC_TYPE (BFloat16, float );
131- CUDA_ACC_TYPE (Half, float );
132- CUDA_ACC_TYPE (Float8_e5m2, float );
133- CUDA_ACC_TYPE (Float8_e4m3fn, float );
134- CUDA_ACC_TYPE (Float8_e5m2fnuz, float );
135- CUDA_ACC_TYPE (Float8_e4m3fnuz, float );
136- CUDA_ACC_TYPE (float , float );
137- CUDA_ACC_TYPE (double , double );
138- CUDA_ACC_TYPE (int8_t , int64_t );
139- CUDA_ACC_TYPE (uint8_t , int64_t );
140- CUDA_ACC_TYPE (char , int64_t );
141- CUDA_ACC_TYPE (int16_t , int64_t );
142- CUDA_ACC_TYPE (int32_t , int64_t );
143- CUDA_ACC_TYPE (int64_t , int64_t );
144- CUDA_ACC_TYPE (bool , bool );
145- CUDA_ACC_TYPE (c10::complex <Half>, c10::complex <float >);
146- CUDA_ACC_TYPE (c10::complex <float >, c10::complex <float >);
147- CUDA_ACC_TYPE (c10::complex <double >, c10::complex <double >);
148-
149- CPU_ACC_TYPE (BFloat16, float );
150- CPU_ACC_TYPE (Half, float );
151- CPU_ACC_TYPE (Float8_e5m2, float );
152- CPU_ACC_TYPE (Float8_e4m3fn, float );
153- CPU_ACC_TYPE (Float8_e5m2fnuz, float );
154- CPU_ACC_TYPE (Float8_e4m3fnuz, float );
155- CPU_ACC_TYPE (float , double );
156- CPU_ACC_TYPE (double , double );
157- CPU_ACC_TYPE (int8_t , int64_t );
158- CPU_ACC_TYPE (uint8_t , int64_t );
159- CPU_ACC_TYPE (char , int64_t );
160- CPU_ACC_TYPE (int16_t , int64_t );
161- CPU_ACC_TYPE (int32_t , int64_t );
162- CPU_ACC_TYPE (int64_t , int64_t );
163- CPU_ACC_TYPE (bool , bool );
164- CPU_ACC_TYPE (c10::complex <Half>, c10::complex <float >);
165- CPU_ACC_TYPE (c10::complex <float >, c10::complex <double >);
166- CPU_ACC_TYPE (c10::complex <double >, c10::complex <double >);
130+ CUDA_ACC_TYPE (BFloat16, float )
131+ CUDA_ACC_TYPE (Half, float )
132+ CUDA_ACC_TYPE (Float8_e5m2, float )
133+ CUDA_ACC_TYPE (Float8_e4m3fn, float )
134+ CUDA_ACC_TYPE (Float8_e5m2fnuz, float )
135+ CUDA_ACC_TYPE (Float8_e4m3fnuz, float )
136+ CUDA_ACC_TYPE (float , float )
137+ CUDA_ACC_TYPE (double , double )
138+ CUDA_ACC_TYPE (int8_t , int64_t )
139+ CUDA_ACC_TYPE (uint8_t , int64_t )
140+ CUDA_ACC_TYPE (char , int64_t )
141+ CUDA_ACC_TYPE (int16_t , int64_t )
142+ CUDA_ACC_TYPE (int32_t , int64_t )
143+ CUDA_ACC_TYPE (int64_t , int64_t )
144+ CUDA_ACC_TYPE (bool , bool )
145+ CUDA_ACC_TYPE (c10::complex <Half>, c10::complex <float >)
146+ CUDA_ACC_TYPE (c10::complex <float >, c10::complex <float >)
147+ CUDA_ACC_TYPE (c10::complex <double >, c10::complex <double >)
148+
149+ CPU_ACC_TYPE (BFloat16, float )
150+ CPU_ACC_TYPE (Half, float )
151+ CPU_ACC_TYPE (Float8_e5m2, float )
152+ CPU_ACC_TYPE (Float8_e4m3fn, float )
153+ CPU_ACC_TYPE (Float8_e5m2fnuz, float )
154+ CPU_ACC_TYPE (Float8_e4m3fnuz, float )
155+ CPU_ACC_TYPE (float , double )
156+ CPU_ACC_TYPE (double , double )
157+ CPU_ACC_TYPE (int8_t , int64_t )
158+ CPU_ACC_TYPE (uint8_t , int64_t )
159+ CPU_ACC_TYPE (char , int64_t )
160+ CPU_ACC_TYPE (int16_t , int64_t )
161+ CPU_ACC_TYPE (int32_t , int64_t )
162+ CPU_ACC_TYPE (int64_t , int64_t )
163+ CPU_ACC_TYPE (bool , bool )
164+ CPU_ACC_TYPE (c10::complex <Half>, c10::complex <float >)
165+ CPU_ACC_TYPE (c10::complex <float >, c10::complex <double >)
166+ CPU_ACC_TYPE (c10::complex <double >, c10::complex <double >)
167167
168168TORCH_API c10::ScalarType toAccumulateType (
169169 c10::ScalarType type,
0 commit comments