diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index a619cab7..8c250b20 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -176,15 +176,26 @@ def _fetch_modules(package): def _patch_with_injections(fn, injections): - @functools.wraps(fn) - def _patched(*args, **kwargs): - to_inject = {} - for injection, provider in injections.items(): - to_inject[injection] = provider() + if inspect.iscoroutinefunction(fn): + @functools.wraps(fn) + async def _patched(*args, **kwargs): + to_inject = {} + for injection, provider in injections.items(): + to_inject[injection] = provider() - to_inject.update(kwargs) + to_inject.update(kwargs) - return fn(*args, **to_inject) + return await fn(*args, **to_inject) + else: + @functools.wraps(fn) + def _patched(*args, **kwargs): + to_inject = {} + for injection, provider in injections.items(): + to_inject[injection] = provider() + + to_inject.update(kwargs) + + return fn(*args, **to_inject) _patched.__wired__ = True _patched.__original__ = fn