6
6
7
7
from collections import deque
8
8
from collections .abc import Mapping
9
- from copy import copy , deepcopy
9
+ from copy import copy
10
10
from typing import Any , Callable , Iterable , Literal
11
11
12
12
import torch
17
17
TensorDictBase ,
18
18
unravel_key ,
19
19
)
20
- from tensordict .nn import ProbabilisticTensorDictModule , TensorDictParams
20
+ from tensordict .nn import (
21
+ ProbabilisticTensorDictModule ,
22
+ ProbabilisticTensorDictSequential ,
23
+ )
21
24
from tensordict .utils import _zip_strict , is_seq_of_nested_key
22
- from torch import nn
23
25
24
26
from torchrl .data .tensor_specs import Composite , NonTensor , TensorSpec , Unbounded
25
27
from torchrl .envs .transforms .transforms import TensorDictPrimer , Transform
26
- from torchrl .envs .transforms .utils import _set_missing_tolerance , _stateless_param
28
+ from torchrl .envs .transforms .utils import _set_missing_tolerance
27
29
from torchrl .envs .utils import make_composite_from_td
28
30
29
31
@@ -506,6 +508,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
506
508
return self ._queue .popleft ()
507
509
return out
508
510
511
+ def __repr__ (self ) -> str :
512
+ class_name = self .__class__ .__name__
513
+ return f"{ class_name } (primers={ self .primers } , dataloader={ self .dataloader } )"
514
+
509
515
510
516
class KLRewardTransform (Transform ):
511
517
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.
@@ -578,6 +584,8 @@ def __init__(
578
584
in_keys = None ,
579
585
out_keys = None ,
580
586
requires_grad = False ,
587
+ log_prob_key : NestedKey = "sample_log_prob" ,
588
+ action_key : NestedKey = "action" ,
581
589
):
582
590
if in_keys is None :
583
591
in_keys = self .DEFAULT_IN_KEYS
@@ -604,35 +612,38 @@ def __init__(
604
612
self .in_keys = self .in_keys + actor .in_keys
605
613
606
614
# check that the model has parameters
607
- params = TensorDict .from_module (actor )
608
- with params .apply (
609
- _stateless_param , device = "meta" , filter_empty = False
610
- ).to_module (actor ):
611
- # copy a stateless actor
612
- self .__dict__ ["functional_actor" ] = deepcopy (actor )
615
+ # params = TensorDict.from_module(actor)
616
+ # with params.apply(
617
+ # _stateless_param, device="meta", filter_empty=False
618
+ # ).to_module(actor):
619
+ # # copy a stateless actor
620
+ # self.__dict__["functional_actor"] = deepcopy(actor)
621
+ self .__dict__ ["functional_actor" ] = actor
622
+
613
623
# we need to register these params as buffer to have `to` and similar
614
624
# methods work properly
615
625
616
- def _make_detached_param (x ):
617
-
618
- if isinstance (x , nn .Parameter ):
619
- # we need an nn.Parameter since some modules (RNN) require nn.Parameters
620
- return nn .Parameter (x .data .clone (), requires_grad = requires_grad )
621
- elif x .requires_grad :
622
- raise ValueError (
623
- "Encountered a value that requires gradients but is not an nn.Parameter instance."
624
- )
625
- return x .clone ()
626
-
627
- self .frozen_params = params .apply (_make_detached_param , filter_empty = False )
628
- if requires_grad :
629
- # includes the frozen params/buffers in the module parameters/buffers
630
- self .frozen_params = TensorDictParams (self .frozen_params , no_convert = True )
626
+ # def _make_detached_param(x):
627
+ #
628
+ # if isinstance(x, nn.Parameter):
629
+ # # we need an nn.Parameter since some modules (RNN) require nn.Parameters
630
+ # return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
631
+ # elif x.requires_grad:
632
+ # raise ValueError(
633
+ # "Encountered a value that requires gradients but is not an nn.Parameter instance."
634
+ # )
635
+ # return x.clone()
636
+ # self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
637
+ # if requires_grad:
638
+ # # includes the frozen params/buffers in the module parameters/buffers
639
+ # self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
631
640
632
641
# self._buffers["actor_params"] = params.clone().detach()
633
642
643
+ self .action_key = action_key
644
+
634
645
# find the sample log-prob key
635
- self .sample_log_prob_key = "sample_log_prob"
646
+ self .sample_log_prob_key = log_prob_key
636
647
637
648
def find_sample_log_prob (module ):
638
649
if hasattr (module , "log_prob_key" ):
@@ -653,16 +664,25 @@ def _reset(
653
664
654
665
def _call (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
655
666
# run the actor on the tensordict
656
- action = next_tensordict .get ("action" , None )
667
+ action = next_tensordict .get (self . action_key , None )
657
668
if action is None :
658
669
# being called after reset or without action, skipping
659
670
if self .out_keys [0 ] != ("reward" ,) and self .parent is not None :
660
671
next_tensordict .set (self .out_keys [0 ], self .parent .reward_spec .zero ())
661
672
return next_tensordict
662
- with self .frozen_params .to_module (self .functional_actor ):
663
- dist = self .functional_actor .get_dist (next_tensordict .clone (False ))
664
- # get the log_prob given the original model
665
- log_prob = dist .log_prob (action )
673
+ # with self.frozen_params.to_module(self.functional_actor):
674
+ if isinstance (
675
+ self .functional_actor ,
676
+ (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential ),
677
+ ):
678
+ dist = self .functional_actor .get_dist (next_tensordict .copy ())
679
+ # get the log_prob given the original model
680
+ log_prob = dist .log_prob (action )
681
+ else :
682
+ log_prob = self .functional_actor (next_tensordict .copy ()).get (
683
+ self .sample_log_prob_key
684
+ )
685
+
666
686
reward_key = self .in_keys [0 ]
667
687
reward = next_tensordict .get ("next" ).get (reward_key )
668
688
curr_log_prob = next_tensordict .get (self .sample_log_prob_key )
@@ -685,12 +705,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
685
705
686
706
if in_key == "reward" and out_key == "reward" :
687
707
parent = self .parent
708
+
709
+ reward_keys = parent .reward_keys
710
+ if len (reward_keys ) == 1 :
711
+ reward_key = reward_keys [0 ]
712
+ elif "reward" in reward_keys :
713
+ reward_key = "reward"
714
+ else :
715
+ raise KeyError ("Couln't find the reward key." )
716
+
688
717
reward_spec = Unbounded (
689
718
device = output_spec .device ,
690
- shape = output_spec ["full_reward_spec" ][parent . reward_key ].shape ,
719
+ shape = output_spec ["full_reward_spec" ][reward_key ].shape ,
691
720
)
692
721
output_spec ["full_reward_spec" ] = Composite (
693
- {parent . reward_key : reward_spec },
722
+ {reward_key : reward_spec },
694
723
shape = output_spec ["full_reward_spec" ].shape ,
695
724
)
696
725
elif in_key == "reward" :
0 commit comments