Skip to content
Prev Previous commit
Next Next commit
More updates to work with latest networkx dispatching
  • Loading branch information
eriknw committed Aug 25, 2023
commit 3347d019fef45b4b86ceb25058b7aeddf3e119e4
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ jobs:
activate-environment: testing
- name: Install dependencies
run: |
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly pytest-mpl
# matplotlib lxml pygraphviz pydot sympy # Extra networkx deps we don't need yet
# pip install git+https://github.com/networkx/networkx.git@main --no-deps
pip install git+https://github.com/eriknw/networkx.git@dispatch_many --no-deps # XXX: temporary
pip install git+https://github.com/networkx/networkx.git@main --no-deps
pip install -e . --no-deps
- name: PyTest
run: |
python -c 'import sys, graphblas_algorithms; assert "networkx" not in sys.modules'
coverage run --branch -m pytest --color=yes -v --check-structure
coverage report
NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
# NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
./run_nx_tests.sh --color=yes --cov --cov-append
coverage report
coverage xml
- name: Coverage
Expand Down
55 changes: 37 additions & 18 deletions graphblas_algorithms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ def convert_from_nx(
graph,
edge_attrs=None,
node_attrs=None,
preserve_edge_attrs=None,
preserve_node_attrs=None,
preserve_graph_attrs=None,
preserve_edge_attrs=False,
preserve_node_attrs=False,
preserve_graph_attrs=False,
name=None,
graph_name=None,
*,
Expand All @@ -188,15 +188,30 @@ def convert_from_nx(
from .classes import DiGraph, Graph, MultiDiGraph, MultiGraph

if preserve_edge_attrs:
raise NotImplementedError("`preserve_edge_attrs=True` is not implemented")
if graph.is_multigraph():
attrs = set().union(
*(
datadict
for nbrs in graph._adj.values()
for keydict in nbrs.values()
for datadict in keydict.values()
)
)
else:
attrs = set().union(
*(datadict for nbrs in graph._adj.values() for datadict in nbrs.values())
)
if len(attrs) == 1:
[attr] = attrs
edge_attrs = {attr: None}
elif attrs:
raise NotImplementedError("`preserve_edge_attrs=True` is not fully implemented")
if node_attrs:
raise NotImplementedError("non-None `node_attrs` is not implemented")
raise NotImplementedError("non-None `node_attrs` is not yet implemented")
if preserve_node_attrs:
raise NotImplementedError("`preserve_node_attrs=True` is not implemented")
if preserve_graph_attrs:
raise NotImplementedError("`preserve_graphs_attrs=True` is not implemented")
if graph_name:
raise NotImplementedError("Not possible to set a graph name")
attrs = set().union(*(datadict for node, datadict in graph.nodes(data=True)))
if attrs:
raise NotImplementedError("`preserve_node_attrs=True` is not implemented")
if edge_attrs:
if len(edge_attrs) > 1:
raise NotImplementedError(
Expand All @@ -209,14 +224,18 @@ def convert_from_nx(
raise NotImplementedError(f"edge default != 1 is not implemented; got {default}")

if isinstance(graph, nx.MultiDiGraph):
return MultiDiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.MultiGraph):
return MultiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.DiGraph):
return DiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.Graph):
return Graph.from_networkx(graph, weight=weight)
raise TypeError(f"Unsupported type of graph: {type(graph)}")
G = MultiDiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.MultiGraph):
G = MultiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.DiGraph):
G = DiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.Graph):
G = Graph.from_networkx(graph, weight=weight)
else:
raise TypeError(f"Unsupported type of graph: {type(graph)}")
if preserve_graph_attrs:
G.graph.update(graph.graph)
return G

@staticmethod
def convert_to_nx(obj, *, name=None):
Expand Down
18 changes: 16 additions & 2 deletions graphblas_algorithms/tests/test_match_nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,27 @@
else:
try:
from networkx.utils import backends

IS_NX_30_OR_31 = False
except ImportError: # pragma: no cover (import)
# This is the location in nx 3.1
from networkx.classes import backends # noqa: F401

IS_NX_30_OR_31 = True


def isdispatched(func):
"""Can this NetworkX function dispatch to other backends?"""
if IS_NX_30_OR_31:
return (
callable(func)
and hasattr(func, "dispatchname")
and func.__module__.startswith("networkx")
)
return (
callable(func) and hasattr(func, "dispatchname") and func.__module__.startswith("networkx")
callable(func)
and hasattr(func, "preserve_edge_attrs")
and func.__module__.startswith("networkx")
)


Expand All @@ -41,7 +53,9 @@ def dispatchname(func):
# Haha, there should be a better way to get this
if not isdispatched(func):
raise ValueError(f"Function is not dispatched in NetworkX: {func.__name__}")
return func.dispatchname
if IS_NX_30_OR_31:
return func.dispatchname
return func.name


def fullname(func):
Expand Down
7 changes: 5 additions & 2 deletions run_nx_tests.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!/bin/bash
NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx "$@"
# NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx --cov --cov-report term-missing "$@"
NETWORKX_GRAPH_CONVERT=graphblas \
NETWORKX_TEST_BACKEND=graphblas \
NETWORKX_FALLBACK_TO_NX=True \
pytest --pyargs networkx "$@"
# pytest --pyargs networkx --cov --cov-report term-missing "$@"