55import pkgutil
66import sys
77from types import ModuleType
8- from typing import Optional , Iterable , Callable , Any , Tuple , Dict , Generic , TypeVar , cast
8+ from typing import Optional , Iterable , Callable , Any , Tuple , Dict , Generic , TypeVar , Type , cast
99
1010if sys .version_info < (3 , 7 ):
1111 from typing import GenericMeta
@@ -176,7 +176,7 @@ def wire(
176176 _patch_fn (module , name , member , providers_map )
177177 elif inspect .isclass (member ):
178178 for method_name , method in inspect .getmembers (member , _is_method ):
179- _patch_fn (member , method_name , method , providers_map )
179+ _patch_method (member , method_name , method , providers_map )
180180
181181
182182def unwire (
@@ -195,10 +195,10 @@ def unwire(
195195 for module in modules :
196196 for name , member in inspect .getmembers (module ):
197197 if inspect .isfunction (member ):
198- _unpatch_fn (module , name , member )
198+ _unpatch (module , name , member )
199199 elif inspect .isclass (member ):
200200 for method_name , method in inspect .getmembers (member , inspect .isfunction ):
201- _unpatch_fn (member , method_name , method )
201+ _unpatch (member , method_name , method )
202202
203203
204204def _patch_fn (
@@ -210,10 +210,41 @@ def _patch_fn(
210210 injections , closing = _resolve_injections (fn , providers_map )
211211 if not injections :
212212 return
213- setattr (module , name , _patch_with_injections (fn , injections , closing ))
213+ patched = _patch_with_injections (fn , injections , closing )
214+ setattr (module , name , _wrap_patched (patched , fn , injections , closing ))
214215
215216
216- def _unpatch_fn (
217+ def _patch_method (
218+ cls : Type ,
219+ name : str ,
220+ method : Callable [..., Any ],
221+ providers_map : ProvidersMap ,
222+ ) -> None :
223+ injections , closing = _resolve_injections (method , providers_map )
224+ if not injections :
225+ return
226+
227+ if hasattr (cls , '__dict__' ) \
228+ and name in cls .__dict__ \
229+ and isinstance (cls .__dict__ [name ], (classmethod , staticmethod )):
230+ method = cls .__dict__ [name ]
231+ patched = _patch_with_injections (method .__func__ , injections , closing )
232+ patched = type (method )(patched )
233+ else :
234+ patched = _patch_with_injections (method , injections , closing )
235+
236+ setattr (cls , name , _wrap_patched (patched , method , injections , closing ))
237+
238+
239+ def _wrap_patched (patched : Callable [..., Any ], original , injections , closing ):
240+ patched .__wired__ = True
241+ patched .__original__ = original
242+ patched .__injections__ = injections
243+ patched .__closing__ = closing
244+ return patched
245+
246+
247+ def _unpatch (
217248 module : ModuleType ,
218249 name : str ,
219250 fn : Callable [..., Any ],
@@ -276,12 +307,6 @@ def _patch_with_injections(fn, injections, closing):
276307 _patched = _get_async_patched (fn , injections , closing )
277308 else :
278309 _patched = _get_patched (fn , injections , closing )
279-
280- _patched .__wired__ = True
281- _patched .__original__ = fn
282- _patched .__injections__ = injections
283- _patched .__closing__ = closing
284-
285310 return _patched
286311
287312
0 commit comments