This commit is contained in:
Foxells 2025-11-15 11:02:18 +02:00 committed by GitHub
commit cba782363d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 775 additions and 50 deletions

View File

@ -0,0 +1,32 @@
.. _context-local-resource-provider:
Context Local Resource provider
================================
.. meta::
:keywords: Python,DI,Dependency injection,IoC,Inversion of Control,Resource,Context Local,
Context Variables,Singleton,Per-context
:description: Context Local Resource provider provides a component with initialization and shutdown
that is scoped to execution context using contextvars. This page demonstrates how to
use context local resource provider.
.. currentmodule:: dependency_injector.providers
``ContextLocalResource`` inherits from :ref:`resource-provider` and uses the same initialization and shutdown logic
as the standard ``Resource`` provider.
It extends it with context-local storage using Python's ``contextvars`` module.
This means that objects are context local singletons - the same context will
receive the same instance, but different execution contexts will have their own separate instances.
This is particularly useful in asynchronous applications where you need per-request resource instances
(such as database sessions) that are automatically cleaned up when the request context ends.
Example:
.. literalinclude:: ../../examples/providers/context_local_resource.py
:language: python
:lines: 3-
.. disqus::

View File

@ -46,6 +46,7 @@ Providers module API docs - :py:mod:`dependency_injector.providers`
dict dict
configuration configuration
resource resource
context_local_resource
aggregate aggregate
selector selector
dependency dependency

View File

@ -21,6 +21,9 @@ Resource provider
Resource providers help to initialize and configure logging, event loop, thread or process pool, etc. Resource providers help to initialize and configure logging, event loop, thread or process pool, etc.
Resource provider is similar to ``Singleton``. Resource initialization happens only once. Resource provider is similar to ``Singleton``. Resource initialization happens only once.
If you need a context local singleton (where each execution context has its own instance),
see :ref:`context-local-resource-provider`.
You can make injections and use provided instance the same way like you do with any other provider. You can make injections and use provided instance the same way like you do with any other provider.
.. code-block:: python .. code-block:: python

View File

@ -0,0 +1,50 @@
from uuid import uuid4
from fastapi import Depends, FastAPI
from dependency_injector import containers, providers
from dependency_injector.wiring import Closing, Provide, inject
global_list = []
class AsyncSessionLocal:
def __init__(self):
self.id = uuid4()
async def __aenter__(self):
print("Entering session !")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
print("Closing session !")
async def execute(self, user_input):
return f"Executing {user_input} in session {self.id}"
app = FastAPI()
class Container(containers.DeclarativeContainer):
db_session = providers.ContextLocalResource(AsyncSessionLocal)
@app.get("/")
@inject
async def index(db: AsyncSessionLocal = Depends(Closing[Provide["db_session"]])):
global global_list
if db.id in global_list:
raise Exception("The db session is already used") # never reaches here
global_list.append(db.id)
res = await db.execute("SELECT 1")
return str(res)
if __name__ == "__main__":
import uvicorn
container = Container()
container.wire(modules=["__main__"])
uvicorn.run(app, host="localhost", port=8000)
container.unwire()

View File

@ -226,9 +226,9 @@ cdef class Dict(Provider):
cdef class Resource(Provider): cdef class Resource(Provider):
cdef object _provides cdef object _provides
cdef bint _initialized cdef bint __initialized
cdef object _shutdowner cdef object __shutdowner
cdef object _resource cdef object __resource
cdef tuple _args cdef tuple _args
cdef int _args_len cdef int _args_len
@ -239,6 +239,12 @@ cdef class Resource(Provider):
cpdef object _provide(self, tuple args, dict kwargs) cpdef object _provide(self, tuple args, dict kwargs)
cdef class ContextLocalResource(Resource):
cdef object _resource_context_var
cdef object _initialized_context_var
cdef object _shutdowner_context_var
cdef class Container(Provider): cdef class Container(Provider):
cdef object _container_cls cdef object _container_cls
cdef dict _overriding_providers cdef dict _overriding_providers

View File

@ -524,6 +524,8 @@ class Resource(Provider[T]):
def init(self) -> Optional[Awaitable[T]]: ... def init(self) -> Optional[Awaitable[T]]: ...
def shutdown(self) -> Optional[Awaitable]: ... def shutdown(self) -> Optional[Awaitable]: ...
class ContextLocalResource(Resource[T]):...
class Container(Provider[T]): class Container(Provider[T]):
def __init__( def __init__(
self, self,

View File

@ -3620,9 +3620,9 @@ cdef class Resource(Provider):
self._provides = None self._provides = None
self.set_provides(provides) self.set_provides(provides)
self._initialized = False self.__initialized = False
self._resource = None self.__resource = None
self._shutdowner = None self.__shutdowner = None
self._args = tuple() self._args = tuple()
self._args_len = 0 self._args_len = 0
@ -3760,6 +3760,36 @@ cdef class Resource(Provider):
self._kwargs_len = len(self._kwargs) self._kwargs_len = len(self._kwargs)
return self return self
@property
def _initialized(self):
"""Get initialized state."""
return self.__initialized
@_initialized.setter
def _initialized(self, value):
"""Set initialized state."""
self.__initialized = value
@property
def _resource(self):
"""Get resource."""
return self.__resource
@_resource.setter
def _resource(self, value):
"""Set resource."""
self.__resource = value
@property
def _shutdowner(self):
"""Get shutdowner."""
return self.__shutdowner
@_shutdowner.setter
def _shutdowner(self, value):
"""Set shutdowner."""
self.__shutdowner = value
@property @property
def initialized(self): def initialized(self):
"""Check if resource is initialized.""" """Check if resource is initialized."""
@ -3771,24 +3801,27 @@ cdef class Resource(Provider):
def shutdown(self): def shutdown(self):
"""Shutdown resource.""" """Shutdown resource."""
if not self._initialized: if not self._initialized :
self._reset_all_contex_vars()
if self._async_mode == ASYNC_MODE_ENABLED: if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE return NULL_AWAITABLE
return return
if self._shutdowner: if self._shutdowner:
future = self._shutdowner(None, None, None) future = self._shutdowner(None, None, None)
if __is_future_or_coroutine(future): if __is_future_or_coroutine(future):
return ensure_future(self._shutdown_async(future)) self._reset_all_contex_vars()
return ensure_future(future)
self._resource = None
self._initialized = False
self._shutdowner = None
self._reset_all_contex_vars()
if self._async_mode == ASYNC_MODE_ENABLED: if self._async_mode == ASYNC_MODE_ENABLED:
return NULL_AWAITABLE return NULL_AWAITABLE
def _reset_all_contex_vars(self):
self._initialized = False
self._resource = None
self._shutdowner = None
@property @property
def related(self): def related(self):
"""Return related providers generator.""" """Return related providers generator."""
@ -3797,41 +3830,28 @@ cdef class Resource(Provider):
yield from filter(is_provider, self.kwargs.values()) yield from filter(is_provider, self.kwargs.values())
yield from super().related yield from super().related
async def _shutdown_async(self, future) -> None:
try:
await future
finally:
self._resource = None
self._initialized = False
self._shutdowner = None
async def _handle_async_cm(self, obj) -> None: async def _handle_async_cm(self, obj) -> None:
try: try:
self._resource = resource = await obj.__aenter__() resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__
return resource return resource
except: except:
self._initialized = False self._initialized = False
raise raise
async def _provide_async(self, future) -> None: async def _provide_async(self, future):
try:
obj = await future obj = await future
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._resource = await obj.__aenter__() resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__ shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__() resource = obj.__enter__()
self._shutdowner = obj.__exit__ shutdowner = obj.__exit__
else: else:
self._resource = obj resource = obj
self._shutdowner = None shutdowner = None
return self._resource return resource, shutdowner
except:
self._initialized = False
raise
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized: if self._initialized:
@ -3850,14 +3870,18 @@ cdef class Resource(Provider):
if __is_future_or_coroutine(obj): if __is_future_or_coroutine(obj):
self._initialized = True self._initialized = True
self._resource = resource = ensure_future(self._provide_async(obj)) future_result = asyncio.Future()
return resource future = ensure_future(self._provide_async(obj))
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
self._resource = future_result
return self._resource
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__() self._resource = obj.__enter__()
self._shutdowner = obj.__exit__ self._shutdowner = obj.__exit__
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._initialized = True self._initialized = True
self._resource = resource = ensure_future(self._handle_async_cm(obj)) self._resource = resource = ensure_future(self._handle_async_cm(obj))
self._shutdowner = obj.__aexit__
return resource return resource
else: else:
self._resource = obj self._resource = obj
@ -3866,6 +3890,57 @@ cdef class Resource(Provider):
self._initialized = True self._initialized = True
return self._resource return self._resource
def _async_init_instance(self, future_result, result):
try:
resource, shutdowner = result.result()
except Exception as exception:
self._resource = None
self._shutdowner = None
self._initialized = False
future_result.set_exception(exception)
else:
self._resource = resource
self._shutdowner = shutdowner
future_result.set_result(resource)
cdef class ContextLocalResource(Resource):
def __init__(self, provides=None, *args, **kwargs):
self._initialized_context_var = ContextVar("_initialized_context_var", default=False)
self._resource_context_var = ContextVar("_resource_context_var", default=None)
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=None)
super().__init__(provides, *args, **kwargs)
@property
def _initialized(self):
"""Get initialized state."""
return self._initialized_context_var.get()
@_initialized.setter
def _initialized(self, value):
"""Set initialized state."""
self._initialized_context_var.set(value)
@property
def _resource(self):
"""Get resource."""
return self._resource_context_var.get()
@_resource.setter
def _resource(self, value):
"""Set resource."""
self._resource_context_var.set(value)
@property
def _shutdowner(self):
"""Get shutdowner."""
return self._shutdowner_context_var.get()
@_shutdowner.setter
def _shutdowner(self, value):
"""Set shutdowner."""
self._shutdowner_context_var.set(value)
cdef class Container(Provider): cdef class Container(Provider):
"""Container provider provides an instance of declarative container. """Container provider provides an instance of declarative container.

View File

@ -0,0 +1,492 @@
"""Resource provider tests."""
import asyncio
import decimal
import sys
from contextlib import contextmanager
from pytest import mark, raises
from dependency_injector import containers, errors, providers, resources
def init_fn(*args, **kwargs):
return args, kwargs
def test_is_provider():
assert providers.is_provider(providers.ContextLocalResource(init_fn)) is True
def test_init_optional_provides():
provider = providers.ContextLocalResource()
provider.set_provides(init_fn)
assert provider.provides is init_fn
assert provider() == (tuple(), dict())
def test_set_provides_returns_():
provider = providers.ContextLocalResource()
assert provider.set_provides(init_fn) is provider
@mark.parametrize(
"str_name,cls",
[
("dependency_injector.providers.Factory", providers.Factory),
("decimal.Decimal", decimal.Decimal),
("list", list),
(".test_context_local_resource_py38.test_is_provider", test_is_provider),
("test_is_provider", test_is_provider),
],
)
def test_set_provides_string_imports(str_name, cls):
assert providers.ContextLocalResource(str_name).provides is cls
def test_provided_instance_provider():
provider = providers.ContextLocalResource(init_fn)
assert isinstance(provider.provided, providers.ProvidedInstance)
def test_injection():
resource = object()
def _init():
_init.counter += 1
return resource
_init.counter = 0
class Container(containers.DeclarativeContainer):
context_local_resource = providers.ContextLocalResource(_init)
dependency1 = providers.List(context_local_resource)
dependency2 = providers.List(context_local_resource)
container = Container()
list1 = container.dependency1()
list2 = container.dependency2()
assert list1 == [resource]
assert list1[0] is resource
assert list2 == [resource]
assert list2[0] is resource
assert _init.counter == 1
@mark.asyncio
async def test_injection_in_different_context():
def _init():
return object()
async def _async_init():
return object()
class Container(containers.DeclarativeContainer):
context_local_resource = providers.ContextLocalResource(_init)
async_context_local_resource = providers.ContextLocalResource(_async_init)
async def run_in_context():
obj = await container.async_context_local_resource()
return obj
container = Container()
obj1, obj2 = await asyncio.gather(run_in_context(), run_in_context())
assert obj1 != obj2
obj3 = await container.async_context_local_resource()
obj4 = await container.async_context_local_resource()
assert obj3 == obj4
obj5, obj6 = await asyncio.gather(run_in_context(), run_in_context())
assert obj5 == obj6 # as context is copied from the current one where async_context_local_resource was initialized
obj7 = container.context_local_resource()
obj8 = container.context_local_resource()
assert obj7 == obj8
def test_init_function():
def _init():
_init.counter += 1
_init.counter = 0
provider = providers.ContextLocalResource(_init)
result1 = provider()
assert result1 is None
assert _init.counter == 1
result2 = provider()
assert result2 is None
assert _init.counter == 1
provider.shutdown()
def test_init_generator_in_one_context():
def _init():
_init.init_counter += 1
yield object()
_init.shutdown_counter += 1
_init.init_counter = 0
_init.shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
result1 = provider()
result2 = provider()
assert result1 == result2
assert _init.init_counter == 1
assert _init.shutdown_counter == 0
provider.shutdown()
assert _init.init_counter == 1
assert _init.shutdown_counter == 1
provider.shutdown()
assert _init.init_counter == 1
assert _init.shutdown_counter == 1
def test_init_context_manager_in_one_context() -> None:
init_counter, shutdown_counter = 0, 0
@contextmanager
def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield object()
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
result1 = provider()
result2 = provider()
assert result1 == result2
assert init_counter == 1
assert shutdown_counter == 0
provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
@mark.asyncio
async def test_async_init_context_manager_in_different_contexts() -> None:
init_counter, shutdown_counter = 0, 0
async def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield object()
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
async def run_in_context():
resource = await provider()
await provider.shutdown()
return resource
result1, result2 = await asyncio.gather(run_in_context(), run_in_context())
assert result1 != result2
assert init_counter == 2
assert shutdown_counter == 2
@mark.asyncio
async def test_async_init_context_manager_in_one_context() -> None:
init_counter, shutdown_counter = 0, 0
async def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield object()
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
async def run_in_context():
resource_1 = await provider()
resource_2 = await provider()
await provider.shutdown()
return resource_1, resource_2
result1, result2 = await run_in_context()
assert result1 == result2
assert init_counter == 1
assert shutdown_counter == 1
def test_init_class():
class TestResource(resources.Resource):
init_counter = 0
shutdown_counter = 0
def init(self):
self.__class__.init_counter += 1
def shutdown(self, _):
self.__class__.shutdown_counter += 1
provider = providers.ContextLocalResource(TestResource)
result1 = provider()
assert result1 is None
assert TestResource.init_counter == 1
assert TestResource.shutdown_counter == 0
provider.shutdown()
assert TestResource.init_counter == 1
assert TestResource.shutdown_counter == 1
result2 = provider()
assert result2 is None
assert TestResource.init_counter == 2
assert TestResource.shutdown_counter == 1
provider.shutdown()
assert TestResource.init_counter == 2
assert TestResource.shutdown_counter == 2
def test_init_not_callable():
provider = providers.ContextLocalResource(1)
with raises(TypeError, match=r"object is not callable"):
provider.init()
def test_init_and_shutdown():
def _init():
_init.init_counter += 1
yield
_init.shutdown_counter += 1
_init.init_counter = 0
_init.shutdown_counter = 0
provider = providers.ContextLocalResource(_init)
result1 = provider.init()
assert result1 is None
assert _init.init_counter == 1
assert _init.shutdown_counter == 0
provider.shutdown()
assert _init.init_counter == 1
assert _init.shutdown_counter == 1
result2 = provider.init()
assert result2 is None
assert _init.init_counter == 2
assert _init.shutdown_counter == 1
provider.shutdown()
assert _init.init_counter == 2
assert _init.shutdown_counter == 2
def test_shutdown_of_not_initialized():
def _init():
yield
provider = providers.ContextLocalResource(_init)
result = provider.shutdown()
assert result is None
def test_initialized():
provider = providers.ContextLocalResource(init_fn)
assert provider.initialized is False
provider.init()
assert provider.initialized is True
provider.shutdown()
assert provider.initialized is False
def test_call_with_context_args():
provider = providers.ContextLocalResource(init_fn, "i1", "i2")
assert provider("i3", i4=4) == (("i1", "i2", "i3"), {"i4": 4})
def test_fluent_interface():
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).add_kwargs(a3=3, a4=4)
assert provider() == ((1, 2), {"a3": 3, "a4": 4})
def test_set_args():
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).set_args(3, 4)
assert provider.args == (3, 4)
def test_clear_args():
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).clear_args()
assert provider.args == tuple()
def test_set_kwargs():
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").set_kwargs(a3="i3", a4="i4")
assert provider.kwargs == {"a3": "i3", "a4": "i4"}
def test_clear_kwargs():
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").clear_kwargs()
assert provider.kwargs == {}
def test_call_overridden():
provider = providers.ContextLocalResource(init_fn, 1)
overriding_provider1 = providers.ContextLocalResource(init_fn, 2)
overriding_provider2 = providers.ContextLocalResource(init_fn, 3)
provider.override(overriding_provider1)
provider.override(overriding_provider2)
instance1 = provider()
instance2 = provider()
assert instance1 is instance2
assert instance1 == ((3,), {})
assert instance2 == ((3,), {})
def test_deepcopy():
provider = providers.ContextLocalResource(init_fn, 1, 2, a3=3, a4=4)
provider_copy = providers.deepcopy(provider)
assert provider is not provider_copy
assert provider.args == provider_copy.args
assert provider.kwargs == provider_copy.kwargs
assert isinstance(provider, providers.ContextLocalResource)
def test_deepcopy_initialized():
provider = providers.ContextLocalResource(init_fn)
provider.init()
with raises(errors.Error):
providers.deepcopy(provider)
def test_deepcopy_from_memo():
provider = providers.ContextLocalResource(init_fn)
provider_copy_memo = providers.ContextLocalResource(init_fn)
provider_copy = providers.deepcopy(
provider,
memo={id(provider): provider_copy_memo},
)
assert provider_copy is provider_copy_memo
def test_deepcopy_args():
provider = providers.ContextLocalResource(init_fn)
dependent_provider1 = providers.Factory(list)
dependent_provider2 = providers.Factory(dict)
provider.add_args(dependent_provider1, dependent_provider2)
provider_copy = providers.deepcopy(provider)
dependent_provider_copy1 = provider_copy.args[0]
dependent_provider_copy2 = provider_copy.args[1]
assert provider.args != provider_copy.args
assert dependent_provider1.cls is dependent_provider_copy1.cls
assert dependent_provider1 is not dependent_provider_copy1
assert dependent_provider2.cls is dependent_provider_copy2.cls
assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_kwargs():
provider = providers.ContextLocalResource(init_fn)
dependent_provider1 = providers.Factory(list)
dependent_provider2 = providers.Factory(dict)
provider.add_kwargs(d1=dependent_provider1, d2=dependent_provider2)
provider_copy = providers.deepcopy(provider)
dependent_provider_copy1 = provider_copy.kwargs["d1"]
dependent_provider_copy2 = provider_copy.kwargs["d2"]
assert provider.kwargs != provider_copy.kwargs
assert dependent_provider1.cls is dependent_provider_copy1.cls
assert dependent_provider1 is not dependent_provider_copy1
assert dependent_provider2.cls is dependent_provider_copy2.cls
assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_overridden():
provider = providers.ContextLocalResource(init_fn)
object_provider = providers.Object(object())
provider.override(object_provider)
provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0]
assert provider is not provider_copy
assert provider.args == provider_copy.args
assert isinstance(provider, providers.ContextLocalResource)
assert object_provider is not object_provider_copy
assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams():
provider = providers.ContextLocalResource(init_fn)
provider.add_args(sys.stdin, sys.stdout, sys.stderr)
provider_copy = providers.deepcopy(provider)
assert provider is not provider_copy
assert isinstance(provider_copy, providers.ContextLocalResource)
assert provider.args[0] is sys.stdin
assert provider.args[1] is sys.stdout
assert provider.args[2] is sys.stderr
def test_repr():
provider = providers.ContextLocalResource(init_fn)
assert repr(provider) == (
"<dependency_injector.providers.ContextLocalResource({0}) at {1}>".format(
repr(init_fn),
hex(id(provider)),
)
)

View File

@ -18,6 +18,7 @@ class TestResource:
resource1 = TestResource() resource1 = TestResource()
resource2 = TestResource() resource2 = TestResource()
resource3 = TestResource()
async def async_resource(resource): async def async_resource(resource):
@ -34,6 +35,8 @@ class Container(containers.DeclarativeContainer):
resource1 = providers.Resource(async_resource, providers.Object(resource1)) resource1 = providers.Resource(async_resource, providers.Object(resource1))
resource2 = providers.Resource(async_resource, providers.Object(resource2)) resource2 = providers.Resource(async_resource, providers.Object(resource2))
context_local_resource = providers.ContextLocalResource(async_resource, providers.Object(resource3))
context_local_resource_with_factory_object = providers.ContextLocalResource(async_resource, providers.Factory(TestResource))
@inject @inject
@ -57,5 +60,13 @@ async def async_generator_injection(
async def async_injection_with_closing( async def async_injection_with_closing(
resource1: object = Closing[Provide[Container.resource1]], resource1: object = Closing[Provide[Container.resource1]],
resource2: object = Closing[Provide[Container.resource2]], resource2: object = Closing[Provide[Container.resource2]],
context_local_resource: object = Closing[Provide[Container.context_local_resource]],
): ):
return resource1, resource2 return resource1, resource2, context_local_resource
@inject
async def async_injection_with_closing_context_local_resources(
context_local_resource1: object = Closing[Provide[Container.context_local_resource_with_factory_object]],
):
return context_local_resource1

View File

@ -16,6 +16,7 @@ class TestResource:
resource1 = TestResource() resource1 = TestResource()
resource2 = TestResource() resource2 = TestResource()
resource3 = TestResource()
async def async_resource(resource): async def async_resource(resource):
@ -32,6 +33,8 @@ class Container(containers.DeclarativeContainer):
resource1 = providers.Resource(async_resource, providers.Object(resource1)) resource1 = providers.Resource(async_resource, providers.Object(resource1))
resource2 = providers.Resource(async_resource, providers.Object(resource2)) resource2 = providers.Resource(async_resource, providers.Object(resource2))
context_local_resource = providers.ContextLocalResource(async_resource, providers.Object(resource3))
context_local_resource_with_factory_object = providers.ContextLocalResource(async_resource, providers.Factory(TestResource))
@inject @inject
@ -46,5 +49,13 @@ async def async_injection(
async def async_injection_with_closing( async def async_injection_with_closing(
resource1: object = Closing[Provide["resource1"]], resource1: object = Closing[Provide["resource1"]],
resource2: object = Closing[Provide["resource2"]], resource2: object = Closing[Provide["resource2"]],
context_local_resource: object = Closing[Provide["context_local_resource"]],
): ):
return resource1, resource2 return resource1, resource2, context_local_resource
@inject
async def async_injection_with_closing_context_local_resources(
context_local_resource1: object = Closing[Provide["context_local_resource_with_factory_object"]]
):
return context_local_resource1

View File

@ -1,7 +1,8 @@
"""Async injection tests.""" """Async injection tests."""
from pytest import fixture, mark import asyncio
from pytest import fixture, mark
from samples.wiring import asyncinjections from samples.wiring import asyncinjections
@ -51,7 +52,7 @@ async def test_async_generator_injections() -> None:
@mark.asyncio @mark.asyncio
async def test_async_injections_with_closing(): async def test_async_injections_with_closing():
resource1, resource2 = await asyncinjections.async_injection_with_closing() resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing()
assert resource1 is asyncinjections.resource1 assert resource1 is asyncinjections.resource1
assert asyncinjections.resource1.init_counter == 1 assert asyncinjections.resource1.init_counter == 1
@ -61,7 +62,11 @@ async def test_async_injections_with_closing():
assert asyncinjections.resource2.init_counter == 1 assert asyncinjections.resource2.init_counter == 1
assert asyncinjections.resource2.shutdown_counter == 1 assert asyncinjections.resource2.shutdown_counter == 1
resource1, resource2 = await asyncinjections.async_injection_with_closing() assert context_local_resource is asyncinjections.resource3
assert asyncinjections.resource3.init_counter == 1
assert asyncinjections.resource3.shutdown_counter == 1
resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing()
assert resource1 is asyncinjections.resource1 assert resource1 is asyncinjections.resource1
assert asyncinjections.resource1.init_counter == 2 assert asyncinjections.resource1.init_counter == 2
@ -70,3 +75,19 @@ async def test_async_injections_with_closing():
assert resource2 is asyncinjections.resource2 assert resource2 is asyncinjections.resource2
assert asyncinjections.resource2.init_counter == 2 assert asyncinjections.resource2.init_counter == 2
assert asyncinjections.resource2.shutdown_counter == 2 assert asyncinjections.resource2.shutdown_counter == 2
assert context_local_resource is asyncinjections.resource3
assert asyncinjections.resource3.init_counter == 2
assert asyncinjections.resource3.shutdown_counter == 2
@mark.asyncio
async def test_async_injections_with_closing_concurrently():
resource1, resource2 = await asyncio.gather(asyncinjections.async_injection_with_closing_context_local_resources(),
asyncinjections.async_injection_with_closing_context_local_resources())
assert resource1 != resource2
resource1 = await asyncinjections.Container.context_local_resource_with_factory_object()
resource2 = await asyncinjections.Container.context_local_resource_with_factory_object()
assert resource1 == resource2

View File

@ -1,7 +1,8 @@
"""Async injection tests.""" """Async injection tests."""
from pytest import fixture, mark import asyncio
from pytest import fixture, mark
from samples.wiringstringids import asyncinjections from samples.wiringstringids import asyncinjections
@ -34,7 +35,7 @@ async def test_async_injections():
@mark.asyncio @mark.asyncio
async def test_async_injections_with_closing(): async def test_async_injections_with_closing():
resource1, resource2 = await asyncinjections.async_injection_with_closing() resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing()
assert resource1 is asyncinjections.resource1 assert resource1 is asyncinjections.resource1
assert asyncinjections.resource1.init_counter == 1 assert asyncinjections.resource1.init_counter == 1
@ -44,7 +45,11 @@ async def test_async_injections_with_closing():
assert asyncinjections.resource2.init_counter == 1 assert asyncinjections.resource2.init_counter == 1
assert asyncinjections.resource2.shutdown_counter == 1 assert asyncinjections.resource2.shutdown_counter == 1
resource1, resource2 = await asyncinjections.async_injection_with_closing() assert context_local_resource is asyncinjections.resource3
assert asyncinjections.resource3.init_counter == 1
assert asyncinjections.resource3.shutdown_counter == 1
resource1, resource2, context_local_resource = await asyncinjections.async_injection_with_closing()
assert resource1 is asyncinjections.resource1 assert resource1 is asyncinjections.resource1
assert asyncinjections.resource1.init_counter == 2 assert asyncinjections.resource1.init_counter == 2
@ -53,3 +58,19 @@ async def test_async_injections_with_closing():
assert resource2 is asyncinjections.resource2 assert resource2 is asyncinjections.resource2
assert asyncinjections.resource2.init_counter == 2 assert asyncinjections.resource2.init_counter == 2
assert asyncinjections.resource2.shutdown_counter == 2 assert asyncinjections.resource2.shutdown_counter == 2
assert context_local_resource is asyncinjections.resource3
assert asyncinjections.resource3.init_counter == 2
assert asyncinjections.resource3.shutdown_counter == 2
@mark.asyncio
async def test_async_injections_with_closing_concurrently():
resource1, resource2 = await asyncio.gather(asyncinjections.async_injection_with_closing_context_local_resources(),
asyncinjections.async_injection_with_closing_context_local_resources())
assert resource1 != resource2
resource1 = await asyncinjections.Container.context_local_resource_with_factory_object()
resource2 = await asyncinjections.Container.context_local_resource_with_factory_object()
assert resource1 == resource2