Skip to content

Commit c275988

Browse files
authored
[Feature] Enabling worker level control on frames_per_batch (#3020)
1 parent 6e38458 commit c275988

File tree

3 files changed

+129
-20
lines changed

3 files changed

+129
-20
lines changed

test/test_collector.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,94 @@ def create_env():
15531553
collector.shutdown()
15541554
del collector
15551555

1556+
@pytest.mark.parametrize("num_env", [1, 2])
1557+
@pytest.mark.parametrize("env_name", ["vec"])
1558+
@pytest.mark.parametrize("frames_per_batch_worker", [[10, 10], [15, 5]])
1559+
def test_collector_frames_per_batch_worker(
1560+
self,
1561+
num_env,
1562+
env_name,
1563+
frames_per_batch_worker,
1564+
seed=100,
1565+
num_workers=2,
1566+
):
1567+
"""Tests that there are 'sum(frames_per_batch_worker)' frames in each batch of a collection."""
1568+
if num_env == 1:
1569+
1570+
def env_fn():
1571+
env = make_make_env(env_name)()
1572+
return env
1573+
1574+
else:
1575+
1576+
def env_fn():
1577+
# 1226: For efficiency, we don't use Parallel but Serial
1578+
# env = ParallelEnv(
1579+
env = SerialEnv(
1580+
num_workers=num_env, create_env_fn=make_make_env(env_name)
1581+
)
1582+
return env
1583+
1584+
policy = make_policy(env_name)
1585+
1586+
torch.manual_seed(0)
1587+
np.random.seed(0)
1588+
1589+
frames_per_batch = sum(frames_per_batch_worker)
1590+
1591+
collector = MultiaSyncDataCollector(
1592+
create_env_fn=[env_fn for _ in range(num_workers)],
1593+
policy=policy,
1594+
frames_per_batch=frames_per_batch_worker,
1595+
max_frames_per_traj=1000,
1596+
total_frames=frames_per_batch * 100,
1597+
)
1598+
try:
1599+
collector.set_seed(seed)
1600+
for i, b in enumerate(collector):
1601+
assert b.numel() == -(-frames_per_batch // num_env) * num_env
1602+
if i == 5:
1603+
break
1604+
assert b.names[-1] == "time"
1605+
finally:
1606+
collector.shutdown()
1607+
1608+
collector = MultiSyncDataCollector(
1609+
create_env_fn=[env_fn for _ in range(num_workers)],
1610+
policy=policy,
1611+
frames_per_batch=frames_per_batch,
1612+
max_frames_per_traj=1000,
1613+
total_frames=frames_per_batch * 100,
1614+
cat_results="stack",
1615+
)
1616+
try:
1617+
collector.set_seed(seed)
1618+
for i, b in enumerate(collector):
1619+
assert (
1620+
b.numel()
1621+
== -(-frames_per_batch // num_env // num_workers)
1622+
* num_env
1623+
* num_workers
1624+
)
1625+
if i == 5:
1626+
break
1627+
assert b.names[-1] == "time"
1628+
finally:
1629+
collector.shutdown()
1630+
del collector
1631+
1632+
with pytest.raises(
1633+
ValueError,
1634+
match="If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker.",
1635+
):
1636+
collector = MultiSyncDataCollector(
1637+
create_env_fn=[env_fn for _ in range(num_workers)],
1638+
policy=policy,
1639+
frames_per_batch=frames_per_batch_worker[:-1],
1640+
max_frames_per_traj=1000,
1641+
total_frames=frames_per_batch * 100,
1642+
)
1643+
15561644

15571645
class TestCollectorDevices:
15581646
class DeviceLessEnv(EnvBase):

torchrl/collectors/collectors.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,8 +1765,9 @@ class _MultiDataCollector(DataCollectorBase):
17651765
.. warning:: `policy_factory` is currently not compatible with multiprocessed data
17661766
collectors.
17671767
1768-
frames_per_batch (int): A keyword-only argument representing the
1769-
total number of elements in a batch.
1768+
frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
1769+
total number of elements in a batch. If a sequence is provided, represents the number of elements in a
1770+
batch per worker. Total number of elements in a batch is then the sum over the sequence.
17701771
total_frames (int, optional): A keyword-only argument representing the
17711772
total number of frames returned by the collector
17721773
during its lifespan. If the ``total_frames`` is not divisible by
@@ -1923,7 +1924,7 @@ def __init__(
19231924
policy_factory: Callable[[], Callable]
19241925
| list[Callable[[], Callable]]
19251926
| None = None,
1926-
frames_per_batch: int,
1927+
frames_per_batch: int | Sequence[int],
19271928
total_frames: int | None = -1,
19281929
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
19291930
storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
@@ -1959,6 +1960,22 @@ def __init__(
19591960
self.closed = True
19601961
self.num_workers = len(create_env_fn)
19611962

1963+
if (
1964+
isinstance(frames_per_batch, Sequence)
1965+
and len(frames_per_batch) != self.num_workers
1966+
):
1967+
raise ValueError(
1968+
"If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
1969+
f"Got {len(frames_per_batch)} values for {self.num_workers} workers."
1970+
)
1971+
1972+
self._frames_per_batch = frames_per_batch
1973+
total_frames_per_batch = (
1974+
sum(frames_per_batch)
1975+
if isinstance(frames_per_batch, Sequence)
1976+
else frames_per_batch
1977+
)
1978+
19621979
self.set_truncated = set_truncated
19631980
self.num_sub_threads = num_sub_threads
19641981
self.num_threads = num_threads
@@ -2076,11 +2093,11 @@ def __init__(
20762093
if total_frames is None or total_frames < 0:
20772094
total_frames = float("inf")
20782095
else:
2079-
remainder = total_frames % frames_per_batch
2096+
remainder = total_frames % total_frames_per_batch
20802097
if remainder != 0 and RL_WARNINGS:
20812098
warnings.warn(
2082-
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
2083-
f"This means {frames_per_batch - remainder} additional frames will be collected. "
2099+
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
2100+
f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
20842101
"To silence this message, set the environment variable RL_WARNINGS to False."
20852102
)
20862103
self.total_frames = (
@@ -2091,7 +2108,8 @@ def __init__(
20912108
self.max_frames_per_traj = (
20922109
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
20932110
)
2094-
self.requested_frames_per_batch = int(frames_per_batch)
2111+
2112+
self.requested_frames_per_batch = total_frames_per_batch
20952113
self.reset_when_done = reset_when_done
20962114
if split_trajs is None:
20972115
split_trajs = False
@@ -2221,8 +2239,7 @@ def _get_devices(
22212239
)
22222240
return storing_device, policy_device, env_device
22232241

2224-
@property
2225-
def frames_per_batch_worker(self):
2242+
def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
22262243
raise NotImplementedError
22272244

22282245
@property
@@ -2281,7 +2298,7 @@ def _run_processes(self) -> None:
22812298
"create_env_kwargs": env_fun_kwargs,
22822299
"policy": policy,
22832300
"max_frames_per_traj": self.max_frames_per_traj,
2284-
"frames_per_batch": self.frames_per_batch_worker,
2301+
"frames_per_batch": self.frames_per_batch_worker(worker_idx=i),
22852302
"reset_at_each_iter": self.reset_at_each_iter,
22862303
"policy_device": policy_device,
22872304
"storing_device": storing_device,
@@ -2773,8 +2790,9 @@ def update_policy_weights_(
27732790
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
27742791
)
27752792

2776-
@property
2777-
def frames_per_batch_worker(self):
2793+
def frames_per_batch_worker(self, worker_idx: int | None) -> int:
2794+
if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
2795+
return self._frames_per_batch[worker_idx]
27782796
if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
27792797
warnings.warn(
27802798
f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
@@ -2855,9 +2873,9 @@ def iterator(self) -> Iterator[TensorDictBase]:
28552873
use_buffers = self._use_buffers
28562874
if self.replay_buffer is not None:
28572875
idx = new_data
2858-
workers_frames[idx] = (
2859-
workers_frames[idx] + self.frames_per_batch_worker
2860-
)
2876+
workers_frames[idx] = workers_frames[
2877+
idx
2878+
] + self.frames_per_batch_worker(worker_idx=idx)
28612879
continue
28622880
elif j == 0 or not use_buffers:
28632881
try:
@@ -2903,7 +2921,12 @@ def iterator(self) -> Iterator[TensorDictBase]:
29032921

29042922
if self.replay_buffer is not None:
29052923
yield
2906-
self._frames += self.frames_per_batch_worker * self.num_workers
2924+
self._frames += sum(
2925+
[
2926+
self.frames_per_batch_worker(worker_idx)
2927+
for worker_idx in range(self.num_workers)
2928+
]
2929+
)
29072930
continue
29082931

29092932
# we have to correct the traj_ids to make sure that they don't overlap
@@ -3156,8 +3179,7 @@ def update_policy_weights_(
31563179
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
31573180
)
31583181

3159-
@property
3160-
def frames_per_batch_worker(self):
3182+
def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
31613183
return self.requested_frames_per_batch
31623184

31633185
def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]:
@@ -3221,7 +3243,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
32213243
if self.split_trajs:
32223244
out = split_trajectories(out, prefix="collector")
32233245
else:
3224-
worker_frames = self.frames_per_batch_worker
3246+
worker_frames = self.frames_per_batch_worker()
32253247
self._frames += worker_frames
32263248
workers_frames[idx] = workers_frames[idx] + worker_frames
32273249
if self.postprocs:

torchrl/data/llm/chat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import torch
1313

14-
1514
from tensordict import lazy_stack, LazyStackedTensorDict, list_to_stack, TensorClass
1615
from tensordict.utils import _maybe_correct_neg_dim
1716

0 commit comments

Comments
 (0)