2424 MXInferenceLinear ,
2525 MXLinear ,
2626 swap_linear_with_mx_inference_linear ,
27- swap_linear_with_mx_linear ,
2827)
28+ from torchao .quantization import quantize_
2929from torchao .quantization .utils import compute_error
3030from torchao .utils import (
3131 TORCH_VERSION_AT_LEAST_2_8 ,
@@ -98,7 +98,7 @@ def test_linear_eager_vs_hp(
9898 elem_dtype_grad_output_override = elem_dtype [2 ],
9999 use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel ,
100100 )
101- swap_linear_with_mx_linear (m_mx , config = config )
101+ quantize_ (m_mx , config )
102102
103103 x_ref = torch .randn (
104104 * input_shape , device = "cuda" , dtype = torch .bfloat16
@@ -159,8 +159,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
159159 config_emulated = MXLinearConfig (block_size = 32 , elem_dtype = elem_dtype )
160160 config_real = MXLinearConfig .from_recipe_name (recipe_name )
161161
162- swap_linear_with_mx_linear (m_emulated , config = config_emulated )
163- swap_linear_with_mx_linear (m_real , config = config_real )
162+ quantize_ (m_emulated , config = config_emulated )
163+ quantize_ (m_real , config = config_real )
164164
165165 y_emulated = m_emulated (x )
166166 y_emulated .backward (g )
@@ -189,7 +189,7 @@ def test_activation_checkpointing():
189189 nn .Linear (8 , 8 , bias = True , device = "cuda" ),
190190 )
191191 config = MXLinearConfig (block_size = 4 , elem_dtype = elem_dtype )
192- swap_linear_with_mx_linear (m , config = config )
192+ quantize_ (m , config = config )
193193
194194 x = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
195195 g = torch .randn (* grad_shape , device = "cuda" )
@@ -252,7 +252,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
252252 config = MXLinearConfig .from_recipe_name (recipe_name )
253253 config .use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
254254
255- swap_linear_with_mx_linear (m_mx , config = config )
255+ quantize_ (m_mx , config = config )
256256 m_mx_c = copy .deepcopy (m_mx )
257257 m_mx_c = torch .compile (m_mx_c , fullgraph = True , backend = "inductor" )
258258
@@ -339,10 +339,10 @@ def test_filter_fn():
339339 nn .Linear (32 , 32 ),
340340 )
341341 m2 = copy .deepcopy (m1 )
342- filter_fn = lambda mod , fqn : fqn != "1" # noqa: E731
342+ filter_fn = lambda mod , fqn : isinstance ( mod , torch . nn . Linear ) and fqn != "1" # noqa: E731
343343
344344 config = MXLinearConfig (block_size = 32 )
345- swap_linear_with_mx_linear (m1 , config = config , filter_fn = filter_fn )
345+ quantize_ (m1 , config = config , filter_fn = filter_fn )
346346 assert type (m1 [0 ]) == MXLinear
347347 assert type (m1 [1 ]) == torch .nn .Linear
348348
@@ -354,7 +354,7 @@ def test_filter_fn():
354354def test_training_print_str ():
355355 m = nn .Sequential (nn .Linear (32 , 32 ))
356356 config = MXLinearConfig ()
357- swap_linear_with_mx_linear (m , config = config )
357+ quantize_ (m , config = config )
358358 s = str (m )
359359 assert "bl_sz=32" in s
360360 assert "kernel=emulated" in s
0 commit comments