|
12 | 12 | import torch
|
13 | 13 | import torchrl.data.tensor_specs
|
14 | 14 | from scipy.stats import chisquare
|
15 |
| -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase |
| 15 | +from tensordict import ( |
| 16 | + LazyStackedTensorDict, |
| 17 | + NonTensorData, |
| 18 | + NonTensorStack, |
| 19 | + TensorDict, |
| 20 | + TensorDictBase, |
| 21 | +) |
16 | 22 | from tensordict.utils import _unravel_key_to_tuple
|
17 | 23 | from torchrl._utils import _make_ordinal_device
|
18 | 24 |
|
|
23 | 29 | Bounded,
|
24 | 30 | BoundedTensorSpec,
|
25 | 31 | Categorical,
|
| 32 | + Choice, |
26 | 33 | Composite,
|
27 | 34 | CompositeSpec,
|
28 | 35 | ContinuousBox,
|
@@ -678,6 +685,23 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
|
678 | 685 | assert ts["nested"].shape == (3,)
|
679 | 686 |
|
680 | 687 |
|
| 688 | +class TestChoiceSpec: |
| 689 | + @pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"]) |
| 690 | + def test_choice(self, input_type): |
| 691 | + if input_type == "spec": |
| 692 | + stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())]) |
| 693 | + elif input_type == "nontensor": |
| 694 | + stack = torch.stack([NonTensorData("a"), NonTensorData("b")]) |
| 695 | + elif input_type == "nontensorstack": |
| 696 | + stack = torch.stack( |
| 697 | + [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 698 | + ) |
| 699 | + |
| 700 | + spec = Choice(stack) |
| 701 | + res = spec.rand() |
| 702 | + assert spec.is_in(res) |
| 703 | + |
| 704 | + |
681 | 705 | @pytest.mark.parametrize("shape", [(), (2, 3)])
|
682 | 706 | @pytest.mark.parametrize("device", get_default_devices())
|
683 | 707 | def test_create_composite_nested(shape, device):
|
@@ -1409,6 +1433,21 @@ def test_non_tensor(self):
|
1409 | 1433 | == NonTensor((2, 3, 4), device="cpu")
|
1410 | 1434 | )
|
1411 | 1435 |
|
| 1436 | + @pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"]) |
| 1437 | + def test_choice(self, input_type): |
| 1438 | + if input_type == "spec": |
| 1439 | + stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())]) |
| 1440 | + elif input_type == "nontensor": |
| 1441 | + stack = torch.stack([NonTensorData("a"), NonTensorData("b")]) |
| 1442 | + elif input_type == "nontensorstack": |
| 1443 | + stack = torch.stack( |
| 1444 | + [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 1445 | + ) |
| 1446 | + |
| 1447 | + spec = Choice(stack) |
| 1448 | + with pytest.raises(NotImplementedError): |
| 1449 | + spec.expand((3,)) |
| 1450 | + |
1412 | 1451 | @pytest.mark.parametrize("shape1", [None, (), (5,)])
|
1413 | 1452 | @pytest.mark.parametrize("shape2", [(), (10,)])
|
1414 | 1453 | def test_onehot(self, shape1, shape2):
|
@@ -1611,6 +1650,21 @@ def test_non_tensor(self):
|
1611 | 1650 | assert spec.clone() == spec
|
1612 | 1651 | assert spec.clone() is not spec
|
1613 | 1652 |
|
| 1653 | + @pytest.mark.parametrize("input_type", ["spec", "nontensor", "nontensorstack"]) |
| 1654 | + def test_choice(self, input_type): |
| 1655 | + if input_type == "spec": |
| 1656 | + stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())]) |
| 1657 | + elif input_type == "nontensor": |
| 1658 | + stack = torch.stack([NonTensorData("a"), NonTensorData("b")]) |
| 1659 | + elif input_type == "nontensorstack": |
| 1660 | + stack = torch.stack( |
| 1661 | + [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 1662 | + ) |
| 1663 | + |
| 1664 | + spec = Choice(stack) |
| 1665 | + assert spec.clone() == spec |
| 1666 | + assert spec.clone() is not spec |
| 1667 | + |
1614 | 1668 | @pytest.mark.parametrize("shape1", [None, (), (5,)])
|
1615 | 1669 | def test_onehot(
|
1616 | 1670 | self,
|
@@ -1696,6 +1750,35 @@ def test_non_tensor(self):
|
1696 | 1750 | with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
|
1697 | 1751 | spec.cardinality()
|
1698 | 1752 |
|
| 1753 | + @pytest.mark.parametrize( |
| 1754 | + "input_type", |
| 1755 | + ["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"], |
| 1756 | + ) |
| 1757 | + def test_choice(self, input_type): |
| 1758 | + if input_type == "bounded_spec": |
| 1759 | + stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())]) |
| 1760 | + elif input_type == "categorical_spec": |
| 1761 | + stack = torch.stack([Categorical(10), Categorical(20)]) |
| 1762 | + elif input_type == "nontensor": |
| 1763 | + stack = torch.stack( |
| 1764 | + [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")] |
| 1765 | + ) |
| 1766 | + elif input_type == "nontensorstack": |
| 1767 | + stack = torch.stack( |
| 1768 | + [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 1769 | + ) |
| 1770 | + |
| 1771 | + spec = Choice(stack) |
| 1772 | + |
| 1773 | + if input_type == "bounded_spec": |
| 1774 | + assert spec.cardinality() == float("inf") |
| 1775 | + elif input_type == "categorical_spec": |
| 1776 | + assert spec.cardinality() == 30 |
| 1777 | + elif input_type == "nontensor": |
| 1778 | + assert spec.cardinality() == 3 |
| 1779 | + elif input_type == "nontensorstack": |
| 1780 | + assert spec.cardinality() == 2 |
| 1781 | + |
1699 | 1782 | @pytest.mark.parametrize("shape1", [(5,), (5, 6)])
|
1700 | 1783 | def test_onehot(
|
1701 | 1784 | self,
|
@@ -2004,6 +2087,27 @@ def test_non_tensor(self, device):
|
2004 | 2087 | spec = NonTensor(shape=(3, 4), device="cpu")
|
2005 | 2088 | assert spec.to(device).device == device
|
2006 | 2089 |
|
| 2090 | + @pytest.mark.parametrize( |
| 2091 | + "input_type", |
| 2092 | + ["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"], |
| 2093 | + ) |
| 2094 | + def test_choice(self, input_type, device): |
| 2095 | + if input_type == "bounded_spec": |
| 2096 | + stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())]) |
| 2097 | + elif input_type == "categorical_spec": |
| 2098 | + stack = torch.stack([Categorical(10), Categorical(20)]) |
| 2099 | + elif input_type == "nontensor": |
| 2100 | + stack = torch.stack( |
| 2101 | + [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")] |
| 2102 | + ) |
| 2103 | + elif input_type == "nontensorstack": |
| 2104 | + stack = torch.stack( |
| 2105 | + [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 2106 | + ) |
| 2107 | + |
| 2108 | + spec = Choice(stack, device="cpu") |
| 2109 | + assert spec.to(device).device == device |
| 2110 | + |
2007 | 2111 | @pytest.mark.parametrize("shape1", [(5,), (5, 6)])
|
2008 | 2112 | def test_onehot(self, shape1, device):
|
2009 | 2113 | if shape1 is None:
|
@@ -2270,6 +2374,29 @@ def test_stack_non_tensor(self, shape, stack_dim):
|
2270 | 2374 | assert new_spec.shape == torch.Size(shape_insert)
|
2271 | 2375 | assert new_spec.device == torch.device("cpu")
|
2272 | 2376 |
|
| 2377 | + @pytest.mark.parametrize( |
| 2378 | + "input_type", |
| 2379 | + ["bounded_spec", "categorical_spec", "nontensor", "nontensorstack"], |
| 2380 | + ) |
| 2381 | + def test_stack_choice(self, input_type, shape, stack_dim): |
| 2382 | + if input_type == "bounded_spec": |
| 2383 | + stack = torch.stack([Bounded(0, 2.5, ()), Bounded(10, 12, ())]) |
| 2384 | + elif input_type == "categorical_spec": |
| 2385 | + stack = torch.stack([Categorical(10), Categorical(20)]) |
| 2386 | + elif input_type == "nontensor": |
| 2387 | + stack = torch.stack( |
| 2388 | + [NonTensorData("a"), NonTensorData("b"), NonTensorData("c")] |
| 2389 | + ) |
| 2390 | + elif input_type == "nontensorstack": |
| 2391 | + stack = torch.stack( |
| 2392 | + [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")] |
| 2393 | + ) |
| 2394 | + |
| 2395 | + spec0 = Choice(stack) |
| 2396 | + spec1 = Choice(stack) |
| 2397 | + with pytest.raises(NotImplementedError): |
| 2398 | + torch.stack([spec0, spec1], 0) |
| 2399 | + |
2273 | 2400 | def test_stack_onehot(self, shape, stack_dim):
|
2274 | 2401 | n = 5
|
2275 | 2402 | shape = (*shape, 5)
|
|
0 commit comments