Skip to content

Commit 60df6e2

Browse files
committed
update float8 integration after UX changes
Summary: float8_experimental landed various BC-breaking UX changes last week. This PR updates torchtitan to work with the version of float8_experimental after pytorch-labs/float8_experimental#332 Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 0f70507 commit 60df6e2

File tree

1 file changed

+15
-24
lines changed

1 file changed

+15
-24
lines changed

torchtitan/float8_linear.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,6 @@
2424
from torchtitan.logging_utils import logger
2525

2626

27-
@contextlib.contextmanager
28-
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
29-
import float8_experimental.config as config
30-
31-
prev = config.enable_fsdp_fp8_all_gather
32-
torch.distributed.barrier()
33-
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
34-
try:
35-
yield
36-
finally:
37-
torch.distributed.barrier()
38-
config.enable_fsdp_fp8_all_gather = prev
39-
40-
4127
@functools.lru_cache(None)
4228
def is_sm90_or_later():
4329
# Float8 is only supported on H100+ GPUs
@@ -63,21 +49,26 @@ def maybe_build_fp8_linear(
6349
)
6450
return
6551
try:
66-
from float8_experimental.float8_linear import TensorScalingType
67-
from float8_experimental.float8_linear_utils import (
68-
swap_linear_with_float8_linear,
52+
from float8_experimental import (
53+
Float8LinearConfig,
54+
Float8TensorCastConfig,
55+
TensorScalingType,
56+
convert_to_float8_training,
6957
)
7058

7159
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
7260
enable_fsdp_float8_all_gather = (
7361
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
7462
)
75-
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
76-
swap_linear_with_float8_linear(
77-
model,
78-
scaling_type_w=TensorScalingType.DYNAMIC,
79-
skip_fqn_list=["output"],
80-
)
63+
float8_config = Float8LinearConfig(
64+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
65+
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DYNAMIC),
66+
)
67+
convert_to_float8_training(
68+
model,
69+
config=float8_config,
70+
module_filter_fn=lambda mod, fqn: fqn != "output",
71+
)
8172
logger.info(
8273
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
8374
)
@@ -102,6 +93,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
10293
"Skipped precomputing fp8 scales because SM90 or later is not available",
10394
)
10495
return
105-
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
96+
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp
10697

10798
precompute_float8_dynamic_scale_for_fsdp(model)

0 commit comments

Comments
 (0)