@@ -295,16 +295,19 @@ def main(x):
295
295
def ewise_add (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
296
296
assert left .ndims == right .ndims
297
297
assert left .dtype == right .dtype
298
+
299
+ if left ._obj is None :
300
+ return right
301
+ if right ._obj is None :
302
+ return left
303
+
298
304
assert left ._sparsity == right ._sparsity
299
305
300
306
rank = left .ndims
301
307
if rank == 0 : # Scalar
302
308
# TODO: implement this
303
309
raise NotImplementedError ("doesn't yet work for Scalar" )
304
310
305
- # TODO: handle case of either left or right not having an _obj -> result will be other for ewise_add
306
- # or have a utility to build an empty MLIRSparseTensor for all input tensors?
307
-
308
311
# Build and compile if needed
309
312
key = ('ewise_add' , op .name , * left .get_loop_key (), * right .get_loop_key ())
310
313
if key not in engine_cache :
@@ -363,15 +366,19 @@ def main(x, y):
363
366
def ewise_mult (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
364
367
assert left .ndims == right .ndims
365
368
assert left .dtype == right .dtype
369
+
370
+ if left ._obj is None :
371
+ return left
372
+ if right ._obj is None :
373
+ return right
374
+
366
375
assert left ._sparsity == right ._sparsity
367
376
368
377
rank = left .ndims
369
378
if rank == 0 : # Scalar
370
379
# TODO: implement this
371
380
raise NotImplementedError ("doesn't yet work for Scalar" )
372
381
373
- # TODO: handle case of either left or right not having an _obj -> result will be empty for ewise_mult
374
-
375
382
# Build and compile if needed
376
383
key = ('ewise_mult' , op .name , * left .get_loop_key (), * right .get_loop_key ())
377
384
if key not in engine_cache :
@@ -433,9 +440,12 @@ def main(x, y):
433
440
def mxm (op : Semiring , left : Union [Matrix , TransposedMatrix ], right : Union [Matrix , TransposedMatrix ]):
434
441
assert left .ndims == right .ndims == 2
435
442
assert left .dtype == right .dtype
436
- assert left ._sparsity == right ._sparsity
437
443
438
- # TODO: handle case of either left or right not having an _obj -> result will be empty for mxm
444
+ optype = op .binop .get_output_type (left .dtype , right .dtype )
445
+ if left ._obj is None or right ._obj is None :
446
+ return Matrix .new (optype , left .shape [0 ], right .shape [1 ])
447
+
448
+ assert left ._sparsity == right ._sparsity
439
449
440
450
# Build and compile if needed
441
451
key = ('mxm' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -446,7 +456,7 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
446
456
mem_out = get_sparse_output_pointer ()
447
457
arg_pointers = [left ._obj , right ._obj , mem_out ]
448
458
engine_cache [key ].invoke ('main' , * arg_pointers )
449
- return Matrix (op . binop . get_output_type ( left . dtype , right . dtype ) , [left .shape [0 ], right .shape [1 ]], mem_out ,
459
+ return Matrix (optype , [left .shape [0 ], right .shape [1 ]], mem_out ,
450
460
left ._sparsity , left .perceived_ordering , intermediate_result = True )
451
461
452
462
@@ -509,9 +519,10 @@ def main(x, y):
509
519
def mxv (op : Semiring , left : Union [Matrix , TransposedMatrix ], right : Vector ):
510
520
assert left .ndims == 2
511
521
assert right .ndims == 1
512
- assert left .dtype == right .dtype
513
522
514
- # TODO: handle case of either left or right not having an _obj -> result will be empty for mxv
523
+ optype = op .binop .get_output_type (left .dtype , right .dtype )
524
+ if left ._obj is None or right ._obj is None :
525
+ return Vector .new (optype , left .shape [0 ])
515
526
516
527
# Build and compile if needed
517
528
key = ('mxv' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -522,7 +533,7 @@ def mxv(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector):
522
533
mem_out = get_sparse_output_pointer ()
523
534
arg_pointers = [left ._obj , right ._obj , mem_out ]
524
535
engine_cache [key ].invoke ('main' , * arg_pointers )
525
- return Vector (op . binop . get_output_type ( left . dtype , right . dtype ) , [left .shape [0 ]], mem_out ,
536
+ return Vector (optype , [left .shape [0 ]], mem_out ,
526
537
right ._sparsity , right .perceived_ordering , intermediate_result = True )
527
538
528
539
@@ -583,9 +594,10 @@ def main(x, y):
583
594
def vxm (op : Semiring , left : Vector , right : Union [Matrix , TransposedMatrix ]):
584
595
assert left .ndims == 1
585
596
assert right .ndims == 2
586
- assert left .dtype == right .dtype
587
597
588
- # TODO: handle case of either left or right not having an _obj -> result will be empty for vxm
598
+ optype = op .binop .get_output_type (left .dtype , right .dtype )
599
+ if left ._obj is None or right ._obj is None :
600
+ return Vector .new (optype , right .shape [1 ])
589
601
590
602
# Build and compile if needed
591
603
key = ('vxm' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -596,7 +608,7 @@ def vxm(op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]):
596
608
mem_out = get_sparse_output_pointer ()
597
609
arg_pointers = [left ._obj , right ._obj , mem_out ]
598
610
engine_cache [key ].invoke ('main' , * arg_pointers )
599
- return Vector (op . binop . get_output_type ( left . dtype , right . dtype ) , [right .shape [1 ]], mem_out ,
611
+ return Vector (optype , [right .shape [1 ]], mem_out ,
600
612
left ._sparsity , left .perceived_ordering , intermediate_result = True )
601
613
602
614
@@ -664,26 +676,34 @@ def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp],
664
676
# TODO: implement this
665
677
raise NotImplementedError ("doesn't yet work for Scalar" )
666
678
667
- # TODO: handle case of empty input (must figure out correct output dtype)
668
-
669
- # Build and compile if needed
670
- # Note that Scalars are included in the key because they are inlined in the compiled code
679
+ # Find output dtype
671
680
optype = type (op )
672
681
if optype is UnaryOp :
673
- key = ('apply_unary' , op .name , * sp .get_loop_key (), inplace )
674
682
output_dtype = op .get_output_type (sp .dtype )
675
683
elif optype is BinaryOp :
676
684
if left is not None :
677
- key = ('apply_bind_first' , op .name , * sp .get_loop_key (), left ._obj , inplace )
678
685
output_dtype = op .get_output_type (left .dtype , sp .dtype )
679
686
else :
680
- key = ('apply_bind_second' , op .name , * sp .get_loop_key (), right ._obj , inplace )
681
687
output_dtype = op .get_output_type (sp .dtype , right .dtype )
682
688
else :
683
689
if inplace :
684
690
raise TypeError ("apply inplace not supported for IndexUnaryOp" )
685
- key = ('apply_indexunary' , op .name , * sp .get_loop_key (), thunk ._obj )
686
691
output_dtype = op .get_output_type (sp .dtype , thunk .dtype )
692
+
693
+ if sp ._obj is None :
694
+ return sp .baseclass (output_dtype , sp .shape )
695
+
696
+ # Build and compile if needed
697
+ # Note that Scalars are included in the key because they are inlined in the compiled code
698
+ if optype is UnaryOp :
699
+ key = ('apply_unary' , op .name , * sp .get_loop_key (), inplace )
700
+ elif optype is BinaryOp :
701
+ if left is not None :
702
+ key = ('apply_bind_first' , op .name , * sp .get_loop_key (), left ._obj , inplace )
703
+ else :
704
+ key = ('apply_bind_second' , op .name , * sp .get_loop_key (), right ._obj , inplace )
705
+ else :
706
+ key = ('apply_indexunary' , op .name , * sp .get_loop_key (), thunk ._obj )
687
707
if key not in engine_cache :
688
708
if inplace :
689
709
engine_cache [key ] = _build_apply_inplace (op , sp , left , right )
@@ -887,7 +907,8 @@ def main(x):
887
907
888
908
889
909
def reduce_to_vector (op : Monoid , mat : Union [Matrix , TransposedMatrix ]):
890
- # TODO: handle case of mat not having an _obj -> result will be empty vector
910
+ if mat ._obj is None :
911
+ return Vector .new (mat .dtype , mat .shape [0 ])
891
912
892
913
# Build and compile if needed
893
914
key = ('reduce_to_vector' , op .name , * mat .get_loop_key ())
@@ -944,7 +965,8 @@ def main(x):
944
965
945
966
946
967
def reduce_to_scalar (op : Monoid , sp : SparseTensorBase ):
947
- # TODO: handle case of sp not having an _obj -> result will be empty scalar
968
+ if sp ._obj is None :
969
+ return Scalar .new (sp .dtype )
948
970
949
971
# Build and compile if needed
950
972
key = ('reduce_to_scalar' , op .name , * sp .get_loop_key ())
0 commit comments