Update .provided + fix resource concurent initialization issue

This commit is contained in:
Roman Mogylatov 2020-12-17 23:34:13 -05:00
parent 90a6cb3c6d
commit 0c42ff9242
4 changed files with 2699 additions and 1354 deletions

File diff suppressed because it is too large Load Diff

View File

@ -2794,6 +2794,8 @@ cdef class Resource(Provider):
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
if self.__initialized: if self.__initialized:
if self.__async: if self.__async:
if __isawaitable(self.__resource):
return self.__resource
result = asyncio.Future() result = asyncio.Future()
result.set_result(self.__resource) result.set_result(self.__resource)
return result return result
@ -2887,6 +2889,8 @@ cdef class Resource(Provider):
future = asyncio.ensure_future(future) future = asyncio.ensure_future(future)
future.add_done_callback(callback) future.add_done_callback(callback)
self.__resource = future
return future return future
def _async_init_callback(self, initializer, shutdowner=None): def _async_init_callback(self, initializer, shutdowner=None):
@ -3202,8 +3206,18 @@ cdef class AttributeGetter(Provider):
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs) provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
future_result = asyncio.Future()
provided = asyncio.ensure_future(provided)
provided.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result
return getattr(provided, self.__attribute) return getattr(provided, self.__attribute)
def _async_provide(self, future_result, future):
provided = future.result()
result = getattr(provided, self.__attribute)
future_result.set_result(result)
cdef class ItemGetter(Provider): cdef class ItemGetter(Provider):
"""Provider that returns the item of the injected instance. """Provider that returns the item of the injected instance.
@ -3252,8 +3266,18 @@ cdef class ItemGetter(Provider):
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs) provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
future_result = asyncio.Future()
provided = asyncio.ensure_future(provided)
provided.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result
return provided[self.__item] return provided[self.__item]
def _async_provide(self, future_result, future):
provided = future.result()
result = provided[self.__item]
future_result.set_result(result)
cdef class MethodCaller(Provider): cdef class MethodCaller(Provider):
"""Provider that calls the method of the injected instance. """Provider that calls the method of the injected instance.
@ -3334,6 +3358,11 @@ cdef class MethodCaller(Provider):
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
call = self.__provider() call = self.__provider()
if __isawaitable(call):
future_result = asyncio.Future()
call = asyncio.ensure_future(call)
call.add_done_callback(functools.partial(self._async_provide, future_result, args, kwargs))
return future_result
return __call( return __call(
call, call,
args, args,
@ -3344,6 +3373,19 @@ cdef class MethodCaller(Provider):
self.__kwargs_len, self.__kwargs_len,
) )
def _async_provide(self, future_result, args, kwargs, future):
call = future.result()
result = __call(
call,
args,
self.__args,
self.__args_len,
kwargs,
self.__kwargs,
self.__kwargs_len,
)
future_result.set_result(result)
cdef class Injection(object): cdef class Injection(object):
"""Abstract injection class.""" """Abstract injection class."""

View File

@ -369,3 +369,93 @@ class DelegatedThreadLocalSingletonTests(AsyncTestCase):
self.assertIs(instance1, instance2) self.assertIs(instance1, instance2)
self.assertIs(instance, instance) self.assertIs(instance, instance)
class ProvidedInstanceTests(AsyncTestCase):
def test_provided_attribute(self):
class TestClient:
def __init__(self, resource):
self.resource = resource
class TestService:
def __init__(self, resource):
self.resource = resource
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(TestClient, resource=resource)
service = providers.Factory(TestService, resource=client.provided.resource)
container = TestContainer()
instance1, instance2 = self._run(
asyncio.gather(
container.service(),
container.service(),
),
)
self.assertIs(instance1.resource, RESOURCE1)
self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource)
def test_provided_item(self):
class TestClient:
def __init__(self, resource):
self.resource = resource
def __getitem__(self, item):
return getattr(self, item)
class TestService:
def __init__(self, resource):
self.resource = resource
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(TestClient, resource=resource)
service = providers.Factory(TestService, resource=client.provided['resource'])
container = TestContainer()
instance1, instance2 = self._run(
asyncio.gather(
container.service(),
container.service(),
),
)
self.assertIs(instance1.resource, RESOURCE1)
self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource)
def test_provided_method_call(self):
class TestClient:
def __init__(self, resource):
self.resource = resource
def get_resource(self):
return self.resource
class TestService:
def __init__(self, resource):
self.resource = resource
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(TestClient, resource=resource)
service = providers.Factory(TestService, resource=client.provided.get_resource.call())
container = TestContainer()
instance1, instance2 = self._run(
asyncio.gather(
container.service(),
container.service(),
),
)
self.assertIs(instance1.resource, RESOURCE1)
self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource)

View File

@ -460,3 +460,27 @@ class AsyncResourceTest(AsyncTestCase):
self._run(provider.shutdown()) self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 2) self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 2) self.assertEqual(_init.shutdown_counter, 2)
def test_concurent_init(self):
resource = object()
async def _init():
await asyncio.sleep(0.001)
_init.counter += 1
return resource
_init.counter = 0
provider = providers.Resource(_init)
result1, result2 = self._run(
asyncio.gather(
provider(),
provider()
),
)
self.assertIs(result1, resource)
self.assertEqual(_init.counter, 1)
self.assertIs(result2, resource)
self.assertEqual(_init.counter, 1)