Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug/feature: Fixing mixed precision training #290

Open
wants to merge 69 commits into
base: main
Choose a base branch
from
Open
Changes from 4 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
e8b084c
Adding grad scaler
isamu-isozaki Feb 16, 2024
dc795ff
Fixed double backward
isamu-isozaki Feb 16, 2024
895c65f
zero grad
isamu-isozaki Feb 16, 2024
6049d40
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 16, 2024
7ccb836
Fixed accordign to review
isamu-isozaki Feb 21, 2024
5e539c1
Fixed styles
isamu-isozaki Feb 21, 2024
35401da
Cleaned up code
isamu-isozaki Feb 21, 2024
4ccca8d
Removed diffs
isamu-isozaki Feb 21, 2024
3c7f8f3
Disable grad scaler for bf16
isamu-isozaki Feb 21, 2024
eebaad7
Fixed mixed precision training
isamu-isozaki Feb 21, 2024
d362686
Added option for None as dtype for mixed precision training
isamu-isozaki Feb 21, 2024
6e5bb46
Remove dtype
isamu-isozaki Feb 22, 2024
e2291b9
Testing with no api change
isamu-isozaki Feb 22, 2024
a45c9a4
Fixed typo
isamu-isozaki Feb 23, 2024
65c4195
Amp config
isamu-isozaki Feb 23, 2024
ba438aa
Style fixes
isamu-isozaki Feb 23, 2024
88f002c
Attempt file fix
isamu-isozaki Feb 23, 2024
e9d857d
Remove diff
isamu-isozaki Feb 23, 2024
0803b86
Remove diffs
isamu-isozaki Feb 23, 2024
ca4ad1b
Fixed always true condition
isamu-isozaki Feb 23, 2024
a3674e4
Removed not implemented error
isamu-isozaki Feb 23, 2024
f0bc382
Remove comment
isamu-isozaki Feb 23, 2024
59cf8ab
Remove overflow grad check
isamu-isozaki Feb 23, 2024
73a4692
Removed grad scaler for non-amp training
isamu-isozaki Feb 26, 2024
ab00953
Remove comments
isamu-isozaki Feb 26, 2024
903fbfe
Cleaner amp
isamu-isozaki Feb 26, 2024
e289d64
Fix default
isamu-isozaki Feb 26, 2024
2959c30
More explicit name
isamu-isozaki Feb 26, 2024
362bd2d
Remove accelerate
isamu-isozaki Feb 26, 2024
2d64832
Linting
isamu-isozaki Feb 26, 2024
05e77b6
Adding grad scaler
isamu-isozaki Feb 16, 2024
be99c30
zero grad
isamu-isozaki Feb 16, 2024
c121dfb
Fixed double backward
isamu-isozaki Feb 16, 2024
b4da3ef
Fixed accordign to review
isamu-isozaki Feb 21, 2024
aa6f0a4
Fixed styles
isamu-isozaki Feb 21, 2024
e227f0d
Cleaned up code
isamu-isozaki Feb 21, 2024
e1f309d
Removed diffs
isamu-isozaki Feb 21, 2024
ae08009
Disable grad scaler for bf16
isamu-isozaki Feb 21, 2024
ce364e1
Fixed mixed precision training
isamu-isozaki Feb 21, 2024
5dc798a
Added option for None as dtype for mixed precision training
isamu-isozaki Feb 21, 2024
9876caf
Remove dtype
isamu-isozaki Feb 22, 2024
b547ee1
Testing with no api change
isamu-isozaki Feb 22, 2024
d78525d
Fixed typo
isamu-isozaki Feb 23, 2024
a82daf4
Amp config
isamu-isozaki Feb 23, 2024
1cf64e3
Style fixes
isamu-isozaki Feb 23, 2024
bc20875
Attempt file fix
isamu-isozaki Feb 23, 2024
691aeb5
Remove diff
isamu-isozaki Feb 23, 2024
dce3cd4
Remove diffs
isamu-isozaki Feb 23, 2024
486de2a
Fixed always true condition
isamu-isozaki Feb 23, 2024
9e7b65a
Removed not implemented error
isamu-isozaki Feb 23, 2024
5d3cd51
Remove comment
isamu-isozaki Feb 23, 2024
acba9cf
Remove overflow grad check
isamu-isozaki Feb 23, 2024
78e3a52
Removed grad scaler for non-amp training
isamu-isozaki Feb 26, 2024
03f4352
Remove comments
isamu-isozaki Feb 26, 2024
3ed2c29
Cleaner amp
isamu-isozaki Feb 26, 2024
4df3741
Fix default
isamu-isozaki Feb 26, 2024
123b4d4
More explicit name
isamu-isozaki Feb 26, 2024
72e8216
Remove accelerate
isamu-isozaki Feb 26, 2024
e206b7a
Linting
isamu-isozaki Feb 26, 2024
07185be
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 26, 2024
7c7c3ff
Merge branch 'main' into grad_scaler
isamu-isozaki Feb 26, 2024
8a9e034
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 26, 2024
a5ff578
Remove diffs
isamu-isozaki Feb 26, 2024
da188f7
Did fixes
isamu-isozaki Mar 1, 2024
49ebf48
Fixed import
isamu-isozaki Mar 8, 2024
58fcb15
Resolve merge conflicts
isamu-isozaki Apr 1, 2024
1d648a2
Delete double optimizer step
isamu-isozaki Apr 1, 2024
4efdffb
Fixed dtype
isamu-isozaki Apr 29, 2024
27cbdba
Merge branch 'main' into grad_scaler
isamu-isozaki Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
StepLR,
)
from torch.utils.data import DataLoader, Dataset

from torch.cuda.amp import GradScaler
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import no_grad
from refiners.training_utils.callback import (
Expand Down Expand Up @@ -182,7 +182,11 @@ def dtype(self) -> DType:
assert isinstance(dtype, DType), f"Unknown dtype: {self.config.training.dtype}"
logger.info(f"Using dtype: {dtype}")
return dtype

@cached_property
def scaler(self) -> GradScaler | None:
if self.config.training.dtype == "float32":
return None
return GradScaler()
@property
def learnable_parameters(self) -> list[nn.Parameter]:
"""Returns a list of learnable parameters in all models"""
Expand Down Expand Up @@ -345,11 +349,24 @@ def backward(self) -> None:
"""Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin")
scaled_loss = self.loss / self.clock.num_step_per_iteration
backward(tensors=scaled_loss)
if self.scaler is not None:
isamu-isozaki marked this conversation as resolved.
Show resolved Hide resolved
self.scaler.scale(scaled_loss).backward() # type: ignore
else:
backward(tensors=scaled_loss)
self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step:
self._call_callbacks(event_name="on_optimizer_step_begin")
self.optimizer.step()
if self.scaler is not None:
isamu-isozaki marked this conversation as resolved.
Show resolved Hide resolved
# logic from accelerator
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
# If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow.
if scale_after < scale_before:
logger.info("Overflow in optimizer caused optimizer to skip")
else:
self.optimizer.step()
self.optimizer.zero_grad()
self._call_callbacks(event_name="on_optimizer_step_end")
if self.clock.is_lr_scheduler_step:
Expand Down