mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-11-24 11:55:47 +03:00
fix shutdowner default none value, add more tests
This commit is contained in:
parent
4838ea6aa8
commit
78cea35db9
|
|
@ -3905,7 +3905,7 @@ cdef class ContextLocalResource(Resource):
|
|||
if self._async_mode == ASYNC_MODE_ENABLED:
|
||||
return NULL_AWAITABLE
|
||||
return
|
||||
if self._shutdowner_context_var.get():
|
||||
if self._shutdowner_context_var.get() != self._none:
|
||||
future = self._shutdowner_context_var.get()(None, None, None)
|
||||
if __is_future_or_coroutine(future):
|
||||
self._reset_all_contex_vars()
|
||||
|
|
@ -3977,7 +3977,7 @@ cdef class ContextLocalResource(Resource):
|
|||
return resource
|
||||
else:
|
||||
self._resource_context_var.set(obj)
|
||||
self._shutdowner_context_var.set(None)
|
||||
self._shutdowner_context_var.set(self._none)
|
||||
|
||||
return self._resource_context_var.get()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ import asyncio
|
|||
import decimal
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from pytest import mark, raises
|
||||
|
||||
from dependency_injector import containers, errors, providers, resources
|
||||
|
||||
|
||||
def init_fn(*args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
|
|
@ -76,22 +76,21 @@ def test_injection():
|
|||
assert _init.counter == 1
|
||||
|
||||
|
||||
def test_injection_in_different_context():
|
||||
@mark.asyncio
|
||||
async def test_injection_in_different_context():
|
||||
def _init():
|
||||
return object()
|
||||
|
||||
async def _async_init():
|
||||
return object()
|
||||
|
||||
|
||||
class Container(containers.DeclarativeContainer):
|
||||
context_local_resource = providers.ContextLocalResource(_init)
|
||||
async_context_local_resource = providers.ContextLocalResource(_async_init)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
container = Container()
|
||||
obj1 = loop.run_until_complete(container.async_context_local_resource())
|
||||
obj2 = loop.run_until_complete(container.async_context_local_resource())
|
||||
obj1 = await container.async_context_local_resource()
|
||||
obj2 = await container.async_context_local_resource()
|
||||
assert obj1 != obj2
|
||||
|
||||
obj3 = container.context_local_resource()
|
||||
|
|
@ -100,8 +99,6 @@ def test_injection_in_different_context():
|
|||
assert obj3 == obj4
|
||||
|
||||
|
||||
|
||||
|
||||
def test_init_function():
|
||||
def _init():
|
||||
_init.counter += 1
|
||||
|
|
@ -121,10 +118,10 @@ def test_init_function():
|
|||
provider.shutdown()
|
||||
|
||||
|
||||
def test_init_generator():
|
||||
def test_init_generator_in_one_context():
|
||||
def _init():
|
||||
_init.init_counter += 1
|
||||
yield
|
||||
yield object()
|
||||
_init.shutdown_counter += 1
|
||||
|
||||
_init.init_counter = 0
|
||||
|
|
@ -133,7 +130,10 @@ def test_init_generator():
|
|||
provider = providers.ContextLocalResource(_init)
|
||||
|
||||
result1 = provider()
|
||||
assert result1 is None
|
||||
result2 = provider()
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
assert _init.init_counter == 1
|
||||
assert _init.shutdown_counter == 0
|
||||
|
||||
|
|
@ -141,17 +141,12 @@ def test_init_generator():
|
|||
assert _init.init_counter == 1
|
||||
assert _init.shutdown_counter == 1
|
||||
|
||||
result2 = provider()
|
||||
assert result2 is None
|
||||
assert _init.init_counter == 2
|
||||
provider.shutdown()
|
||||
assert _init.init_counter == 1
|
||||
assert _init.shutdown_counter == 1
|
||||
|
||||
provider.shutdown()
|
||||
assert _init.init_counter == 2
|
||||
assert _init.shutdown_counter == 2
|
||||
|
||||
|
||||
def test_init_context_manager() -> None:
|
||||
def test_init_context_manager_in_one_context() -> None:
|
||||
init_counter, shutdown_counter = 0, 0
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -159,7 +154,7 @@ def test_init_context_manager() -> None:
|
|||
nonlocal init_counter, shutdown_counter
|
||||
|
||||
init_counter += 1
|
||||
yield
|
||||
yield object()
|
||||
shutdown_counter += 1
|
||||
|
||||
init_counter = 0
|
||||
|
|
@ -168,24 +163,77 @@ def test_init_context_manager() -> None:
|
|||
provider = providers.ContextLocalResource(_init)
|
||||
|
||||
result1 = provider()
|
||||
assert result1 is None
|
||||
result2 = provider()
|
||||
assert result1 == result2
|
||||
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 0
|
||||
|
||||
provider.shutdown()
|
||||
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 1
|
||||
|
||||
result2 = provider()
|
||||
assert result2 is None
|
||||
assert init_counter == 2
|
||||
provider.shutdown()
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 1
|
||||
|
||||
provider.shutdown()
|
||||
|
||||
@mark.asyncio
|
||||
async def test_async_init_context_manager_in_different_contexts() -> None:
|
||||
init_counter, shutdown_counter = 0, 0
|
||||
|
||||
async def _init():
|
||||
nonlocal init_counter, shutdown_counter
|
||||
init_counter += 1
|
||||
yield object()
|
||||
shutdown_counter += 1
|
||||
|
||||
init_counter = 0
|
||||
shutdown_counter = 0
|
||||
|
||||
provider = providers.ContextLocalResource(_init)
|
||||
|
||||
async def run_in_context():
|
||||
resource = await provider()
|
||||
await provider.shutdown()
|
||||
return resource
|
||||
|
||||
result1, result2 = await asyncio.gather(run_in_context(), run_in_context())
|
||||
|
||||
assert result1 != result2
|
||||
assert init_counter == 2
|
||||
assert shutdown_counter == 2
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_async_init_context_manager_in_one_context() -> None:
|
||||
init_counter, shutdown_counter = 0, 0
|
||||
|
||||
async def _init():
|
||||
nonlocal init_counter, shutdown_counter
|
||||
init_counter += 1
|
||||
yield object()
|
||||
shutdown_counter += 1
|
||||
|
||||
init_counter = 0
|
||||
shutdown_counter = 0
|
||||
|
||||
provider = providers.ContextLocalResource(_init)
|
||||
|
||||
async def run_in_context():
|
||||
resource_1 = await provider()
|
||||
resource_2 = await provider()
|
||||
await provider.shutdown()
|
||||
return resource_1, resource_2
|
||||
|
||||
result1, result2 = await run_in_context()
|
||||
|
||||
assert result1 == result2
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 1
|
||||
|
||||
|
||||
def test_init_class():
|
||||
class TestResource(resources.Resource):
|
||||
init_counter = 0
|
||||
|
|
@ -218,7 +266,6 @@ def test_init_class():
|
|||
assert TestResource.shutdown_counter == 2
|
||||
|
||||
|
||||
|
||||
def test_init_not_callable():
|
||||
provider = providers.ContextLocalResource(1)
|
||||
with raises(TypeError, match=r"object is not callable"):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user