Skip to content

Commit c3d5d54

Browse files
authored
Improve OptimizerWrapper composability (#85)
* Improve OptimizerWrapper composability OptimizerWrapper currently miss several attributes that are required for training integration. This PR adds the missing gap.
1 parent 6e4ae38 commit c3d5d54

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

Diff for: torchft/optim.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
1313
"""
1414

15-
from typing import TYPE_CHECKING, Optional
15+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
1616

17+
import torch
1718
from torch.optim import Optimizer
1819

1920
if TYPE_CHECKING:
@@ -52,3 +53,11 @@ def step(self, closure: Optional[object] = None) -> None:
5253
assert closure is None, "optimizers that use closures are not supported"
5354
if self.manager.should_commit():
5455
self.optim.step()
56+
57+
@property
58+
def param_groups(self) -> List[Dict[str, Any]]:
59+
return self.optim.param_groups
60+
61+
@property
62+
def state(self) -> Mapping[torch.Tensor, Any]: # pyre-fixme[3]
63+
return self.optim.state

Diff for: torchft/optim_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from unittest import TestCase
88
from unittest.mock import MagicMock, create_autospec
99

10+
import torch
1011
from torch.nn import Linear
1112
from torch.optim import AdamW
1213

@@ -34,9 +35,16 @@ def test_optimizer_wrapper(self) -> None:
3435
optim.zero_grad()
3536
self.assertEqual(manager.start_quorum.call_count, 1)
3637

38+
b = torch.rand(3)
39+
m(b).sum().backward()
40+
3741
manager.should_commit.return_value = True
3842
optim.step()
3943
manager.should_commit.return_value = False
4044
optim.step()
45+
self.assertEqual(len(optim.param_groups), 2)
46+
self.assertEqual(optim.param_groups[1]["lr"], 1e-4)
47+
self.assertEqual(optim.param_groups[1]["params"], [])
48+
self.assertEqual(len(optim.state), len(list(m.parameters())))
4149

4250
self.assertEqual(manager.should_commit.call_count, 2)

0 commit comments

Comments
 (0)