Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c4c9ae8

Browse files
committed
Update on "[wip] add axiswise granularity to Float8Tensor"
Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent e87f005 commit c4c9ae8

File tree

4 files changed

+62
-25
lines changed

4 files changed

+62
-25
lines changed

float8_experimental/float8_ops.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def decorator(func):
4141

4242
@implements(
4343
[
44-
aten.view.default,
44+
# aten.view.default,
4545
aten._unsafe_view.default,
4646
aten.as_strided.default,
4747
aten.clone.default,
@@ -79,6 +79,19 @@ def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
7979
args[0]._gemm_input_role,
8080
)
8181

82+
@implements([aten.view.default])
83+
def float8_view(aten_op, args, kwargs=None):
84+
if len(args[0]._scale.shape) < 2:
85+
# tensorwise scaling
86+
return float8_desugar_op(aten_op, *args, **kwargs)
87+
print('args', args)
88+
print('kwargs', kwargs)
89+
tensor, new_shape = args[0], args[1]
90+
91+
# for now, only support reshaping to [-1, *dims] or [*dims, -1]
92+
if len(new_shape) >= 2 and (new_shape[0] == -1 or new_shape[-1] == -1):
93+
return float8_desugar_data_and_scale(aten_op, *args, **kwargs)
94+
raise AssertionError(f"{aten_op} with axiswise scaling and shape {new_shape} is not supported yet.")
8295

8396
@implements([aten.split.Tensor])
8497
def float8_split(aten_op, args, kwargs=None):

float8_experimental/float8_python_api.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,6 @@ def addmm_float8_unwrapped(
3939
a_inverse_scale = a_scale.reciprocal()
4040
b_inverse_scale = b_scale.reciprocal()
4141

42-
# TODO: should we change torch._scaled_mm?
43-
# torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank
44-
# 2. Translate to this format.
45-
# TODO: audit if we need to make this more generic for various shapes.
46-
a_inverse_scale = a_inverse_scale.squeeze()
47-
b_inverse_scale = b_inverse_scale.squeeze()
48-
4942
if output_dtype == torch.float32 and bias is not None:
5043
# Bias is not supported by _scaled_mm when output is fp32
5144
output = torch._scaled_mm(

float8_experimental/float8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def tensor_to_amax(
115115

116116
# convert from axiswise_dim (dim to keep) to
117117
# dim as the input to the `torch.amax` function (tuple of dims to reduce)
118-
dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim)
118+
# dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim)
119119

120-
amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True)
120+
amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True)
121121

122122
# If the user asked for distributed reduction, do it.
123123
# If the user did not ask for it, assume that it will

test/test_base.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
6363
return True
6464

6565

66-
class TestFloat8Tensor(unittest.TestCase):
66+
class TestFloat8Tensor:
6767
def test_preserves_dtype(self) -> None:
6868
# hp means high precision, lp means low precision
6969
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
@@ -73,7 +73,7 @@ def test_preserves_dtype(self) -> None:
7373
x1_s = tensor_to_scale(x1_hp, lp_dtype)
7474
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
7575
x3_hp = x2_lp.to_original_precision()
76-
self.assertTrue(x3_hp.dtype == hp_dtype)
76+
assert x3_hp.dtype == hp_dtype
7777

7878
def test_differentiable_casts(self) -> None:
7979
lp_dtypes = (e4m3_dtype, e5m2_dtype)
@@ -108,7 +108,7 @@ def test_index_put(self):
108108
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
109109
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)
110110

111-
with self.assertRaises(AssertionError):
111+
with pytest.raises(AssertionError):
112112
b[index] = fp8_a
113113
fp8_b[index] = a
114114
fp8_b_bad[index] = fp8_a
@@ -122,7 +122,7 @@ def test_copy_(self):
122122
b = torch.empty(16, dtype=torch.bfloat16)
123123
b.copy_(fp8_a) # Should work
124124
torch.testing.assert_close(b, fp8_a.to_original_precision())
125-
with self.assertRaises(RuntimeError):
125+
with pytest.raises(RuntimeError):
126126
fp8_a.copy_(b) # Should fail
127127

128128
fp8_b = Float8Tensor(
@@ -149,21 +149,49 @@ def test_weights_only_load(self):
149149
buffer.seek(0)
150150
_ = torch.load(buffer, weights_only=True)
151151

152-
def test_axiswise_dynamic_cast(self):
153-
a = torch.randn(16, 32, dtype=torch.bfloat16)
152+
@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
153+
@pytest.mark.parametrize("dim_name", ["first", "last"])
154+
def test_axiswise_dynamic_cast(self, shape, dim_name):
155+
a = torch.randn(*shape, dtype=torch.bfloat16)
156+
157+
if dim_name == "first":
158+
dim = 0
159+
elif dim_name == "last":
160+
dim = len(a.shape) - 1
161+
154162
linear_mm_config = LinearMMConfig()
163+
a_fp8 = hp_tensor_to_float8_dynamic(
164+
a,
165+
e4m3_dtype,
166+
linear_mm_config,
167+
scaling_granularity=ScalingGranularity.AXISWISE,
168+
axiswise_dim=dim,
169+
)
170+
a_dq = a_fp8.to_original_precision()
171+
sqnr = compute_error(a, a_dq)
172+
assert sqnr >= 25.0
173+
174+
# TODO(next) make this work
175+
def test_axiswise_reshape(self):
176+
a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda")
177+
linear_mm_config = LinearMMConfig()
178+
155179
a_fp8 = hp_tensor_to_float8_dynamic(
156180
a,
157181
e4m3_dtype,
158182
linear_mm_config,
159183
scaling_granularity=ScalingGranularity.AXISWISE,
160184
axiswise_dim=0,
161185
)
162-
# print(a_fp8)
163-
# print(a_fp8.to_original_precision())
164-
# print(a_fp8.t())
165-
b = a_fp8.t()
166-
# TODO check numerical accuracy
186+
# a_fp8._data.shape is (3, 5, 7)
187+
# a_fp8._scale.shape is (1, 5, 7)
188+
print(a_fp8._scale.shape)
189+
190+
# reshape to (3, 5 * 7)
191+
# a_fp8._scale.shape should be (1, 5 * 7)
192+
a_fp8_r = a_fp8.reshape(3, -1)
193+
print(a_fp8_r._scale.shape)
194+
167195

168196
def test_axiswise_gemm(self):
169197
a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
@@ -177,18 +205,21 @@ def test_axiswise_gemm(self):
177205
linear_mm_config,
178206
gemm_input_role=GemmInputRole.INPUT,
179207
scaling_granularity=ScalingGranularity.AXISWISE,
180-
axiswise_dim=0,
208+
axiswise_dim=1,
181209
)
182210
b_fp8 = hp_tensor_to_float8_dynamic(
183211
b,
184212
e4m3_dtype,
185213
linear_mm_config,
186214
gemm_input_role=GemmInputRole.WEIGHT,
187215
scaling_granularity=ScalingGranularity.AXISWISE,
188-
axiswise_dim=0,
216+
axiswise_dim=1,
189217
)
190-
c = torch.mm(a_fp8, b_fp8.t())
191-
print(c)
218+
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
219+
print(c_fp8_compute)
220+
c_ref = torch.mm(a, b.t())
221+
sqnr = compute_error(c_ref, c_fp8_compute)
222+
print('sqnr', sqnr)
192223
# TODO check numerical accuracy
193224

194225

0 commit comments

Comments
 (0)