24
24
MXInferenceLinear ,
25
25
MXLinear ,
26
26
swap_linear_with_mx_inference_linear ,
27
- swap_linear_with_mx_linear ,
28
27
)
28
+ from torchao .quantization import quantize_
29
29
from torchao .quantization .utils import compute_error
30
30
from torchao .utils import (
31
31
TORCH_VERSION_AT_LEAST_2_8 ,
@@ -98,7 +98,7 @@ def test_linear_eager_vs_hp(
98
98
elem_dtype_grad_output_override = elem_dtype [2 ],
99
99
use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel ,
100
100
)
101
- swap_linear_with_mx_linear (m_mx , config = config )
101
+ quantize_ (m_mx , config )
102
102
103
103
x_ref = torch .randn (
104
104
* input_shape , device = "cuda" , dtype = torch .bfloat16
@@ -159,8 +159,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
159
159
config_emulated = MXLinearConfig (block_size = 32 , elem_dtype = elem_dtype )
160
160
config_real = MXLinearConfig .from_recipe_name (recipe_name )
161
161
162
- swap_linear_with_mx_linear (m_emulated , config = config_emulated )
163
- swap_linear_with_mx_linear (m_real , config = config_real )
162
+ quantize_ (m_emulated , config = config_emulated )
163
+ quantize_ (m_real , config = config_real )
164
164
165
165
y_emulated = m_emulated (x )
166
166
y_emulated .backward (g )
@@ -189,7 +189,7 @@ def test_activation_checkpointing():
189
189
nn .Linear (8 , 8 , bias = True , device = "cuda" ),
190
190
)
191
191
config = MXLinearConfig (block_size = 4 , elem_dtype = elem_dtype )
192
- swap_linear_with_mx_linear (m , config = config )
192
+ quantize_ (m , config = config )
193
193
194
194
x = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
195
195
g = torch .randn (* grad_shape , device = "cuda" )
@@ -252,7 +252,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
252
252
config = MXLinearConfig .from_recipe_name (recipe_name )
253
253
config .use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
254
254
255
- swap_linear_with_mx_linear (m_mx , config = config )
255
+ quantize_ (m_mx , config = config )
256
256
m_mx_c = copy .deepcopy (m_mx )
257
257
m_mx_c = torch .compile (m_mx_c , fullgraph = True , backend = "inductor" )
258
258
@@ -339,10 +339,12 @@ def test_filter_fn():
339
339
nn .Linear (32 , 32 ),
340
340
)
341
341
m2 = copy .deepcopy (m1 )
342
- filter_fn = lambda mod , fqn : fqn != "1" # noqa: E731
342
+ filter_fn = lambda mod , fqn : isinstance ( mod , torch . nn . Linear ) and fqn != "1" # noqa: E731
343
343
344
344
config = MXLinearConfig (block_size = 32 )
345
- swap_linear_with_mx_linear (m1 , config = config , filter_fn = filter_fn )
345
+ print ("before" , m1 )
346
+ quantize_ (m1 , config = config , filter_fn = filter_fn )
347
+ print ("after" , m1 )
346
348
assert type (m1 [0 ]) == MXLinear
347
349
assert type (m1 [1 ]) == torch .nn .Linear
348
350
@@ -354,7 +356,7 @@ def test_filter_fn():
354
356
def test_training_print_str ():
355
357
m = nn .Sequential (nn .Linear (32 , 32 ))
356
358
config = MXLinearConfig ()
357
- swap_linear_with_mx_linear (m , config = config )
359
+ quantize_ (m , config = config )
358
360
s = str (m )
359
361
assert "bl_sz=32" in s
360
362
assert "kernel=emulated" in s
0 commit comments