Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit a6cef5a

Browse files
vkuzofacebook-github-bot
authored andcommitted
clarify public API of float8_experimental (#330)
Summary: Pull Request resolved: #330 Makes the following functions public: * convert_to_float8_training and all of its configuration * linear_requires_sync * sync_float8_amax_and_scale_history * precompute_float8_dynamic_scale_for_fsdp Everything else is private. The fbsource counterpart of this PR will remove usage of private APIs. Reviewed By: weifengpy Differential Revision: D60195666 fbshipit-source-id: 2e99475cf7f852b91b4c96687a7f229a2c8b3adf
1 parent da487a3 commit a6cef5a

File tree

4 files changed

+22
-8
lines changed

4 files changed

+22
-8
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ We provide two per-tensor scaling strategies: dynamic and delayed. See https://
3535
This is the most accurate recipe as every tensor is scaled dynamically.
3636

3737
```python
38-
from float8_experimental.float8_linear_utils import (
38+
from float8_experimental import (
3939
convert_to_float8_training,
40+
precompute_float8_dynamic_scale_for_fsdp,
4041
)
41-
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
4242

4343
# create model
4444
m = Model(...)
@@ -82,11 +82,11 @@ for _ in range(N_ITER):
8282
This is theoretically the most performant recipe as it minimizes memory reads.
8383

8484
```python
85-
from float8_experimental.float8_linear_utils import (
85+
from float8_experimental import (
8686
convert_to_float8_training,
8787
sync_float8_amax_and_scale_history,
88+
TensorScalingType,
8889
)
89-
from float8_experimental.float8_linear import TensorScalingType
9090

9191
# create model
9292
m = Model(...)

float8_experimental/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010
TensorScalingType,
1111
)
1212
from float8_experimental.float8_linear import Float8Linear
13-
from float8_experimental.float8_linear_utils import convert_to_float8_training
13+
from float8_experimental.float8_linear_utils import (
14+
convert_to_float8_training,
15+
linear_requires_sync,
16+
sync_float8_amax_and_scale_history,
17+
)
1418
from float8_experimental.float8_tensor import (
1519
Float8Tensor,
1620
GemmInputRole,
1721
LinearMMConfig,
1822
ScaledMMConfig,
1923
)
24+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
2025

2126
# Needed to load Float8Tensor with weights_only = True
2227
from torch.serialization import add_safe_globals
@@ -30,7 +35,8 @@
3035
"Float8TensorCastConfig",
3136
# top level UX
3237
"convert_to_float8_training",
33-
# TODO(future): remove Float8Tensor and Float8Linear from public API
34-
"Float8Tensor",
35-
"Float8Linear",
38+
"linear_requires_sync",
39+
"sync_float8_amax_and_scale_history",
40+
"precompute_float8_dynamic_scale_for_fsdp",
41+
# note: Float8Tensor and Float8Linear are not public APIs
3642
]

float8_experimental/float8_linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
150150

151151
class Float8Linear(torch.nn.Linear):
152152
"""
153+
Note: this is **not** a public API and is only intended to be used
154+
inside of this repository. Please file an issue if you would benefit
155+
from this being a public API.
156+
153157
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
154158
scales in way friendly to delayed scaling.
155159
"""

float8_experimental/float8_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ def backward(ctx, g):
252252

253253
class Float8Tensor(torch.Tensor):
254254
"""
255+
Note: this is **not** a public API and is only intended to be used
256+
inside of this repository. Please file an issue if you would benefit
257+
from this being a public API.
258+
255259
A Python-only Float8 tensor subclass. Contains:
256260
* `_data`: the underlying e4m3 or e5m2 data
257261
* `_scale`: the scale used to scale the original fp32 tensor. We multiply

0 commit comments

Comments
 (0)