Skip to content

Commit 3121a93

Browse files
committed
Update
[ghstack-poisoned]
1 parent 078077a commit 3121a93

File tree

2 files changed

+149
-89
lines changed

2 files changed

+149
-89
lines changed

test/test_specs.py

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -689,17 +689,56 @@ class TestChoiceSpec:
689689
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
690690
def test_choice(self, input_type):
691691
if input_type == "spec":
692-
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
692+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
693+
example_in = torch.tensor(11.0)
694+
example_out = torch.tensor(9.0)
693695
elif input_type == "nontensor":
694-
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
696+
choices = [NonTensorData("a"), NonTensorData("b")]
697+
example_in = NonTensorData("b")
698+
example_out = NonTensorData("c")
695699
elif input_type == "nontensorstack":
696-
stack = torch.stack(
697-
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
698-
)
700+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
701+
example_in = NonTensorStack("a", "b", "c")
702+
example_out = NonTensorStack("a", "c", "b")
699703

700-
spec = Choice(stack)
704+
spec = Choice(choices)
701705
res = spec.rand()
702706
assert spec.is_in(res)
707+
assert spec.is_in(example_in)
708+
assert not spec.is_in(example_out)
709+
710+
def test_errors(self):
711+
with pytest.raises(TypeError, match="must be a list"):
712+
Choice("abc")
713+
714+
with pytest.raises(
715+
TypeError,
716+
match="must be either a TensorSpec, NonTensorData, or NonTensorStack",
717+
):
718+
Choice(["abc"])
719+
720+
with pytest.raises(TypeError, match="must be the same type"):
721+
Choice([Bounded(0, 1, (1,)), Categorical(10, (1,))])
722+
723+
with pytest.raises(ValueError, match="must have the same shape"):
724+
Choice([Categorical(10, (1,)), Categorical(10, (2,))])
725+
726+
with pytest.raises(ValueError, match="must have the same dtype"):
727+
Choice(
728+
[
729+
Categorical(10, (2,), dtype=torch.long),
730+
Categorical(10, (2,), dtype=torch.float),
731+
]
732+
)
733+
734+
if torch.cuda.device_count():
735+
with pytest.raises(ValueError, match="must have the same device"):
736+
Choice(
737+
[
738+
Categorical(10, (2,), device="cpu"),
739+
Categorical(10, (2,), device="cuda"),
740+
]
741+
)
703742

704743

705744
@pytest.mark.parametrize("shape", [(), (2, 3)])
@@ -1436,17 +1475,23 @@ def test_non_tensor(self):
14361475
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
14371476
def test_choice(self, input_type):
14381477
if input_type == "spec":
1439-
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
1478+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
14401479
elif input_type == "nontensor":
1441-
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
1480+
choices = [NonTensorData("a"), NonTensorData("b")]
14421481
elif input_type == "nontensorstack":
1443-
stack = torch.stack(
1444-
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1445-
)
1482+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
14461483

1447-
spec = Choice(stack)
1448-
with pytest.raises(NotImplementedError):
1449-
spec.expand((3,))
1484+
spec = Choice(choices)
1485+
res = spec.expand(
1486+
[
1487+
3,
1488+
]
1489+
)
1490+
assert res.shape == torch.Size(
1491+
[
1492+
3,
1493+
]
1494+
)
14501495

14511496
@pytest.mark.parametrize("shape1", [None, (), (5,)])
14521497
@pytest.mark.parametrize("shape2", [(), (10,)])
@@ -1653,15 +1698,13 @@ def test_non_tensor(self):
16531698
@pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"])
16541699
def test_choice(self, input_type):
16551700
if input_type == "spec":
1656-
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
1701+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
16571702
elif input_type == "nontensor":
1658-
stack = torch.stack([NonTensorData("a"), NonTensorData("b")])
1703+
choices = [NonTensorData("a"), NonTensorData("b")]
16591704
elif input_type == "nontensorstack":
1660-
stack = torch.stack(
1661-
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1662-
)
1705+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
16631706

1664-
spec = Choice(stack)
1707+
spec = Choice(choices)
16651708
assert spec.clone() == spec
16661709
assert spec.clone() is not spec
16671710

@@ -1756,19 +1799,15 @@ def test_non_tensor(self):
17561799
)
17571800
def test_choice(self, input_type):
17581801
if input_type == "bounded_spec":
1759-
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
1802+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
17601803
elif input_type == "categorical_spec":
1761-
stack = torch.stack([Categorical(10), Categorical(20)])
1804+
choices = [Categorical(10), Categorical(20)]
17621805
elif input_type == "nontensor":
1763-
stack = torch.stack(
1764-
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
1765-
)
1806+
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
17661807
elif input_type == "nontensorstack":
1767-
stack = torch.stack(
1768-
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
1769-
)
1808+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
17701809

1771-
spec = Choice(stack)
1810+
spec = Choice(choices)
17721811

17731812
if input_type == "bounded_spec":
17741813
assert spec.cardinality() == float("inf")
@@ -2093,19 +2132,15 @@ def test_non_tensor(self, device):
20932132
)
20942133
def test_choice(self, input_type, device):
20952134
if input_type == "bounded_spec":
2096-
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
2135+
choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())]
20972136
elif input_type == "categorical_spec":
2098-
stack = torch.stack([Categorical(10), Categorical(20)])
2137+
choices = [Categorical(10), Categorical(20)]
20992138
elif input_type == "nontensor":
2100-
stack = torch.stack(
2101-
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
2102-
)
2139+
choices = [NonTensorData("a"), NonTensorData("b")]
21032140
elif input_type == "nontensorstack":
2104-
stack = torch.stack(
2105-
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
2106-
)
2141+
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
21072142

2108-
spec = Choice(stack, device="cpu")
2143+
spec = Choice(choices)
21092144
assert spec.to(device).device == device
21102145

21112146
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
@@ -2376,26 +2411,30 @@ def test_stack_non_tensor(self, shape, stack_dim):
23762411

23772412
@pytest.mark.parametrize(
23782413
"input_type",
2379-
["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"],
2414+
["bounded_spec", "categorical_spec", "nontensor"],
23802415
)
23812416
def test_stack_choice(self, input_type, shape, stack_dim):
23822417
if input_type == "bounded_spec":
2383-
stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())])
2418+
choices = [Bounded(0, 2.5, shape), Bounded(10, 12, shape)]
23842419
elif input_type == "categorical_spec":
2385-
stack = torch.stack([Categorical(10), Categorical(20)])
2420+
choices = [Categorical(10, shape), Categorical(20, shape)]
23862421
elif input_type == "nontensor":
2387-
stack = torch.stack(
2388-
[NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
2389-
)
2390-
elif input_type == "nontensorstack":
2391-
stack = torch.stack(
2392-
[NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
2393-
)
2422+
if len(shape) == 0:
2423+
choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")]
2424+
else:
2425+
choices = [
2426+
NonTensorStack("a").expand(shape + (1,)).squeeze(-1),
2427+
NonTensorStack("d").expand(shape + (1,)).squeeze(-1),
2428+
]
23942429

2395-
spec0 = Choice(stack)
2396-
spec1 = Choice(stack)
2397-
with pytest.raises(NotImplementedError):
2398-
torch.stack([spec0, spec1], 0)
2430+
spec0 = Choice(choices)
2431+
spec1 = Choice(choices)
2432+
res = torch.stack([spec0, spec1], stack_dim)
2433+
assert isinstance(res, Choice)
2434+
assert (
2435+
res.shape
2436+
== torch.stack([torch.empty(shape), torch.empty(shape)], stack_dim).shape
2437+
)
23992438

24002439
def test_stack_onehot(self, shape, stack_dim):
24012440
n = 5

torchrl/data/tensor_specs.py

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3683,52 +3683,73 @@ class Choice(TensorSpec):
36833683
"""A discrete choice spec for either tensor or non-tensor data.
36843684
36853685
Args:
3686-
stack (:class:`~Stacked`, :class:`~StackedComposite`, or :class:`~tensordict.NonTensorStack`):
3687-
Stack of specs or non-tensor data from which to choose during
3688-
sampling.
3689-
device (str, int or torch.device, optional): device of the tensors.
3686+
choices (list[:class:`~TensorSpec`, :class:`~tensordict.NonTensorData`, :class:`~tensordict.NonTensorStack`]):
3687+
List of specs or non-tensor data from which to choose during
3688+
sampling. All elements must have the same type, shape, dtype, and
3689+
device.
36903690
36913691
Examples:
36923692
>>> import torch
36933693
>>> _ = torch.manual_seed(0)
3694-
>>> from torchrl.data import Choice, Categorical
3695-
>>> spec = Choice(torch.stack([
3696-
... Categorical(n=4, shape=(1,)),
3697-
... Categorical(n=4, shape=(2,))]))
3698-
>>> spec.shape
3699-
torch.Size([2, -1])
3694+
>>> from torchrl.data import Choice, Bounded
3695+
>>> spec = Choice([
3696+
... Bounded(0, 1, shape=(1,)),
3697+
... Bounded(10, 11, shape=(1,))])
37003698
>>> spec.rand()
3701-
tensor([3])
3699+
tensor([0.7682])
37023700
>>> spec.rand()
3703-
tensor([0, 3])
3701+
tensor([10.1320])
3702+
>>> from tensordict import NonTensorData
3703+
>>> _ = torch.manual_seed(0)
3704+
>>> spec = Choice([NonTensorData(s) for s in ["a", "b", "c", "d"]])
3705+
>>> spec.rand().data
3706+
'a'
3707+
>>> spec.rand().data
3708+
'd'
37043709
"""
37053710

37063711
def __init__(
37073712
self,
3708-
stack: Stacked | StackedComposite | NonTensorStack,
3709-
device: Optional[DEVICE_TYPING] = None,
3713+
choices: List[TensorSpec | NonTensorData | NonTensorStack],
37103714
):
3711-
assert isinstance(stack, (Stacked, StackedComposite, NonTensorStack))
3712-
stack = stack.clone()
3713-
if device is not None:
3714-
self._stack = stack.to(device)
3715-
else:
3716-
self._stack = stack
3717-
device = stack.device
3715+
if not isinstance(choices, list):
3716+
raise TypeError("'choices' must be a list")
37183717

3719-
shape = stack.shape
3720-
dtype = stack.dtype
3718+
if not isinstance(choices[0], (TensorSpec, NonTensorData, NonTensorStack)):
3719+
raise TypeError(
3720+
(
3721+
"Each choice must be either a TensorSpec, NonTensorData, or "
3722+
f"NonTensorStack, but got {type(choices[0])}"
3723+
)
3724+
)
3725+
3726+
if not all([isinstance(choice, type(choices[0])) for choice in choices[1:]]):
3727+
raise TypeError("All choices must be the same type")
3728+
3729+
if not all([choice.shape == choices[0].shape for choice in choices[1:]]):
3730+
raise ValueError("All choices must have the same shape")
3731+
3732+
if not all([choice.dtype == choices[0].dtype for choice in choices[1:]]):
3733+
raise ValueError("All choices must have the same dtype")
3734+
3735+
if not all([choice.device == choices[0].device for choice in choices[1:]]):
3736+
raise ValueError("All choices must have the same device")
3737+
3738+
shape = choices[0].shape
3739+
device = choices[0].device
3740+
dtype = choices[0].dtype
37213741

3722-
domain = None
37233742
super().__init__(
3724-
shape=shape, space=None, device=device, dtype=dtype, domain=domain
3743+
shape=shape, space=None, device=device, dtype=dtype, domain=None
37253744
)
37263745

3746+
self._choices = [choice.clone() for choice in choices]
3747+
37273748
def _rand_idx(self):
3728-
return torch.randint(0, len(self._stack), ()).item()
3749+
return torch.randint(0, len(self._choices), ()).item()
37293750

37303751
def _sample(self, idx, spec_sample_fn) -> TensorDictBase:
3731-
res = self._stack[idx]
3752+
res = self._choices[idx]
37323753
if isinstance(res, TensorSpec):
37333754
return spec_sample_fn(res)
37343755
else:
@@ -3744,32 +3765,32 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
37443765
return self._sample(self._rand_idx(), lambda x: x.rand(shape))
37453766

37463767
def is_in(self, val: torch.Tensor | TensorDictBase) -> bool:
3747-
if isinstance(self._stack, (Stacked, StackedComposite)):
3748-
return any([stack_elem.is_in(val) for stack_elem in self._stack])
3768+
if isinstance(self._choices[0], TensorSpec):
3769+
return any([choice.is_in(val) for choice in self._choices])
37493770
else:
3750-
return any([(stack_elem == val).all() for stack_elem in self._stack])
3771+
return any([(choice == val).all() for choice in self._choices])
37513772

37523773
def expand(self, *shape):
3753-
raise NotImplementedError
3774+
return self.__class__([choice.expand(*shape) for choice in self._choices])
37543775

37553776
def unsqueeze(self, dim: int):
3756-
raise NotImplementedError
3777+
return self.__class__([choice.unsqueeze(dim) for choice in self._choices])
37573778

37583779
def clone(self) -> Choice:
3759-
return self.__class__(self._stack)
3780+
return self.__class__([choice.clone() for choice in self._choices])
37603781

37613782
def cardinality(self) -> int:
3762-
if isinstance(self._stack, NonTensorStack):
3763-
return len(self._stack)
3783+
if isinstance(self._choices[0], (NonTensorData, NonTensorStack)):
3784+
return len(self._choices)
37643785
else:
37653786
return (
3766-
torch.tensor([stack_elem.cardinality() for stack_elem in self._stack])
3787+
torch.tensor([choice.cardinality() for choice in self._choices])
37673788
.sum()
37683789
.item()
37693790
)
37703791

37713792
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Choice:
3772-
return self.__class__(self._stack.to(dest))
3793+
return self.__class__([choice.to(dest) for choice in self._choices])
37733794

37743795

37753796
@dataclass(repr=False)

0 commit comments

Comments
 (0)