Add context manager support to Resource provider

This commit is contained in:
ZipFile 2025-06-01 14:03:41 +00:00
parent 4b3476cb48
commit a322584308
6 changed files with 260 additions and 191 deletions

View File

@ -61,11 +61,12 @@ When you call ``.shutdown()`` method on a resource provider, it will remove the
if any, and switch to uninitialized state. Some of resource initializer types support specifying custom if any, and switch to uninitialized state. Some of resource initializer types support specifying custom
resource shutdown. resource shutdown.
Resource provider supports 3 types of initializers: Resource provider supports 4 types of initializers:
- Function - Function
- Generator - Context Manager
- Subclass of ``resources.Resource`` - Generator (legacy)
- Subclass of ``resources.Resource`` (legacy)
Function initializer Function initializer
-------------------- --------------------
@ -103,8 +104,44 @@ you configure global resource:
Function initializer does not provide a way to specify custom resource shutdown. Function initializer does not provide a way to specify custom resource shutdown.
Generator initializer Context Manager initializer
--------------------- ---------------------------
This is an extension to the Function initializer. Resource provider automatically detects if the initializer returns a
context manager and uses it to manage the resource lifecycle.
.. code-block:: python
from dependency_injector import containers, providers
class DatabaseConnection:
def __init__(self, host, port, user, password):
self.host = host
self.port = port
self.user = user
self.password = password
def __enter__(self):
print(f"Connecting to {self.host}:{self.port} as {self.user}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print("Closing connection")
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
db = providers.Resource(
DatabaseConnection,
host=config.db.host,
port=config.db.port,
user=config.db.user,
password=config.db.password,
)
Generator initializer (legacy)
------------------------------
Resource provider can use 2-step generators: Resource provider can use 2-step generators:
@ -154,8 +191,13 @@ object is not mandatory. You can leave ``yield`` statement empty:
argument2=..., argument2=...,
) )
Subclass initializer .. note::
--------------------
Generator initializers are automatically wrapped with ``contextmanager`` or ``asynccontextmanager`` decorator when
provided to a ``Resource`` provider.
Subclass initializer (legacy)
-----------------------------
You can create resource initializer by implementing a subclass of the ``resources.Resource``: You can create resource initializer by implementing a subclass of the ``resources.Resource``:
@ -263,10 +305,11 @@ Asynchronous function initializer:
argument2=..., argument2=...,
) )
Asynchronous generator initializer: Asynchronous Context Manager initializer:
.. code-block:: python .. code-block:: python
@asynccontextmanager
async def init_async_resource(argument1=..., argument2=...): async def init_async_resource(argument1=..., argument2=...):
connection = await connect() connection = await connect()
yield connection yield connection

View File

@ -3,10 +3,12 @@
import sys import sys
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dependency_injector import containers, providers from dependency_injector import containers, providers
@contextmanager
def init_thread_pool(max_workers: int): def init_thread_pool(max_workers: int):
thread_pool = ThreadPoolExecutor(max_workers=max_workers) thread_pool = ThreadPoolExecutor(max_workers=max_workers)
yield thread_pool yield thread_pool

View File

@ -15,8 +15,11 @@ import re
import sys import sys
import threading import threading
import warnings import warnings
from asyncio import ensure_future
from configparser import ConfigParser as IniConfigParser from configparser import ConfigParser as IniConfigParser
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from inspect import isasyncgenfunction, isgeneratorfunction
try: try:
from inspect import _is_coroutine_mark as _is_coroutine_marker from inspect import _is_coroutine_mark as _is_coroutine_marker
@ -3598,6 +3601,17 @@ cdef class Dict(Provider):
return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode) return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode)
@cython.no_gc
cdef class NullAwaitable:
def __next__(self):
raise StopIteration from None
def __await__(self):
return self
cdef NullAwaitable NULL_AWAITABLE = NullAwaitable()
cdef class Resource(Provider): cdef class Resource(Provider):
"""Resource provider provides a component with initialization and shutdown.""" """Resource provider provides a component with initialization and shutdown."""
@ -3653,6 +3667,12 @@ cdef class Resource(Provider):
def set_provides(self, provides): def set_provides(self, provides):
"""Set provider provides.""" """Set provider provides."""
provides = _resolve_string_import(provides) provides = _resolve_string_import(provides)
if isasyncgenfunction(provides):
provides = asynccontextmanager(provides)
elif isgeneratorfunction(provides):
provides = contextmanager(provides)
self._provides = provides self._provides = provides
return self return self
@ -3753,28 +3773,21 @@ cdef class Resource(Provider):
"""Shutdown resource.""" """Shutdown resource."""
if not self._initialized: if not self._initialized:
if self._async_mode == ASYNC_MODE_ENABLED: if self._async_mode == ASYNC_MODE_ENABLED:
result = asyncio.Future() return NULL_AWAITABLE
result.set_result(None)
return result
return return
if self._shutdowner: if self._shutdowner:
try: future = self._shutdowner(None, None, None)
shutdown = self._shutdowner(self._resource)
except StopIteration: if __is_future_or_coroutine(future):
pass return ensure_future(self._shutdown_async(future))
else:
if inspect.isawaitable(shutdown):
return self._create_shutdown_future(shutdown)
self._resource = None self._resource = None
self._initialized = False self._initialized = False
self._shutdowner = None self._shutdowner = None
if self._async_mode == ASYNC_MODE_ENABLED: if self._async_mode == ASYNC_MODE_ENABLED:
result = asyncio.Future() return NULL_AWAITABLE
result.set_result(None)
return result
@property @property
def related(self): def related(self):
@ -3784,164 +3797,74 @@ cdef class Resource(Provider):
yield from filter(is_provider, self.kwargs.values()) yield from filter(is_provider, self.kwargs.values())
yield from super().related yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): async def _shutdown_async(self, future) -> None:
if self._initialized:
return self._resource
if self._is_resource_subclass(self._provides):
initializer = self._provides()
self._resource = __call(
initializer.init,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._shutdowner = initializer.shutdown
elif self._is_async_resource_subclass(self._provides):
initializer = self._provides()
async_init = __call(
initializer.init,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._initialized = True
return self._create_init_future(async_init, initializer.shutdown)
elif inspect.isgeneratorfunction(self._provides):
initializer = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._resource = next(initializer)
self._shutdowner = initializer.send
elif iscoroutinefunction(self._provides):
initializer = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._initialized = True
return self._create_init_future(initializer)
elif isasyncgenfunction(self._provides):
initializer = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._initialized = True
return self._create_async_gen_init_future(initializer)
elif callable(self._provides):
self._resource = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
else:
raise Error("Unknown type of resource initializer")
self._initialized = True
return self._resource
def _create_init_future(self, future, shutdowner=None):
callback = self._async_init_callback
if shutdowner:
callback = functools.partial(callback, shutdowner=shutdowner)
future = asyncio.ensure_future(future)
future.add_done_callback(callback)
self._resource = future
return future
def _create_async_gen_init_future(self, initializer):
if inspect.isasyncgen(initializer):
return self._create_init_future(initializer.__anext__(), initializer.asend)
future = asyncio.Future()
create_initializer = asyncio.ensure_future(initializer)
create_initializer.add_done_callback(functools.partial(self._async_create_gen_callback, future))
self._resource = future
return future
def _async_init_callback(self, initializer, shutdowner=None):
try: try:
resource = initializer.result() await future
except Exception: finally:
self._initialized = False
else:
self._resource = resource
self._shutdowner = shutdowner
def _async_create_gen_callback(self, future, initializer_future):
initializer = initializer_future.result()
init_future = self._create_init_future(initializer.__anext__(), initializer.asend)
init_future.add_done_callback(functools.partial(self._async_trigger_result, future))
def _async_trigger_result(self, future, future_result):
future.set_result(future_result.result())
def _create_shutdown_future(self, shutdown_future):
future = asyncio.Future()
shutdown_future = asyncio.ensure_future(shutdown_future)
shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future))
return future
def _async_shutdown_callback(self, future_result, shutdowner):
try:
shutdowner.result()
except StopAsyncIteration:
pass
self._resource = None self._resource = None
self._initialized = False self._initialized = False
self._shutdowner = None self._shutdowner = None
future_result.set_result(None) async def _handle_async_cm(self, obj) -> None:
try:
self._resource = resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__
return resource
except:
self._initialized = False
raise
@staticmethod async def _provide_async(self, future) -> None:
def _is_resource_subclass(instance): try:
if not isinstance(instance, type): obj = await future
return
from . import resources
return issubclass(instance, resources.Resource)
@staticmethod if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
def _is_async_resource_subclass(instance): self._resource = await obj.__aenter__()
if not isinstance(instance, type): self._shutdowner = obj.__aexit__
return elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
from . import resources self._resource = obj.__enter__()
return issubclass(instance, resources.AsyncResource) self._shutdowner = obj.__exit__
else:
self._resource = obj
self._shutdowner = None
return self._resource
except:
self._initialized = False
raise
cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized:
return self._resource
obj = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
if __is_future_or_coroutine(obj):
self._initialized = True
self._resource = resource = ensure_future(self._provide_async(obj))
return resource
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__()
self._shutdowner = obj.__exit__
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._initialized = True
self._resource = resource = ensure_future(self._handle_async_cm(obj))
return resource
else:
self._resource = obj
self._shutdowner = None
self._initialized = True
return self._resource
cdef class Container(Provider): cdef class Container(Provider):
@ -4993,14 +4916,6 @@ def iscoroutinefunction(obj):
return False return False
def isasyncgenfunction(obj):
"""Check if object is an asynchronous generator function."""
try:
return inspect.isasyncgenfunction(obj)
except AttributeError:
return False
def _resolve_string_import(provides): def _resolve_string_import(provides):
if provides is None: if provides is None:
return provides return provides

View File

@ -1,23 +1,54 @@
"""Resources module.""" """Resources module."""
import abc from abc import ABCMeta, abstractmethod
from typing import TypeVar, Generic, Optional from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar
T = TypeVar("T") T = TypeVar("T")
class Resource(Generic[T], metaclass=abc.ABCMeta): class Resource(Generic[T], metaclass=ABCMeta):
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
@abc.abstractmethod obj: Optional[T]
def init(self, *args, **kwargs) -> Optional[T]: ...
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs
self.obj = None
@abstractmethod
def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
def shutdown(self, resource: Optional[T]) -> None: ... def shutdown(self, resource: Optional[T]) -> None: ...
def __enter__(self) -> Optional[T]:
self.obj = obj = self.init(*self.args, **self.kwargs)
return obj
class AsyncResource(Generic[T], metaclass=abc.ABCMeta): def __exit__(self, *exc_info: Any) -> None:
self.shutdown(self.obj)
self.obj = None
@abc.abstractmethod
async def init(self, *args, **kwargs) -> Optional[T]: ... class AsyncResource(Generic[T], metaclass=ABCMeta):
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
obj: Optional[T]
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs
self.obj = None
@abstractmethod
async def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
async def shutdown(self, resource: Optional[T]) -> None: ... async def shutdown(self, resource: Optional[T]) -> None: ...
async def __aenter__(self) -> Optional[T]:
self.obj = obj = await self.init(*self.args, **self.kwargs)
return obj
async def __aexit__(self, *exc_info: Any) -> None:
await self.shutdown(self.obj)
self.obj = None

View File

@ -2,12 +2,13 @@
import asyncio import asyncio
import inspect import inspect
import sys from contextlib import asynccontextmanager
from typing import Any from typing import Any
from dependency_injector import containers, providers, resources
from pytest import mark, raises from pytest import mark, raises
from dependency_injector import containers, providers, resources
@mark.asyncio @mark.asyncio
async def test_init_async_function(): async def test_init_async_function():
@ -70,6 +71,46 @@ async def test_init_async_generator():
assert _init.shutdown_counter == 2 assert _init.shutdown_counter == 2
@mark.asyncio
async def test_init_async_context_manager() -> None:
resource = object()
init_counter = 0
shutdown_counter = 0
@asynccontextmanager
async def _init():
nonlocal init_counter, shutdown_counter
await asyncio.sleep(0.001)
init_counter += 1
yield resource
await asyncio.sleep(0.001)
shutdown_counter += 1
provider = providers.Resource(_init)
result1 = await provider()
assert result1 is resource
assert init_counter == 1
assert shutdown_counter == 0
await provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
result2 = await provider()
assert result2 is resource
assert init_counter == 2
assert shutdown_counter == 1
await provider.shutdown()
assert init_counter == 2
assert shutdown_counter == 2
@mark.asyncio @mark.asyncio
async def test_init_async_class(): async def test_init_async_class():
resource = object() resource = object()

View File

@ -2,10 +2,12 @@
import decimal import decimal
import sys import sys
from contextlib import contextmanager
from typing import Any from typing import Any
from dependency_injector import containers, providers, resources, errors from pytest import mark, raises
from pytest import raises, mark
from dependency_injector import containers, errors, providers, resources
def init_fn(*args, **kwargs): def init_fn(*args, **kwargs):
@ -123,6 +125,41 @@ def test_init_generator():
assert _init.shutdown_counter == 2 assert _init.shutdown_counter == 2
def test_init_context_manager() -> None:
init_counter, shutdown_counter = 0, 0
@contextmanager
def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.Resource(_init)
result1 = provider()
assert result1 is None
assert init_counter == 1
assert shutdown_counter == 0
provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
result2 = provider()
assert result2 is None
assert init_counter == 2
assert shutdown_counter == 1
provider.shutdown()
assert init_counter == 2
assert shutdown_counter == 2
def test_init_class(): def test_init_class():
class TestResource(resources.Resource): class TestResource(resources.Resource):
init_counter = 0 init_counter = 0
@ -190,7 +227,7 @@ def test_init_class_abc_shutdown_definition_is_not_required():
def test_init_not_callable(): def test_init_not_callable():
provider = providers.Resource(1) provider = providers.Resource(1)
with raises(errors.Error): with raises(TypeError, match=r"object is not callable"):
provider.init() provider.init()