@@ -724,9 +724,10 @@ def test_choice(self, input_type):
724
724
choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
725
725
example_in = NonTensorStack ("a" , "b" , "c" )
726
726
example_out = NonTensorStack ("a" , "c" , "b" )
727
-
728
- spec = Choice (choices )
729
- res = spec .rand ()
727
+ torch .manual_seed (0 )
728
+ for _ in range (10 ):
729
+ spec = Choice (choices )
730
+ res = spec .rand ()
730
731
assert spec .is_in (res )
731
732
assert spec .is_in (example_in )
732
733
assert not spec .is_in (example_out )
@@ -755,7 +756,7 @@ def test_errors(self):
755
756
]
756
757
)
757
758
758
- if torch .cuda .device_count ():
759
+ if torch .cuda .is_available ():
759
760
with pytest .raises (ValueError , match = "must have the same device" ):
760
761
Choice (
761
762
[
@@ -914,7 +915,7 @@ def test_equality_bounded(self):
914
915
915
916
ts_other = Bounded (minimum , maximum + 1 , torch .Size ((1 ,)), device , dtype )
916
917
assert ts != ts_other
917
- if torch .cuda .device_count ():
918
+ if torch .cuda .is_available ():
918
919
ts_other = Bounded (minimum , maximum , torch .Size ((1 ,)), "cuda:0" , dtype )
919
920
assert ts != ts_other
920
921
@@ -942,7 +943,7 @@ def test_equality_onehot(self):
942
943
)
943
944
assert ts != ts_other
944
945
945
- if torch .cuda .device_count ():
946
+ if torch .cuda .is_available ():
946
947
ts_other = OneHot (
947
948
n = n , device = "cuda:0" , dtype = dtype , use_register = use_register
948
949
)
@@ -972,7 +973,7 @@ def test_equality_unbounded(self):
972
973
ts_same = Unbounded (device = device , dtype = dtype )
973
974
assert ts == ts_same
974
975
975
- if torch .cuda .device_count ():
976
+ if torch .cuda .is_available ():
976
977
ts_other = Unbounded (device = "cuda:0" , dtype = dtype )
977
978
assert ts != ts_other
978
979
@@ -1005,7 +1006,7 @@ def test_equality_ndbounded(self):
1005
1006
ts_other = Bounded (low = minimum , high = maximum + 1 , device = device , dtype = dtype )
1006
1007
assert ts != ts_other
1007
1008
1008
- if torch .cuda .device_count ():
1009
+ if torch .cuda .is_available ():
1009
1010
ts_other = Bounded (low = minimum , high = maximum , device = "cuda:0" , dtype = dtype )
1010
1011
assert ts != ts_other
1011
1012
@@ -1033,7 +1034,7 @@ def test_equality_discrete(self):
1033
1034
ts_other = Categorical (n = n + 1 , shape = shape , device = device , dtype = dtype )
1034
1035
assert ts != ts_other
1035
1036
1036
- if torch .cuda .device_count ():
1037
+ if torch .cuda .is_available ():
1037
1038
ts_other = Categorical (n = n , shape = shape , device = "cuda:0" , dtype = dtype )
1038
1039
assert ts != ts_other
1039
1040
@@ -1071,7 +1072,7 @@ def test_equality_ndunbounded(self, shape):
1071
1072
ts_other = Unbounded (shape = other_shape , device = device , dtype = dtype )
1072
1073
assert ts != ts_other
1073
1074
1074
- if torch .cuda .device_count ():
1075
+ if torch .cuda .is_available ():
1075
1076
ts_other = Unbounded (shape = shape , device = "cuda:0" , dtype = dtype )
1076
1077
assert ts != ts_other
1077
1078
@@ -1097,7 +1098,7 @@ def test_equality_binary(self):
1097
1098
ts_other = Binary (n = n + 5 , device = device , dtype = dtype )
1098
1099
assert ts != ts_other
1099
1100
1100
- if torch .cuda .device_count ():
1101
+ if torch .cuda .is_available ():
1101
1102
ts_other = Binary (n = n , device = "cuda:0" , dtype = dtype )
1102
1103
assert ts != ts_other
1103
1104
@@ -1131,7 +1132,7 @@ def test_equality_multi_onehot(self, nvec):
1131
1132
ts_other = MultiOneHot (nvec = other_nvec , device = device , dtype = dtype )
1132
1133
assert ts != ts_other
1133
1134
1134
- if torch .cuda .device_count ():
1135
+ if torch .cuda .is_available ():
1135
1136
ts_other = MultiOneHot (nvec = nvec , device = "cuda:0" , dtype = dtype )
1136
1137
assert ts != ts_other
1137
1138
@@ -1165,7 +1166,7 @@ def test_equality_multi_discrete(self, nvec):
1165
1166
ts_other = MultiCategorical (nvec = other_nvec , device = device , dtype = dtype )
1166
1167
assert ts != ts_other
1167
1168
1168
- if torch .cuda .device_count ():
1169
+ if torch .cuda .is_available ():
1169
1170
ts_other = MultiCategorical (nvec = nvec , device = "cuda:0" , dtype = dtype )
1170
1171
assert ts != ts_other
1171
1172
@@ -1571,16 +1572,8 @@ def test_choice(self, input_type):
1571
1572
choices = [NonTensorStack ("a" , "b" , "c" ), NonTensorStack ("d" , "e" , "f" )]
1572
1573
1573
1574
spec = Choice (choices )
1574
- res = spec .expand (
1575
- [
1576
- 3 ,
1577
- ]
1578
- )
1579
- assert res .shape == torch .Size (
1580
- [
1581
- 3 ,
1582
- ]
1583
- )
1575
+ res = spec .expand ([3 ])
1576
+ assert res .shape == torch .Size ([3 ])
1584
1577
1585
1578
@pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
1586
1579
@pytest .mark .parametrize ("shape2" , [(), (10 ,)])
0 commit comments