Skip to content

Commit 688d12b

Browse files
committed
Update
[ghstack-poisoned]
1 parent b1a9c44 commit 688d12b

File tree

8 files changed

+60
-36
lines changed

8 files changed

+60
-36
lines changed

examples/rlhf/models/actor_critic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
from torchrl.modules.tensordict_module.actors import LMHeadActorValueOperator
68
from torchrl.modules.tensordict_module.common import VmapModule
79

torchrl/envs/transforms/gym_transforms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import warnings
1111

1212
import torch
13-
import torchrl.objectives.common
1413
from tensordict import TensorDictBase
1514
from tensordict.utils import expand_as_right, NestedKey
1615
from torchrl.data.tensor_specs import Unbounded
@@ -189,7 +188,9 @@ def transform_observation_spec(self, observation_spec):
189188
)
190189
return observation_spec
191190

192-
def register_keys(self, loss_or_advantage: torchrl.objectives.common.LossModule):
191+
def register_keys(
192+
self, loss_or_advantage: torchrl.objectives.common.LossModule # noqa
193+
):
193194
"""Registers the end-of-life key at appropriate places within the loss.
194195
195196
Args:

torchrl/envs/transforms/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
make_composite_from_td,
8888
step_mdp,
8989
)
90-
from torchrl.objectives.value.functional import reward2go
9190

9291
_has_tv = importlib.util.find_spec("torchvision", None) is not None
9392

@@ -8539,6 +8538,8 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
85398538
def _inv_apply_transform(
85408539
self, reward: torch.Tensor, done: torch.Tensor
85418540
) -> torch.Tensor:
8541+
from torchrl.objectives.value.functional import reward2go
8542+
85428543
return reward2go(reward, done, self.gamma)
85438544

85448545
def set_container(self, container):

torchrl/modules/tensordict_module/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .actors import (
6+
from torchrl.modules.tensordict_module.actors import (
77
Actor,
88
ActorCriticOperator,
99
ActorCriticWrapper,
@@ -21,20 +21,20 @@
2121
TanhModule,
2222
ValueOperator,
2323
)
24-
from .common import SafeModule, VmapModule
25-
from .exploration import (
24+
from torchrl.modules.tensordict_module.common import SafeModule, VmapModule
25+
from torchrl.modules.tensordict_module.exploration import (
2626
AdditiveGaussianModule,
2727
AdditiveGaussianWrapper,
2828
EGreedyModule,
2929
EGreedyWrapper,
3030
OrnsteinUhlenbeckProcessModule,
3131
OrnsteinUhlenbeckProcessWrapper,
3232
)
33-
from .probabilistic import (
33+
from torchrl.modules.tensordict_module.probabilistic import (
3434
SafeProbabilisticModule,
3535
SafeProbabilisticTensorDictSequential,
3636
)
37-
from .rnn import (
37+
from torchrl.modules.tensordict_module.rnn import (
3838
GRU,
3939
GRUCell,
4040
GRUModule,
@@ -44,8 +44,8 @@
4444
recurrent_mode,
4545
set_recurrent_mode,
4646
)
47-
from .sequence import SafeSequential
48-
from .world_models import WorldModelWrapper
47+
from torchrl.modules.tensordict_module.sequence import SafeSequential
48+
from torchrl.modules.tensordict_module.world_models import WorldModelWrapper
4949

5050
__all__ = [
5151
"Actor",

torchrl/modules/tensordict_module/rnn.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919

2020
from torchrl._utils import _ContextManager, _DecoratorContextManager
2121
from torchrl.data.tensor_specs import Unbounded
22-
from torchrl.objectives.value.functional import (
23-
_inv_pad_sequence,
24-
_split_and_pad_sequence,
25-
)
26-
from torchrl.objectives.value.utils import _get_num_per_traj_init
2722

2823

2924
class LSTMCell(RNNCellBase):
@@ -718,6 +713,11 @@ def set_recurrent_mode(self, mode: bool = True):
718713

719714
@dispatch
720715
def forward(self, tensordict: TensorDictBase):
716+
from torchrl.objectives.value.functional import (
717+
_inv_pad_sequence,
718+
_split_and_pad_sequence,
719+
)
720+
721721
# we want to get an error if the value input is missing, but not the hidden states
722722
defaults = [NO_DEFAULT, None, None]
723723
shape = tensordict.shape
@@ -742,6 +742,8 @@ def forward(self, tensordict: TensorDictBase):
742742
is_init = tensordict_shaped["is_init"].squeeze(-1)
743743
splits = None
744744
if self.recurrent_mode and is_init[..., 1:].any():
745+
from torchrl.objectives.value.utils import _get_num_per_traj_init
746+
745747
# if we have consecutive trajectories, things get a little more complicated
746748
# we have a tensordict of shape [B, T]
747749
# we will split / pad things such that we get a tensordict of shape
@@ -1533,6 +1535,11 @@ def set_recurrent_mode(self, mode: bool = True):
15331535
@dispatch
15341536
@set_lazy_legacy(False)
15351537
def forward(self, tensordict: TensorDictBase):
1538+
from torchrl.objectives.value.functional import (
1539+
_inv_pad_sequence,
1540+
_split_and_pad_sequence,
1541+
)
1542+
15361543
# we want to get an error if the value input is missing, but not the hidden states
15371544
defaults = [NO_DEFAULT, None]
15381545
shape = tensordict.shape
@@ -1557,6 +1564,8 @@ def forward(self, tensordict: TensorDictBase):
15571564
is_init = tensordict_shaped["is_init"].squeeze(-1)
15581565
splits = None
15591566
if self.recurrent_mode and is_init[..., 1:].any():
1567+
from torchrl.objectives.value.utils import _get_num_per_traj_init
1568+
15601569
# if we have consecutive trajectories, things get a little more complicated
15611570
# we have a tensordict of shape [B, T]
15621571
# we will split / pad things such that we get a tensordict of shape

torchrl/objectives/__init__.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,28 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .a2c import A2CLoss
7-
from .common import LossModule
8-
from .cql import CQLLoss, DiscreteCQLLoss
9-
from .crossq import CrossQLoss
10-
from .ddpg import DDPGLoss
11-
from .decision_transformer import DTLoss, OnlineDTLoss
12-
from .dqn import DistributionalDQNLoss, DQNLoss
13-
from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss
14-
from .gail import GAILLoss
15-
from .iql import DiscreteIQLLoss, IQLLoss
16-
from .multiagent import QMixerLoss
17-
from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss
18-
from .redq import REDQLoss
19-
from .reinforce import ReinforceLoss
20-
from .sac import DiscreteSACLoss, SACLoss
21-
from .td3 import TD3Loss
22-
from .td3_bc import TD3BCLoss
23-
from .utils import (
6+
from torchrl.objectives.a2c import A2CLoss
7+
from torchrl.objectives.common import LossModule
8+
from torchrl.objectives.cql import CQLLoss, DiscreteCQLLoss
9+
from torchrl.objectives.crossq import CrossQLoss
10+
from torchrl.objectives.ddpg import DDPGLoss
11+
from torchrl.objectives.decision_transformer import DTLoss, OnlineDTLoss
12+
from torchrl.objectives.dqn import DistributionalDQNLoss, DQNLoss
13+
from torchrl.objectives.dreamer import (
14+
DreamerActorLoss,
15+
DreamerModelLoss,
16+
DreamerValueLoss,
17+
)
18+
from torchrl.objectives.gail import GAILLoss
19+
from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss
20+
from torchrl.objectives.multiagent import QMixerLoss
21+
from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss
22+
from torchrl.objectives.redq import REDQLoss
23+
from torchrl.objectives.reinforce import ReinforceLoss
24+
from torchrl.objectives.sac import DiscreteSACLoss, SACLoss
25+
from torchrl.objectives.td3 import TD3Loss
26+
from torchrl.objectives.td3_bc import TD3BCLoss
27+
from torchrl.objectives.utils import (
2428
default_value_kwargs,
2529
distance_loss,
2630
group_optimizers,

torchrl/objectives/cql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torchrl.data.tensor_specs import Composite
2121
from torchrl.data.utils import _find_action_space
2222
from torchrl.envs.utils import ExplorationType, set_exploration_type
23-
from torchrl.modules import ProbabilisticActor, QValueActor
23+
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, QValueActor
2424
from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible
2525
from torchrl.objectives.common import LossModule
2626
from torchrl.objectives.utils import (

torchrl/trainers/helpers/envs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
# Therefore we need Optional and Union
88
# from __future__ import annotations
99

10+
import importlib.util
1011
from copy import copy
1112
from dataclasses import dataclass, field as dataclass_field
1213
from typing import Any, Callable, Optional, Sequence, Union
1314

1415
import torch
15-
from omegaconf import DictConfig
16-
1716
from torchrl._utils import logger as torchrl_logger, VERBOSE
1817
from torchrl.envs import ParallelEnv
1918
from torchrl.envs.common import EnvBase
@@ -49,6 +48,14 @@
4948
"dm_control": DMControlEnv,
5049
}
5150

51+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
52+
if _has_omegaconf:
53+
from omegaconf import DictConfig
54+
else:
55+
56+
class DictConfig: # noqa
57+
...
58+
5259

5360
def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821
5461
"""Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.

0 commit comments

Comments
 (0)