Add async mode for the provider

This commit is contained in:
Roman Mogylatov 2020-12-23 21:09:07 -05:00
parent 32c4c6e29a
commit b446dab559
6 changed files with 12689 additions and 11731 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ cimport cython
cdef class Provider(object): cdef class Provider(object):
cdef tuple __overridden cdef tuple __overridden
cdef Provider __last_overriding cdef Provider __last_overriding
cdef int __async_mode
cpdef object _provide(self, tuple args, dict kwargs) cpdef object _provide(self, tuple args, dict kwargs)
cpdef void _copy_overridings(self, Provider copied, dict memo) cpdef void _copy_overridings(self, Provider copied, dict memo)

View File

@ -169,7 +169,7 @@ class Configuration(Object):
class Factory(Provider, Generic[T]): class Factory(Provider, Generic[T]):
provided_type: Optional[Type] provided_type: Optional[Type]
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ... def __call__(self, *args: Injection, **kwargs: Injection) -> Union[T, Awaitable[T]]: ...
@property @property
def cls(self) -> T: ... def cls(self) -> T: ...
@property @property

View File

@ -89,6 +89,11 @@ else:
return parser return parser
cdef int ASYNC_MODE_UNDEFINED = 0
cdef int ASYNC_MODE_ENABLED = 1
cdef int ASYNC_MODE_DISABLED = 2
cdef class Provider(object): cdef class Provider(object):
"""Base provider class. """Base provider class.
@ -149,6 +154,7 @@ cdef class Provider(object):
"""Initializer.""" """Initializer."""
self.__overridden = tuple() self.__overridden = tuple()
self.__last_overriding = None self.__last_overriding = None
self.__async_mode = ASYNC_MODE_UNDEFINED
super(Provider, self).__init__() super(Provider, self).__init__()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
@ -157,8 +163,24 @@ cdef class Provider(object):
Callable interface implementation. Callable interface implementation.
""" """
if self.__last_overriding is not None: if self.__last_overriding is not None:
return self.__last_overriding(*args, **kwargs) result = self.__last_overriding(*args, **kwargs)
return self._provide(args, kwargs) else:
result = self._provide(args, kwargs)
if self.__async_mode == ASYNC_MODE_DISABLED:
return result
elif self.__async_mode == ASYNC_MODE_ENABLED:
if not __isawaitable(result):
future_result = asyncio.Future()
future_result.set_result(result)
return future_result
return result
elif self.__async_mode == ASYNC_MODE_UNDEFINED:
if __isawaitable(result):
self.__async_mode = ASYNC_MODE_ENABLED
else:
self.__async_mode = ASYNC_MODE_DISABLED
return result
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
"""Create and return full copy of provider.""" """Create and return full copy of provider."""

View File

@ -511,3 +511,62 @@ class DependencyTests(AsyncTestCase):
dependency4 = self._run(provider()) dependency4 = self._run(provider())
self.assertEqual(dependency3, dependency) self.assertEqual(dependency3, dependency)
self.assertEqual(dependency4, dependency) self.assertEqual(dependency4, dependency)
class OverrideTests(AsyncTestCase):
def test_provider(self):
dependency = object()
async def _get_dependency_async():
return dependency
def _get_dependency_sync():
return dependency
provider = providers.Provider()
provider.override(providers.Callable(_get_dependency_async))
dependency1 = self._run(provider())
provider.override(providers.Callable(_get_dependency_sync))
dependency2 = self._run(provider())
self.assertIs(dependency1, dependency)
self.assertIs(dependency2, dependency)
def test_callable(self):
dependency = object()
async def _get_dependency_async():
return dependency
def _get_dependency_sync():
return dependency
provider = providers.Callable(_get_dependency_async)
dependency1 = self._run(provider())
provider.override(providers.Callable(_get_dependency_sync))
dependency2 = self._run(provider())
self.assertIs(dependency1, dependency)
self.assertIs(dependency2, dependency)
def test_factory(self):
dependency = object()
async def _get_dependency_async():
return dependency
def _get_dependency_sync():
return dependency
provider = providers.Factory(_get_dependency_async)
dependency1 = self._run(provider())
provider.override(providers.Callable(_get_dependency_sync))
dependency2 = self._run(provider())
self.assertIs(dependency1, dependency)
self.assertIs(dependency2, dependency)