Skip to content

Commit eeeeb2b

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents c076c0c + f1c42e0 commit eeeeb2b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+590
-294
lines changed

.github/workflows/docs.yml

+6-6
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ jobs:
2626
build-docs:
2727
strategy:
2828
matrix:
29-
python_version: ["3.9"]
30-
cuda_arch_version: ["12.4"]
29+
python_version: [ "3.9" ]
30+
cuda_arch_version: [ "12.4" ]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3232
with:
33+
runner: linux.g5.4xlarge.nvidia.gpu
3334
repository: pytorch/rl
3435
upload-artifact: docs
3536
timeout: 120
@@ -38,7 +39,6 @@ jobs:
3839
set -v
3940
# apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils
4041
yum makecache
41-
# yum install -y glfw glew mesa-libGL mesa-libGL-devel mesa-libOSMesa-devel egl-utils freeglut
4242
# Install Mesa and OpenGL Libraries:
4343
yum install -y glfw mesa-libGL mesa-libGL-devel egl-utils freeglut mesa-libGLU mesa-libEGL
4444
# Install DRI Drivers:
@@ -112,7 +112,7 @@ jobs:
112112
cd ./docs
113113
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
114114
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
115-
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build
115+
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build -v -j 4
116116
cd ..
117117
118118
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
@@ -123,8 +123,8 @@ jobs:
123123
124124
upload:
125125
needs: build-docs
126-
if: github.repository == 'pytorch/rl' && github.event_name == 'push' &&
127-
((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag')
126+
if: github.repository == 'pytorch/rl' && github.event_name == 'push' &&
127+
((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag')
128128
permissions:
129129
contents: write
130130
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main

docs/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ vmas
2828
onnxscript
2929
onnxruntime
3030
onnx
31+
psutil

docs/source/conf.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
import pytorch_sphinx_theme
2929
import torchrl
3030

31-
# Suppress warnings - TODO
32-
# suppress_warnings = [ 'misc.highlighting_failure' ]
31+
# Suppress warnings
3332
warnings.filterwarnings("ignore", category=UserWarning)
3433

3534
project = "torchrl"
@@ -86,6 +85,21 @@
8685
"torchvision": ("https://pytorch.org/vision/stable/", None),
8786
}
8887

88+
89+
def kill_procs(gallery_conf, fname):
90+
import os
91+
92+
import psutil
93+
94+
# Get the current process
95+
current_proc = psutil.Process(os.getpid())
96+
# Iterate over all child processes
97+
for child in current_proc.children(recursive=True):
98+
# Kill the child process
99+
child.terminate()
100+
print(f"Killed child process with PID {child.pid}") # noqa: T201
101+
102+
89103
sphinx_gallery_conf = {
90104
"examples_dirs": "reference/generated/tutorials/", # path to your example scripts
91105
"gallery_dirs": "tutorials", # path to where to save gallery generated output
@@ -95,9 +109,12 @@
95109
"notebook_images": "reference/generated/tutorials/media/", # images to parse
96110
"download_all_examples": True,
97111
"abort_on_example_error": True,
98-
"show_memory": True,
112+
# "show_memory": True,
113+
"plot_gallery": "False",
99114
"capture_repr": ("_repr_html_", "__repr__"), # capture representations
100115
"write_computation_times": True,
116+
# "compress_images": ("images", "thumbnails"),
117+
"reset_modules": (kill_procs, "matplotlib", "seaborn"),
101118
}
102119

103120
napoleon_use_ivar = True

docs/source/reference/envs.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,7 @@ to be able to create this other composition:
976976
Hash
977977
InitTracker
978978
KLRewardTransform
979-
LineariseReward
979+
LineariseRewards
980980
MultiAction
981981
NoopResetEnv
982982
ObservationNorm

test/mocking_classes.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,9 @@ def _step(
714714
while done.shape != tensordict.shape:
715715
done = done.any(-1)
716716
done = reward = done.unsqueeze(-1)
717-
tensordict.set("reward", reward.to(torch.get_default_dtype()))
717+
tensordict.set(
718+
"reward", reward.to(self.reward_spec.dtype).expand(self.reward_spec.shape)
719+
)
718720
tensordict.set("done", done)
719721
tensordict.set("terminated", done)
720722
return tensordict

torchrl/_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def reset(cls, setters_dict: Dict[str, implement_for] = None):
513513
"""Resets the setters in setter_dict.
514514
515515
``setter_dict`` is a copy of implementations. We just need to iterate through its
516-
values and call :meth:`~.module_set` for each.
516+
values and call :meth:`module_set` for each.
517517
518518
"""
519519
if VERBOSE:
@@ -888,7 +888,7 @@ def _standardize(
888888
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
889889
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
890890
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
891-
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
891+
eps (:obj:`float`): epsilon to be used for numerical stability. Default: float32 resolution.
892892
893893
"""
894894
if eps is None:

torchrl/collectors/collectors.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,12 @@ class SyncDataCollector(DataCollectorBase):
339339
instances) it will be wrapped in a `nn.Module` first.
340340
Then, the collector will try to assess if these
341341
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
342+
342343
- If the policy forward signature matches any of ``forward(self, tensordict)``,
343344
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
344345
any typing with a single argument typed as a subclass of ``TensorDictBase``)
345346
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
347+
346348
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
347349
348350
Keyword Args:
@@ -1462,6 +1464,7 @@ class _MultiDataCollector(DataCollectorBase):
14621464
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
14631465
any typing with a single argument typed as a subclass of ``TensorDictBase``)
14641466
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
1467+
14651468
- In all other cases an attempt to wrap it will be undergone as such:
14661469
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
14671470
@@ -1548,7 +1551,7 @@ class _MultiDataCollector(DataCollectorBase):
15481551
reset_when_done (bool, optional): if ``True`` (default), an environment
15491552
that return a ``True`` value in its ``"done"`` or ``"truncated"``
15501553
entry will be reset at the corresponding indices.
1551-
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
1554+
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
15521555
will be called before (sync) or after (async) each data collection.
15531556
Defaults to ``False``.
15541557
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
@@ -2774,10 +2777,12 @@ class aSyncDataCollector(MultiaSyncDataCollector):
27742777
instances) it will be wrapped in a `nn.Module` first.
27752778
Then, the collector will try to assess if these
27762779
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
2780+
27772781
- If the policy forward signature matches any of ``forward(self, tensordict)``,
27782782
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
27792783
any typing with a single argument typed as a subclass of ``TensorDictBase``)
27802784
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
2785+
27812786
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
27822787
27832788
Keyword Args:
@@ -2863,7 +2868,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
28632868
reset_when_done (bool, optional): if ``True`` (default), an environment
28642869
that return a ``True`` value in its ``"done"`` or ``"truncated"``
28652870
entry will be reset at the corresponding indices.
2866-
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
2871+
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
28672872
will be called before (sync) or after (async) each data collection.
28682873
Defaults to ``False``.
28692874
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers

torchrl/collectors/distributed/generic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,12 @@ class DistributedDataCollector(DataCollectorBase):
262262
instances) it will be wrapped in a `nn.Module` first.
263263
Then, the collector will try to assess if these
264264
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
265+
265266
- If the policy forward signature matches any of ``forward(self, tensordict)``,
266267
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
267268
any typing with a single argument typed as a subclass of ``TensorDictBase``)
268269
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
270+
269271
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
270272
271273
Keyword Args:
@@ -341,7 +343,7 @@ class DistributedDataCollector(DataCollectorBase):
341343
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
342344
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
343345
or ``torchrl.envs.utils.ExplorationType.MEAN``.
344-
collector_class (type or str, optional): a collector class for the remote node. Can be
346+
collector_class (Type or str, optional): a collector class for the remote node. Can be
345347
:class:`~torchrl.collectors.SyncDataCollector`,
346348
:class:`~torchrl.collectors.MultiSyncDataCollector`,
347349
:class:`~torchrl.collectors.MultiaSyncDataCollector`

torchrl/collectors/distributed/ray.py

+2
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,12 @@ class RayCollector(DataCollectorBase):
135135
instances) it will be wrapped in a `nn.Module` first.
136136
Then, the collector will try to assess if these
137137
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
138+
138139
- If the policy forward signature matches any of ``forward(self, tensordict)``,
139140
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
140141
any typing with a single argument typed as a subclass of ``TensorDictBase``)
141142
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
143+
142144
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
143145
144146
Keyword Args:

torchrl/collectors/distributed/rpc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,12 @@ class RPCDataCollector(DataCollectorBase):
110110
instances) it will be wrapped in a `nn.Module` first.
111111
Then, the collector will try to assess if these
112112
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
113+
113114
- If the policy forward signature matches any of ``forward(self, tensordict)``,
114115
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
115116
any typing with a single argument typed as a subclass of ``TensorDictBase``)
116117
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
118+
117119
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
118120
119121
Keyword Args:
@@ -190,7 +192,7 @@ class RPCDataCollector(DataCollectorBase):
190192
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
191193
or ``torchrl.envs.utils.ExplorationType.MEAN``.
192194
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
193-
collector_class (type or str, optional): a collector class for the remote node. Can be
195+
collector_class (Type or str, optional): a collector class for the remote node. Can be
194196
:class:`~torchrl.collectors.SyncDataCollector`,
195197
:class:`~torchrl.collectors.MultiSyncDataCollector`,
196198
:class:`~torchrl.collectors.MultiaSyncDataCollector`

torchrl/collectors/distributed/sync.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,12 @@ class DistributedSyncDataCollector(DataCollectorBase):
143143
instances) it will be wrapped in a `nn.Module` first.
144144
Then, the collector will try to assess if these
145145
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
146+
146147
- If the policy forward signature matches any of ``forward(self, tensordict)``,
147148
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
148149
any typing with a single argument typed as a subclass of ``TensorDictBase``)
149150
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
151+
150152
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
151153
152154
Keyword Args:
@@ -222,7 +224,7 @@ class DistributedSyncDataCollector(DataCollectorBase):
222224
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
223225
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
224226
or ``torchrl.envs.utils.ExplorationType.MEAN``.
225-
collector_class (type or str, optional): a collector class for the remote node. Can be
227+
collector_class (Type or str, optional): a collector class for the remote node. Can be
226228
:class:`~torchrl.collectors.SyncDataCollector`,
227229
:class:`~torchrl.collectors.MultiSyncDataCollector`,
228230
:class:`~torchrl.collectors.MultiaSyncDataCollector`

torchrl/data/datasets/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def preprocess(
7272
7373
Args and Keyword Args are forwarded to :meth:`~tensordict.TensorDictBase.map`.
7474
75-
The dataset can subsequently be deleted using :meth:`~.delete`.
75+
The dataset can subsequently be deleted using :meth:`delete`.
7676
7777
Keyword Args:
7878
dest (path or equivalent): a path to the location of the new dataset.

torchrl/data/datasets/openx.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class for more information on how to interact with non-tensor data
6666
sampling strategy.
6767
If the ``batch_size`` is ``None`` (default), iterating over the
6868
dataset will deliver trajectories one at a time *whereas* calling
69-
:meth:`~.sample` will *still* require a batch-size to be provided.
69+
:meth:`sample` will *still* require a batch-size to be provided.
7070
7171
Keyword Args:
7272
shuffle (bool, optional): if ``True``, trajectories are delivered in a
@@ -115,7 +115,7 @@ class for more information on how to interact with non-tensor data
115115
replacement (bool, optional): if ``False``, sampling will be done
116116
without replacement. Defaults to ``True`` for downloaded datasets,
117117
``False`` for streamed datasets.
118-
pad (bool, float or None): if ``True``, trajectories of insufficient length
118+
pad (bool, :obj:`float` or None): if ``True``, trajectories of insufficient length
119119
given the `slice_len` or `num_slices` arguments will be padded with
120120
0s. If another value is provided, it will be used for padding. If
121121
``False`` or ``None`` (default) any encounter with a trajectory of

torchrl/data/map/tdstorage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def from_tensordict_pair(
193193
in the storage. Defaults to ``None`` (all keys are registered).
194194
max_size (int, optional): the maximum number of elements in the storage. Ignored if the
195195
``storage_constructor`` is passed. Defaults to ``1000``.
196-
storage_constructor (type, optional): a type of tensor storage.
196+
storage_constructor (Type, optional): a type of tensor storage.
197197
Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`.
198198
Other options include :class:`~tensordict.nn.storage.FixedStorage`.
199199
hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`.

0 commit comments

Comments
 (0)