17
17
from .compiler import compile , engine_cache
18
18
from . descriptor import Descriptor , NULL as NULL_DESC
19
19
from .utils import (get_sparse_output_pointer , get_scalar_output_pointer ,
20
- get_scalar_input_arg , pick_and_renumber_indices )
20
+ get_scalar_input_arg , pick_and_renumber_indices , determine_sparsity )
21
21
from .types import RankedTensorType , BOOL , INT64 , FP64
22
22
from .exceptions import GrbError , GrbIndexOutOfBounds , GrbDimensionMismatch
23
23
@@ -34,7 +34,6 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor =
34
34
in `sp` correspond to missing or "falsy" elements in the mask.
35
35
"""
36
36
assert mask .ndims == sp .ndims
37
- assert mask ._sparsity == sp ._sparsity
38
37
if mask .shape != sp .shape :
39
38
raise GrbDimensionMismatch (f"Mask shape mismatch: { mask .shape } != { sp .shape } " )
40
39
@@ -62,7 +61,7 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor =
62
61
mem_out = get_sparse_output_pointer ()
63
62
arg_pointers = [mask ._obj , sp ._obj , mem_out ]
64
63
engine_cache [key ].invoke ('main' , * arg_pointers )
65
- return mask .baseclass (sp .dtype , mask .shape , mem_out , mask . _sparsity ,
64
+ return mask .baseclass (sp .dtype , mask .shape , mem_out , determine_sparsity ( mask , sp ) ,
66
65
mask .perceived_ordering , intermediate_result = True )
67
66
68
67
@@ -80,7 +79,8 @@ def _build_select_by_mask(mask: SparseTensor, sp: SparseTensorBase, complement:
80
79
perm_out = ir .AffineMap .get_permutation (range (rank ))
81
80
rtt_sp = sp .rtt .as_mlir_type ()
82
81
rtt_mask = mask .rtt .as_mlir_type ()
83
- rtt_out = mask .rtt .copy (dtype = sp .dtype ).as_mlir_type ()
82
+ rtt_out = mask .rtt .copy (dtype = sp .dtype ,
83
+ sparsity = determine_sparsity (mask , sp )).as_mlir_type ()
84
84
85
85
@func .FuncOp .from_py_func (rtt_mask , rtt_sp )
86
86
def main (msk , x ):
@@ -368,8 +368,6 @@ def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
368
368
engine_cache [key ].invoke ('main' , * arg_pointers )
369
369
return Scalar (left .dtype , (), left .dtype .np_type (mem_out .contents .value ))
370
370
371
- assert left ._sparsity == right ._sparsity
372
-
373
371
# Build and compile if needed
374
372
key = ('ewise_add' , op .name , * left .get_loop_key (), * right .get_loop_key ())
375
373
if key not in engine_cache :
@@ -380,7 +378,8 @@ def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
380
378
arg_pointers = [left ._obj , right ._obj , mem_out ]
381
379
engine_cache [key ].invoke ('main' , * arg_pointers )
382
380
return left .baseclass (op .get_output_type (left .dtype , right .dtype ), left .shape , mem_out ,
383
- left ._sparsity , left .perceived_ordering , intermediate_result = True )
381
+ determine_sparsity (left , right , union = True ), left .perceived_ordering ,
382
+ intermediate_result = True )
384
383
385
384
386
385
def _build_ewise_add (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
@@ -395,7 +394,8 @@ def _build_ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBa
395
394
perm_out = ir .AffineMap .get_permutation (range (rank ))
396
395
rtt_left = left .rtt .as_mlir_type ()
397
396
rtt_right = right .rtt .as_mlir_type ()
398
- rtt_out = left .rtt .copy (ordering = left .perceived_ordering ).as_mlir_type ()
397
+ rtt_out = left .rtt .copy (ordering = left .perceived_ordering ,
398
+ sparsity = determine_sparsity (left , right , union = True )).as_mlir_type ()
399
399
400
400
@func .FuncOp .from_py_func (rtt_left , rtt_right )
401
401
def main (x , y ):
@@ -443,8 +443,6 @@ def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
443
443
engine_cache [key ].invoke ('main' , * arg_pointers )
444
444
return Scalar (output_dtype , (), output_dtype .np_type (mem_out .contents .value ))
445
445
446
- assert left ._sparsity == right ._sparsity
447
-
448
446
# Build and compile if needed
449
447
key = ('ewise_mult' , op .name , * left .get_loop_key (), * right .get_loop_key ())
450
448
if key not in engine_cache :
@@ -455,7 +453,8 @@ def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
455
453
arg_pointers = [left ._obj , right ._obj , mem_out ]
456
454
engine_cache [key ].invoke ('main' , * arg_pointers )
457
455
return left .baseclass (output_dtype , left .shape , mem_out ,
458
- left ._sparsity , left .perceived_ordering , intermediate_result = True )
456
+ determine_sparsity (left , right ), left .perceived_ordering ,
457
+ intermediate_result = True )
459
458
460
459
461
460
def _build_ewise_mult (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
@@ -472,7 +471,9 @@ def _build_ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorB
472
471
perm_out = ir .AffineMap .get_permutation (range (rank ))
473
472
rtt_left = left .rtt .as_mlir_type ()
474
473
rtt_right = right .rtt .as_mlir_type ()
475
- rtt_out = left .rtt .copy (dtype = op_result_dtype , ordering = left .perceived_ordering ).as_mlir_type ()
474
+ rtt_out = RankedTensorType (dtype = op_result_dtype ,
475
+ sparsity = determine_sparsity (left , right ),
476
+ ordering = left .perceived_ordering ).as_mlir_type ()
476
477
477
478
@func .FuncOp .from_py_func (rtt_left , rtt_right )
478
479
def main (x , y ):
@@ -511,8 +512,6 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
511
512
if left ._obj is None or right ._obj is None :
512
513
return Matrix .new (optype , left .shape [0 ], right .shape [1 ])
513
514
514
- assert left ._sparsity == right ._sparsity
515
-
516
515
# Build and compile if needed
517
516
key = ('mxm' , op .name , * left .get_loop_key (), * right .get_loop_key ())
518
517
if key not in engine_cache :
@@ -523,7 +522,7 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
523
522
arg_pointers = [left ._obj , right ._obj , mem_out ]
524
523
engine_cache [key ].invoke ('main' , * arg_pointers )
525
524
return Matrix (optype , [left .shape [0 ], right .shape [1 ]], mem_out ,
526
- left . _sparsity , left .perceived_ordering , intermediate_result = True )
525
+ determine_sparsity ( left , right ) , left .perceived_ordering , intermediate_result = True )
527
526
528
527
529
528
def _build_mxm (op : Semiring , left : Union [Matrix , TransposedMatrix ], right : Union [Matrix , TransposedMatrix ]):
@@ -539,7 +538,9 @@ def _build_mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union
539
538
perm_out = ir .AffineMap .get (3 , 0 , [ir .AffineDimExpr .get (0 ), ir .AffineDimExpr .get (1 )])
540
539
rtt_left = left .rtt .as_mlir_type ()
541
540
rtt_right = right .rtt .as_mlir_type ()
542
- rtt_out = left .rtt .copy (dtype = op_result_dtype , ordering = left .perceived_ordering ).as_mlir_type ()
541
+ rtt_out = RankedTensorType (dtype = op_result_dtype ,
542
+ sparsity = determine_sparsity (left , right ),
543
+ ordering = left .perceived_ordering ).as_mlir_type ()
543
544
544
545
@func .FuncOp .from_py_func (rtt_left , rtt_right )
545
546
def main (x , y ):
@@ -1223,18 +1224,18 @@ def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_siz
1223
1224
v = Vector .new (tensor .dtype , row_size )
1224
1225
# Map idx to output indices
1225
1226
idx = np .array (row_indices , dtype = np .uint64 )[idx ]
1226
- v .build (idx , vals )
1227
+ v .build (idx , vals , sparsity = tensor . _sparsity )
1227
1228
return v
1228
1229
# Assign Vector as row or column of Matrix
1229
1230
m = Matrix .new (tensor .dtype , row_size , col_size )
1230
1231
if type (row_indices ) is int :
1231
1232
# Map idx to output cols
1232
1233
colidx = idx if col_indices is None else np .array (col_indices , dtype = np .uint64 )[idx ]
1233
- m .build ([row_indices ]* len (vals ), colidx , vals )
1234
+ m .build ([row_indices ]* len (vals ), colidx , vals , sparsity = [ "compressed" , "compressed" ] )
1234
1235
if type (col_indices ) is int :
1235
1236
# Map idx to output rows
1236
1237
rowidx = idx if row_indices is None else np .array (row_indices , dtype = np .uint64 )[idx ]
1237
- m .build (rowidx , [col_indices ]* len (vals ), vals )
1238
+ m .build (rowidx , [col_indices ]* len (vals ), vals , sparsity = [ "compressed" , "compressed" ] )
1238
1239
return m
1239
1240
1240
1241
# Matrix input
@@ -1249,5 +1250,5 @@ def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_siz
1249
1250
if col_indices is not None :
1250
1251
colidx = np .array (col_indices , dtype = np .uint64 )[colidx ]
1251
1252
m = Matrix .new (tensor .dtype , row_size , col_size )
1252
- m .build (rowidx , colidx , vals )
1253
+ m .build (rowidx , colidx , vals , sparsity = [ "compressed" , "compressed" ] )
1253
1254
return m
0 commit comments