Skip to content

Commit 3caced2

Browse files
authored
Update to work with new networkx dispatching (#68)
1 parent 1f5ccb6 commit 3caced2

File tree

9 files changed

+113
-32
lines changed

9 files changed

+113
-32
lines changed

Diff for: .github/workflows/publish_pypi.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
- name: Check with twine
3636
run: python -m twine check --strict dist/*
3737
- name: Publish to PyPI
38-
uses: pypa/[email protected].6
38+
uses: pypa/[email protected].10
3939
with:
4040
user: __token__
4141
password: ${{ secrets.PYPI_TOKEN }}

Diff for: .github/workflows/test.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
activate-environment: testing
3131
- name: Install dependencies
3232
run: |
33-
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly
33+
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly pytest-mpl
3434
# matplotlib lxml pygraphviz pydot sympy # Extra networkx deps we don't need yet
3535
pip install git+https://github.com/networkx/networkx.git@main --no-deps
3636
pip install -e . --no-deps
@@ -39,7 +39,8 @@ jobs:
3939
python -c 'import sys, graphblas_algorithms; assert "networkx" not in sys.modules'
4040
coverage run --branch -m pytest --color=yes -v --check-structure
4141
coverage report
42-
NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
42+
# NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
43+
./run_nx_tests.sh --color=yes --cov --cov-append
4344
coverage report
4445
coverage xml
4546
- name: Coverage

Diff for: .pre-commit-config.yaml

+18-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
ci:
88
# See: https://pre-commit.ci/#configuration
99
autofix_prs: false
10-
autoupdate_schedule: monthly
10+
autoupdate_schedule: quarterly
1111
skip: [no-commit-to-branch]
1212
fail_fast: true
1313
default_language_version:
@@ -17,21 +17,27 @@ repos:
1717
rev: v4.4.0
1818
hooks:
1919
- id: check-added-large-files
20+
- id: check-case-conflict
21+
- id: check-merge-conflict
22+
- id: check-symlinks
2023
- id: check-ast
2124
- id: check-toml
2225
- id: check-yaml
2326
- id: debug-statements
2427
- id: end-of-file-fixer
28+
exclude_types: [svg]
2529
- id: mixed-line-ending
2630
- id: trailing-whitespace
31+
- id: name-tests-test
32+
args: ["--pytest-test-first"]
2733
- repo: https://github.com/abravalheri/validate-pyproject
28-
rev: v0.13
34+
rev: v0.14
2935
hooks:
3036
- id: validate-pyproject
3137
name: Validate pyproject.toml
3238
# I don't yet trust ruff to do what autoflake does
3339
- repo: https://github.com/PyCQA/autoflake
34-
rev: v2.1.1
40+
rev: v2.2.0
3541
hooks:
3642
- id: autoflake
3743
args: [--in-place]
@@ -40,7 +46,7 @@ repos:
4046
hooks:
4147
- id: isort
4248
- repo: https://github.com/asottile/pyupgrade
43-
rev: v3.4.0
49+
rev: v3.10.1
4450
hooks:
4551
- id: pyupgrade
4652
args: [--py38-plus]
@@ -50,38 +56,38 @@ repos:
5056
- id: auto-walrus
5157
args: [--line-length, "100"]
5258
- repo: https://github.com/psf/black
53-
rev: 23.3.0
59+
rev: 23.7.0
5460
hooks:
5561
- id: black
5662
# - id: black-jupyter
5763
- repo: https://github.com/charliermarsh/ruff-pre-commit
58-
rev: v0.0.270
64+
rev: v0.0.285
5965
hooks:
6066
- id: ruff
6167
args: [--fix-only, --show-fixes]
6268
- repo: https://github.com/PyCQA/flake8
63-
rev: 6.0.0
69+
rev: 6.1.0
6470
hooks:
6571
- id: flake8
6672
additional_dependencies: &flake8_dependencies
6773
# These versions need updated manually
68-
- flake8==6.0.0
69-
- flake8-bugbear==23.5.9
74+
- flake8==6.1.0
75+
- flake8-bugbear==23.7.10
7076
- flake8-simplify==0.20.0
7177
- repo: https://github.com/asottile/yesqa
72-
rev: v1.4.0
78+
rev: v1.5.0
7379
hooks:
7480
- id: yesqa
7581
additional_dependencies: *flake8_dependencies
7682
- repo: https://github.com/codespell-project/codespell
77-
rev: v2.2.4
83+
rev: v2.2.5
7884
hooks:
7985
- id: codespell
8086
types_or: [python, rst, markdown]
8187
additional_dependencies: [tomli]
8288
files: ^(graphblas_algorithms|docs)/
8389
- repo: https://github.com/charliermarsh/ruff-pre-commit
84-
rev: v0.0.270
90+
rev: v0.0.285
8591
hooks:
8692
- id: ruff
8793
# `pyroma` may help keep our package standards up to date if best practices change.

Diff for: graphblas_algorithms/interface.py

+60-9
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,71 @@ class Dispatcher:
171171
# End auto-generated code: dispatch
172172

173173
@staticmethod
174-
def convert_from_nx(graph, weight=None, *, name=None):
174+
def convert_from_nx(
175+
graph,
176+
edge_attrs=None,
177+
node_attrs=None,
178+
preserve_edge_attrs=False,
179+
preserve_node_attrs=False,
180+
preserve_graph_attrs=False,
181+
name=None,
182+
graph_name=None,
183+
*,
184+
weight=None, # For nx.__version__ <= 3.1
185+
):
175186
import networkx as nx
176187

177188
from .classes import DiGraph, Graph, MultiDiGraph, MultiGraph
178189

190+
if preserve_edge_attrs:
191+
if graph.is_multigraph():
192+
attrs = set().union(
193+
*(
194+
datadict
195+
for nbrs in graph._adj.values()
196+
for keydict in nbrs.values()
197+
for datadict in keydict.values()
198+
)
199+
)
200+
else:
201+
attrs = set().union(
202+
*(datadict for nbrs in graph._adj.values() for datadict in nbrs.values())
203+
)
204+
if len(attrs) == 1:
205+
[attr] = attrs
206+
edge_attrs = {attr: None}
207+
elif attrs:
208+
raise NotImplementedError("`preserve_edge_attrs=True` is not fully implemented")
209+
if node_attrs:
210+
raise NotImplementedError("non-None `node_attrs` is not yet implemented")
211+
if preserve_node_attrs:
212+
attrs = set().union(*(datadict for node, datadict in graph.nodes(data=True)))
213+
if attrs:
214+
raise NotImplementedError("`preserve_node_attrs=True` is not implemented")
215+
if edge_attrs:
216+
if len(edge_attrs) > 1:
217+
raise NotImplementedError(
218+
"Multiple edge attributes is not implemented (bad value for edge_attrs)"
219+
)
220+
if weight is not None:
221+
raise TypeError("edge_attrs and weight both given")
222+
[[weight, default]] = edge_attrs.items()
223+
if default is not None and default != 1:
224+
raise NotImplementedError(f"edge default != 1 is not implemented; got {default}")
225+
179226
if isinstance(graph, nx.MultiDiGraph):
180-
return MultiDiGraph.from_networkx(graph, weight=weight)
181-
if isinstance(graph, nx.MultiGraph):
182-
return MultiGraph.from_networkx(graph, weight=weight)
183-
if isinstance(graph, nx.DiGraph):
184-
return DiGraph.from_networkx(graph, weight=weight)
185-
if isinstance(graph, nx.Graph):
186-
return Graph.from_networkx(graph, weight=weight)
187-
raise TypeError(f"Unsupported type of graph: {type(graph)}")
227+
G = MultiDiGraph.from_networkx(graph, weight=weight)
228+
elif isinstance(graph, nx.MultiGraph):
229+
G = MultiGraph.from_networkx(graph, weight=weight)
230+
elif isinstance(graph, nx.DiGraph):
231+
G = DiGraph.from_networkx(graph, weight=weight)
232+
elif isinstance(graph, nx.Graph):
233+
G = Graph.from_networkx(graph, weight=weight)
234+
else:
235+
raise TypeError(f"Unsupported type of graph: {type(graph)}")
236+
if preserve_graph_attrs:
237+
G.graph.update(graph.graph)
238+
return G
188239

189240
@staticmethod
190241
def convert_to_nx(obj, *, name=None):

Diff for: graphblas_algorithms/tests/test_match_nx.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,29 @@
2222
"Matching networkx namespace requires networkx to be installed", allow_module_level=True
2323
)
2424
else:
25-
from networkx.classes import backends # noqa: F401
25+
try:
26+
from networkx.utils import backends
27+
28+
IS_NX_30_OR_31 = False
29+
except ImportError: # pragma: no cover (import)
30+
# This is the location in nx 3.1
31+
from networkx.classes import backends # noqa: F401
32+
33+
IS_NX_30_OR_31 = True
2634

2735

2836
def isdispatched(func):
2937
"""Can this NetworkX function dispatch to other backends?"""
38+
if IS_NX_30_OR_31:
39+
return (
40+
callable(func)
41+
and hasattr(func, "dispatchname")
42+
and func.__module__.startswith("networkx")
43+
)
3044
return (
31-
callable(func) and hasattr(func, "dispatchname") and func.__module__.startswith("networkx")
45+
callable(func)
46+
and hasattr(func, "preserve_edge_attrs")
47+
and func.__module__.startswith("networkx")
3248
)
3349

3450

@@ -37,7 +53,9 @@ def dispatchname(func):
3753
# Haha, there should be a better way to get this
3854
if not isdispatched(func):
3955
raise ValueError(f"Function is not dispatched in NetworkX: {func.__name__}")
40-
return func.dispatchname
56+
if IS_NX_30_OR_31:
57+
return func.dispatchname
58+
return func.name
4159

4260

4361
def fullname(func):

Diff for: pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,14 @@ ignore = [
214214
"RET502", # Do not implicitly `return None` in function able to return non-`None` value
215215
"RET503", # Missing explicit `return` at the end of function able to return non-`None` value
216216
"RET504", # Unnecessary variable assignment before `return` statement
217+
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` (Note: no annotations yet)
217218
"S110", # `try`-`except`-`pass` detected, consider logging the exception (Note: good advice, but we don't log)
218219
"S112", # `try`-`except`-`continue` detected, consider logging the exception (Note: good advice, but we don't log)
219220
"SIM102", # Use a single `if` statement instead of nested `if` statements (Note: often necessary)
220221
"SIM105", # Use contextlib.suppress(...) instead of try-except-pass (Note: try-except-pass is much faster)
221222
"SIM108", # Use ternary operator ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer)
222223
"TRY003", # Avoid specifying long messages outside the exception class (Note: why?)
224+
"FIX001", "FIX002", "FIX003", "FIX004", # flake8-fixme (like flake8-todos)
223225

224226
# Ignored categories
225227
"C90", # mccabe (Too strict, but maybe we should make things less complex)

Diff for: run_nx_tests.sh

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
#!/bin/bash
2-
NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx "$@"
3-
# NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx --cov --cov-report term-missing "$@"
2+
NETWORKX_GRAPH_CONVERT=graphblas \
3+
NETWORKX_TEST_BACKEND=graphblas \
4+
NETWORKX_FALLBACK_TO_NX=True \
5+
pytest --pyargs networkx "$@"
6+
# pytest --pyargs networkx --cov --cov-report term-missing "$@"

Diff for: scripts/bench.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
datapaths = [
2121
Path(__file__).parent / ".." / "data",
22-
Path("."),
22+
Path(),
2323
]
2424

2525

@@ -37,7 +37,7 @@ def find_data(dataname):
3737
if dataname not in download_data.data_urls:
3838
raise FileNotFoundError(f"Unable to find data file for {dataname}")
3939
curpath = Path(download_data.main([dataname])[0])
40-
return curpath.resolve().relative_to(Path(".").resolve())
40+
return curpath.resolve().relative_to(Path().resolve())
4141

4242

4343
def get_symmetry(file_or_mminfo):

Diff for: scripts/download_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def main(datanames, overwrite=False):
4747
for name in datanames:
4848
target = datapath / f"{name}.mtx"
4949
filenames.append(target)
50-
relpath = target.resolve().relative_to(Path(".").resolve())
50+
relpath = target.resolve().relative_to(Path().resolve())
5151
if not overwrite and target.exists():
5252
print(f"{relpath} already exists; skipping", file=sys.stderr)
5353
continue

0 commit comments

Comments
 (0)