mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-11-01 00:17:55 +03:00 
			
		
		
		
	Add support for async generator injections
This commit is contained in:
		
							parent
							
								
									c1f14a876a
								
							
						
					
					
						commit
						c82cc343dd
					
				|  | @ -1,23 +1,18 @@ | |||
| from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar | ||||
| from typing import Any, Dict | ||||
| 
 | ||||
| from .providers import Provider | ||||
| 
 | ||||
| T = TypeVar("T") | ||||
| class DependencyResolver: | ||||
|     def __init__( | ||||
|         self, | ||||
|         kwargs: Dict[str, Any], | ||||
|         injections: Dict[str, Provider[Any]], | ||||
|         closings: Dict[str, Provider[Any]], | ||||
|         /, | ||||
|     ) -> None: ... | ||||
|     def __enter__(self) -> Dict[str, Any]: ... | ||||
|     def __exit__(self, *exc_info: Any) -> None: ... | ||||
|     async def __aenter__(self) -> Dict[str, Any]: ... | ||||
|     async def __aexit__(self, *exc_info: Any) -> None: ... | ||||
| 
 | ||||
| def _sync_inject( | ||||
|     fn: Callable[..., T], | ||||
|     args: Tuple[Any, ...], | ||||
|     kwargs: Dict[str, Any], | ||||
|     injections: Dict[str, Provider[Any]], | ||||
|     closings: Dict[str, Provider[Any]], | ||||
|     /, | ||||
| ) -> T: ... | ||||
| async def _async_inject( | ||||
|     fn: Callable[..., Awaitable[T]], | ||||
|     args: Tuple[Any, ...], | ||||
|     kwargs: Dict[str, Any], | ||||
|     injections: Dict[str, Provider[Any]], | ||||
|     closings: Dict[str, Provider[Any]], | ||||
|     /, | ||||
| ) -> T: ... | ||||
| def _isawaitable(instance: Any) -> bool: ... | ||||
|  |  | |||
|  | @ -1,83 +1,109 @@ | |||
| """Wiring optimizations module.""" | ||||
| 
 | ||||
| import asyncio | ||||
| import collections.abc | ||||
| import inspect | ||||
| import types | ||||
| from asyncio import gather | ||||
| from collections.abc import Awaitable | ||||
| from inspect import CO_ITERABLE_COROUTINE | ||||
| from types import CoroutineType, GeneratorType | ||||
| 
 | ||||
| from .providers cimport Provider, Resource, NULL_AWAITABLE | ||||
| from .wiring import _Marker | ||||
| 
 | ||||
| from .providers cimport Provider, Resource | ||||
| cimport cython | ||||
| 
 | ||||
| 
 | ||||
| def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): | ||||
|     cdef object result | ||||
| @cython.no_gc | ||||
| cdef class KWPair: | ||||
|     cdef str name | ||||
|     cdef object value | ||||
| 
 | ||||
|     def __cinit__(self, str name, object value, /): | ||||
|         self.name = name | ||||
|         self.value = value | ||||
| 
 | ||||
| 
 | ||||
| cdef inline bint _is_injectable(dict kwargs, str name): | ||||
|     return name not in kwargs or isinstance(kwargs[name], _Marker) | ||||
| 
 | ||||
| 
 | ||||
| cdef class DependencyResolver: | ||||
|     cdef dict kwargs | ||||
|     cdef dict to_inject | ||||
|     cdef object arg_key | ||||
|     cdef Provider provider | ||||
|     cdef dict injections | ||||
|     cdef dict closings | ||||
| 
 | ||||
|     to_inject = kwargs.copy() | ||||
|     for arg_key, provider in injections.items(): | ||||
|         if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): | ||||
|             to_inject[arg_key] = provider() | ||||
|     def __init__(self, dict kwargs, dict injections, dict closings, /): | ||||
|         self.kwargs = kwargs | ||||
|         self.to_inject = kwargs.copy() | ||||
|         self.injections = injections | ||||
|         self.closings = closings | ||||
| 
 | ||||
|     result = fn(*args, **to_inject) | ||||
|     async def _await_injection(self, p: KWPair, /) -> None: | ||||
|         self.to_inject[p.name] = await p.value | ||||
| 
 | ||||
|     if closings: | ||||
|         for arg_key, provider in closings.items(): | ||||
|             if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): | ||||
|                 continue | ||||
|             if not isinstance(provider, Resource): | ||||
|                 continue | ||||
|             provider.shutdown() | ||||
|     cdef object _await_injections(self, to_await: list): | ||||
|         return gather(*map(self._await_injection, to_await)) | ||||
| 
 | ||||
|     return result | ||||
|     cdef void _handle_injections_sync(self): | ||||
|         cdef Provider provider | ||||
| 
 | ||||
|         for name, provider in self.injections.items(): | ||||
|             if _is_injectable(self.kwargs, name): | ||||
|                 self.to_inject[name] = provider() | ||||
| 
 | ||||
| async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): | ||||
|     cdef object result | ||||
|     cdef dict to_inject | ||||
|     cdef list to_inject_await = [] | ||||
|     cdef list to_close_await = [] | ||||
|     cdef object arg_key | ||||
|     cdef Provider provider | ||||
|     cdef list _handle_injections_async(self): | ||||
|         cdef list to_await = [] | ||||
|         cdef Provider provider | ||||
| 
 | ||||
|     to_inject = kwargs.copy() | ||||
|     for arg_key, provider in injections.items(): | ||||
|         if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): | ||||
|             provide = provider() | ||||
|             if provider.is_async_mode_enabled(): | ||||
|                 to_inject_await.append((arg_key, provide)) | ||||
|             elif _isawaitable(provide): | ||||
|                 to_inject_await.append((arg_key, provide)) | ||||
|             else: | ||||
|                 to_inject[arg_key] = provide | ||||
|         for name, provider in self.injections.items(): | ||||
|             if _is_injectable(self.kwargs, name): | ||||
|                 provide = provider() | ||||
| 
 | ||||
|     if to_inject_await: | ||||
|         async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await)) | ||||
|         for provide, (injection, _) in zip(async_to_inject, to_inject_await): | ||||
|             to_inject[injection] = provide | ||||
|                 if provider.is_async_mode_enabled() or _isawaitable(provide): | ||||
|                     to_await.append(KWPair(name, provide)) | ||||
|                 else: | ||||
|                     self.to_inject[name] = provide | ||||
| 
 | ||||
|     result = await fn(*args, **to_inject) | ||||
|         return to_await | ||||
| 
 | ||||
|     if closings: | ||||
|         for arg_key, provider in closings.items(): | ||||
|             if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): | ||||
|                 continue | ||||
|             if not isinstance(provider, Resource): | ||||
|                 continue | ||||
|             shutdown = provider.shutdown() | ||||
|             if _isawaitable(shutdown): | ||||
|                 to_close_await.append(shutdown) | ||||
|     cdef void _handle_closings_sync(self): | ||||
|         cdef Provider provider | ||||
| 
 | ||||
|         await asyncio.gather(*to_close_await) | ||||
|         for name, provider in self.closings.items(): | ||||
|             if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): | ||||
|                 provider.shutdown() | ||||
| 
 | ||||
|     return result | ||||
|     cdef list _handle_closings_async(self): | ||||
|         cdef list to_await = [] | ||||
|         cdef Provider provider | ||||
| 
 | ||||
|         for name, provider in self.closings.items(): | ||||
|             if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): | ||||
|                 if _isawaitable(shutdown := provider.shutdown()): | ||||
|                     to_await.append(shutdown) | ||||
| 
 | ||||
|         return to_await | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|         self._handle_injections_sync() | ||||
|         return self.to_inject | ||||
| 
 | ||||
|     def __exit__(self, *_): | ||||
|         self._handle_closings_sync() | ||||
| 
 | ||||
|     async def __aenter__(self): | ||||
|         if to_await := self._handle_injections_async(): | ||||
|             await self._await_injections(to_await) | ||||
|         return self.to_inject | ||||
| 
 | ||||
|     def __aexit__(self, *_): | ||||
|         if to_await := self._handle_closings_async(): | ||||
|             return gather(*to_await) | ||||
|         return NULL_AWAITABLE | ||||
| 
 | ||||
| 
 | ||||
| cdef bint _isawaitable(object instance): | ||||
|     """Return true if object can be passed to an ``await`` expression.""" | ||||
|     return (isinstance(instance, types.CoroutineType) or | ||||
|             isinstance(instance, types.GeneratorType) and | ||||
|             bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or | ||||
|             isinstance(instance, collections.abc.Awaitable)) | ||||
|     return (isinstance(instance, CoroutineType) or | ||||
|             isinstance(instance, GeneratorType) and | ||||
|             bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or | ||||
|             isinstance(instance, Awaitable)) | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ from types import ModuleType | |||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Any, | ||||
|     AsyncIterator, | ||||
|     Callable, | ||||
|     Dict, | ||||
|     Iterable, | ||||
|  | @ -720,6 +721,8 @@ def _get_patched( | |||
| 
 | ||||
|     if inspect.iscoroutinefunction(fn): | ||||
|         patched = _get_async_patched(fn, patched_object) | ||||
|     elif inspect.isasyncgenfunction(fn): | ||||
|         patched = _get_async_gen_patched(fn, patched_object) | ||||
|     else: | ||||
|         patched = _get_sync_patched(fn, patched_object) | ||||
| 
 | ||||
|  | @ -1035,36 +1038,42 @@ _inspect_filter = InspectFilter() | |||
| _loader = AutoLoader() | ||||
| 
 | ||||
| # Optimizations | ||||
| from ._cwiring import _async_inject  # noqa | ||||
| from ._cwiring import _sync_inject  # noqa | ||||
| from ._cwiring import DependencyResolver  # noqa: E402 | ||||
| 
 | ||||
| 
 | ||||
| # Wiring uses the following Python wrapper because there is | ||||
| # no possibility to compile a first-type citizen coroutine in Cython. | ||||
| def _get_async_patched(fn: F, patched: PatchedCallable) -> F: | ||||
|     @functools.wraps(fn) | ||||
|     async def _patched(*args, **kwargs): | ||||
|         return await _async_inject( | ||||
|             fn, | ||||
|             args, | ||||
|             kwargs, | ||||
|             patched.injections, | ||||
|             patched.closing, | ||||
|         ) | ||||
|     async def _patched(*args: Any, **raw_kwargs: Any) -> Any: | ||||
|         dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing) | ||||
| 
 | ||||
|         async with dr as kwargs: | ||||
|             return await fn(*args, **kwargs) | ||||
| 
 | ||||
|     return cast(F, _patched) | ||||
| 
 | ||||
| 
 | ||||
| # Async generators too... | ||||
| def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F: | ||||
|     @functools.wraps(fn) | ||||
|     async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]: | ||||
|         dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing) | ||||
| 
 | ||||
|         async with dr as kwargs: | ||||
|             async for obj in fn(*args, **kwargs): | ||||
|                 yield obj | ||||
| 
 | ||||
|     return cast(F, _patched) | ||||
| 
 | ||||
| 
 | ||||
| def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: | ||||
|     @functools.wraps(fn) | ||||
|     def _patched(*args, **kwargs): | ||||
|         return _sync_inject( | ||||
|             fn, | ||||
|             args, | ||||
|             kwargs, | ||||
|             patched.injections, | ||||
|             patched.closing, | ||||
|         ) | ||||
|     def _patched(*args: Any, **raw_kwargs: Any) -> Any: | ||||
|         dr = DependencyResolver(raw_kwargs, patched.injections, patched.closing) | ||||
| 
 | ||||
|         with dr as kwargs: | ||||
|             return fn(*args, **kwargs) | ||||
| 
 | ||||
|     return cast(F, _patched) | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,9 @@ | |||
| import asyncio | ||||
| 
 | ||||
| from typing_extensions import Annotated | ||||
| 
 | ||||
| from dependency_injector import containers, providers | ||||
| from dependency_injector.wiring import inject, Provide, Closing | ||||
| from dependency_injector.wiring import Closing, Provide, inject | ||||
| 
 | ||||
| 
 | ||||
| class TestResource: | ||||
|  | @ -42,6 +44,15 @@ async def async_injection( | |||
|     return resource1, resource2 | ||||
| 
 | ||||
| 
 | ||||
| @inject | ||||
| async def async_generator_injection( | ||||
|         resource1: object = Provide[Container.resource1], | ||||
|         resource2: object = Closing[Provide[Container.resource2]], | ||||
| ): | ||||
|     yield resource1 | ||||
|     yield resource2 | ||||
| 
 | ||||
| 
 | ||||
| @inject | ||||
| async def async_injection_with_closing( | ||||
|         resource1: object = Closing[Provide[Container.resource1]], | ||||
|  |  | |||
|  | @ -32,6 +32,23 @@ async def test_async_injections(): | |||
|     assert asyncinjections.resource2.shutdown_counter == 0 | ||||
| 
 | ||||
| 
 | ||||
| @mark.asyncio | ||||
| async def test_async_generator_injections() -> None: | ||||
|     resources = [] | ||||
| 
 | ||||
|     async for resource in asyncinjections.async_generator_injection(): | ||||
|         resources.append(resource) | ||||
| 
 | ||||
|     assert len(resources) == 2 | ||||
|     assert resources[0] is asyncinjections.resource1 | ||||
|     assert asyncinjections.resource1.init_counter == 1 | ||||
|     assert asyncinjections.resource1.shutdown_counter == 0 | ||||
| 
 | ||||
|     assert resources[1] is asyncinjections.resource2 | ||||
|     assert asyncinjections.resource2.init_counter == 1 | ||||
|     assert asyncinjections.resource2.shutdown_counter == 1 | ||||
| 
 | ||||
| 
 | ||||
| @mark.asyncio | ||||
| async def test_async_injections_with_closing(): | ||||
|     resource1, resource2 = await asyncinjections.async_injection_with_closing() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user