Skip to content

Commit

Permalink
Track only Stateful objects and not classes (#827)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #827

`torchtnt.framework.callbacks.meta.model_store_checkpointer.ModelStoreCheckpointer` fails when a checkpointed unit contains an attribute which is a class (not a an object) implementing the `Stateful` interface. This is a typical case when a user specifies a type of an optimizer in an `AutoUnit` which is instantiated later in `AutoUnit.configure_optimizers_and_lr_scheduler`. The specific reason why the checkpointer fails is that this attribute then gets tracked because `isinstance(torch.optim.Optimizer, Stateful)` returns `True`. `MultiStateful` then tries to call `state_dict` on that attribute which fails because the attribute is not an object of a class.

Reviewed By: JKSenthil

Differential Revision: D57159095

fbshipit-source-id: 9224193f63803fa139c26553ff6090cd6ac9886d
  • Loading branch information
zedsdead01 authored and facebook-github-bot committed May 10, 2024
1 parent d76f4f0 commit 735bfbc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
11 changes: 11 additions & 0 deletions tests/framework/test_app_state_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import FSDPOptimizerWrapper
from torchtnt.utils.stateful import MultiStateful


class Dummy(AppStateMixin):
Expand All @@ -32,6 +33,7 @@ def __init__(self) -> None:
self.optimizer_c, step_size=30, gamma=0.1
)
self.grad_scaler_e = torch.cuda.amp.GradScaler()
self.optimizer_class_f = torch.optim.SGD


class AppStateMixinTest(unittest.TestCase):
Expand Down Expand Up @@ -103,6 +105,15 @@ def test_miscellaneous_stateful(self) -> None:
# assert that the grad scaler is stored in the app_state
self.assertEqual(my_unit.app_state()["grad_scaler_e"], my_unit.grad_scaler_e)

# assert that only stateful class objects are being tracked
self.assertFalse("optimizer_class_f" in my_unit.tracked_misc_statefuls())

multi_stateful = MultiStateful(my_unit.tracked_misc_statefuls())
try:
_ = multi_stateful.state_dict()
except TypeError:
self.fail("Not able to get the state dict from my_unit.")

# delete the attribute
# pyre-fixme[8]: Attribute has type `GradScaler`; used as `None`.
my_unit.grad_scaler_e = None
Expand Down
3 changes: 2 additions & 1 deletion torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict


import inspect
import logging
from abc import ABC, abstractmethod
from typing import Any, cast, Dict, Generic, Iterator, TypeVar, Union
Expand Down Expand Up @@ -148,7 +149,7 @@ def __setattr__(self, name: str, value: object) -> None:
value,
self.__dict__.get("_progress"),
)
elif isinstance(value, Stateful):
elif isinstance(value, Stateful) and not inspect.isclass(value):
self._update_attr(
name,
value,
Expand Down

0 comments on commit 735bfbc

Please sign in to comment.