Fix broken wiring of sync inject-decorated methods (#673)

Co-authored-by: Martin Lafrance <mlafrance@cae.com>
Co-authored-by: ZipFile <zipfile.d@protonmail.com>
This commit is contained in:
Martin Lafrance 2025-02-23 11:17:45 -05:00 committed by GitHub
parent 2330122de6
commit 0d6fdb5b78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 29 deletions

View File

@ -2,44 +2,39 @@
import asyncio import asyncio
import collections.abc import collections.abc
import functools
import inspect import inspect
import types import types
from . import providers from .wiring import _Marker
from .wiring import _Marker, PatchedCallable
from .providers cimport Provider from .providers cimport Provider, Resource
def _get_sync_patched(fn, patched: PatchedCallable): def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
@functools.wraps(fn)
def _patched(*args, **kwargs):
cdef object result cdef object result
cdef dict to_inject cdef dict to_inject
cdef object arg_key cdef object arg_key
cdef Provider provider cdef Provider provider
to_inject = kwargs.copy() to_inject = kwargs.copy()
for arg_key, provider in patched.injections.items(): for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider() to_inject[arg_key] = provider()
result = fn(*args, **to_inject) result = fn(*args, **to_inject)
if patched.closing: if closings:
for arg_key, provider in patched.closing.items(): for arg_key, provider in closings.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue continue
if not isinstance(provider, providers.Resource): if not isinstance(provider, Resource):
continue continue
provider.shutdown() provider.shutdown()
return result return result
return _patched
async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings): async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result cdef object result
cdef dict to_inject cdef dict to_inject
cdef list to_inject_await = [] cdef list to_inject_await = []
@ -69,7 +64,7 @@ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dic
for arg_key, provider in closings.items(): for arg_key, provider in closings.items():
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
continue continue
if not isinstance(provider, providers.Resource): if not isinstance(provider, Resource):
continue continue
shutdown = provider.shutdown() shutdown = provider.shutdown()
if _isawaitable(shutdown): if _isawaitable(shutdown):

View File

@ -1030,7 +1030,7 @@ _inspect_filter = InspectFilter()
_loader = AutoLoader() _loader = AutoLoader()
# Optimizations # Optimizations
from ._cwiring import _get_sync_patched # noqa from ._cwiring import _sync_inject # noqa
from ._cwiring import _async_inject # noqa from ._cwiring import _async_inject # noqa
@ -1047,4 +1047,17 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
patched.closing, patched.closing,
) )
return _patched return cast(F, _patched)
def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
def _patched(*args, **kwargs):
return _sync_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
return cast(F, _patched)

View File

@ -6,6 +6,13 @@ import inspect
from dependency_injector.wiring import inject from dependency_injector.wiring import inject
def test_isfunction():
@inject
def foo(): ...
assert inspect.isfunction(foo)
def test_asyncio_iscoroutinefunction(): def test_asyncio_iscoroutinefunction():
@inject @inject
async def foo(): async def foo():