11from __future__ import annotations
22
33import copy
4- from typing import Any , Collection , Iterable , Mapping , TypeVar
4+ import functools
5+ from typing import Any , Callable , TypeVar
56
67from bigtree .dag import construct , export
78from bigtree .node import dagnode
8- from bigtree .utils import exceptions , iterators
9+ from bigtree .utils import iterators
910
1011try :
1112 import pandas as pd
@@ -32,73 +33,44 @@ class DAG:
3233 Do refer to the various modules respectively on the keyword parameters.
3334 """
3435
36+ _plugins : dict [str , Callable [..., Any ]] = {}
3537 construct_kwargs : dict [str , Any ] = dict ()
3638
3739 def __init__ (self , dag : dagnode .DAGNode ):
3840 self .dag = dag
3941
40- # Construct methods
4142 @classmethod
42- def from_dataframe (cls , data : pd .DataFrame , ** kwargs : Any ) -> "DAG" :
43- """See `dataframe_to_dag` for full details.
43+ def register_plugin (
44+ cls , name : str , func : Callable [..., Any ], is_classmethod : bool
45+ ) -> None :
46+ base_func = func .func if isinstance (func , functools .partial ) else func
4447
45- Accepts the same arguments as `dataframe_to_dag`.
46- """
47- construct_kwargs = {** cls .construct_kwargs , ** kwargs }
48- root_node = construct .dataframe_to_dag (data , ** construct_kwargs )
49- return cls (root_node )
48+ if is_classmethod :
5049
51- @classmethod
52- def from_dict (cls , relation_attrs : Mapping [str , Any ], ** kwargs : Any ) -> "DAG" :
53- """See `dict_to_dag` for full details.
50+ def wrapper (cls , * args , ** kwargs ): # type: ignore
51+ construct_kwargs = {** cls .construct_kwargs , ** kwargs }
52+ root_node = func (* args , ** construct_kwargs )
53+ return cls (root_node )
5454
55- Accepts the same arguments as `dict_to_dag`.
56- """
57- construct_kwargs = {** cls .construct_kwargs , ** kwargs }
58- root_node = construct .dict_to_dag (relation_attrs , ** construct_kwargs )
59- return cls (root_node )
60-
61- @classmethod
62- def from_list (cls , relations : Collection [tuple [str , str ]], ** kwargs : Any ) -> "DAG" :
63- """See `list_to_dag` for full details.
55+ else :
6456
65- Accepts the same arguments as `list_to_dag`.
66- """
67- construct_kwargs = {** cls .construct_kwargs , ** kwargs }
68- root_node = construct .list_to_dag (relations , ** construct_kwargs )
69- return cls (root_node )
70-
71- # Export methods
72- def to_dataframe (self , * args : Any , ** kwargs : Any ) -> pd .DataFrame :
73- """See `dag_to_dataframe` for full details.
74-
75- Accepts the same arguments as `dag_to_dataframe`.
76- """
77- return export .dag_to_dataframe (self .dag , * args , ** kwargs )
57+ def wrapper (self , * args , ** kwargs ): # type: ignore
58+ return func (self .dag , * args , ** kwargs )
7859
79- def to_dict (self , * args : Any , ** kwargs : Any ) -> dict [str , Any ]:
80- """See `dag_to_dict` for full details.
60+ functools .update_wrapper (wrapper , base_func )
61+ wrapper .__name__ = name
62+ if is_classmethod :
63+ setattr (cls , name , classmethod (wrapper )) # type: ignore
64+ else :
65+ setattr (cls , name , wrapper )
66+ cls ._plugins [name ] = func
8167
82- Accepts the same arguments as `dag_to_dict`.
83- """
84- return export .dag_to_dict (self .dag , * args , ** kwargs )
85-
86- def to_list (self ) -> list [tuple [str , str ]]:
87- """See `dag_to_list` for full details."""
88- return export .dag_to_list (self .dag )
89-
90- @exceptions .optional_dependencies_image ("pydot" )
91- def to_dot (self , * args : Any , ** kwargs : Any ) -> pydot .Dot :
92- """See `dag_to_dot` for full details.
93-
94- Accepts the same arguments as `dag_to_dot`.
95- """
96- return export .dag_to_dot (self .dag , * args , ** kwargs )
97-
98- # Iterator methods
99- def iterate (self ) -> Iterable [tuple [dagnode .DAGNode , dagnode .DAGNode ]]:
100- """See `dag_iterator` for full details."""
101- return iterators .dag_iterator (self .dag )
68+ @classmethod
69+ def register_plugins (
70+ cls , mapping : dict [str , Callable [..., Any ]], is_classmethod : bool = False
71+ ) -> None :
72+ for name , func in mapping .items ():
73+ cls .register_plugin (name , func , is_classmethod )
10274
10375 # Magic methods
10476 def __getitem__ (self , child_name : str ) -> "DAG" :
@@ -160,3 +132,24 @@ def __repr__(self) -> str:
160132
161133
162134T = TypeVar ("T" , bound = DAG )
135+
136+ DAG .register_plugins (
137+ {
138+ # Construct methods
139+ "from_dataframe" : construct .dataframe_to_dag ,
140+ "from_dict" : construct .dict_to_dag ,
141+ "from_list" : construct .list_to_dag ,
142+ },
143+ is_classmethod = True ,
144+ )
145+ DAG .register_plugins (
146+ {
147+ # Export methods
148+ "to_dataframe" : export .dag_to_dataframe ,
149+ "to_dict" : export .dag_to_dict ,
150+ "to_list" : export .dag_to_list ,
151+ "to_dot" : export .dag_to_dot ,
152+ # Iterator methods
153+ "iterate" : iterators .dag_iterator ,
154+ },
155+ )
0 commit comments