diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index 045b8dc7..f829bfba 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -3905,7 +3905,7 @@ cdef class ContextLocalResource(Resource): if self._async_mode == ASYNC_MODE_ENABLED: return NULL_AWAITABLE return - if self._shutdowner_context_var.get(): + if self._shutdowner_context_var.get() != self._none: future = self._shutdowner_context_var.get()(None, None, None) if __is_future_or_coroutine(future): self._reset_all_contex_vars() @@ -3977,7 +3977,7 @@ cdef class ContextLocalResource(Resource): return resource else: self._resource_context_var.set(obj) - self._shutdowner_context_var.set(None) + self._shutdowner_context_var.set(self._none) return self._resource_context_var.get() diff --git a/tests/unit/providers/resource/test_context_local_resource_py38.py b/tests/unit/providers/resource/test_context_local_resource_py38.py index 6fb85aed..2bcc0b9e 100644 --- a/tests/unit/providers/resource/test_context_local_resource_py38.py +++ b/tests/unit/providers/resource/test_context_local_resource_py38.py @@ -4,12 +4,12 @@ import asyncio import decimal import sys from contextlib import contextmanager -from typing import Any from pytest import mark, raises from dependency_injector import containers, errors, providers, resources + def init_fn(*args, **kwargs): return args, kwargs @@ -76,30 +76,27 @@ def test_injection(): assert _init.counter == 1 -def test_injection_in_different_context(): +@mark.asyncio +async def test_injection_in_different_context(): def _init(): return object() async def _async_init(): return object() - class Container(containers.DeclarativeContainer): context_local_resource = providers.ContextLocalResource(_init) async_context_local_resource = providers.ContextLocalResource(_async_init) - loop = asyncio.get_event_loop() container = Container() - obj1 = loop.run_until_complete(container.async_context_local_resource()) - obj2 = loop.run_until_complete(container.async_context_local_resource()) - assert obj1!=obj2 + obj1 = await container.async_context_local_resource() + obj2 = await container.async_context_local_resource() + assert obj1 != obj2 obj3 = container.context_local_resource() obj4 = container.context_local_resource() - assert obj3==obj4 - - + assert obj3 == obj4 def test_init_function(): @@ -121,10 +118,10 @@ def test_init_function(): provider.shutdown() -def test_init_generator(): +def test_init_generator_in_one_context(): def _init(): _init.init_counter += 1 - yield + yield object() _init.shutdown_counter += 1 _init.init_counter = 0 @@ -133,7 +130,10 @@ def test_init_generator(): provider = providers.ContextLocalResource(_init) result1 = provider() - assert result1 is None + result2 = provider() + + assert result1 == result2 + assert _init.init_counter == 1 assert _init.shutdown_counter == 0 @@ -141,17 +141,12 @@ def test_init_generator(): assert _init.init_counter == 1 assert _init.shutdown_counter == 1 - result2 = provider() - assert result2 is None - assert _init.init_counter == 2 + provider.shutdown() + assert _init.init_counter == 1 assert _init.shutdown_counter == 1 - provider.shutdown() - assert _init.init_counter == 2 - assert _init.shutdown_counter == 2 - -def test_init_context_manager() -> None: +def test_init_context_manager_in_one_context() -> None: init_counter, shutdown_counter = 0, 0 @contextmanager @@ -159,7 +154,7 @@ def test_init_context_manager() -> None: nonlocal init_counter, shutdown_counter init_counter += 1 - yield + yield object() shutdown_counter += 1 init_counter = 0 @@ -168,24 +163,77 @@ def test_init_context_manager() -> None: provider = providers.ContextLocalResource(_init) result1 = provider() - assert result1 is None + result2 = provider() + assert result1 == result2 + assert init_counter == 1 assert shutdown_counter == 0 provider.shutdown() + assert init_counter == 1 assert shutdown_counter == 1 - result2 = provider() - assert result2 is None - assert init_counter == 2 + provider.shutdown() + assert init_counter == 1 assert shutdown_counter == 1 - provider.shutdown() + +@mark.asyncio +async def test_async_init_context_manager_in_different_contexts() -> None: + init_counter, shutdown_counter = 0, 0 + + async def _init(): + nonlocal init_counter, shutdown_counter + init_counter += 1 + yield object() + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + async def run_in_context(): + resource = await provider() + await provider.shutdown() + return resource + + result1, result2 = await asyncio.gather(run_in_context(), run_in_context()) + + assert result1 != result2 assert init_counter == 2 assert shutdown_counter == 2 +@mark.asyncio +async def test_async_init_context_manager_in_one_context() -> None: + init_counter, shutdown_counter = 0, 0 + + async def _init(): + nonlocal init_counter, shutdown_counter + init_counter += 1 + yield object() + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + async def run_in_context(): + resource_1 = await provider() + resource_2 = await provider() + await provider.shutdown() + return resource_1, resource_2 + + result1, result2 = await run_in_context() + + assert result1 == result2 + assert init_counter == 1 + assert shutdown_counter == 1 + + def test_init_class(): class TestResource(resources.Resource): init_counter = 0 @@ -218,7 +266,6 @@ def test_init_class(): assert TestResource.shutdown_counter == 2 - def test_init_not_callable(): provider = providers.ContextLocalResource(1) with raises(TypeError, match=r"object is not callable"):