Skip to content

Commit f16655f

Browse files
committed
Update
[ghstack-poisoned]
1 parent 4136fb1 commit f16655f

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

Diff for: torchrl/envs/custom/llm.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,15 @@ def _make_next_obs(
325325
if self.attention_key is not None:
326326
attention_mask = tensordict.get(self.attention_key)
327327
n = action.shape[-1] - attention_mask.shape[-1]
328-
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[:-1] + (n,))], -1)
328+
if n > 0:
329+
# It can happen that there's only one action (eg rand_action)
330+
attention_mask = torch.cat(
331+
[
332+
attention_mask,
333+
attention_mask.new_ones(attention_mask.shape[:-1] + (n,)),
334+
],
335+
-1,
336+
)
329337
nex_td.set(self.attention_key, attention_mask)
330338
return nex_td
331339

Diff for: torchrl/envs/transforms/transforms.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -6361,10 +6361,13 @@ def _reset_func(
63616361

63626362
def __repr__(self) -> str:
63636363
class_name = self.__class__.__name__
6364-
default_value = {
6365-
key: value if isinstance(value, float) else "Callable"
6366-
for key, value in self.default_value.items()
6367-
}
6364+
if callable(self.default_value):
6365+
default_value = self.default_value
6366+
else:
6367+
default_value = {
6368+
key: value if isinstance(value, float) else "Callable"
6369+
for key, value in self.default_value.items()
6370+
}
63686371
return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})"
63696372

63706373

0 commit comments

Comments
 (0)