fix shutdowner default none value, add more tests

This commit is contained in:
elina-israyelyan 2025-10-23 00:32:52 +04:00
parent 4838ea6aa8
commit 78cea35db9
2 changed files with 77 additions and 30 deletions

View File

@ -3905,7 +3905,7 @@ cdef class ContextLocalResource(Resource):
if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE
return
if self._shutdowner_context_var.get():
if self._shutdowner_context_var.get() != self._none:
future = self._shutdowner_context_var.get()(None, None, None)
if __is_future_or_coroutine(future):
self._reset_all_contex_vars()
@ -3977,7 +3977,7 @@ cdef class ContextLocalResource(Resource):
return resource
else:
self._resource_context_var.set(obj)
self._shutdowner_context_var.set(None)
self._shutdowner_context_var.set(self._none)
return self._resource_context_var.get()

View File

@ -4,12 +4,12 @@ import asyncio
import decimal
import sys
from contextlib import contextmanager
from typing import Any
from pytest import mark, raises
from dependency_injector import containers, errors, providers, resources
def init_fn(*args, **kwargs):
return args, kwargs
@ -76,22 +76,21 @@ def test_injection():
assert _init.counter == 1
def test_injection_in_different_context():
@mark.asyncio
async def test_injection_in_different_context():
def _init():
return object()
async def _async_init():
return object()
class Container(containers.DeclarativeContainer):
context_local_resource = providers.ContextLocalResource(_init)
async_context_local_resource = providers.ContextLocalResource(_async_init)
loop = asyncio.get_event_loop()
container = Container()
obj1 = loop.run_until_complete(container.async_context_local_resource())
obj2 = loop.run_until_complete(container.async_context_local_resource())
obj1 = await container.async_context_local_resource()
obj2 = await container.async_context_local_resource()
assert obj1 != obj2
obj3 = container.context_local_resource()
@ -100,8 +99,6 @@ def test_injection_in_different_context():
assert obj3 == obj4
def test_init_function():
def _init():
_init.counter += 1
@ -121,10 +118,10 @@ def test_init_function():
provider.shutdown()
def test_init_generator():
def test_init_generator_in_one_context():
def _init():
_init.init_counter += 1
yield
yield object()
_init.shutdown_counter += 1
_init.init_counter = 0
@ -133,7 +130,10 @@ def test_init_generator():
provider = providers.ContextLocalResource(_init)
result1 = provider()
assert result1 is None
result2 = provider()
assert result1 == result2
assert _init.init_counter == 1
assert _init.shutdown_counter == 0
@ -141,17 +141,12 @@ def test_init_generator():
assert _init.init_counter == 1
assert _init.shutdown_counter == 1
result2 = provider()
assert result2 is None
assert _init.init_counter == 2
provider.shutdown()
assert _init.init_counter == 1
assert _init.shutdown_counter == 1
provider.shutdown()
assert _init.init_counter == 2
assert _init.shutdown_counter == 2
def test_init_context_manager() -> None:
def test_init_context_manager_in_one_context() -> None:
init_counter, shutdown_counter = 0, 0
@contextmanager
@ -159,7 +154,7 @@ def test_init_context_manager() -> None:
nonlocal init_counter, shutdown_counter
init_counter += 1
yield
yield object()
shutdown_counter += 1
init_counter = 0
@ -168,24 +163,77 @@ def test_init_context_manager() -> None:
provider = providers.ContextLocalResource(_init)
result1 = provider()
assert result1 is None
result2 = provider()
assert result1 == result2
assert init_counter == 1
assert shutdown_counter == 0
provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
result2 = provider()
assert result2 is None
assert init_counter == 2
provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
provider.shutdown()
@mark.asyncio
async def test_async_init_context_manager_in_different_contexts() -> None:
init_counter, shutdown_counter = 0, 0
async def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield object()
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
async def run_in_context():
resource = await provider()
await provider.shutdown()
return resource
result1, result2 = await asyncio.gather(run_in_context(), run_in_context())
assert result1 != result2
assert init_counter == 2
assert shutdown_counter == 2
@mark.asyncio
async def test_async_init_context_manager_in_one_context() -> None:
init_counter, shutdown_counter = 0, 0
async def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield object()
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
async def run_in_context():
resource_1 = await provider()
resource_2 = await provider()
await provider.shutdown()
return resource_1, resource_2
result1, result2 = await run_in_context()
assert result1 == result2
assert init_counter == 1
assert shutdown_counter == 1
def test_init_class():
class TestResource(resources.Resource):
init_counter = 0
@ -218,7 +266,6 @@ def test_init_class():
assert TestResource.shutdown_counter == 2
def test_init_not_callable():
provider = providers.ContextLocalResource(1)
with raises(TypeError, match=r"object is not callable"):