Skip to content

Commit 96c11f9

Browse files
[float8] move inductor config setting to float8.py (#1110)
1 parent 90c60f3 commit 96c11f9

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

torchtitan/components/float8.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
1515

16+
import torch
1617
import torch.nn as nn
1718

1819
from torchtitan.config_manager import JobConfig
@@ -67,6 +68,13 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6768
f"Float8 training active with recipe {float8_config.recipe_name}"
6869
)
6970

71+
# short-term solution for https://github.com/pytorch/pytorch/issues/150859
72+
if float8_config.recipe_name == "rowwise":
73+
torch._inductor.config.emulate_precision_casts = True
74+
logger.debug(
75+
"Set torch._inductor.config.emulate_precision_casts to True"
76+
)
77+
7078
else:
7179
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
7280
enable_fsdp_float8_all_gather = (

torchtitan/train.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,6 @@ def __init__(self, job_config: JobConfig):
7171
if job_config.job.print_args:
7272
logger.info(f"Running with args: {job_config.to_dict()}")
7373

74-
# short-term solution for https://github.com/pytorch/pytorch/issues/150859
75-
if job_config.float8.recipe_name == "rowwise":
76-
torch._inductor.config.emulate_precision_casts = True
77-
logger.debug("Set torch._inductor.config.emulate_precision_casts to True")
78-
7974
# take control of garbage collection to avoid stragglers
8075
self.gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
8176

0 commit comments

Comments
 (0)