Add support for async generator injections (#900)

This commit is contained in:
ZipFile 2025-06-03 21:45:43 +03:00 committed by GitHub
parent c1f14a876a
commit d8e49f7dd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 154 additions and 96 deletions

View File

@ -1,23 +1,18 @@
from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar
from typing import Any, Dict
from .providers import Provider
T = TypeVar("T")
class DependencyResolver:
def __init__(
self,
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> None: ...
def __enter__(self) -> Dict[str, Any]: ...
def __exit__(self, *exc_info: Any) -> None: ...
async def __aenter__(self) -> Dict[str, Any]: ...
async def __aexit__(self, *exc_info: Any) -> None: ...
def _sync_inject(
fn: Callable[..., T],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> T: ...
async def _async_inject(
fn: Callable[..., Awaitable[T]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> T: ...
def _isawaitable(instance: Any) -> bool: ...

View File

@ -1,83 +1,110 @@
"""Wiring optimizations module."""
import asyncio
import collections.abc
import inspect
import types
from asyncio import gather
from collections.abc import Awaitable
from inspect import CO_ITERABLE_COROUTINE
from types import CoroutineType, GeneratorType
from .providers cimport Provider, Resource, NULL_AWAITABLE
from .wiring import _Marker
from .providers cimport Provider, Resource
cimport cython
def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result
@cython.internal
@cython.no_gc
cdef class KWPair:
cdef str name
cdef object value
def __cinit__(self, str name, object value, /):
self.name = name
self.value = value
cdef inline bint _is_injectable(dict kwargs, str name):
return name not in kwargs or isinstance(kwargs[name], _Marker)
cdef class DependencyResolver:
cdef dict kwargs
cdef dict to_inject
cdef object arg_key
cdef Provider provider
cdef dict injections
cdef dict closings
to_inject = kwargs.copy()
for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider()
def __init__(self, dict kwargs, dict injections, dict closings, /):
self.kwargs = kwargs
self.to_inject = kwargs.copy()
self.injections = injections
self.closings = closings
result = fn(*args, **to_inject)
async def _await_injection(self, kw_pair: KWPair, /) -> None:
self.to_inject[kw_pair.name] = await kw_pair.value
if closings:
for arg_key, provider in closings.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, Resource):
continue
provider.shutdown()
cdef object _await_injections(self, to_await: list):
return gather(*map(self._await_injection, to_await))
return result
cdef void _handle_injections_sync(self):
cdef Provider provider
for name, provider in self.injections.items():
if _is_injectable(self.kwargs, name):
self.to_inject[name] = provider()
async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result
cdef dict to_inject
cdef list to_inject_await = []
cdef list to_close_await = []
cdef object arg_key
cdef Provider provider
cdef list _handle_injections_async(self):
cdef list to_await = []
cdef Provider provider
to_inject = kwargs.copy()
for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
provide = provider()
if provider.is_async_mode_enabled():
to_inject_await.append((arg_key, provide))
elif _isawaitable(provide):
to_inject_await.append((arg_key, provide))
else:
to_inject[arg_key] = provide
for name, provider in self.injections.items():
if _is_injectable(self.kwargs, name):
provide = provider()
if to_inject_await:
async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
to_inject[injection] = provide
if provider.is_async_mode_enabled() or _isawaitable(provide):
to_await.append(KWPair(name, provide))
else:
self.to_inject[name] = provide
result = await fn(*args, **to_inject)
return to_await
if closings:
for arg_key, provider in closings.items():
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, Resource):
continue
shutdown = provider.shutdown()
if _isawaitable(shutdown):
to_close_await.append(shutdown)
cdef void _handle_closings_sync(self):
cdef Provider provider
await asyncio.gather(*to_close_await)
for name, provider in self.closings.items():
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
provider.shutdown()
return result
cdef list _handle_closings_async(self):
cdef list to_await = []
cdef Provider provider
for name, provider in self.closings.items():
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
if _isawaitable(shutdown := provider.shutdown()):
to_await.append(shutdown)
return to_await
def __enter__(self):
self._handle_injections_sync()
return self.to_inject
def __exit__(self, *_):
self._handle_closings_sync()
async def __aenter__(self):
if to_await := self._handle_injections_async():
await self._await_injections(to_await)
return self.to_inject
def __aexit__(self, *_):
if to_await := self._handle_closings_async():
return gather(*to_await)
return NULL_AWAITABLE
cdef bint _isawaitable(object instance):
"""Return true if object can be passed to an ``await`` expression."""
return (isinstance(instance, types.CoroutineType) or
isinstance(instance, types.GeneratorType) and
bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or
isinstance(instance, collections.abc.Awaitable))
return (isinstance(instance, CoroutineType) or
isinstance(instance, GeneratorType) and
bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or
isinstance(instance, Awaitable))

View File

@ -10,6 +10,7 @@ from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterable,
@ -720,6 +721,8 @@ def _get_patched(
if inspect.iscoroutinefunction(fn):
patched = _get_async_patched(fn, patched_object)
elif inspect.isasyncgenfunction(fn):
patched = _get_async_gen_patched(fn, patched_object)
else:
patched = _get_sync_patched(fn, patched_object)
@ -1035,36 +1038,41 @@ _inspect_filter = InspectFilter()
_loader = AutoLoader()
# Optimizations
from ._cwiring import _async_inject # noqa
from ._cwiring import _sync_inject # noqa
from ._cwiring import DependencyResolver # noqa: E402
# Wiring uses the following Python wrapper because there is
# no possibility to compile a first-type citizen coroutine in Cython.
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
async def _patched(*args, **kwargs):
return await _async_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
async def _patched(*args: Any, **raw_kwargs: Any) -> Any:
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
async with resolver as kwargs:
return await fn(*args, **kwargs)
return cast(F, _patched)
def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]:
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
async with resolver as kwargs:
async for obj in fn(*args, **kwargs):
yield obj
return cast(F, _patched)
def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
def _patched(*args, **kwargs):
return _sync_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
def _patched(*args: Any, **raw_kwargs: Any) -> Any:
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
with resolver as kwargs:
return fn(*args, **kwargs)
return cast(F, _patched)

View File

@ -1,7 +1,9 @@
import asyncio
from typing_extensions import Annotated
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing
from dependency_injector.wiring import Closing, Provide, inject
class TestResource:
@ -42,6 +44,15 @@ async def async_injection(
return resource1, resource2
@inject
async def async_generator_injection(
resource1: object = Provide[Container.resource1],
resource2: object = Closing[Provide[Container.resource2]],
):
yield resource1
yield resource2
@inject
async def async_injection_with_closing(
resource1: object = Closing[Provide[Container.resource1]],

View File

@ -32,6 +32,23 @@ async def test_async_injections():
assert asyncinjections.resource2.shutdown_counter == 0
@mark.asyncio
async def test_async_generator_injections() -> None:
resources = []
async for resource in asyncinjections.async_generator_injection():
resources.append(resource)
assert len(resources) == 2
assert resources[0] is asyncinjections.resource1
assert asyncinjections.resource1.init_counter == 1
assert asyncinjections.resource1.shutdown_counter == 0
assert resources[1] is asyncinjections.resource2
assert asyncinjections.resource2.init_counter == 1
assert asyncinjections.resource2.shutdown_counter == 1
@mark.asyncio
async def test_async_injections_with_closing():
resource1, resource2 = await asyncinjections.async_injection_with_closing()