mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-29 13:04:02 +03:00
Add async mode for the provider
This commit is contained in:
parent
32c4c6e29a
commit
b446dab559
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ cimport cython
|
|||
cdef class Provider(object):
|
||||
cdef tuple __overridden
|
||||
cdef Provider __last_overriding
|
||||
cdef int __async_mode
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs)
|
||||
cpdef void _copy_overridings(self, Provider copied, dict memo)
|
||||
|
|
|
@ -169,7 +169,7 @@ class Configuration(Object):
|
|||
class Factory(Provider, Generic[T]):
|
||||
provided_type: Optional[Type]
|
||||
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
|
||||
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
|
||||
def __call__(self, *args: Injection, **kwargs: Injection) -> Union[T, Awaitable[T]]: ...
|
||||
@property
|
||||
def cls(self) -> T: ...
|
||||
@property
|
||||
|
|
|
@ -89,6 +89,11 @@ else:
|
|||
return parser
|
||||
|
||||
|
||||
cdef int ASYNC_MODE_UNDEFINED = 0
|
||||
cdef int ASYNC_MODE_ENABLED = 1
|
||||
cdef int ASYNC_MODE_DISABLED = 2
|
||||
|
||||
|
||||
cdef class Provider(object):
|
||||
"""Base provider class.
|
||||
|
||||
|
@ -149,6 +154,7 @@ cdef class Provider(object):
|
|||
"""Initializer."""
|
||||
self.__overridden = tuple()
|
||||
self.__last_overriding = None
|
||||
self.__async_mode = ASYNC_MODE_UNDEFINED
|
||||
super(Provider, self).__init__()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
@ -157,8 +163,24 @@ cdef class Provider(object):
|
|||
Callable interface implementation.
|
||||
"""
|
||||
if self.__last_overriding is not None:
|
||||
return self.__last_overriding(*args, **kwargs)
|
||||
return self._provide(args, kwargs)
|
||||
result = self.__last_overriding(*args, **kwargs)
|
||||
else:
|
||||
result = self._provide(args, kwargs)
|
||||
|
||||
if self.__async_mode == ASYNC_MODE_DISABLED:
|
||||
return result
|
||||
elif self.__async_mode == ASYNC_MODE_ENABLED:
|
||||
if not __isawaitable(result):
|
||||
future_result = asyncio.Future()
|
||||
future_result.set_result(result)
|
||||
return future_result
|
||||
return result
|
||||
elif self.__async_mode == ASYNC_MODE_UNDEFINED:
|
||||
if __isawaitable(result):
|
||||
self.__async_mode = ASYNC_MODE_ENABLED
|
||||
else:
|
||||
self.__async_mode = ASYNC_MODE_DISABLED
|
||||
return result
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""Create and return full copy of provider."""
|
||||
|
|
|
@ -511,3 +511,62 @@ class DependencyTests(AsyncTestCase):
|
|||
dependency4 = self._run(provider())
|
||||
self.assertEqual(dependency3, dependency)
|
||||
self.assertEqual(dependency4, dependency)
|
||||
|
||||
|
||||
class OverrideTests(AsyncTestCase):
|
||||
|
||||
def test_provider(self):
|
||||
dependency = object()
|
||||
|
||||
async def _get_dependency_async():
|
||||
return dependency
|
||||
|
||||
def _get_dependency_sync():
|
||||
return dependency
|
||||
|
||||
provider = providers.Provider()
|
||||
|
||||
provider.override(providers.Callable(_get_dependency_async))
|
||||
dependency1 = self._run(provider())
|
||||
|
||||
provider.override(providers.Callable(_get_dependency_sync))
|
||||
dependency2 = self._run(provider())
|
||||
|
||||
self.assertIs(dependency1, dependency)
|
||||
self.assertIs(dependency2, dependency)
|
||||
|
||||
def test_callable(self):
|
||||
dependency = object()
|
||||
|
||||
async def _get_dependency_async():
|
||||
return dependency
|
||||
|
||||
def _get_dependency_sync():
|
||||
return dependency
|
||||
|
||||
provider = providers.Callable(_get_dependency_async)
|
||||
dependency1 = self._run(provider())
|
||||
|
||||
provider.override(providers.Callable(_get_dependency_sync))
|
||||
dependency2 = self._run(provider())
|
||||
|
||||
self.assertIs(dependency1, dependency)
|
||||
self.assertIs(dependency2, dependency)
|
||||
|
||||
def test_factory(self):
|
||||
dependency = object()
|
||||
|
||||
async def _get_dependency_async():
|
||||
return dependency
|
||||
|
||||
def _get_dependency_sync():
|
||||
return dependency
|
||||
|
||||
provider = providers.Factory(_get_dependency_async)
|
||||
dependency1 = self._run(provider())
|
||||
|
||||
provider.override(providers.Callable(_get_dependency_sync))
|
||||
dependency2 = self._run(provider())
|
||||
|
||||
self.assertIs(dependency1, dependency)
|
||||
self.assertIs(dependency2, dependency)
|
||||
|
|
Loading…
Reference in New Issue
Block a user