Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit eff4ba6

Browse files
vkuzofacebook-github-bot
authored andcommitted
rename config.enable_fsdp_fp8_all_gather to use float8 (#332)
Summary: Pull Request resolved: #332 old: `enable_fsdp_fp8_all_gather` new: `enable_fsdp_float8_all_gather` this is to match the `float8` naming elsewhere Reviewed By: weifengpy Differential Revision: D60252072 fbshipit-source-id: 5e240f0a97b647aa4f43a63dab3f03f68fd3b405
1 parent 701647b commit eff4ba6

File tree

3 files changed

+24
-25
lines changed

3 files changed

+24
-25
lines changed

float8_experimental/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,9 @@ class Float8LinearConfig:
5858
# option is useful for safety, but not strictly necessary.
5959
enable_pre_and_post_forward: bool = True
6060

61-
# If True, then uses a tensor subclass for the fp8 linear module's weight that
62-
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
63-
# Only dynamic scaling is supported for now.
64-
enable_fsdp_fp8_all_gather: bool = False
61+
# If True, then uses a tensor subclass for the float8 linear module's weight that
62+
# implements pre/post-all-gather methods to do float8 all-gather with FSDP2.
63+
enable_fsdp_float8_all_gather: bool = False
6564

6665
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
6766
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls

float8_experimental/float8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def from_float(
467467
# 1. weight needs to be on the correct device to create the buffers
468468
# 2. buffers need to be already created for the delayed scaling version
469469
# of the weight wrapper to be initialized
470-
if config.enable_fsdp_fp8_all_gather:
470+
if config.enable_fsdp_float8_all_gather:
471471
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
472472
new_mod.weight = torch.nn.Parameter(
473473
WeightWithDynamicFloat8CastTensor(

test/test_fsdp2/test_fsdp2.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def world_size(self) -> int:
8383
def test_transformer_parity(self):
8484
self.run_subtests(
8585
{
86-
"enable_fsdp_fp8_all_gather": [False, True],
86+
"enable_fsdp_float8_all_gather": [False, True],
8787
"precompute": [False, True],
8888
"scaling_type_weight": [
8989
TensorScalingType.DYNAMIC,
@@ -96,12 +96,12 @@ def test_transformer_parity(self):
9696

9797
def _test_transformer_parity(
9898
self,
99-
enable_fsdp_fp8_all_gather: bool,
99+
enable_fsdp_float8_all_gather: bool,
100100
precompute: bool,
101101
scaling_type_weight: TensorScalingType,
102102
compile_transformer_block: bool,
103103
):
104-
if not enable_fsdp_fp8_all_gather and precompute:
104+
if not enable_fsdp_float8_all_gather and precompute:
105105
return
106106
elif scaling_type_weight is TensorScalingType.DELAYED and precompute:
107107
return
@@ -110,7 +110,7 @@ def _test_transformer_parity(
110110
# embedding weight and output linear weight are tied but only the
111111
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
112112
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
113-
weight_tying = not enable_fsdp_fp8_all_gather
113+
weight_tying = not enable_fsdp_float8_all_gather
114114
module = self.init_transformer(weight_tying=weight_tying).cuda()
115115
ref_module = copy.deepcopy(module)
116116
float8_linear_config1 = Float8LinearConfig(
@@ -125,7 +125,7 @@ def _test_transformer_parity(
125125
transformer_block = torch.compile(transformer_block, dynamic=False)
126126
ref_module.layers.register_module(layer_id, transformer_block)
127127
float8_linear_config2 = Float8LinearConfig(
128-
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
128+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
129129
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
130130
)
131131
convert_to_float8_training(
@@ -158,10 +158,10 @@ def _test_transformer_parity(
158158
@skip_if_lt_x_gpu(2)
159159
def test_transformer_memory(self):
160160
"""Tests peak active memory in the forward and backward passes."""
161-
for enable_fsdp_fp8_all_gather in [False, True]:
162-
self._test_transformer_memory(enable_fsdp_fp8_all_gather)
161+
for enable_fsdp_float8_all_gather in [False, True]:
162+
self._test_transformer_memory(enable_fsdp_float8_all_gather)
163163

164-
def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
164+
def _test_transformer_memory(self, enable_fsdp_float8_all_gather: bool):
165165
torch.manual_seed(42)
166166
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
167167
# allocate the cuBLAS workspaces before measuring the memory usage
@@ -184,7 +184,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
184184
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
185185
# requirement to use a smaller activation size
186186
float8_linear_config = Float8LinearConfig(
187-
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
187+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
188188
emulate=True,
189189
)
190190
convert_to_float8_training(model, config=float8_linear_config)
@@ -231,7 +231,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
231231
# number is kept much smaller than the actual memory usage, which is on
232232
# the order of 100-200+ MB)
233233
buffer_mb = 16
234-
if enable_fsdp_fp8_all_gather:
234+
if enable_fsdp_float8_all_gather:
235235
# Non-block parameters (fp32), 3x block non-linear-weight
236236
# parameters (fp32) and block linear-weight parameters (fp8)
237237
# (current all-gather, copy-out, and next all-gather), and other
@@ -255,7 +255,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
255255
# Backward:
256256
loss.sum().backward()
257257
mem_mb = self._get_peak_active_memory_mb()
258-
if enable_fsdp_fp8_all_gather:
258+
if enable_fsdp_float8_all_gather:
259259
# Non-block parameters (fp32), 2x block non-linear weight
260260
# parameters (fp32) and block linear-weight parameters (fp8)
261261
# (current copy-out and next all-gather), 1x block gradients (fp32)
@@ -294,7 +294,7 @@ def test_weight_subclass_dynamic(self):
294294
# Check for a single FSDP paramter group
295295
module_fp32 = self.init_single_module()
296296
float8_linear_config = Float8LinearConfig(
297-
enable_fsdp_fp8_all_gather=True,
297+
enable_fsdp_float8_all_gather=True,
298298
emulate=True,
299299
)
300300
module = convert_to_float8_training(
@@ -360,7 +360,7 @@ def get_expected_all_gather_size(module: nn.Module):
360360
module_fp32 = self.init_single_module()
361361
ref_module = copy.deepcopy(module_fp32)
362362
float8_linear_config = Float8LinearConfig(
363-
enable_fsdp_fp8_all_gather=True,
363+
enable_fsdp_float8_all_gather=True,
364364
)
365365
module_fp32 = convert_to_float8_training(
366366
module_fp32, config=float8_linear_config
@@ -418,15 +418,15 @@ def test_fp32_fp8_single_module_parity(self):
418418
[False, True],
419419
[TensorScalingType.DYNAMIC, TensorScalingType.DELAYED],
420420
)
421-
for enable_fsdp_fp8_all_gather, scaling_type_weight in choices:
421+
for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
422422
float8_linear_config1 = Float8LinearConfig(
423-
enable_fsdp_fp8_all_gather=False,
423+
enable_fsdp_float8_all_gather=False,
424424
cast_config_weight=Float8TensorCastConfig(
425425
scaling_type=scaling_type_weight
426426
),
427427
)
428428
float8_linear_config2 = Float8LinearConfig(
429-
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
429+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
430430
cast_config_weight=Float8TensorCastConfig(
431431
scaling_type=scaling_type_weight
432432
),
@@ -466,15 +466,15 @@ def test_fp32_fp8_multi_module_parity(self):
466466
[False, True],
467467
[TensorScalingType.DYNAMIC, TensorScalingType.DELAYED],
468468
)
469-
for enable_fsdp_fp8_all_gather, scaling_type_weight in choices:
469+
for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
470470
float8_linear_config1 = Float8LinearConfig(
471-
enable_fsdp_fp8_all_gather=False,
471+
enable_fsdp_float8_all_gather=False,
472472
cast_config_weight=Float8TensorCastConfig(
473473
scaling_type=scaling_type_weight
474474
),
475475
)
476476
float8_linear_config2 = Float8LinearConfig(
477-
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
477+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
478478
cast_config_weight=Float8TensorCastConfig(
479479
scaling_type=scaling_type_weight
480480
),
@@ -545,7 +545,7 @@ def test_delayed_scaling_inplace_update(self):
545545
"""
546546
module = self.init_single_module()
547547
float8_linear_config = Float8LinearConfig(
548-
enable_fsdp_fp8_all_gather=True,
548+
enable_fsdp_float8_all_gather=True,
549549
cast_config_weight=Float8TensorCastConfig(
550550
scaling_type=TensorScalingType.DELAYED
551551
),

0 commit comments

Comments
 (0)