Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6ac2f82

Browse files
vkuzofacebook-github-bot
authored andcommitted
clean up casting: ToFloat8ConstrFunc -> hp_tensor_and_scale_to_float8 (#348)
Summary: Pull Request resolved: #348 Moves `ToFloat8ConstrFunc` to private, and creates `hp_tensor_and_scale_to_float8` as the official wrapper which clearly describes what this function is doing. A future PR will rename the scaling-aware functions to match this naming. Reviewed By: drisspg Differential Revision: D60310240 fbshipit-source-id: 954e7c910cee36f2ea0b0d1984fe163862b47ee5
1 parent 7e0182f commit 6ac2f82

File tree

7 files changed

+72
-53
lines changed

7 files changed

+72
-53
lines changed

benchmarks/bench_padding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import torch
77
from float8_experimental.float8_tensor import (
88
GemmInputRole,
9+
hp_tensor_and_scale_to_float8,
910
LinearMMConfig,
1011
ScaledMMConfig,
11-
ToFloat8ConstrFunc,
1212
)
1313
from float8_experimental.float8_utils import pad_tensor_for_matmul
1414
from tabulate import tabulate
@@ -58,14 +58,14 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
5858
a_config = LinearMMConfig(a_config, a_config, a_config)
5959
b_config = LinearMMConfig(b_config, b_config, b_config)
6060

61-
a_fp8 = ToFloat8ConstrFunc.apply(
61+
a_fp8 = hp_tensor_and_scale_to_float8(
6262
A,
6363
scale_a,
6464
fp8_dtype,
6565
a_config,
6666
GemmInputRole.INPUT,
6767
)
68-
b_fp8 = ToFloat8ConstrFunc.apply(
68+
b_fp8 = hp_tensor_and_scale_to_float8(
6969
B,
7070
scale_b,
7171
fp8_dtype,

float8_experimental/float8_scaling_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from float8_experimental.float8_tensor import (
1616
Float8Tensor,
1717
GemmInputRole,
18+
hp_tensor_and_scale_to_float8,
1819
LinearMMConfig,
1920
ScaledMMConfig,
2021
tensor_already_casted_to_fp8,
21-
ToFloat8ConstrFunc,
2222
)
2323

2424
from float8_experimental.float8_utils import (
@@ -39,7 +39,7 @@ def cast_to_float8_e4m3_dynamic(
3939
if tensor_already_casted_to_fp8(inpt_tensor):
4040
return inpt_tensor
4141
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
42-
return ToFloat8ConstrFunc.apply(
42+
return hp_tensor_and_scale_to_float8(
4343
inpt_tensor,
4444
scale,
4545
e4m3_dtype,
@@ -58,7 +58,7 @@ def cast_to_float8_delayed(
5858
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
5959
):
6060
amax_buffer.fill_(tensor_to_amax(tensor))
61-
return ToFloat8ConstrFunc.apply(
61+
return hp_tensor_and_scale_to_float8(
6262
tensor,
6363
scale,
6464
float8_dtype,
@@ -145,7 +145,7 @@ def backward(ctx, go):
145145

146146
fp8_amax_grad_output.fill_(tensor_to_amax(go))
147147

148-
res = ToFloat8ConstrFunc.apply(
148+
res = hp_tensor_and_scale_to_float8(
149149
go,
150150
fp8_scale_grad_output,
151151
e5m2_dtype,
@@ -177,7 +177,7 @@ def backward(ctx, gradY):
177177
if tensor_already_casted_to_fp8(gradY):
178178
return gradY, None
179179
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
180-
fp8_tensor = ToFloat8ConstrFunc.apply(
180+
fp8_tensor = hp_tensor_and_scale_to_float8(
181181
gradY,
182182
gradY_scale,
183183
e5m2_dtype,

float8_experimental/float8_tensor.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
129129

130130

131131
@torch._dynamo.allow_in_graph
132-
class ToFloat8ConstrFunc(torch.autograd.Function):
132+
class _ToFloat8ConstrFunc(torch.autograd.Function):
133133
"""
134134
A differentiable conversion to fp8.
135135
* forward: convert from high precision to float8
@@ -154,15 +154,6 @@ def forward(
154154
with that composing with FakeTensor, so we special case here.
155155
156156
DTensor Invariant: DTensor must always be the outer most tensor subclass
157-
158-
Args:
159-
tensor: the tensor to convert
160-
scale: the scale to use to convert the tensor
161-
float8_dtype: the float8 dtype to use
162-
linear_mm_config: Defines the configuration for the scaled_mm for
163-
the 3 fwd/bwd gemms of linear
164-
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
165-
the 3 fwd/bwd gemms of linear
166157
"""
167158
tensor_scaled = tensor * scale
168159
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
@@ -205,7 +196,7 @@ def backward(ctx, g):
205196

206197

207198
@torch._dynamo.allow_in_graph
208-
class FromFloat8ConstrFunc(torch.autograd.Function):
199+
class _FromFloat8ConstrFunc(torch.autograd.Function):
209200
"""
210201
A differentiable conversion from fp8.
211202
* forward: convert from float8 to high precision
@@ -221,6 +212,34 @@ def backward(ctx, g):
221212
return g, None, None
222213

223214

215+
def hp_tensor_and_scale_to_float8(
216+
hp_tensor: torch.Tensor,
217+
s: torch.Tensor,
218+
float8_dtype=e4m3_dtype,
219+
linear_mm_config: Optional[LinearMMConfig] = None,
220+
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
221+
):
222+
"""
223+
Given a high precision tensor `hp_tensor` and a precalculated scale `s`,
224+
scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result.
225+
226+
Autograd-aware, the derivative is pass-through.
227+
DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor).
228+
229+
Args:
230+
hp_tensor: the tensor to convert
231+
s: the scale to use to convert the tensor
232+
float8_dtype: the float8 dtype to use
233+
linear_mm_config: Defines the configuration for the scaled_mm for
234+
the 3 fwd/bwd gemms of linear
235+
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
236+
the 3 fwd/bwd gemms of linear
237+
"""
238+
return _ToFloat8ConstrFunc.apply(
239+
hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role
240+
)
241+
242+
224243
class Float8Tensor(torch.Tensor):
225244
"""
226245
Note: this is **not** a public API and is only intended to be used
@@ -309,7 +328,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
309328
)
310329

311330
def to_original_precision(self):
312-
return FromFloat8ConstrFunc.apply(self)
331+
return _FromFloat8ConstrFunc.apply(self)
313332

314333
@classmethod
315334
def __torch_dispatch__(cls, func, types, args, kwargs=None):

float8_experimental/fsdp_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from float8_experimental.float8_tensor import (
1919
Float8Tensor,
2020
GemmInputRole,
21+
hp_tensor_and_scale_to_float8,
2122
LinearMMConfig,
22-
ToFloat8ConstrFunc,
2323
)
2424

2525
from float8_experimental.float8_utils import e4m3_dtype, EPS
@@ -167,7 +167,7 @@ def __repr__(self):
167167

168168
def fsdp_pre_all_gather(self, mesh):
169169
if self._precomputed_scale is not None:
170-
float8_tensor = ToFloat8ConstrFunc.apply(
170+
float8_tensor = hp_tensor_and_scale_to_float8(
171171
self._tensor,
172172
self._precomputed_scale,
173173
torch.float8_e4m3fn,

float8_experimental/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from float8_experimental.float8_tensor import (
2020
Float8Tensor,
2121
GemmInputRole,
22+
hp_tensor_and_scale_to_float8,
2223
LinearMMConfig,
2324
ScaledMMConfig,
2425
tensor_already_casted_to_fp8,
25-
ToFloat8ConstrFunc,
2626
)
2727
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
2828

@@ -127,7 +127,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
127127
self.weight, Float8Tensor
128128
), "Weight has already been quantized, cannot quantize again."
129129
scale = tensor_to_scale(self.weight, dtype)
130-
quantized_weight = ToFloat8ConstrFunc.apply(
130+
quantized_weight = hp_tensor_and_scale_to_float8(
131131
self.weight,
132132
scale,
133133
dtype,
@@ -200,7 +200,7 @@ def cast_to_float8_e4m3_inference(
200200
if static_quantization_scale is not None
201201
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
202202
)
203-
return ToFloat8ConstrFunc.apply(
203+
return hp_tensor_and_scale_to_float8(
204204
inpt_tensor,
205205
scale,
206206
e4m3_dtype,

test/test_base.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from float8_experimental.float8_tensor import (
2929
Float8Tensor,
3030
GemmInputRole,
31+
hp_tensor_and_scale_to_float8,
3132
LinearMMConfig,
3233
ScaledMMConfig,
33-
ToFloat8ConstrFunc,
3434
)
3535
from float8_experimental.float8_utils import (
3636
compute_error,
@@ -66,7 +66,7 @@ def test_preserves_dtype(self) -> None:
6666
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
6767
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
6868
x1_s = tensor_to_scale(x1_hp, lp_dtype)
69-
x2_lp = ToFloat8ConstrFunc.apply(x1_hp, x1_s, lp_dtype)
69+
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
7070
x3_hp = x2_lp.to_original_precision()
7171
self.assertTrue(x3_hp.dtype == hp_dtype)
7272

@@ -76,7 +76,7 @@ def test_differentiable_casts(self) -> None:
7676
x = torch.randn(1).requires_grad_()
7777
grad = torch.randn(1)
7878
x_s = tensor_to_scale(x, f8_dtype)
79-
x_f8 = ToFloat8ConstrFunc.apply(x, x_s, f8_dtype)
79+
x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
8080
x_f8_hp = x_f8.to_original_precision()
8181
x_f8_hp.backward(grad)
8282
# the gradient should be unchanged through both casts
@@ -85,7 +85,7 @@ def test_differentiable_casts(self) -> None:
8585
def test_split_cat(self):
8686
a = torch.rand(16, 16, dtype=torch.bfloat16)
8787
scale = tensor_to_scale(a, e4m3_dtype)
88-
fp8_a = ToFloat8ConstrFunc.apply(a, scale, e4m3_dtype)
88+
fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)
8989

9090
splits = torch.split(fp8_a, 16)
9191
catted = torch.cat(splits, dim=0)
@@ -94,14 +94,14 @@ def test_split_cat(self):
9494
def test_index_put(self):
9595
a = torch.rand(16, dtype=torch.bfloat16)
9696
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
97-
fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn)
97+
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
9898

9999
index = torch.randint(0, 15, (16,), dtype=torch.long)
100100

101101
b = torch.rand(16, 16, dtype=torch.bfloat16)
102102
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
103-
fp8_b = ToFloat8ConstrFunc.apply(b, scale_a, torch.float8_e4m3fn)
104-
fp8_b_bad = ToFloat8ConstrFunc.apply(b, scale_b, torch.float8_e4m3fn)
103+
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
104+
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)
105105

106106
with self.assertRaises(AssertionError):
107107
b[index] = fp8_a
@@ -112,7 +112,7 @@ def test_index_put(self):
112112
def test_copy_(self):
113113
a = torch.rand(16, dtype=torch.bfloat16)
114114
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
115-
fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn)
115+
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
116116

117117
b = torch.empty(16, dtype=torch.bfloat16)
118118
b.copy_(fp8_a) # Should work
@@ -407,8 +407,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
407407
a_scale = tensor_to_scale(a, input_dtype).float()
408408
b_scale = tensor_to_scale(b, input_dtype).float()
409409

410-
a_fp8 = ToFloat8ConstrFunc.apply(a, a_scale, input_dtype)
411-
b_fp8 = ToFloat8ConstrFunc.apply(b, b_scale, input_dtype)
410+
a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
411+
b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)
412412

413413
out_scaled_mm = addmm_float8_unwrapped(
414414
a_fp8._data,
@@ -447,14 +447,14 @@ def test_different_configs_error(self):
447447
ScaledMMConfig(True, False, False, False),
448448
ScaledMMConfig(True, False, False, False),
449449
)
450-
a = ToFloat8ConstrFunc.apply(
450+
a = hp_tensor_and_scale_to_float8(
451451
x_fp32,
452452
x_scale,
453453
fp8_dtype,
454454
linear_config_a,
455455
GemmInputRole.INPUT,
456456
)
457-
b = ToFloat8ConstrFunc.apply(
457+
b = hp_tensor_and_scale_to_float8(
458458
x_fp32,
459459
x_scale,
460460
fp8_dtype,
@@ -486,10 +486,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
486486
a_scale = tensor_to_scale(a, input_dtype).float()
487487
b_scale = tensor_to_scale(b, input_dtype).float()
488488

489-
a_fp8 = ToFloat8ConstrFunc.apply(
489+
a_fp8 = hp_tensor_and_scale_to_float8(
490490
a, a_scale, input_dtype, None, GemmInputRole.INPUT
491491
)
492-
b_fp8 = ToFloat8ConstrFunc.apply(
492+
b_fp8 = hp_tensor_and_scale_to_float8(
493493
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
494494
)
495495

@@ -506,14 +506,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
506506
scaled_mm_config, scaled_mm_config, scaled_mm_config
507507
)
508508

509-
a_fp8 = ToFloat8ConstrFunc.apply(
509+
a_fp8 = hp_tensor_and_scale_to_float8(
510510
a,
511511
a_scale,
512512
input_dtype,
513513
pad_config,
514514
GemmInputRole.INPUT,
515515
)
516-
b_fp8 = ToFloat8ConstrFunc.apply(
516+
b_fp8 = hp_tensor_and_scale_to_float8(
517517
b,
518518
b_scale,
519519
input_dtype,
@@ -529,14 +529,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
529529
emulated_scaled_mm_config,
530530
emulated_scaled_mm_config,
531531
)
532-
a_fp8 = ToFloat8ConstrFunc.apply(
532+
a_fp8 = hp_tensor_and_scale_to_float8(
533533
a,
534534
a_scale,
535535
input_dtype,
536536
emulated_config,
537537
GemmInputRole.INPUT,
538538
)
539-
b_fp8 = ToFloat8ConstrFunc.apply(
539+
b_fp8 = hp_tensor_and_scale_to_float8(
540540
b,
541541
b_scale,
542542
input_dtype,
@@ -695,19 +695,19 @@ def test_fp8_tensor_statistics(self):
695695

696696
# Overflow caused by a too large scaling factor
697697
s_overflow = torch.tensor(1e9)
698-
fp8_overflow = ToFloat8ConstrFunc.apply(x1_hp, s_overflow, lp_dtype)
698+
fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
699699
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
700700
self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))
701701

702702
# Underflow caused by a too small scaling factor
703703
s_underflow = torch.tensor(1e-9)
704-
fp8_underflow = ToFloat8ConstrFunc.apply(x1_hp, s_underflow, lp_dtype)
704+
fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
705705
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
706706
self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))
707707

708708
# Both overflow and underflow
709709
x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0)
710-
fp8_over_underflow = ToFloat8ConstrFunc.apply(
710+
fp8_over_underflow = hp_tensor_and_scale_to_float8(
711711
x2_hp, torch.tensor(1.0), lp_dtype
712712
)
713713
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)

0 commit comments

Comments
 (0)