mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-06-16 11:33:13 +03:00
Add context manager support to Resource provider
This commit is contained in:
parent
4b3476cb48
commit
a322584308
|
@ -61,11 +61,12 @@ When you call ``.shutdown()`` method on a resource provider, it will remove the
|
|||
if any, and switch to uninitialized state. Some of resource initializer types support specifying custom
|
||||
resource shutdown.
|
||||
|
||||
Resource provider supports 3 types of initializers:
|
||||
Resource provider supports 4 types of initializers:
|
||||
|
||||
- Function
|
||||
- Generator
|
||||
- Subclass of ``resources.Resource``
|
||||
- Context Manager
|
||||
- Generator (legacy)
|
||||
- Subclass of ``resources.Resource`` (legacy)
|
||||
|
||||
Function initializer
|
||||
--------------------
|
||||
|
@ -103,8 +104,44 @@ you configure global resource:
|
|||
|
||||
Function initializer does not provide a way to specify custom resource shutdown.
|
||||
|
||||
Generator initializer
|
||||
---------------------
|
||||
Context Manager initializer
|
||||
---------------------------
|
||||
|
||||
This is an extension to the Function initializer. Resource provider automatically detects if the initializer returns a
|
||||
context manager and uses it to manage the resource lifecycle.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
|
||||
class DatabaseConnection:
|
||||
def __init__(self, host, port, user, password):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
|
||||
def __enter__(self):
|
||||
print(f"Connecting to {self.host}:{self.port} as {self.user}")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
print("Closing connection")
|
||||
|
||||
|
||||
class Container(containers.DeclarativeContainer):
|
||||
|
||||
config = providers.Configuration()
|
||||
db = providers.Resource(
|
||||
DatabaseConnection,
|
||||
host=config.db.host,
|
||||
port=config.db.port,
|
||||
user=config.db.user,
|
||||
password=config.db.password,
|
||||
)
|
||||
|
||||
Generator initializer (legacy)
|
||||
------------------------------
|
||||
|
||||
Resource provider can use 2-step generators:
|
||||
|
||||
|
@ -154,8 +191,13 @@ object is not mandatory. You can leave ``yield`` statement empty:
|
|||
argument2=...,
|
||||
)
|
||||
|
||||
Subclass initializer
|
||||
--------------------
|
||||
.. note::
|
||||
|
||||
Generator initializers are automatically wrapped with ``contextmanager`` or ``asynccontextmanager`` decorator when
|
||||
provided to a ``Resource`` provider.
|
||||
|
||||
Subclass initializer (legacy)
|
||||
-----------------------------
|
||||
|
||||
You can create resource initializer by implementing a subclass of the ``resources.Resource``:
|
||||
|
||||
|
@ -263,10 +305,11 @@ Asynchronous function initializer:
|
|||
argument2=...,
|
||||
)
|
||||
|
||||
Asynchronous generator initializer:
|
||||
Asynchronous Context Manager initializer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@asynccontextmanager
|
||||
async def init_async_resource(argument1=..., argument2=...):
|
||||
connection = await connect()
|
||||
yield connection
|
||||
|
|
|
@ -3,10 +3,12 @@
|
|||
import sys
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_thread_pool(max_workers: int):
|
||||
thread_pool = ThreadPoolExecutor(max_workers=max_workers)
|
||||
yield thread_pool
|
||||
|
|
|
@ -15,8 +15,11 @@ import re
|
|||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from asyncio import ensure_future
|
||||
from configparser import ConfigParser as IniConfigParser
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import ContextVar
|
||||
from inspect import isasyncgenfunction, isgeneratorfunction
|
||||
|
||||
try:
|
||||
from inspect import _is_coroutine_mark as _is_coroutine_marker
|
||||
|
@ -3598,6 +3601,17 @@ cdef class Dict(Provider):
|
|||
return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode)
|
||||
|
||||
|
||||
@cython.no_gc
|
||||
cdef class NullAwaitable:
|
||||
def __next__(self):
|
||||
raise StopIteration from None
|
||||
|
||||
def __await__(self):
|
||||
return self
|
||||
|
||||
|
||||
cdef NullAwaitable NULL_AWAITABLE = NullAwaitable()
|
||||
|
||||
|
||||
cdef class Resource(Provider):
|
||||
"""Resource provider provides a component with initialization and shutdown."""
|
||||
|
@ -3653,6 +3667,12 @@ cdef class Resource(Provider):
|
|||
def set_provides(self, provides):
|
||||
"""Set provider provides."""
|
||||
provides = _resolve_string_import(provides)
|
||||
|
||||
if isasyncgenfunction(provides):
|
||||
provides = asynccontextmanager(provides)
|
||||
elif isgeneratorfunction(provides):
|
||||
provides = contextmanager(provides)
|
||||
|
||||
self._provides = provides
|
||||
return self
|
||||
|
||||
|
@ -3753,28 +3773,21 @@ cdef class Resource(Provider):
|
|||
"""Shutdown resource."""
|
||||
if not self._initialized:
|
||||
if self._async_mode == ASYNC_MODE_ENABLED:
|
||||
result = asyncio.Future()
|
||||
result.set_result(None)
|
||||
return result
|
||||
return NULL_AWAITABLE
|
||||
return
|
||||
|
||||
if self._shutdowner:
|
||||
try:
|
||||
shutdown = self._shutdowner(self._resource)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
if inspect.isawaitable(shutdown):
|
||||
return self._create_shutdown_future(shutdown)
|
||||
future = self._shutdowner(None, None, None)
|
||||
|
||||
if __is_future_or_coroutine(future):
|
||||
return ensure_future(self._shutdown_async(future))
|
||||
|
||||
self._resource = None
|
||||
self._initialized = False
|
||||
self._shutdowner = None
|
||||
|
||||
if self._async_mode == ASYNC_MODE_ENABLED:
|
||||
result = asyncio.Future()
|
||||
result.set_result(None)
|
||||
return result
|
||||
return NULL_AWAITABLE
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
|
@ -3784,165 +3797,75 @@ cdef class Resource(Provider):
|
|||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().related
|
||||
|
||||
async def _shutdown_async(self, future) -> None:
|
||||
try:
|
||||
await future
|
||||
finally:
|
||||
self._resource = None
|
||||
self._initialized = False
|
||||
self._shutdowner = None
|
||||
|
||||
async def _handle_async_cm(self, obj) -> None:
|
||||
try:
|
||||
self._resource = resource = await obj.__aenter__()
|
||||
self._shutdowner = obj.__aexit__
|
||||
return resource
|
||||
except:
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
async def _provide_async(self, future) -> None:
|
||||
try:
|
||||
obj = await future
|
||||
|
||||
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
|
||||
self._resource = await obj.__aenter__()
|
||||
self._shutdowner = obj.__aexit__
|
||||
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
|
||||
self._resource = obj.__enter__()
|
||||
self._shutdowner = obj.__exit__
|
||||
else:
|
||||
self._resource = obj
|
||||
self._shutdowner = None
|
||||
|
||||
return self._resource
|
||||
except:
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
if self._initialized:
|
||||
return self._resource
|
||||
|
||||
if self._is_resource_subclass(self._provides):
|
||||
initializer = self._provides()
|
||||
self._resource = __call(
|
||||
initializer.init,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
self._shutdowner = initializer.shutdown
|
||||
elif self._is_async_resource_subclass(self._provides):
|
||||
initializer = self._provides()
|
||||
async_init = __call(
|
||||
initializer.init,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
obj = __call(
|
||||
self._provides,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
|
||||
if __is_future_or_coroutine(obj):
|
||||
self._initialized = True
|
||||
return self._create_init_future(async_init, initializer.shutdown)
|
||||
elif inspect.isgeneratorfunction(self._provides):
|
||||
initializer = __call(
|
||||
self._provides,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
self._resource = next(initializer)
|
||||
self._shutdowner = initializer.send
|
||||
elif iscoroutinefunction(self._provides):
|
||||
initializer = __call(
|
||||
self._provides,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
self._resource = resource = ensure_future(self._provide_async(obj))
|
||||
return resource
|
||||
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
|
||||
self._resource = obj.__enter__()
|
||||
self._shutdowner = obj.__exit__
|
||||
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
|
||||
self._initialized = True
|
||||
return self._create_init_future(initializer)
|
||||
elif isasyncgenfunction(self._provides):
|
||||
initializer = __call(
|
||||
self._provides,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
self._initialized = True
|
||||
return self._create_async_gen_init_future(initializer)
|
||||
elif callable(self._provides):
|
||||
self._resource = __call(
|
||||
self._provides,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
self._resource = resource = ensure_future(self._handle_async_cm(obj))
|
||||
return resource
|
||||
else:
|
||||
raise Error("Unknown type of resource initializer")
|
||||
self._resource = obj
|
||||
self._shutdowner = None
|
||||
|
||||
self._initialized = True
|
||||
return self._resource
|
||||
|
||||
def _create_init_future(self, future, shutdowner=None):
|
||||
callback = self._async_init_callback
|
||||
if shutdowner:
|
||||
callback = functools.partial(callback, shutdowner=shutdowner)
|
||||
|
||||
future = asyncio.ensure_future(future)
|
||||
future.add_done_callback(callback)
|
||||
self._resource = future
|
||||
|
||||
return future
|
||||
|
||||
def _create_async_gen_init_future(self, initializer):
|
||||
if inspect.isasyncgen(initializer):
|
||||
return self._create_init_future(initializer.__anext__(), initializer.asend)
|
||||
|
||||
future = asyncio.Future()
|
||||
|
||||
create_initializer = asyncio.ensure_future(initializer)
|
||||
create_initializer.add_done_callback(functools.partial(self._async_create_gen_callback, future))
|
||||
self._resource = future
|
||||
|
||||
return future
|
||||
|
||||
def _async_init_callback(self, initializer, shutdowner=None):
|
||||
try:
|
||||
resource = initializer.result()
|
||||
except Exception:
|
||||
self._initialized = False
|
||||
else:
|
||||
self._resource = resource
|
||||
self._shutdowner = shutdowner
|
||||
|
||||
def _async_create_gen_callback(self, future, initializer_future):
|
||||
initializer = initializer_future.result()
|
||||
init_future = self._create_init_future(initializer.__anext__(), initializer.asend)
|
||||
init_future.add_done_callback(functools.partial(self._async_trigger_result, future))
|
||||
|
||||
def _async_trigger_result(self, future, future_result):
|
||||
future.set_result(future_result.result())
|
||||
|
||||
def _create_shutdown_future(self, shutdown_future):
|
||||
future = asyncio.Future()
|
||||
shutdown_future = asyncio.ensure_future(shutdown_future)
|
||||
shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future))
|
||||
return future
|
||||
|
||||
def _async_shutdown_callback(self, future_result, shutdowner):
|
||||
try:
|
||||
shutdowner.result()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
self._resource = None
|
||||
self._initialized = False
|
||||
self._shutdowner = None
|
||||
|
||||
future_result.set_result(None)
|
||||
|
||||
@staticmethod
|
||||
def _is_resource_subclass(instance):
|
||||
if not isinstance(instance, type):
|
||||
return
|
||||
from . import resources
|
||||
return issubclass(instance, resources.Resource)
|
||||
|
||||
@staticmethod
|
||||
def _is_async_resource_subclass(instance):
|
||||
if not isinstance(instance, type):
|
||||
return
|
||||
from . import resources
|
||||
return issubclass(instance, resources.AsyncResource)
|
||||
|
||||
|
||||
cdef class Container(Provider):
|
||||
"""Container provider provides an instance of declarative container.
|
||||
|
@ -4993,14 +4916,6 @@ def iscoroutinefunction(obj):
|
|||
return False
|
||||
|
||||
|
||||
def isasyncgenfunction(obj):
|
||||
"""Check if object is an asynchronous generator function."""
|
||||
try:
|
||||
return inspect.isasyncgenfunction(obj)
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_string_import(provides):
|
||||
if provides is None:
|
||||
return provides
|
||||
|
|
|
@ -1,23 +1,54 @@
|
|||
"""Resources module."""
|
||||
|
||||
import abc
|
||||
from typing import TypeVar, Generic, Optional
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Resource(Generic[T], metaclass=abc.ABCMeta):
|
||||
class Resource(Generic[T], metaclass=ABCMeta):
|
||||
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
|
||||
|
||||
@abc.abstractmethod
|
||||
def init(self, *args, **kwargs) -> Optional[T]: ...
|
||||
obj: Optional[T]
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.obj = None
|
||||
|
||||
@abstractmethod
|
||||
def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
|
||||
|
||||
def shutdown(self, resource: Optional[T]) -> None: ...
|
||||
|
||||
def __enter__(self) -> Optional[T]:
|
||||
self.obj = obj = self.init(*self.args, **self.kwargs)
|
||||
return obj
|
||||
|
||||
class AsyncResource(Generic[T], metaclass=abc.ABCMeta):
|
||||
def __exit__(self, *exc_info: Any) -> None:
|
||||
self.shutdown(self.obj)
|
||||
self.obj = None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init(self, *args, **kwargs) -> Optional[T]: ...
|
||||
|
||||
class AsyncResource(Generic[T], metaclass=ABCMeta):
|
||||
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
|
||||
|
||||
obj: Optional[T]
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.obj = None
|
||||
|
||||
@abstractmethod
|
||||
async def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
|
||||
|
||||
async def shutdown(self, resource: Optional[T]) -> None: ...
|
||||
|
||||
async def __aenter__(self) -> Optional[T]:
|
||||
self.obj = obj = await self.init(*self.args, **self.kwargs)
|
||||
return obj
|
||||
|
||||
async def __aexit__(self, *exc_info: Any) -> None:
|
||||
await self.shutdown(self.obj)
|
||||
self.obj = None
|
||||
|
|
|
@ -2,12 +2,13 @@
|
|||
|
||||
import asyncio
|
||||
import inspect
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from dependency_injector import containers, providers, resources
|
||||
from pytest import mark, raises
|
||||
|
||||
from dependency_injector import containers, providers, resources
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_init_async_function():
|
||||
|
@ -70,6 +71,46 @@ async def test_init_async_generator():
|
|||
assert _init.shutdown_counter == 2
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_init_async_context_manager() -> None:
|
||||
resource = object()
|
||||
|
||||
init_counter = 0
|
||||
shutdown_counter = 0
|
||||
|
||||
@asynccontextmanager
|
||||
async def _init():
|
||||
nonlocal init_counter, shutdown_counter
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
init_counter += 1
|
||||
|
||||
yield resource
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
shutdown_counter += 1
|
||||
|
||||
provider = providers.Resource(_init)
|
||||
|
||||
result1 = await provider()
|
||||
assert result1 is resource
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 0
|
||||
|
||||
await provider.shutdown()
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 1
|
||||
|
||||
result2 = await provider()
|
||||
assert result2 is resource
|
||||
assert init_counter == 2
|
||||
assert shutdown_counter == 1
|
||||
|
||||
await provider.shutdown()
|
||||
assert init_counter == 2
|
||||
assert shutdown_counter == 2
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_init_async_class():
|
||||
resource = object()
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
import decimal
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from dependency_injector import containers, providers, resources, errors
|
||||
from pytest import raises, mark
|
||||
from pytest import mark, raises
|
||||
|
||||
from dependency_injector import containers, errors, providers, resources
|
||||
|
||||
|
||||
def init_fn(*args, **kwargs):
|
||||
|
@ -123,6 +125,41 @@ def test_init_generator():
|
|||
assert _init.shutdown_counter == 2
|
||||
|
||||
|
||||
def test_init_context_manager() -> None:
|
||||
init_counter, shutdown_counter = 0, 0
|
||||
|
||||
@contextmanager
|
||||
def _init():
|
||||
nonlocal init_counter, shutdown_counter
|
||||
|
||||
init_counter += 1
|
||||
yield
|
||||
shutdown_counter += 1
|
||||
|
||||
init_counter = 0
|
||||
shutdown_counter = 0
|
||||
|
||||
provider = providers.Resource(_init)
|
||||
|
||||
result1 = provider()
|
||||
assert result1 is None
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 0
|
||||
|
||||
provider.shutdown()
|
||||
assert init_counter == 1
|
||||
assert shutdown_counter == 1
|
||||
|
||||
result2 = provider()
|
||||
assert result2 is None
|
||||
assert init_counter == 2
|
||||
assert shutdown_counter == 1
|
||||
|
||||
provider.shutdown()
|
||||
assert init_counter == 2
|
||||
assert shutdown_counter == 2
|
||||
|
||||
|
||||
def test_init_class():
|
||||
class TestResource(resources.Resource):
|
||||
init_counter = 0
|
||||
|
@ -190,7 +227,7 @@ def test_init_class_abc_shutdown_definition_is_not_required():
|
|||
|
||||
def test_init_not_callable():
|
||||
provider = providers.Resource(1)
|
||||
with raises(errors.Error):
|
||||
with raises(TypeError, match=r"object is not callable"):
|
||||
provider.init()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user