|
12 | 12 | import torch
|
13 | 13 | import torchrl.data.tensor_specs
|
14 | 14 | from scipy.stats import chisquare
|
15 |
| -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase |
| 15 | +from tensordict import ( |
| 16 | + LazyStackedTensorDict, |
| 17 | + NonTensorData, |
| 18 | + NonTensorStack, |
| 19 | + TensorDict, |
| 20 | + TensorDictBase, |
| 21 | +) |
16 | 22 | from tensordict.utils import _unravel_key_to_tuple
|
17 | 23 | from torchrl._utils import _make_ordinal_device
|
18 | 24 |
|
|
23 | 29 | Bounded,
|
24 | 30 | BoundedTensorSpec,
|
25 | 31 | Categorical,
|
| 32 | + Choice, |
26 | 33 | Composite,
|
27 | 34 | CompositeSpec,
|
28 | 35 | ContinuousBox,
|
@@ -702,6 +709,62 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
|
702 | 709 | assert ts["nested"].shape == (3,)
|
703 | 710 |
|
704 | 711 |
|
| 712 | +class TestChoiceSpec: |
| 713 | + @pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"]) |
| 714 | + def test_choice(self, input_type): |
| 715 | + if input_type == "spec": |
| 716 | + choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())] |
| 717 | + example_in = torch.tensor(11.0) |
| 718 | + example_out = torch.tensor(9.0) |
| 719 | + elif input_type == "nontensor": |
| 720 | + choices = [NonTensorData("a"), NonTensorData("b")] |
| 721 | + example_in = NonTensorData("b") |
| 722 | + example_out = NonTensorData("c") |
| 723 | + elif input_type == "nontensorstack": |
| 724 | + choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 725 | + example_in = NonTensorStack("a", "b", "c") |
| 726 | + example_out = NonTensorStack("a", "c", "b") |
| 727 | + |
| 728 | + spec = Choice(choices) |
| 729 | + res = spec.rand() |
| 730 | + assert spec.is_in(res) |
| 731 | + assert spec.is_in(example_in) |
| 732 | + assert not spec.is_in(example_out) |
| 733 | + |
| 734 | + def test_errors(self): |
| 735 | + with pytest.raises(TypeError, match="must be a list"): |
| 736 | + Choice("abc") |
| 737 | + |
| 738 | + with pytest.raises( |
| 739 | + TypeError, |
| 740 | + match="must be either a TensorSpec, NonTensorData, or NonTensorStack", |
| 741 | + ): |
| 742 | + Choice(["abc"]) |
| 743 | + |
| 744 | + with pytest.raises(TypeError, match="must be the same type"): |
| 745 | + Choice([Bounded(0, 1, (1,)), Categorical(10, (1,))]) |
| 746 | + |
| 747 | + with pytest.raises(ValueError, match="must have the same shape"): |
| 748 | + Choice([Categorical(10, (1,)), Categorical(10, (2,))]) |
| 749 | + |
| 750 | + with pytest.raises(ValueError, match="must have the same dtype"): |
| 751 | + Choice( |
| 752 | + [ |
| 753 | + Categorical(10, (2,), dtype=torch.long), |
| 754 | + Categorical(10, (2,), dtype=torch.float), |
| 755 | + ] |
| 756 | + ) |
| 757 | + |
| 758 | + if torch.cuda.device_count(): |
| 759 | + with pytest.raises(ValueError, match="must have the same device"): |
| 760 | + Choice( |
| 761 | + [ |
| 762 | + Categorical(10, (2,), device="cpu"), |
| 763 | + Categorical(10, (2,), device="cuda"), |
| 764 | + ] |
| 765 | + ) |
| 766 | + |
| 767 | + |
705 | 768 | @pytest.mark.parametrize("shape", [(), (2, 3)])
|
706 | 769 | @pytest.mark.parametrize("device", get_default_devices())
|
707 | 770 | def test_create_composite_nested(shape, device):
|
@@ -1498,6 +1561,27 @@ def test_non_tensor(self):
|
1498 | 1561 | )
|
1499 | 1562 | assert spec.expand(2, 3, 4).example_data == "example_data"
|
1500 | 1563 |
|
| 1564 | + @pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"]) |
| 1565 | + def test_choice(self, input_type): |
| 1566 | + if input_type == "spec": |
| 1567 | + choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())] |
| 1568 | + elif input_type == "nontensor": |
| 1569 | + choices = [NonTensorData("a"), NonTensorData("b")] |
| 1570 | + elif input_type == "nontensorstack": |
| 1571 | + choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 1572 | + |
| 1573 | + spec = Choice(choices) |
| 1574 | + res = spec.expand( |
| 1575 | + [ |
| 1576 | + 3, |
| 1577 | + ] |
| 1578 | + ) |
| 1579 | + assert res.shape == torch.Size( |
| 1580 | + [ |
| 1581 | + 3, |
| 1582 | + ] |
| 1583 | + ) |
| 1584 | + |
1501 | 1585 | @pytest.mark.parametrize("shape1", [None, (), (5,)])
|
1502 | 1586 | @pytest.mark.parametrize("shape2", [(), (10,)])
|
1503 | 1587 | def test_onehot(self, shape1, shape2):
|
@@ -1701,6 +1785,19 @@ def test_non_tensor(self):
|
1701 | 1785 | assert spec.clone() is not spec
|
1702 | 1786 | assert spec.clone().example_data == "example_data"
|
1703 | 1787 |
|
| 1788 | + @pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"]) |
| 1789 | + def test_choice(self, input_type): |
| 1790 | + if input_type == "spec": |
| 1791 | + choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())] |
| 1792 | + elif input_type == "nontensor": |
| 1793 | + choices = [NonTensorData("a"), NonTensorData("b")] |
| 1794 | + elif input_type == "nontensorstack": |
| 1795 | + choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 1796 | + |
| 1797 | + spec = Choice(choices) |
| 1798 | + assert spec.clone() == spec |
| 1799 | + assert spec.clone() is not spec |
| 1800 | + |
1704 | 1801 | @pytest.mark.parametrize("shape1", [None, (), (5,)])
|
1705 | 1802 | def test_onehot(
|
1706 | 1803 | self,
|
@@ -1786,6 +1883,31 @@ def test_non_tensor(self):
|
1786 | 1883 | with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
|
1787 | 1884 | spec.cardinality()
|
1788 | 1885 |
|
| 1886 | + @pytest.mark.parametrize( |
| 1887 | + "input_type", |
| 1888 | + ["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"], |
| 1889 | + ) |
| 1890 | + def test_choice(self, input_type): |
| 1891 | + if input_type == "bounded_spec": |
| 1892 | + choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())] |
| 1893 | + elif input_type == "categorical_spec": |
| 1894 | + choices = [Categorical(10), Categorical(20)] |
| 1895 | + elif input_type == "nontensor": |
| 1896 | + choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")] |
| 1897 | + elif input_type == "nontensorstack": |
| 1898 | + choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 1899 | + |
| 1900 | + spec = Choice(choices) |
| 1901 | + |
| 1902 | + if input_type == "bounded_spec": |
| 1903 | + assert spec.cardinality() == float("inf") |
| 1904 | + elif input_type == "categorical_spec": |
| 1905 | + assert spec.cardinality() == 30 |
| 1906 | + elif input_type == "nontensor": |
| 1907 | + assert spec.cardinality() == 3 |
| 1908 | + elif input_type == "nontensorstack": |
| 1909 | + assert spec.cardinality() == 2 |
| 1910 | + |
1789 | 1911 | @pytest.mark.parametrize("shape1", [(5,), (5, 6)])
|
1790 | 1912 | def test_onehot(
|
1791 | 1913 | self,
|
@@ -2096,6 +2218,23 @@ def test_non_tensor(self, device):
|
2096 | 2218 | assert spec.to(device).device == device
|
2097 | 2219 | assert spec.to(device).example_data == "example_data"
|
2098 | 2220 |
|
| 2221 | + @pytest.mark.parametrize( |
| 2222 | + "input_type", |
| 2223 | + ["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"], |
| 2224 | + ) |
| 2225 | + def test_choice(self, input_type, device): |
| 2226 | + if input_type == "bounded_spec": |
| 2227 | + choices = [Bounded(0, 2.5, ()), Bounded(10, 12, ())] |
| 2228 | + elif input_type == "categorical_spec": |
| 2229 | + choices = [Categorical(10), Categorical(20)] |
| 2230 | + elif input_type == "nontensor": |
| 2231 | + choices = [NonTensorData("a"), NonTensorData("b")] |
| 2232 | + elif input_type == "nontensorstack": |
| 2233 | + choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 2234 | + |
| 2235 | + spec = Choice(choices) |
| 2236 | + assert spec.to(device).device == device |
| 2237 | + |
2099 | 2238 | @pytest.mark.parametrize("shape1", [(5,), (5, 6)])
|
2100 | 2239 | def test_onehot(self, shape1, device):
|
2101 | 2240 | if shape1 is None:
|
@@ -2363,6 +2502,33 @@ def test_stack_non_tensor(self, shape, stack_dim):
|
2363 | 2502 | assert new_spec.device == torch.device("cpu")
|
2364 | 2503 | assert new_spec.example_data == "example_data"
|
2365 | 2504 |
|
| 2505 | + @pytest.mark.parametrize( |
| 2506 | + "input_type", |
| 2507 | + ["bounded_spec", "categorical_spec", "nontensor"], |
| 2508 | + ) |
| 2509 | + def test_stack_choice(self, input_type, shape, stack_dim): |
| 2510 | + if input_type == "bounded_spec": |
| 2511 | + choices = [Bounded(0, 2.5, shape), Bounded(10, 12, shape)] |
| 2512 | + elif input_type == "categorical_spec": |
| 2513 | + choices = [Categorical(10, shape), Categorical(20, shape)] |
| 2514 | + elif input_type == "nontensor": |
| 2515 | + if len(shape) == 0: |
| 2516 | + choices = [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")] |
| 2517 | + else: |
| 2518 | + choices = [ |
| 2519 | + NonTensorStack("a").expand(shape + (1,)).squeeze(-1), |
| 2520 | + NonTensorStack("d").expand(shape + (1,)).squeeze(-1), |
| 2521 | + ] |
| 2522 | + |
| 2523 | + spec0 = Choice(choices) |
| 2524 | + spec1 = Choice(choices) |
| 2525 | + res = torch.stack([spec0, spec1], stack_dim) |
| 2526 | + assert isinstance(res, Choice) |
| 2527 | + assert ( |
| 2528 | + res.shape |
| 2529 | + == torch.stack([torch.empty(shape), torch.empty(shape)], stack_dim).shape |
| 2530 | + ) |
| 2531 | + |
2366 | 2532 | def test_stack_onehot(self, shape, stack_dim):
|
2367 | 2533 | n = 5
|
2368 | 2534 | shape = (*shape, 5)
|
|
0 commit comments