fix shutdowner default none value, add more tests

This commit is contained in:
elina-israyelyan 2025-10-23 00:32:52 +04:00
parent 4838ea6aa8
commit 78cea35db9
2 changed files with 77 additions and 30 deletions

View File

@ -3905,7 +3905,7 @@ cdef class ContextLocalResource(Resource):
if self._async_mode == ASYNC_MODE_ENABLED: if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE return NULL_AWAITABLE
return return
if self._shutdowner_context_var.get(): if self._shutdowner_context_var.get() != self._none:
future = self._shutdowner_context_var.get()(None, None, None) future = self._shutdowner_context_var.get()(None, None, None)
if __is_future_or_coroutine(future): if __is_future_or_coroutine(future):
self._reset_all_contex_vars() self._reset_all_contex_vars()
@ -3977,7 +3977,7 @@ cdef class ContextLocalResource(Resource):
return resource return resource
else: else:
self._resource_context_var.set(obj) self._resource_context_var.set(obj)
self._shutdowner_context_var.set(None) self._shutdowner_context_var.set(self._none)
return self._resource_context_var.get() return self._resource_context_var.get()

View File

@ -4,12 +4,12 @@ import asyncio
import decimal import decimal
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any
from pytest import mark, raises from pytest import mark, raises
from dependency_injector import containers, errors, providers, resources from dependency_injector import containers, errors, providers, resources
def init_fn(*args, **kwargs): def init_fn(*args, **kwargs):
return args, kwargs return args, kwargs
@ -76,30 +76,27 @@ def test_injection():
assert _init.counter == 1 assert _init.counter == 1
def test_injection_in_different_context(): @mark.asyncio
async def test_injection_in_different_context():
def _init(): def _init():
return object() return object()
async def _async_init(): async def _async_init():
return object() return object()
class Container(containers.DeclarativeContainer): class Container(containers.DeclarativeContainer):
context_local_resource = providers.ContextLocalResource(_init) context_local_resource = providers.ContextLocalResource(_init)
async_context_local_resource = providers.ContextLocalResource(_async_init) async_context_local_resource = providers.ContextLocalResource(_async_init)
loop = asyncio.get_event_loop()
container = Container() container = Container()
obj1 = loop.run_until_complete(container.async_context_local_resource()) obj1 = await container.async_context_local_resource()
obj2 = loop.run_until_complete(container.async_context_local_resource()) obj2 = await container.async_context_local_resource()
assert obj1!=obj2 assert obj1 != obj2
obj3 = container.context_local_resource() obj3 = container.context_local_resource()
obj4 = container.context_local_resource() obj4 = container.context_local_resource()
assert obj3==obj4 assert obj3 == obj4
def test_init_function(): def test_init_function():
@ -121,10 +118,10 @@ def test_init_function():
provider.shutdown() provider.shutdown()
def test_init_generator(): def test_init_generator_in_one_context():
def _init(): def _init():
_init.init_counter += 1 _init.init_counter += 1
yield yield object()
_init.shutdown_counter += 1 _init.shutdown_counter += 1
_init.init_counter = 0 _init.init_counter = 0
@ -133,7 +130,10 @@ def test_init_generator():
provider = providers.ContextLocalResource(_init) provider = providers.ContextLocalResource(_init)
result1 = provider() result1 = provider()
assert result1 is None result2 = provider()
assert result1 == result2
assert _init.init_counter == 1 assert _init.init_counter == 1
assert _init.shutdown_counter == 0 assert _init.shutdown_counter == 0
@ -141,17 +141,12 @@ def test_init_generator():
assert _init.init_counter == 1 assert _init.init_counter == 1
assert _init.shutdown_counter == 1 assert _init.shutdown_counter == 1
result2 = provider() provider.shutdown()
assert result2 is None assert _init.init_counter == 1
assert _init.init_counter == 2
assert _init.shutdown_counter == 1 assert _init.shutdown_counter == 1
provider.shutdown()
assert _init.init_counter == 2
assert _init.shutdown_counter == 2
def test_init_context_manager_in_one_context() -> None:
def test_init_context_manager() -> None:
init_counter, shutdown_counter = 0, 0 init_counter, shutdown_counter = 0, 0
@contextmanager @contextmanager
@ -159,7 +154,7 @@ def test_init_context_manager() -> None:
nonlocal init_counter, shutdown_counter nonlocal init_counter, shutdown_counter
init_counter += 1 init_counter += 1
yield yield object()
shutdown_counter += 1 shutdown_counter += 1
init_counter = 0 init_counter = 0
@ -168,24 +163,77 @@ def test_init_context_manager() -> None:
provider = providers.ContextLocalResource(_init) provider = providers.ContextLocalResource(_init)
result1 = provider() result1 = provider()
assert result1 is None result2 = provider()
assert result1 == result2
assert init_counter == 1 assert init_counter == 1
assert shutdown_counter == 0 assert shutdown_counter == 0
provider.shutdown() provider.shutdown()
assert init_counter == 1 assert init_counter == 1
assert shutdown_counter == 1 assert shutdown_counter == 1
result2 = provider() provider.shutdown()
assert result2 is None assert init_counter == 1
assert init_counter == 2
assert shutdown_counter == 1 assert shutdown_counter == 1
provider.shutdown()
@mark.asyncio
async def test_async_init_context_manager_in_different_contexts() -> None:
init_counter, shutdown_counter = 0, 0
async def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield object()
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
async def run_in_context():
resource = await provider()
await provider.shutdown()
return resource
result1, result2 = await asyncio.gather(run_in_context(), run_in_context())
assert result1 != result2
assert init_counter == 2 assert init_counter == 2
assert shutdown_counter == 2 assert shutdown_counter == 2
@mark.asyncio
async def test_async_init_context_manager_in_one_context() -> None:
init_counter, shutdown_counter = 0, 0
async def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield object()
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
async def run_in_context():
resource_1 = await provider()
resource_2 = await provider()
await provider.shutdown()
return resource_1, resource_2
result1, result2 = await run_in_context()
assert result1 == result2
assert init_counter == 1
assert shutdown_counter == 1
def test_init_class(): def test_init_class():
class TestResource(resources.Resource): class TestResource(resources.Resource):
init_counter = 0 init_counter = 0
@ -218,7 +266,6 @@ def test_init_class():
assert TestResource.shutdown_counter == 2 assert TestResource.shutdown_counter == 2
def test_init_not_callable(): def test_init_not_callable():
provider = providers.ContextLocalResource(1) provider = providers.ContextLocalResource(1)
with raises(TypeError, match=r"object is not callable"): with raises(TypeError, match=r"object is not callable"):