Skip to content

Commit 5d3b7a8

Browse files
committed
Fix the return type of e.g. agg.count to be INT64 by default
This should fix python-graphblas/graphblas-algorithms#82
1 parent 0bfcb66 commit 5d3b7a8

File tree

5 files changed

+26
-8
lines changed

5 files changed

+26
-8
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ repos:
4646
# We can probably remove `isort` if we come to trust `ruff --fix`,
4747
# but we'll need to figure out the configuration to do this in `ruff`
4848
- repo: https://github.com/pycqa/isort
49-
rev: 5.12.0
49+
rev: 5.13.1
5050
hooks:
5151
- id: isort
5252
# Let's keep `pyupgrade` even though `ruff --fix` probably does most of it
@@ -61,12 +61,12 @@ repos:
6161
- id: auto-walrus
6262
args: [--line-length, "100"]
6363
- repo: https://github.com/psf/black
64-
rev: 23.10.1
64+
rev: 23.12.0
6565
hooks:
6666
- id: black
6767
- id: black-jupyter
6868
- repo: https://github.com/astral-sh/ruff-pre-commit
69-
rev: v0.1.4
69+
rev: v0.1.7
7070
hooks:
7171
- id: ruff
7272
args: [--fix-only, --show-fixes]
@@ -79,7 +79,7 @@ repos:
7979
additional_dependencies: &flake8_dependencies
8080
# These versions need updated manually
8181
- flake8==6.1.0
82-
- flake8-bugbear==23.9.16
82+
- flake8-bugbear==23.12.2
8383
- flake8-simplify==0.21.0
8484
- repo: https://github.com/asottile/yesqa
8585
rev: v1.5.0
@@ -94,11 +94,11 @@ repos:
9494
additional_dependencies: [tomli]
9595
files: ^(graphblas|docs)/
9696
- repo: https://github.com/astral-sh/ruff-pre-commit
97-
rev: v0.1.4
97+
rev: v0.1.7
9898
hooks:
9999
- id: ruff
100100
- repo: https://github.com/sphinx-contrib/sphinx-lint
101-
rev: v0.8.1
101+
rev: v0.9.1
102102
hooks:
103103
- id: sphinx-lint
104104
args: [--enable, all, "--disable=line-too-long,leaked-markup"]

graphblas/core/operator/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scala
7575
from .agg import Aggregator, TypedAggregator
7676

7777
if isinstance(op, Aggregator):
78+
# agg._any_dtype basically serves the same purpose as op._custom_dtype
79+
if op._any_dtype is not None and op._any_dtype is not True:
80+
return op[op._any_dtype]
7881
return op[dtype]
7982
if isinstance(op, TypedAggregator):
8083
return op

graphblas/tests/test_scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def test_update(s):
250250

251251
def test_not_hashable(s):
252252
with pytest.raises(TypeError, match="unhashable type"):
253-
{s}
253+
_ = {s}
254254
with pytest.raises(TypeError, match="unhashable type"):
255255
hash(s)
256256

graphblas/tests/test_vector.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,21 @@ def test_reduce_agg(v):
948948
assert s.is_empty
949949

950950

951+
def test_reduce_agg_count_is_int64(v):
952+
"""Aggregators that count should default to INT64 return dtype."""
953+
assert v.dtype == dtypes.INT64
954+
res = v.reduce(agg.count).new()
955+
assert res.dtype == dtypes.INT64
956+
assert res == 4
957+
res = v.dup(dtypes.INT8).reduce(agg.count).new()
958+
assert res.dtype == dtypes.INT64
959+
assert res == 4
960+
# Allow return dtype to be specified
961+
res = v.dup(dtypes.INT8).reduce(agg.count[dtypes.INT16]).new()
962+
assert res.dtype == dtypes.INT16
963+
assert res == 4
964+
965+
951966
@pytest.mark.skipif("not suitesparse")
952967
def test_reduce_agg_argminmax(v):
953968
assert v.reduce(agg.ss.argmin).new() == 6

scripts/check_versions.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Use, adjust, copy/paste, etc. as necessary to answer your questions.
44
# This may be helpful when updating dependency versions in CI.
55
# Tip: add `--json` for more information.
6-
conda search 'flake8-bugbear[channel=conda-forge]>=23.9.16'
6+
conda search 'flake8-bugbear[channel=conda-forge]>=23.12.2'
77
conda search 'flake8-simplify[channel=conda-forge]>=0.21.0'
88
conda search 'numpy[channel=conda-forge]>=1.26.0'
99
conda search 'pandas[channel=conda-forge]>=2.1.2'

0 commit comments

Comments
 (0)