Skip to content

Commit 4d8998d

Browse files
committed
print MX config when printing MXLinear and MXInferenceLinear
Summary: Adds relevant MX config options to string representation of MX linear objects, to make debugging easier. Example: ``` MXLinear(in_features=4096, out_features=4096, bias=False, bl_sz=32, lp_dtype=f8e4m3, kernel=cublas, use_fp8_dim1_cast_triton_kernel=True) ``` Test Plan: CI Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ce0d3fc ghstack-comment-id: 2749522655 Pull Request resolved: #1947
1 parent 9cb48b5 commit 4d8998d

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,21 @@ def test_filter_fn():
401401
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
402402
assert type(m2[0]) == MXInferenceLinear
403403
assert type(m2[1]) == torch.nn.Linear
404+
405+
406+
def test_training_print_str():
407+
m = nn.Sequential(nn.Linear(32, 32))
408+
config = MXLinearConfig()
409+
swap_linear_with_mx_linear(m, config=config)
410+
s = str(m)
411+
assert "bl_sz=32" in s
412+
assert "kernel=emulated" in s
413+
414+
415+
def test_inference_print_str():
416+
m = nn.Sequential(nn.Linear(32, 32))
417+
config = MXLinearConfig()
418+
swap_linear_with_mx_inference_linear(m, config=config)
419+
s = str(m)
420+
assert "bl_sz=32" in s
421+
assert "kernel=emulated" in s

torchao/prototype/mx_formats/config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from torchao.prototype.mx_formats.constants import (
1414
DTYPE_FP4,
15+
DTYPE_TO_SHORT_STR,
1516
SUPPORTED_ELEM_DTYPES,
1617
)
1718

@@ -143,3 +144,22 @@ def from_recipe_name(
143144
)
144145
else:
145146
raise AssertionError(f"unknown recipe_name {recipe_name}")
147+
148+
def short_str(self) -> str:
149+
"""
150+
Returns a concise representation of the current config.
151+
"""
152+
s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}"
153+
if self.elem_dtype_weight_override is not None:
154+
s += (
155+
f", lp_w_override={DTYPE_TO_SHORT_STR[self.elem_dtype_weight_override]}"
156+
)
157+
if self.elem_dtype_grad_output_override is not None:
158+
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
159+
s += f", kernel={self.gemm_kernel_choice.value}"
160+
if self.use_fp8_dim1_cast_triton_kernel:
161+
s += ", use_fp8_dim1_cast_triton_kernel=True"
162+
if self.use_fp4_custom_triton_dequant_kernel:
163+
s += ", use_fp4_custom_triton_dequant_kernel=True"
164+
# TODO(future PR): split training from inference and add fp6 here
165+
return s

torchao/prototype/mx_formats/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
DTYPE_FP4,
2323
]
2424

25+
DTYPE_TO_SHORT_STR = {
26+
torch.float8_e4m3fn: "f8e4m3",
27+
torch.float8_e5m2: "f8e5m2",
28+
DTYPE_FP6_E2M3: "f6e2m3",
29+
DTYPE_FP6_E3M2: "f6e3m2",
30+
DTYPE_FP4: "f4e2m1",
31+
}
32+
2533
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
2634
F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0
2735

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ def forward(self, x):
213213
y = y + self.bias
214214
return y
215215

216+
def extra_repr(self):
217+
s = f"{super().extra_repr()}, {self.config.short_str()}"
218+
return s
219+
216220

217221
class MXInferenceLinear(torch.nn.Linear):
218222
"""
@@ -255,6 +259,10 @@ def forward(self, x):
255259
y = F.linear(x, w_hp, self.bias)
256260
return y
257261

262+
def extra_repr(self):
263+
s = f"{super().extra_repr()}, {self.config.short_str()}"
264+
return s
265+
258266

259267
def replace_with_custom_fn_if_matches_filter(
260268
model, replacement_fn, filter_fn, cur_fqn=""

0 commit comments

Comments
 (0)