2020 TypeVar ,
2121 Type ,
2222 Union ,
23+ Set ,
2324 cast ,
2425)
2526
@@ -82,22 +83,53 @@ class GenericMeta(type):
8283Container = Any
8384
8485
85- class Registry :
86+ class PatchedRegistry :
8687
8788 def __init__ (self ):
88- self ._storage = set ()
89+ self ._callables : Set [Callable [..., Any ]] = set ()
90+ self ._attributes : Set [PatchedAttribute ] = set ()
8991
90- def add (self , patched : Callable [..., Any ]) -> None :
91- self ._storage .add (patched )
92+ def add_callable (self , patched : Callable [..., Any ]) -> None :
93+ self ._callables .add (patched )
9294
93- def get_from_module (self , module : ModuleType ) -> Iterator [Callable [..., Any ]]:
94- for patched in self ._storage :
95+ def get_callables_from_module (self , module : ModuleType ) -> Iterator [Callable [..., Any ]]:
96+ for patched in self ._callables :
9597 if patched .__module__ != module .__name__ :
9698 continue
9799 yield patched
98100
101+ def add_attribute (self , patched : 'PatchedAttribute' ):
102+ self ._attributes .add (patched )
99103
100- _patched_registry = Registry ()
104+ def get_attributes_from_module (self , module : ModuleType ) -> Iterator ['PatchedAttribute' ]:
105+ for attribute in self ._attributes :
106+ if not attribute .is_in_module (module ):
107+ continue
108+ yield attribute
109+
110+ def clear_module_attributes (self , module : ModuleType ):
111+ for attribute in self ._attributes .copy ():
112+ if not attribute .is_in_module (module ):
113+ continue
114+ self ._attributes .remove (attribute )
115+
116+
117+ class PatchedAttribute :
118+
119+ def __init__ (self , member : Any , name : str , marker : '_Marker' ):
120+ self .member = member
121+ self .name = name
122+ self .marker = marker
123+
124+ @property
125+ def module_name (self ) -> str :
126+ if isinstance (self .member , ModuleType ):
127+ return self .member .__name__
128+ else :
129+ return self .member .__module__
130+
131+ def is_in_module (self , module : ModuleType ) -> bool :
132+ return self .module_name == module .__name__
101133
102134
103135class ProvidersMap :
@@ -278,9 +310,6 @@ def _is_starlette_request_cls(self, instance: object) -> bool:
278310 and issubclass (instance , starlette .requests .Request )
279311
280312
281- inspect_filter = InspectFilter ()
282-
283-
284313def wire ( # noqa: C901
285314 container : Container ,
286315 * ,
@@ -301,20 +330,27 @@ def wire( # noqa: C901
301330 providers_map = ProvidersMap (container )
302331
303332 for module in modules :
304- for name , member in inspect .getmembers (module ):
305- if inspect_filter .is_excluded (member ):
333+ for member_name , member in inspect .getmembers (module ):
334+ if _inspect_filter .is_excluded (member ):
306335 continue
307- if inspect .isfunction (member ):
308- _patch_fn (module , name , member , providers_map )
309- elif inspect .isclass (member ):
310- for method_name , method in inspect .getmembers (member , _is_method ):
311- _patch_method (member , method_name , method , providers_map )
312336
313- for patched in _patched_registry .get_from_module (module ):
337+ if _is_marker (member ):
338+ _patch_attribute (module , member_name , member , providers_map )
339+ elif inspect .isfunction (member ):
340+ _patch_fn (module , member_name , member , providers_map )
341+ elif inspect .isclass (member ):
342+ cls = member
343+ for cls_member_name , cls_member in inspect .getmembers (cls ):
344+ if _is_marker (cls_member ):
345+ _patch_attribute (cls , cls_member_name , cls_member , providers_map )
346+ elif _is_method (cls_member ):
347+ _patch_method (cls , cls_member_name , cls_member , providers_map )
348+
349+ for patched in _patched_registry .get_callables_from_module (module ):
314350 _bind_injections (patched , providers_map )
315351
316352
317- def unwire (
353+ def unwire ( # noqa: C901
318354 * ,
319355 modules : Optional [Iterable [ModuleType ]] = None ,
320356 packages : Optional [Iterable [ModuleType ]] = None ,
@@ -335,15 +371,19 @@ def unwire(
335371 for method_name , method in inspect .getmembers (member , inspect .isfunction ):
336372 _unpatch (member , method_name , method )
337373
338- for patched in _patched_registry .get_from_module (module ):
374+ for patched in _patched_registry .get_callables_from_module (module ):
339375 _unbind_injections (patched )
340376
377+ for patched_attribute in _patched_registry .get_attributes_from_module (module ):
378+ _unpatch_attribute (patched_attribute )
379+ _patched_registry .clear_module_attributes (module )
380+
341381
342382def inject (fn : F ) -> F :
343383 """Decorate callable with injecting decorator."""
344384 reference_injections , reference_closing = _fetch_reference_injections (fn )
345385 patched = _get_patched (fn , reference_injections , reference_closing )
346- _patched_registry .add (patched )
386+ _patched_registry .add_callable (patched )
347387 return cast (F , patched )
348388
349389
@@ -358,7 +398,7 @@ def _patch_fn(
358398 if not reference_injections :
359399 return
360400 fn = _get_patched (fn , reference_injections , reference_closing )
361- _patched_registry .add (fn )
401+ _patched_registry .add_callable (fn )
362402
363403 _bind_injections (fn , providers_map )
364404
@@ -384,7 +424,7 @@ def _patch_method(
384424 if not reference_injections :
385425 return
386426 fn = _get_patched (fn , reference_injections , reference_closing )
387- _patched_registry .add (fn )
427+ _patched_registry .add_callable (fn )
388428
389429 _bind_injections (fn , providers_map )
390430
@@ -411,6 +451,31 @@ def _unpatch(
411451 _unbind_injections (fn )
412452
413453
454+ def _patch_attribute (
455+ member : Any ,
456+ name : str ,
457+ marker : '_Marker' ,
458+ providers_map : ProvidersMap ,
459+ ) -> None :
460+ provider = providers_map .resolve_provider (marker .provider , marker .modifier )
461+ if provider is None :
462+ return
463+
464+ _patched_registry .add_attribute (PatchedAttribute (member , name , marker ))
465+
466+ if isinstance (marker , Provide ):
467+ instance = provider ()
468+ setattr (member , name , instance )
469+ elif isinstance (marker , Provider ):
470+ setattr (member , name , provider )
471+ else :
472+ raise Exception (f'Unknown type of marker { marker } ' )
473+
474+
475+ def _unpatch_attribute (patched : PatchedAttribute ) -> None :
476+ setattr (patched .member , patched .name , patched .marker )
477+
478+
414479def _fetch_reference_injections (
415480 fn : Callable [..., Any ],
416481) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
@@ -484,6 +549,10 @@ def _is_method(member):
484549 return inspect .ismethod (member ) or inspect .isfunction (member )
485550
486551
552+ def _is_marker (member ):
553+ return isinstance (member , _Marker )
554+
555+
487556def _get_patched (fn , reference_injections , reference_closing ):
488557 if inspect .iscoroutinefunction (fn ):
489558 patched = _get_async_patched (fn )
@@ -825,9 +894,6 @@ def uninstall(self):
825894 importlib .invalidate_caches ()
826895
827896
828- _loader = AutoLoader ()
829-
830-
831897def register_loader_containers (* containers : Container ) -> None :
832898 """Register containers in auto-wiring module loader."""
833899 _loader .register_containers (* containers )
@@ -851,3 +917,8 @@ def uninstall_loader() -> None:
851917def is_loader_installed () -> bool :
852918 """Check if auto-wiring module loader hook is installed."""
853919 return _loader .installed
920+
921+
922+ _patched_registry = PatchedRegistry ()
923+ _inspect_filter = InspectFilter ()
924+ _loader = AutoLoader ()
0 commit comments