Add tests for async resources

This commit is contained in:
Roman Mogylatov 2020-12-01 18:36:50 -05:00
parent dea1033371
commit 31b03243a4
6 changed files with 3062 additions and 2580 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -196,6 +196,7 @@ cdef class Resource(Provider):
cdef bint __initialized cdef bint __initialized
cdef object __shutdowner cdef object __shutdowner
cdef object __resource cdef object __resource
cdef bint __async
cdef tuple __args cdef tuple __args
cdef int __args_len cdef int __args_len

View File

@ -300,7 +300,7 @@ class Resource(Provider, Generic[T]):
def clear_kwargs(self) -> Resource: ... def clear_kwargs(self) -> Resource: ...
@property @property
def initialized(self) -> bool: ... def initialized(self) -> bool: ...
def init(self) -> T: ... def init(self) -> Optional[Awaitable[T]]: ...
def shutdown(self) -> Optional[Awaitable]: ... def shutdown(self) -> Optional[Awaitable]: ...

View File

@ -2574,6 +2574,7 @@ cdef class Resource(Provider):
self.__initialized = False self.__initialized = False
self.__resource = None self.__resource = None
self.__shutdowner = None self.__shutdowner = None
self.__async = False
self.__args = tuple() self.__args = tuple()
self.__args_len = 0 self.__args_len = 0
@ -2705,6 +2706,10 @@ cdef class Resource(Provider):
def shutdown(self): def shutdown(self):
"""Shutdown resource.""" """Shutdown resource."""
if not self.__initialized: if not self.__initialized:
if self.__async:
result = asyncio.Future()
result.set_result(None)
return result
return return
if self.__shutdowner: if self.__shutdowner:
@ -2720,8 +2725,18 @@ cdef class Resource(Provider):
self.__initialized = False self.__initialized = False
self.__shutdowner = None self.__shutdowner = None
if self.__async:
result = asyncio.Future()
result.set_result(None)
return result
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
if self.__initialized: if self.__initialized:
if self.__async:
result = asyncio.Future()
result.set_result(self.__resource)
return result
return self.__resource return self.__resource
if self._is_resource_subclass(self.__initializer): if self._is_resource_subclass(self.__initializer):
@ -2748,6 +2763,7 @@ cdef class Resource(Provider):
self.__kwargs_len, self.__kwargs_len,
) )
self.__initialized = True self.__initialized = True
self.__async = True
return __async_resource_init(self, async_init, initializer.shutdown) return __async_resource_init(self, async_init, initializer.shutdown)
elif inspect.isgeneratorfunction(self.__initializer): elif inspect.isgeneratorfunction(self.__initializer):
initializer = __call( initializer = __call(
@ -2772,6 +2788,7 @@ cdef class Resource(Provider):
self.__kwargs_len, self.__kwargs_len,
) )
self.__initialized = True self.__initialized = True
self.__async = True
return __async_resource_init(self, initializer) return __async_resource_init(self, initializer)
elif isasyncgenfunction(self.__initializer): elif isasyncgenfunction(self.__initializer):
initializer = __call( initializer = __call(
@ -2784,6 +2801,7 @@ cdef class Resource(Provider):
self.__kwargs_len, self.__kwargs_len,
) )
self.__initialized = True self.__initialized = True
self.__async = True
return __async_resource_init(self, initializer.__anext__(), initializer.asend) return __async_resource_init(self, initializer.__anext__(), initializer.asend)
elif callable(self.__initializer): elif callable(self.__initializer):
self.__resource = __call( self.__resource = __call(

View File

@ -1,11 +1,24 @@
"""Dependency injector resource provider unit tests.""" """Dependency injector resource provider unit tests."""
import sys import asyncio
import unittest2 as unittest import unittest2 as unittest
from dependency_injector import containers, providers, resources, errors from dependency_injector import containers, providers, resources, errors
# Runtime import to get asyncutils module
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
import sys
sys.path.append(_TOP_DIR)
from asyncutils import AsyncTestCase
def init_fn(*args, **kwargs): def init_fn(*args, **kwargs):
return args, kwargs return args, kwargs
@ -320,3 +333,128 @@ class ResourceTests(unittest.TestCase):
provider.initialized, provider.initialized,
) )
) )
class AsyncResourceTest(AsyncTestCase):
def test_init_async_function(self):
resource = object()
async def _init():
await asyncio.sleep(0.001)
_init.counter += 1
return resource
_init.counter = 0
provider = providers.Resource(_init)
result1 = self._run(provider())
self.assertIs(result1, resource)
self.assertEqual(_init.counter, 1)
result2 = self._run(provider())
self.assertIs(result2, resource)
self.assertEqual(_init.counter, 1)
self._run(provider.shutdown())
def test_init_async_generator(self):
resource = object()
async def _init():
await asyncio.sleep(0.001)
_init.init_counter += 1
yield resource
await asyncio.sleep(0.001)
_init.shutdown_counter += 1
_init.init_counter = 0
_init.shutdown_counter = 0
provider = providers.Resource(_init)
result1 = self._run(provider())
self.assertIs(result1, resource)
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 0)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 1)
result2 = self._run(provider())
self.assertIs(result2, resource)
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 1)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 2)
def test_init_async_class(self):
resource = object()
class TestResource(resources.AsyncResource):
init_counter = 0
shutdown_counter = 0
async def init(self):
self.__class__.init_counter += 1
return resource
async def shutdown(self, resource_):
self.__class__.shutdown_counter += 1
assert resource_ is resource
provider = providers.Resource(TestResource)
result1 = self._run(provider())
self.assertIs(result1, resource)
self.assertEqual(TestResource.init_counter, 1)
self.assertEqual(TestResource.shutdown_counter, 0)
self._run(provider.shutdown())
self.assertEqual(TestResource.init_counter, 1)
self.assertEqual(TestResource.shutdown_counter, 1)
result2 = self._run(provider())
self.assertIs(result2, resource)
self.assertEqual(TestResource.init_counter, 2)
self.assertEqual(TestResource.shutdown_counter, 1)
self._run(provider.shutdown())
self.assertEqual(TestResource.init_counter, 2)
self.assertEqual(TestResource.shutdown_counter, 2)
def test_init_and_shutdown_methods(self):
async def _init():
await asyncio.sleep(0.001)
_init.init_counter += 1
yield
await asyncio.sleep(0.001)
_init.shutdown_counter += 1
_init.init_counter = 0
_init.shutdown_counter = 0
provider = providers.Resource(_init)
self._run(provider.init())
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 0)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 1)
self._run(provider.init())
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 1)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 2)