diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index 7841b5b..8ff188b 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -35,7 +35,7 @@ jobs: - name: Check with twine run: python -m twine check --strict dist/* - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@v1.8.6 + uses: pypa/gh-action-pypi-publish@v1.8.10 with: user: __token__ password: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6a1b79e..103821c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,7 +30,7 @@ jobs: activate-environment: testing - name: Install dependencies run: | - conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly + conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly pytest-mpl # matplotlib lxml pygraphviz pydot sympy # Extra networkx deps we don't need yet pip install git+https://github.com/networkx/networkx.git@main --no-deps pip install -e . --no-deps @@ -39,7 +39,8 @@ jobs: python -c 'import sys, graphblas_algorithms; assert "networkx" not in sys.modules' coverage run --branch -m pytest --color=yes -v --check-structure coverage report - NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append + # NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append + ./run_nx_tests.sh --color=yes --cov --cov-append coverage report coverage xml - name: Coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c0c02d4..6b08ce0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ ci: # See: https://pre-commit.ci/#configuration autofix_prs: false - autoupdate_schedule: monthly + autoupdate_schedule: quarterly skip: [no-commit-to-branch] fail_fast: true default_language_version: @@ -17,21 +17,27 @@ repos: rev: v4.4.0 hooks: - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks - id: check-ast - id: check-toml - id: check-yaml - id: debug-statements - id: end-of-file-fixer + exclude_types: [svg] - id: mixed-line-ending - id: trailing-whitespace + - id: name-tests-test + args: ["--pytest-test-first"] - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.13 + rev: v0.14 hooks: - id: validate-pyproject name: Validate pyproject.toml # I don't yet trust ruff to do what autoflake does - repo: https://github.com/PyCQA/autoflake - rev: v2.1.1 + rev: v2.2.0 hooks: - id: autoflake args: [--in-place] @@ -40,7 +46,7 @@ repos: hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v3.4.0 + rev: v3.10.1 hooks: - id: pyupgrade args: [--py38-plus] @@ -50,38 +56,38 @@ repos: - id: auto-walrus args: [--line-length, "100"] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 hooks: - id: black # - id: black-jupyter - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.270 + rev: v0.0.285 hooks: - id: ruff args: [--fix-only, --show-fixes] - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 additional_dependencies: &flake8_dependencies # These versions need updated manually - - flake8==6.0.0 - - flake8-bugbear==23.5.9 + - flake8==6.1.0 + - flake8-bugbear==23.7.10 - flake8-simplify==0.20.0 - repo: https://github.com/asottile/yesqa - rev: v1.4.0 + rev: v1.5.0 hooks: - id: yesqa additional_dependencies: *flake8_dependencies - repo: https://github.com/codespell-project/codespell - rev: v2.2.4 + rev: v2.2.5 hooks: - id: codespell types_or: [python, rst, markdown] additional_dependencies: [tomli] files: ^(graphblas_algorithms|docs)/ - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.270 + rev: v0.0.285 hooks: - id: ruff # `pyroma` may help keep our package standards up to date if best practices change. diff --git a/graphblas_algorithms/interface.py b/graphblas_algorithms/interface.py index a43b520..1d8c283 100644 --- a/graphblas_algorithms/interface.py +++ b/graphblas_algorithms/interface.py @@ -171,20 +171,71 @@ class Dispatcher: # End auto-generated code: dispatch @staticmethod - def convert_from_nx(graph, weight=None, *, name=None): + def convert_from_nx( + graph, + edge_attrs=None, + node_attrs=None, + preserve_edge_attrs=False, + preserve_node_attrs=False, + preserve_graph_attrs=False, + name=None, + graph_name=None, + *, + weight=None, # For nx.__version__ <= 3.1 + ): import networkx as nx from .classes import DiGraph, Graph, MultiDiGraph, MultiGraph + if preserve_edge_attrs: + if graph.is_multigraph(): + attrs = set().union( + *( + datadict + for nbrs in graph._adj.values() + for keydict in nbrs.values() + for datadict in keydict.values() + ) + ) + else: + attrs = set().union( + *(datadict for nbrs in graph._adj.values() for datadict in nbrs.values()) + ) + if len(attrs) == 1: + [attr] = attrs + edge_attrs = {attr: None} + elif attrs: + raise NotImplementedError("`preserve_edge_attrs=True` is not fully implemented") + if node_attrs: + raise NotImplementedError("non-None `node_attrs` is not yet implemented") + if preserve_node_attrs: + attrs = set().union(*(datadict for node, datadict in graph.nodes(data=True))) + if attrs: + raise NotImplementedError("`preserve_node_attrs=True` is not implemented") + if edge_attrs: + if len(edge_attrs) > 1: + raise NotImplementedError( + "Multiple edge attributes is not implemented (bad value for edge_attrs)" + ) + if weight is not None: + raise TypeError("edge_attrs and weight both given") + [[weight, default]] = edge_attrs.items() + if default is not None and default != 1: + raise NotImplementedError(f"edge default != 1 is not implemented; got {default}") + if isinstance(graph, nx.MultiDiGraph): - return MultiDiGraph.from_networkx(graph, weight=weight) - if isinstance(graph, nx.MultiGraph): - return MultiGraph.from_networkx(graph, weight=weight) - if isinstance(graph, nx.DiGraph): - return DiGraph.from_networkx(graph, weight=weight) - if isinstance(graph, nx.Graph): - return Graph.from_networkx(graph, weight=weight) - raise TypeError(f"Unsupported type of graph: {type(graph)}") + G = MultiDiGraph.from_networkx(graph, weight=weight) + elif isinstance(graph, nx.MultiGraph): + G = MultiGraph.from_networkx(graph, weight=weight) + elif isinstance(graph, nx.DiGraph): + G = DiGraph.from_networkx(graph, weight=weight) + elif isinstance(graph, nx.Graph): + G = Graph.from_networkx(graph, weight=weight) + else: + raise TypeError(f"Unsupported type of graph: {type(graph)}") + if preserve_graph_attrs: + G.graph.update(graph.graph) + return G @staticmethod def convert_to_nx(obj, *, name=None): diff --git a/graphblas_algorithms/tests/test_match_nx.py b/graphblas_algorithms/tests/test_match_nx.py index 225c970..1924ff7 100644 --- a/graphblas_algorithms/tests/test_match_nx.py +++ b/graphblas_algorithms/tests/test_match_nx.py @@ -22,13 +22,29 @@ "Matching networkx namespace requires networkx to be installed", allow_module_level=True ) else: - from networkx.classes import backends # noqa: F401 + try: + from networkx.utils import backends + + IS_NX_30_OR_31 = False + except ImportError: # pragma: no cover (import) + # This is the location in nx 3.1 + from networkx.classes import backends # noqa: F401 + + IS_NX_30_OR_31 = True def isdispatched(func): """Can this NetworkX function dispatch to other backends?""" + if IS_NX_30_OR_31: + return ( + callable(func) + and hasattr(func, "dispatchname") + and func.__module__.startswith("networkx") + ) return ( - callable(func) and hasattr(func, "dispatchname") and func.__module__.startswith("networkx") + callable(func) + and hasattr(func, "preserve_edge_attrs") + and func.__module__.startswith("networkx") ) @@ -37,7 +53,9 @@ def dispatchname(func): # Haha, there should be a better way to get this if not isdispatched(func): raise ValueError(f"Function is not dispatched in NetworkX: {func.__name__}") - return func.dispatchname + if IS_NX_30_OR_31: + return func.dispatchname + return func.name def fullname(func): diff --git a/pyproject.toml b/pyproject.toml index 36afd28..8fb2ffc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,12 +214,14 @@ ignore = [ "RET502", # Do not implicitly `return None` in function able to return non-`None` value "RET503", # Missing explicit `return` at the end of function able to return non-`None` value "RET504", # Unnecessary variable assignment before `return` statement + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` (Note: no annotations yet) "S110", # `try`-`except`-`pass` detected, consider logging the exception (Note: good advice, but we don't log) "S112", # `try`-`except`-`continue` detected, consider logging the exception (Note: good advice, but we don't log) "SIM102", # Use a single `if` statement instead of nested `if` statements (Note: often necessary) "SIM105", # Use contextlib.suppress(...) instead of try-except-pass (Note: try-except-pass is much faster) "SIM108", # Use ternary operator ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer) "TRY003", # Avoid specifying long messages outside the exception class (Note: why?) + "FIX001", "FIX002", "FIX003", "FIX004", # flake8-fixme (like flake8-todos) # Ignored categories "C90", # mccabe (Too strict, but maybe we should make things less complex) diff --git a/run_nx_tests.sh b/run_nx_tests.sh index 08a5582..740ab26 100755 --- a/run_nx_tests.sh +++ b/run_nx_tests.sh @@ -1,3 +1,6 @@ #!/bin/bash -NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx "$@" -# NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx --cov --cov-report term-missing "$@" +NETWORKX_GRAPH_CONVERT=graphblas \ +NETWORKX_TEST_BACKEND=graphblas \ +NETWORKX_FALLBACK_TO_NX=True \ + pytest --pyargs networkx "$@" +# pytest --pyargs networkx --cov --cov-report term-missing "$@" diff --git a/scripts/bench.py b/scripts/bench.py index ba61300..3b3f4dc 100755 --- a/scripts/bench.py +++ b/scripts/bench.py @@ -19,7 +19,7 @@ datapaths = [ Path(__file__).parent / ".." / "data", - Path("."), + Path(), ] @@ -37,7 +37,7 @@ def find_data(dataname): if dataname not in download_data.data_urls: raise FileNotFoundError(f"Unable to find data file for {dataname}") curpath = Path(download_data.main([dataname])[0]) - return curpath.resolve().relative_to(Path(".").resolve()) + return curpath.resolve().relative_to(Path().resolve()) def get_symmetry(file_or_mminfo): diff --git a/scripts/download_data.py b/scripts/download_data.py index 009ebf0..b01626c 100755 --- a/scripts/download_data.py +++ b/scripts/download_data.py @@ -47,7 +47,7 @@ def main(datanames, overwrite=False): for name in datanames: target = datapath / f"{name}.mtx" filenames.append(target) - relpath = target.resolve().relative_to(Path(".").resolve()) + relpath = target.resolve().relative_to(Path().resolve()) if not overwrite and target.exists(): print(f"{relpath} already exists; skipping", file=sys.stderr) continue