Skip to content

Commit 339ad29

Browse files
committed
use torch.float8_e8m0fnu in mx_formats
Summary: After pytorch/pytorch#148461 lands, we can use `torch.float8_e8m0fnu` throughout our codebase and compile will still work, removing the workarounds. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 278117b ghstack-comment-id: 2755728114 Pull Request resolved: #1966
1 parent 6877686 commit 339ad29

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
2020
from torchao.prototype.mx_formats.mx_tensor import (
21-
E8M0_EXPONENT_NAN_VAL,
2221
MXTensor,
2322
ScaleCalculationMode,
2423
to_dtype,
@@ -117,8 +116,8 @@ def test_exponent_nan_in(elem_dtype):
117116
)
118117
block_size = 4
119118
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
120-
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
121-
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
119+
assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0]))
120+
assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:]))
122121

123122

124123
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -128,8 +127,11 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
128127
"""
129128
If block exponent value is NaN, the MX tensor block value is NaN
130129
"""
131-
scale_e8m0_bits = torch.tensor(
132-
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
130+
if pack_fp6 and elem_dtype not in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
131+
pytest.skip("invalid configuration")
132+
133+
scale_e8m0 = torch.tensor(
134+
[float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda"
133135
)
134136

135137
block_size = 4
@@ -155,7 +157,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
155157
block_size = 4
156158
use_fp4_custom_triton_dequant_kernel = False
157159
tensor_mx = MXTensor(
158-
scale_e8m0_bits,
160+
scale_e8m0,
159161
data_bits,
160162
elem_dtype,
161163
block_size,

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ def triton_f4_to_scaled_bf16(
745745
size is currently assumed to be 32.
746746
Output: a tensor of bfloat16 values, multiplied by the encoded scale
747747
"""
748+
s_e8m0 = s_e8m0.view(torch.uint8)
748749
assert TORCH_VERSION_AT_LEAST_2_4, "unsupported"
749750
new_shape = (*x.shape[:-1], x.shape[-1] * 2)
750751
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
@@ -861,6 +862,7 @@ def triton_f6_e2m3_to_scaled_bf16(
861862
size is currently assumed to be 32.
862863
Output: a tensor of bfloat16 values, multiplied by the encoded scale
863864
"""
865+
s_e8m0 = s_e8m0.view(torch.uint8)
864866

865867
packed_mx_block_size = 3 * mx_block_size // 4
866868

@@ -902,6 +904,7 @@ def triton_f6_e3m2_to_scaled_bf16(
902904
size is currently assumed to be 32.
903905
Output: a tensor of bfloat16 values, multiplied by the encoded scale
904906
"""
907+
s_e8m0 = s_e8m0.view(torch.uint8)
905908

906909
packed_mx_block_size = 3 * mx_block_size // 4
907910

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
9292
weight_hp, block_size
9393
)
9494
weight_mx_dim1 = MXTensor(
95-
weight_mx_dim1_scale.view(torch.uint8).reshape(-1),
95+
weight_mx_dim1_scale.reshape(-1),
9696
weight_mx_dim1_data.t(),
9797
w_elem_dtype,
9898
block_size,
@@ -121,7 +121,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
121121
grad_output_hp_r, block_size
122122
)
123123
grad_output_mx_dim1 = MXTensor(
124-
grad_output_mx_dim1_scale.view(torch.uint8).reshape(-1),
124+
grad_output_mx_dim1_scale.reshape(-1),
125125
grad_output_mx_dim1_data.t(),
126126
grad_elem_dtype,
127127
block_size,
@@ -143,7 +143,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
143143
input_hp_r, block_size
144144
)
145145
input_t_mx_dim0_tmp = MXTensor(
146-
input_t_mx_dim0_tmp_scale.view(torch.uint8).reshape(-1),
146+
input_t_mx_dim0_tmp_scale.reshape(-1),
147147
input_t_mx_dim0_tmp_data.t(),
148148
in_elem_dtype,
149149
block_size,

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,12 @@ def to_mx(
271271
else:
272272
raise AssertionError("unsupported")
273273

274+
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
274275
return scale_e8m0_biased, data_lp
275276

276277

277278
def get_fp_scale(scale_e8m0):
279+
scale_e8m0 = scale_e8m0.view(torch.uint8)
278280
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
279281
# TODO(later): it would be nice if there was a way to do the 2^x operation
280282
# in PyTorch without creating a tensor of twos
@@ -507,7 +509,9 @@ def __new__(
507509
dtype=orig_dtype,
508510
device=data_bits.device,
509511
)
510-
assert scale_e8m0_bits.dtype == torch.uint8, "unsupported"
512+
assert (
513+
scale_e8m0_bits.dtype == torch.float8_e8m0fnu
514+
), f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
511515
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
512516
assert data_bits.dtype in (
513517
torch.float8_e4m3fn,

0 commit comments

Comments
 (0)