Add context manager support to Resource provider

This commit is contained in:
ZipFile 2025-06-01 14:03:41 +00:00
parent 4b3476cb48
commit a322584308
6 changed files with 260 additions and 191 deletions

View File

@ -61,11 +61,12 @@ When you call ``.shutdown()`` method on a resource provider, it will remove the
if any, and switch to uninitialized state. Some of resource initializer types support specifying custom
resource shutdown.
Resource provider supports 3 types of initializers:
Resource provider supports 4 types of initializers:
- Function
- Generator
- Subclass of ``resources.Resource``
- Context Manager
- Generator (legacy)
- Subclass of ``resources.Resource`` (legacy)
Function initializer
--------------------
@ -103,8 +104,44 @@ you configure global resource:
Function initializer does not provide a way to specify custom resource shutdown.
Generator initializer
---------------------
Context Manager initializer
---------------------------
This is an extension to the Function initializer. Resource provider automatically detects if the initializer returns a
context manager and uses it to manage the resource lifecycle.
.. code-block:: python
from dependency_injector import containers, providers
class DatabaseConnection:
def __init__(self, host, port, user, password):
self.host = host
self.port = port
self.user = user
self.password = password
def __enter__(self):
print(f"Connecting to {self.host}:{self.port} as {self.user}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print("Closing connection")
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
db = providers.Resource(
DatabaseConnection,
host=config.db.host,
port=config.db.port,
user=config.db.user,
password=config.db.password,
)
Generator initializer (legacy)
------------------------------
Resource provider can use 2-step generators:
@ -154,8 +191,13 @@ object is not mandatory. You can leave ``yield`` statement empty:
argument2=...,
)
Subclass initializer
--------------------
.. note::
Generator initializers are automatically wrapped with ``contextmanager`` or ``asynccontextmanager`` decorator when
provided to a ``Resource`` provider.
Subclass initializer (legacy)
-----------------------------
You can create resource initializer by implementing a subclass of the ``resources.Resource``:
@ -263,10 +305,11 @@ Asynchronous function initializer:
argument2=...,
)
Asynchronous generator initializer:
Asynchronous Context Manager initializer:
.. code-block:: python
@asynccontextmanager
async def init_async_resource(argument1=..., argument2=...):
connection = await connect()
yield connection

View File

@ -3,10 +3,12 @@
import sys
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dependency_injector import containers, providers
@contextmanager
def init_thread_pool(max_workers: int):
thread_pool = ThreadPoolExecutor(max_workers=max_workers)
yield thread_pool

View File

@ -15,8 +15,11 @@ import re
import sys
import threading
import warnings
from asyncio import ensure_future
from configparser import ConfigParser as IniConfigParser
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from inspect import isasyncgenfunction, isgeneratorfunction
try:
from inspect import _is_coroutine_mark as _is_coroutine_marker
@ -3598,6 +3601,17 @@ cdef class Dict(Provider):
return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode)
@cython.no_gc
cdef class NullAwaitable:
def __next__(self):
raise StopIteration from None
def __await__(self):
return self
cdef NullAwaitable NULL_AWAITABLE = NullAwaitable()
cdef class Resource(Provider):
"""Resource provider provides a component with initialization and shutdown."""
@ -3653,6 +3667,12 @@ cdef class Resource(Provider):
def set_provides(self, provides):
"""Set provider provides."""
provides = _resolve_string_import(provides)
if isasyncgenfunction(provides):
provides = asynccontextmanager(provides)
elif isgeneratorfunction(provides):
provides = contextmanager(provides)
self._provides = provides
return self
@ -3753,28 +3773,21 @@ cdef class Resource(Provider):
"""Shutdown resource."""
if not self._initialized:
if self._async_mode == ASYNC_MODE_ENABLED:
result = asyncio.Future()
result.set_result(None)
return result
return NULL_AWAITABLE
return
if self._shutdowner:
try:
shutdown = self._shutdowner(self._resource)
except StopIteration:
pass
else:
if inspect.isawaitable(shutdown):
return self._create_shutdown_future(shutdown)
future = self._shutdowner(None, None, None)
if __is_future_or_coroutine(future):
return ensure_future(self._shutdown_async(future))
self._resource = None
self._initialized = False
self._shutdowner = None
if self._async_mode == ASYNC_MODE_ENABLED:
result = asyncio.Future()
result.set_result(None)
return result
return NULL_AWAITABLE
@property
def related(self):
@ -3784,165 +3797,75 @@ cdef class Resource(Provider):
yield from filter(is_provider, self.kwargs.values())
yield from super().related
async def _shutdown_async(self, future) -> None:
try:
await future
finally:
self._resource = None
self._initialized = False
self._shutdowner = None
async def _handle_async_cm(self, obj) -> None:
try:
self._resource = resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__
return resource
except:
self._initialized = False
raise
async def _provide_async(self, future) -> None:
try:
obj = await future
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__()
self._shutdowner = obj.__exit__
else:
self._resource = obj
self._shutdowner = None
return self._resource
except:
self._initialized = False
raise
cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized:
return self._resource
if self._is_resource_subclass(self._provides):
initializer = self._provides()
self._resource = __call(
initializer.init,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._shutdowner = initializer.shutdown
elif self._is_async_resource_subclass(self._provides):
initializer = self._provides()
async_init = __call(
initializer.init,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
obj = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
if __is_future_or_coroutine(obj):
self._initialized = True
return self._create_init_future(async_init, initializer.shutdown)
elif inspect.isgeneratorfunction(self._provides):
initializer = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._resource = next(initializer)
self._shutdowner = initializer.send
elif iscoroutinefunction(self._provides):
initializer = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._resource = resource = ensure_future(self._provide_async(obj))
return resource
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__()
self._shutdowner = obj.__exit__
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._initialized = True
return self._create_init_future(initializer)
elif isasyncgenfunction(self._provides):
initializer = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._initialized = True
return self._create_async_gen_init_future(initializer)
elif callable(self._provides):
self._resource = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
self._resource = resource = ensure_future(self._handle_async_cm(obj))
return resource
else:
raise Error("Unknown type of resource initializer")
self._resource = obj
self._shutdowner = None
self._initialized = True
return self._resource
def _create_init_future(self, future, shutdowner=None):
callback = self._async_init_callback
if shutdowner:
callback = functools.partial(callback, shutdowner=shutdowner)
future = asyncio.ensure_future(future)
future.add_done_callback(callback)
self._resource = future
return future
def _create_async_gen_init_future(self, initializer):
if inspect.isasyncgen(initializer):
return self._create_init_future(initializer.__anext__(), initializer.asend)
future = asyncio.Future()
create_initializer = asyncio.ensure_future(initializer)
create_initializer.add_done_callback(functools.partial(self._async_create_gen_callback, future))
self._resource = future
return future
def _async_init_callback(self, initializer, shutdowner=None):
try:
resource = initializer.result()
except Exception:
self._initialized = False
else:
self._resource = resource
self._shutdowner = shutdowner
def _async_create_gen_callback(self, future, initializer_future):
initializer = initializer_future.result()
init_future = self._create_init_future(initializer.__anext__(), initializer.asend)
init_future.add_done_callback(functools.partial(self._async_trigger_result, future))
def _async_trigger_result(self, future, future_result):
future.set_result(future_result.result())
def _create_shutdown_future(self, shutdown_future):
future = asyncio.Future()
shutdown_future = asyncio.ensure_future(shutdown_future)
shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future))
return future
def _async_shutdown_callback(self, future_result, shutdowner):
try:
shutdowner.result()
except StopAsyncIteration:
pass
self._resource = None
self._initialized = False
self._shutdowner = None
future_result.set_result(None)
@staticmethod
def _is_resource_subclass(instance):
if not isinstance(instance, type):
return
from . import resources
return issubclass(instance, resources.Resource)
@staticmethod
def _is_async_resource_subclass(instance):
if not isinstance(instance, type):
return
from . import resources
return issubclass(instance, resources.AsyncResource)
cdef class Container(Provider):
"""Container provider provides an instance of declarative container.
@ -4993,14 +4916,6 @@ def iscoroutinefunction(obj):
return False
def isasyncgenfunction(obj):
"""Check if object is an asynchronous generator function."""
try:
return inspect.isasyncgenfunction(obj)
except AttributeError:
return False
def _resolve_string_import(provides):
if provides is None:
return provides

View File

@ -1,23 +1,54 @@
"""Resources module."""
import abc
from typing import TypeVar, Generic, Optional
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar
T = TypeVar("T")
class Resource(Generic[T], metaclass=abc.ABCMeta):
class Resource(Generic[T], metaclass=ABCMeta):
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
@abc.abstractmethod
def init(self, *args, **kwargs) -> Optional[T]: ...
obj: Optional[T]
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs
self.obj = None
@abstractmethod
def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
def shutdown(self, resource: Optional[T]) -> None: ...
def __enter__(self) -> Optional[T]:
self.obj = obj = self.init(*self.args, **self.kwargs)
return obj
class AsyncResource(Generic[T], metaclass=abc.ABCMeta):
def __exit__(self, *exc_info: Any) -> None:
self.shutdown(self.obj)
self.obj = None
@abc.abstractmethod
async def init(self, *args, **kwargs) -> Optional[T]: ...
class AsyncResource(Generic[T], metaclass=ABCMeta):
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
obj: Optional[T]
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs
self.obj = None
@abstractmethod
async def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
async def shutdown(self, resource: Optional[T]) -> None: ...
async def __aenter__(self) -> Optional[T]:
self.obj = obj = await self.init(*self.args, **self.kwargs)
return obj
async def __aexit__(self, *exc_info: Any) -> None:
await self.shutdown(self.obj)
self.obj = None

View File

@ -2,12 +2,13 @@
import asyncio
import inspect
import sys
from contextlib import asynccontextmanager
from typing import Any
from dependency_injector import containers, providers, resources
from pytest import mark, raises
from dependency_injector import containers, providers, resources
@mark.asyncio
async def test_init_async_function():
@ -70,6 +71,46 @@ async def test_init_async_generator():
assert _init.shutdown_counter == 2
@mark.asyncio
async def test_init_async_context_manager() -> None:
resource = object()
init_counter = 0
shutdown_counter = 0
@asynccontextmanager
async def _init():
nonlocal init_counter, shutdown_counter
await asyncio.sleep(0.001)
init_counter += 1
yield resource
await asyncio.sleep(0.001)
shutdown_counter += 1
provider = providers.Resource(_init)
result1 = await provider()
assert result1 is resource
assert init_counter == 1
assert shutdown_counter == 0
await provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
result2 = await provider()
assert result2 is resource
assert init_counter == 2
assert shutdown_counter == 1
await provider.shutdown()
assert init_counter == 2
assert shutdown_counter == 2
@mark.asyncio
async def test_init_async_class():
resource = object()

View File

@ -2,10 +2,12 @@
import decimal
import sys
from contextlib import contextmanager
from typing import Any
from dependency_injector import containers, providers, resources, errors
from pytest import raises, mark
from pytest import mark, raises
from dependency_injector import containers, errors, providers, resources
def init_fn(*args, **kwargs):
@ -123,6 +125,41 @@ def test_init_generator():
assert _init.shutdown_counter == 2
def test_init_context_manager() -> None:
init_counter, shutdown_counter = 0, 0
@contextmanager
def _init():
nonlocal init_counter, shutdown_counter
init_counter += 1
yield
shutdown_counter += 1
init_counter = 0
shutdown_counter = 0
provider = providers.Resource(_init)
result1 = provider()
assert result1 is None
assert init_counter == 1
assert shutdown_counter == 0
provider.shutdown()
assert init_counter == 1
assert shutdown_counter == 1
result2 = provider()
assert result2 is None
assert init_counter == 2
assert shutdown_counter == 1
provider.shutdown()
assert init_counter == 2
assert shutdown_counter == 2
def test_init_class():
class TestResource(resources.Resource):
init_counter = 0
@ -190,7 +227,7 @@ def test_init_class_abc_shutdown_definition_is_not_required():
def test_init_not_callable():
provider = providers.Resource(1)
with raises(errors.Error):
with raises(TypeError, match=r"object is not callable"):
provider.init()