diff --git a/src/dependency_injector/_cwiring.pyx b/src/dependency_injector/_cwiring.pyx index 3e2775c7..01871243 100644 --- a/src/dependency_injector/_cwiring.pyx +++ b/src/dependency_injector/_cwiring.pyx @@ -5,24 +5,11 @@ from collections.abc import Awaitable from inspect import CO_ITERABLE_COROUTINE from types import CoroutineType, GeneratorType -from .providers cimport Provider, Resource, NULL_AWAITABLE +from .providers cimport Provider, Resource from .wiring import _Marker -cimport cython - -@cython.internal -@cython.no_gc -cdef class KWPair: - cdef str name - cdef object value - - def __cinit__(self, str name, object value, /): - self.name = name - self.value = value - - -cdef inline bint _is_injectable(dict kwargs, str name): +cdef inline bint _is_injectable(dict kwargs, object name): return name not in kwargs or isinstance(kwargs[name], _Marker) @@ -38,11 +25,8 @@ cdef class DependencyResolver: self.injections = injections self.closings = closings - async def _await_injection(self, kw_pair: KWPair, /) -> None: - self.to_inject[kw_pair.name] = await kw_pair.value - - cdef object _await_injections(self, to_await: list): - return gather(*map(self._await_injection, to_await)) + async def _await_injection(self, name: str, value: object, /) -> None: + self.to_inject[name] = await value cdef void _handle_injections_sync(self): cdef Provider provider @@ -60,7 +44,7 @@ cdef class DependencyResolver: provide = provider() if provider.is_async_mode_enabled() or _isawaitable(provide): - to_await.append(KWPair(name, provide)) + to_await.append(self._await_injection(name, provide)) else: self.to_inject[name] = provide @@ -93,13 +77,12 @@ cdef class DependencyResolver: async def __aenter__(self): if to_await := self._handle_injections_async(): - await self._await_injections(to_await) + await gather(*to_await) return self.to_inject - def __aexit__(self, *_): + async def __aexit__(self, *_): if to_await := self._handle_closings_async(): - return gather(*to_await) - return NULL_AWAITABLE + await gather(*to_await) cdef bint _isawaitable(object instance):