Skip to content

Commit 735bfbc

Browse files
zedsdead01facebook-github-bot
authored andcommitted
Track only Stateful objects and not classes (#827)
Summary: Pull Request resolved: #827 `torchtnt.framework.callbacks.meta.model_store_checkpointer.ModelStoreCheckpointer` fails when a checkpointed unit contains an attribute which is a class (not a an object) implementing the `Stateful` interface. This is a typical case when a user specifies a type of an optimizer in an `AutoUnit` which is instantiated later in `AutoUnit.configure_optimizers_and_lr_scheduler`. The specific reason why the checkpointer fails is that this attribute then gets tracked because `isinstance(torch.optim.Optimizer, Stateful)` returns `True`. `MultiStateful` then tries to call `state_dict` on that attribute which fails because the attribute is not an object of a class. Reviewed By: JKSenthil Differential Revision: D57159095 fbshipit-source-id: 9224193f63803fa139c26553ff6090cd6ac9886d
1 parent d76f4f0 commit 735bfbc

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tests/framework/test_app_state_mixin.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torchtnt.utils.env import init_from_env
2121
from torchtnt.utils.lr_scheduler import TLRScheduler
2222
from torchtnt.utils.prepare_module import FSDPOptimizerWrapper
23+
from torchtnt.utils.stateful import MultiStateful
2324

2425

2526
class Dummy(AppStateMixin):
@@ -32,6 +33,7 @@ def __init__(self) -> None:
3233
self.optimizer_c, step_size=30, gamma=0.1
3334
)
3435
self.grad_scaler_e = torch.cuda.amp.GradScaler()
36+
self.optimizer_class_f = torch.optim.SGD
3537

3638

3739
class AppStateMixinTest(unittest.TestCase):
@@ -103,6 +105,15 @@ def test_miscellaneous_stateful(self) -> None:
103105
# assert that the grad scaler is stored in the app_state
104106
self.assertEqual(my_unit.app_state()["grad_scaler_e"], my_unit.grad_scaler_e)
105107

108+
# assert that only stateful class objects are being tracked
109+
self.assertFalse("optimizer_class_f" in my_unit.tracked_misc_statefuls())
110+
111+
multi_stateful = MultiStateful(my_unit.tracked_misc_statefuls())
112+
try:
113+
_ = multi_stateful.state_dict()
114+
except TypeError:
115+
self.fail("Not able to get the state dict from my_unit.")
116+
106117
# delete the attribute
107118
# pyre-fixme[8]: Attribute has type `GradScaler`; used as `None`.
108119
my_unit.grad_scaler_e = None

torchtnt/framework/unit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99

10+
import inspect
1011
import logging
1112
from abc import ABC, abstractmethod
1213
from typing import Any, cast, Dict, Generic, Iterator, TypeVar, Union
@@ -148,7 +149,7 @@ def __setattr__(self, name: str, value: object) -> None:
148149
value,
149150
self.__dict__.get("_progress"),
150151
)
151-
elif isinstance(value, Stateful):
152+
elif isinstance(value, Stateful) and not inspect.isclass(value):
152153
self._update_attr(
153154
name,
154155
value,

0 commit comments

Comments
 (0)