Skip to content

Commit 2271b63

Browse files
authored
Delete delayed scaling (#812)
Torchao plans to deprecate delayed scaling, delete it at torchtitan Fix the issue: #654 Here are the logs running with `enable_float8_linear = true` 1. `compile = false` ``` [rank0]:2025-01-31 10:12:50,551 - root - INFO - Float8 training active [rank0]:2025-01-31 10:12:50,571 - root - INFO - Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=False [rank0]:2025-01-31 10:12:50,572 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2025-01-31 10:12:50,572 - root - INFO - Applied selective activation checkpointing to the model [rank0]:2025-01-31 10:12:50,635 - root - INFO - Applied FSDP to the model [rank0]:2025-01-31 10:12:50,835 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank0]:2025-01-31 10:12:50,835 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:2025-01-31 10:12:50,837 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250131-1012 [rank0]:2025-01-31 10:12:50,837 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 5 (warmup 200) [rank0]:2025-01-31 10:12:50,837 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:2025-01-31 10:13:02,460 - root - INFO - step: 1 loss: 12.2581 memory: 74.27GiB(78.18%) tps: 705 mfu: 4.13% [rank0]:2025-01-31 10:13:02,460 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2025-01-31 10:13:04,973 - root - INFO - step: 2 loss: 12.0754 memory: 81.77GiB(86.07%) tps: 3,262 mfu: 19.10% [rank0]:2025-01-31 10:13:07,033 - root - INFO - step: 3 loss: 11.7432 memory: 81.77GiB(86.07%) tps: 3,980 mfu: 23.30% [rank0]:2025-01-31 10:13:09,089 - root - INFO - step: 4 loss: 11.3079 memory: 81.77GiB(86.07%) tps: 3,986 mfu: 23.34% [rank0]:2025-01-31 10:13:11,146 - root - INFO - step: 5 loss: 10.9303 memory: 81.77GiB(86.07%) tps: 3,985 mfu: 23.33% [rank0]:2025-01-31 10:13:11,147 - root - INFO - Saving a full checkpoint at last step, step 5. [rank0]:2025-01-31 10:13:31,549 - root - INFO - Finished saving the checkpoint (or staging if async is enabled)in 20.40 seconds. [rank0]:2025-01-31 10:13:31,549 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2025-01-31 10:13:33,551 - root - INFO - Training completed ``` 2. `compile = true` ``` [rank0]:2025-01-31 10:18:55,527 - root - INFO - Float8 training active [rank0]:2025-01-31 10:18:55,547 - root - INFO - Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=False [rank0]:2025-01-31 10:18:55,548 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2025-01-31 10:18:55,549 - root - INFO - Applied selective activation checkpointing to the model [rank0]:2025-01-31 10:18:55,591 - root - INFO - Compiling each TransformerBlock with torch.compile [rank0]:2025-01-31 10:18:55,656 - root - INFO - Applied FSDP to the model [rank0]:2025-01-31 10:18:56,530 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank0]:2025-01-31 10:18:56,532 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250131-1018 [rank0]:2025-01-31 10:18:56,533 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 5 (warmup 200) [rank0]:2025-01-31 10:18:56,533 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] [rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] Detected that context_fn is passed to torch.utils.checkpoint under torch.compile. [rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_). [rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] [rank0]:/data/users/yifanmao/pytorch/torch/_inductor/lowering.py:1903: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2025-01-31 10:19:15,619 - root - INFO - step: 1 loss: 12.2476 memory: 40.21GiB(42.32%) tps: 429 mfu: 2.51% [rank0]:2025-01-31 10:19:15,619 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2025-01-31 10:19:16,747 - root - INFO - step: 2 loss: 12.0860 memory: 47.77GiB(50.28%) tps: 7,267 mfu: 42.55% [rank0]:2025-01-31 10:19:17,852 - root - INFO - step: 3 loss: 11.7620 memory: 47.77GiB(50.28%) tps: 7,420 mfu: 43.45% [rank0]:2025-01-31 10:19:18,953 - root - INFO - step: 4 loss: 11.3075 memory: 47.77GiB(50.28%) tps: 7,449 mfu: 43.62% [rank0]:2025-01-31 10:19:20,054 - root - INFO - step: 5 loss: 10.9359 memory: 47.77GiB(50.28%) tps: 7,448 mfu: 43.61% [rank0]:2025-01-31 10:19:20,054 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:2025-01-31 10:19:22,056 - root - INFO - Training completed ```
1 parent d4c86e3 commit 2271b63

File tree

4 files changed

+6
-69
lines changed

4 files changed

+6
-69
lines changed

scripts/estimate/estimation.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def loss_fn(pred, labels):
116116
model_config.vocab_size = tokenizer.n_words
117117
model_config.max_seq_len = job_config.training.seq_len
118118

119-
with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():
119+
with (
120+
FakeTensorMode()
121+
if not job_config.memory_estimation.disable_fake_mode
122+
else contextlib.nullcontext()
123+
):
120124

121125
logger.info(
122126
f"Building {model_name} {job_config.model.flavor} with {model_config}"
@@ -174,8 +178,6 @@ def loss_fn(pred, labels):
174178
torch.nn.utils.clip_grad_norm_(
175179
model.parameters(), job_config.training.max_norm, foreach=True
176180
)
177-
# sync float8 amaxes and scales
178-
float8_handler.sync_float8_amax_and_scale_history(model)
179181
# optimizer step
180182
optimizers.step()
181183
lr_schedulers.step()

torchtitan/config_manager.py

-19
Original file line numberDiff line numberDiff line change
@@ -548,25 +548,6 @@ def __init__(self):
548548
action="store_true",
549549
help="Whether precompute float8 scales dynamically for FSDP",
550550
)
551-
self.parser.add_argument(
552-
"--float8.scaling_type_input",
553-
type=str,
554-
default="dynamic",
555-
help="float8 scaling for input, dynamic (default) or delayed",
556-
choices=["dynamic", "delayed"],
557-
)
558-
self.parser.add_argument(
559-
"--float8.scaling_type_weight",
560-
type=str,
561-
default="dynamic",
562-
help="float8 scaling for input, dynamic (default) or delayed",
563-
)
564-
self.parser.add_argument(
565-
"--float8.scaling_type_grad_output",
566-
type=str,
567-
default="dynamic",
568-
help="float8 scaling for input, dynamic (default) or delayed",
569-
)
570551

571552
# communications library settings
572553
self.parser.add_argument(

torchtitan/float8.py

+1-44
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4141
)
4242
return
4343
try:
44-
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
44+
from torchao.float8 import Float8LinearConfig
4545
except ImportError as e:
4646
raise ImportError(
4747
"torchao is not installed. Please install it to use float8 linear layers."
@@ -52,14 +52,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5252
parallel_dims.dp_shard_enabled
5353
and float8_config.enable_fsdp_float8_all_gather
5454
)
55-
scaling_type_input = ScalingType(float8_config.scaling_type_input)
56-
scaling_type_weight = ScalingType(float8_config.scaling_type_weight)
57-
scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output)
5855
self.config = Float8LinearConfig(
5956
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
60-
cast_config_input=CastConfig(scaling_type=scaling_type_input),
61-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
62-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
6357
)
6458

6559
self.enabled = True
@@ -70,15 +64,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
7064
and float8_config.precompute_float8_dynamic_scale_for_fsdp
7165
)
7266

73-
# for sync_float8_amax_and_scale_history
74-
self.delayed_scaling = (
75-
scaling_type_input is ScalingType.DELAYED
76-
or scaling_type_weight is ScalingType.DELAYED
77-
or scaling_type_grad_output is ScalingType.DELAYED
78-
)
79-
self._sync_float8_amax_and_scale_history = None
80-
self.compile = job_config.training.compile
81-
8267
logger.info("Float8 training active")
8368

8469
def convert_to_float8_training(self, model: nn.Module):
@@ -117,31 +102,3 @@ def precompute_float8_dynamic_scale_for_fsdp(
117102
models = [model] if isinstance(model, nn.Module) else model
118103
for m in models:
119104
precompute_float8_dynamic_scale_for_fsdp(m)
120-
121-
def sync_float8_amax_and_scale_history(
122-
self, model: Union[nn.Module, List[nn.Module]]
123-
):
124-
if not self.enabled:
125-
return
126-
127-
if not self.delayed_scaling:
128-
return
129-
130-
from torchao.float8 import sync_float8_amax_and_scale_history
131-
132-
# TODO(vkuzo): see if precalculating the modules to sync over is going to
133-
# meaningfully help performance
134-
135-
if self._sync_float8_amax_and_scale_history is None:
136-
if self.compile:
137-
self._sync_float8_amax_and_scale_history = torch.compile(
138-
sync_float8_amax_and_scale_history
139-
)
140-
else:
141-
self._sync_float8_amax_and_scale_history = (
142-
sync_float8_amax_and_scale_history
143-
)
144-
145-
models = [model] if isinstance(model, nn.Module) else model
146-
for m in models:
147-
self._sync_float8_amax_and_scale_history(m)

train.py

-3
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,6 @@ def loss_fn(pred, labels):
321321
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
322322
)
323323

324-
# sync float8 amaxes and scales
325-
float8_handler.sync_float8_amax_and_scale_history(model_parts)
326-
327324
# optimizer step
328325
checkpoint.maybe_wait_for_staging()
329326
optimizers.step()

0 commit comments

Comments
 (0)