diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 056814d2..a3c36d4a 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -1,5 +1,6 @@ """Wiring module.""" +import asyncio import functools import inspect import importlib @@ -426,10 +427,20 @@ def _get_async_patched(fn): @functools.wraps(fn) async def _patched(*args, **kwargs): to_inject = kwargs.copy() + to_inject_await = [] + to_close_await = [] for injection, provider in _patched.__injections__.items(): if injection not in kwargs \ or _is_fastapi_default_arg_injection(injection, kwargs): - to_inject[injection] = provider() + provide = provider() + if inspect.isawaitable(provide): + to_inject_await.append((injection, provide)) + else: + to_inject[injection] = provide + + async_to_inject = await asyncio.gather(*[provide for _, provide in to_inject_await]) + for provide, (injection, _) in zip(async_to_inject, to_inject_await): + to_inject[injection] = provide result = await fn(*args, **to_inject) @@ -439,7 +450,11 @@ def _get_async_patched(fn): continue if not isinstance(provider, providers.Resource): continue - provider.shutdown() + shutdown = provider.shutdown() + if inspect.isawaitable(shutdown): + to_close_await.append(shutdown) + + await asyncio.gather(*to_close_await) return result return _patched