Skip to content

Commit 98654dc

Browse files
committed
Update
[ghstack-poisoned]
1 parent 157de34 commit 98654dc

File tree

4 files changed

+19
-26
lines changed

4 files changed

+19
-26
lines changed

benchmarks/ecosystem/gym_env_throughput.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
if __name__ == "__main__":
3333
avail_devices = ("cpu",)
34-
if torch.cuda.device_count():
34+
if torch.cuda.is_available():
3535
avail_devices = avail_devices + ("cuda:0",)
3636

3737
for envname in [

examples/distributed/collectors/multi_nodes/rpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"slurm_gpus_per_task": args.slurm_gpus_per_task,
7575
}
7676
device_str = "device" if num_workers <= 1 else "devices"
77-
if torch.cuda.device_count():
77+
if torch.cuda.is_available():
7878
collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"}
7979
else:
8080
collector_kwargs = {device_str: "cpu", "storing_{device_str}": "cpu"}

test/test_specs.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -724,9 +724,10 @@ def test_choice(self, input_type):
724724
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
725725
example_in = NonTensorStack("a", "b", "c")
726726
example_out = NonTensorStack("a", "c", "b")
727-
728-
spec = Choice(choices)
729-
res = spec.rand()
727+
torch.manual_seed(0)
728+
for _ in range(10):
729+
spec = Choice(choices)
730+
res = spec.rand()
730731
assert spec.is_in(res)
731732
assert spec.is_in(example_in)
732733
assert not spec.is_in(example_out)
@@ -755,7 +756,7 @@ def test_errors(self):
755756
]
756757
)
757758

758-
if torch.cuda.device_count():
759+
if torch.cuda.is_available():
759760
with pytest.raises(ValueError, match="must have the same device"):
760761
Choice(
761762
[
@@ -914,7 +915,7 @@ def test_equality_bounded(self):
914915

915916
ts_other = Bounded(minimum, maximum + 1, torch.Size((1,)), device, dtype)
916917
assert ts != ts_other
917-
if torch.cuda.device_count():
918+
if torch.cuda.is_available():
918919
ts_other = Bounded(minimum, maximum, torch.Size((1,)), "cuda:0", dtype)
919920
assert ts != ts_other
920921

@@ -942,7 +943,7 @@ def test_equality_onehot(self):
942943
)
943944
assert ts != ts_other
944945

945-
if torch.cuda.device_count():
946+
if torch.cuda.is_available():
946947
ts_other = OneHot(
947948
n=n, device="cuda:0", dtype=dtype, use_register=use_register
948949
)
@@ -972,7 +973,7 @@ def test_equality_unbounded(self):
972973
ts_same = Unbounded(device=device, dtype=dtype)
973974
assert ts == ts_same
974975

975-
if torch.cuda.device_count():
976+
if torch.cuda.is_available():
976977
ts_other = Unbounded(device="cuda:0", dtype=dtype)
977978
assert ts != ts_other
978979

@@ -1005,7 +1006,7 @@ def test_equality_ndbounded(self):
10051006
ts_other = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype)
10061007
assert ts != ts_other
10071008

1008-
if torch.cuda.device_count():
1009+
if torch.cuda.is_available():
10091010
ts_other = Bounded(low=minimum, high=maximum, device="cuda:0", dtype=dtype)
10101011
assert ts != ts_other
10111012

@@ -1033,7 +1034,7 @@ def test_equality_discrete(self):
10331034
ts_other = Categorical(n=n + 1, shape=shape, device=device, dtype=dtype)
10341035
assert ts != ts_other
10351036

1036-
if torch.cuda.device_count():
1037+
if torch.cuda.is_available():
10371038
ts_other = Categorical(n=n, shape=shape, device="cuda:0", dtype=dtype)
10381039
assert ts != ts_other
10391040

@@ -1071,7 +1072,7 @@ def test_equality_ndunbounded(self, shape):
10711072
ts_other = Unbounded(shape=other_shape, device=device, dtype=dtype)
10721073
assert ts != ts_other
10731074

1074-
if torch.cuda.device_count():
1075+
if torch.cuda.is_available():
10751076
ts_other = Unbounded(shape=shape, device="cuda:0", dtype=dtype)
10761077
assert ts != ts_other
10771078

@@ -1097,7 +1098,7 @@ def test_equality_binary(self):
10971098
ts_other = Binary(n=n + 5, device=device, dtype=dtype)
10981099
assert ts != ts_other
10991100

1100-
if torch.cuda.device_count():
1101+
if torch.cuda.is_available():
11011102
ts_other = Binary(n=n, device="cuda:0", dtype=dtype)
11021103
assert ts != ts_other
11031104

@@ -1131,7 +1132,7 @@ def test_equality_multi_onehot(self, nvec):
11311132
ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype)
11321133
assert ts != ts_other
11331134

1134-
if torch.cuda.device_count():
1135+
if torch.cuda.is_available():
11351136
ts_other = MultiOneHot(nvec=nvec, device="cuda:0", dtype=dtype)
11361137
assert ts != ts_other
11371138

@@ -1165,7 +1166,7 @@ def test_equality_multi_discrete(self, nvec):
11651166
ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype)
11661167
assert ts != ts_other
11671168

1168-
if torch.cuda.device_count():
1169+
if torch.cuda.is_available():
11691170
ts_other = MultiCategorical(nvec=nvec, device="cuda:0", dtype=dtype)
11701171
assert ts != ts_other
11711172

@@ -1571,16 +1572,8 @@ def test_choice(self, input_type):
15711572
choices = [NonTensorStack("a", "b", "c"), NonTensorStack("d", "e", "f")]
15721573

15731574
spec = Choice(choices)
1574-
res = spec.expand(
1575-
[
1576-
3,
1577-
]
1578-
)
1579-
assert res.shape == torch.Size(
1580-
[
1581-
3,
1582-
]
1583-
)
1575+
res = spec.expand([3])
1576+
assert res.shape == torch.Size([3])
15841577

15851578
@pytest.mark.parametrize("shape1", [None, (), (5,)])
15861579
@pytest.mark.parametrize("shape2", [(), (10,)])

torchrl/collectors/distributed/rpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def _init_master_rpc(
449449
):
450450
"""Init RPC on main node."""
451451
options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options)
452-
if torch.cuda.device_count():
452+
if torch.cuda.is_available():
453453
if self.visible_devices:
454454
for i in range(self.num_workers):
455455
rank = i + 1

0 commit comments

Comments
 (0)