diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index c4dc1298..555ebb1d 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -273,41 +273,9 @@ def _is_method(member): def _patch_with_injections(fn, injections, closing): if inspect.iscoroutinefunction(fn): - @functools.wraps(fn) - async def _patched(*args, **kwargs): - to_inject = kwargs.copy() - for injection, provider in injections.items(): - if injection not in kwargs: - to_inject[injection] = provider() - - result = await fn(*args, **to_inject) - - for injection, provider in closing.items(): - if injection in kwargs: - continue - if not isinstance(provider, providers.Resource): - continue - provider.shutdown() - - return result + _patched = _get_async_patched(fn, injections, closing) else: - @functools.wraps(fn) - def _patched(*args, **kwargs): - to_inject = kwargs.copy() - for injection, provider in injections.items(): - if injection not in kwargs: - to_inject[injection] = provider() - - result = fn(*args, **to_inject) - - for injection, provider in closing.items(): - if injection in kwargs: - continue - if not isinstance(provider, providers.Resource): - continue - provider.shutdown() - - return result + _patched = _get_patched(fn, injections, closing) _patched.__wired__ = True _patched.__original__ = fn @@ -317,6 +285,47 @@ def _patch_with_injections(fn, injections, closing): return _patched +def _get_patched(fn, injections, closing): + def _patched(*args, **kwargs): + to_inject = kwargs.copy() + for injection, provider in injections.items(): + if injection not in kwargs: + to_inject[injection] = provider() + + result = fn(*args, **to_inject) + + for injection, provider in closing.items(): + if injection in kwargs: + continue + if not isinstance(provider, providers.Resource): + continue + provider.shutdown() + + return result + return _patched + + +def _get_async_patched(fn, injections, closing): + @functools.wraps(fn) + async def _patched(*args, **kwargs): + to_inject = kwargs.copy() + for injection, provider in injections.items(): + if injection not in kwargs: + to_inject[injection] = provider() + + result = await fn(*args, **to_inject) + + for injection, provider in closing.items(): + if injection in kwargs: + continue + if not isinstance(provider, providers.Resource): + continue + provider.shutdown() + + return result + return _patched + + def _is_patched(fn): return getattr(fn, '__wired__', False) is True