Skip to content

Commit ff2c42f

Browse files
committed
[Feature] ConditionalPolicySwitch transform
ghstack-source-id: eee6f67835ff8fc221e3096490c0577f240bc541 Pull Request resolved: #2711
1 parent 5b459ba commit ff2c42f

File tree

6 files changed

+229
-18
lines changed

6 files changed

+229
-18
lines changed

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ to be able to create this other composition:
816816
CenterCrop
817817
ClipTransform
818818
Compose
819+
ConditionalPolicySwitch
819820
Crop
820821
DTypeCastTransform
821822
DeviceCastTransform

examples/agents/ppo-chess.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
import tensordict.nn
66
import torch
77
import tqdm
8-
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
9-
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
8+
from tensordict.nn import (
9+
ProbabilisticTensorDictModule as TDProb,
10+
ProbabilisticTensorDictSequential as TDProbSeq,
11+
TensorDictModule as TDMod,
12+
TensorDictSequential as TDSeq,
13+
)
1014
from torch import nn
1115
from torch.nn.utils import clip_grad_norm_
1216
from torch.optim import Adam
1317

1418
from torchrl.collectors import SyncDataCollector
19+
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement
1520

1621
from torchrl.envs import ChessEnv, Tokenizer
1722
from torchrl.modules import MLP
1823
from torchrl.modules.distributions import MaskedCategorical
1924
from torchrl.objectives import ClipPPOLoss
2025
from torchrl.objectives.value import GAE
21-
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
2226

2327
tensordict.nn.set_composite_lp_aggregate(False)
2428

@@ -39,7 +43,9 @@
3943
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)
4044

4145
# Embedding for the fen
42-
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
46+
embedding_fen = nn.Embedding(
47+
num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64
48+
)
4349

4450
backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)
4551

@@ -49,20 +55,30 @@
4955
critic_head = nn.Linear(512, 1)
5056
critic_head.bias.data.fill_(0)
5157

52-
prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
58+
prob = TDProb(
59+
in_keys=["logits", "mask"],
60+
out_keys=["action"],
61+
distribution_class=MaskedCategorical,
62+
return_log_prob=True,
63+
)
64+
5365

5466
def make_mask(idx):
5567
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
5668
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]
5769

70+
5871
actor = TDProbSeq(
59-
TDMod(
60-
make_mask,
61-
in_keys=["legal_moves"], out_keys=["mask"]),
72+
TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]),
6273
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
6374
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
64-
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
65-
out_keys=["features"]),
75+
TDMod(
76+
lambda *args: torch.cat(
77+
[arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1
78+
),
79+
in_keys=["embedded_legal_moves", "embedded_fen"],
80+
out_keys=["features"],
81+
),
6682
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
6783
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
6884
prob,
@@ -78,7 +94,9 @@ def make_mask(idx):
7894

7995
optim = Adam(loss.parameters())
8096

81-
gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
97+
gae = GAE(
98+
value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True
99+
)
82100

83101
# Create a data collector
84102
collector = SyncDataCollector(
@@ -88,12 +106,20 @@ def make_mask(idx):
88106
total_frames=1_000_000,
89107
)
90108

91-
replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
92-
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
109+
replay_buffer0 = ReplayBuffer(
110+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
111+
batch_size=batch_size,
112+
sampler=SamplerWithoutReplacement(),
113+
)
114+
replay_buffer1 = ReplayBuffer(
115+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
116+
batch_size=batch_size,
117+
sampler=SamplerWithoutReplacement(),
118+
)
93119

94120
for data in tqdm.tqdm(collector):
95121
data = data.filter_non_tensor_data()
96-
print('data', data[0::2])
122+
print("data", data[0::2])
97123
for i in range(num_epochs):
98124
replay_buffer0.empty()
99125
replay_buffer1.empty()
@@ -103,14 +129,24 @@ def make_mask(idx):
103129
# player 1
104130
data1 = gae(data[1::2])
105131
if i == 0:
106-
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
107-
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
132+
print(
133+
"win rate for 0",
134+
data0["next", "reward"].sum()
135+
/ data["next", "done"].sum().clamp_min(1e-6),
136+
)
137+
print(
138+
"win rate for 1",
139+
data1["next", "reward"].sum()
140+
/ data["next", "done"].sum().clamp_min(1e-6),
141+
)
108142

109143
replay_buffer0.extend(data0)
110144
replay_buffer1.extend(data1)
111145

112-
n_iter = collector.frames_per_batch//(2 * batch_size)
113-
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
146+
n_iter = collector.frames_per_batch // (2 * batch_size)
147+
for (d0, d1) in tqdm.tqdm(
148+
zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter
149+
):
114150
loss_vals = (loss(d0) + loss(d1)) / 2
115151
loss_vals.sum(reduce=True).backward()
116152
gn = clip_grad_norm_(loss.parameters(), 100.0)

test/test_transforms.py

+122
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
CenterCrop,
107107
ClipTransform,
108108
Compose,
109+
ConditionalPolicySwitch,
109110
Crop,
110111
DeviceCastTransform,
111112
DiscreteActionProjection,
@@ -13192,6 +13193,127 @@ def test_composite_reward_spec(self) -> None:
1319213193
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
1319313194

1319413195

13196+
class TestConditionalPolicySwitch(TransformBase):
13197+
def test_single_trans_env_check(self):
13198+
base_env = CountingEnv(max_steps=15)
13199+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13200+
# Player 0
13201+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13202+
policy_even = lambda td: td.set("action", env.action_spec.one())
13203+
transforms = Compose(
13204+
StepCounter(),
13205+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13206+
)
13207+
env = base_env.append_transform(transforms)
13208+
r = env.rollout(1000, policy_odd, break_when_all_done=True)
13209+
assert r.shape[0] == 15
13210+
assert (r["action"] == 0).all()
13211+
assert (r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)).all()
13212+
assert r["next", "done"].any()
13213+
13214+
# Player 1
13215+
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
13216+
transforms = Compose(
13217+
StepCounter(),
13218+
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
13219+
)
13220+
env = base_env.append_transform(transforms)
13221+
r = env.rollout(1000, policy_even, break_when_all_done=True)
13222+
assert r.shape[0] == 16
13223+
assert (r["action"] == 1).all()
13224+
assert (r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)).all()
13225+
assert r["next", "done"].any()
13226+
13227+
13228+
def test_trans_serial_env_check(self):
13229+
def make_env(max_count):
13230+
def make():
13231+
base_env = CountingEnv(max_steps=max_count)
13232+
transforms =
13233+
return base_env.append_transform(transforms)
13234+
return make
13235+
13236+
base_env = SerialEnv(3,
13237+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)])
13238+
condition = lambda td: ((td.get("step_count") % 2) == 0)
13239+
policy_odd = lambda td, base_env=base_env: td.set("action", base_env.action_spec.zero())
13240+
policy_even = lambda td, base_env=base_env: td.set("action", base_env.action_spec.one())
13241+
env = base_env.append_transform(Compose(
13242+
StepCounter(),
13243+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13244+
))
13245+
r = env.rollout(100, break_when_all_done=False)
13246+
print(r["step_count"].squeeze())
13247+
13248+
13249+
def test_trans_parallel_env_check(self):
13250+
"""tests that a transformed paprallel env (TransformedEnv(ParallelEnv(N, lambda: env()), transform)) passes the check_env_specs test."""
13251+
raise NotImplementedError
13252+
13253+
def test_serial_trans_env_check(self):
13254+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13255+
# Player 0
13256+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13257+
policy_even = lambda td: td.set("action", env.action_spec.one())
13258+
def make_env(max_count):
13259+
def make():
13260+
base_env = CountingEnv(max_steps=max_count)
13261+
transforms = Compose(
13262+
StepCounter(),
13263+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13264+
)
13265+
return base_env.append_transform(transforms)
13266+
return make
13267+
13268+
env = SerialEnv(3,
13269+
[make_env(6), make_env(7), make_env(8)])
13270+
r = env.rollout(100, break_when_all_done=False)
13271+
print(r["step_count"].squeeze())
13272+
13273+
def test_parallel_trans_env_check(self):
13274+
"""tests that a parallel transformed env (ParallelEnv(N, lambda: TransformedEnv(env, transform))) passes the check_env_specs test."""
13275+
raise NotImplementedError
13276+
13277+
def test_transform_no_env(self):
13278+
"""tests the transform on dummy data, without an env."""
13279+
raise NotImplementedError
13280+
13281+
def test_transform_compose(self):
13282+
"""tests the transform on dummy data, without an env but inside a Compose."""
13283+
raise NotImplementedError
13284+
13285+
def test_transform_env(self):
13286+
"""tests the transform on a real env.
13287+
13288+
If possible, do not use a mock env, as bugs may go unnoticed if the dynamic is too
13289+
simplistic. A call to reset() and step() should be tested independently, ie
13290+
a check that reset produces the desired output and that step() does too.
13291+
13292+
"""
13293+
raise NotImplementedError
13294+
13295+
def test_transform_model(self):
13296+
"""tests the transform before an nn.Module that reads the output."""
13297+
raise NotImplementedError
13298+
13299+
def test_transform_rb(self):
13300+
"""tests the transform when used with a replay buffer.
13301+
13302+
If your transform is not supposed to work with a replay buffer, test that
13303+
an error will be raised when called or appended to a RB.
13304+
13305+
"""
13306+
raise NotImplementedError
13307+
13308+
def test_transform_inverse(self):
13309+
"""tests the inverse transform. If not applicable, simply skip this test.
13310+
13311+
If your transform is not supposed to work offline, test that
13312+
an error will be raised when called in a nn.Module.
13313+
"""
13314+
raise NotImplementedError
13315+
13316+
1319513317
if __name__ == "__main__":
1319613318
args, unknown = argparse.ArgumentParser().parse_known_args()
1319713319
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
CenterCrop,
5656
ClipTransform,
5757
Compose,
58+
ConditionalPolicySwitch,
5859
Crop,
5960
DeviceCastTransform,
6061
DiscreteActionProjection,

torchrl/envs/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CenterCrop,
2121
ClipTransform,
2222
Compose,
23+
ConditionalPolicySwitch,
2324
Crop,
2425
DeviceCastTransform,
2526
DiscreteActionProjection,

torchrl/envs/transforms/transforms.py

+50
Original file line numberDiff line numberDiff line change
@@ -9974,3 +9974,53 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase:
99749974
)
99759975

99769976
return (self.weights * reward).sum(dim=-1)
9977+
9978+
9979+
class ConditionalPolicySwitch(Transform):
9980+
def __init__(
9981+
self,
9982+
policy: Callable[[TensorDictBase], TensorDictBase],
9983+
condition: Callable[[TensorDictBase], bool],
9984+
):
9985+
super().__init__([], [])
9986+
self.__dict__["policy"] = policy
9987+
self.condition = condition
9988+
9989+
def _step(
9990+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
9991+
) -> TensorDictBase:
9992+
cond = self.condition(next_tensordict)
9993+
if not isinstance(cond, (bool, torch.Tensor)):
9994+
raise RuntimeError("Calling the condition function should return a boolean or a tensor.")
9995+
if isinstance(cond, (torch.Tensor,)) and cond.shape not in ((1,), (), tensordict.shape):
9996+
raise RuntimeError("Tenspr outputs must have the shape of the tensordict, or contain a single element.")
9997+
if cond.any():
9998+
parent: TransformedEnv = self.parent
9999+
done = next_tensordict.get("done")
10000+
next_td_save = None
10001+
if done.any():
10002+
if next_tensordict.numel() == 1 or done.all():
10003+
return next_tensordict
10004+
if parent.base_env.batch_locked:
10005+
raise RuntimeError("Cannot run partial steps in a batched locked environment")
10006+
done = done.view(next_tensordict.shape)
10007+
next_td_save = next_tensordict[done]
10008+
next_tensordict = next_tensordict[~done]
10009+
tensordict = tensordict[~done]
10010+
td = self.policy(
10011+
parent.step_mdp(tensordict.copy().set("next", next_tensordict))
10012+
)
10013+
next_tensordict = parent._step(td)
10014+
if next_td_save is not None:
10015+
return torch.where(done, next_td_save, next_tensordict)
10016+
return next_tensordict
10017+
return next_tensordict
10018+
10019+
def _reset(
10020+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
10021+
) -> TensorDictBase:
10022+
if self.condition(tensordict_reset):
10023+
parent: TransformedEnv = self.parent
10024+
td = self.policy(tensordict_reset)
10025+
return parent._step(td).exclude(*parent.reward_keys)
10026+
return tensordict_reset

0 commit comments

Comments
 (0)