@@ -339,10 +339,12 @@ class SyncDataCollector(DataCollectorBase):
339
339
instances) it will be wrapped in a `nn.Module` first.
340
340
Then, the collector will try to assess if these
341
341
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
342
+
342
343
- If the policy forward signature matches any of ``forward(self, tensordict)``,
343
344
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
344
345
any typing with a single argument typed as a subclass of ``TensorDictBase``)
345
346
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
347
+
346
348
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
347
349
348
350
Keyword Args:
@@ -1462,6 +1464,7 @@ class _MultiDataCollector(DataCollectorBase):
1462
1464
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
1463
1465
any typing with a single argument typed as a subclass of ``TensorDictBase``)
1464
1466
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
1467
+
1465
1468
- In all other cases an attempt to wrap it will be undergone as such:
1466
1469
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
1467
1470
@@ -1548,7 +1551,7 @@ class _MultiDataCollector(DataCollectorBase):
1548
1551
reset_when_done (bool, optional): if ``True`` (default), an environment
1549
1552
that return a ``True`` value in its ``"done"`` or ``"truncated"``
1550
1553
entry will be reset at the corresponding indices.
1551
- update_at_each_batch (boolm optional): if ``True``, :meth:`~. update_policy_weight_()`
1554
+ update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
1552
1555
will be called before (sync) or after (async) each data collection.
1553
1556
Defaults to ``False``.
1554
1557
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
@@ -2774,10 +2777,12 @@ class aSyncDataCollector(MultiaSyncDataCollector):
2774
2777
instances) it will be wrapped in a `nn.Module` first.
2775
2778
Then, the collector will try to assess if these
2776
2779
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
2780
+
2777
2781
- If the policy forward signature matches any of ``forward(self, tensordict)``,
2778
2782
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
2779
2783
any typing with a single argument typed as a subclass of ``TensorDictBase``)
2780
2784
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
2785
+
2781
2786
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
2782
2787
2783
2788
Keyword Args:
@@ -2863,7 +2868,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
2863
2868
reset_when_done (bool, optional): if ``True`` (default), an environment
2864
2869
that return a ``True`` value in its ``"done"`` or ``"truncated"``
2865
2870
entry will be reset at the corresponding indices.
2866
- update_at_each_batch (boolm optional): if ``True``, :meth:`~. update_policy_weight_()`
2871
+ update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
2867
2872
will be called before (sync) or after (async) each data collection.
2868
2873
Defaults to ``False``.
2869
2874
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
0 commit comments