Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit da487a3

Browse files
vkuzofacebook-github-bot
authored andcommitted
rename top level UX to convert_to_float8_training (#329)
Summary: Pull Request resolved: #329 Old name: `swap_linear_with_float8_linear` New name: `convert_to_float8_training` Choosing a more generic name, with the following improvements from the old name: 1. doesn't mention module swaps, which is an implementation detail 2. doesn't mention `Float8Linear`, which is an implementation detail 3. clarifies that this is for training, not to be confused with inference APIs 4. doesn't mention `linear`, which gives more freedom to add other modules later ``` find . -name '*.py' -print0 | xargs -0 sed -i 's/swap_linear_with_float8_linear/convert_to_float8_training/g' ``` Reviewed By: weifengpy Differential Revision: D60195665 fbshipit-source-id: 8157b3d6f5db36c33370014135cbadfd192ac5b4
1 parent e1c5fe1 commit da487a3

14 files changed

+52
-52
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
3636

3737
```python
3838
from float8_experimental.float8_linear_utils import (
39-
swap_linear_with_float8_linear,
39+
convert_to_float8_training,
4040
)
4141
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
4242

@@ -55,7 +55,7 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str):
5555
return True
5656

5757
# convert all `torch.nn.Linear` modules to `Float8Linear`
58-
swap_linear_with_float8_linear(m, module_filter_fn=module_filter_fn)
58+
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
5959

6060
# optional: use FSDP
6161
model = FSDP(model, use_orig_params=True)
@@ -83,7 +83,7 @@ This is theoretically the most performant recipe as it minimizes memory reads.
8383

8484
```python
8585
from float8_experimental.float8_linear_utils import (
86-
swap_linear_with_float8_linear,
86+
convert_to_float8_training,
8787
sync_float8_amax_and_scale_history,
8888
)
8989
from float8_experimental.float8_linear import TensorScalingType
@@ -106,7 +106,7 @@ config = Float8LinearConfig(
106106

107107
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
108108
# type
109-
swap_linear_with_float8_linear(
109+
convert_to_float8_training(
110110
m,
111111
config=config,
112112
)

benchmarks/bench_multi_gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TensorScalingType,
2121
)
2222
from float8_experimental.float8_linear_utils import (
23-
swap_linear_with_float8_linear,
23+
convert_to_float8_training,
2424
sync_float8_amax_and_scale_history,
2525
)
2626
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -77,7 +77,7 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
7777
modules.append(nn.ReLU())
7878
m = nn.Sequential(*modules)
7979
if is_fp8:
80-
swap_linear_with_float8_linear(
80+
convert_to_float8_training(
8181
m,
8282
config=config,
8383
)

benchmarks/profile_linear_float8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
TensorScalingType,
2525
)
2626
from float8_experimental.float8_linear_utils import (
27+
convert_to_float8_training,
2728
linear_requires_sync,
28-
swap_linear_with_float8_linear,
2929
sync_float8_amax_and_scale_history,
3030
)
3131
from torch.profiler import profile, ProfilerActivity, record_function
@@ -268,7 +268,7 @@ def main(
268268
m_ref = m_ref.to(device).to(ref_dtype)
269269

270270
m_float8 = copy.deepcopy(m_ref)
271-
swap_linear_with_float8_linear(m_float8, config=config)
271+
convert_to_float8_training(m_float8, config=config)
272272

273273
def ref_forw_backward(x):
274274
out = m_ref(x)

float8_experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
TensorScalingType,
1111
)
1212
from float8_experimental.float8_linear import Float8Linear
13-
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
13+
from float8_experimental.float8_linear_utils import convert_to_float8_training
1414
from float8_experimental.float8_tensor import (
1515
Float8Tensor,
1616
GemmInputRole,
@@ -29,7 +29,7 @@
2929
"Float8LinearConfig",
3030
"Float8TensorCastConfig",
3131
# top level UX
32-
"swap_linear_with_float8_linear",
32+
"convert_to_float8_training",
3333
# TODO(future): remove Float8Tensor and Float8Linear from public API
3434
"Float8Tensor",
3535
"Float8Linear",

float8_experimental/float8_linear_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def swap_linear_layers(
6161
from_float_func: Callable[[nn.Linear], nn.Linear],
6262
*,
6363
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
64-
) -> Optional[nn.Module]:
64+
) -> nn.Module:
6565
"""
6666
Generic function to swap linear layers in a module with a new type of linear layer.
6767
@@ -122,12 +122,12 @@ def post_order_traversal(
122122
return root_module
123123

124124

125-
def swap_linear_with_float8_linear(
125+
def convert_to_float8_training(
126126
module: nn.Module,
127127
*,
128128
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
129129
config: Float8LinearConfig = None,
130-
) -> Optional[nn.Module]:
130+
) -> nn.Module:
131131
"""
132132
Swaps `torch.nn.Linear` in `module` with `Float8Linear`.
133133

float8_experimental/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def quantize_to_float8(
215215
*,
216216
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
217217
use_fast_accum: bool = True,
218-
) -> Optional[nn.Module]:
218+
) -> nn.Module:
219219
"""
220220
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
221221

test/test_base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
)
2424
from float8_experimental.float8_linear import Float8Linear
2525
from float8_experimental.float8_linear_utils import (
26+
convert_to_float8_training,
2627
linear_requires_sync,
27-
swap_linear_with_float8_linear,
2828
sync_float8_amax_and_scale_history,
2929
)
3030
from float8_experimental.float8_python_api import addmm_float8_unwrapped
@@ -604,7 +604,7 @@ def test_swap_root_linear(self):
604604
for emulate in [True, False]:
605605
module = nn.Linear(3, 3)
606606
config = Float8LinearConfig(emulate=emulate)
607-
module = swap_linear_with_float8_linear(module, config=config)
607+
module = convert_to_float8_training(module, config=config)
608608
self.assertIsInstance(module, Float8Linear)
609609
self.assertEqual(module.linear_mm_config.y.emulate, emulate)
610610
self.assertEqual(module.linear_mm_config.y.emulate, emulate)
@@ -618,7 +618,7 @@ def test_swap_root_linear_with_children_raises(self):
618618
AssertionError,
619619
"Does not support a root nn.Linear with children",
620620
):
621-
swap_linear_with_float8_linear(module, config=config)
621+
convert_to_float8_training(module, config=config)
622622

623623
def test_swap_submodule_linears(self):
624624
class MLP(nn.Module):
@@ -630,7 +630,7 @@ def __init__(self, dim: int):
630630
for emulate in [True, False]:
631631
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
632632
config = Float8LinearConfig(emulate=emulate)
633-
model = swap_linear_with_float8_linear(model, config=config)
633+
model = convert_to_float8_training(model, config=config)
634634
self.assertIsInstance(model[0].lin1, Float8Linear)
635635
self.assertIsInstance(model[0].lin2, Float8Linear)
636636
self.assertIsInstance(model[1], Float8Linear)
@@ -658,7 +658,7 @@ def module_filter_fn(mod, fqn):
658658
)
659659

660660
config = Float8LinearConfig(emulate=True)
661-
model = swap_linear_with_float8_linear(
661+
model = convert_to_float8_training(
662662
model,
663663
config=config,
664664
module_filter_fn=module_filter_fn,
@@ -687,7 +687,7 @@ def __init__(self, dim: int):
687687
"2.lin1",
688688
]
689689
config = Float8LinearConfig(emulate=True)
690-
model = swap_linear_with_float8_linear(
690+
model = convert_to_float8_training(
691691
model,
692692
config=config,
693693
module_filter_fn=module_filter_fn,

test/test_compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
)
2121
from float8_experimental.float8_linear import Float8Linear
2222
from float8_experimental.float8_linear_utils import (
23+
convert_to_float8_training,
2324
get_float8_layers,
24-
swap_linear_with_float8_linear,
2525
sync_float8_amax_and_scale_history,
2626
)
2727
from float8_experimental.float8_tensor import Float8Tensor, LinearMMConfig
@@ -280,7 +280,7 @@ def test_sync_amax_func():
280280
scaling_type=TensorScalingType.DELAYED
281281
),
282282
)
283-
float8_mod = swap_linear_with_float8_linear(
283+
float8_mod = convert_to_float8_training(
284284
module,
285285
config=config,
286286
)
@@ -324,7 +324,7 @@ def test_sync_amax_func_cuda_graph_success():
324324
scaling_type=TensorScalingType.DELAYED
325325
),
326326
)
327-
swap_linear_with_float8_linear(
327+
convert_to_float8_training(
328328
my_module,
329329
config=config,
330330
)

test/test_dtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from float8_experimental import Float8LinearConfig
1717

1818
from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw
19-
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
19+
from float8_experimental.float8_linear_utils import convert_to_float8_training
2020
from float8_experimental.float8_tensor import (
2121
Float8Tensor,
2222
GemmInputRole,
@@ -187,12 +187,12 @@ def _test_fp8_mlp_tensor_parallelism_base(
187187
config = Float8LinearConfig(emulate=True)
188188

189189
toy_model = ToyModel().to(device)
190-
toy_model_fp8 = swap_linear_with_float8_linear(toy_model, config=config)
190+
toy_model_fp8 = convert_to_float8_training(toy_model, config=config)
191191

192192
tp_model = copy.deepcopy(toy_model)
193-
tp_model = swap_linear_with_float8_linear(tp_model, config=config)
193+
tp_model = convert_to_float8_training(tp_model, config=config)
194194
sp_model = copy.deepcopy(toy_model)
195-
sp_model = swap_linear_with_float8_linear(sp_model, config=config)
195+
sp_model = convert_to_float8_training(sp_model, config=config)
196196

197197
# vanilla TP
198198
tp_model = parallelize_module(
@@ -223,7 +223,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
223223

224224
# PrepareFloat8ModuleInput with specific submodule fqn
225225
sp_model2 = copy.deepcopy(toy_model)
226-
sp_model2 = swap_linear_with_float8_linear(sp_model2, config=config)
226+
sp_model2 = convert_to_float8_training(sp_model2, config=config)
227227

228228
sp_model2 = parallelize_module(
229229
sp_model2,

test/test_fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
TensorScalingType,
2828
)
2929
from float8_experimental.float8_linear_utils import (
30+
convert_to_float8_training,
3031
linear_requires_sync,
31-
swap_linear_with_float8_linear,
3232
sync_float8_amax_and_scale_history,
3333
)
3434
from float8_experimental.float8_utils import compute_error
@@ -90,7 +90,7 @@ def fsdp_main(rank, world_size, args):
9090

9191
# Note: we only iterate over `scaling_type_weight` because FSDP only interacts
9292
# with weights.
93-
swap_linear_with_float8_linear(
93+
convert_to_float8_training(
9494
model_fp8,
9595
config=config,
9696
)

test/test_fsdp2/test_fsdp2.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Float8TensorCastConfig,
1414
TensorScalingType,
1515
)
16-
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
16+
from float8_experimental.float8_linear_utils import convert_to_float8_training
1717
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
1818
from test_fsdp2_common import check_parity_bf16_mp, check_parity_no_mp
1919
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
@@ -116,7 +116,7 @@ def _test_transformer_parity(
116116
float8_linear_config1 = Float8LinearConfig(
117117
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
118118
)
119-
swap_linear_with_float8_linear(
119+
convert_to_float8_training(
120120
ref_module,
121121
config=float8_linear_config1,
122122
)
@@ -128,7 +128,7 @@ def _test_transformer_parity(
128128
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
129129
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
130130
)
131-
swap_linear_with_float8_linear(
131+
convert_to_float8_training(
132132
module,
133133
config=float8_linear_config2,
134134
)
@@ -187,7 +187,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
187187
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
188188
emulate=True,
189189
)
190-
swap_linear_with_float8_linear(model, config=float8_linear_config)
190+
convert_to_float8_training(model, config=float8_linear_config)
191191
model_unsharded_numel = sum(p.numel() for p in model.parameters())
192192
model_sharded_numel = (model_unsharded_numel + 1) // 2
193193
block_lin_weight_numel = 0
@@ -297,7 +297,7 @@ def test_weight_subclass_dynamic(self):
297297
enable_fsdp_fp8_all_gather=True,
298298
emulate=True,
299299
)
300-
module = swap_linear_with_float8_linear(
300+
module = convert_to_float8_training(
301301
module_fp32,
302302
config=float8_linear_config,
303303
)
@@ -310,7 +310,7 @@ def test_weight_subclass_dynamic(self):
310310

311311
# Check for multiple FSDP paramter groups
312312
module = self.init_multi_module()
313-
module = swap_linear_with_float8_linear(
313+
module = convert_to_float8_training(
314314
module,
315315
config=float8_linear_config,
316316
)
@@ -362,7 +362,7 @@ def get_expected_all_gather_size(module: nn.Module):
362362
float8_linear_config = Float8LinearConfig(
363363
enable_fsdp_fp8_all_gather=True,
364364
)
365-
module_fp32 = swap_linear_with_float8_linear(
365+
module_fp32 = convert_to_float8_training(
366366
module_fp32, config=float8_linear_config
367367
)
368368
module = module_fp32
@@ -392,7 +392,7 @@ def get_expected_all_gather_size(module: nn.Module):
392392
# - Check for multiple FSDP parameter groups
393393
module = self.init_multi_module()
394394
ref_module = copy.deepcopy(module)
395-
module = swap_linear_with_float8_linear(module, config=float8_linear_config)
395+
module = convert_to_float8_training(module, config=float8_linear_config)
396396
for submodule in module:
397397
fully_shard(submodule)
398398
fully_shard(module)
@@ -433,12 +433,12 @@ def test_fp32_fp8_single_module_parity(self):
433433
)
434434
module_fp32 = self.init_single_module()
435435
ref_module = copy.deepcopy(module_fp32)
436-
ref_module = swap_linear_with_float8_linear(
436+
ref_module = convert_to_float8_training(
437437
ref_module,
438438
config=float8_linear_config1,
439439
)
440440
ref_module = ref_module.cuda()
441-
module = swap_linear_with_float8_linear(
441+
module = convert_to_float8_training(
442442
module_fp32,
443443
config=float8_linear_config2,
444444
)
@@ -481,11 +481,11 @@ def test_fp32_fp8_multi_module_parity(self):
481481
)
482482
module = self.init_multi_module().cuda()
483483
ref_module = copy.deepcopy(module)
484-
ref_module = swap_linear_with_float8_linear(
484+
ref_module = convert_to_float8_training(
485485
ref_module,
486486
config=float8_linear_config1,
487487
)
488-
module = swap_linear_with_float8_linear(
488+
module = convert_to_float8_training(
489489
module,
490490
config=float8_linear_config2,
491491
)
@@ -518,12 +518,12 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self):
518518
module = self.init_multi_module()
519519
ref_module_bf16 = copy.deepcopy(module).to(torch.bfloat16)
520520
float8_config = Float8LinearConfig(emulate=True)
521-
ref_module_bf16 = swap_linear_with_float8_linear(
521+
ref_module_bf16 = convert_to_float8_training(
522522
ref_module_bf16,
523523
config=float8_config,
524524
)
525525
ref_module_fp32 = copy.deepcopy(module).cuda()
526-
module = swap_linear_with_float8_linear(module, config=float8_config)
526+
module = convert_to_float8_training(module, config=float8_config)
527527
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
528528
for mlp in module:
529529
fully_shard(mlp, mp_policy=mp_policy)
@@ -550,7 +550,7 @@ def test_delayed_scaling_inplace_update(self):
550550
scaling_type=TensorScalingType.DELAYED
551551
),
552552
)
553-
m_fp8 = swap_linear_with_float8_linear(
553+
m_fp8 = convert_to_float8_training(
554554
module,
555555
config=float8_linear_config,
556556
)

0 commit comments

Comments
 (0)