Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 10, 2025
1 parent 7a3ae4f commit 4148190
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4818,7 +4818,7 @@ def get(self, item, default=NO_DEFAULT):
def __setitem__(self, key, value):
dest = self
if isinstance(key, tuple) and len(key) > 1:
while key[0] not in self.keys():
while key[0] not in dest.keys():
dest[key[0]] = dest = Composite(shape=self.shape, device=self.device)
if len(key) > 2:
key = key[1:]
Expand Down
10 changes: 7 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8037,9 +8037,11 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
input_spec["full_action_spec"][out_key] = input_spec[
"full_action_spec"
][action_key].clone()
if not self.create_copy:
if not self.create_copy:
for action_key in self.parent.action_keys:
if action_key in self.in_keys_inv:
del input_spec["full_action_spec"][action_key]
for state_key in self.parent.full_state_spec.keys(True):
for state_key in self.parent.full_state_spec.keys(True, True):
if state_key in self.in_keys_inv:
for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
if self.in_keys_inv[i] == state_key:
Expand All @@ -8050,7 +8052,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
input_spec["full_state_spec"][out_key] = input_spec["full_state_spec"][
state_key
].clone()
if not self.create_copy:
if not self.create_copy:
for state_key in self.parent.full_state_spec.keys(True, True):
if state_key in self.in_keys_inv:
del input_spec["full_state_spec"][state_key]
return input_spec

Expand Down

0 comments on commit 4148190

Please sign in to comment.