Skip to content

[Feature] Add Choice spec #2713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

if __name__ == "__main__":
avail_devices = ("cpu",)
if torch.cuda.device_count():
if torch.cuda.is_available():
avail_devices = avail_devices + ("cuda:0",)

for envname in [
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"slurm_gpus_per_task": args.slurm_gpus_per_task,
}
device_str = "device" if num_workers <= 1 else "devices"
if torch.cuda.device_count():
if torch.cuda.is_available():
collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"}
else:
collector_kwargs = {device_str: "cpu", "storing_{device_str}": "cpu"}
Expand Down
179 changes: 169 additions & 10 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import torch
import torchrl.data.tensor_specs
from scipy.stats import chisquare
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict import (
LazyStackedTensorDict,
NonTensorData,
NonTensorStack,
TensorDict,
TensorDictBase,
)
from tensordict.utils import _unravel_key_to_tuple
from torchrl._utils import _make_ordinal_device

Expand All @@ -23,6 +29,7 @@
Bounded,
BoundedTensorSpec,
Categorical,
Choice,
Composite,
CompositeSpec,
ContinuousBox,
Expand Down Expand Up @@ -702,6 +709,63 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
assert ts["nested"].shape == (3,)


class TestChoiceSpec:
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
def test_choice(self, input_type):
if input_type == "spec":
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
example_in = torch.tensor(11.0)
example_out = torch.tensor(9.0)
elif input_type == "nontensor":
choices = [NonTensorData("a"), NonTensorData("b")]
example_in = NonTensorData("b")
example_out = NonTensorData("c")
elif input_type == "nontensorstack":
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
example_in = NonTensorStack("a", "b", "c")
example_out = NonTensorStack("a", "c", "b")
torch.manual_seed(0)
for _ in range(10):
spec = Choice(choices)
res = spec.rand()
assert spec.is_in(res)
assert spec.is_in(example_in)
assert not spec.is_in(example_out)

def test_errors(self):
with pytest.raises(TypeError, match="must be a list"):
Choice("abc")

with pytest.raises(
TypeError,
match="must be either a TensorSpec, NonTensorData, or NonTensorStack",
):
Choice(["abc"])

with pytest.raises(TypeError, match="must be the same type"):
Choice([Bounded(0, 1, (1,)), Categorical(10, (1,))])

with pytest.raises(ValueError, match="must have the same shape"):
Choice([Categorical(10, (1,)), Categorical(10, (2,))])

with pytest.raises(ValueError, match="must have the same dtype"):
Choice(
[
Categorical(10, (2,), dtype=torch.long),
Categorical(10, (2,), dtype=torch.float),
]
)

if torch.cuda.is_available():
with pytest.raises(ValueError, match="must have the same device"):
Choice(
[
Categorical(10, (2,), device="cpu"),
Categorical(10, (2,), device="cuda"),
]
)


@pytest.mark.parametrize("shape", [(), (2, 3)])
@pytest.mark.parametrize("device", get_default_devices())
def test_create_composite_nested(shape, device):
Expand Down Expand Up @@ -851,7 +915,7 @@ def test_equality_bounded(self):

ts_other = Bounded(minimum, maximum + 1, torch.Size((1,)), device, dtype)
assert ts != ts_other
if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = Bounded(minimum, maximum, torch.Size((1,)), "cuda:0", dtype)
assert ts != ts_other

Expand Down Expand Up @@ -879,7 +943,7 @@ def test_equality_onehot(self):
)
assert ts != ts_other

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = OneHot(
n=n, device="cuda:0", dtype=dtype, use_register=use_register
)
Expand Down Expand Up @@ -909,7 +973,7 @@ def test_equality_unbounded(self):
ts_same = Unbounded(device=device, dtype=dtype)
assert ts == ts_same

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = Unbounded(device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -942,7 +1006,7 @@ def test_equality_ndbounded(self):
ts_other = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype)
assert ts != ts_other

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = Bounded(low=minimum, high=maximum, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -970,7 +1034,7 @@ def test_equality_discrete(self):
ts_other = Categorical(n=n + 1, shape=shape, device=device, dtype=dtype)
assert ts != ts_other

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = Categorical(n=n, shape=shape, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -1008,7 +1072,7 @@ def test_equality_ndunbounded(self, shape):
ts_other = Unbounded(shape=other_shape, device=device, dtype=dtype)
assert ts != ts_other

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = Unbounded(shape=shape, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand All @@ -1034,7 +1098,7 @@ def test_equality_binary(self):
ts_other = Binary(n=n + 5, device=device, dtype=dtype)
assert ts != ts_other

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = Binary(n=n, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -1068,7 +1132,7 @@ def test_equality_multi_onehot(self, nvec):
ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = MultiOneHot(nvec=nvec, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -1102,7 +1166,7 @@ def test_equality_multi_discrete(self, nvec):
ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other

if torch.cuda.device_count():
if torch.cuda.is_available():
ts_other = MultiCategorical(nvec=nvec, device="cuda:0", dtype=dtype)
assert ts != ts_other

Expand Down Expand Up @@ -1498,6 +1562,19 @@ def test_non_tensor(self):
)
assert spec.expand(2, 3, 4).example_data == "example_data"

@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
def test_choice(self, input_type):
if input_type == "spec":
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
elif input_type == "nontensor":
choices = [NonTensorData("a"), NonTensorData("b")]
elif input_type == "nontensorstack":
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]

spec = Choice(choices)
res = spec.expand([3])
assert res.shape == torch.Size([3])

@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
def test_onehot(self, shape1, shape2):
Expand Down Expand Up @@ -1701,6 +1778,19 @@ def test_non_tensor(self):
assert spec.clone() is not spec
assert spec.clone().example_data == "example_data"

@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
def test_choice(self, input_type):
if input_type == "spec":
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
elif input_type == "nontensor":
choices = [NonTensorData("a"), NonTensorData("b")]
elif input_type == "nontensorstack":
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]

spec = Choice(choices)
assert spec.clone() == spec
assert spec.clone() is not spec

@pytest.mark.parametrize("shape1", [None, (), (5,)])
def test_onehot(
self,
Expand Down Expand Up @@ -1786,6 +1876,31 @@ def test_non_tensor(self):
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
spec.cardinality()

@pytest.mark.parametrize(
"input_type",
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
)
def test_choice(self, input_type):
if input_type == "bounded_spec":
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
elif input_type == "categorical_spec":
choices = [Categorical(10), Categorical(20)]
elif input_type == "nontensor":
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
elif input_type == "nontensorstack":
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]

spec = Choice(choices)

if input_type == "bounded_spec":
assert spec.cardinality() == float("inf")
elif input_type == "categorical_spec":
assert spec.cardinality() == 30
elif input_type == "nontensor":
assert spec.cardinality() == 3
elif input_type == "nontensorstack":
assert spec.cardinality() == 2

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
self,
Expand Down Expand Up @@ -2096,6 +2211,23 @@ def test_non_tensor(self, device):
assert spec.to(device).device == device
assert spec.to(device).example_data == "example_data"

@pytest.mark.parametrize(
"input_type",
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
)
def test_choice(self, input_type, device):
if input_type == "bounded_spec":
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
elif input_type == "categorical_spec":
choices = [Categorical(10), Categorical(20)]
elif input_type == "nontensor":
choices = [NonTensorData("a"), NonTensorData("b")]
elif input_type == "nontensorstack":
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]

spec = Choice(choices)
assert spec.to(device).device == device

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(self, shape1, device):
if shape1 is None:
Expand Down Expand Up @@ -2363,6 +2495,33 @@ def test_stack_non_tensor(self, shape, stack_dim):
assert new_spec.device == torch.device("cpu")
assert new_spec.example_data == "example_data"

@pytest.mark.parametrize(
"input_type",
["bounded_spec", "categorical_spec", "nontensor"],
)
def test_stack_choice(self, input_type, shape, stack_dim):
if input_type == "bounded_spec":
choices = [Bounded(0, 2.5, shape), Bounded(10, 12, shape)]
elif input_type == "categorical_spec":
choices = [Categorical(10, shape), Categorical(20, shape)]
elif input_type == "nontensor":
if len(shape) == 0:
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
else:
choices = [
NonTensorStack("a").expand(shape + (1,)).squeeze(-1),
NonTensorStack("d").expand(shape + (1,)).squeeze(-1),
]

spec0 = Choice(choices)
spec1 = Choice(choices)
res = torch.stack([spec0, spec1], stack_dim)
assert isinstance(res, Choice)
assert (
res.shape
== torch.stack([torch.empty(shape), torch.empty(shape)], stack_dim).shape
)

def test_stack_onehot(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _init_master_rpc(
):
"""Init RPC on main node."""
options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options)
if torch.cuda.device_count():
if torch.cuda.is_available():
if self.visible_devices:
for i in range(self.num_workers):
rank = i + 1
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
BoundedContinuous,
BoundedTensorSpec,
Categorical,
Choice,
Composite,
CompositeSpec,
DEVICE_TYPING,
Expand Down
Loading
Loading