Skip to content

Commit 46162cc

Browse files
ssnlfacebook-github-bot
authored andcommitted
Autograd indices/values and sparse_coo ctor (pytorch#13001)
Summary: Reopen of pytorch#11253 after fixing bug in index_select Pull Request resolved: pytorch#13001 Differential Revision: D10514987 Pulled By: SsnL fbshipit-source-id: 399a83a1d3246877a3523baf99aaf1ce8066f33f
1 parent e0f21a4 commit 46162cc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1462
-995
lines changed

aten/src/ATen/Declarations.cwrap

+3-1
Original file line numberDiff line numberDiff line change
@@ -3266,7 +3266,9 @@
32663266
name: alias
32673267
return: THTensor*
32683268
cpu_half: True
3269-
variants: [function]
3269+
variants:
3270+
- method
3271+
- function
32703272
options:
32713273
- cname: newWithTensor
32723274
arguments:

aten/src/ATen/SparseTensorImpl.cpp

+16-11
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ namespace {
2222
// a scalar and have one element)
2323
//
2424
// Thus, an empty sparse tensor should be a 1-dimensional tensor of size [0].
25-
// Furthermore, we have dim == sparseDims + denseDims; since this is a sparse
26-
// tensor, let us say that an empty sparse tensor has sparseDims == 1 and
27-
// denseDims == 0. (There is a degree of freedom here, but given that this
28-
// is a sparse dimension, it seems reasonable to demand that sparseDims > 0).
25+
// Furthermore, we have dim == sparse_dim + dense_dim; since this is a sparse
26+
// tensor, let us say that an empty sparse tensor has sparse_dim == 1 and
27+
// dense_dim == 0. (There is a degree of freedom here, but given that this
28+
// is a sparse dimension, it seems reasonable to demand that sparse_dim > 0).
2929
//
3030
// This means that we allocate a [1,0] size indices tensor and a [0] size
3131
// values tensor for such an empty tensor.
3232
SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeMeta& data_type)
3333
: TensorImpl(type_id, data_type, nullptr, false)
3434
, size_{0}
35-
, sparseDims_(1)
36-
, denseDims_(0)
35+
, sparse_dim_(1)
36+
, dense_dim_(0)
3737
, indices_(at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(ScalarType::Long)))
3838
, values_(at::empty({0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(dataTypeToScalarType(data_type.id())))) {}
3939

@@ -67,7 +67,7 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) {
6767
}
6868

6969
int64_t SparseTensorImpl::dim() const {
70-
return sparseDims_ + denseDims_;
70+
return sparse_dim_ + dense_dim_;
7171
}
7272
TensorImpl* SparseTensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
7373
AT_CHECK(condition_when_zero_dim == (dim() == 0),
@@ -83,17 +83,22 @@ int64_t SparseTensorImpl::storage_offset() const {
8383
AT_ERROR("sparse tensors do not have storage");
8484
}
8585
void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) {
86+
AT_ASSERT(!indices.is_variable() && !values.is_variable()); // They should be plain tensors!
87+
88+
AT_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
89+
AT_CHECK(!values.is_sparse(), "expected values to be a dense tensor, but got values of layout ", values.layout());
90+
8691
AT_CHECK(values.type().toSparse() == type(), "values type must match sparse tensor type");
8792
AT_CHECK(indices.type().scalarType() == kLong, "indices must be an int64 tensor");
8893
AT_CHECK(indices.type().backend() == values.type().backend(), "backend of indices (", indices.type().backend(), ") must match backend of values (", values.type().backend(), ")");
8994
AT_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")");
9095

91-
AT_CHECK(indices.dim() == 2, "indices must be nDim x nnz, but got: ", indices.sizes());
96+
AT_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes());
9297
AT_CHECK(indices.size(1) == values.size(0), "indices and values must have same nnz, but got nnz from indices: ", indices.size(1), ", nnz from values: ", values.size(0));
93-
AT_CHECK(indices.size(0) == sparseDims_, "indices has incorrect first dimension, expected ", sparseDims_, ", got ", indices.size(0));
94-
AT_CHECK(values.dim() == denseDims_ + 1, "values has incorrect number of dimensions, expected ", denseDims_ + 1, ", got ", values.dim());
98+
AT_CHECK(indices.size(0) == sparse_dim_, "indices has incorrect first dimension, expected ", sparse_dim_, ", got ", indices.size(0));
99+
AT_CHECK(values.dim() == dense_dim_ + 1, "values has incorrect number of dimensions, expected ", dense_dim_ + 1, ", got ", values.dim());
95100

96-
auto dense_size_original = sizes().slice(sparseDims_);
101+
auto dense_size_original = sizes().slice(sparse_dim_);
97102
std::vector<int64_t> expected_values_size_vec = {values.size(0)};
98103
expected_values_size_vec.insert(expected_values_size_vec.end(), dense_size_original.begin(), dense_size_original.end());
99104
IntList expected_values_size(expected_values_size_vec);

aten/src/ATen/SparseTensorImpl.h

+46-47
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
99
// Stored in COO format, indices + values.
1010

1111
// INVARIANTS:
12-
// _sparseDims: range [0, len(shape)]; _sparseDims + _denseDims = len(shape)
13-
// _denseDims : range [0, len(shape)]; _sparseDims + _denseDims = len(shape)
14-
// _indices.shape: dimensionality: 2, shape: (_sparseDims, nnz)
15-
// _values.shape: dimensionality: 1 + _denseDims. shape: (nnz, shape[_sparseDims:])
12+
// sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
13+
// dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
14+
// _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
15+
// _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, shape[sparse_dim:])
1616

1717
// The true size of the sparse tensor (e.g., if you called to_dense()
1818
// on it). When THTensor merges into TensorImpl, this field
1919
// should move to the parent class.
2020
std::vector<int64_t> size_;
2121

22-
int64_t sparseDims_ = 0; // number of sparse dimensions
23-
int64_t denseDims_ = 0; // number of dense dimensions
22+
int64_t sparse_dim_ = 0; // number of sparse dimensions
23+
int64_t dense_dim_ = 0; // number of dense dimensions
2424

2525
Tensor indices_; // always a LongTensor
2626
Tensor values_;
@@ -39,8 +39,8 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
3939
explicit SparseTensorImpl(at::TensorTypeId, const caffe2::TypeMeta&);
4040

4141
int64_t nnz() const { return values_.size(0); }
42-
int64_t sparseDims() const { return sparseDims_; }
43-
int64_t denseDims() const { return denseDims_; }
42+
int64_t sparse_dim() const { return sparse_dim_; }
43+
int64_t dense_dim() const { return dense_dim_; }
4444
bool coalesced() const { return coalesced_; }
4545
Tensor indices() const { return indices_; }
4646
Tensor values() const { return values_; }
@@ -60,16 +60,16 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
6060
const Storage& storage() const override;
6161
int64_t storage_offset() const override;
6262

63-
// WARNING: This function does NOT preserve invariants of sparseDims/denseDims with
63+
// WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim with
6464
// respect to indices and values
65-
void raw_resize_(int64_t sparseDims, int64_t denseDims, IntList size) {
65+
void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
6666
size_ = size.vec();
67-
sparseDims_ = sparseDims;
68-
denseDims_ = denseDims;
67+
sparse_dim_ = sparse_dim;
68+
dense_dim_ = dense_dim;
6969
refresh_numel();
7070
}
7171

72-
// NOTE: This function preserves invariants of sparseDims/denseDims with respect to
72+
// NOTE: This function preserves invariants of sparse_dim/dense_dim with respect to
7373
// indices and values.
7474
//
7575
// NOTE: This function supports the following cases:
@@ -91,75 +91,73 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
9191
// and for API consistency we don't support it).
9292
// 4. When we attempt to shrink the size of any of the sparse dimensions on a non-empty sparse tensor
9393
// (this could make some of the stored indices out-of-bound and thus unsafe).
94-
void resize_(int64_t sparseDims, int64_t denseDims, IntList size) {
95-
AT_CHECK(sparseDims + denseDims == size.size(), "number of dimensions must be sparseDims (", sparseDims, ") + denseDims (", denseDims, "), but got ", size.size());
94+
void resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
95+
AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
9696
if (nnz() > 0) {
9797
auto alt_options_msg = "You could try the following options:\n\
98-
1. If you need an empty sparse tensor of this size, call `x=torch.sparse_coo_tensor(size)`.\n\
98+
1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
9999
2. If you need to resize this tensor, you have the following options:\n\
100100
1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
101101
2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
102102

103-
AT_CHECK(sparseDims == sparseDims_,
104-
"changing the number of sparse dimensions (from ", sparseDims_, " to ", sparseDims, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
103+
AT_CHECK(sparse_dim == sparse_dim_,
104+
"changing the number of sparse dimensions (from ", sparse_dim_, " to ", sparse_dim, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
105105

106-
AT_CHECK(denseDims == denseDims_,
107-
"changing the number of dense dimensions (from ", denseDims_, " to ", denseDims, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
106+
AT_CHECK(dense_dim == dense_dim_,
107+
"changing the number of dense dimensions (from ", dense_dim_, " to ", dense_dim, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
108108

109109
bool shrinking_sparse_dims = false;
110-
bool shrinking_dense_dims = false;
111-
auto sparse_size_original = sizes().slice(0, sparseDims);
112-
auto sparse_size_new = size.slice(0, sparseDims);
113-
for (int i = 0; i < sparseDims; i++) {
110+
bool shrinking_dense_dim = false;
111+
auto sparse_size_original = sizes().slice(0, sparse_dim);
112+
auto sparse_size_new = size.slice(0, sparse_dim);
113+
for (int i = 0; i < sparse_dim; i++) {
114114
if (sparse_size_new[i] < sparse_size_original[i]) {
115115
shrinking_sparse_dims = true;
116116
break;
117117
}
118118
}
119-
auto dense_size_original = sizes().slice(sparseDims);
120-
auto dense_size_new = size.slice(sparseDims);
121-
for (int i = 0; i < denseDims; i++) {
119+
auto dense_size_original = sizes().slice(sparse_dim);
120+
auto dense_size_new = size.slice(sparse_dim);
121+
for (int i = 0; i < dense_dim; i++) {
122122
if (dense_size_new[i] < dense_size_original[i]) {
123-
shrinking_dense_dims = true;
123+
shrinking_dense_dim = true;
124124
break;
125125
}
126126
}
127127

128128
AT_CHECK(!shrinking_sparse_dims,
129129
"shrinking the size of sparse dimensions (from ", sparse_size_original, " to ", sparse_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
130130

131-
AT_CHECK(!shrinking_dense_dims,
131+
AT_CHECK(!shrinking_dense_dim,
132132
"shrinking the size of dense dimensions (from ", dense_size_original, " to ", dense_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
133133
}
134134

135-
if ((!size.equals(size_)) || (sparseDims != sparseDims_) || (denseDims != denseDims_)) {
136-
std::vector<int64_t> values_size = {values().size(0)};
137-
auto dense_size = size.slice(sparseDims);
135+
if ((!size.equals(size_)) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) {
136+
auto nnz = values().size(0);
137+
std::vector<int64_t> values_size = {nnz};
138+
auto dense_size = size.slice(sparse_dim);
138139
values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
139140
values_.resize_(values_size);
140-
141-
std::vector<int64_t> indices_size = indices().sizes().vec();
142-
indices_size[0] = sparseDims;
143-
indices_.resize_(indices_size);
141+
indices_.resize_({sparse_dim, nnz});
144142
}
145143

146144
size_ = size.vec();
147-
sparseDims_ = sparseDims;
148-
denseDims_ = denseDims;
145+
sparse_dim_ = sparse_dim;
146+
dense_dim_ = dense_dim;
149147
refresh_numel();
150148
}
151149

152150
// NOTE: this function will resize the sparse tensor and also set `indices` and `values` to empty.
153-
void resize_and_clear_(int64_t sparseDims, int64_t denseDims, IntList size) {
154-
AT_CHECK(sparseDims + denseDims == size.size(), "number of dimensions must be sparseDims (", sparseDims, ") + denseDims (", denseDims, "), but got ", size.size());
151+
void resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
152+
AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
155153

156154
size_ = size.vec();
157-
sparseDims_ = sparseDims;
158-
denseDims_ = denseDims;
155+
sparse_dim_ = sparse_dim;
156+
dense_dim_ = dense_dim;
159157

160-
auto empty_indices = at::empty({sparseDims, 0}, indices().options());
158+
auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
161159
std::vector<int64_t> values_size = {0};
162-
auto dense_size = sizes().slice(sparseDims);
160+
auto dense_size = sizes().slice(sparse_dim);
163161
values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
164162
auto empty_values = at::empty(values_size, values().options());
165163
set_indices_and_values_unsafe(empty_indices, empty_values);
@@ -169,9 +167,10 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
169167
void set_coalesced(bool coalesced) { coalesced_ = coalesced; }
170168

171169
// NOTE: this function is only used internally and not exposed to Python frontend
172-
void set_nnz_and_narrow(int64_t nnz) {
173-
indices_ = indices_.narrow(1, 0, nnz);
174-
values_ = values_.narrow(0, 0, nnz);
170+
void set_nnz_and_narrow(int64_t new_nnz) {
171+
AT_ASSERT(new_nnz <= nnz());
172+
indices_ = indices_.narrow(1, 0, new_nnz);
173+
values_ = values_.narrow(0, 0, new_nnz);
175174
}
176175

177176
// Takes indices and values and directly puts them into the sparse tensor, no copy.

aten/src/ATen/SparseTensorUtils.h

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/SparseTensorImpl.h>
3+
4+
namespace at { namespace sparse {
5+
6+
// Just for documentary purposes
7+
using SparseTensor = Tensor;
8+
using LongTensor = Tensor;
9+
using IntTensor = Tensor;
10+
using SparseType = Type;
11+
12+
// This is an internal utility function for getting at the SparseTensorImpl,
13+
// so that we can write sparse tensor specific accessors for special fields
14+
// in SparseTensor. You should only use this for writing low level
15+
// setters/getters for SparseTensorImpl fields; otherwise, you should use
16+
// the low level setters/getters that were implemented using this.
17+
//
18+
// This may be called repeatedly, so make sure it's pretty cheap.
19+
inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
20+
AT_ASSERTM(!self.is_variable(), "_internal_get_SparseTensorImpl: should not be a variable");
21+
AT_ASSERTM(self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
22+
return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
23+
}
24+
25+
// Port of the old THCSTensor_(checkGPU), but it doesn't really belong here
26+
// because it is more general
27+
// NB: I dropped kernelP2PEnabled support
28+
// NB: This only works if the tensors are KNOWN to be CUDA.
29+
// TODO: Generalize it so it works on CPU as well
30+
inline bool check_device(ArrayRef<Tensor> ts) {
31+
if (ts.empty()) {
32+
return true;
33+
}
34+
int64_t curDevice = current_device();
35+
for (const Tensor& t : ts) {
36+
if (t.get_device() != curDevice) return false;
37+
}
38+
return true;
39+
}
40+
41+
// Takes indices and values and directly puts them into the sparse tensor, no
42+
// copy. This used to be called THSTensor_(_move)
43+
inline void alias_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values) {
44+
get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
45+
}
46+
47+
// Take indices and values and makes a (data) copy of them to put into the sparse
48+
// indices/values. This used to be called THSTensor_(_set)
49+
inline void copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) {
50+
alias_into_sparse(self, self._indices().type().copy(indices, non_blocking), self._values().type().copy(values, non_blocking));
51+
}
52+
53+
// TODO: put this into the public API
54+
inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
55+
return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
56+
}
57+
58+
inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
59+
return self.sparse_dim() == src.sparse_dim() && self.dense_dim() == src.dense_dim();
60+
}
61+
62+
// Give us a new values tensor, with the same dimensionality
63+
// as 'values' but with a new number of non-zero elements.
64+
// TODO: Expose this for real in ATen, some day?
65+
// NB: Doesn't preserve data.
66+
inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
67+
std::vector<int64_t> size = values.sizes().vec();
68+
size[0] = nnz;
69+
return at::empty(size, values.options());
70+
}
71+
72+
// This helper function flattens a sparse indices tensor (a LongTensor) into a 1D
73+
// indices tensor. E.g.,
74+
// input = [[2, 4, 0],
75+
// [3, 1, 10]]
76+
// full_size = [2, 12]
77+
// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
78+
//
79+
// In other words, assuming that each `indices[i, :]` is a valid index to a
80+
// tensor `t` of shape `full_size`. This returns the corresponding indices to
81+
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
82+
// if forceClone is true, the result will forced to be a clone of self.
83+
// if force_clone is true, the result will forced to be a clone of self.
84+
inline LongTensor flatten_indices(const Tensor& indices, IntList full_size, bool force_clone = false) {
85+
int64_t sparse_dim = indices.size(0);
86+
if (sparse_dim == 1) {
87+
if (force_clone) {
88+
return indices.squeeze(0).clone();
89+
} else {
90+
return indices.squeeze(0);
91+
}
92+
} else {
93+
std::vector<int64_t> indices_mult_cpu_vec;
94+
indices_mult_cpu_vec.reserve(sparse_dim);
95+
int64_t mult = 1;
96+
for (int64_t i = sparse_dim - 1; i >= 0; i--) {
97+
indices_mult_cpu_vec[i] = mult;
98+
mult *= full_size[i];
99+
}
100+
auto indices_mult_cpu = indices.type().cpu()
101+
.tensorFromBlob(indices_mult_cpu_vec.data(), /*size=*/{sparse_dim, 1});
102+
// NB: must be blocking because this blob may be freed after this closure,
103+
// and non_blocking copy will see garbage.
104+
auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false);
105+
// Ideally we want matmul but matmul is slow on CPU Long and not implemented
106+
// on CUDA Long. So mul is faster.
107+
return indices.mul(indices_mult).sum(0);
108+
}
109+
}
110+
111+
}} // namespace at::sparse

aten/src/ATen/core/Tensor.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ class CAFFE2_API Tensor {
404404
Tensor & log_normal_(double mean=1, double std=2, Generator * generator=nullptr);
405405
Tensor & exponential_(double lambd=1, Generator * generator=nullptr);
406406
Tensor & geometric_(double p, Generator * generator=nullptr);
407+
Tensor alias() const;
407408
Tensor abs() const;
408409
Tensor & abs_();
409410
Tensor acos() const;
@@ -621,17 +622,22 @@ class CAFFE2_API Tensor {
621622
Tensor & sub_(Scalar other, Scalar alpha=1);
622623
Tensor addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
623624
Tensor & addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1);
624-
Tensor & sparse_resize_(IntList size, int64_t sparseDims, int64_t denseDims);
625-
Tensor & sparse_resize_and_clear_(IntList size, int64_t sparseDims, int64_t denseDims);
625+
Tensor & sparse_resize_(IntList size, int64_t sparse_dim, int64_t dense_dim);
626+
Tensor & sparse_resize_and_clear_(IntList size, int64_t sparse_dim, int64_t dense_dim);
626627
Tensor sparse_mask(SparseTensorRef mask) const;
627628
Tensor to_dense() const;
628-
int64_t _sparseDims() const;
629-
int64_t _denseDims() const;
629+
int64_t sparse_dim() const;
630+
int64_t _dimI() const;
631+
int64_t dense_dim() const;
632+
int64_t _dimV() const;
630633
int64_t _nnz() const;
631634
Tensor coalesce() const;
632635
bool is_coalesced() const;
633636
Tensor _indices() const;
634637
Tensor _values() const;
638+
Tensor & _coalesced_(bool coalesced);
639+
Tensor indices() const;
640+
Tensor values() const;
635641
int64_t numel() const;
636642
std::vector<Tensor> unbind(int64_t dim=0) const;
637643
int64_t get_device() const;

0 commit comments

Comments
 (0)