Skip to content

Commit

Permalink
Fix LR scheduler issue with CPU offload optimizer (#1649)
Browse files Browse the repository at this point in the history
* synchronize param H2D

* let CPU offload inherits Optimizer

* add scheduler to test
  • Loading branch information
gau-nernst authored Feb 2, 2025
1 parent 122eb73 commit 6ffe236
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 5 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
offload_gradients=offload_grad,
)

scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(optim1, 100)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, 100)

rng = torch.Generator(device=device)
rng.manual_seed(42)

Expand All @@ -299,6 +302,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):

optim1.step()
optim1.zero_grad()
scheduler1.step()

# reset the rng
rng.manual_seed(42)
Expand All @@ -309,6 +313,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):

optim2.step()
optim2.zero_grad()
scheduler2.step()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)
Expand Down
6 changes: 5 additions & 1 deletion torchao/prototype/low_bit_optim/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices


class CPUOffloadOptimizer:
# NOTE: We make this inherit Optimizer so it works with PyTorch's built-in LR
# schedulers. (those schedulers specifically check for instances of Optimizer).
# However, it won't behave exactly like Optimizer e.g. we don't call
# Optimizer.__init__(), there is no self.defaults.
class CPUOffloadOptimizer(Optimizer):
def __init__(
self,
params: ParamsT,
Expand Down

0 comments on commit 6ffe236

Please sign in to comment.