66
77from collections import deque
88from collections .abc import Mapping
9- from copy import copy , deepcopy
9+ from copy import copy
1010from typing import Any , Callable , Iterable , Literal
1111
1212import torch
1717 TensorDictBase ,
1818 unravel_key ,
1919)
20- from tensordict .nn import ProbabilisticTensorDictModule , TensorDictParams
20+ from tensordict .nn import (
21+ ProbabilisticTensorDictModule ,
22+ ProbabilisticTensorDictSequential ,
23+ )
2124from tensordict .utils import _zip_strict , is_seq_of_nested_key
22- from torch import nn
2325
2426from torchrl .data .tensor_specs import Composite , NonTensor , TensorSpec , Unbounded
2527from torchrl .envs .transforms .transforms import TensorDictPrimer , Transform
26- from torchrl .envs .transforms .utils import _set_missing_tolerance , _stateless_param
28+ from torchrl .envs .transforms .utils import _set_missing_tolerance
2729from torchrl .envs .utils import make_composite_from_td
2830
2931
@@ -506,6 +508,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
506508 return self ._queue .popleft ()
507509 return out
508510
511+ def __repr__ (self ) -> str :
512+ class_name = self .__class__ .__name__
513+ return f"{ class_name } (primers={ self .primers } , dataloader={ self .dataloader } )"
514+
509515
510516class KLRewardTransform (Transform ):
511517 """A transform to add a KL[pi_current||pi_0] correction term to the reward.
@@ -578,6 +584,8 @@ def __init__(
578584 in_keys = None ,
579585 out_keys = None ,
580586 requires_grad = False ,
587+ log_prob_key : NestedKey = "sample_log_prob" ,
588+ action_key : NestedKey = "action" ,
581589 ):
582590 if in_keys is None :
583591 in_keys = self .DEFAULT_IN_KEYS
@@ -604,35 +612,38 @@ def __init__(
604612 self .in_keys = self .in_keys + actor .in_keys
605613
606614 # check that the model has parameters
607- params = TensorDict .from_module (actor )
608- with params .apply (
609- _stateless_param , device = "meta" , filter_empty = False
610- ).to_module (actor ):
611- # copy a stateless actor
612- self .__dict__ ["functional_actor" ] = deepcopy (actor )
615+ # params = TensorDict.from_module(actor)
616+ # with params.apply(
617+ # _stateless_param, device="meta", filter_empty=False
618+ # ).to_module(actor):
619+ # # copy a stateless actor
620+ # self.__dict__["functional_actor"] = deepcopy(actor)
621+ self .__dict__ ["functional_actor" ] = actor
622+
613623 # we need to register these params as buffer to have `to` and similar
614624 # methods work properly
615625
616- def _make_detached_param (x ):
617-
618- if isinstance (x , nn .Parameter ):
619- # we need an nn.Parameter since some modules (RNN) require nn.Parameters
620- return nn .Parameter (x .data .clone (), requires_grad = requires_grad )
621- elif x .requires_grad :
622- raise ValueError (
623- "Encountered a value that requires gradients but is not an nn.Parameter instance."
624- )
625- return x .clone ()
626-
627- self .frozen_params = params .apply (_make_detached_param , filter_empty = False )
628- if requires_grad :
629- # includes the frozen params/buffers in the module parameters/buffers
630- self .frozen_params = TensorDictParams (self .frozen_params , no_convert = True )
626+ # def _make_detached_param(x):
627+ #
628+ # if isinstance(x, nn.Parameter):
629+ # # we need an nn.Parameter since some modules (RNN) require nn.Parameters
630+ # return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
631+ # elif x.requires_grad:
632+ # raise ValueError(
633+ # "Encountered a value that requires gradients but is not an nn.Parameter instance."
634+ # )
635+ # return x.clone()
636+ # self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
637+ # if requires_grad:
638+ # # includes the frozen params/buffers in the module parameters/buffers
639+ # self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
631640
632641 # self._buffers["actor_params"] = params.clone().detach()
633642
643+ self .action_key = action_key
644+
634645 # find the sample log-prob key
635- self .sample_log_prob_key = "sample_log_prob"
646+ self .sample_log_prob_key = log_prob_key
636647
637648 def find_sample_log_prob (module ):
638649 if hasattr (module , "log_prob_key" ):
@@ -653,16 +664,25 @@ def _reset(
653664
654665 def _call (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
655666 # run the actor on the tensordict
656- action = next_tensordict .get ("action" , None )
667+ action = next_tensordict .get (self . action_key , None )
657668 if action is None :
658669 # being called after reset or without action, skipping
659670 if self .out_keys [0 ] != ("reward" ,) and self .parent is not None :
660671 next_tensordict .set (self .out_keys [0 ], self .parent .reward_spec .zero ())
661672 return next_tensordict
662- with self .frozen_params .to_module (self .functional_actor ):
663- dist = self .functional_actor .get_dist (next_tensordict .clone (False ))
664- # get the log_prob given the original model
665- log_prob = dist .log_prob (action )
673+ # with self.frozen_params.to_module(self.functional_actor):
674+ if isinstance (
675+ self .functional_actor ,
676+ (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential ),
677+ ):
678+ dist = self .functional_actor .get_dist (next_tensordict .copy ())
679+ # get the log_prob given the original model
680+ log_prob = dist .log_prob (action )
681+ else :
682+ log_prob = self .functional_actor (next_tensordict .copy ()).get (
683+ self .sample_log_prob_key
684+ )
685+
666686 reward_key = self .in_keys [0 ]
667687 reward = next_tensordict .get ("next" ).get (reward_key )
668688 curr_log_prob = next_tensordict .get (self .sample_log_prob_key )
@@ -685,12 +705,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
685705
686706 if in_key == "reward" and out_key == "reward" :
687707 parent = self .parent
708+
709+ reward_keys = parent .reward_keys
710+ if len (reward_keys ) == 1 :
711+ reward_key = reward_keys [0 ]
712+ elif "reward" in reward_keys :
713+ reward_key = "reward"
714+ else :
715+ raise KeyError ("Couln't find the reward key." )
716+
688717 reward_spec = Unbounded (
689718 device = output_spec .device ,
690- shape = output_spec ["full_reward_spec" ][parent . reward_key ].shape ,
719+ shape = output_spec ["full_reward_spec" ][reward_key ].shape ,
691720 )
692721 output_spec ["full_reward_spec" ] = Composite (
693- {parent . reward_key : reward_spec },
722+ {reward_key : reward_spec },
694723 shape = output_spec ["full_reward_spec" ].shape ,
695724 )
696725 elif in_key == "reward" :
0 commit comments