@@ -689,17 +689,56 @@ class TestChoiceSpec:
689
689
@pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
690
690
def test_choice (self , input_type ):
691
691
if input_type == "spec" :
692
- stack = torch .stack ([Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())])
692
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
693
+ example_in = torch .tensor (11.0 )
694
+ example_out = torch .tensor (9.0 )
693
695
elif input_type == "nontensor" :
694
- stack = torch .stack ([NonTensorData ("a" ), NonTensorData ("b" )])
696
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
697
+ example_in = NonTensorData ("b" )
698
+ example_out = NonTensorData ("c" )
695
699
elif input_type == "nontensorstack" :
696
- stack = torch . stack (
697
- [ NonTensorStack ("a" , "b" , "c" ), NonTensorStack ( "d" , "e" , "f" )]
698
- )
700
+ choices = [ NonTensorStack ( "a" , "b" , "c" ), NonTensorStack ( "d" , "e" , "f" )]
701
+ example_in = NonTensorStack ("a" , "b" , "c" )
702
+ example_out = NonTensorStack ( "a" , "c" , "b" )
699
703
700
- spec = Choice (stack )
704
+ spec = Choice (choices )
701
705
res = spec .rand ()
702
706
assert spec .is_in (res )
707
+ assert spec .is_in (example_in )
708
+ assert not spec .is_in (example_out )
709
+
710
+ def test_errors (self ):
711
+ with pytest .raises (TypeError , match = "must be a list" ):
712
+ Choice ("abc" )
713
+
714
+ with pytest .raises (
715
+ TypeError ,
716
+ match = "must be either a TensorSpec, NonTensorData, or NonTensorStack" ,
717
+ ):
718
+ Choice (["abc" ])
719
+
720
+ with pytest .raises (TypeError , match = "must be the same type" ):
721
+ Choice ([Bounded (0 , 1 , (1 ,)), Categorical (10 , (1 ,))])
722
+
723
+ with pytest .raises (ValueError , match = "must have the same shape" ):
724
+ Choice ([Categorical (10 , (1 ,)), Categorical (10 , (2 ,))])
725
+
726
+ with pytest .raises (ValueError , match = "must have the same dtype" ):
727
+ Choice (
728
+ [
729
+ Categorical (10 , (2 ,), dtype = torch .long ),
730
+ Categorical (10 , (2 ,), dtype = torch .float ),
731
+ ]
732
+ )
733
+
734
+ if torch .cuda .device_count ():
735
+ with pytest .raises (ValueError , match = "must have the same device" ):
736
+ Choice (
737
+ [
738
+ Categorical (10 , (2 ,), device = "cpu" ),
739
+ Categorical (10 , (2 ,), device = "cuda" ),
740
+ ]
741
+ )
703
742
704
743
705
744
@pytest .mark .parametrize ("shape" , [(), (2 , 3 )])
@@ -1436,17 +1475,23 @@ def test_non_tensor(self):
1436
1475
@pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
1437
1476
def test_choice (self , input_type ):
1438
1477
if input_type == "spec" :
1439
- stack = torch . stack ( [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())])
1478
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1440
1479
elif input_type == "nontensor" :
1441
- stack = torch . stack ( [NonTensorData ("a" ), NonTensorData ("b" )])
1480
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
1442
1481
elif input_type == "nontensorstack" :
1443
- stack = torch .stack (
1444
- [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1445
- )
1482
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1446
1483
1447
- spec = Choice (stack )
1448
- with pytest .raises (NotImplementedError ):
1449
- spec .expand ((3 ,))
1484
+ spec = Choice (choices )
1485
+ res = spec .expand (
1486
+ [
1487
+ 3 ,
1488
+ ]
1489
+ )
1490
+ assert res .shape == torch .Size (
1491
+ [
1492
+ 3 ,
1493
+ ]
1494
+ )
1450
1495
1451
1496
@pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
1452
1497
@pytest .mark .parametrize ("shape2" , [(), (10 ,)])
@@ -1653,15 +1698,13 @@ def test_non_tensor(self):
1653
1698
@pytest .mark .parametrize ("input_type" , ["spec" , "nontensor" , "nontensorstack" ])
1654
1699
def test_choice (self , input_type ):
1655
1700
if input_type == "spec" :
1656
- stack = torch . stack ( [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())])
1701
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1657
1702
elif input_type == "nontensor" :
1658
- stack = torch . stack ( [NonTensorData ("a" ), NonTensorData ("b" )])
1703
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
1659
1704
elif input_type == "nontensorstack" :
1660
- stack = torch .stack (
1661
- [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1662
- )
1705
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1663
1706
1664
- spec = Choice (stack )
1707
+ spec = Choice (choices )
1665
1708
assert spec .clone () == spec
1666
1709
assert spec .clone () is not spec
1667
1710
@@ -1756,19 +1799,15 @@ def test_non_tensor(self):
1756
1799
)
1757
1800
def test_choice (self , input_type ):
1758
1801
if input_type == "bounded_spec" :
1759
- stack = torch . stack ( [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())])
1802
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
1760
1803
elif input_type == "categorical_spec" :
1761
- stack = torch . stack ( [Categorical (10 ), Categorical (20 )])
1804
+ choices = [Categorical (10 ), Categorical (20 )]
1762
1805
elif input_type == "nontensor" :
1763
- stack = torch .stack (
1764
- [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
1765
- )
1806
+ choices = [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
1766
1807
elif input_type == "nontensorstack" :
1767
- stack = torch .stack (
1768
- [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1769
- )
1808
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1770
1809
1771
- spec = Choice (stack )
1810
+ spec = Choice (choices )
1772
1811
1773
1812
if input_type == "bounded_spec" :
1774
1813
assert spec .cardinality () == float ("inf" )
@@ -2093,19 +2132,15 @@ def test_non_tensor(self, device):
2093
2132
)
2094
2133
def test_choice (self , input_type , device ):
2095
2134
if input_type == "bounded_spec" :
2096
- stack = torch . stack ( [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())])
2135
+ choices = [Bounded (0 , 2.5 , ()), Bounded (10 , 12 , ())]
2097
2136
elif input_type == "categorical_spec" :
2098
- stack = torch . stack ( [Categorical (10 ), Categorical (20 )])
2137
+ choices = [Categorical (10 ), Categorical (20 )]
2099
2138
elif input_type == "nontensor" :
2100
- stack = torch .stack (
2101
- [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
2102
- )
2139
+ choices = [NonTensorData ("a" ), NonTensorData ("b" )]
2103
2140
elif input_type == "nontensorstack" :
2104
- stack = torch .stack (
2105
- [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
2106
- )
2141
+ choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
2107
2142
2108
- spec = Choice (stack , device = "cpu" )
2143
+ spec = Choice (choices )
2109
2144
assert spec .to (device ).device == device
2110
2145
2111
2146
@pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
@@ -2376,26 +2411,30 @@ def test_stack_non_tensor(self, shape, stack_dim):
2376
2411
2377
2412
@pytest .mark .parametrize (
2378
2413
"input_type" ,
2379
- ["bounded_spec" , "categorical_spec" , "nontensor" , "nontensorstack" ],
2414
+ ["bounded_spec" , "categorical_spec" , "nontensor" ],
2380
2415
)
2381
2416
def test_stack_choice (self , input_type , shape , stack_dim ):
2382
2417
if input_type == "bounded_spec" :
2383
- stack = torch . stack ( [Bounded (0 , 2.5 , ()) , Bounded (10 , 12 , ())])
2418
+ choices = [Bounded (0 , 2.5 , shape ) , Bounded (10 , 12 , shape )]
2384
2419
elif input_type == "categorical_spec" :
2385
- stack = torch . stack ( [Categorical (10 ), Categorical (20 )])
2420
+ choices = [Categorical (10 , shape ), Categorical (20 , shape )]
2386
2421
elif input_type == "nontensor" :
2387
- stack = torch . stack (
2388
- [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
2389
- )
2390
- elif input_type == "nontensorstack" :
2391
- stack = torch . stack (
2392
- [ NonTensorStack ( "a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
2393
- )
2422
+ if len ( shape ) == 0 :
2423
+ choices = [NonTensorData ("a" ), NonTensorData ("b" ), NonTensorData ("c" )]
2424
+ else :
2425
+ choices = [
2426
+ NonTensorStack ( "a" ). expand ( shape + ( 1 ,)). squeeze ( - 1 ),
2427
+ NonTensorStack ("d" ). expand ( shape + ( 1 ,)). squeeze ( - 1 ),
2428
+ ]
2394
2429
2395
- spec0 = Choice (stack )
2396
- spec1 = Choice (stack )
2397
- with pytest .raises (NotImplementedError ):
2398
- torch .stack ([spec0 , spec1 ], 0 )
2430
+ spec0 = Choice (choices )
2431
+ spec1 = Choice (choices )
2432
+ res = torch .stack ([spec0 , spec1 ], stack_dim )
2433
+ assert isinstance (res , Choice )
2434
+ assert (
2435
+ res .shape
2436
+ == torch .stack ([torch .empty (shape ), torch .empty (shape )], stack_dim ).shape
2437
+ )
2399
2438
2400
2439
def test_stack_onehot (self , shape , stack_dim ):
2401
2440
n = 5
0 commit comments