mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-07-02 03:13:15 +03:00
Implement async mode for Dependency provider
This commit is contained in:
parent
0c42ff9242
commit
32c4c6e29a
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -33,6 +33,7 @@ cdef class Delegate(Provider):
|
||||||
|
|
||||||
cdef class Dependency(Provider):
|
cdef class Dependency(Provider):
|
||||||
cdef object __instance_of
|
cdef object __instance_of
|
||||||
|
cdef bint __async
|
||||||
|
|
||||||
|
|
||||||
cdef class ExternalDependency(Dependency):
|
cdef class ExternalDependency(Dependency):
|
||||||
|
|
|
@ -452,6 +452,7 @@ cdef class Dependency(Provider):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.__instance_of = instance_of
|
self.__instance_of = instance_of
|
||||||
|
self.__async = False
|
||||||
super(Dependency, self).__init__()
|
super(Dependency, self).__init__()
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
|
@ -480,9 +481,19 @@ cdef class Dependency(Provider):
|
||||||
|
|
||||||
instance = self.__last_overriding(*args, **kwargs)
|
instance = self.__last_overriding(*args, **kwargs)
|
||||||
|
|
||||||
if not isinstance(instance, self.instance_of):
|
if __isawaitable(instance):
|
||||||
raise Error('{0} is not an '.format(instance) +
|
future_result = asyncio.Future()
|
||||||
'instance of {0}'.format(self.instance_of))
|
instance = asyncio.ensure_future(instance)
|
||||||
|
instance.add_done_callback(functools.partial(self._async_provide, future_result))
|
||||||
|
self.__async = True
|
||||||
|
return future_result
|
||||||
|
|
||||||
|
self._check_instance_type(instance)
|
||||||
|
|
||||||
|
if self.__async:
|
||||||
|
result = asyncio.Future()
|
||||||
|
result.set_result(instance)
|
||||||
|
return result
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
@ -515,6 +526,19 @@ cdef class Dependency(Provider):
|
||||||
"""
|
"""
|
||||||
return self.override(provider)
|
return self.override(provider)
|
||||||
|
|
||||||
|
def _async_provide(self, future_result, future):
|
||||||
|
instance = future.result()
|
||||||
|
try:
|
||||||
|
self._check_instance_type(instance)
|
||||||
|
except Error as exception:
|
||||||
|
future_result.set_exception(exception)
|
||||||
|
else:
|
||||||
|
future_result.set_result(instance)
|
||||||
|
|
||||||
|
def _check_instance_type(self, instance):
|
||||||
|
if not isinstance(instance, self.instance_of):
|
||||||
|
raise Error('{0} is not an instance of {1}'.format(instance, self.instance_of))
|
||||||
|
|
||||||
|
|
||||||
cdef class ExternalDependency(Dependency):
|
cdef class ExternalDependency(Dependency):
|
||||||
""":py:class:`ExternalDependency` provider describes dependency interface.
|
""":py:class:`ExternalDependency` provider describes dependency interface.
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from dependency_injector import containers, providers
|
from dependency_injector import containers, providers, errors
|
||||||
|
|
||||||
# Runtime import to get asyncutils module
|
# Runtime import to get asyncutils module
|
||||||
import os
|
import os
|
||||||
|
@ -459,3 +459,55 @@ class ProvidedInstanceTests(AsyncTestCase):
|
||||||
self.assertIs(instance1.resource, RESOURCE1)
|
self.assertIs(instance1.resource, RESOURCE1)
|
||||||
self.assertIs(instance2.resource, RESOURCE1)
|
self.assertIs(instance2.resource, RESOURCE1)
|
||||||
self.assertIs(instance1.resource, instance2.resource)
|
self.assertIs(instance1.resource, instance2.resource)
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyTests(AsyncTestCase):
|
||||||
|
|
||||||
|
def test_isinstance(self):
|
||||||
|
dependency = 1.0
|
||||||
|
|
||||||
|
async def get_async():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
provider = providers.Dependency(instance_of=float)
|
||||||
|
provider.override(providers.Callable(get_async))
|
||||||
|
|
||||||
|
dependency1 = self._run(provider())
|
||||||
|
dependency2 = self._run(provider())
|
||||||
|
|
||||||
|
self.assertEqual(dependency1, dependency)
|
||||||
|
self.assertEqual(dependency2, dependency)
|
||||||
|
|
||||||
|
def test_isinstance_invalid(self):
|
||||||
|
async def get_async():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
provider = providers.Dependency(instance_of=float)
|
||||||
|
provider.override(providers.Callable(get_async))
|
||||||
|
|
||||||
|
with self.assertRaises(errors.Error):
|
||||||
|
self._run(provider())
|
||||||
|
|
||||||
|
def test_async_mode(self):
|
||||||
|
dependency = 123
|
||||||
|
|
||||||
|
async def get_async():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
def get_sync():
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
provider = providers.Dependency(instance_of=int)
|
||||||
|
provider.override(providers.Factory(get_async))
|
||||||
|
|
||||||
|
dependency1 = self._run(provider())
|
||||||
|
dependency2 = self._run(provider())
|
||||||
|
self.assertEqual(dependency1, dependency)
|
||||||
|
self.assertEqual(dependency2, dependency)
|
||||||
|
|
||||||
|
provider.override(providers.Factory(get_sync))
|
||||||
|
|
||||||
|
dependency3 = self._run(provider())
|
||||||
|
dependency4 = self._run(provider())
|
||||||
|
self.assertEqual(dependency3, dependency)
|
||||||
|
self.assertEqual(dependency4, dependency)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user