5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import copy
8
- import itertools
9
8
10
9
import pytest
11
10
import torch
16
15
MXLinearConfig ,
17
16
MXLinearRecipeName ,
18
17
)
19
- from torchao .prototype .mx_formats .constants import DTYPE_FP4 , SUPPORTED_ELEM_DTYPES
18
+ from torchao .prototype .mx_formats .constants import (
19
+ DTYPE_FP4 ,
20
+ DTYPE_FP6_E2M3 ,
21
+ DTYPE_FP6_E3M2 ,
22
+ SUPPORTED_ELEM_DTYPES ,
23
+ )
20
24
from torchao .prototype .mx_formats .mx_linear import (
21
25
MXInferenceLinear ,
22
26
MXLinear ,
@@ -48,38 +52,65 @@ def run_around_tests():
48
52
49
53
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
50
54
@pytest .mark .parametrize (
51
- "elem_dtype" , itertools .product (SUPPORTED_ELEM_DTYPES , repeat = 3 )
55
+ "elem_dtype" ,
56
+ (
57
+ # test each dtype
58
+ (torch .float8_e4m3fn , torch .float8_e4m3fn , torch .float8_e4m3fn ),
59
+ (DTYPE_FP6_E3M2 , DTYPE_FP6_E3M2 , DTYPE_FP6_E3M2 ),
60
+ (DTYPE_FP6_E2M3 , DTYPE_FP6_E2M3 , DTYPE_FP6_E2M3 ),
61
+ (DTYPE_FP4 , DTYPE_FP4 , DTYPE_FP4 ),
62
+ # only test one type of mixed-dtype overrides, to save testing time
63
+ (torch .float8_e4m3fn , DTYPE_FP4 , DTYPE_FP4 ),
64
+ ),
52
65
)
53
66
@pytest .mark .parametrize ("bias" , [True , False ])
54
- @pytest .mark .parametrize ("input_shape" , [(4 , 8 ), (1 , 4 , 8 ), (1 , 1 , 4 , 8 )])
55
- def test_linear_eager (elem_dtype , bias , input_shape ):
67
+ @pytest .mark .parametrize ("input_shape" , [(128 , 256 ), (1 , 128 , 256 ), (1 , 1 , 128 , 256 )])
68
+ @pytest .mark .parametrize ("use_fp8_dim1_cast_triton_kernel" , [False , True ])
69
+ def test_linear_eager_vs_hp (
70
+ elem_dtype , bias , input_shape , use_fp8_dim1_cast_triton_kernel
71
+ ):
56
72
"""
57
73
Smoke test for training linear module with mx weight, compares the following:
58
74
* baseline: float32
59
75
* experiment: emulated MX
60
76
"""
77
+ if use_fp8_dim1_cast_triton_kernel :
78
+ if elem_dtype != (
79
+ torch .float8_e4m3fn ,
80
+ torch .float8_e4m3fn ,
81
+ torch .float8_e4m3fn ,
82
+ ):
83
+ pytest .skip ("unsupported configuration" )
84
+ elif not is_sm_at_least_89 ():
85
+ pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
86
+
61
87
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
62
88
grad_shape = list (input_shape )
63
- grad_shape [- 1 ] = 8
89
+ grad_shape [- 1 ] = 256
64
90
65
91
m = nn .Sequential (
66
- nn .Linear (8 , 8 , bias = bias , device = "cuda" ),
92
+ nn .Linear (256 , 256 , bias = bias , device = "cuda" , dtype = torch . bfloat16 ),
67
93
)
68
94
m_mx = copy .deepcopy (m )
69
95
config = MXLinearConfig (
70
96
block_size = 4 ,
71
97
elem_dtype = elem_dtype [0 ],
72
98
elem_dtype_weight_override = elem_dtype [1 ],
73
99
elem_dtype_grad_output_override = elem_dtype [2 ],
100
+ use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel ,
74
101
)
75
102
swap_linear_with_mx_linear (m_mx , config = config )
76
103
77
- x_ref = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
104
+ x_ref = torch .randn (
105
+ * input_shape , device = "cuda" , dtype = torch .bfloat16
106
+ ).requires_grad_ ()
78
107
x = copy .deepcopy (x_ref )
79
108
g = torch .randn (* grad_shape , device = "cuda" )
80
- with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
81
- y_ref = m (x_ref )
82
- y_mx = m_mx (x )
109
+
110
+ y_ref = m (x_ref )
111
+ y_mx = m_mx (x )
112
+
113
+ assert y_mx .dtype == x .dtype
83
114
84
115
y_ref .backward (g )
85
116
y_mx .backward (g )
@@ -112,7 +143,6 @@ def test_linear_eager(elem_dtype, bias, input_shape):
112
143
)
113
144
@pytest .mark .parametrize ("mkn" , [(128 , 256 , 512 ), (256 , 512 , 128 ), (512 , 128 , 256 )])
114
145
def test_linear_eager_emulated_vs_real_gemm (recipe_name , mkn ):
115
- M , K , N = 128 , 128 , 128
116
146
M , K , N = mkn
117
147
118
148
x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ).requires_grad_ ()
@@ -143,9 +173,9 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
143
173
y_sqnr = compute_error (y_real , y_emulated )
144
174
w_sqnr = compute_error (m_real [0 ].weight .grad , m_emulated [0 ].weight .grad )
145
175
g_sqnr = compute_error (x_copy .grad , x .grad )
146
- assert y_sqnr > 100 .0 , f"y_sqnr { y_sqnr } too low!"
147
- assert w_sqnr > 100 .0 , f"w_sqnr { w_sqnr } too low!"
148
- assert g_sqnr > 100 .0 , f"g_sqnr { g_sqnr } too low!"
176
+ assert y_sqnr > 90 .0 , f"y_sqnr { y_sqnr } too low!"
177
+ assert w_sqnr > 90 .0 , f"w_sqnr { w_sqnr } too low!"
178
+ assert g_sqnr > 90 .0 , f"g_sqnr { g_sqnr } too low!"
149
179
150
180
151
181
# TODO(future): enable compile support
@@ -169,6 +199,7 @@ def test_activation_checkpointing():
169
199
170
200
171
201
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
202
+ @pytest .mark .parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
172
203
@pytest .mark .parametrize (
173
204
"recipe_name" ,
174
205
[
@@ -182,7 +213,8 @@ def test_activation_checkpointing():
182
213
@pytest .mark .parametrize ("bias" , [False , True ])
183
214
# TODO(future PR): figure out why torch.compile does not match eager when
184
215
# autocast is on
185
- def test_linear_compile (recipe_name , bias ):
216
+ @pytest .mark .parametrize ("use_fp8_dim1_cast_triton_kernel" , [False , True ])
217
+ def test_linear_compile (hp_dtype , recipe_name , bias , use_fp8_dim1_cast_triton_kernel ):
186
218
"""
187
219
Verify that compile does not change numerics of MX linear fw + bw
188
220
"""
@@ -198,20 +230,36 @@ def test_linear_compile(recipe_name, bias):
198
230
# TODO(future PR): fix this, things are clearly broken with bias=True
199
231
pytest .skip ("this test is broken for non-emulated recipes with bias=True" )
200
232
233
+ if use_fp8_dim1_cast_triton_kernel :
234
+ if recipe_name not in ("mxfp8_emulated" , "mxfp8_cublas" , "mxfp8_cutlass" ):
235
+ pytest .skip ("unsupported configuration" )
236
+ if not is_sm_at_least_89 ():
237
+ pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
238
+ if hp_dtype != torch .bfloat16 :
239
+ pytest .skip ("unsupported configuration" )
240
+
241
+ if hp_dtype == torch .bfloat16 and recipe_name != "mxfp8_cublas" :
242
+ # TODO(future PR): properly enable float32 + bfloat16 for every
243
+ # recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even
244
+ # if the underlying gemm kernel only supports bf16 output)
245
+ pytest .skip ("unsupported configuration" )
246
+
201
247
M , K , N = 128 , 256 , 512
202
248
input_shape = (M , K )
203
249
grad_shape = (M , N )
204
250
m_mx = nn .Sequential (
205
- nn .Linear (K , N , bias = bias , device = "cuda" ),
251
+ nn .Linear (K , N , bias = bias , device = "cuda" , dtype = hp_dtype ),
206
252
)
207
253
config = MXLinearConfig .from_recipe_name (recipe_name )
254
+ config .use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
255
+
208
256
swap_linear_with_mx_linear (m_mx , config = config )
209
257
m_mx_c = copy .deepcopy (m_mx )
210
258
m_mx_c = torch .compile (m_mx_c , fullgraph = True , backend = "inductor" )
211
259
212
- x_ref = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
260
+ x_ref = torch .randn (* input_shape , device = "cuda" , dtype = hp_dtype ).requires_grad_ ()
213
261
x = copy .deepcopy (x_ref )
214
- g = torch .randn (* grad_shape , device = "cuda" )
262
+ g = torch .randn (* grad_shape , device = "cuda" , dtype = hp_dtype )
215
263
216
264
y_ref = m_mx (x_ref )
217
265
y = m_mx_c (x )
@@ -283,7 +331,7 @@ def test_inference_compile_simple(elem_dtype):
283
331
if elem_dtype is torch .float8_e4m3fn :
284
332
assert sqnr >= 20.0
285
333
else :
286
- assert sqnr >= 13 .5
334
+ assert sqnr >= 11 .5
287
335
288
336
289
337
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
0 commit comments