Skip to content

Commit d7db6a7

Browse files
IvanYashchukpytorchmergebot
authored andcommitted
Sparse CSR: Add backward for torch.sparse.sampled_addmm
Pull Request resolved: pytorch#68084 Approved by: https://github.com/cpuhrsch
1 parent e6b4d77 commit d7db6a7

File tree

8 files changed

+154
-6
lines changed

8 files changed

+154
-6
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,6 +2133,7 @@
21332133
CPU, CUDA: fill_
21342134
QuantizedCPU, QuantizedCUDA: fill_quantized_
21352135
Meta: fill_meta_
2136+
SparseCsrCPU, SparseCsrCUDA: fill_sparse_csr_
21362137

21372138
- func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
21382139
device_check: NoCheck # TensorIterator
@@ -3245,6 +3246,7 @@
32453246
variants: function, method
32463247
dispatch:
32473248
CompositeExplicitAutograd: mul
3249+
SparseCsrCPU, SparseCsrCUDA: mul_scalar_sparse_csr
32483250

32493251
- func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
32503252
device_check: NoCheck # TensorIterator
@@ -5443,6 +5445,7 @@
54435445
dispatch:
54445446
SparseCPU: sparse_mask_cpu
54455447
SparseCUDA: sparse_mask_cuda
5448+
SparseCsrCPU, SparseCsrCUDA: sparse_mask_sparse_csr
54465449

54475450
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
54485451
variants: function

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,9 +567,9 @@ Tensor empty_like_sparse_csr(
567567
self.col_indices().clone(),
568568
at::empty(self.values().sizes(), options.layout(kStrided)),
569569
self.sizes(),
570-
dtype,
570+
optTypeMetaToScalarType(options.dtype()),
571571
self.layout(),
572-
device);
572+
options.device());
573573
return result;
574574
} else if (options.layout() == kStrided) {
575575
return at::native::empty_like(self, dtype, layout, device, pin_memory, optional_memory_format);

aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include <ATen/ops/erfinv_native.h>
5353
#include <ATen/ops/expm1.h>
5454
#include <ATen/ops/expm1_native.h>
55+
#include <ATen/ops/fill_native.h>
5556
#include <ATen/ops/floor.h>
5657
#include <ATen/ops/floor_native.h>
5758
#include <ATen/ops/isinf.h>
@@ -65,9 +66,11 @@
6566
#include <ATen/ops/log1p.h>
6667
#include <ATen/ops/log1p_native.h>
6768
#include <ATen/ops/mm_native.h>
69+
#include <ATen/ops/mul_native.h>
6870
#include <ATen/ops/neg.h>
6971
#include <ATen/ops/neg_native.h>
7072
#include <ATen/ops/normal_native.h>
73+
#include <ATen/ops/ones_like.h>
7174
#include <ATen/ops/rad2deg.h>
7275
#include <ATen/ops/rad2deg_native.h>
7376
#include <ATen/ops/resize_as_sparse_native.h>
@@ -87,6 +90,8 @@
8790
#include <ATen/ops/sinh_native.h>
8891
#include <ATen/ops/sqrt.h>
8992
#include <ATen/ops/sqrt_native.h>
93+
#include <ATen/ops/sparse_mask.h>
94+
#include <ATen/ops/sparse_mask_native.h>
9095
#include <ATen/ops/tan.h>
9196
#include <ATen/ops/tan_native.h>
9297
#include <ATen/ops/tanh.h>
@@ -280,6 +285,39 @@ Tensor& normal_sparse_csr_(
280285
return unary_op_inplace(self, &Tensor::normal_, mean, std, gen);
281286
}
282287

288+
Tensor& fill_sparse_csr_(Tensor& self, const Scalar& value) {
289+
return unary_op_inplace(self, &TensorBase::fill_, value);
290+
}
291+
292+
Tensor sparse_mask_sparse_csr(
293+
const Tensor& self,
294+
const Tensor& sparse_mask) {
295+
TORCH_CHECK(sparse_mask.is_sparse_csr(), "sparse_mask_sparse_csr expects mask to be sparse csr");
296+
TORCH_CHECK(self.dim() == 2, "sparse_mask_sparse_csr expects self to be 2D");
297+
TORCH_CHECK(sparse_mask.dim() == 2, "sparse_mask_sparse_csr expects mask to be 2D");
298+
299+
// We are computing self.mul(at::ones_like(sparse_mask))
300+
// But mul(dense, sparse_csr) is not implemented yet
301+
if (self.layout() == sparse_mask.layout()) {
302+
// Both inputs are CSR
303+
return self.mul(at::ones_like(sparse_mask));
304+
} else {
305+
return self.sparse_mask(sparse_mask.to_sparse()).to_sparse_csr();
306+
}
307+
}
308+
309+
Tensor mul_scalar_sparse_csr(const Tensor& self, const Scalar& other) {
310+
auto result_values = self.values().mul(other);
311+
return at::native::_sparse_csr_tensor_unsafe(
312+
self.crow_indices().clone(),
313+
self.col_indices().clone(),
314+
result_values,
315+
self.sizes(),
316+
result_values.scalar_type(),
317+
self.layout(),
318+
result_values.device());
319+
}
320+
283321
/* Implementation of Unary Ufuncs, those supported for Sparse CSR Layout
284322
* Only simple funcs, with 0->0 correspondence are currently supported. */
285323

aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,17 +1282,16 @@ void sampled_addmm_out_sparse_csr(
12821282
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(B.layout() == Layout::Strided);
12831283
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C.is_sparse_csr());
12841284

1285-
auto descA = at::cuda::sparse::CuSparseDnMatDescriptor(A);
1286-
auto descB = at::cuda::sparse::CuSparseDnMatDescriptor(B);
1287-
auto descC = at::cuda::sparse::CuSparseSpMatCsrDescriptor(C);
1288-
12891285
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
12901286
cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
12911287

12921288
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
12931289
C.scalar_type(),
12941290
"sampled_addmm_out_sparse_csr",
12951291
[&] {
1292+
auto descA = at::cuda::sparse::CuSparseDnMatDescriptor(A);
1293+
auto descB = at::cuda::sparse::CuSparseDnMatDescriptor(B);
1294+
auto descC = at::cuda::sparse::CuSparseSpMatCsrDescriptor(C);
12961295
auto beta_ = beta.to<scalar_t>();
12971296
auto alpha_ = alpha.to<scalar_t>();
12981297
auto compute_type = at::cuda::getCudaDataType<scalar_t>();

test/test_sparse_csr.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,38 @@ def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None):
14541454
for op_a, op_b in itertools.product([True, False], repeat=2):
14551455
run_test(c, a, b, op_a, op_b)
14561456

1457+
@skipCUDAIfRocm
1458+
@onlyCUDA
1459+
@skipCUDAIf(
1460+
not _check_cusparse_sddmm_available(),
1461+
"cuSparse Generic API SDDMM is not available"
1462+
)
1463+
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1464+
def test_sampled_addmm_autograd(self, device, dtype):
1465+
from torch.testing._internal.common_methods_invocations import sample_inputs_sparse_sampled_addmm
1466+
1467+
samples = list(sample_inputs_sparse_sampled_addmm(None, device, dtype, requires_grad=True))
1468+
1469+
for sample, dense_covector in zip(samples, [True, False]):
1470+
c = sample.input
1471+
a = sample.args[0]
1472+
b = sample.args[1]
1473+
1474+
# Compute sparse result
1475+
output = torch.sparse.sampled_addmm(c, a, b, **sample.kwargs)
1476+
covector = torch.randn_like(output).to_dense() if dense_covector else torch.randn_like(output)
1477+
output.backward(covector)
1478+
1479+
# Compute dense result and compare with sparse result
1480+
c1, a1, b1 = map(lambda x: x.detach().to_dense().requires_grad_(True), [c, a, b])
1481+
dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1
1482+
self.assertEqual(output, dense_output)
1483+
dense_covector = covector.to_dense()
1484+
dense_output.backward(dense_covector)
1485+
self.assertEqual(c.grad, c1.grad)
1486+
self.assertEqual(a.grad, a1.grad)
1487+
self.assertEqual(b.grad, b1.grad)
1488+
14571489
@skipCUDAIfRocm
14581490
@onlyCUDA
14591491
@skipCUDAIf(True, "Causes CUDA memory exception, see https://github.com/pytorch/pytorch/issues/72177")

tools/autograd/derivatives.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2327,6 +2327,11 @@
23272327
self: zeros_like(self)
23282328
result: replication_pad3d_backward(grad_output_t, self_p, padding)
23292329

2330+
- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
2331+
self: maybe_multiply(grad, beta.conj())
2332+
mat1: maybe_multiply(grad.sparse_mask(self).mm(mat2.mH()), alpha.conj())
2333+
mat2: maybe_multiply(mat1.mH().mm(grad.sparse_mask(self)), alpha.conj())
2334+
23302335
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor
23312336
grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, beta)
23322337
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)

tools/autograd/gen_variable_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@
329329
"im2col",
330330
"im2col_backward",
331331
"cholesky_inverse",
332+
"to_sparse",
333+
"sparse_sampled_addmm",
332334
}
333335

334336
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {

torch/testing/_internal/common_methods_invocations.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,6 +2850,36 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
28502850
kwargs={'alpha': alpha_val, 'beta': beta_val},))
28512851
return sample_inputs
28522852

2853+
def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs):
2854+
alpha = 2 + 3j if dtype.is_complex else 0.6
2855+
beta = 1 + 2j if dtype.is_complex else 0.2
2856+
2857+
def generator():
2858+
# sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C
2859+
for m, n, k in itertools.product([0, 5], repeat=3):
2860+
yield SampleInput(
2861+
torch.eye(m, n, device=device, dtype=dtype)
2862+
.to_sparse_csr()
2863+
.requires_grad_(requires_grad),
2864+
args=(
2865+
make_tensor(
2866+
(m, k),
2867+
device=device,
2868+
dtype=dtype,
2869+
requires_grad=requires_grad,
2870+
),
2871+
make_tensor(
2872+
(k, n),
2873+
device=device,
2874+
dtype=dtype,
2875+
requires_grad=requires_grad,
2876+
),
2877+
),
2878+
kwargs={"alpha": alpha, "beta": beta},
2879+
)
2880+
2881+
return list(generator())
2882+
28532883
def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
28542884
return (
28552885
SampleInput(
@@ -10689,6 +10719,45 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
1068910719
supports_forward_ad=True,
1069010720
supports_fwgrad_bwgrad=True,
1069110721
supports_out=False),
10722+
OpInfo('sparse.sampled_addmm',
10723+
dtypes=floating_and_complex_types(),
10724+
supports_autograd=True,
10725+
sample_inputs_func=sample_inputs_sparse_sampled_addmm,
10726+
decorators=[
10727+
onlyCUDA,
10728+
skipCUDAIf(_get_torch_cuda_version() < (11, 3), "cusparseSDDMM was added in 11.2.1"), ],
10729+
skips=(
10730+
# NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
10731+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
10732+
# RuntimeError: Sparse CSR tensors do not have strides.
10733+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
10734+
# RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided
10735+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
10736+
# RuntimeError: Sparse CSR tensors do not have strides
10737+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
10738+
# RuntimeError: Sparse CSR tensors do not have strides
10739+
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
10740+
# RuntimeError: Sparse CSR tensors do not have strides
10741+
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
10742+
# RuntimeError: Sparse CSR tensors do not have strides
10743+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
10744+
# RuntimeError: Sparse CSR tensors do not have strides
10745+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
10746+
# RuntimeError: Sparse CSR tensors do not have strides
10747+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
10748+
# RuntimeError: Sparse CSR tensors do not have strides
10749+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
10750+
# RuntimeError: unsupported memory format option Preserve
10751+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
10752+
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
10753+
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad'),
10754+
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
10755+
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad'),
10756+
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
10757+
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad'),
10758+
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
10759+
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),
10760+
)),
1069210761
UnaryUfuncInfo('i0',
1069310762
ref=np_unary_ufunc_integer_promotion_wrapper(
1069410763
scipy.special.i0) if TEST_SCIPY else _NOTHING,

0 commit comments

Comments
 (0)