Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Solve ref issues in docstrings #2776

Merged
merged 27 commits into from
Feb 11, 2025
Merged
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ jobs:
cd ./docs
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build -v
cd ..

cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
Expand Down
4 changes: 4 additions & 0 deletions docs/source/content_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def generate_tutorial_references(tutorial_path: str, file_type: str) -> None:
for f in os.listdir(tutorial_path)
if f.endswith((".py", ".rst", ".png"))
]
# Make rb_tutorial.py the first one
file_paths = [p for p in file_paths if p.endswith("rb_tutorial.py")] + [
p for p in file_paths if not p.endswith("rb_tutorial.py")
]

for file_path in file_paths:
shutil.copyfile(file_path, os.path.join(target_path, Path(file_path).name))
3 changes: 2 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ component (sub-environments or agents) should be reset.
This allows to reset some but not all of the components.

The ``"_reset"`` key has two distinct functionalities:

1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may
not be present in the input tensordict. TorchRL's convention is that the
absence of the ``"_reset"`` key at a given ``"done"`` level indicates
Expand Down Expand Up @@ -899,7 +900,7 @@ to be able to create this other composition:
Hash
InitTracker
KLRewardTransform
LineariseReward
LineariseRewards
NoopResetEnv
ObservationNorm
ObservationTransform
Expand Down
4 changes: 2 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def reset(cls, setters_dict: Dict[str, implement_for] = None):
"""Resets the setters in setter_dict.

``setter_dict`` is a copy of implementations. We just need to iterate through its
values and call :meth:`~.module_set` for each.
values and call :meth:`module_set` for each.

"""
if VERBOSE:
Expand Down Expand Up @@ -888,7 +888,7 @@ def _standardize(
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
eps (:obj:`float`): epsilon to be used for numerical stability. Default: float32 resolution.

"""
if eps is None:
Expand Down
9 changes: 7 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,12 @@ class SyncDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

Keyword Args:
Expand Down Expand Up @@ -1462,6 +1464,7 @@ class _MultiDataCollector(DataCollectorBase):
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

- In all other cases an attempt to wrap it will be undergone as such:
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

Expand Down Expand Up @@ -1548,7 +1551,7 @@ class _MultiDataCollector(DataCollectorBase):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
Expand Down Expand Up @@ -2774,10 +2777,12 @@ class aSyncDataCollector(MultiaSyncDataCollector):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

Keyword Args:
Expand Down Expand Up @@ -2863,7 +2868,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,12 @@ class DistributedDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

Keyword Args:
Expand Down Expand Up @@ -341,7 +343,7 @@ class DistributedDataCollector(DataCollectorBase):
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
2 changes: 2 additions & 0 deletions torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ class RayCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

Keyword Args:
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ class RPCDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

Keyword Args:
Expand Down Expand Up @@ -190,7 +192,7 @@ class RPCDataCollector(DataCollectorBase):
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ class DistributedSyncDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

Keyword Args:
Expand Down Expand Up @@ -222,7 +224,7 @@ class DistributedSyncDataCollector(DataCollectorBase):
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def preprocess(

Args and Keyword Args are forwarded to :meth:`~tensordict.TensorDictBase.map`.

The dataset can subsequently be deleted using :meth:`~.delete`.
The dataset can subsequently be deleted using :meth:`delete`.

Keyword Args:
dest (path or equivalent): a path to the location of the new dataset.
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class for more information on how to interact with non-tensor data
sampling strategy.
If the ``batch_size`` is ``None`` (default), iterating over the
dataset will deliver trajectories one at a time *whereas* calling
:meth:`~.sample` will *still* require a batch-size to be provided.
:meth:`sample` will *still* require a batch-size to be provided.

Keyword Args:
shuffle (bool, optional): if ``True``, trajectories are delivered in a
Expand Down Expand Up @@ -115,7 +115,7 @@ class for more information on how to interact with non-tensor data
replacement (bool, optional): if ``False``, sampling will be done
without replacement. Defaults to ``True`` for downloaded datasets,
``False`` for streamed datasets.
pad (bool, float or None): if ``True``, trajectories of insufficient length
pad (bool, :obj:`float` or None): if ``True``, trajectories of insufficient length
given the `slice_len` or `num_slices` arguments will be padded with
0s. If another value is provided, it will be used for padding. If
``False`` or ``None`` (default) any encounter with a trajectory of
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_tensordict_pair(
in the storage. Defaults to ``None`` (all keys are registered).
max_size (int, optional): the maximum number of elements in the storage. Ignored if the
``storage_constructor`` is passed. Defaults to ``1000``.
storage_constructor (type, optional): a type of tensor storage.
storage_constructor (Type, optional): a type of tensor storage.
Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`.
Other options include :class:`~tensordict.nn.storage.FixedStorage`.
hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`.
Expand Down
19 changes: 11 additions & 8 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Tree(TensorClass["nocast"]):
node_id (int): A unique identifier for this node.
rollout (TensorDict): Rollout data following the observation encoded in this node, in a TED format.
If there are multiple actions taken at this node, subtrees are stored in the corresponding
entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method.
entry. Rollouts can be reconstructed using the :meth:`rollout_from_path` method.
node (TensorDict): Data defining this node (e.g., observations) before the next branching.
Entries usually matches the ``in_keys`` in ``MCTSForest.node_map``.
subtree (Tree): A stack of subtrees produced when actions are taken.
Expand Down Expand Up @@ -215,7 +215,7 @@ def node_observation(self) -> torch.Tensor | TensorDictBase:
"""Returns the observation associated with this particular node.

This is the observation (or bag of observations) that defines the node before a branching occurs.
If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the
If the node contains a :meth:`rollout` attribute, the node observation is typically identical to the
observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``.

If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance
Expand All @@ -232,7 +232,7 @@ def node_observations(self) -> torch.Tensor | TensorDictBase:
"""Returns the observations associated with this particular node in a TensorDict format.

This is the observation (or bag of observations) that defines the node before a branching occurs.
If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the
If the node contains a :meth:`rollout` attribute, the node observation is typically identical to the
observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``.

If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance
Expand Down Expand Up @@ -442,8 +442,11 @@ def num_vertices(self, *, count_repeat: bool = False) -> int:
"""Returns the number of unique vertices in the Tree.

Keyword Args:
count_repeat (bool, optional): Determines whether to count repeated vertices.
count_repeat (bool, optional): Determines whether to count repeated
vertices.

- If ``False``, counts each unique vertex only once.

- If ``True``, counts vertices multiple times if they appear in different paths.
Defaults to ``False``.

Expand Down Expand Up @@ -629,16 +632,16 @@ class MCTSForest:
``node_map.max_size``. If none of these are provided, defaults to `1000`.
done_keys (list of NestedKey, optional): the done keys of the environment. If not provided,
defaults to ``("done", "terminated", "truncated")``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
action_keys (list of NestedKey, optional): the action keys of the environment. If not provided,
defaults to ``("action",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
reward_keys (list of NestedKey, optional): the reward keys of the environment. If not provided,
defaults to ``("reward",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
observation_keys (list of NestedKey, optional): the observation keys of the environment. If not provided,
defaults to ``("observation",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
excluded_keys (list of NestedKey, optional): a list of keys to exclude from the data storage.
consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk.
Defaults to ``False``.
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
environment (i.e. before multi-step);
- The "reward" values will be replaced by the newly computed
rewards.

The ``"done"`` key can have either the shape of the tensordict
OR the shape of the tensordict followed by a singleton
dimension OR the shape of the tensordict followed by other
Expand Down
Loading
Loading