mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-07-01 19:03:19 +03:00
Implement async mode for Dependency provider
This commit is contained in:
parent
0c42ff9242
commit
32c4c6e29a
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -33,6 +33,7 @@ cdef class Delegate(Provider):
|
|||
|
||||
cdef class Dependency(Provider):
|
||||
cdef object __instance_of
|
||||
cdef bint __async
|
||||
|
||||
|
||||
cdef class ExternalDependency(Dependency):
|
||||
|
|
|
@ -452,6 +452,7 @@ cdef class Dependency(Provider):
|
|||
)
|
||||
|
||||
self.__instance_of = instance_of
|
||||
self.__async = False
|
||||
super(Dependency, self).__init__()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
@ -480,9 +481,19 @@ cdef class Dependency(Provider):
|
|||
|
||||
instance = self.__last_overriding(*args, **kwargs)
|
||||
|
||||
if not isinstance(instance, self.instance_of):
|
||||
raise Error('{0} is not an '.format(instance) +
|
||||
'instance of {0}'.format(self.instance_of))
|
||||
if __isawaitable(instance):
|
||||
future_result = asyncio.Future()
|
||||
instance = asyncio.ensure_future(instance)
|
||||
instance.add_done_callback(functools.partial(self._async_provide, future_result))
|
||||
self.__async = True
|
||||
return future_result
|
||||
|
||||
self._check_instance_type(instance)
|
||||
|
||||
if self.__async:
|
||||
result = asyncio.Future()
|
||||
result.set_result(instance)
|
||||
return result
|
||||
|
||||
return instance
|
||||
|
||||
|
@ -515,6 +526,19 @@ cdef class Dependency(Provider):
|
|||
"""
|
||||
return self.override(provider)
|
||||
|
||||
def _async_provide(self, future_result, future):
|
||||
instance = future.result()
|
||||
try:
|
||||
self._check_instance_type(instance)
|
||||
except Error as exception:
|
||||
future_result.set_exception(exception)
|
||||
else:
|
||||
future_result.set_result(instance)
|
||||
|
||||
def _check_instance_type(self, instance):
|
||||
if not isinstance(instance, self.instance_of):
|
||||
raise Error('{0} is not an instance of {1}'.format(instance, self.instance_of))
|
||||
|
||||
|
||||
cdef class ExternalDependency(Dependency):
|
||||
""":py:class:`ExternalDependency` provider describes dependency interface.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import random
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
from dependency_injector import containers, providers, errors
|
||||
|
||||
# Runtime import to get asyncutils module
|
||||
import os
|
||||
|
@ -459,3 +459,55 @@ class ProvidedInstanceTests(AsyncTestCase):
|
|||
self.assertIs(instance1.resource, RESOURCE1)
|
||||
self.assertIs(instance2.resource, RESOURCE1)
|
||||
self.assertIs(instance1.resource, instance2.resource)
|
||||
|
||||
|
||||
class DependencyTests(AsyncTestCase):
|
||||
|
||||
def test_isinstance(self):
|
||||
dependency = 1.0
|
||||
|
||||
async def get_async():
|
||||
return dependency
|
||||
|
||||
provider = providers.Dependency(instance_of=float)
|
||||
provider.override(providers.Callable(get_async))
|
||||
|
||||
dependency1 = self._run(provider())
|
||||
dependency2 = self._run(provider())
|
||||
|
||||
self.assertEqual(dependency1, dependency)
|
||||
self.assertEqual(dependency2, dependency)
|
||||
|
||||
def test_isinstance_invalid(self):
|
||||
async def get_async():
|
||||
return {}
|
||||
|
||||
provider = providers.Dependency(instance_of=float)
|
||||
provider.override(providers.Callable(get_async))
|
||||
|
||||
with self.assertRaises(errors.Error):
|
||||
self._run(provider())
|
||||
|
||||
def test_async_mode(self):
|
||||
dependency = 123
|
||||
|
||||
async def get_async():
|
||||
return dependency
|
||||
|
||||
def get_sync():
|
||||
return dependency
|
||||
|
||||
provider = providers.Dependency(instance_of=int)
|
||||
provider.override(providers.Factory(get_async))
|
||||
|
||||
dependency1 = self._run(provider())
|
||||
dependency2 = self._run(provider())
|
||||
self.assertEqual(dependency1, dependency)
|
||||
self.assertEqual(dependency2, dependency)
|
||||
|
||||
provider.override(providers.Factory(get_sync))
|
||||
|
||||
dependency3 = self._run(provider())
|
||||
dependency4 = self._run(provider())
|
||||
self.assertEqual(dependency3, dependency)
|
||||
self.assertEqual(dependency4, dependency)
|
||||
|
|
Loading…
Reference in New Issue
Block a user