mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-06-16 19:43:15 +03:00
Add support for async generator injections (#900)
This commit is contained in:
parent
c1f14a876a
commit
d8e49f7dd5
|
@ -1,23 +1,18 @@
|
|||
from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar
|
||||
from typing import Any, Dict
|
||||
|
||||
from .providers import Provider
|
||||
|
||||
T = TypeVar("T")
|
||||
class DependencyResolver:
|
||||
def __init__(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
injections: Dict[str, Provider[Any]],
|
||||
closings: Dict[str, Provider[Any]],
|
||||
/,
|
||||
) -> None: ...
|
||||
def __enter__(self) -> Dict[str, Any]: ...
|
||||
def __exit__(self, *exc_info: Any) -> None: ...
|
||||
async def __aenter__(self) -> Dict[str, Any]: ...
|
||||
async def __aexit__(self, *exc_info: Any) -> None: ...
|
||||
|
||||
def _sync_inject(
|
||||
fn: Callable[..., T],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
injections: Dict[str, Provider[Any]],
|
||||
closings: Dict[str, Provider[Any]],
|
||||
/,
|
||||
) -> T: ...
|
||||
async def _async_inject(
|
||||
fn: Callable[..., Awaitable[T]],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
injections: Dict[str, Provider[Any]],
|
||||
closings: Dict[str, Provider[Any]],
|
||||
/,
|
||||
) -> T: ...
|
||||
def _isawaitable(instance: Any) -> bool: ...
|
||||
|
|
|
@ -1,83 +1,110 @@
|
|||
"""Wiring optimizations module."""
|
||||
|
||||
import asyncio
|
||||
import collections.abc
|
||||
import inspect
|
||||
import types
|
||||
from asyncio import gather
|
||||
from collections.abc import Awaitable
|
||||
from inspect import CO_ITERABLE_COROUTINE
|
||||
from types import CoroutineType, GeneratorType
|
||||
|
||||
from .providers cimport Provider, Resource, NULL_AWAITABLE
|
||||
from .wiring import _Marker
|
||||
|
||||
from .providers cimport Provider, Resource
|
||||
cimport cython
|
||||
|
||||
|
||||
def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
|
||||
cdef object result
|
||||
@cython.internal
|
||||
@cython.no_gc
|
||||
cdef class KWPair:
|
||||
cdef str name
|
||||
cdef object value
|
||||
|
||||
def __cinit__(self, str name, object value, /):
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
|
||||
cdef inline bint _is_injectable(dict kwargs, str name):
|
||||
return name not in kwargs or isinstance(kwargs[name], _Marker)
|
||||
|
||||
|
||||
cdef class DependencyResolver:
|
||||
cdef dict kwargs
|
||||
cdef dict to_inject
|
||||
cdef object arg_key
|
||||
cdef Provider provider
|
||||
cdef dict injections
|
||||
cdef dict closings
|
||||
|
||||
to_inject = kwargs.copy()
|
||||
for arg_key, provider in injections.items():
|
||||
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
|
||||
to_inject[arg_key] = provider()
|
||||
def __init__(self, dict kwargs, dict injections, dict closings, /):
|
||||
self.kwargs = kwargs
|
||||
self.to_inject = kwargs.copy()
|
||||
self.injections = injections
|
||||
self.closings = closings
|
||||
|
||||
result = fn(*args, **to_inject)
|
||||
async def _await_injection(self, kw_pair: KWPair, /) -> None:
|
||||
self.to_inject[kw_pair.name] = await kw_pair.value
|
||||
|
||||
if closings:
|
||||
for arg_key, provider in closings.items():
|
||||
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
|
||||
continue
|
||||
if not isinstance(provider, Resource):
|
||||
continue
|
||||
provider.shutdown()
|
||||
cdef object _await_injections(self, to_await: list):
|
||||
return gather(*map(self._await_injection, to_await))
|
||||
|
||||
return result
|
||||
cdef void _handle_injections_sync(self):
|
||||
cdef Provider provider
|
||||
|
||||
for name, provider in self.injections.items():
|
||||
if _is_injectable(self.kwargs, name):
|
||||
self.to_inject[name] = provider()
|
||||
|
||||
async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
|
||||
cdef object result
|
||||
cdef dict to_inject
|
||||
cdef list to_inject_await = []
|
||||
cdef list to_close_await = []
|
||||
cdef object arg_key
|
||||
cdef Provider provider
|
||||
cdef list _handle_injections_async(self):
|
||||
cdef list to_await = []
|
||||
cdef Provider provider
|
||||
|
||||
to_inject = kwargs.copy()
|
||||
for arg_key, provider in injections.items():
|
||||
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
|
||||
provide = provider()
|
||||
if provider.is_async_mode_enabled():
|
||||
to_inject_await.append((arg_key, provide))
|
||||
elif _isawaitable(provide):
|
||||
to_inject_await.append((arg_key, provide))
|
||||
else:
|
||||
to_inject[arg_key] = provide
|
||||
for name, provider in self.injections.items():
|
||||
if _is_injectable(self.kwargs, name):
|
||||
provide = provider()
|
||||
|
||||
if to_inject_await:
|
||||
async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))
|
||||
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
|
||||
to_inject[injection] = provide
|
||||
if provider.is_async_mode_enabled() or _isawaitable(provide):
|
||||
to_await.append(KWPair(name, provide))
|
||||
else:
|
||||
self.to_inject[name] = provide
|
||||
|
||||
result = await fn(*args, **to_inject)
|
||||
return to_await
|
||||
|
||||
if closings:
|
||||
for arg_key, provider in closings.items():
|
||||
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
|
||||
continue
|
||||
if not isinstance(provider, Resource):
|
||||
continue
|
||||
shutdown = provider.shutdown()
|
||||
if _isawaitable(shutdown):
|
||||
to_close_await.append(shutdown)
|
||||
cdef void _handle_closings_sync(self):
|
||||
cdef Provider provider
|
||||
|
||||
await asyncio.gather(*to_close_await)
|
||||
for name, provider in self.closings.items():
|
||||
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
|
||||
provider.shutdown()
|
||||
|
||||
return result
|
||||
cdef list _handle_closings_async(self):
|
||||
cdef list to_await = []
|
||||
cdef Provider provider
|
||||
|
||||
for name, provider in self.closings.items():
|
||||
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
|
||||
if _isawaitable(shutdown := provider.shutdown()):
|
||||
to_await.append(shutdown)
|
||||
|
||||
return to_await
|
||||
|
||||
def __enter__(self):
|
||||
self._handle_injections_sync()
|
||||
return self.to_inject
|
||||
|
||||
def __exit__(self, *_):
|
||||
self._handle_closings_sync()
|
||||
|
||||
async def __aenter__(self):
|
||||
if to_await := self._handle_injections_async():
|
||||
await self._await_injections(to_await)
|
||||
return self.to_inject
|
||||
|
||||
def __aexit__(self, *_):
|
||||
if to_await := self._handle_closings_async():
|
||||
return gather(*to_await)
|
||||
return NULL_AWAITABLE
|
||||
|
||||
|
||||
cdef bint _isawaitable(object instance):
|
||||
"""Return true if object can be passed to an ``await`` expression."""
|
||||
return (isinstance(instance, types.CoroutineType) or
|
||||
isinstance(instance, types.GeneratorType) and
|
||||
bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or
|
||||
isinstance(instance, collections.abc.Awaitable))
|
||||
return (isinstance(instance, CoroutineType) or
|
||||
isinstance(instance, GeneratorType) and
|
||||
bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or
|
||||
isinstance(instance, Awaitable))
|
||||
|
|
|
@ -10,6 +10,7 @@ from types import ModuleType
|
|||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
|
@ -720,6 +721,8 @@ def _get_patched(
|
|||
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
patched = _get_async_patched(fn, patched_object)
|
||||
elif inspect.isasyncgenfunction(fn):
|
||||
patched = _get_async_gen_patched(fn, patched_object)
|
||||
else:
|
||||
patched = _get_sync_patched(fn, patched_object)
|
||||
|
||||
|
@ -1035,36 +1038,41 @@ _inspect_filter = InspectFilter()
|
|||
_loader = AutoLoader()
|
||||
|
||||
# Optimizations
|
||||
from ._cwiring import _async_inject # noqa
|
||||
from ._cwiring import _sync_inject # noqa
|
||||
from ._cwiring import DependencyResolver # noqa: E402
|
||||
|
||||
|
||||
# Wiring uses the following Python wrapper because there is
|
||||
# no possibility to compile a first-type citizen coroutine in Cython.
|
||||
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
|
||||
@functools.wraps(fn)
|
||||
async def _patched(*args, **kwargs):
|
||||
return await _async_inject(
|
||||
fn,
|
||||
args,
|
||||
kwargs,
|
||||
patched.injections,
|
||||
patched.closing,
|
||||
)
|
||||
async def _patched(*args: Any, **raw_kwargs: Any) -> Any:
|
||||
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
|
||||
|
||||
async with resolver as kwargs:
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
return cast(F, _patched)
|
||||
|
||||
|
||||
def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F:
|
||||
@functools.wraps(fn)
|
||||
async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]:
|
||||
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
|
||||
|
||||
async with resolver as kwargs:
|
||||
async for obj in fn(*args, **kwargs):
|
||||
yield obj
|
||||
|
||||
return cast(F, _patched)
|
||||
|
||||
|
||||
def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
|
||||
@functools.wraps(fn)
|
||||
def _patched(*args, **kwargs):
|
||||
return _sync_inject(
|
||||
fn,
|
||||
args,
|
||||
kwargs,
|
||||
patched.injections,
|
||||
patched.closing,
|
||||
)
|
||||
def _patched(*args: Any, **raw_kwargs: Any) -> Any:
|
||||
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)
|
||||
|
||||
with resolver as kwargs:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return cast(F, _patched)
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import asyncio
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
from dependency_injector.wiring import inject, Provide, Closing
|
||||
from dependency_injector.wiring import Closing, Provide, inject
|
||||
|
||||
|
||||
class TestResource:
|
||||
|
@ -42,6 +44,15 @@ async def async_injection(
|
|||
return resource1, resource2
|
||||
|
||||
|
||||
@inject
|
||||
async def async_generator_injection(
|
||||
resource1: object = Provide[Container.resource1],
|
||||
resource2: object = Closing[Provide[Container.resource2]],
|
||||
):
|
||||
yield resource1
|
||||
yield resource2
|
||||
|
||||
|
||||
@inject
|
||||
async def async_injection_with_closing(
|
||||
resource1: object = Closing[Provide[Container.resource1]],
|
||||
|
|
|
@ -32,6 +32,23 @@ async def test_async_injections():
|
|||
assert asyncinjections.resource2.shutdown_counter == 0
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_async_generator_injections() -> None:
|
||||
resources = []
|
||||
|
||||
async for resource in asyncinjections.async_generator_injection():
|
||||
resources.append(resource)
|
||||
|
||||
assert len(resources) == 2
|
||||
assert resources[0] is asyncinjections.resource1
|
||||
assert asyncinjections.resource1.init_counter == 1
|
||||
assert asyncinjections.resource1.shutdown_counter == 0
|
||||
|
||||
assert resources[1] is asyncinjections.resource2
|
||||
assert asyncinjections.resource2.init_counter == 1
|
||||
assert asyncinjections.resource2.shutdown_counter == 1
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_async_injections_with_closing():
|
||||
resource1, resource2 = await asyncinjections.async_injection_with_closing()
|
||||
|
|
Loading…
Reference in New Issue
Block a user