|
2 | 2 |
|
3 | 3 | import copy |
4 | 4 | import functools |
5 | | -from typing import Any, Callable, TypeVar |
| 5 | +from typing import Any, Callable, Literal, TypeVar |
6 | 6 |
|
7 | 7 | from bigtree.node import basenode, binarynode, node |
8 | 8 | from bigtree.tree import construct, export, helper, query, search |
9 | 9 | from bigtree.utils import iterators |
10 | 10 |
|
11 | | -try: |
12 | | - import pandas as pd |
13 | | -except ImportError: # pragma: no cover |
14 | | - from unittest.mock import MagicMock |
15 | | - |
16 | | - pd = MagicMock() |
17 | | - |
18 | | -try: |
19 | | - import polars as pl |
20 | | -except ImportError: # pragma: no cover |
21 | | - from unittest.mock import MagicMock |
22 | | - |
23 | | - pl = MagicMock() |
24 | | - |
25 | | - |
26 | | -try: |
27 | | - import pydot |
28 | | -except ImportError: # pragma: no cover |
29 | | - from unittest.mock import MagicMock |
30 | | - |
31 | | - pydot = MagicMock() |
32 | | - |
33 | | -try: |
34 | | - from PIL import Image, ImageDraw, ImageFont |
35 | | -except ImportError: # pragma: no cover |
36 | | - from unittest.mock import MagicMock |
37 | | - |
38 | | - Image = ImageDraw = ImageFont = MagicMock() |
39 | | - |
40 | | -try: |
41 | | - import matplotlib as mpl |
42 | | - from matplotlib.colors import Normalize |
43 | | -except ImportError: # pragma: no cover |
44 | | - from unittest.mock import MagicMock |
45 | | - |
46 | | - mpl = MagicMock() |
47 | | - Normalize = MagicMock() |
48 | | - |
49 | | -try: |
50 | | - import pyvis |
51 | | -except ImportError: # pragma: no cover |
52 | | - from unittest.mock import MagicMock |
53 | | - |
54 | | - pyvis = MagicMock() |
55 | | - |
56 | 11 | try: |
57 | 12 | import matplotlib.pyplot as plt |
58 | 13 | except ImportError: # pragma: no cover |
@@ -85,36 +40,51 @@ def __init__(self, root: node.Node): |
85 | 40 |
|
86 | 41 | @classmethod |
87 | 42 | def register_plugin( |
88 | | - cls, name: str, func: Callable[..., Any], is_classmethod: bool |
| 43 | + cls, |
| 44 | + name: str, |
| 45 | + func: Callable[..., Any], |
| 46 | + method: Literal["default", "class", "helper", "diff"], |
89 | 47 | ) -> None: |
90 | 48 | base_func = func.func if isinstance(func, functools.partial) else func |
91 | 49 |
|
92 | | - if is_classmethod: |
| 50 | + if method == "default": |
| 51 | + |
| 52 | + def wrapper(self, *args, **kwargs): # type: ignore |
| 53 | + return func(self.node, *args, **kwargs) |
| 54 | + |
| 55 | + elif method == "class": |
93 | 56 |
|
94 | 57 | def wrapper(cls, *args, **kwargs): # type: ignore |
95 | 58 | construct_kwargs = {**cls.construct_kwargs, **kwargs} |
96 | 59 | root_node = func(*args, **construct_kwargs) |
97 | 60 | return cls(root_node) |
98 | 61 |
|
99 | | - else: |
| 62 | + elif method == "helper": |
100 | 63 |
|
101 | 64 | def wrapper(self, *args, **kwargs): # type: ignore |
102 | | - return func(self.node, *args, **kwargs) |
| 65 | + return type(self)(func(self.node, *args, **kwargs)) |
| 66 | + |
| 67 | + else: |
| 68 | + |
| 69 | + def wrapper(self, other_tree: T, *args, **kwargs): # type: ignore |
| 70 | + return func(self.node, other_tree.node, *args, **kwargs) |
103 | 71 |
|
104 | 72 | functools.update_wrapper(wrapper, base_func) |
105 | 73 | wrapper.__name__ = name |
106 | | - if is_classmethod: |
| 74 | + if method == "class": |
107 | 75 | setattr(cls, name, classmethod(wrapper)) # type: ignore |
108 | 76 | else: |
109 | 77 | setattr(cls, name, wrapper) |
110 | 78 | cls._plugins[name] = func |
111 | 79 |
|
112 | 80 | @classmethod |
113 | 81 | def register_plugins( |
114 | | - cls, mapping: dict[str, Callable[..., Any]], is_classmethod: bool = False |
| 82 | + cls, |
| 83 | + mapping: dict[str, Callable[..., Any]], |
| 84 | + method: Literal["default", "class", "helper", "diff"] = "default", |
115 | 85 | ) -> None: |
116 | 86 | for name, func in mapping.items(): |
117 | | - cls.register_plugin(name, func, is_classmethod) |
| 87 | + cls.register_plugin(name, func, method) |
118 | 88 |
|
119 | 89 | def show(self, **kwargs: Any) -> None: |
120 | 90 | self.node.show(**kwargs) |
@@ -143,38 +113,6 @@ def depth(self) -> int: |
143 | 113 | """ |
144 | 114 | return self.node.max_depth |
145 | 115 |
|
146 | | - # Helper methods |
147 | | - def clone(self, node_type: type[BaseNodeT]) -> "Tree": |
148 | | - """See `clone_tree` for full details. |
149 | | -
|
150 | | - Accepts the same arguments as `clone_tree`. |
151 | | - """ |
152 | | - return type(self)(helper.clone_tree(self.node, node_type)) # type: ignore |
153 | | - |
154 | | - def prune(self, *args: Any, **kwargs: Any) -> "Tree": |
155 | | - """See `prune_tree` for full details. |
156 | | -
|
157 | | - Accepts the same arguments as `prune_tree`. |
158 | | - """ |
159 | | - return type(self)(helper.prune_tree(self.node, *args, **kwargs)) |
160 | | - |
161 | | - def diff_dataframe(self, other_tree: T, *args: Any, **kwargs: Any) -> pd.DataFrame: |
162 | | - """See `get_tree_diff_dataframe` for full details. |
163 | | -
|
164 | | - Accepts the same arguments as `get_tree_diff_dataframe`. |
165 | | - """ |
166 | | - return helper.get_tree_diff_dataframe( |
167 | | - self.node, other_tree.node, *args, **kwargs |
168 | | - ) |
169 | | - |
170 | | - def diff(self, other_tree: T, *args: Any, **kwargs: Any) -> node.Node: |
171 | | - """See `get_tree_diff` for full details. |
172 | | -
|
173 | | - Accepts the same arguments as `get_tree_diff`. |
174 | | - """ |
175 | | - return helper.get_tree_diff(self.node, other_tree.node, *args, **kwargs) |
176 | | - |
177 | | - # Plot methods |
178 | 116 | def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: |
179 | 117 | """Plot tree in line form. Accepts args and kwargs for matplotlib.pyplot.plot() function. |
180 | 118 |
|
@@ -259,7 +197,7 @@ def __repr__(self) -> str: |
259 | 197 | "from_str": construct.str_to_tree, |
260 | 198 | "from_newick": construct.newick_to_tree, |
261 | 199 | }, |
262 | | - is_classmethod=True, |
| 200 | + method="class", |
263 | 201 | ) |
264 | 202 |
|
265 | 203 | Tree.register_plugins( |
@@ -309,3 +247,19 @@ def __repr__(self) -> str: |
309 | 247 | "zigzaggroup_iter": iterators.zigzaggroup_iter, |
310 | 248 | } |
311 | 249 | ) |
| 250 | +Tree.register_plugins( |
| 251 | + { |
| 252 | + # Helper methods |
| 253 | + "clone": helper.clone_tree, |
| 254 | + "prune": helper.prune_tree, |
| 255 | + }, |
| 256 | + method="helper", |
| 257 | +) |
| 258 | +Tree.register_plugins( |
| 259 | + { |
| 260 | + # Helper methods |
| 261 | + "diff_dataframe": helper.get_tree_diff_dataframe, |
| 262 | + "diff": helper.get_tree_diff, |
| 263 | + }, |
| 264 | + method="diff", |
| 265 | +) |
0 commit comments