6
6
7
7
from collections import deque
8
8
from collections .abc import Mapping
9
- from copy import copy , deepcopy
9
+ from copy import copy
10
10
from typing import Any , Callable , Iterable , Literal
11
11
12
12
import torch
13
13
from tensordict import lazy_stack , NestedKey , TensorDict , TensorDictBase , unravel_key
14
- from tensordict .nn import ProbabilisticTensorDictModule , TensorDictParams
14
+ from tensordict .nn import (
15
+ ProbabilisticTensorDictModule ,
16
+ ProbabilisticTensorDictSequential ,
17
+ )
15
18
from tensordict .utils import _zip_strict , is_seq_of_nested_key
16
- from torch import nn
17
19
18
20
from torchrl .data .tensor_specs import Composite , NonTensor , TensorSpec , Unbounded
19
21
from torchrl .envs .transforms .transforms import TensorDictPrimer , Transform
20
- from torchrl .envs .transforms .utils import _set_missing_tolerance , _stateless_param
22
+ from torchrl .envs .transforms .utils import _set_missing_tolerance
21
23
from torchrl .envs .utils import make_composite_from_td
22
24
23
25
@@ -500,6 +502,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
500
502
return self ._queue .popleft ()
501
503
return out
502
504
505
+ def __repr__ (self ) -> str :
506
+ class_name = self .__class__ .__name__
507
+ return f"{ class_name } (primers={ self .primers } , dataloader={ self .dataloader } )"
508
+
503
509
504
510
class KLRewardTransform (Transform ):
505
511
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.
@@ -572,6 +578,8 @@ def __init__(
572
578
in_keys = None ,
573
579
out_keys = None ,
574
580
requires_grad = False ,
581
+ log_prob_key : NestedKey = "sample_log_prob" ,
582
+ action_key : NestedKey = "action" ,
575
583
):
576
584
if in_keys is None :
577
585
in_keys = self .DEFAULT_IN_KEYS
@@ -598,35 +606,38 @@ def __init__(
598
606
self .in_keys = self .in_keys + actor .in_keys
599
607
600
608
# check that the model has parameters
601
- params = TensorDict .from_module (actor )
602
- with params .apply (
603
- _stateless_param , device = "meta" , filter_empty = False
604
- ).to_module (actor ):
605
- # copy a stateless actor
606
- self .__dict__ ["functional_actor" ] = deepcopy (actor )
609
+ # params = TensorDict.from_module(actor)
610
+ # with params.apply(
611
+ # _stateless_param, device="meta", filter_empty=False
612
+ # ).to_module(actor):
613
+ # # copy a stateless actor
614
+ # self.__dict__["functional_actor"] = deepcopy(actor)
615
+ self .__dict__ ["functional_actor" ] = actor
616
+
607
617
# we need to register these params as buffer to have `to` and similar
608
618
# methods work properly
609
619
610
- def _make_detached_param (x ):
611
-
612
- if isinstance (x , nn .Parameter ):
613
- # we need an nn.Parameter since some modules (RNN) require nn.Parameters
614
- return nn .Parameter (x .data .clone (), requires_grad = requires_grad )
615
- elif x .requires_grad :
616
- raise ValueError (
617
- "Encountered a value that requires gradients but is not an nn.Parameter instance."
618
- )
619
- return x .clone ()
620
-
621
- self .frozen_params = params .apply (_make_detached_param , filter_empty = False )
622
- if requires_grad :
623
- # includes the frozen params/buffers in the module parameters/buffers
624
- self .frozen_params = TensorDictParams (self .frozen_params , no_convert = True )
620
+ # def _make_detached_param(x):
621
+ #
622
+ # if isinstance(x, nn.Parameter):
623
+ # # we need an nn.Parameter since some modules (RNN) require nn.Parameters
624
+ # return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
625
+ # elif x.requires_grad:
626
+ # raise ValueError(
627
+ # "Encountered a value that requires gradients but is not an nn.Parameter instance."
628
+ # )
629
+ # return x.clone()
630
+ # self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
631
+ # if requires_grad:
632
+ # # includes the frozen params/buffers in the module parameters/buffers
633
+ # self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
625
634
626
635
# self._buffers["actor_params"] = params.clone().detach()
627
636
637
+ self .action_key = action_key
638
+
628
639
# find the sample log-prob key
629
- self .sample_log_prob_key = "sample_log_prob"
640
+ self .sample_log_prob_key = log_prob_key
630
641
631
642
def find_sample_log_prob (module ):
632
643
if hasattr (module , "log_prob_key" ):
@@ -647,16 +658,25 @@ def _reset(
647
658
648
659
def _call (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
649
660
# run the actor on the tensordict
650
- action = next_tensordict .get ("action" , None )
661
+ action = next_tensordict .get (self . action_key , None )
651
662
if action is None :
652
663
# being called after reset or without action, skipping
653
664
if self .out_keys [0 ] != ("reward" ,) and self .parent is not None :
654
665
next_tensordict .set (self .out_keys [0 ], self .parent .reward_spec .zero ())
655
666
return next_tensordict
656
- with self .frozen_params .to_module (self .functional_actor ):
657
- dist = self .functional_actor .get_dist (next_tensordict .clone (False ))
658
- # get the log_prob given the original model
659
- log_prob = dist .log_prob (action )
667
+ # with self.frozen_params.to_module(self.functional_actor):
668
+ if isinstance (
669
+ self .functional_actor ,
670
+ (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential ),
671
+ ):
672
+ dist = self .functional_actor .get_dist (next_tensordict .copy ())
673
+ # get the log_prob given the original model
674
+ log_prob = dist .log_prob (action )
675
+ else :
676
+ log_prob = self .functional_actor (next_tensordict .copy ()).get (
677
+ self .sample_log_prob_key
678
+ )
679
+
660
680
reward_key = self .in_keys [0 ]
661
681
reward = next_tensordict .get ("next" ).get (reward_key )
662
682
curr_log_prob = next_tensordict .get (self .sample_log_prob_key )
@@ -679,12 +699,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
679
699
680
700
if in_key == "reward" and out_key == "reward" :
681
701
parent = self .parent
702
+
703
+ reward_keys = parent .reward_keys
704
+ if len (reward_keys ) == 1 :
705
+ reward_key = reward_keys [0 ]
706
+ elif "reward" in reward_keys :
707
+ reward_key = "reward"
708
+ else :
709
+ raise KeyError ("Couln't find the reward key." )
710
+
682
711
reward_spec = Unbounded (
683
712
device = output_spec .device ,
684
- shape = output_spec ["full_reward_spec" ][parent . reward_key ].shape ,
713
+ shape = output_spec ["full_reward_spec" ][reward_key ].shape ,
685
714
)
686
715
output_spec ["full_reward_spec" ] = Composite (
687
- {parent . reward_key : reward_spec },
716
+ {reward_key : reward_spec },
688
717
shape = output_spec ["full_reward_spec" ].shape ,
689
718
)
690
719
elif in_key == "reward" :
0 commit comments