Skip to content

Commit a7d1408

Browse files
authored
Add tests for empty input (#7)
1 parent 1d24a87 commit a7d1408

File tree

5 files changed

+225
-39
lines changed

5 files changed

+225
-39
lines changed

Diff for: mlir_graphblas/implementations.py

+46-24
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,19 @@ def main(x):
295295
def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
296296
assert left.ndims == right.ndims
297297
assert left.dtype == right.dtype
298+
299+
if left._obj is None:
300+
return right
301+
if right._obj is None:
302+
return left
303+
298304
assert left._sparsity == right._sparsity
299305

300306
rank = left.ndims
301307
if rank == 0: # Scalar
302308
# TODO: implement this
303309
raise NotImplementedError("doesn't yet work for Scalar")
304310

305-
# TODO: handle case of either left or right not having an _obj -> result will be other for ewise_add
306-
# or have a utility to build an empty MLIRSparseTensor for all input tensors?
307-
308311
# Build and compile if needed
309312
key = ('ewise_add', op.name, *left.get_loop_key(), *right.get_loop_key())
310313
if key not in engine_cache:
@@ -363,15 +366,19 @@ def main(x, y):
363366
def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
364367
assert left.ndims == right.ndims
365368
assert left.dtype == right.dtype
369+
370+
if left._obj is None:
371+
return left
372+
if right._obj is None:
373+
return right
374+
366375
assert left._sparsity == right._sparsity
367376

368377
rank = left.ndims
369378
if rank == 0: # Scalar
370379
# TODO: implement this
371380
raise NotImplementedError("doesn't yet work for Scalar")
372381

373-
# TODO: handle case of either left or right not having an _obj -> result will be empty for ewise_mult
374-
375382
# Build and compile if needed
376383
key = ('ewise_mult', op.name, *left.get_loop_key(), *right.get_loop_key())
377384
if key not in engine_cache:
@@ -433,9 +440,12 @@ def main(x, y):
433440
def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix, TransposedMatrix]):
434441
assert left.ndims == right.ndims == 2
435442
assert left.dtype == right.dtype
436-
assert left._sparsity == right._sparsity
437443

438-
# TODO: handle case of either left or right not having an _obj -> result will be empty for mxm
444+
optype = op.binop.get_output_type(left.dtype, right.dtype)
445+
if left._obj is None or right._obj is None:
446+
return Matrix.new(optype, left.shape[0], right.shape[1])
447+
448+
assert left._sparsity == right._sparsity
439449

440450
# Build and compile if needed
441451
key = ('mxm', op.name, *left.get_loop_key(), *right.get_loop_key())
@@ -446,7 +456,7 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
446456
mem_out = get_sparse_output_pointer()
447457
arg_pointers = [left._obj, right._obj, mem_out]
448458
engine_cache[key].invoke('main', *arg_pointers)
449-
return Matrix(op.binop.get_output_type(left.dtype, right.dtype), [left.shape[0], right.shape[1]], mem_out,
459+
return Matrix(optype, [left.shape[0], right.shape[1]], mem_out,
450460
left._sparsity, left.perceived_ordering, intermediate_result=True)
451461

452462

@@ -509,9 +519,10 @@ def main(x, y):
509519
def mxv(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector):
510520
assert left.ndims == 2
511521
assert right.ndims == 1
512-
assert left.dtype == right.dtype
513522

514-
# TODO: handle case of either left or right not having an _obj -> result will be empty for mxv
523+
optype = op.binop.get_output_type(left.dtype, right.dtype)
524+
if left._obj is None or right._obj is None:
525+
return Vector.new(optype, left.shape[0])
515526

516527
# Build and compile if needed
517528
key = ('mxv', op.name, *left.get_loop_key(), *right.get_loop_key())
@@ -522,7 +533,7 @@ def mxv(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector):
522533
mem_out = get_sparse_output_pointer()
523534
arg_pointers = [left._obj, right._obj, mem_out]
524535
engine_cache[key].invoke('main', *arg_pointers)
525-
return Vector(op.binop.get_output_type(left.dtype, right.dtype), [left.shape[0]], mem_out,
536+
return Vector(optype, [left.shape[0]], mem_out,
526537
right._sparsity, right.perceived_ordering, intermediate_result=True)
527538

528539

@@ -583,9 +594,10 @@ def main(x, y):
583594
def vxm(op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]):
584595
assert left.ndims == 1
585596
assert right.ndims == 2
586-
assert left.dtype == right.dtype
587597

588-
# TODO: handle case of either left or right not having an _obj -> result will be empty for vxm
598+
optype = op.binop.get_output_type(left.dtype, right.dtype)
599+
if left._obj is None or right._obj is None:
600+
return Vector.new(optype, right.shape[1])
589601

590602
# Build and compile if needed
591603
key = ('vxm', op.name, *left.get_loop_key(), *right.get_loop_key())
@@ -596,7 +608,7 @@ def vxm(op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]):
596608
mem_out = get_sparse_output_pointer()
597609
arg_pointers = [left._obj, right._obj, mem_out]
598610
engine_cache[key].invoke('main', *arg_pointers)
599-
return Vector(op.binop.get_output_type(left.dtype, right.dtype), [right.shape[1]], mem_out,
611+
return Vector(optype, [right.shape[1]], mem_out,
600612
left._sparsity, left.perceived_ordering, intermediate_result=True)
601613

602614

@@ -664,26 +676,34 @@ def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp],
664676
# TODO: implement this
665677
raise NotImplementedError("doesn't yet work for Scalar")
666678

667-
# TODO: handle case of empty input (must figure out correct output dtype)
668-
669-
# Build and compile if needed
670-
# Note that Scalars are included in the key because they are inlined in the compiled code
679+
# Find output dtype
671680
optype = type(op)
672681
if optype is UnaryOp:
673-
key = ('apply_unary', op.name, *sp.get_loop_key(), inplace)
674682
output_dtype = op.get_output_type(sp.dtype)
675683
elif optype is BinaryOp:
676684
if left is not None:
677-
key = ('apply_bind_first', op.name, *sp.get_loop_key(), left._obj, inplace)
678685
output_dtype = op.get_output_type(left.dtype, sp.dtype)
679686
else:
680-
key = ('apply_bind_second', op.name, *sp.get_loop_key(), right._obj, inplace)
681687
output_dtype = op.get_output_type(sp.dtype, right.dtype)
682688
else:
683689
if inplace:
684690
raise TypeError("apply inplace not supported for IndexUnaryOp")
685-
key = ('apply_indexunary', op.name, *sp.get_loop_key(), thunk._obj)
686691
output_dtype = op.get_output_type(sp.dtype, thunk.dtype)
692+
693+
if sp._obj is None:
694+
return sp.baseclass(output_dtype, sp.shape)
695+
696+
# Build and compile if needed
697+
# Note that Scalars are included in the key because they are inlined in the compiled code
698+
if optype is UnaryOp:
699+
key = ('apply_unary', op.name, *sp.get_loop_key(), inplace)
700+
elif optype is BinaryOp:
701+
if left is not None:
702+
key = ('apply_bind_first', op.name, *sp.get_loop_key(), left._obj, inplace)
703+
else:
704+
key = ('apply_bind_second', op.name, *sp.get_loop_key(), right._obj, inplace)
705+
else:
706+
key = ('apply_indexunary', op.name, *sp.get_loop_key(), thunk._obj)
687707
if key not in engine_cache:
688708
if inplace:
689709
engine_cache[key] = _build_apply_inplace(op, sp, left, right)
@@ -887,7 +907,8 @@ def main(x):
887907

888908

889909
def reduce_to_vector(op: Monoid, mat: Union[Matrix, TransposedMatrix]):
890-
# TODO: handle case of mat not having an _obj -> result will be empty vector
910+
if mat._obj is None:
911+
return Vector.new(mat.dtype, mat.shape[0])
891912

892913
# Build and compile if needed
893914
key = ('reduce_to_vector', op.name, *mat.get_loop_key())
@@ -944,7 +965,8 @@ def main(x):
944965

945966

946967
def reduce_to_scalar(op: Monoid, sp: SparseTensorBase):
947-
# TODO: handle case of sp not having an _obj -> result will be empty scalar
968+
if sp._obj is None:
969+
return Scalar.new(sp.dtype)
948970

949971
# Build and compile if needed
950972
key = ('reduce_to_scalar', op.name, *sp.get_loop_key())

Diff for: mlir_graphblas/operations.py

+4
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def update(output: SparseObject,
116116
output._replace(impl.select_by_mask(output, mask, desc))
117117
result = impl.ewise_add(accum, output, tensor)
118118

119+
if result is output:
120+
# This can happen if empty tensors are used as input
121+
return output
122+
119123
# If not an intermediate result, make a copy
120124
if not result._intermediate_result:
121125
result = impl.dup(result)

Diff for: mlir_graphblas/tensor.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def clear(self):
132132
def _replace(self, tensor: SparseObject):
133133
if not tensor._intermediate_result and tensor._obj is not None:
134134
raise ValueError("Can only replace using intermediate values")
135+
if tensor.shape != self.shape:
136+
raise GrbDimensionMismatch("Replace only allowed with same shape object")
137+
if tensor.dtype != self.dtype:
138+
raise GrbDomainMismatch("Replace only allowed with same dtype")
135139
self.clear()
136140
self._obj = tensor._obj
137141
self._sparsity = tensor._sparsity
@@ -150,18 +154,19 @@ def _to_sparse_tensor(self, np_indices, np_values, sparsity, ordering):
150154
np_perm = np.array(ordering, dtype=np.uint64)
151155

152156
# Validate indices are within range
153-
if np_indices.min() < 0:
154-
raise GrbIndexOutOfBounds(f"negative indices not allowed: {np_indices.min()}")
155-
if self.ndims == 2:
156-
max_row = np_indices[:, 0].max()
157-
max_col = np_indices[:, 1].max()
158-
if max_row >= self.shape[0]:
159-
raise GrbIndexOutOfBounds(f"row index out of bounds: {max_row} >= {self.shape[0]}")
160-
if max_col >= self.shape[1]:
161-
raise GrbIndexOutOfBounds(f"col index out of bound: {max_col} >= {self.shape[1]}")
162-
else:
163-
if np_indices.max() >= self.shape[0]:
164-
raise GrbIndexOutOfBounds(f"index out of bounds: {np_indices.max()} >= {self.shape[0]}")
157+
if len(np_indices) > 0:
158+
if np_indices.min() < 0:
159+
raise GrbIndexOutOfBounds(f"negative indices not allowed: {np_indices.min()}")
160+
if self.ndims == 2:
161+
max_row = np_indices[:, 0].max()
162+
max_col = np_indices[:, 1].max()
163+
if max_row >= self.shape[0]:
164+
raise GrbIndexOutOfBounds(f"row index out of bounds: {max_row} >= {self.shape[0]}")
165+
if max_col >= self.shape[1]:
166+
raise GrbIndexOutOfBounds(f"col index out of bound: {max_col} >= {self.shape[1]}")
167+
else:
168+
if np_indices.max() >= self.shape[0]:
169+
raise GrbIndexOutOfBounds(f"index out of bounds: {np_indices.max()} >= {self.shape[0]}")
165170

166171
rank = ctypes.c_ulonglong(len(np_shape))
167172
nse = ctypes.c_ulonglong(len(np_values))
@@ -230,7 +235,7 @@ def nvals(self):
230235
return 1
231236

232237
def set_element(self, val):
233-
self._obj = self.dtype.np_type(val)
238+
self._obj = val if val is None else self.dtype.np_type(val)
234239

235240
def extract_element(self):
236241
if self._obj is None:
@@ -468,9 +473,13 @@ def _compute_format(self):
468473
return '?'
469474

470475
def is_rowwise(self):
476+
if self._ordering is None:
477+
return True
471478
return tuple(self._ordering) == self.permutation
472479

473480
def is_colwise(self):
481+
if self._ordering is None:
482+
return True
474483
return tuple(self._ordering) != self.permutation
475484

476485
def nrows(self):

0 commit comments

Comments
 (0)