From 0d6fdb5b78ae7c961e3e1f5d466f9a04b7456ba6 Mon Sep 17 00:00:00 2001 From: Martin Lafrance <20482569+martlaf@users.noreply.github.com> Date: Sun, 23 Feb 2025 11:17:45 -0500 Subject: [PATCH] Fix broken wiring of sync inject-decorated methods (#673) Co-authored-by: Martin Lafrance Co-authored-by: ZipFile --- src/dependency_injector/_cwiring.pyx | 49 +++++++++----------- src/dependency_injector/wiring.py | 17 ++++++- tests/unit/wiring/test_introspection_py36.py | 7 +++ 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/dependency_injector/_cwiring.pyx b/src/dependency_injector/_cwiring.pyx index 88b6bc5a..84a5485f 100644 --- a/src/dependency_injector/_cwiring.pyx +++ b/src/dependency_injector/_cwiring.pyx @@ -2,44 +2,39 @@ import asyncio import collections.abc -import functools import inspect import types -from . import providers -from .wiring import _Marker, PatchedCallable +from .wiring import _Marker -from .providers cimport Provider +from .providers cimport Provider, Resource -def _get_sync_patched(fn, patched: PatchedCallable): - @functools.wraps(fn) - def _patched(*args, **kwargs): - cdef object result - cdef dict to_inject - cdef object arg_key - cdef Provider provider +def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): + cdef object result + cdef dict to_inject + cdef object arg_key + cdef Provider provider - to_inject = kwargs.copy() - for arg_key, provider in patched.injections.items(): - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): - to_inject[arg_key] = provider() + to_inject = kwargs.copy() + for arg_key, provider in injections.items(): + if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): + to_inject[arg_key] = provider() - result = fn(*args, **to_inject) + result = fn(*args, **to_inject) - if patched.closing: - for arg_key, provider in patched.closing.items(): - if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): - continue - if not isinstance(provider, providers.Resource): - continue - provider.shutdown() + if closings: + for arg_key, provider in closings.items(): + if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): + continue + if not isinstance(provider, Resource): + continue + provider.shutdown() - return result - return _patched + return result -async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings): +async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): cdef object result cdef dict to_inject cdef list to_inject_await = [] @@ -69,7 +64,7 @@ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dic for arg_key, provider in closings.items(): if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): continue - if not isinstance(provider, providers.Resource): + if not isinstance(provider, Resource): continue shutdown = provider.shutdown() if _isawaitable(shutdown): diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 67d56f29..9bb990ab 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -1030,7 +1030,7 @@ _inspect_filter = InspectFilter() _loader = AutoLoader() # Optimizations -from ._cwiring import _get_sync_patched # noqa +from ._cwiring import _sync_inject # noqa from ._cwiring import _async_inject # noqa @@ -1047,4 +1047,17 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F: patched.closing, ) - return _patched + return cast(F, _patched) + + +def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: + @functools.wraps(fn) + def _patched(*args, **kwargs): + return _sync_inject( + fn, + args, + kwargs, + patched.injections, + patched.closing, + ) + return cast(F, _patched) diff --git a/tests/unit/wiring/test_introspection_py36.py b/tests/unit/wiring/test_introspection_py36.py index c7149602..66b36a80 100644 --- a/tests/unit/wiring/test_introspection_py36.py +++ b/tests/unit/wiring/test_introspection_py36.py @@ -6,6 +6,13 @@ import inspect from dependency_injector.wiring import inject +def test_isfunction(): + @inject + def foo(): ... + + assert inspect.isfunction(foo) + + def test_asyncio_iscoroutinefunction(): @inject async def foo():