| 
1 | 1 | """Wiring optimizations module."""  | 
2 | 2 | 
 
  | 
3 |  | -import asyncio  | 
4 |  | -import collections.abc  | 
5 |  | -import inspect  | 
6 |  | -import types  | 
 | 3 | +from asyncio import gather  | 
 | 4 | +from collections.abc import Awaitable  | 
 | 5 | +from inspect import CO_ITERABLE_COROUTINE  | 
 | 6 | +from types import CoroutineType, GeneratorType  | 
7 | 7 | 
 
  | 
 | 8 | +from .providers cimport Provider, Resource, NULL_AWAITABLE  | 
8 | 9 | from .wiring import _Marker  | 
9 | 10 | 
 
  | 
10 |  | -from .providers cimport Provider, Resource  | 
 | 11 | +cimport cython  | 
11 | 12 | 
 
  | 
12 | 13 | 
 
  | 
13 |  | -def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):  | 
14 |  | - cdef object result  | 
 | 14 | +@cython.internal  | 
 | 15 | +@cython.no_gc  | 
 | 16 | +cdef class KWPair:  | 
 | 17 | + cdef str name  | 
 | 18 | + cdef object value  | 
 | 19 | + | 
 | 20 | + def __cinit__(self, str name, object value, /):  | 
 | 21 | + self.name = name  | 
 | 22 | + self.value = value  | 
 | 23 | + | 
 | 24 | + | 
 | 25 | +cdef inline bint _is_injectable(dict kwargs, str name):  | 
 | 26 | + return name not in kwargs or isinstance(kwargs[name], _Marker)  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +cdef class DependencyResolver:  | 
 | 30 | + cdef dict kwargs  | 
15 | 31 |  cdef dict to_inject  | 
16 |  | - cdef object arg_key  | 
17 |  | - cdef Provider provider  | 
 | 32 | + cdef dict injections  | 
 | 33 | + cdef dict closings  | 
18 | 34 | 
 
  | 
19 |  | - to_inject = kwargs.copy()  | 
20 |  | - for arg_key, provider in injections.items():  | 
21 |  | - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):  | 
22 |  | - to_inject[arg_key] = provider()  | 
 | 35 | + def __init__(self, dict kwargs, dict injections, dict closings, /):  | 
 | 36 | + self.kwargs = kwargs  | 
 | 37 | + self.to_inject = kwargs.copy()  | 
 | 38 | + self.injections = injections  | 
 | 39 | + self.closings = closings  | 
23 | 40 | 
 
  | 
24 |  | - result = fn(*args, **to_inject)  | 
 | 41 | + async def _await_injection(self, kw_pair: KWPair, /) -> None:  | 
 | 42 | + self.to_inject[kw_pair.name] = await kw_pair.value  | 
25 | 43 | 
 
  | 
26 |  | - if closings:  | 
27 |  | - for arg_key, provider in closings.items():  | 
28 |  | - if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):  | 
29 |  | - continue  | 
30 |  | - if not isinstance(provider, Resource):  | 
31 |  | - continue  | 
32 |  | - provider.shutdown()  | 
 | 44 | + cdef object _await_injections(self, to_await: list):  | 
 | 45 | + return gather(*map(self._await_injection, to_await))  | 
33 | 46 | 
 
  | 
34 |  | - return result  | 
 | 47 | + cdef void _handle_injections_sync(self):  | 
 | 48 | + cdef Provider provider  | 
35 | 49 | 
 
  | 
 | 50 | + for name, provider in self.injections.items():  | 
 | 51 | + if _is_injectable(self.kwargs, name):  | 
 | 52 | + self.to_inject[name] = provider()  | 
36 | 53 | 
 
  | 
37 |  | -async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):  | 
38 |  | - cdef object result  | 
39 |  | - cdef dict to_inject  | 
40 |  | - cdef list to_inject_await = []  | 
41 |  | - cdef list to_close_await = []  | 
42 |  | - cdef object arg_key  | 
43 |  | - cdef Provider provider  | 
44 |  | - | 
45 |  | - to_inject = kwargs.copy()  | 
46 |  | - for arg_key, provider in injections.items():  | 
47 |  | - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):  | 
48 |  | - provide = provider()  | 
49 |  | - if provider.is_async_mode_enabled():  | 
50 |  | - to_inject_await.append((arg_key, provide))  | 
51 |  | - elif _isawaitable(provide):  | 
52 |  | - to_inject_await.append((arg_key, provide))  | 
53 |  | - else:  | 
54 |  | - to_inject[arg_key] = provide  | 
55 |  | - | 
56 |  | - if to_inject_await:  | 
57 |  | - async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))  | 
58 |  | - for provide, (injection, _) in zip(async_to_inject, to_inject_await):  | 
59 |  | - to_inject[injection] = provide  | 
60 |  | - | 
61 |  | - result = await fn(*args, **to_inject)  | 
62 |  | - | 
63 |  | - if closings:  | 
64 |  | - for arg_key, provider in closings.items():  | 
65 |  | - if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):  | 
66 |  | - continue  | 
67 |  | - if not isinstance(provider, Resource):  | 
68 |  | - continue  | 
69 |  | - shutdown = provider.shutdown()  | 
70 |  | - if _isawaitable(shutdown):  | 
71 |  | - to_close_await.append(shutdown)  | 
72 |  | - | 
73 |  | - await asyncio.gather(*to_close_await)  | 
74 |  | - | 
75 |  | - return result  | 
 | 54 | + cdef list _handle_injections_async(self):  | 
 | 55 | + cdef list to_await = []  | 
 | 56 | + cdef Provider provider  | 
 | 57 | + | 
 | 58 | + for name, provider in self.injections.items():  | 
 | 59 | + if _is_injectable(self.kwargs, name):  | 
 | 60 | + provide = provider()  | 
 | 61 | + | 
 | 62 | + if provider.is_async_mode_enabled() or _isawaitable(provide):  | 
 | 63 | + to_await.append(KWPair(name, provide))  | 
 | 64 | + else:  | 
 | 65 | + self.to_inject[name] = provide  | 
 | 66 | + | 
 | 67 | + return to_await  | 
 | 68 | + | 
 | 69 | + cdef void _handle_closings_sync(self):  | 
 | 70 | + cdef Provider provider  | 
 | 71 | + | 
 | 72 | + for name, provider in self.closings.items():  | 
 | 73 | + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):  | 
 | 74 | + provider.shutdown()  | 
 | 75 | + | 
 | 76 | + cdef list _handle_closings_async(self):  | 
 | 77 | + cdef list to_await = []  | 
 | 78 | + cdef Provider provider  | 
 | 79 | + | 
 | 80 | + for name, provider in self.closings.items():  | 
 | 81 | + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):  | 
 | 82 | + if _isawaitable(shutdown := provider.shutdown()):  | 
 | 83 | + to_await.append(shutdown)  | 
 | 84 | + | 
 | 85 | + return to_await  | 
 | 86 | + | 
 | 87 | + def __enter__(self):  | 
 | 88 | + self._handle_injections_sync()  | 
 | 89 | + return self.to_inject  | 
 | 90 | + | 
 | 91 | + def __exit__(self, *_):  | 
 | 92 | + self._handle_closings_sync()  | 
 | 93 | + | 
 | 94 | + async def __aenter__(self):  | 
 | 95 | + if to_await := self._handle_injections_async():  | 
 | 96 | + await self._await_injections(to_await)  | 
 | 97 | + return self.to_inject  | 
 | 98 | + | 
 | 99 | + def __aexit__(self, *_):  | 
 | 100 | + if to_await := self._handle_closings_async():  | 
 | 101 | + return gather(*to_await)  | 
 | 102 | + return NULL_AWAITABLE  | 
76 | 103 | 
 
  | 
77 | 104 | 
 
  | 
78 | 105 | cdef bint _isawaitable(object instance):  | 
79 | 106 |  """Return true if object can be passed to an ``await`` expression."""  | 
80 |  | - return (isinstance(instance, types.CoroutineType) or  | 
81 |  | - isinstance(instance, types.GeneratorType) and  | 
82 |  | - bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or  | 
83 |  | - isinstance(instance, collections.abc.Awaitable))  | 
 | 107 | + return (isinstance(instance, CoroutineType) or  | 
 | 108 | + isinstance(instance, GeneratorType) and  | 
 | 109 | + bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or  | 
 | 110 | + isinstance(instance, Awaitable))  | 
0 commit comments