Skip to content

Commit 76bf772

Browse files
committed
[Feature] Add Choice spec
ghstack-source-id: e0092df Pull Request resolved: #2713
1 parent a901064 commit 76bf772

File tree

3 files changed

+223
-1
lines changed

3 files changed

+223
-1
lines changed

test/test_specs.py

+128-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
import torch
1313
import torchrl.data.tensor_specs
1414
from scipy.stats import chisquare
15-
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
15+
from tensordict import (
16+
LazyStackedTensorDict,
17+
NonTensorData,
18+
NonTensorStack,
19+
TensorDict,
20+
TensorDictBase,
21+
)
1622
from tensordict.utils import _unravel_key_to_tuple
1723
from torchrl._utils import _make_ordinal_device
1824

@@ -23,6 +29,7 @@
2329
Bounded,
2430
BoundedTensorSpec,
2531
Categorical,
32+
Choice,
2633
Composite,
2734
CompositeSpec,
2835
ContinuousBox,
@@ -678,6 +685,23 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
678685
assert ts["nested"].shape == (3,)
679686

680687

688+
class TestChoiceSpec:
689+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
690+
def test_choice(self, input_type):
691+
if input_type == "spec":
692+
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
693+
elif input_type == "nontensor":
694+
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
695+
elif input_type == "nontensorstack":
696+
stack = torch.stack(
697+
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
698+
)
699+
700+
spec = Choice(stack)
701+
res = spec.rand()
702+
assert spec.is_in(res)
703+
704+
681705
@pytest.mark.parametrize("shape", [(), (2, 3)])
682706
@pytest.mark.parametrize("device", get_default_devices())
683707
def test_create_composite_nested(shape, device):
@@ -1409,6 +1433,21 @@ def test_non_tensor(self):
14091433
== NonTensor((2, 3, 4), device="cpu")
14101434
)
14111435

1436+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
1437+
def test_choice(self, input_type):
1438+
if input_type == "spec":
1439+
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
1440+
elif input_type == "nontensor":
1441+
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
1442+
elif input_type == "nontensorstack":
1443+
stack = torch.stack(
1444+
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1445+
)
1446+
1447+
spec = Choice(stack)
1448+
with pytest.raises(NotImplementedError):
1449+
spec.expand((3,))
1450+
14121451
@pytest.mark.parametrize("shape1", [None, (), (5,)])
14131452
@pytest.mark.parametrize("shape2", [(), (10,)])
14141453
def test_onehot(self, shape1, shape2):
@@ -1611,6 +1650,21 @@ def test_non_tensor(self):
16111650
assert spec.clone() == spec
16121651
assert spec.clone() is not spec
16131652

1653+
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
1654+
def test_choice(self, input_type):
1655+
if input_type == "spec":
1656+
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
1657+
elif input_type == "nontensor":
1658+
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
1659+
elif input_type == "nontensorstack":
1660+
stack = torch.stack(
1661+
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1662+
)
1663+
1664+
spec = Choice(stack)
1665+
assert spec.clone() == spec
1666+
assert spec.clone() is not spec
1667+
16141668
@pytest.mark.parametrize("shape1", [None, (), (5,)])
16151669
def test_onehot(
16161670
self,
@@ -1696,6 +1750,35 @@ def test_non_tensor(self):
16961750
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
16971751
spec.cardinality()
16981752

1753+
@pytest.mark.parametrize(
1754+
"input_type",
1755+
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
1756+
)
1757+
def test_choice(self, input_type):
1758+
if input_type == "bounded_spec":
1759+
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
1760+
elif input_type == "categorical_spec":
1761+
stack = torch.stack([Categorical(10), Categorical(20)])
1762+
elif input_type == "nontensor":
1763+
stack = torch.stack(
1764+
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
1765+
)
1766+
elif input_type == "nontensorstack":
1767+
stack = torch.stack(
1768+
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1769+
)
1770+
1771+
spec = Choice(stack)
1772+
1773+
if input_type == "bounded_spec":
1774+
assert spec.cardinality() == float("inf")
1775+
elif input_type == "categorical_spec":
1776+
assert spec.cardinality() == 30
1777+
elif input_type == "nontensor":
1778+
assert spec.cardinality() == 3
1779+
elif input_type == "nontensorstack":
1780+
assert spec.cardinality() == 2
1781+
16991782
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
17001783
def test_onehot(
17011784
self,
@@ -2004,6 +2087,27 @@ def test_non_tensor(self, device):
20042087
spec = NonTensor(shape=(3, 4), device="cpu")
20052088
assert spec.to(device).device == device
20062089

2090+
@pytest.mark.parametrize(
2091+
"input_type",
2092+
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
2093+
)
2094+
def test_choice(self, input_type, device):
2095+
if input_type == "bounded_spec":
2096+
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
2097+
elif input_type == "categorical_spec":
2098+
stack = torch.stack([Categorical(10), Categorical(20)])
2099+
elif input_type == "nontensor":
2100+
stack = torch.stack(
2101+
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
2102+
)
2103+
elif input_type == "nontensorstack":
2104+
stack = torch.stack(
2105+
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
2106+
)
2107+
2108+
spec = Choice(stack, device="cpu")
2109+
assert spec.to(device).device == device
2110+
20072111
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
20082112
def test_onehot(self, shape1, device):
20092113
if shape1 is None:
@@ -2270,6 +2374,29 @@ def test_stack_non_tensor(self, shape, stack_dim):
22702374
assert new_spec.shape == torch.Size(shape_insert)
22712375
assert new_spec.device == torch.device("cpu")
22722376

2377+
@pytest.mark.parametrize(
2378+
"input_type",
2379+
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
2380+
)
2381+
def test_stack_choice(self, input_type, shape, stack_dim):
2382+
if input_type == "bounded_spec":
2383+
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
2384+
elif input_type == "categorical_spec":
2385+
stack = torch.stack([Categorical(10), Categorical(20)])
2386+
elif input_type == "nontensor":
2387+
stack = torch.stack(
2388+
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
2389+
)
2390+
elif input_type == "nontensorstack":
2391+
stack = torch.stack(
2392+
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
2393+
)
2394+
2395+
spec0 = Choice(stack)
2396+
spec1 = Choice(stack)
2397+
with pytest.raises(NotImplementedError):
2398+
torch.stack([spec0, spec1], 0)
2399+
22732400
def test_stack_onehot(self, shape, stack_dim):
22742401
n = 5
22752402
shape = (*shape, 5)

torchrl/data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
BoundedContinuous,
7676
BoundedTensorSpec,
7777
Categorical,
78+
Choice,
7879
Composite,
7980
CompositeSpec,
8081
DEVICE_TYPING,

torchrl/data/tensor_specs.py

+94
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_tensor_collection,
3737
LazyStackedTensorDict,
3838
NonTensorData,
39+
NonTensorStack,
3940
TensorDict,
4041
TensorDictBase,
4142
unravel_key,
@@ -3678,6 +3679,99 @@ def clone(self) -> Categorical:
36783679
)
36793680

36803681

3682+
class Choice(TensorSpec):
3683+
"""A discrete choice spec for either tensor or non-tensor data.
3684+
3685+
Args:
3686+
stack (:class:`~Stacked`, :class:`~StackedComposite`, or :class:`~tensordict.NonTensorStack`):
3687+
Stack of specs or non-tensor data from which to choose during
3688+
sampling.
3689+
device (str, int or torch.device, optional): device of the tensors.
3690+
3691+
Examples:
3692+
>>> import torch
3693+
>>> _ = torch.manual_seed(0)
3694+
>>> from torchrl.data import Choice, Categorical
3695+
>>> spec = Choice(torch.stack([
3696+
... Categorical(n=4, shape=(1,)),
3697+
... Categorical(n=4, shape=(2,))]))
3698+
>>> spec.shape
3699+
torch.Size([2, -1])
3700+
>>> spec.rand()
3701+
tensor([3])
3702+
>>> spec.rand()
3703+
tensor([0, 3])
3704+
"""
3705+
3706+
def __init__(
3707+
self,
3708+
stack: Stacked | StackedComposite | NonTensorStack,
3709+
device: Optional[DEVICE_TYPING] = None,
3710+
):
3711+
assert isinstance(stack, (Stacked, StackedComposite, NonTensorStack))
3712+
stack = stack.clone()
3713+
if device is not None:
3714+
self._stack = stack.to(device)
3715+
else:
3716+
self._stack = stack
3717+
device = stack.device
3718+
3719+
shape = stack.shape
3720+
dtype = stack.dtype
3721+
3722+
domain = None
3723+
super().__init__(
3724+
shape=shape, space=None, device=device, dtype=dtype, domain=domain
3725+
)
3726+
3727+
def _rand_idx(self):
3728+
return torch.randint(0, len(self._stack), ()).item()
3729+
3730+
def _sample(self, idx, spec_sample_fn) -> TensorDictBase:
3731+
res = self._stack[idx]
3732+
if isinstance(res, TensorSpec):
3733+
return spec_sample_fn(res)
3734+
else:
3735+
return res
3736+
3737+
def zero(self, shape: torch.Size = None) -> TensorDictBase:
3738+
return self._sample(0, lambda x: x.zero(shape))
3739+
3740+
def one(self, shape: torch.Size = None) -> TensorDictBase:
3741+
return self._sample(min(1, len(self - 1)), lambda x: x.one(shape))
3742+
3743+
def rand(self, shape: torch.Size = None) -> TensorDictBase:
3744+
return self._sample(self._rand_idx(), lambda x: x.rand(shape))
3745+
3746+
def is_in(self, val: torch.Tensor | TensorDictBase) -> bool:
3747+
if isinstance(self._stack, (Stacked, StackedComposite)):
3748+
return any([stack_elem.is_in(val) for stack_elem in self._stack])
3749+
else:
3750+
return any([(stack_elem == val).all() for stack_elem in self._stack])
3751+
3752+
def expand(self, *shape):
3753+
raise NotImplementedError
3754+
3755+
def unsqueeze(self, dim: int):
3756+
raise NotImplementedError
3757+
3758+
def clone(self) -> Choice:
3759+
return self.__class__(self._stack)
3760+
3761+
def cardinality(self) -> int:
3762+
if isinstance(self._stack, NonTensorStack):
3763+
return len(self._stack)
3764+
else:
3765+
return (
3766+
torch.tensor([stack_elem.cardinality() for stack_elem in self._stack])
3767+
.sum()
3768+
.item()
3769+
)
3770+
3771+
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Choice:
3772+
return self.__class__(self._stack.to(dest))
3773+
3774+
36813775
@dataclass(repr=False)
36823776
class Binary(Categorical):
36833777
"""A binary discrete tensor spec.

0 commit comments

Comments
 (0)