Skip to content

Commit 6c7a831

Browse files
authored
Fix Tensor.type(dtype) not preserving device (pytorch#7474)
Note that Tensor.cuda() will stil copy the tensor to the current device if it's a CUDA tensor on a different device. Fixes pytorch#7441
1 parent 43264c3 commit 6c7a831

File tree

4 files changed

+30
-12
lines changed

4 files changed

+30
-12
lines changed

test/test_cuda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,8 @@ def test_type_conversions(self):
817817
def test_type_conversions_same_gpu(self):
818818
x = torch.randn(5, 5).cuda(1)
819819
self.assertEqual(x.int().get_device(), 1)
820+
self.assertEqual(x.type(torch.int).get_device(), 1)
821+
self.assertEqual(x.to(torch.int).get_device(), 1)
820822

821823
def test_neg(self):
822824
TestTorch._test_neg(self, lambda t: t.cuda())

tools/autograd/templates/python_variable_methods.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
628628
auto self_device_type = torch::getDeviceType(self_.type());
629629
auto& type = is_dtype ? torch::getType(r.scalartype(0), *torch::getLayout(self_.type().backend()), self_device_type) :
630630
torch::utils::type_from_string(type_name);
631-
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, -1, r.toBool(1)));
631+
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, at::nullopt, r.toBool(1)));
632632
END_HANDLE_TH_ERRORS
633633
}
634634

torch/csrc/utils/tensor_conversion_dispatch.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88

99
namespace torch { namespace utils {
1010

11-
at::Tensor dispatch_type_conversion(const at::Tensor & self, const at::Type & type) {
12-
int64_t device = self.is_cuda() ? self.get_device() : -1;
13-
return dispatch_type_conversion(self, type, device, false);
14-
}
15-
16-
at::Tensor dispatch_type_conversion(const at::Tensor & self, const at::Type & type,
17-
int device, bool non_blocking) {
11+
at::Tensor dispatch_type_conversion(
12+
const at::Tensor & self,
13+
const at::Type & type,
14+
at::optional<int> device,
15+
bool non_blocking) {
1816
if (type.is_cuda()) {
1917
torch::utils::cuda_lazy_init();
2018
}
2119
AutoNoGIL no_gil;
22-
AutoGPU auto_gpu(device);
20+
2321
int64_t tensor_device = self.is_cuda() ? self.get_device() : -1;
22+
AutoGPU auto_gpu(device.value_or(tensor_device));
23+
2424
if (self.is_cuda() && type.is_cuda() && tensor_device != at::current_device()) {
2525
// copy if the devices are different even if the types are the same
2626
return type.copy(self, non_blocking);

torch/csrc/utils/tensor_conversion_dispatch.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,24 @@
66

77
namespace torch { namespace utils {
88

9-
at::Tensor dispatch_type_conversion(const at::Tensor & self, const at::Type & type);
10-
at::Tensor dispatch_type_conversion(const at::Tensor & self, const at::Type & type,
11-
int device, bool non_blocking);
9+
// Returns a tensor with the same data as `self` and the specified type and
10+
// device. Returns `self` unmodified if neither the type nor device change;
11+
// otherwise a copy is made.
12+
//
13+
// The `device` argument is only relevant if `type` is a CUDA type. There are
14+
// a few special cases for device:
15+
//
16+
// - if device is -1 then the returned tensor will be on the current device
17+
// - if device is nullopt then the returned tensor will be on the same device
18+
// as `self` if possible; otherwise it will be on the current device.
19+
//
20+
// If `non_blocking` is true, then the copy may be performed asynchronously
21+
// w.r.t the host if `self` is a CPU tensor in pinned memory and `type` is a
22+
// CUDA type. Note that copies between CUDA devices are always asynchronous
23+
// w.r.t the host.
24+
at::Tensor dispatch_type_conversion(const at::Tensor & self,
25+
const at::Type & type,
26+
at::optional<int> device=at::nullopt,
27+
bool non_blocking=false);
1228

1329
}} // namespace torch::utils

0 commit comments

Comments
 (0)