Skip to content

Commit f6f8d13

Browse files
authored
Remove restriction on operation arguments having the same sparsity (#10)
1 parent 4175b01 commit f6f8d13

File tree

3 files changed

+50
-21
lines changed

3 files changed

+50
-21
lines changed

Diff for: mlir_graphblas/implementations.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .compiler import compile, engine_cache
1818
from . descriptor import Descriptor, NULL as NULL_DESC
1919
from .utils import (get_sparse_output_pointer, get_scalar_output_pointer,
20-
get_scalar_input_arg, pick_and_renumber_indices)
20+
get_scalar_input_arg, pick_and_renumber_indices, determine_sparsity)
2121
from .types import RankedTensorType, BOOL, INT64, FP64
2222
from .exceptions import GrbError, GrbIndexOutOfBounds, GrbDimensionMismatch
2323

@@ -34,7 +34,6 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor =
3434
in `sp` correspond to missing or "falsy" elements in the mask.
3535
"""
3636
assert mask.ndims == sp.ndims
37-
assert mask._sparsity == sp._sparsity
3837
if mask.shape != sp.shape:
3938
raise GrbDimensionMismatch(f"Mask shape mismatch: {mask.shape} != {sp.shape}")
4039

@@ -62,7 +61,7 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor =
6261
mem_out = get_sparse_output_pointer()
6362
arg_pointers = [mask._obj, sp._obj, mem_out]
6463
engine_cache[key].invoke('main', *arg_pointers)
65-
return mask.baseclass(sp.dtype, mask.shape, mem_out, mask._sparsity,
64+
return mask.baseclass(sp.dtype, mask.shape, mem_out, determine_sparsity(mask, sp),
6665
mask.perceived_ordering, intermediate_result=True)
6766

6867

@@ -80,7 +79,8 @@ def _build_select_by_mask(mask: SparseTensor, sp: SparseTensorBase, complement:
8079
perm_out = ir.AffineMap.get_permutation(range(rank))
8180
rtt_sp = sp.rtt.as_mlir_type()
8281
rtt_mask = mask.rtt.as_mlir_type()
83-
rtt_out = mask.rtt.copy(dtype=sp.dtype).as_mlir_type()
82+
rtt_out = mask.rtt.copy(dtype=sp.dtype,
83+
sparsity=determine_sparsity(mask, sp)).as_mlir_type()
8484

8585
@func.FuncOp.from_py_func(rtt_mask, rtt_sp)
8686
def main(msk, x):
@@ -368,8 +368,6 @@ def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
368368
engine_cache[key].invoke('main', *arg_pointers)
369369
return Scalar(left.dtype, (), left.dtype.np_type(mem_out.contents.value))
370370

371-
assert left._sparsity == right._sparsity
372-
373371
# Build and compile if needed
374372
key = ('ewise_add', op.name, *left.get_loop_key(), *right.get_loop_key())
375373
if key not in engine_cache:
@@ -380,7 +378,8 @@ def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
380378
arg_pointers = [left._obj, right._obj, mem_out]
381379
engine_cache[key].invoke('main', *arg_pointers)
382380
return left.baseclass(op.get_output_type(left.dtype, right.dtype), left.shape, mem_out,
383-
left._sparsity, left.perceived_ordering, intermediate_result=True)
381+
determine_sparsity(left, right, union=True), left.perceived_ordering,
382+
intermediate_result=True)
384383

385384

386385
def _build_ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
@@ -395,7 +394,8 @@ def _build_ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBa
395394
perm_out = ir.AffineMap.get_permutation(range(rank))
396395
rtt_left = left.rtt.as_mlir_type()
397396
rtt_right = right.rtt.as_mlir_type()
398-
rtt_out = left.rtt.copy(ordering=left.perceived_ordering).as_mlir_type()
397+
rtt_out = left.rtt.copy(ordering=left.perceived_ordering,
398+
sparsity=determine_sparsity(left, right, union=True)).as_mlir_type()
399399

400400
@func.FuncOp.from_py_func(rtt_left, rtt_right)
401401
def main(x, y):
@@ -443,8 +443,6 @@ def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
443443
engine_cache[key].invoke('main', *arg_pointers)
444444
return Scalar(output_dtype, (), output_dtype.np_type(mem_out.contents.value))
445445

446-
assert left._sparsity == right._sparsity
447-
448446
# Build and compile if needed
449447
key = ('ewise_mult', op.name, *left.get_loop_key(), *right.get_loop_key())
450448
if key not in engine_cache:
@@ -455,7 +453,8 @@ def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
455453
arg_pointers = [left._obj, right._obj, mem_out]
456454
engine_cache[key].invoke('main', *arg_pointers)
457455
return left.baseclass(output_dtype, left.shape, mem_out,
458-
left._sparsity, left.perceived_ordering, intermediate_result=True)
456+
determine_sparsity(left, right), left.perceived_ordering,
457+
intermediate_result=True)
459458

460459

461460
def _build_ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
@@ -472,7 +471,9 @@ def _build_ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorB
472471
perm_out = ir.AffineMap.get_permutation(range(rank))
473472
rtt_left = left.rtt.as_mlir_type()
474473
rtt_right = right.rtt.as_mlir_type()
475-
rtt_out = left.rtt.copy(dtype=op_result_dtype, ordering=left.perceived_ordering).as_mlir_type()
474+
rtt_out = RankedTensorType(dtype=op_result_dtype,
475+
sparsity=determine_sparsity(left, right),
476+
ordering=left.perceived_ordering).as_mlir_type()
476477

477478
@func.FuncOp.from_py_func(rtt_left, rtt_right)
478479
def main(x, y):
@@ -511,8 +512,6 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
511512
if left._obj is None or right._obj is None:
512513
return Matrix.new(optype, left.shape[0], right.shape[1])
513514

514-
assert left._sparsity == right._sparsity
515-
516515
# Build and compile if needed
517516
key = ('mxm', op.name, *left.get_loop_key(), *right.get_loop_key())
518517
if key not in engine_cache:
@@ -523,7 +522,7 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
523522
arg_pointers = [left._obj, right._obj, mem_out]
524523
engine_cache[key].invoke('main', *arg_pointers)
525524
return Matrix(optype, [left.shape[0], right.shape[1]], mem_out,
526-
left._sparsity, left.perceived_ordering, intermediate_result=True)
525+
determine_sparsity(left, right), left.perceived_ordering, intermediate_result=True)
527526

528527

529528
def _build_mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix, TransposedMatrix]):
@@ -539,7 +538,9 @@ def _build_mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union
539538
perm_out = ir.AffineMap.get(3, 0, [ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1)])
540539
rtt_left = left.rtt.as_mlir_type()
541540
rtt_right = right.rtt.as_mlir_type()
542-
rtt_out = left.rtt.copy(dtype=op_result_dtype, ordering=left.perceived_ordering).as_mlir_type()
541+
rtt_out = RankedTensorType(dtype=op_result_dtype,
542+
sparsity=determine_sparsity(left, right),
543+
ordering=left.perceived_ordering).as_mlir_type()
543544

544545
@func.FuncOp.from_py_func(rtt_left, rtt_right)
545546
def main(x, y):
@@ -1223,18 +1224,18 @@ def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_siz
12231224
v = Vector.new(tensor.dtype, row_size)
12241225
# Map idx to output indices
12251226
idx = np.array(row_indices, dtype=np.uint64)[idx]
1226-
v.build(idx, vals)
1227+
v.build(idx, vals, sparsity=tensor._sparsity)
12271228
return v
12281229
# Assign Vector as row or column of Matrix
12291230
m = Matrix.new(tensor.dtype, row_size, col_size)
12301231
if type(row_indices) is int:
12311232
# Map idx to output cols
12321233
colidx = idx if col_indices is None else np.array(col_indices, dtype=np.uint64)[idx]
1233-
m.build([row_indices]*len(vals), colidx, vals)
1234+
m.build([row_indices]*len(vals), colidx, vals, sparsity=["compressed", "compressed"])
12341235
if type(col_indices) is int:
12351236
# Map idx to output rows
12361237
rowidx = idx if row_indices is None else np.array(row_indices, dtype=np.uint64)[idx]
1237-
m.build(rowidx, [col_indices]*len(vals), vals)
1238+
m.build(rowidx, [col_indices]*len(vals), vals, sparsity=["compressed", "compressed"])
12381239
return m
12391240

12401241
# Matrix input
@@ -1249,5 +1250,5 @@ def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_siz
12491250
if col_indices is not None:
12501251
colidx = np.array(col_indices, dtype=np.uint64)[colidx]
12511252
m = Matrix.new(tensor.dtype, row_size, col_size)
1252-
m.build(rowidx, colidx, vals)
1253+
m.build(rowidx, colidx, vals, sparsity=["compressed", "compressed"])
12531254
return m

Diff for: mlir_graphblas/operations.py

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .utils import ensure_scalar_of_type, ensure_unique
1212
from .types import BOOL, INT64
1313
from .operators import BinaryOp, SelectOp
14-
from mlir.dialects.sparse_tensor import DimLevelType
1514

1615

1716
__all__ = ["transpose", "ewise_add", "ewise_mult", "mxm", "apply", "select",

Diff for: mlir_graphblas/utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from enum import Enum
44
from mlir import ir
5+
from mlir.dialects.sparse_tensor import DimLevelType
56
from .exceptions import (
67
GrbNullPointer, GrbInvalidValue, GrbInvalidIndex, GrbDomainMismatch,
78
GrbDimensionMismatch, GrbOutputNotEmpty, GrbIndexOutOfBounds, GrbEmptyObject
@@ -49,6 +50,34 @@ def ensure_unique(indices, name=None):
4950
raise ValueError(f"Found duplicate indices{name_str}: {unique[counts > 1]}")
5051

5152

53+
def determine_sparsity(left, right, union=False):
54+
"""
55+
Returns the sparsity appropriate for the two inputs based on `union`.
56+
If union == True, finds a sparsity that anticipates more values than either input.
57+
If union == False, finds a sparsity that anticipates the same or fewer values than either input.
58+
"""
59+
assert left.ndims == right.ndims
60+
assert left.ndims > 0, "Not allowed for Scalars"
61+
if left._sparsity == right._sparsity:
62+
return left._sparsity
63+
64+
dense = DimLevelType.dense
65+
comp = DimLevelType.compressed
66+
67+
if left.ndims == 1: # Vector
68+
levels = ([comp], [dense])
69+
else: # Matrix
70+
levels = ([comp, comp], [comp, dense], [dense, comp], [dense, dense])
71+
72+
if union:
73+
levels = reversed(levels)
74+
for lvl in levels:
75+
if left._sparsity == lvl or right._sparsity == lvl:
76+
return lvl
77+
78+
raise Exception("something went wrong finding the sparsity")
79+
80+
5281
def pick_and_renumber_indices(selected, indices, *related):
5382
"""
5483
This function is used by the `extract` operation.

0 commit comments

Comments
 (0)