11
11
import torch
12
12
import torch .nn as nn
13
13
14
+ from torchao .prototype .mx_formats .config import MXLinearConfig
14
15
from torchao .prototype .mx_formats .constants import SUPPORTED_ELEM_DTYPES
15
16
from torchao .prototype .mx_formats .mx_linear import (
16
17
MXInferenceLinear ,
@@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
59
60
nn .Linear (8 , 6 , bias = bias , device = "cuda" ),
60
61
)
61
62
m_mx = copy .deepcopy (m )
62
- block_size = 2
63
- swap_linear_with_mx_linear (m_mx , * elem_dtype , block_size = block_size )
63
+ config = MXLinearConfig (
64
+ block_size = 2 ,
65
+ elem_dtype = elem_dtype [0 ],
66
+ elem_dtype_weight_override = elem_dtype [1 ],
67
+ elem_dtype_grad_output_override = elem_dtype [2 ],
68
+ )
69
+ swap_linear_with_mx_linear (m_mx , config = config )
64
70
65
71
x_ref = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
66
72
x = copy .deepcopy (x_ref )
@@ -97,8 +103,8 @@ def test_activation_checkpointing():
97
103
nn .Linear (4 , 6 , bias = True , device = "cuda" ),
98
104
nn .Linear (6 , 6 , bias = True , device = "cuda" ),
99
105
)
100
- block_size = 2
101
- swap_linear_with_mx_linear (m , elem_dtype , block_size = block_size )
106
+ config = MXLinearConfig ( block_size = 2 , elem_dtype = elem_dtype )
107
+ swap_linear_with_mx_linear (m , config = config )
102
108
103
109
x = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
104
110
g = torch .randn (* grad_shape , device = "cuda" )
@@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
133
139
m_mx = nn .Sequential (
134
140
nn .Linear (K , N , bias = bias , device = "cuda" ),
135
141
)
136
- block_size = 2
137
- swap_linear_with_mx_linear (m_mx , elem_dtype , block_size = block_size )
142
+ config = MXLinearConfig ( block_size = 2 , elem_dtype = elem_dtype )
143
+ swap_linear_with_mx_linear (m_mx , config = config )
138
144
m_mx_c = copy .deepcopy (m_mx )
139
145
m_mx_c = torch .compile (m_mx_c , fullgraph = True , backend = "inductor" )
140
146
@@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape):
181
187
m = nn .Sequential (nn .Linear (4 , 6 , bias = bias , dtype = torch .bfloat16 ))
182
188
m = m .cuda ()
183
189
m_mx = copy .deepcopy (m )
184
- block_size = 2
185
- swap_linear_with_mx_inference_linear (m_mx , elem_dtype , block_size )
190
+ config = MXLinearConfig ( block_size = 2 , elem_dtype = elem_dtype )
191
+ swap_linear_with_mx_inference_linear (m_mx , config = config )
186
192
187
193
x = torch .randn (* input_shape , device = "cuda" , dtype = torch .bfloat16 )
188
194
y_ref = m (x )
@@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype):
209
215
m = nn .Sequential (nn .Linear (4 , 6 , bias = False , dtype = torch .bfloat16 ))
210
216
m = m .cuda ()
211
217
m_mx = copy .deepcopy (m )
212
- block_size = 2
213
- swap_linear_with_mx_inference_linear (m_mx , elem_dtype , block_size )
218
+ config = MXLinearConfig ( block_size = 2 , elem_dtype = elem_dtype )
219
+ swap_linear_with_mx_inference_linear (m_mx , config = config )
214
220
m_mx = torch .compile (m_mx , fullgraph = "true" )
215
221
216
222
x = torch .randn (2 , 4 , device = "cuda" , dtype = torch .bfloat16 )
@@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype):
223
229
assert sqnr >= 13.5
224
230
225
231
226
- def test_mx_linear_input_weight_gradient_dtypes ():
227
- m = nn .Sequential (nn .Linear (32 , 32 ))
228
- swap_linear_with_mx_linear (m , * SUPPORTED_ELEM_DTYPES [:3 ], block_size = 32 )
229
- assert m [0 ].in_elem_dtype == SUPPORTED_ELEM_DTYPES [0 ]
230
- assert m [0 ].w_elem_dtype == SUPPORTED_ELEM_DTYPES [1 ]
231
- assert m [0 ].grad_elem_dtype == SUPPORTED_ELEM_DTYPES [2 ]
232
-
233
- m = nn .Sequential (nn .Linear (32 , 32 ))
234
- swap_linear_with_mx_linear (m , torch .float8_e4m3fn , block_size = 32 )
235
- assert m [0 ].in_elem_dtype == torch .float8_e4m3fn
236
- assert m [0 ].w_elem_dtype == torch .float8_e4m3fn
237
- assert m [0 ].grad_elem_dtype == torch .float8_e4m3fn
238
-
239
-
240
232
def test_filter_fn ():
241
233
m1 = nn .Sequential (
242
234
nn .Linear (32 , 32 ),
@@ -245,12 +237,11 @@ def test_filter_fn():
245
237
m2 = copy .deepcopy (m1 )
246
238
filter_fn = lambda mod , fqn : fqn != "1" # noqa: E731
247
239
248
- swap_linear_with_mx_linear (
249
- m1 , torch .float8_e4m3fn , block_size = 32 , filter_fn = filter_fn
250
- )
240
+ config = MXLinearConfig (block_size = 32 )
241
+ swap_linear_with_mx_linear (m1 , config = config , filter_fn = filter_fn )
251
242
assert type (m1 [0 ]) == MXLinear
252
243
assert type (m1 [1 ]) == torch .nn .Linear
253
244
254
- swap_linear_with_mx_inference_linear (m2 , torch . float8_e4m3fn , 32 , filter_fn ) # noqa: E501
245
+ swap_linear_with_mx_inference_linear (m2 , config = config , filter_fn = filter_fn ) # noqa: E501
255
246
assert type (m2 [0 ]) == MXInferenceLinear
256
247
assert type (m2 [1 ]) == torch .nn .Linear
0 commit comments