Skip to content

Commit 0bfcb66

Browse files
authored
Support A.power(0) (python-graphblas#518)
1 parent 8a80032 commit 0bfcb66

File tree

3 files changed

+33
-11
lines changed

3 files changed

+33
-11
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ repos:
6666
- id: black
6767
- id: black-jupyter
6868
- repo: https://github.com/astral-sh/ruff-pre-commit
69-
rev: v0.1.3
69+
rev: v0.1.4
7070
hooks:
7171
- id: ruff
7272
args: [--fix-only, --show-fixes]
@@ -94,7 +94,7 @@ repos:
9494
additional_dependencies: [tomli]
9595
files: ^(graphblas|docs)/
9696
- repo: https://github.com/astral-sh/ruff-pre-commit
97-
rev: v0.1.3
97+
rev: v0.1.4
9898
hooks:
9999
- id: ruff
100100
- repo: https://github.com/sphinx-contrib/sphinx-lint

graphblas/core/matrix.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ def _reposition(updater, indices, chunk):
101101

102102
def _power(updater, A, n, op):
103103
opts = updater.opts
104+
if n == 0:
105+
v = Vector.from_scalar(op.binaryop.monoid.identity, A._nrows, A.dtype, name="v_diag")
106+
updater << v.diag(name="M_diag")
107+
return
104108
if n == 1:
105109
updater << A
106110
return
@@ -2895,7 +2899,11 @@ def power(self, n, op=semiring.plus_times):
28952899
Parameters
28962900
----------
28972901
n : int
2898-
The exponent must be a positive integer.
2902+
The exponent must be a nonnegative integer. If n=0, the result will be a diagonal
2903+
matrix with values equal to the identity of the semiring's binary operator.
2904+
For example, ``plus_times`` will have diagonal values of 1, which is the
2905+
identity of ``times``. The binary operator must be associated with a monoid
2906+
when n=0 so the identity can be determined; otherwise, ValueError is raised.
28992907
op : :class:`~graphblas.core.operator.Semiring`
29002908
Semiring used in the computation
29012909
@@ -2923,11 +2931,17 @@ def power(self, n, op=semiring.plus_times):
29232931
if self._nrows != self._ncols:
29242932
raise DimensionMismatch(f"power only works for square Matrix; shape is {self.shape}")
29252933
if (N := maybe_integral(n)) is None:
2926-
raise TypeError(f"n must be a positive integer; got bad type: {type(n)}")
2927-
if N <= 0:
2928-
raise ValueError(f"n must be a positive integer; got: {N}")
2934+
raise TypeError(f"n must be a nonnegative integer; got bad type: {type(n)}")
2935+
if N < 0:
2936+
raise ValueError(f"n must be a nonnegative integer; got: {N}")
29292937
op = get_typed_op(op, self.dtype, kind="semiring")
29302938
self._expect_op(op, "Semiring", within=method_name, argname="op")
2939+
if N == 0 and op.binaryop.monoid is None:
2940+
raise ValueError(
2941+
f"Binary operator of {op} semiring does not have a monoid with an identity. "
2942+
"When n=0, the result is a diagonal matrix with values equal to the "
2943+
"identity of the binaryop, so the binaryop must be associated with a monoid."
2944+
)
29312945
return MatrixExpression(
29322946
"power",
29332947
None,

graphblas/tests/test_matrix.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4402,14 +4402,22 @@ def test_power(A):
44024402
result = A.power(i, semiring.min_plus).new()
44034403
assert result.isequal(expected)
44044404
expected << semiring.min_plus(A @ expected)
4405+
# n == 0
4406+
result = A.power(0).new()
4407+
expected = Vector.from_scalar(1, A.nrows, A.dtype).diag()
4408+
assert result.isequal(expected)
4409+
result = A.power(0, semiring.plus_min).new()
4410+
identity = semiring.plus_min[A.dtype].binaryop.monoid.identity
4411+
assert identity != 1
4412+
expected = Vector.from_scalar(identity, A.nrows, A.dtype).diag()
4413+
assert result.isequal(expected)
44054414
# Exceptional
4406-
with pytest.raises(TypeError, match="must be a positive integer"):
4415+
with pytest.raises(TypeError, match="must be a nonnegative integer"):
44074416
A.power(1.5)
4408-
with pytest.raises(ValueError, match="must be a positive integer"):
4417+
with pytest.raises(ValueError, match="must be a nonnegative integer"):
44094418
A.power(-1)
4410-
with pytest.raises(ValueError, match="must be a positive integer"):
4411-
# Not implemented yet... could create identity matrix
4412-
A.power(0)
4419+
with pytest.raises(ValueError, match="binaryop must be associated with a monoid"):
4420+
A.power(0, semiring.min_first)
44134421
B = A[:2, :3].new()
44144422
with pytest.raises(DimensionMismatch):
44154423
B.power(2)

0 commit comments

Comments
 (0)