Skip to content

Commit 3ee1744

Browse files
authored
Migrate more functions to plugin (#415)
* feat: nested_key_dict to tree to support child_key=None * feat: tree to nested_key_dict to support child_key=None * refactor: enable plugin - retains docstrings, help, and suggestions * docs: update CHANGELOG * refactor: migrate more methods to plugin * test: fix codecov * refactor: more clean up
1 parent 06f35a3 commit 3ee1744

File tree

3 files changed

+55
-113
lines changed

3 files changed

+55
-113
lines changed

bigtree/binarytree/binarytree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, root: binarynode.BinaryNode):
2929
# Append methods
3030
"from_heapq_list": construct.list_to_binarytree,
3131
},
32-
is_classmethod=True,
32+
method="class",
3333
)
3434
BinaryTree.register_plugins(
3535
{

bigtree/dag/dag.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,12 @@
22

33
import copy
44
import functools
5-
from typing import Any, Callable, TypeVar
5+
from typing import Any, Callable, Literal, TypeVar
66

77
from bigtree.dag import construct, export
88
from bigtree.node import dagnode
99
from bigtree.utils import iterators
1010

11-
try:
12-
import pandas as pd
13-
except ImportError: # pragma: no cover
14-
from unittest.mock import MagicMock
15-
16-
pd = MagicMock()
17-
18-
try:
19-
import pydot
20-
except ImportError: # pragma: no cover
21-
from unittest.mock import MagicMock
22-
23-
pydot = MagicMock()
24-
2511

2612
class DAG:
2713
"""
@@ -41,36 +27,38 @@ def __init__(self, dag: dagnode.DAGNode):
4127

4228
@classmethod
4329
def register_plugin(
44-
cls, name: str, func: Callable[..., Any], is_classmethod: bool
30+
cls, name: str, func: Callable[..., Any], method: Literal["default", "class"]
4531
) -> None:
4632
base_func = func.func if isinstance(func, functools.partial) else func
4733

48-
if is_classmethod:
34+
if method == "default":
35+
36+
def wrapper(self, *args, **kwargs): # type: ignore
37+
return func(self.dag, *args, **kwargs)
38+
39+
else:
4940

5041
def wrapper(cls, *args, **kwargs): # type: ignore
5142
construct_kwargs = {**cls.construct_kwargs, **kwargs}
5243
root_node = func(*args, **construct_kwargs)
5344
return cls(root_node)
5445

55-
else:
56-
57-
def wrapper(self, *args, **kwargs): # type: ignore
58-
return func(self.dag, *args, **kwargs)
59-
6046
functools.update_wrapper(wrapper, base_func)
6147
wrapper.__name__ = name
62-
if is_classmethod:
48+
if method == "class":
6349
setattr(cls, name, classmethod(wrapper)) # type: ignore
6450
else:
6551
setattr(cls, name, wrapper)
6652
cls._plugins[name] = func
6753

6854
@classmethod
6955
def register_plugins(
70-
cls, mapping: dict[str, Callable[..., Any]], is_classmethod: bool = False
56+
cls,
57+
mapping: dict[str, Callable[..., Any]],
58+
method: Literal["default", "class"] = "default",
7159
) -> None:
7260
for name, func in mapping.items():
73-
cls.register_plugin(name, func, is_classmethod)
61+
cls.register_plugin(name, func, method)
7462

7563
# Magic methods
7664
def __getitem__(self, child_name: str) -> "DAG":
@@ -140,7 +128,7 @@ def __repr__(self) -> str:
140128
"from_dict": construct.dict_to_dag,
141129
"from_list": construct.list_to_dag,
142130
},
143-
is_classmethod=True,
131+
method="class",
144132
)
145133
DAG.register_plugins(
146134
{

bigtree/tree/tree.py

Lines changed: 40 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,12 @@
22

33
import copy
44
import functools
5-
from typing import Any, Callable, TypeVar
5+
from typing import Any, Callable, Literal, TypeVar
66

77
from bigtree.node import basenode, binarynode, node
88
from bigtree.tree import construct, export, helper, query, search
99
from bigtree.utils import iterators
1010

11-
try:
12-
import pandas as pd
13-
except ImportError: # pragma: no cover
14-
from unittest.mock import MagicMock
15-
16-
pd = MagicMock()
17-
18-
try:
19-
import polars as pl
20-
except ImportError: # pragma: no cover
21-
from unittest.mock import MagicMock
22-
23-
pl = MagicMock()
24-
25-
26-
try:
27-
import pydot
28-
except ImportError: # pragma: no cover
29-
from unittest.mock import MagicMock
30-
31-
pydot = MagicMock()
32-
33-
try:
34-
from PIL import Image, ImageDraw, ImageFont
35-
except ImportError: # pragma: no cover
36-
from unittest.mock import MagicMock
37-
38-
Image = ImageDraw = ImageFont = MagicMock()
39-
40-
try:
41-
import matplotlib as mpl
42-
from matplotlib.colors import Normalize
43-
except ImportError: # pragma: no cover
44-
from unittest.mock import MagicMock
45-
46-
mpl = MagicMock()
47-
Normalize = MagicMock()
48-
49-
try:
50-
import pyvis
51-
except ImportError: # pragma: no cover
52-
from unittest.mock import MagicMock
53-
54-
pyvis = MagicMock()
55-
5611
try:
5712
import matplotlib.pyplot as plt
5813
except ImportError: # pragma: no cover
@@ -85,36 +40,51 @@ def __init__(self, root: node.Node):
8540

8641
@classmethod
8742
def register_plugin(
88-
cls, name: str, func: Callable[..., Any], is_classmethod: bool
43+
cls,
44+
name: str,
45+
func: Callable[..., Any],
46+
method: Literal["default", "class", "helper", "diff"],
8947
) -> None:
9048
base_func = func.func if isinstance(func, functools.partial) else func
9149

92-
if is_classmethod:
50+
if method == "default":
51+
52+
def wrapper(self, *args, **kwargs): # type: ignore
53+
return func(self.node, *args, **kwargs)
54+
55+
elif method == "class":
9356

9457
def wrapper(cls, *args, **kwargs): # type: ignore
9558
construct_kwargs = {**cls.construct_kwargs, **kwargs}
9659
root_node = func(*args, **construct_kwargs)
9760
return cls(root_node)
9861

99-
else:
62+
elif method == "helper":
10063

10164
def wrapper(self, *args, **kwargs): # type: ignore
102-
return func(self.node, *args, **kwargs)
65+
return type(self)(func(self.node, *args, **kwargs))
66+
67+
else:
68+
69+
def wrapper(self, other_tree: T, *args, **kwargs): # type: ignore
70+
return func(self.node, other_tree.node, *args, **kwargs)
10371

10472
functools.update_wrapper(wrapper, base_func)
10573
wrapper.__name__ = name
106-
if is_classmethod:
74+
if method == "class":
10775
setattr(cls, name, classmethod(wrapper)) # type: ignore
10876
else:
10977
setattr(cls, name, wrapper)
11078
cls._plugins[name] = func
11179

11280
@classmethod
11381
def register_plugins(
114-
cls, mapping: dict[str, Callable[..., Any]], is_classmethod: bool = False
82+
cls,
83+
mapping: dict[str, Callable[..., Any]],
84+
method: Literal["default", "class", "helper", "diff"] = "default",
11585
) -> None:
11686
for name, func in mapping.items():
117-
cls.register_plugin(name, func, is_classmethod)
87+
cls.register_plugin(name, func, method)
11888

11989
def show(self, **kwargs: Any) -> None:
12090
self.node.show(**kwargs)
@@ -143,38 +113,6 @@ def depth(self) -> int:
143113
"""
144114
return self.node.max_depth
145115

146-
# Helper methods
147-
def clone(self, node_type: type[BaseNodeT]) -> "Tree":
148-
"""See `clone_tree` for full details.
149-
150-
Accepts the same arguments as `clone_tree`.
151-
"""
152-
return type(self)(helper.clone_tree(self.node, node_type)) # type: ignore
153-
154-
def prune(self, *args: Any, **kwargs: Any) -> "Tree":
155-
"""See `prune_tree` for full details.
156-
157-
Accepts the same arguments as `prune_tree`.
158-
"""
159-
return type(self)(helper.prune_tree(self.node, *args, **kwargs))
160-
161-
def diff_dataframe(self, other_tree: T, *args: Any, **kwargs: Any) -> pd.DataFrame:
162-
"""See `get_tree_diff_dataframe` for full details.
163-
164-
Accepts the same arguments as `get_tree_diff_dataframe`.
165-
"""
166-
return helper.get_tree_diff_dataframe(
167-
self.node, other_tree.node, *args, **kwargs
168-
)
169-
170-
def diff(self, other_tree: T, *args: Any, **kwargs: Any) -> node.Node:
171-
"""See `get_tree_diff` for full details.
172-
173-
Accepts the same arguments as `get_tree_diff`.
174-
"""
175-
return helper.get_tree_diff(self.node, other_tree.node, *args, **kwargs)
176-
177-
# Plot methods
178116
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
179117
"""Plot tree in line form. Accepts args and kwargs for matplotlib.pyplot.plot() function.
180118
@@ -259,7 +197,7 @@ def __repr__(self) -> str:
259197
"from_str": construct.str_to_tree,
260198
"from_newick": construct.newick_to_tree,
261199
},
262-
is_classmethod=True,
200+
method="class",
263201
)
264202

265203
Tree.register_plugins(
@@ -309,3 +247,19 @@ def __repr__(self) -> str:
309247
"zigzaggroup_iter": iterators.zigzaggroup_iter,
310248
}
311249
)
250+
Tree.register_plugins(
251+
{
252+
# Helper methods
253+
"clone": helper.clone_tree,
254+
"prune": helper.prune_tree,
255+
},
256+
method="helper",
257+
)
258+
Tree.register_plugins(
259+
{
260+
# Helper methods
261+
"diff_dataframe": helper.get_tree_diff_dataframe,
262+
"diff": helper.get_tree_diff,
263+
},
264+
method="diff",
265+
)

0 commit comments

Comments
 (0)