Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1052066

Browse files
committed
Update on "[wip] add axiswise granularity to Float8Tensor"
Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent c4c9ae8 commit 1052066

File tree

3 files changed

+71
-24
lines changed

3 files changed

+71
-24
lines changed

float8_experimental/float8_ops.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def decorator(func):
4141

4242
@implements(
4343
[
44-
# aten.view.default,
4544
aten._unsafe_view.default,
4645
aten.as_strided.default,
4746
aten.clone.default,
@@ -79,19 +78,40 @@ def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
7978
args[0]._gemm_input_role,
8079
)
8180

81+
8282
@implements([aten.view.default])
8383
def float8_view(aten_op, args, kwargs=None):
8484
if len(args[0]._scale.shape) < 2:
8585
# tensorwise scaling
86-
return float8_desugar_op(aten_op, *args, **kwargs)
87-
print('args', args)
88-
print('kwargs', kwargs)
89-
tensor, new_shape = args[0], args[1]
90-
91-
# for now, only support reshaping to [-1, *dims] or [*dims, -1]
92-
if len(new_shape) >= 2 and (new_shape[0] == -1 or new_shape[-1] == -1):
93-
return float8_desugar_data_and_scale(aten_op, *args, **kwargs)
94-
raise AssertionError(f"{aten_op} with axiswise scaling and shape {new_shape} is not supported yet.")
86+
return float8_desugar_op(aten_op, args, kwargs)
87+
88+
t, new_shape = args[0], args[1]
89+
# for now, only support reshaping to [-1, dim] or [dim, -1]
90+
if len(new_shape) == 2:
91+
if new_shape == [t.shape[0], -1] and t._scale.shape[0] == 1:
92+
new_data = aten_op(t._data, new_shape, **kwargs)
93+
new_scale = aten_op(t._scale, [1, -1], **kwargs)
94+
return Float8Tensor(
95+
new_data,
96+
new_scale,
97+
t._orig_dtype,
98+
t._linear_mm_config,
99+
t._gemm_input_role,
100+
)
101+
elif new_shape == [-1, t.shape[-1]] and t._scale.shape[-1] == 1:
102+
new_data = aten_op(t._data, new_shape, **kwargs)
103+
new_scale = aten_op(t._scale, [-1, 1], **kwargs)
104+
return Float8Tensor(
105+
new_data,
106+
new_scale,
107+
t._orig_dtype,
108+
t._linear_mm_config,
109+
t._gemm_input_role,
110+
)
111+
raise AssertionError(
112+
f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} new_shape {new_shape} is not supported yet."
113+
)
114+
95115

96116
@implements([aten.split.Tensor])
97117
def float8_split(aten_op, args, kwargs=None):

float8_experimental/float8_python_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def addmm_float8_unwrapped(
3838
"""
3939
a_inverse_scale = a_scale.reciprocal()
4040
b_inverse_scale = b_scale.reciprocal()
41-
4241
if output_dtype == torch.float32 and bias is not None:
4342
# Bias is not supported by _scaled_mm when output is fp32
4443
output = torch._scaled_mm(

test/test_base.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,27 +171,57 @@ def test_axiswise_dynamic_cast(self, shape, dim_name):
171171
sqnr = compute_error(a, a_dq)
172172
assert sqnr >= 25.0
173173

174-
# TODO(next) make this work
175174
def test_axiswise_reshape(self):
176175
a = torch.randn(3, 5, 7, dtype=torch.bfloat16, device="cuda")
177176
linear_mm_config = LinearMMConfig()
178177

179-
a_fp8 = hp_tensor_to_float8_dynamic(
178+
# if we scale across dim0, we can only reshape to [3, -1]
179+
a_fp8_d0 = hp_tensor_to_float8_dynamic(
180180
a,
181181
e4m3_dtype,
182182
linear_mm_config,
183183
scaling_granularity=ScalingGranularity.AXISWISE,
184184
axiswise_dim=0,
185185
)
186-
# a_fp8._data.shape is (3, 5, 7)
187-
# a_fp8._scale.shape is (1, 5, 7)
188-
print(a_fp8._scale.shape)
186+
assert list(a_fp8_d0._data.shape) == [3, 5, 7]
187+
assert list(a_fp8_d0._scale.shape) == [1, 5, 7]
188+
189+
a_fp8_d0_r = a_fp8_d0.reshape(3, -1)
190+
assert list(a_fp8_d0_r.shape) == [3, 5 * 7]
191+
assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7]
192+
# verify numerics did not change
193+
assert torch.allclose(
194+
a_fp8_d0.to_original_precision(),
195+
a_fp8_d0_r.to_original_precision().reshape(3, 5, 7),
196+
atol=0,
197+
rtol=0,
198+
)
199+
with pytest.raises(AssertionError):
200+
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)
189201

190-
# reshape to (3, 5 * 7)
191-
# a_fp8._scale.shape should be (1, 5 * 7)
192-
a_fp8_r = a_fp8.reshape(3, -1)
193-
print(a_fp8_r._scale.shape)
194-
202+
# if we scale across dim2, we can only reshape to [-1, 7]
203+
a_fp8_d2 = hp_tensor_to_float8_dynamic(
204+
a,
205+
e4m3_dtype,
206+
linear_mm_config,
207+
scaling_granularity=ScalingGranularity.AXISWISE,
208+
axiswise_dim=2,
209+
)
210+
assert list(a_fp8_d2._data.shape) == [3, 5, 7]
211+
assert list(a_fp8_d2._scale.shape) == [3, 5, 1]
212+
213+
a_fp8_d2_r = a_fp8_d2.reshape(-1, 7)
214+
assert list(a_fp8_d2_r.shape) == [3 * 5, 7]
215+
assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1]
216+
# verify numerics did not change
217+
assert torch.allclose(
218+
a_fp8_d2.to_original_precision(),
219+
a_fp8_d2_r.to_original_precision().reshape(3, 5, 7),
220+
atol=0,
221+
rtol=0,
222+
)
223+
with pytest.raises(AssertionError):
224+
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)
195225

196226
def test_axiswise_gemm(self):
197227
a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
@@ -216,11 +246,9 @@ def test_axiswise_gemm(self):
216246
axiswise_dim=1,
217247
)
218248
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
219-
print(c_fp8_compute)
220249
c_ref = torch.mm(a, b.t())
221250
sqnr = compute_error(c_ref, c_fp8_compute)
222-
print('sqnr', sqnr)
223-
# TODO check numerical accuracy
251+
assert sqnr >= 25.0
224252

225253

226254
class TestFloat8Linear:

0 commit comments

Comments
 (0)