diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index 4e8ff165..b773cce0 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -3837,32 +3837,28 @@ cdef class Resource(Provider): async def _handle_async_cm(self, obj) -> None: try: - self._resource = resource = await obj.__aenter__() - self._shutdowner = obj.__aexit__ + resource = await obj.__aenter__() return resource except: self._initialized = False raise - async def _provide_async(self, future) -> None: - try: - obj = await future + async def _provide_async(self, future): + obj = await future - if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): - self._resource = await obj.__aenter__() - self._shutdowner = obj.__aexit__ - elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): - self._resource = obj.__enter__() - self._shutdowner = obj.__exit__ - else: - self._resource = obj - self._shutdowner = None + if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): + resource = await obj.__aenter__() + shutdowner = obj.__aexit__ + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + resource = obj.__enter__() + shutdowner = obj.__exit__ + else: + resource = obj + shutdowner = None - return self._resource - except: - self._initialized = False - raise + return resource, shutdowner + cpdef object _provide(self, tuple args, dict kwargs): if self._initialized: return self._resource @@ -3880,14 +3876,18 @@ cdef class Resource(Provider): if __is_future_or_coroutine(obj): self._initialized = True - self._resource = resource = ensure_future(self._provide_async(obj)) - return resource + future_result = asyncio.Future() + future = ensure_future(self._provide_async(obj)) + future.add_done_callback(functools.partial(self._async_init_instance, future_result)) + self._resource = future_result + return self._resource elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): self._resource = obj.__enter__() self._shutdowner = obj.__exit__ elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): self._initialized = True self._resource = resource = ensure_future(self._handle_async_cm(obj)) + self._shutdowner = obj.__aexit__ return resource else: self._resource = obj @@ -3896,14 +3896,27 @@ cdef class Resource(Provider): self._initialized = True return self._resource + def _async_init_instance(self, future_result, result): + try: + resource, shutdowner = result.result() + except Exception as exception: + self._resource = None + self._shutdowner = None + self._initialized = False + future_result.set_exception(exception) + else: + self._resource = resource + self._shutdowner = shutdowner + future_result.set_result(resource) + cdef class ContextLocalResource(Resource): _none = object() def __init__(self, provides=None, *args, **kwargs): self._initialized_context_var = ContextVar("_initialized_context_var", default=False) - self._resource_context_var = ContextVar("_resource_context_var", default=self._none) - self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none) + self._resource_context_var = ContextVar("_resource_context_var", default=None) + self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=None) super().__init__(provides, *args, **kwargs) @property @@ -3945,7 +3958,7 @@ cdef class ContextLocalResource(Resource): return NULL_AWAITABLE return - if self._shutdowner != self._none: + if self._shutdowner != None: future = self._shutdowner(None, None, None) if __is_future_or_coroutine(future): self._reset_all_contex_vars() @@ -3958,79 +3971,8 @@ cdef class ContextLocalResource(Resource): def _reset_all_contex_vars(self): self._initialized=False - self._resource = self._none - self._shutdowner = self._none - - async def _handle_async_cm(self, obj) -> None: - resource = await obj.__aenter__() - return resource - - async def _provide_async(self, future): - obj = await future - - if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): - resource = await obj.__aenter__() - shutdowner = obj.__aexit__ - elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): - resource = obj.__enter__() - shutdowner = obj.__exit__ - else: - resource = obj - shutdowner = self._none - - return resource, shutdowner - - - cpdef object _provide(self, tuple args, dict kwargs): - if self._initialized: - return self._resource - obj = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) - - if __is_future_or_coroutine(obj): - future_result = asyncio.Future() - future = ensure_future(self._provide_async(obj)) - future.add_done_callback(functools.partial(self._async_init_instance, future_result)) - return future_result - elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): - resource = obj.__enter__() - self._resource = resource - self._initialized = True - self._shutdowner = obj.__exit__ - elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): - resource = ensure_future(self._handle_async_cm(obj)) - self._resource = resource - self._initialized = True - self._shutdowner = obj.__aexit__ - return resource - else: - self._resource = obj - self._initialized = True - self._shutdowner = self._none - - return self._resource - - def _async_init_instance(self, future_result, result): - try: - resource, shutdowner = result.result() - except Exception as exception: - self._resource = self._none - self._shutdowner = self._none - self._initialized = False - future_result.set_exception(exception) - else: - self._resource = resource - self._initialized = True - self._shutdowner = shutdowner - future_result.set_result(resource) + self._resource = None + self._shutdowner = None cdef class Container(Provider): diff --git a/tests/unit/providers/resource/test_context_local_resource_py38.py b/tests/unit/providers/resource/test_context_local_resource_py38.py index 2bcc0b9e..3a0452b9 100644 --- a/tests/unit/providers/resource/test_context_local_resource_py38.py +++ b/tests/unit/providers/resource/test_context_local_resource_py38.py @@ -88,16 +88,27 @@ async def test_injection_in_different_context(): context_local_resource = providers.ContextLocalResource(_init) async_context_local_resource = providers.ContextLocalResource(_async_init) + async def run_in_context(): + obj = await container.async_context_local_resource() + return obj + container = Container() - obj1 = await container.async_context_local_resource() - obj2 = await container.async_context_local_resource() + + obj1, obj2 = await asyncio.gather(run_in_context(), run_in_context()) assert obj1 != obj2 - obj3 = container.context_local_resource() - obj4 = container.context_local_resource() - + obj3 = await container.async_context_local_resource() + obj4 = await container.async_context_local_resource() assert obj3 == obj4 + obj5, obj6 = await asyncio.gather(run_in_context(), run_in_context()) + assert obj5 == obj6 # as context is copied from the current one where async_context_local_resource was initialized + + obj7 = container.context_local_resource() + obj8 = container.context_local_resource() + + assert obj7 == obj8 + def test_init_function(): def _init(): @@ -329,37 +340,27 @@ def test_call_with_context_args(): def test_fluent_interface(): - provider = providers.ContextLocalResource(init_fn) \ - .add_args(1, 2) \ - .add_kwargs(a3=3, a4=4) + provider = providers.ContextLocalResource(init_fn).add_args(1, 2).add_kwargs(a3=3, a4=4) assert provider() == ((1, 2), {"a3": 3, "a4": 4}) def test_set_args(): - provider = providers.ContextLocalResource(init_fn) \ - .add_args(1, 2) \ - .set_args(3, 4) + provider = providers.ContextLocalResource(init_fn).add_args(1, 2).set_args(3, 4) assert provider.args == (3, 4) def test_clear_args(): - provider = providers.ContextLocalResource(init_fn) \ - .add_args(1, 2) \ - .clear_args() + provider = providers.ContextLocalResource(init_fn).add_args(1, 2).clear_args() assert provider.args == tuple() def test_set_kwargs(): - provider = providers.ContextLocalResource(init_fn) \ - .add_kwargs(a1="i1", a2="i2") \ - .set_kwargs(a3="i3", a4="i4") + provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").set_kwargs(a3="i3", a4="i4") assert provider.kwargs == {"a3": "i3", "a4": "i4"} def test_clear_kwargs(): - provider = providers.ContextLocalResource(init_fn) \ - .add_kwargs(a1="i1", a2="i2") \ - .clear_kwargs() + provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").clear_kwargs() assert provider.kwargs == {}