diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index b1db6c6712a..358ef2006d2 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -2371,7 +2371,10 @@ def forward( action_entry = parent_td.get(action_key_orig[-1], None) if action_entry is None: raise self._NO_INIT_ERR - if self.n_steps is not None and action_entry.shape[parent_td.ndim] != self.n_steps: + if ( + self.n_steps is not None + and action_entry.shape[parent_td.ndim] != self.n_steps + ): raise RuntimeError( f"The action's time dimension (dim={parent_td.ndim}) doesn't match the n_steps argument ({self.n_steps}). " f"The action shape was {action_entry.shape}."