Skip to content

Commit 74528e1

Browse files
kurtamohlervmoens
authored andcommitted
[Feature] Add Choice spec
ghstack-source-id: 6776395 Pull Request resolved: #2713
1 parent 20a19fe commit 74528e1

File tree

3 files changed

+283
-1
lines changed

3 files changed

+283
-1
lines changed

test/test_specs.py

+167-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
import torch
1313
import torchrl.data.tensor_specs
1414
from scipy.stats import chisquare
15-
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
15+
from tensordict import (
16+
LazyStackedTensorDict,
17+
NonTensorData,
18+
NonTensorStack,
19+
TensorDict,
20+
TensorDictBase,
21+
)
1622
from tensordict.utils import _unravel_key_to_tuple
1723
from torchrl._utils import _make_ordinal_device
1824

@@ -23,6 +29,7 @@
2329
Bounded,
2430
BoundedTensorSpec,
2531
Categorical,
32+
Choice,
2633
Composite,
2734
CompositeSpec,
2835
ContinuousBox,
@@ -702,6 +709,62 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
702709
assert ts["nested"].shape == (3,)
703710

704711

712+
class TestChoiceSpec:
713+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
714+
def test_choice(self, input_type):
715+
if input_type == "spec":
716+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
717+
example_in = torch.tensor(11.0)
718+
example_out = torch.tensor(9.0)
719+
elif input_type == "nontensor":
720+
choices = [NonTensorData("a"), NonTensorData("b")]
721+
example_in = NonTensorData("b")
722+
example_out = NonTensorData("c")
723+
elif input_type == "nontensorstack":
724+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
725+
example_in = NonTensorStack("a", "b", "c")
726+
example_out = NonTensorStack("a", "c", "b")
727+
728+
spec = Choice(choices)
729+
res = spec.rand()
730+
assert spec.is_in(res)
731+
assert spec.is_in(example_in)
732+
assert not spec.is_in(example_out)
733+
734+
def test_errors(self):
735+
with pytest.raises(TypeError, match="must be a list"):
736+
Choice("abc")
737+
738+
with pytest.raises(
739+
TypeError,
740+
match="must be either a TensorSpec, NonTensorData, or NonTensorStack",
741+
):
742+
Choice(["abc"])
743+
744+
with pytest.raises(TypeError, match="must be the same type"):
745+
Choice([Bounded(0, 1, (1,)), Categorical(10, (1,))])
746+
747+
with pytest.raises(ValueError, match="must have the same shape"):
748+
Choice([Categorical(10, (1,)), Categorical(10, (2,))])
749+
750+
with pytest.raises(ValueError, match="must have the same dtype"):
751+
Choice(
752+
[
753+
Categorical(10, (2,), dtype=torch.long),
754+
Categorical(10, (2,), dtype=torch.float),
755+
]
756+
)
757+
758+
if torch.cuda.device_count():
759+
with pytest.raises(ValueError, match="must have the same device"):
760+
Choice(
761+
[
762+
Categorical(10, (2,), device="cpu"),
763+
Categorical(10, (2,), device="cuda"),
764+
]
765+
)
766+
767+
705768
@pytest.mark.parametrize("shape", [(), (2, 3)])
706769
@pytest.mark.parametrize("device", get_default_devices())
707770
def test_create_composite_nested(shape, device):
@@ -1498,6 +1561,27 @@ def test_non_tensor(self):
14981561
)
14991562
assert spec.expand(2, 3, 4).example_data == "example_data"
15001563

1564+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
1565+
def test_choice(self, input_type):
1566+
if input_type == "spec":
1567+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
1568+
elif input_type == "nontensor":
1569+
choices = [NonTensorData("a"), NonTensorData("b")]
1570+
elif input_type == "nontensorstack":
1571+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1572+
1573+
spec = Choice(choices)
1574+
res = spec.expand(
1575+
[
1576+
3,
1577+
]
1578+
)
1579+
assert res.shape == torch.Size(
1580+
[
1581+
3,
1582+
]
1583+
)
1584+
15011585
@pytest.mark.parametrize("shape1", [None, (), (5,)])
15021586
@pytest.mark.parametrize("shape2", [(), (10,)])
15031587
def test_onehot(self, shape1, shape2):
@@ -1701,6 +1785,19 @@ def test_non_tensor(self):
17011785
assert spec.clone() is not spec
17021786
assert spec.clone().example_data == "example_data"
17031787

1788+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
1789+
def test_choice(self, input_type):
1790+
if input_type == "spec":
1791+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
1792+
elif input_type == "nontensor":
1793+
choices = [NonTensorData("a"), NonTensorData("b")]
1794+
elif input_type == "nontensorstack":
1795+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1796+
1797+
spec = Choice(choices)
1798+
assert spec.clone() == spec
1799+
assert spec.clone() is not spec
1800+
17041801
@pytest.mark.parametrize("shape1", [None, (), (5,)])
17051802
def test_onehot(
17061803
self,
@@ -1786,6 +1883,31 @@ def test_non_tensor(self):
17861883
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
17871884
spec.cardinality()
17881885

1886+
@pytest.mark.parametrize(
1887+
"input_type",
1888+
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
1889+
)
1890+
def test_choice(self, input_type):
1891+
if input_type == "bounded_spec":
1892+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
1893+
elif input_type == "categorical_spec":
1894+
choices = [Categorical(10), Categorical(20)]
1895+
elif input_type == "nontensor":
1896+
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
1897+
elif input_type == "nontensorstack":
1898+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1899+
1900+
spec = Choice(choices)
1901+
1902+
if input_type == "bounded_spec":
1903+
assert spec.cardinality() == float("inf")
1904+
elif input_type == "categorical_spec":
1905+
assert spec.cardinality() == 30
1906+
elif input_type == "nontensor":
1907+
assert spec.cardinality() == 3
1908+
elif input_type == "nontensorstack":
1909+
assert spec.cardinality() == 2
1910+
17891911
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
17901912
def test_onehot(
17911913
self,
@@ -2096,6 +2218,23 @@ def test_non_tensor(self, device):
20962218
assert spec.to(device).device == device
20972219
assert spec.to(device).example_data == "example_data"
20982220

2221+
@pytest.mark.parametrize(
2222+
"input_type",
2223+
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
2224+
)
2225+
def test_choice(self, input_type, device):
2226+
if input_type == "bounded_spec":
2227+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
2228+
elif input_type == "categorical_spec":
2229+
choices = [Categorical(10), Categorical(20)]
2230+
elif input_type == "nontensor":
2231+
choices = [NonTensorData("a"), NonTensorData("b")]
2232+
elif input_type == "nontensorstack":
2233+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
2234+
2235+
spec = Choice(choices)
2236+
assert spec.to(device).device == device
2237+
20992238
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
21002239
def test_onehot(self, shape1, device):
21012240
if shape1 is None:
@@ -2363,6 +2502,33 @@ def test_stack_non_tensor(self, shape, stack_dim):
23632502
assert new_spec.device == torch.device("cpu")
23642503
assert new_spec.example_data == "example_data"
23652504

2505+
@pytest.mark.parametrize(
2506+
"input_type",
2507+
["bounded_spec", "categorical_spec", "nontensor"],
2508+
)
2509+
def test_stack_choice(self, input_type, shape, stack_dim):
2510+
if input_type == "bounded_spec":
2511+
choices = [Bounded(0, 2.5, shape), Bounded(10, 12, shape)]
2512+
elif input_type == "categorical_spec":
2513+
choices = [Categorical(10, shape), Categorical(20, shape)]
2514+
elif input_type == "nontensor":
2515+
if len(shape) == 0:
2516+
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
2517+
else:
2518+
choices = [
2519+
NonTensorStack("a").expand(shape + (1,)).squeeze(-1),
2520+
NonTensorStack("d").expand(shape + (1,)).squeeze(-1),
2521+
]
2522+
2523+
spec0 = Choice(choices)
2524+
spec1 = Choice(choices)
2525+
res = torch.stack([spec0, spec1], stack_dim)
2526+
assert isinstance(res, Choice)
2527+
assert (
2528+
res.shape
2529+
== torch.stack([torch.empty(shape), torch.empty(shape)], stack_dim).shape
2530+
)
2531+
23662532
def test_stack_onehot(self, shape, stack_dim):
23672533
n = 5
23682534
shape = (*shape, 5)

torchrl/data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
BoundedContinuous,
7777
BoundedTensorSpec,
7878
Categorical,
79+
Choice,
7980
Composite,
8081
CompositeSpec,
8182
DEVICE_TYPING,

torchrl/data/tensor_specs.py

+115
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
is_tensor_collection,
3939
LazyStackedTensorDict,
4040
NonTensorData,
41+
NonTensorStack,
4142
TensorDict,
4243
TensorDictBase,
4344
unravel_key,
@@ -3732,6 +3733,120 @@ def clone(self) -> Categorical:
37323733
)
37333734

37343735

3736+
class Choice(TensorSpec):
3737+
"""A discrete choice spec for either tensor or non-tensor data.
3738+
3739+
Args:
3740+
choices (list[:class:`~TensorSpec`, :class:`~tensordict.NonTensorData`, :class:`~tensordict.NonTensorStack`]):
3741+
List of specs or non-tensor data from which to choose during
3742+
sampling. All elements must have the same type, shape, dtype, and
3743+
device.
3744+
3745+
Examples:
3746+
>>> import torch
3747+
>>> _ = torch.manual_seed(0)
3748+
>>> from torchrl.data import Choice, Bounded
3749+
>>> spec = Choice([
3750+
... Bounded(0, 1, shape=(1,)),
3751+
... Bounded(10, 11, shape=(1,))])
3752+
>>> spec.rand()
3753+
tensor([0.7682])
3754+
>>> spec.rand()
3755+
tensor([10.1320])
3756+
>>> from tensordict import NonTensorData
3757+
>>> _ = torch.manual_seed(0)
3758+
>>> spec = Choice([NonTensorData(s) for s in ["a", "b", "c", "d"]])
3759+
>>> spec.rand().data
3760+
'a'
3761+
>>> spec.rand().data
3762+
'd'
3763+
"""
3764+
3765+
def __init__(
3766+
self,
3767+
choices: List[TensorSpec | NonTensorData | NonTensorStack],
3768+
):
3769+
if not isinstance(choices, list):
3770+
raise TypeError("'choices' must be a list")
3771+
3772+
if not isinstance(choices[0], (TensorSpec, NonTensorData, NonTensorStack)):
3773+
raise TypeError(
3774+
(
3775+
"Each choice must be either a TensorSpec, NonTensorData, or "
3776+
f"NonTensorStack, but got {type(choices[0])}"
3777+
)
3778+
)
3779+
3780+
if not all([isinstance(choice, type(choices[0])) for choice in choices[1:]]):
3781+
raise TypeError("All choices must be the same type")
3782+
3783+
if not all([choice.shape == choices[0].shape for choice in choices[1:]]):
3784+
raise ValueError("All choices must have the same shape")
3785+
3786+
if not all([choice.dtype == choices[0].dtype for choice in choices[1:]]):
3787+
raise ValueError("All choices must have the same dtype")
3788+
3789+
if not all([choice.device == choices[0].device for choice in choices[1:]]):
3790+
raise ValueError("All choices must have the same device")
3791+
3792+
shape = choices[0].shape
3793+
device = choices[0].device
3794+
dtype = choices[0].dtype
3795+
3796+
super().__init__(
3797+
shape=shape, space=None, device=device, dtype=dtype, domain=None
3798+
)
3799+
3800+
self._choices = [choice.clone() for choice in choices]
3801+
3802+
def _rand_idx(self):
3803+
return torch.randint(0, len(self._choices), ()).item()
3804+
3805+
def _sample(self, idx, spec_sample_fn) -> TensorDictBase:
3806+
res = self._choices[idx]
3807+
if isinstance(res, TensorSpec):
3808+
return spec_sample_fn(res)
3809+
else:
3810+
return res
3811+
3812+
def zero(self, shape: torch.Size = None) -> TensorDictBase:
3813+
return self._sample(0, lambda x: x.zero(shape))
3814+
3815+
def one(self, shape: torch.Size = None) -> TensorDictBase:
3816+
return self._sample(min(1, len(self - 1)), lambda x: x.one(shape))
3817+
3818+
def rand(self, shape: torch.Size = None) -> TensorDictBase:
3819+
return self._sample(self._rand_idx(), lambda x: x.rand(shape))
3820+
3821+
def is_in(self, val: torch.Tensor | TensorDictBase) -> bool:
3822+
if isinstance(self._choices[0], TensorSpec):
3823+
return any([choice.is_in(val) for choice in self._choices])
3824+
else:
3825+
return any([(choice == val).all() for choice in self._choices])
3826+
3827+
def expand(self, *shape):
3828+
return self.__class__([choice.expand(*shape) for choice in self._choices])
3829+
3830+
def unsqueeze(self, dim: int):
3831+
return self.__class__([choice.unsqueeze(dim) for choice in self._choices])
3832+
3833+
def clone(self) -> Choice:
3834+
return self.__class__([choice.clone() for choice in self._choices])
3835+
3836+
def cardinality(self) -> int:
3837+
if isinstance(self._choices[0], (NonTensorData, NonTensorStack)):
3838+
return len(self._choices)
3839+
else:
3840+
return (
3841+
torch.tensor([choice.cardinality() for choice in self._choices])
3842+
.sum()
3843+
.item()
3844+
)
3845+
3846+
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Choice:
3847+
return self.__class__([choice.to(dest) for choice in self._choices])
3848+
3849+
37353850
@dataclass(repr=False)
37363851
class Binary(Categorical):
37373852
"""A binary discrete tensor spec.

0 commit comments

Comments
 (0)