@@ -83,7 +83,7 @@ def world_size(self) -> int:
83
83
def test_transformer_parity (self ):
84
84
self .run_subtests (
85
85
{
86
- "enable_fsdp_fp8_all_gather " : [False , True ],
86
+ "enable_fsdp_float8_all_gather " : [False , True ],
87
87
"precompute" : [False , True ],
88
88
"scaling_type_weight" : [
89
89
TensorScalingType .DYNAMIC ,
@@ -96,12 +96,12 @@ def test_transformer_parity(self):
96
96
97
97
def _test_transformer_parity (
98
98
self ,
99
- enable_fsdp_fp8_all_gather : bool ,
99
+ enable_fsdp_float8_all_gather : bool ,
100
100
precompute : bool ,
101
101
scaling_type_weight : TensorScalingType ,
102
102
compile_transformer_block : bool ,
103
103
):
104
- if not enable_fsdp_fp8_all_gather and precompute :
104
+ if not enable_fsdp_float8_all_gather and precompute :
105
105
return
106
106
elif scaling_type_weight is TensorScalingType .DELAYED and precompute :
107
107
return
@@ -110,7 +110,7 @@ def _test_transformer_parity(
110
110
# embedding weight and output linear weight are tied but only the
111
111
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
112
112
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
113
- weight_tying = not enable_fsdp_fp8_all_gather
113
+ weight_tying = not enable_fsdp_float8_all_gather
114
114
module = self .init_transformer (weight_tying = weight_tying ).cuda ()
115
115
ref_module = copy .deepcopy (module )
116
116
float8_linear_config1 = Float8LinearConfig (
@@ -125,7 +125,7 @@ def _test_transformer_parity(
125
125
transformer_block = torch .compile (transformer_block , dynamic = False )
126
126
ref_module .layers .register_module (layer_id , transformer_block )
127
127
float8_linear_config2 = Float8LinearConfig (
128
- enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather ,
128
+ enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather ,
129
129
cast_config_weight = Float8TensorCastConfig (scaling_type = scaling_type_weight ),
130
130
)
131
131
convert_to_float8_training (
@@ -158,10 +158,10 @@ def _test_transformer_parity(
158
158
@skip_if_lt_x_gpu (2 )
159
159
def test_transformer_memory (self ):
160
160
"""Tests peak active memory in the forward and backward passes."""
161
- for enable_fsdp_fp8_all_gather in [False , True ]:
162
- self ._test_transformer_memory (enable_fsdp_fp8_all_gather )
161
+ for enable_fsdp_float8_all_gather in [False , True ]:
162
+ self ._test_transformer_memory (enable_fsdp_float8_all_gather )
163
163
164
- def _test_transformer_memory (self , enable_fsdp_fp8_all_gather : bool ):
164
+ def _test_transformer_memory (self , enable_fsdp_float8_all_gather : bool ):
165
165
torch .manual_seed (42 )
166
166
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
167
167
# allocate the cuBLAS workspaces before measuring the memory usage
@@ -184,7 +184,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
184
184
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
185
185
# requirement to use a smaller activation size
186
186
float8_linear_config = Float8LinearConfig (
187
- enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather ,
187
+ enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather ,
188
188
emulate = True ,
189
189
)
190
190
convert_to_float8_training (model , config = float8_linear_config )
@@ -231,7 +231,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
231
231
# number is kept much smaller than the actual memory usage, which is on
232
232
# the order of 100-200+ MB)
233
233
buffer_mb = 16
234
- if enable_fsdp_fp8_all_gather :
234
+ if enable_fsdp_float8_all_gather :
235
235
# Non-block parameters (fp32), 3x block non-linear-weight
236
236
# parameters (fp32) and block linear-weight parameters (fp8)
237
237
# (current all-gather, copy-out, and next all-gather), and other
@@ -255,7 +255,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
255
255
# Backward:
256
256
loss .sum ().backward ()
257
257
mem_mb = self ._get_peak_active_memory_mb ()
258
- if enable_fsdp_fp8_all_gather :
258
+ if enable_fsdp_float8_all_gather :
259
259
# Non-block parameters (fp32), 2x block non-linear weight
260
260
# parameters (fp32) and block linear-weight parameters (fp8)
261
261
# (current copy-out and next all-gather), 1x block gradients (fp32)
@@ -294,7 +294,7 @@ def test_weight_subclass_dynamic(self):
294
294
# Check for a single FSDP paramter group
295
295
module_fp32 = self .init_single_module ()
296
296
float8_linear_config = Float8LinearConfig (
297
- enable_fsdp_fp8_all_gather = True ,
297
+ enable_fsdp_float8_all_gather = True ,
298
298
emulate = True ,
299
299
)
300
300
module = convert_to_float8_training (
@@ -360,7 +360,7 @@ def get_expected_all_gather_size(module: nn.Module):
360
360
module_fp32 = self .init_single_module ()
361
361
ref_module = copy .deepcopy (module_fp32 )
362
362
float8_linear_config = Float8LinearConfig (
363
- enable_fsdp_fp8_all_gather = True ,
363
+ enable_fsdp_float8_all_gather = True ,
364
364
)
365
365
module_fp32 = convert_to_float8_training (
366
366
module_fp32 , config = float8_linear_config
@@ -418,15 +418,15 @@ def test_fp32_fp8_single_module_parity(self):
418
418
[False , True ],
419
419
[TensorScalingType .DYNAMIC , TensorScalingType .DELAYED ],
420
420
)
421
- for enable_fsdp_fp8_all_gather , scaling_type_weight in choices :
421
+ for enable_fsdp_float8_all_gather , scaling_type_weight in choices :
422
422
float8_linear_config1 = Float8LinearConfig (
423
- enable_fsdp_fp8_all_gather = False ,
423
+ enable_fsdp_float8_all_gather = False ,
424
424
cast_config_weight = Float8TensorCastConfig (
425
425
scaling_type = scaling_type_weight
426
426
),
427
427
)
428
428
float8_linear_config2 = Float8LinearConfig (
429
- enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather ,
429
+ enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather ,
430
430
cast_config_weight = Float8TensorCastConfig (
431
431
scaling_type = scaling_type_weight
432
432
),
@@ -466,15 +466,15 @@ def test_fp32_fp8_multi_module_parity(self):
466
466
[False , True ],
467
467
[TensorScalingType .DYNAMIC , TensorScalingType .DELAYED ],
468
468
)
469
- for enable_fsdp_fp8_all_gather , scaling_type_weight in choices :
469
+ for enable_fsdp_float8_all_gather , scaling_type_weight in choices :
470
470
float8_linear_config1 = Float8LinearConfig (
471
- enable_fsdp_fp8_all_gather = False ,
471
+ enable_fsdp_float8_all_gather = False ,
472
472
cast_config_weight = Float8TensorCastConfig (
473
473
scaling_type = scaling_type_weight
474
474
),
475
475
)
476
476
float8_linear_config2 = Float8LinearConfig (
477
- enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather ,
477
+ enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather ,
478
478
cast_config_weight = Float8TensorCastConfig (
479
479
scaling_type = scaling_type_weight
480
480
),
@@ -545,7 +545,7 @@ def test_delayed_scaling_inplace_update(self):
545
545
"""
546
546
module = self .init_single_module ()
547
547
float8_linear_config = Float8LinearConfig (
548
- enable_fsdp_fp8_all_gather = True ,
548
+ enable_fsdp_float8_all_gather = True ,
549
549
cast_config_weight = Float8TensorCastConfig (
550
550
scaling_type = TensorScalingType .DELAYED
551
551
),
0 commit comments