File tree 4 files changed +54
-0
lines changed
test/prototype/mx_formats
torchao/prototype/mx_formats
4 files changed +54
-0
lines changed Original file line number Diff line number Diff line change @@ -401,3 +401,21 @@ def test_filter_fn():
401
401
swap_linear_with_mx_inference_linear (m2 , config = config , filter_fn = filter_fn ) # noqa: E501
402
402
assert type (m2 [0 ]) == MXInferenceLinear
403
403
assert type (m2 [1 ]) == torch .nn .Linear
404
+
405
+
406
+ def test_training_print_str ():
407
+ m = nn .Sequential (nn .Linear (32 , 32 ))
408
+ config = MXLinearConfig ()
409
+ swap_linear_with_mx_linear (m , config = config )
410
+ s = str (m )
411
+ assert "bl_sz=32" in s
412
+ assert "kernel=emulated" in s
413
+
414
+
415
+ def test_inference_print_str ():
416
+ m = nn .Sequential (nn .Linear (32 , 32 ))
417
+ config = MXLinearConfig ()
418
+ swap_linear_with_mx_inference_linear (m , config = config )
419
+ s = str (m )
420
+ assert "bl_sz=32" in s
421
+ assert "kernel=emulated" in s
Original file line number Diff line number Diff line change 12
12
13
13
from torchao .prototype .mx_formats .constants import (
14
14
DTYPE_FP4 ,
15
+ DTYPE_TO_SHORT_STR ,
15
16
SUPPORTED_ELEM_DTYPES ,
16
17
)
17
18
@@ -143,3 +144,22 @@ def from_recipe_name(
143
144
)
144
145
else :
145
146
raise AssertionError (f"unknown recipe_name { recipe_name } " )
147
+
148
+ def short_str (self ) -> str :
149
+ """
150
+ Returns a concise representation of the current config.
151
+ """
152
+ s = f"bl_sz={ self .block_size } , lp_dtype={ DTYPE_TO_SHORT_STR [self .elem_dtype ]} "
153
+ if self .elem_dtype_weight_override is not None :
154
+ s += (
155
+ f", lp_w_override={ DTYPE_TO_SHORT_STR [self .elem_dtype_weight_override ]} "
156
+ )
157
+ if self .elem_dtype_grad_output_override is not None :
158
+ s += f", lp_go_override={ DTYPE_TO_SHORT_STR [self .elem_dtype_grad_output_override ]} "
159
+ s += f", kernel={ self .gemm_kernel_choice .value } "
160
+ if self .use_fp8_dim1_cast_triton_kernel :
161
+ s += ", use_fp8_dim1_cast_triton_kernel=True"
162
+ if self .use_fp4_custom_triton_dequant_kernel :
163
+ s += ", use_fp4_custom_triton_dequant_kernel=True"
164
+ # TODO(future PR): split training from inference and add fp6 here
165
+ return s
Original file line number Diff line number Diff line change 22
22
DTYPE_FP4 ,
23
23
]
24
24
25
+ DTYPE_TO_SHORT_STR = {
26
+ torch .float8_e4m3fn : "f8e4m3" ,
27
+ torch .float8_e5m2 : "f8e5m2" ,
28
+ DTYPE_FP6_E2M3 : "f6e2m3" ,
29
+ DTYPE_FP6_E3M2 : "f6e3m2" ,
30
+ DTYPE_FP4 : "f4e2m1" ,
31
+ }
32
+
25
33
F8E4M3_MAX = torch .finfo (torch .float8_e4m3fn ).max # 448.0
26
34
F8E5M2_MAX = torch .finfo (torch .float8_e5m2 ).max # 57344.0
27
35
Original file line number Diff line number Diff line change @@ -213,6 +213,10 @@ def forward(self, x):
213
213
y = y + self .bias
214
214
return y
215
215
216
+ def extra_repr (self ):
217
+ s = f"{ super ().extra_repr ()} , { self .config .short_str ()} "
218
+ return s
219
+
216
220
217
221
class MXInferenceLinear (torch .nn .Linear ):
218
222
"""
@@ -255,6 +259,10 @@ def forward(self, x):
255
259
y = F .linear (x , w_hp , self .bias )
256
260
return y
257
261
262
+ def extra_repr (self ):
263
+ s = f"{ super ().extra_repr ()} , { self .config .short_str ()} "
264
+ return s
265
+
258
266
259
267
def replace_with_custom_fn_if_matches_filter (
260
268
model , replacement_fn , filter_fn , cur_fqn = ""
You can’t perform that action at this time.
0 commit comments