Skip to content

Commit 8bcdde5

Browse files
ezyangpytorchmergebot
authored andcommitted
Support uint{16,32,64} deterministic empty fill and scalar Python binding handling (pytorch#116807)
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#116807 Approved by: https://github.com/albanD ghstack dependencies: pytorch#116805, pytorch#116806
1 parent 43a23a7 commit 8bcdde5

File tree

6 files changed

+46
-30
lines changed

6 files changed

+46
-30
lines changed

aten/src/ATen/ScalarOps.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/Dispatch.h>
3+
#include <ATen/Dispatch_v2.h>
34
#include <ATen/EmptyTensor.h>
45
#include <ATen/ScalarOps.h>
56

@@ -15,10 +16,10 @@ inline void fill_inplace(Tensor& self, const Scalar& value_scalar) {
1516

1617
namespace detail {
1718
Tensor& scalar_fill(Tensor& self, const Scalar& value) {
18-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
19-
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() {
19+
AT_DISPATCH_V2(
20+
self.scalar_type(), "fill_out", AT_WRAP([&]() {
2021
fill_inplace<scalar_t>(self, value);
21-
});
22+
}), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
2223
return self;
2324
}
2425

aten/src/ATen/ScalarOps.h

+2-20
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,9 @@ inline at::Tensor scalar_to_tensor(
3333
const Device device = at::kCPU) {
3434
// This is the fast track we have for CPU scalar tensors.
3535
if (device == at::kCPU) {
36-
if (s.isFloatingPoint()) {
37-
return at::detail::scalar_tensor_static(s, at::kDouble, at::kCPU);
38-
} else if (s.isComplex()) {
39-
return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
40-
} else if (s.isBoolean()) {
41-
return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU);
42-
} else {
43-
AT_ASSERT(s.isIntegral(false));
44-
return at::detail::scalar_tensor_static(s, at::kLong, at::kCPU);
45-
}
46-
}
47-
if (s.isFloatingPoint()) {
48-
return at::scalar_tensor(s, at::device(device).dtype(at::kDouble));
49-
} else if (s.isBoolean()) {
50-
return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
51-
} else if (s.isComplex()) {
52-
return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble));
53-
} else {
54-
AT_ASSERT(s.isIntegral(false));
55-
return at::scalar_tensor(s, at::device(device).dtype(at::kLong));
36+
return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
5637
}
38+
return at::scalar_tensor(s, at::device(device).dtype(s.type()));
5739
}
5840

5941
} // namespace c10

aten/src/ATen/native/TensorFactories.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/EmptyTensor.h>
55
#include <ATen/TensorIterator.h>
66
#include <ATen/Dispatch.h>
7+
#include <ATen/Dispatch_v2.h>
78
#include <ATen/native/DispatchStub.h>
89

910
#ifndef AT_PER_OPERATOR_HEADERS
@@ -107,10 +108,10 @@ inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
107108
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
108109
});
109110
} else {
110-
AT_DISPATCH_INTEGRAL_TYPES_AND(
111-
kBool, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
111+
AT_DISPATCH_V2(
112+
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
112113
tensor.fill_(std::numeric_limits<scalar_t>::max());
113-
});
114+
}), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
114115
}
115116
return tensor;
116117
}

test/test_torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1268,7 +1268,7 @@ def test_nondeterministic_resize_quantized(self, device, dtype):
12681268

12691269
@skipXLA
12701270
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
1271-
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1271+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
12721272
def test_deterministic_resize(self, device, dtype):
12731273
test_cases = [
12741274
# size, stride, resize_size

torch/csrc/utils/pybind.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ py::handle type_caster<c10::Scalar>::cast(
127127
if (scalar.isSymbolic()) {
128128
return py::cast(scalar.toSymInt()).release();
129129
} else {
130-
return py::cast(scalar.toLong()).release();
130+
if (scalar.type() == at::ScalarType::UInt64) {
131+
return py::cast(scalar.toUInt64()).release();
132+
} else {
133+
return py::cast(scalar.toLong()).release();
134+
}
131135
}
132136
} else if (scalar.isFloatingPoint()) {
133137
// This isn't strictly necessary but we add it for symmetry

torch/csrc/utils/python_arg_parser.cpp

+30-2
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,21 @@ at::Tensor PythonArgs::tensor_slow(int i) {
16491649
if (PyBool_Check(obj)) {
16501650
scalar = at::Scalar(THPUtils_unpackBool(obj));
16511651
} else if (THPUtils_checkLong(obj)) {
1652-
scalar = at::Scalar(THPUtils_unpackLong(obj));
1652+
int overflow = -1;
1653+
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
1654+
if (value == -1 && PyErr_Occurred()) {
1655+
throw python_error();
1656+
}
1657+
if (overflow != 0) {
1658+
// try unsigned
1659+
unsigned long long value = PyLong_AsUnsignedLongLong(obj);
1660+
if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1661+
throw python_error();
1662+
}
1663+
scalar = at::Scalar(static_cast<uint64_t>(value));
1664+
} else {
1665+
scalar = at::Scalar(static_cast<int64_t>(value));
1666+
}
16531667
} else if (PyComplex_Check(obj)) {
16541668
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
16551669
} else if (THPUtils_checkDouble(obj)) {
@@ -1712,7 +1726,21 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
17121726
}
17131727

17141728
if (THPUtils_checkLong(arg)) {
1715-
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(arg)));
1729+
int overflow = -1;
1730+
long long value = PyLong_AsLongLongAndOverflow(arg, &overflow);
1731+
if (value == -1 && PyErr_Occurred()) {
1732+
throw python_error();
1733+
}
1734+
if (overflow != 0) {
1735+
// try unsigned
1736+
unsigned long long value = PyLong_AsUnsignedLongLong(arg);
1737+
if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1738+
throw python_error();
1739+
}
1740+
return at::Scalar(static_cast<uint64_t>(value));
1741+
} else {
1742+
return at::Scalar(static_cast<int64_t>(value));
1743+
}
17161744
}
17171745

17181746
if (PyBool_Check(arg)) {

0 commit comments

Comments
 (0)