Fix mistakenly processed awaitable objects

This commit is contained in:
Roman Mogylatov 2021-02-17 08:58:32 -05:00
parent 6e59b4ab6f
commit 7c124024b4
4 changed files with 7581 additions and 7374 deletions

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ cdef class Provider(object):
cdef int __async_mode
cpdef object _provide(self, tuple args, dict kwargs)
cpdef object _process_result(self, object result)
cpdef void _copy_overridings(self, Provider copied, dict memo)
@ -544,14 +545,9 @@ cdef inline object __call(
if args_awaitable or kwargs_awaitable:
if not args_awaitable:
future = asyncio.Future()
future.set_result(args)
args = future
args = __future_result(args)
if not kwargs_awaitable:
future = asyncio.Future()
future.set_result(kwargs)
kwargs = future
kwargs = __future_result(kwargs)
future_result = asyncio.Future()
@ -618,15 +614,9 @@ cdef inline object __factory_call(Factory self, tuple args, dict kwargs):
if instance_awaitable or attributes_awaitable:
if not instance_awaitable:
future = asyncio.Future()
future.set_result(instance)
instance = future
instance = __future_result(instance)
if not attributes_awaitable:
future = asyncio.Future()
future.set_result(attributes)
attributes = future
attributes = __future_result(attributes)
return __async_inject_attributes(instance, attributes)
__inject_attributes(instance, attributes)
@ -648,3 +638,15 @@ cdef inline bint __isawaitable(object instance):
return inspect.isawaitable(instance)
return False
cdef inline bint __is_future_or_coroutine(object instance):
if asyncio is None:
return False
return asyncio.isfuture(instance) or asyncio.iscoroutine(instance)
cdef inline object __future_result(object instance):
future_result = asyncio.Future()
future_result.set_result(instance)
return future_result

View File

@ -188,21 +188,7 @@ cdef class Provider(object):
result = self.__last_overriding(*args, **kwargs)
else:
result = self._provide(args, kwargs)
if self.is_async_mode_disabled():
return result
elif self.is_async_mode_enabled():
if not __isawaitable(result):
future_result = asyncio.Future()
future_result.set_result(result)
return future_result
return result
elif self.is_async_mode_undefined():
if __isawaitable(result):
self.enable_async_mode()
else:
self.disable_async_mode()
return result
return self._process_result(result)
def __deepcopy__(self, memo):
"""Create and return full copy of provider."""
@ -386,6 +372,20 @@ cdef class Provider(object):
"""
raise NotImplementedError()
cpdef object _process_result(self, object result):
if self.is_async_mode_disabled():
return result
elif self.is_async_mode_enabled():
if __is_future_or_coroutine(result):
return result
return __future_result(result)
elif self.is_async_mode_undefined():
if __is_future_or_coroutine(result):
self.enable_async_mode()
else:
self.disable_async_mode()
return result
cpdef void _copy_overridings(self, Provider copied, dict memo):
"""Copy provider overridings to a newly copied provider."""
copied.__overridden = deepcopy(self.__overridden, memo)
@ -661,18 +661,16 @@ cdef class Dependency(Provider):
self._check_instance_type(result)
return result
elif self.is_async_mode_enabled():
if __isawaitable(result):
if __is_future_or_coroutine(result):
future_result = asyncio.Future()
result = asyncio.ensure_future(result)
result.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result
else:
self._check_instance_type(result)
future_result = asyncio.Future()
future_result.set_result(result)
return future_result
return __future_result(result)
elif self.is_async_mode_undefined():
if __isawaitable(result):
if __is_future_or_coroutine(result):
self.enable_async_mode()
future_result = asyncio.Future()
@ -2701,7 +2699,7 @@ cdef class Singleton(BaseSingleton):
:rtype: None
"""
if __isawaitable(self.__storage):
if __is_future_or_coroutine(self.__storage):
asyncio.ensure_future(self.__storage).cancel()
self.__storage = None
@ -2710,7 +2708,7 @@ cdef class Singleton(BaseSingleton):
if self.__storage is None:
instance = __factory_call(self.__instantiator, args, kwargs)
if __isawaitable(instance):
if __is_future_or_coroutine(instance):
future_result = asyncio.Future()
instance = asyncio.ensure_future(instance)
instance.add_done_callback(functools.partial(self._async_init_instance, future_result))
@ -2769,7 +2767,7 @@ cdef class ThreadSafeSingleton(BaseSingleton):
:rtype: None
"""
with self.__storage_lock:
if __isawaitable(self.__storage):
if __is_future_or_coroutine(self.__storage):
asyncio.ensure_future(self.__storage).cancel()
self.__storage = None
@ -2783,7 +2781,7 @@ cdef class ThreadSafeSingleton(BaseSingleton):
if self.__storage is None:
instance = __factory_call(self.__instantiator, args, kwargs)
if __isawaitable(instance):
if __is_future_or_coroutine(instance):
future_result = asyncio.Future()
instance = asyncio.ensure_future(instance)
instance.add_done_callback(functools.partial(self._async_init_instance, future_result))
@ -2850,7 +2848,7 @@ cdef class ThreadLocalSingleton(BaseSingleton):
:rtype: None
"""
if __isawaitable(self.__storage.instance):
if __is_future_or_coroutine(self.__storage.instance):
asyncio.ensure_future(self.__storage.instance).cancel()
del self.__storage.instance
@ -2863,7 +2861,7 @@ cdef class ThreadLocalSingleton(BaseSingleton):
except AttributeError:
instance = __factory_call(self.__instantiator, args, kwargs)
if __isawaitable(instance):
if __is_future_or_coroutine(instance):
future_result = asyncio.Future()
instance = asyncio.ensure_future(instance)
instance.add_done_callback(functools.partial(self._async_init_instance, future_result))
@ -3884,7 +3882,7 @@ cdef class AttributeGetter(Provider):
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
if __is_future_or_coroutine(provided):
future_result = asyncio.Future()
provided = asyncio.ensure_future(provided)
provided.add_done_callback(functools.partial(self._async_provide, future_result))
@ -3892,9 +3890,13 @@ cdef class AttributeGetter(Provider):
return getattr(provided, self.__attribute)
def _async_provide(self, future_result, future):
provided = future.result()
result = getattr(provided, self.__attribute)
future_result.set_result(result)
try:
provided = future.result()
result = getattr(provided, self.__attribute)
except Exception:
pass
else:
future_result.set_result(result)
cdef class ItemGetter(Provider):
@ -3950,7 +3952,7 @@ cdef class ItemGetter(Provider):
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
if __is_future_or_coroutine(provided):
future_result = asyncio.Future()
provided = asyncio.ensure_future(provided)
provided.add_done_callback(functools.partial(self._async_provide, future_result))
@ -4050,7 +4052,7 @@ cdef class MethodCaller(Provider):
cpdef object _provide(self, tuple args, dict kwargs):
call = self.__provider()
if __isawaitable(call):
if __is_future_or_coroutine(call):
future_result = asyncio.Future()
call = asyncio.ensure_future(call)
call.add_done_callback(functools.partial(self._async_provide, future_result, args, kwargs))

View File

@ -987,3 +987,65 @@ class AsyncProvidersWithAsyncDependenciesTests(AsyncTestCase):
service = self._run(container.service())
self.assertEquals(service, {'service': 'ok', 'db': {'db': 'ok'}})
class AsyncProviderWithAwaitableObjectTests(AsyncTestCase):
def test(self):
class SomeResource:
def __await__(self):
raise RuntimeError('Should never happen')
async def init_resource():
pool = SomeResource()
yield pool
class Service:
def __init__(self, resource) -> None:
self.resource = resource
class Container(containers.DeclarativeContainer):
resource = providers.Resource(init_resource)
service = providers.Singleton(Service, resource=resource)
container = Container()
self._run(container.init_resources())
self.assertIsInstance(container.service(), asyncio.Future)
self.assertIsInstance(container.resource(), asyncio.Future)
resource = self._run(container.resource())
service = self._run(container.service())
self.assertIsInstance(resource, SomeResource)
self.assertIsInstance(service.resource, SomeResource)
self.assertIs(service.resource, resource)
def test_without_init_resources(self):
class SomeResource:
def __await__(self):
raise RuntimeError('Should never happen')
async def init_resource():
pool = SomeResource()
yield pool
class Service:
def __init__(self, resource) -> None:
self.resource = resource
class Container(containers.DeclarativeContainer):
resource = providers.Resource(init_resource)
service = providers.Singleton(Service, resource=resource)
container = Container()
self.assertIsInstance(container.service(), asyncio.Future)
self.assertIsInstance(container.resource(), asyncio.Future)
resource = self._run(container.resource())
service = self._run(container.service())
self.assertIsInstance(resource, SomeResource)
self.assertIsInstance(service.resource, SomeResource)
self.assertIs(service.resource, resource)