Skip to content

Commit b6648c1

Browse files
Michael Antonovfacebook-github-bot
Michael Antonov
authored andcommitted
Update ATen internals to use int64_t for dimension indexing (pytorch#16739)
Summary: Pull Request resolved: pytorch#16739 Some code ATen locations seemed to use int, etc. inclorrectly where either int64_t or size_t was required. Update them to use int64_t for dimension indexing where necessary. Reviewed By: ezyang Differential Revision: D13950124 fbshipit-source-id: aaf1cef783bf3c657aa03490f2616c35c816679f
1 parent 1aa9019 commit b6648c1

File tree

7 files changed

+23
-22
lines changed

7 files changed

+23
-22
lines changed

aten/src/ATen/ExpandUtils.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
namespace at {
44

55
std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
6-
auto dimsA = a.size();
7-
auto dimsB = b.size();
8-
ptrdiff_t ndim = dimsA > dimsB ? dimsA : dimsB;
6+
size_t dimsA = a.size();
7+
size_t dimsB = b.size();
8+
size_t ndim = dimsA > dimsB ? dimsA : dimsB;
99
std::vector<int64_t> expandedSizes(ndim);
1010

11-
for (long i = ndim - 1; i >= 0; --i) {
12-
long offset = ndim - 1 - i;
13-
long dimA = dimsA - 1 - offset;
14-
long dimB = dimsB - 1 - offset;
15-
long sizeA = (dimA >= 0) ? a[dimA] : 1;
16-
long sizeB = (dimB >= 0) ? b[dimB] : 1;
11+
// Use ptrdiff_t to ensure signed comparison.
12+
for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
13+
ptrdiff_t offset = ndim - 1 - i;
14+
ptrdiff_t dimA = dimsA - 1 - offset;
15+
ptrdiff_t dimB = dimsB - 1 - offset;
16+
int64_t sizeA = (dimA >= 0) ? a[dimA] : 1;
17+
int64_t sizeB = (dimB >= 0) ? b[dimB] : 1;
1718

1819
AT_CHECK(
1920
sizeA == sizeB || sizeA == 1 || sizeB == 1,

aten/src/ATen/ExpandUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,12 @@ static inline Tensor sum_to(Tensor tensor, const IntArrayRef shape) {
159159

160160
// True if `shape` can be broadcasted to `desired`
161161
static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
162-
int ndim = shape.size();
163-
int target_dim = desired.size();
162+
size_t ndim = shape.size();
163+
size_t target_dim = desired.size();
164164
if (ndim > target_dim) {
165165
return false;
166166
}
167-
for (int i = 0; i < ndim; i++) {
167+
for (size_t i = 0; i < ndim; i++) {
168168
int64_t size = shape[ndim - i - 1];
169169
int64_t target = desired[target_dim - i - 1];
170170
if (size != target && size != 1) {

aten/src/ATen/SparseTensorImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,15 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
106106
bool shrinking_dense_dim = false;
107107
auto sparse_size_original = sizes().slice(0, sparse_dim);
108108
auto sparse_size_new = size.slice(0, sparse_dim);
109-
for (int i = 0; i < sparse_dim; i++) {
109+
for (int64_t i = 0; i < sparse_dim; i++) {
110110
if (sparse_size_new[i] < sparse_size_original[i]) {
111111
shrinking_sparse_dims = true;
112112
break;
113113
}
114114
}
115115
auto dense_size_original = sizes().slice(sparse_dim);
116116
auto dense_size_new = size.slice(sparse_dim);
117-
for (int i = 0; i < dense_dim; i++) {
117+
for (int64_t i = 0; i < dense_dim; i++) {
118118
if (dense_size_new[i] < dense_size_original[i]) {
119119
shrinking_dense_dim = true;
120120
break;

aten/src/ATen/core/type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
2525
out << ")";
2626
} else if (auto value = t.cast<DimensionedTensorType>()) {
2727
out << toString(value->scalarType()) << "(";
28-
for (int i = 0; i < value->dim(); ++i) {
28+
for (int64_t i = 0; i < value->dim(); ++i) {
2929
if (i > 0) {
3030
out << ", ";
3131
}

aten/src/ATen/native/Indexing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t
326326
AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
327327
{
328328
int64_t element_size_bytes = src.type().elementSizeInBytes();
329-
int dims_before = 0, dims_after = 0, dims_indexed = 0;
329+
int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
330330
IntArrayRef replacement_shape;
331331
for (size_t dim = 0; dim < indices_list.size(); dim++) {
332332
if (!indices_list[dim].defined()) {

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dty
3838

3939
using DimMask = TensorIterator::DimMask;
4040

41-
static DimMask make_dim_mask(IntArrayRef dims, int ndim) {
41+
static DimMask make_dim_mask(IntArrayRef dims, int64_t ndim) {
4242
auto mask = DimMask();
4343
if (dims.empty()) {
4444
mask.flip();
4545
} else {
46-
for (int dim : dims) {
46+
for (int64_t dim : dims) {
4747
mask.set(maybe_wrap_dim(dim, ndim));
4848
}
4949
}
@@ -98,7 +98,7 @@ static std::unique_ptr<TensorIterator> make_reduction(
9898
" and ",
9999
toString(dtype),
100100
".");
101-
int ndim = self.dim();
101+
int64_t ndim = self.dim();
102102
auto mask = make_dim_mask(dim, ndim);
103103
allocate_reduction_result(result, self, mask, keepdim, dtype);
104104
auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ void TensorIterator::reorder_dimensions() {
2222

2323
// returns 1 if the dim0 should come after dim1, -1 if dim0 should come
2424
// before dim1, and 0 if the comparison is ambiguous.
25-
auto should_swap = [&](int dim0, int dim1) {
25+
auto should_swap = [&](size_t dim0, size_t dim1) {
2626
int ret = 0;
2727
for (int arg = 0; arg < ntensors(); arg++) {
2828
if (operands_[arg].stride_bytes.empty()) {
2929
continue;
3030
}
31-
int stride0 = operands_[arg].stride_bytes[dim0];
32-
int stride1 = operands_[arg].stride_bytes[dim1];
31+
int64_t stride0 = operands_[arg].stride_bytes[dim0];
32+
int64_t stride1 = operands_[arg].stride_bytes[dim1];
3333
if (operands_[arg].is_output) {
3434
// move reduced dimensions to the front
3535
if ((stride0 == 0) != (stride1 == 0)) {

0 commit comments

Comments
 (0)