Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions src/dependency_injector/_cwiring.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,28 @@ from .wiring import _Marker, PatchedCallable
from .providers cimport Provider


def _get_sync_patched(fn, patched: PatchedCallable):
@functools.wraps(fn)
def _patched(*args, **kwargs):
cdef object result
cdef dict to_inject
cdef object arg_key
cdef Provider provider

to_inject = kwargs.copy()
for arg_key, provider in patched.injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider()

result = fn(*args, **to_inject)

if patched.closing:
for arg_key, provider in patched.closing.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()

return result
return _patched
def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings):
cdef object result
cdef dict to_inject
cdef object arg_key
cdef Provider provider

to_inject = kwargs.copy()
for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider()

result = fn(*args, **to_inject)

if closings:
for arg_key, provider in closings.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()

return result


async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings):
Expand Down
17 changes: 15 additions & 2 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def is_loader_installed() -> bool:
_loader = AutoLoader()

# Optimizations
from ._cwiring import _get_sync_patched # noqa
from ._cwiring import _sync_inject # noqa
from ._cwiring import _async_inject # noqa


Expand All @@ -1047,4 +1047,17 @@ async def _patched(*args, **kwargs):
patched.closing,
)

return _patched
return cast(F, _patched)


def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
def _patched(*args, **kwargs):
return _sync_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
return cast(F, _patched)