Skip to content

Commit 59bf666

Browse files
authored
use torch.float8_e8m0fnu in mx_formats (#1966)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 7e3978c commit 59bf666

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

test/prototype/mx_formats/test_mx_tensor.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
2121
from torchao.prototype.mx_formats.mx_tensor import (
22-
E8M0_EXPONENT_NAN_VAL,
2322
MXTensor,
2423
ScaleCalculationMode,
2524
to_dtype,
@@ -321,8 +320,8 @@ def test_exponent_nan_in(elem_dtype):
321320
)
322321
block_size = 4
323322
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
324-
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
325-
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
323+
assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0]))
324+
assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:]))
326325

327326

328327
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -332,8 +331,11 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
332331
"""
333332
If block exponent value is NaN, the MX tensor block value is NaN
334333
"""
335-
scale_e8m0_bits = torch.tensor(
336-
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
334+
if pack_fp6 and elem_dtype not in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
335+
pytest.skip("invalid configuration")
336+
337+
scale_e8m0 = torch.tensor(
338+
[float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda"
337339
)
338340

339341
block_size = 4
@@ -359,7 +361,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
359361
block_size = 4
360362
use_fp4_custom_triton_dequant_kernel = False
361363
tensor_mx = MXTensor(
362-
scale_e8m0_bits,
364+
scale_e8m0,
363365
data_bits,
364366
elem_dtype,
365367
block_size,

torchao/prototype/mx_formats/custom_cast.py

+3
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,7 @@ def triton_f4_to_scaled_bf16(
745745
size is currently assumed to be 32.
746746
Output: a tensor of bfloat16 values, multiplied by the encoded scale
747747
"""
748+
s_e8m0 = s_e8m0.view(torch.uint8)
748749
assert TORCH_VERSION_AT_LEAST_2_4, "unsupported"
749750
new_shape = (*x.shape[:-1], x.shape[-1] * 2)
750751
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
@@ -861,6 +862,7 @@ def triton_f6_e2m3_to_scaled_bf16(
861862
size is currently assumed to be 32.
862863
Output: a tensor of bfloat16 values, multiplied by the encoded scale
863864
"""
865+
s_e8m0 = s_e8m0.view(torch.uint8)
864866

865867
packed_mx_block_size = 3 * mx_block_size // 4
866868

@@ -902,6 +904,7 @@ def triton_f6_e3m2_to_scaled_bf16(
902904
size is currently assumed to be 32.
903905
Output: a tensor of bfloat16 values, multiplied by the encoded scale
904906
"""
907+
s_e8m0 = s_e8m0.view(torch.uint8)
905908

906909
packed_mx_block_size = 3 * mx_block_size // 4
907910

torchao/prototype/mx_formats/mx_linear.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
9292
weight_hp, block_size
9393
)
9494
weight_mx_dim1 = MXTensor(
95-
weight_mx_dim1_scale.view(torch.uint8).reshape(-1),
95+
weight_mx_dim1_scale.reshape(-1),
9696
weight_mx_dim1_data.t(),
9797
w_elem_dtype,
9898
block_size,
@@ -121,7 +121,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
121121
grad_output_hp_r, block_size
122122
)
123123
grad_output_mx_dim1 = MXTensor(
124-
grad_output_mx_dim1_scale.view(torch.uint8).reshape(-1),
124+
grad_output_mx_dim1_scale.reshape(-1),
125125
grad_output_mx_dim1_data.t(),
126126
grad_elem_dtype,
127127
block_size,
@@ -143,7 +143,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
143143
input_hp_r, block_size
144144
)
145145
input_t_mx_dim0_tmp = MXTensor(
146-
input_t_mx_dim0_tmp_scale.view(torch.uint8).reshape(-1),
146+
input_t_mx_dim0_tmp_scale.reshape(-1),
147147
input_t_mx_dim0_tmp_data.t(),
148148
in_elem_dtype,
149149
block_size,

torchao/prototype/mx_formats/mx_tensor.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,12 @@ def to_mx(
326326
else:
327327
raise AssertionError("unsupported")
328328

329+
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
329330
return scale_e8m0_biased, data_lp
330331

331332

332333
def get_fp_scale(scale_e8m0):
334+
scale_e8m0 = scale_e8m0.view(torch.uint8)
333335
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
334336
# TODO(later): it would be nice if there was a way to do the 2^x operation
335337
# in PyTorch without creating a tensor of twos
@@ -562,7 +564,9 @@ def __new__(
562564
dtype=orig_dtype,
563565
device=data_bits.device,
564566
)
565-
assert scale_e8m0_bits.dtype == torch.uint8, "unsupported"
567+
assert (
568+
scale_e8m0_bits.dtype == torch.float8_e8m0fnu
569+
), f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
566570
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
567571
assert data_bits.dtype in (
568572
torch.float8_e4m3fn,

0 commit comments

Comments
 (0)