Add support for async generator injections

This commit is contained in:
ZipFile 2025-06-02 23:07:20 +00:00
parent c1f14a876a
commit c82cc343dd
5 changed files with 154 additions and 96 deletions

View File

@ -1,23 +1,18 @@
from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar from typing import Any, Dict
from .providers import Provider from .providers import Provider
T = TypeVar("T") class DependencyResolver:
def __init__(
self,
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> None: ...
def __enter__(self) -> Dict[str, Any]: ...
def __exit__(self, *exc_info: Any) -> None: ...
async def __aenter__(self) -> Dict[str, Any]: ...
async def __aexit__(self, *exc_info: Any) -> None: ...
def _sync_inject(
fn: Callable[..., T],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> T: ...
async def _async_inject(
fn: Callable[..., Awaitable[T]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> T: ...
def _isawaitable(instance: Any) -> bool: ... def _isawaitable(instance: Any) -> bool: ...

View File

@ -1,83 +1,109 @@
"""Wiring optimizations module.""" """Wiring optimizations module."""
import asyncio from asyncio import gather
import collections.abc from collections.abc import Awaitable
import inspect from inspect import CO_ITERABLE_COROUTINE
import types from types import CoroutineType, GeneratorType
from .providers cimport Provider, Resource, NULL_AWAITABLE
from .wiring import _Marker from .wiring import _Marker
from .providers cimport Provider, Resource cimport cython
def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): @cython.no_gc
cdef object result cdef class KWPair:
cdef str name
cdef object value
def __cinit__(self, str name, object value, /):
self.name = name
self.value = value
cdef inline bint _is_injectable(dict kwargs, str name):
return name not in kwargs or isinstance(kwargs[name], _Marker)
cdef class DependencyResolver:
cdef dict kwargs
cdef dict to_inject cdef dict to_inject
cdef object arg_key cdef dict injections
cdef Provider provider cdef dict closings
to_inject = kwargs.copy() def __init__(self, dict kwargs, dict injections, dict closings, /):
for arg_key, provider in injections.items(): self.kwargs = kwargs
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): self.to_inject = kwargs.copy()
to_inject[arg_key] = provider() self.injections = injections
self.closings = closings
result = fn(*args, **to_inject) async def _await_injection(self, p: KWPair, /) -> None:
self.to_inject[p.name] = await p.value
if closings: cdef object _await_injections(self, to_await: list):
for arg_key, provider in closings.items(): return gather(*map(self._await_injection, to_await))
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, Resource):
continue
provider.shutdown()
return result cdef void _handle_injections_sync(self):
cdef Provider provider
for name, provider in self.injections.items():
if _is_injectable(self.kwargs, name):
self.to_inject[name] = provider()
async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): cdef list _handle_injections_async(self):
cdef object result cdef list to_await = []
cdef dict to_inject cdef Provider provider
cdef list to_inject_await = []
cdef list to_close_await = []
cdef object arg_key
cdef Provider provider
to_inject = kwargs.copy() for name, provider in self.injections.items():
for arg_key, provider in injections.items(): if _is_injectable(self.kwargs, name):
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): provide = provider()
provide = provider()
if provider.is_async_mode_enabled():
to_inject_await.append((arg_key, provide))
elif _isawaitable(provide):
to_inject_await.append((arg_key, provide))
else:
to_inject[arg_key] = provide
if to_inject_await: if provider.is_async_mode_enabled() or _isawaitable(provide):
async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await)) to_await.append(KWPair(name, provide))
for provide, (injection, _) in zip(async_to_inject, to_inject_await): else:
to_inject[injection] = provide self.to_inject[name] = provide
result = await fn(*args, **to_inject) return to_await
if closings: cdef void _handle_closings_sync(self):
for arg_key, provider in closings.items(): cdef Provider provider
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, Resource):
continue
shutdown = provider.shutdown()
if _isawaitable(shutdown):
to_close_await.append(shutdown)
await asyncio.gather(*to_close_await) for name, provider in self.closings.items():
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
provider.shutdown()
return result cdef list _handle_closings_async(self):
cdef list to_await = []
cdef Provider provider
for name, provider in self.closings.items():
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
if _isawaitable(shutdown := provider.shutdown()):
to_await.append(shutdown)
return to_await
def __enter__(self):
self._handle_injections_sync()
return self.to_inject
def __exit__(self, *_):
self._handle_closings_sync()
async def __aenter__(self):
if to_await := self._handle_injections_async():
await self._await_injections(to_await)
return self.to_inject
def __aexit__(self, *_):
if to_await := self._handle_closings_async():
return gather(*to_await)
return NULL_AWAITABLE
cdef bint _isawaitable(object instance): cdef bint _isawaitable(object instance):
"""Return true if object can be passed to an ``await`` expression.""" """Return true if object can be passed to an ``await`` expression."""
return (isinstance(instance, types.CoroutineType) or return (isinstance(instance, CoroutineType) or
isinstance(instance, types.GeneratorType) and isinstance(instance, GeneratorType) and
bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or
isinstance(instance, collections.abc.Awaitable)) isinstance(instance, Awaitable))

View File

@ -10,6 +10,7 @@ from types import ModuleType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
@ -720,6 +721,8 @@ def _get_patched(
if inspect.iscoroutinefunction(fn): if inspect.iscoroutinefunction(fn):
patched = _get_async_patched(fn, patched_object) patched = _get_async_patched(fn, patched_object)
elif inspect.isasyncgenfunction(fn):
patched = _get_async_gen_patched(fn, patched_object)
else: else:
patched = _get_sync_patched(fn, patched_object) patched = _get_sync_patched(fn, patched_object)
@ -1035,36 +1038,42 @@ _inspect_filter = InspectFilter()
_loader = AutoLoader() _loader = AutoLoader()
# Optimizations # Optimizations
from ._cwiring import _async_inject # noqa from ._cwiring import DependencyResolver # noqa: E402
from ._cwiring import _sync_inject # noqa
# Wiring uses the following Python wrapper because there is # Wiring uses the following Python wrapper because there is
# no possibility to compile a first-type citizen coroutine in Cython. # no possibility to compile a first-type citizen coroutine in Cython.
def _get_async_patched(fn: F, patched: PatchedCallable) -> F: def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn) @functools.wraps(fn)
async def _patched(*args, **kwargs): async def _patched(*args: Any, **raw_kwargs: Any) -> Any:
return await _async_inject( dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
fn,
args, async with dr as kwargs:
kwargs, return await fn(*args, **kwargs)
patched.injections,
patched.closing, return cast(F, _patched)
)
# Async generators too...
def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]:
dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
async with dr as kwargs:
async for obj in fn(*args, **kwargs):
yield obj
return cast(F, _patched) return cast(F, _patched)
def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn) @functools.wraps(fn)
def _patched(*args, **kwargs): def _patched(*args: Any, **raw_kwargs: Any) -> Any:
return _sync_inject( dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
fn,
args, with dr as kwargs:
kwargs, return fn(*args, **kwargs)
patched.injections,
patched.closing,
)
return cast(F, _patched) return cast(F, _patched)

View File

@ -1,7 +1,9 @@
import asyncio import asyncio
from typing_extensions import Annotated
from dependency_injector import containers, providers from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing from dependency_injector.wiring import Closing, Provide, inject
class TestResource: class TestResource:
@ -42,6 +44,15 @@ async def async_injection(
return resource1, resource2 return resource1, resource2
@inject
async def async_generator_injection(
resource1: object = Provide[Container.resource1],
resource2: object = Closing[Provide[Container.resource2]],
):
yield resource1
yield resource2
@inject @inject
async def async_injection_with_closing( async def async_injection_with_closing(
resource1: object = Closing[Provide[Container.resource1]], resource1: object = Closing[Provide[Container.resource1]],

View File

@ -32,6 +32,23 @@ async def test_async_injections():
assert asyncinjections.resource2.shutdown_counter == 0 assert asyncinjections.resource2.shutdown_counter == 0
@mark.asyncio
async def test_async_generator_injections() -> None:
resources = []
async for resource in asyncinjections.async_generator_injection():
resources.append(resource)
assert len(resources) == 2
assert resources[0] is asyncinjections.resource1
assert asyncinjections.resource1.init_counter == 1
assert asyncinjections.resource1.shutdown_counter == 0
assert resources[1] is asyncinjections.resource2
assert asyncinjections.resource2.init_counter == 1
assert asyncinjections.resource2.shutdown_counter == 1
@mark.asyncio @mark.asyncio
async def test_async_injections_with_closing(): async def test_async_injections_with_closing():
resource1, resource2 = await asyncinjections.async_injection_with_closing() resource1, resource2 = await asyncinjections.async_injection_with_closing()