Skip to content
This repository was archived by the owner on Feb 24, 2025. It is now read-only.

Commit f7e4867

Browse files
committed
Add --allow-tf32 perf tuning argument that can be used to enable tf32
Defaults to keeping tf32 disabled. This is because we haven't fully verified training results with fp32 enabled.
1 parent d3a616a commit f7e4867

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

Diff for: docs/train-help.txt

+1
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,6 @@ Options:
6565
--fp32 BOOL Disable mixed-precision training
6666
--nhwc BOOL Use NHWC memory format with FP16
6767
--nobench BOOL Disable cuDNN benchmarking
68+
--allow-tf32 BOOL Allow PyTorch to use TF32 internally
6869
--workers INT Override number of DataLoader workers
6970
--help Show this message and exit.

Diff for: train.py

+8
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def setup_training_loop_kwargs(
6161
# Performance options (not included in desc).
6262
fp32 = None, # Disable mixed-precision training: <bool>, default = False
6363
nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
64+
allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
6465
nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
6566
workers = None, # Override number of DataLoader workers: <int>, default = 3
6667
):
@@ -343,6 +344,12 @@ def setup_training_loop_kwargs(
343344
if nobench:
344345
args.cudnn_benchmark = False
345346

347+
if allow_tf32 is None:
348+
allow_tf32 = False
349+
assert isinstance(allow_tf32, bool)
350+
if allow_tf32:
351+
args.allow_tf32 = True
352+
346353
if workers is not None:
347354
assert isinstance(workers, int)
348355
if not workers >= 1:
@@ -425,6 +432,7 @@ def convert(self, value, param, ctx):
425432
@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
426433
@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
427434
@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
435+
@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
428436
@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
429437

430438
def main(ctx, outdir, dry_run, **config_kwargs):

Diff for: training/training_loop.py

+3
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def training_loop(
115115
network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
116116
resume_pkl = None, # Network pickle to resume training from.
117117
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
118+
allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
118119
abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
119120
progress_fn = None, # Callback function for updating training progress. Called for all ranks.
120121
):
@@ -124,6 +125,8 @@ def training_loop(
124125
np.random.seed(random_seed * num_gpus + rank)
125126
torch.manual_seed(random_seed * num_gpus + rank)
126127
torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
128+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
129+
torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
127130
conv2d_gradfix.enabled = True # Improves training speed.
128131
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
129132

0 commit comments

Comments
 (0)