Skip to content

Commit e6b4d77

Browse files
pearupytorchmergebot
authored andcommitted
Sparse Compressed tensor factory function 2
Pull Request resolved: pytorch#76623 Approved by: https://github.com/cpuhrsch
1 parent 9fae076 commit e6b4d77

File tree

6 files changed

+172
-84
lines changed

6 files changed

+172
-84
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5394,8 +5394,10 @@
53945394
# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
53955395
# the default would never make sense.
53965396

5397+
- func: sparse_compressed_tensor.crow_col_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
53975398
- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
53985399

5400+
- func: sparse_compressed_tensor.crow_col_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
53995401
- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
54005402

54015403
- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <ATen/ops/resize_as_sparse_native.h>
3333
#include <ATen/ops/resize_native.h>
3434
#include <ATen/ops/select_native.h>
35+
#include <ATen/ops/sparse_compressed_tensor_native.h>
3536
#include <ATen/ops/sparse_csr_tensor_native.h>
3637
#include <ATen/ops/values_native.h>
3738
#endif
@@ -298,26 +299,54 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice
298299

299300
SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
300301

302+
DimVector _estimate_sparse_compressed_tensor_size(
303+
const Tensor& compressed_indices,
304+
const Tensor& plain_indices,
305+
const Tensor& values,
306+
Layout layout) {
307+
DimVector size = DimVector(IntArrayRef(plain_indices.sizes().data(), plain_indices.dim() - 1));
308+
int64_t compressed_dim = (plain_indices.size(-1) > 0 ? compressed_indices.size(-1) - 1 : 0);
309+
int64_t plain_dim = AT_DISPATCH_INTEGRAL_TYPES(plain_indices.scalar_type(), "estimate_sparse_compressed_tensor_size",
310+
[&]() -> int64_t { return plain_indices.max().item<scalar_t>() + 1; });
311+
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "estimate_sparse_compressed_tensor_size",
312+
[&]{
313+
size.push_back(compressed_dim);
314+
size.push_back(plain_dim);
315+
},
316+
[&]{
317+
size.push_back(plain_dim);
318+
size.push_back(compressed_dim);
319+
});
320+
return size;
321+
}
322+
301323
// TODO: This constructor should probably use an ATen abstract method in order
302324
// to make autograd dispatch available for the CSR constructor. See the relevant
303325
// note in native_functions.yaml.
304-
Tensor sparse_csr_tensor(
305-
const Tensor& crow_indices,
306-
const Tensor& col_indices,
326+
Tensor sparse_compressed_tensor(
327+
const Tensor& compressed_indices,
328+
const Tensor& plain_indices,
307329
const Tensor& values,
308330
IntArrayRef size,
309331
c10::optional<ScalarType> dtype,
310332
c10::optional<Layout> layout,
311333
c10::optional<Device> device,
312334
c10::optional<bool> pin_memory) {
335+
336+
if (!layout) {
337+
AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
338+
}
339+
Layout layout_ = layout.value();
340+
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
341+
313342
// See [Note: hacky wrapper removal for TensorOptions]
314-
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
343+
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
315344

316-
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
345+
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
317346

318-
return at::native::_sparse_csr_tensor_unsafe(
319-
crow_indices,
320-
col_indices,
347+
return at::native::_sparse_compressed_tensor_unsafe(
348+
compressed_indices,
349+
plain_indices,
321350
values,
322351
size,
323352
optTypeMetaToScalarType(options.dtype_opt()),
@@ -326,26 +355,31 @@ Tensor sparse_csr_tensor(
326355
options.pinned_memory_opt());
327356
}
328357

329-
Tensor sparse_csr_tensor(
330-
const Tensor& crow_indices,
331-
const Tensor& col_indices,
358+
Tensor sparse_compressed_tensor(
359+
const Tensor& compressed_indices,
360+
const Tensor& plain_indices,
332361
const Tensor& values,
333362
c10::optional<ScalarType> dtype,
334363
c10::optional<Layout> layout,
335364
c10::optional<Device> device,
336365
c10::optional<bool> pin_memory) {
366+
367+
if (!layout) {
368+
AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none");
369+
}
370+
Layout layout_ = layout.value();
371+
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{});
372+
373+
DimVector size = _estimate_sparse_compressed_tensor_size(compressed_indices, plain_indices, values, layout_);
374+
337375
// See [Note: hacky wrapper removal for TensorOptions]
338-
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
339-
// std::array<int64_t, 2> size = {0, 0};
340-
auto size = DimVector(IntArrayRef(col_indices.sizes().data(), col_indices.dim() - 1));
341-
size.push_back(crow_indices.size(-1) - 1);
342-
size.push_back(col_indices.max().item<int64_t>() + 1);
376+
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
343377

344-
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
378+
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
345379

346-
return at::native::_sparse_csr_tensor_unsafe(
347-
crow_indices,
348-
col_indices,
380+
return at::native::_sparse_compressed_tensor_unsafe(
381+
compressed_indices,
382+
plain_indices,
349383
values,
350384
size,
351385
optTypeMetaToScalarType(options.dtype_opt()),
@@ -354,6 +388,37 @@ Tensor sparse_csr_tensor(
354388
options.pinned_memory_opt());
355389
}
356390

391+
#define SPARSE_COMPRESSED_TENSOR(KIND, REQUIRED_LAYOUT) \
392+
Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
393+
const Tensor& plain_indices, \
394+
const Tensor& values, \
395+
c10::optional<ScalarType> dtype, \
396+
c10::optional<Layout> layout, \
397+
c10::optional<Device> device, \
398+
c10::optional<bool> pin_memory) { \
399+
if (layout) { \
400+
TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
401+
} \
402+
c10::optional<Layout> layout_(REQUIRED_LAYOUT); \
403+
return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, dtype, layout_, device, pin_memory); \
404+
} \
405+
Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
406+
const Tensor& plain_indices, \
407+
const Tensor& values, \
408+
IntArrayRef size, \
409+
c10::optional<ScalarType> dtype, \
410+
c10::optional<Layout> layout, \
411+
c10::optional<Device> device, \
412+
c10::optional<bool> pin_memory) { \
413+
if (layout) { \
414+
TORCH_CHECK(layout.value() == REQUIRED_LAYOUT, "sparse " # KIND " layout must be ", REQUIRED_LAYOUT, " but got ", layout.value()); \
415+
} \
416+
c10::optional<Layout> layout_(REQUIRED_LAYOUT); \
417+
return at::native::sparse_compressed_tensor(compressed_indices, plain_indices, values, size, dtype, layout_, device, pin_memory); \
418+
}
419+
420+
SPARSE_COMPRESSED_TENSOR(csr, kSparseCsr)
421+
357422
Tensor empty_sparse_csr(
358423
IntArrayRef size,
359424
c10::optional<ScalarType> dtype,

torch/csrc/autograd/python_torch_functions_manual.cpp

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -407,28 +407,6 @@ static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor & self) {
407407

408408
static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs);
409409

410-
static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
411-
{
412-
HANDLE_TH_ERRORS
413-
static PythonArgParser parser({
414-
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
415-
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
416-
});
417-
418-
ParsedArgs<9> parsed_args;
419-
auto r = parser.parse(args, kwargs, parsed_args);
420-
if (r.has_torch_function()) {
421-
return handle_torch_function(
422-
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
423-
}
424-
jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
425-
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(
426-
torch::tensors::get_default_dispatch_key(),
427-
torch::tensors::get_default_scalar_type(),
428-
r));
429-
END_HANDLE_TH_ERRORS
430-
}
431-
432410
#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \
433411
static PyObject * THPVariable_ ## NAME(PyObject* self, PyObject* args, PyObject* kwargs) \
434412
{ \
@@ -444,6 +422,13 @@ static PyObject * THPVariable_ ## NAME(PyObject* self, PyObject* args, PyObject*
444422
END_HANDLE_TH_ERRORS \
445423
}
446424

425+
THPVARIABLE_SPARSE_COMPRESSED_CTOR(sparse_compressed_tensor, 9,
426+
({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
427+
"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
428+
THPVARIABLE_SPARSE_COMPRESSED_CTOR(sparse_csr_tensor, 9,
429+
({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
430+
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
431+
447432
THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_compressed_tensor_unsafe, 8,
448433
({"_sparse_compressed_tensor_unsafe(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool requires_grad=False)"}))
449434
THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_csr_tensor_unsafe, 7,
@@ -796,6 +781,7 @@ static PyMethodDef torch_functions_manual[] = {
796781
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
797782
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
798783
{"_sparse_compressed_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_compressed_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
784+
{"sparse_compressed_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_compressed_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
799785
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
800786
{"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
801787
{"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},

0 commit comments

Comments
 (0)