@@ -1765,8 +1765,9 @@ class _MultiDataCollector(DataCollectorBase):
1765
1765
.. warning:: `policy_factory` is currently not compatible with multiprocessed data
1766
1766
collectors.
1767
1767
1768
- frames_per_batch (int): A keyword-only argument representing the
1769
- total number of elements in a batch.
1768
+ frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
1769
+ total number of elements in a batch. If a sequence is provided, represents the number of elements in a
1770
+ batch per worker. Total number of elements in a batch is then the sum over the sequence.
1770
1771
total_frames (int, optional): A keyword-only argument representing the
1771
1772
total number of frames returned by the collector
1772
1773
during its lifespan. If the ``total_frames`` is not divisible by
@@ -1923,7 +1924,7 @@ def __init__(
1923
1924
policy_factory : Callable [[], Callable ]
1924
1925
| list [Callable [[], Callable ]]
1925
1926
| None = None ,
1926
- frames_per_batch : int ,
1927
+ frames_per_batch : int | Sequence [ int ] ,
1927
1928
total_frames : int | None = - 1 ,
1928
1929
device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
1929
1930
storing_device : DEVICE_TYPING | Sequence [DEVICE_TYPING ] | None = None ,
@@ -1959,6 +1960,22 @@ def __init__(
1959
1960
self .closed = True
1960
1961
self .num_workers = len (create_env_fn )
1961
1962
1963
+ if (
1964
+ isinstance (frames_per_batch , Sequence )
1965
+ and len (frames_per_batch ) != self .num_workers
1966
+ ):
1967
+ raise ValueError (
1968
+ "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
1969
+ f"Got { len (frames_per_batch )} values for { self .num_workers } workers."
1970
+ )
1971
+
1972
+ self ._frames_per_batch = frames_per_batch
1973
+ total_frames_per_batch = (
1974
+ sum (frames_per_batch )
1975
+ if isinstance (frames_per_batch , Sequence )
1976
+ else frames_per_batch
1977
+ )
1978
+
1962
1979
self .set_truncated = set_truncated
1963
1980
self .num_sub_threads = num_sub_threads
1964
1981
self .num_threads = num_threads
@@ -2076,11 +2093,11 @@ def __init__(
2076
2093
if total_frames is None or total_frames < 0 :
2077
2094
total_frames = float ("inf" )
2078
2095
else :
2079
- remainder = total_frames % frames_per_batch
2096
+ remainder = total_frames % total_frames_per_batch
2080
2097
if remainder != 0 and RL_WARNINGS :
2081
2098
warnings .warn (
2082
- f"total_frames ({ total_frames } ) is not exactly divisible by frames_per_batch ({ frames_per_batch } ). "
2083
- f"This means { frames_per_batch - remainder } additional frames will be collected. "
2099
+ f"total_frames ({ total_frames } ) is not exactly divisible by frames_per_batch ({ total_frames_per_batch } ). "
2100
+ f"This means { total_frames_per_batch - remainder } additional frames will be collected. "
2084
2101
"To silence this message, set the environment variable RL_WARNINGS to False."
2085
2102
)
2086
2103
self .total_frames = (
@@ -2091,7 +2108,8 @@ def __init__(
2091
2108
self .max_frames_per_traj = (
2092
2109
int (max_frames_per_traj ) if max_frames_per_traj is not None else 0
2093
2110
)
2094
- self .requested_frames_per_batch = int (frames_per_batch )
2111
+
2112
+ self .requested_frames_per_batch = total_frames_per_batch
2095
2113
self .reset_when_done = reset_when_done
2096
2114
if split_trajs is None :
2097
2115
split_trajs = False
@@ -2221,8 +2239,7 @@ def _get_devices(
2221
2239
)
2222
2240
return storing_device , policy_device , env_device
2223
2241
2224
- @property
2225
- def frames_per_batch_worker (self ):
2242
+ def frames_per_batch_worker (self , worker_idx : int | None = None ) -> int :
2226
2243
raise NotImplementedError
2227
2244
2228
2245
@property
@@ -2281,7 +2298,7 @@ def _run_processes(self) -> None:
2281
2298
"create_env_kwargs" : env_fun_kwargs ,
2282
2299
"policy" : policy ,
2283
2300
"max_frames_per_traj" : self .max_frames_per_traj ,
2284
- "frames_per_batch" : self .frames_per_batch_worker ,
2301
+ "frames_per_batch" : self .frames_per_batch_worker ( worker_idx = i ) ,
2285
2302
"reset_at_each_iter" : self .reset_at_each_iter ,
2286
2303
"policy_device" : policy_device ,
2287
2304
"storing_device" : storing_device ,
@@ -2773,8 +2790,9 @@ def update_policy_weights_(
2773
2790
policy_or_weights = policy_or_weights , worker_ids = worker_ids , ** kwargs
2774
2791
)
2775
2792
2776
- @property
2777
- def frames_per_batch_worker (self ):
2793
+ def frames_per_batch_worker (self , worker_idx : int | None ) -> int :
2794
+ if worker_idx is not None and isinstance (self ._frames_per_batch , Sequence ):
2795
+ return self ._frames_per_batch [worker_idx ]
2778
2796
if self .requested_frames_per_batch % self .num_workers != 0 and RL_WARNINGS :
2779
2797
warnings .warn (
2780
2798
f"frames_per_batch { self .requested_frames_per_batch } is not exactly divisible by the number of collector workers { self .num_workers } ,"
@@ -2855,9 +2873,9 @@ def iterator(self) -> Iterator[TensorDictBase]:
2855
2873
use_buffers = self ._use_buffers
2856
2874
if self .replay_buffer is not None :
2857
2875
idx = new_data
2858
- workers_frames [idx ] = (
2859
- workers_frames [ idx ] + self . frames_per_batch_worker
2860
- )
2876
+ workers_frames [idx ] = workers_frames [
2877
+ idx
2878
+ ] + self . frames_per_batch_worker ( worker_idx = idx )
2861
2879
continue
2862
2880
elif j == 0 or not use_buffers :
2863
2881
try :
@@ -2903,7 +2921,12 @@ def iterator(self) -> Iterator[TensorDictBase]:
2903
2921
2904
2922
if self .replay_buffer is not None :
2905
2923
yield
2906
- self ._frames += self .frames_per_batch_worker * self .num_workers
2924
+ self ._frames += sum (
2925
+ [
2926
+ self .frames_per_batch_worker (worker_idx )
2927
+ for worker_idx in range (self .num_workers )
2928
+ ]
2929
+ )
2907
2930
continue
2908
2931
2909
2932
# we have to correct the traj_ids to make sure that they don't overlap
@@ -3156,8 +3179,7 @@ def update_policy_weights_(
3156
3179
policy_or_weights = policy_or_weights , worker_ids = worker_ids , ** kwargs
3157
3180
)
3158
3181
3159
- @property
3160
- def frames_per_batch_worker (self ):
3182
+ def frames_per_batch_worker (self , worker_idx : int | None = None ) -> int :
3161
3183
return self .requested_frames_per_batch
3162
3184
3163
3185
def _get_from_queue (self , timeout = None ) -> tuple [int , int , TensorDictBase ]:
@@ -3221,7 +3243,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
3221
3243
if self .split_trajs :
3222
3244
out = split_trajectories (out , prefix = "collector" )
3223
3245
else :
3224
- worker_frames = self .frames_per_batch_worker
3246
+ worker_frames = self .frames_per_batch_worker ()
3225
3247
self ._frames += worker_frames
3226
3248
workers_frames [idx ] = workers_frames [idx ] + worker_frames
3227
3249
if self .postprocs :
0 commit comments