Implement async mode for Dependency provider

This commit is contained in:
Roman Mogylatov 2020-12-21 22:39:01 -05:00
parent 0c42ff9242
commit 32c4c6e29a
5 changed files with 6735 additions and 5943 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ cdef class Delegate(Provider):
cdef class Dependency(Provider):
cdef object __instance_of
cdef bint __async
cdef class ExternalDependency(Dependency):

View File

@ -452,6 +452,7 @@ cdef class Dependency(Provider):
)
self.__instance_of = instance_of
self.__async = False
super(Dependency, self).__init__()
def __deepcopy__(self, memo):
@ -480,9 +481,19 @@ cdef class Dependency(Provider):
instance = self.__last_overriding(*args, **kwargs)
if not isinstance(instance, self.instance_of):
raise Error('{0} is not an '.format(instance) +
'instance of {0}'.format(self.instance_of))
if __isawaitable(instance):
future_result = asyncio.Future()
instance = asyncio.ensure_future(instance)
instance.add_done_callback(functools.partial(self._async_provide, future_result))
self.__async = True
return future_result
self._check_instance_type(instance)
if self.__async:
result = asyncio.Future()
result.set_result(instance)
return result
return instance
@ -515,6 +526,19 @@ cdef class Dependency(Provider):
"""
return self.override(provider)
def _async_provide(self, future_result, future):
instance = future.result()
try:
self._check_instance_type(instance)
except Error as exception:
future_result.set_exception(exception)
else:
future_result.set_result(instance)
def _check_instance_type(self, instance):
if not isinstance(instance, self.instance_of):
raise Error('{0} is not an instance of {1}'.format(instance, self.instance_of))
cdef class ExternalDependency(Dependency):
""":py:class:`ExternalDependency` provider describes dependency interface.

View File

@ -1,7 +1,7 @@
import asyncio
import random
from dependency_injector import containers, providers
from dependency_injector import containers, providers, errors
# Runtime import to get asyncutils module
import os
@ -459,3 +459,55 @@ class ProvidedInstanceTests(AsyncTestCase):
self.assertIs(instance1.resource, RESOURCE1)
self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource)
class DependencyTests(AsyncTestCase):
def test_isinstance(self):
dependency = 1.0
async def get_async():
return dependency
provider = providers.Dependency(instance_of=float)
provider.override(providers.Callable(get_async))
dependency1 = self._run(provider())
dependency2 = self._run(provider())
self.assertEqual(dependency1, dependency)
self.assertEqual(dependency2, dependency)
def test_isinstance_invalid(self):
async def get_async():
return {}
provider = providers.Dependency(instance_of=float)
provider.override(providers.Callable(get_async))
with self.assertRaises(errors.Error):
self._run(provider())
def test_async_mode(self):
dependency = 123
async def get_async():
return dependency
def get_sync():
return dependency
provider = providers.Dependency(instance_of=int)
provider.override(providers.Factory(get_async))
dependency1 = self._run(provider())
dependency2 = self._run(provider())
self.assertEqual(dependency1, dependency)
self.assertEqual(dependency2, dependency)
provider.override(providers.Factory(get_sync))
dependency3 = self._run(provider())
dependency4 = self._run(provider())
self.assertEqual(dependency3, dependency)
self.assertEqual(dependency4, dependency)