mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-12-01 22:14:04 +03:00
Add async mode for the provider
This commit is contained in:
parent
32c4c6e29a
commit
b446dab559
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ cimport cython
|
||||||
cdef class Provider(object):
|
cdef class Provider(object):
|
||||||
cdef tuple __overridden
|
cdef tuple __overridden
|
||||||
cdef Provider __last_overriding
|
cdef Provider __last_overriding
|
||||||
|
cdef int __async_mode
|
||||||
|
|
||||||
cpdef object _provide(self, tuple args, dict kwargs)
|
cpdef object _provide(self, tuple args, dict kwargs)
|
||||||
cpdef void _copy_overridings(self, Provider copied, dict memo)
|
cpdef void _copy_overridings(self, Provider copied, dict memo)
|
||||||
|
|
|
@ -169,7 +169,7 @@ class Configuration(Object):
|
||||||
class Factory(Provider, Generic[T]):
|
class Factory(Provider, Generic[T]):
|
||||||
provided_type: Optional[Type]
|
provided_type: Optional[Type]
|
||||||
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
|
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
|
||||||
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
|
def __call__(self, *args: Injection, **kwargs: Injection) -> Union[T, Awaitable[T]]: ...
|
||||||
@property
|
@property
|
||||||
def cls(self) -> T: ...
|
def cls(self) -> T: ...
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -89,6 +89,11 @@ else:
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
cdef int ASYNC_MODE_UNDEFINED = 0
|
||||||
|
cdef int ASYNC_MODE_ENABLED = 1
|
||||||
|
cdef int ASYNC_MODE_DISABLED = 2
|
||||||
|
|
||||||
|
|
||||||
cdef class Provider(object):
|
cdef class Provider(object):
|
||||||
"""Base provider class.
|
"""Base provider class.
|
||||||
|
|
||||||
|
@ -149,6 +154,7 @@ cdef class Provider(object):
|
||||||
"""Initializer."""
|
"""Initializer."""
|
||||||
self.__overridden = tuple()
|
self.__overridden = tuple()
|
||||||
self.__last_overriding = None
|
self.__last_overriding = None
|
||||||
|
self.__async_mode = ASYNC_MODE_UNDEFINED
|
||||||
super(Provider, self).__init__()
|
super(Provider, self).__init__()
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
@ -157,8 +163,24 @@ cdef class Provider(object):
|
||||||
Callable interface implementation.
|
Callable interface implementation.
|
||||||
"""
|
"""
|
||||||
if self.__last_overriding is not None:
|
if self.__last_overriding is not None:
|
||||||
return self.__last_overriding(*args, **kwargs)
|
result = self.__last_overriding(*args, **kwargs)
|
||||||
return self._provide(args, kwargs)
|
else:
|
||||||
|
result = self._provide(args, kwargs)
|
||||||
|
|
||||||
|
if self.__async_mode == ASYNC_MODE_DISABLED:
|
||||||
|
return result
|
||||||
|
elif self.__async_mode == ASYNC_MODE_ENABLED:
|
||||||
|
if not __isawaitable(result):
|
||||||
|
future_result = asyncio.Future()
|
||||||
|
future_result.set_result(result)
|
||||||
|
return future_result
|
||||||
|
return result
|
||||||
|
elif self.__async_mode == ASYNC_MODE_UNDEFINED:
|
||||||
|
if __isawaitable(result):
|
||||||
|
self.__async_mode = ASYNC_MODE_ENABLED
|
||||||
|
else:
|
||||||
|
self.__async_mode = ASYNC_MODE_DISABLED
|
||||||
|
return result
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
"""Create and return full copy of provider."""
|
"""Create and return full copy of provider."""
|
||||||
|
|
|
@ -511,3 +511,62 @@ class DependencyTests(AsyncTestCase):
|
||||||
dependency4 = self._run(provider())
|
dependency4 = self._run(provider())
|
||||||
self.assertEqual(dependency3, dependency)
|
self.assertEqual(dependency3, dependency)
|
||||||
self.assertEqual(dependency4, dependency)
|
self.assertEqual(dependency4, dependency)
|
||||||
|
|
||||||
|
|
||||||
|
class OverrideTests(AsyncTestCase):
|
||||||
|
|
||||||
|
def test_provider(self):
|
||||||
|
dependency = object()
|
||||||
|
|
||||||
|
async def _get_dependency_async():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
def _get_dependency_sync():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
provider = providers.Provider()
|
||||||
|
|
||||||
|
provider.override(providers.Callable(_get_dependency_async))
|
||||||
|
dependency1 = self._run(provider())
|
||||||
|
|
||||||
|
provider.override(providers.Callable(_get_dependency_sync))
|
||||||
|
dependency2 = self._run(provider())
|
||||||
|
|
||||||
|
self.assertIs(dependency1, dependency)
|
||||||
|
self.assertIs(dependency2, dependency)
|
||||||
|
|
||||||
|
def test_callable(self):
|
||||||
|
dependency = object()
|
||||||
|
|
||||||
|
async def _get_dependency_async():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
def _get_dependency_sync():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
provider = providers.Callable(_get_dependency_async)
|
||||||
|
dependency1 = self._run(provider())
|
||||||
|
|
||||||
|
provider.override(providers.Callable(_get_dependency_sync))
|
||||||
|
dependency2 = self._run(provider())
|
||||||
|
|
||||||
|
self.assertIs(dependency1, dependency)
|
||||||
|
self.assertIs(dependency2, dependency)
|
||||||
|
|
||||||
|
def test_factory(self):
|
||||||
|
dependency = object()
|
||||||
|
|
||||||
|
async def _get_dependency_async():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
def _get_dependency_sync():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
provider = providers.Factory(_get_dependency_async)
|
||||||
|
dependency1 = self._run(provider())
|
||||||
|
|
||||||
|
provider.override(providers.Callable(_get_dependency_sync))
|
||||||
|
dependency2 = self._run(provider())
|
||||||
|
|
||||||
|
self.assertIs(dependency1, dependency)
|
||||||
|
self.assertIs(dependency2, dependency)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user