32
32
#include < ATen/ops/resize_as_sparse_native.h>
33
33
#include < ATen/ops/resize_native.h>
34
34
#include < ATen/ops/select_native.h>
35
+ #include < ATen/ops/sparse_compressed_tensor_native.h>
35
36
#include < ATen/ops/sparse_csr_tensor_native.h>
36
37
#include < ATen/ops/values_native.h>
37
38
#endif
@@ -298,26 +299,54 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice
298
299
299
300
SPARSE_COMPRESSED_TENSOR_UNSAFE (csr, kSparseCsr );
300
301
302
+ DimVector _estimate_sparse_compressed_tensor_size (
303
+ const Tensor& compressed_indices,
304
+ const Tensor& plain_indices,
305
+ const Tensor& values,
306
+ Layout layout) {
307
+ DimVector size = DimVector (IntArrayRef (plain_indices.sizes ().data (), plain_indices.dim () - 1 ));
308
+ int64_t compressed_dim = (plain_indices.size (-1 ) > 0 ? compressed_indices.size (-1 ) - 1 : 0 );
309
+ int64_t plain_dim = AT_DISPATCH_INTEGRAL_TYPES (plain_indices.scalar_type (), " estimate_sparse_compressed_tensor_size" ,
310
+ [&]() -> int64_t { return plain_indices.max ().item <scalar_t >() + 1 ; });
311
+ AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS (layout, " estimate_sparse_compressed_tensor_size" ,
312
+ [&]{
313
+ size.push_back (compressed_dim);
314
+ size.push_back (plain_dim);
315
+ },
316
+ [&]{
317
+ size.push_back (plain_dim);
318
+ size.push_back (compressed_dim);
319
+ });
320
+ return size;
321
+ }
322
+
301
323
// TODO: This constructor should probably use an ATen abstract method in order
302
324
// to make autograd dispatch available for the CSR constructor. See the relevant
303
325
// note in native_functions.yaml.
304
- Tensor sparse_csr_tensor (
305
- const Tensor& crow_indices ,
306
- const Tensor& col_indices ,
326
+ Tensor sparse_compressed_tensor (
327
+ const Tensor& compressed_indices ,
328
+ const Tensor& plain_indices ,
307
329
const Tensor& values,
308
330
IntArrayRef size,
309
331
c10::optional<ScalarType> dtype,
310
332
c10::optional<Layout> layout,
311
333
c10::optional<Device> device,
312
334
c10::optional<bool > pin_memory) {
335
+
336
+ if (!layout) {
337
+ AT_ERROR (" sparse_compressed_tensor expected sparse compressed tensor layout but got none" );
338
+ }
339
+ Layout layout_ = layout.value ();
340
+ AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS (layout_, " sparse_compressed_tensor" , [&]{});
341
+
313
342
// See [Note: hacky wrapper removal for TensorOptions]
314
- TensorOptions options = TensorOptions ().dtype (dtype).layout (layout ).device (device).pinned_memory (pin_memory);
343
+ TensorOptions options = TensorOptions ().dtype (dtype).layout (layout_ ).device (device).pinned_memory (pin_memory);
315
344
316
- at::native::_validate_sparse_csr_tensor_args (crow_indices, col_indices , values, size);
345
+ _validate_sparse_compressed_tensor_args_worker (compressed_indices, plain_indices , values, size, layout_ );
317
346
318
- return at::native::_sparse_csr_tensor_unsafe (
319
- crow_indices ,
320
- col_indices ,
347
+ return at::native::_sparse_compressed_tensor_unsafe (
348
+ compressed_indices ,
349
+ plain_indices ,
321
350
values,
322
351
size,
323
352
optTypeMetaToScalarType (options.dtype_opt ()),
@@ -326,26 +355,31 @@ Tensor sparse_csr_tensor(
326
355
options.pinned_memory_opt ());
327
356
}
328
357
329
- Tensor sparse_csr_tensor (
330
- const Tensor& crow_indices ,
331
- const Tensor& col_indices ,
358
+ Tensor sparse_compressed_tensor (
359
+ const Tensor& compressed_indices ,
360
+ const Tensor& plain_indices ,
332
361
const Tensor& values,
333
362
c10::optional<ScalarType> dtype,
334
363
c10::optional<Layout> layout,
335
364
c10::optional<Device> device,
336
365
c10::optional<bool > pin_memory) {
366
+
367
+ if (!layout) {
368
+ AT_ERROR (" sparse_compressed_tensor expected sparse compressed tensor layout but got none" );
369
+ }
370
+ Layout layout_ = layout.value ();
371
+ AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS (layout_, " sparse_compressed_tensor" , [&]{});
372
+
373
+ DimVector size = _estimate_sparse_compressed_tensor_size (compressed_indices, plain_indices, values, layout_);
374
+
337
375
// See [Note: hacky wrapper removal for TensorOptions]
338
- TensorOptions options = TensorOptions ().dtype (dtype).layout (layout).device (device).pinned_memory (pin_memory);
339
- // std::array<int64_t, 2> size = {0, 0};
340
- auto size = DimVector (IntArrayRef (col_indices.sizes ().data (), col_indices.dim () - 1 ));
341
- size.push_back (crow_indices.size (-1 ) - 1 );
342
- size.push_back (col_indices.max ().item <int64_t >() + 1 );
376
+ TensorOptions options = TensorOptions ().dtype (dtype).layout (layout_).device (device).pinned_memory (pin_memory);
343
377
344
- at::native::_validate_sparse_csr_tensor_args (crow_indices, col_indices , values, size);
378
+ _validate_sparse_compressed_tensor_args_worker (compressed_indices, plain_indices , values, size, layout_ );
345
379
346
- return at::native::_sparse_csr_tensor_unsafe (
347
- crow_indices ,
348
- col_indices ,
380
+ return at::native::_sparse_compressed_tensor_unsafe (
381
+ compressed_indices ,
382
+ plain_indices ,
349
383
values,
350
384
size,
351
385
optTypeMetaToScalarType (options.dtype_opt ()),
@@ -354,6 +388,37 @@ Tensor sparse_csr_tensor(
354
388
options.pinned_memory_opt ());
355
389
}
356
390
391
+ #define SPARSE_COMPRESSED_TENSOR (KIND, REQUIRED_LAYOUT ) \
392
+ Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
393
+ const Tensor& plain_indices, \
394
+ const Tensor& values, \
395
+ c10::optional<ScalarType> dtype, \
396
+ c10::optional<Layout> layout, \
397
+ c10::optional<Device> device, \
398
+ c10::optional<bool > pin_memory) { \
399
+ if (layout) { \
400
+ TORCH_CHECK (layout.value () == REQUIRED_LAYOUT, " sparse " # KIND " layout must be " , REQUIRED_LAYOUT, " but got " , layout.value ()); \
401
+ } \
402
+ c10::optional<Layout> layout_ (REQUIRED_LAYOUT); \
403
+ return at::native::sparse_compressed_tensor (compressed_indices, plain_indices, values, dtype, layout_, device, pin_memory); \
404
+ } \
405
+ Tensor sparse_##KIND##_tensor(const Tensor& compressed_indices, \
406
+ const Tensor& plain_indices, \
407
+ const Tensor& values, \
408
+ IntArrayRef size, \
409
+ c10::optional<ScalarType> dtype, \
410
+ c10::optional<Layout> layout, \
411
+ c10::optional<Device> device, \
412
+ c10::optional<bool > pin_memory) { \
413
+ if (layout) { \
414
+ TORCH_CHECK (layout.value () == REQUIRED_LAYOUT, " sparse " # KIND " layout must be " , REQUIRED_LAYOUT, " but got " , layout.value ()); \
415
+ } \
416
+ c10::optional<Layout> layout_ (REQUIRED_LAYOUT); \
417
+ return at::native::sparse_compressed_tensor (compressed_indices, plain_indices, values, size, dtype, layout_, device, pin_memory); \
418
+ }
419
+
420
+ SPARSE_COMPRESSED_TENSOR (csr, kSparseCsr )
421
+
357
422
Tensor empty_sparse_csr (
358
423
IntArrayRef size,
359
424
c10::optional<ScalarType> dtype,
0 commit comments