diff --git a/docs/providers/resource.rst b/docs/providers/resource.rst index 918dfa66..6bea458a 100644 --- a/docs/providers/resource.rst +++ b/docs/providers/resource.rst @@ -210,6 +210,72 @@ first argument. .. _resource-provider-wiring-closing: +Scoping Resources using specialized subclasses +---------------------------------------------- + +You can use specialized subclasses of ``Resource`` provider to initialize and shutdown resources by type. +Allowing for example to only initialize a subgroup of resources. + +.. code-block:: python + + class ScopedResource(resources.Resource): + pass + + def init_service(name) -> Service: + print(f"Init {name}") + yield Service() + print(f"Shutdown {name}") + + class Container(containers.DeclarativeContainer): + + scoped = ScopedResource( + init_service, + "scoped", + ) + + generic = providers.Resource( + init_service, + "generic", + ) + + +To initialize resources by type you can use ``init_resources(resource_type)`` and ``shutdown_resources(resource_type)`` +methods adding the resource type as an argument: + +.. code-block:: python + + def main(): + container = Container() + container.init_resources(ScopedResource) + # Generates: + # >>> Init scoped + + container.shutdown_resources(ScopedResource) + # Generates: + # >>> Shutdown scoped + + +And to initialize all resources you can use ``init_resources()`` and ``shutdown_resources()`` without arguments: + +.. code-block:: python + + def main(): + container = Container() + container.init_resources() + # Generates: + # >>> Init scoped + # >>> Init generic + + container.shutdown_resources() + # Generates: + # >>> Shutdown scoped + # >>> Shutdown generic + + +It works using the :ref:`traverse` method to find all resources of the specified type, selecting all resources +which are instances of the specified type. + + Resources, wiring, and per-function execution scope --------------------------------------------------- diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index 4b40fbba..9cced661 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -17,7 +17,7 @@ from typing import ( overload, ) -from .providers import Provider, Self, ProviderParent +from .providers import Provider, Resource, Self, ProviderParent C_Base = TypeVar("C_Base", bound="Container") C = TypeVar("C", bound="DeclarativeContainer") @@ -67,8 +67,8 @@ class Container: from_package: Optional[str] = None, ) -> None: ... def unwire(self) -> None: ... - def init_resources(self) -> Optional[Awaitable]: ... - def shutdown_resources(self) -> Optional[Awaitable]: ... + def init_resources(self, resource_type: Type[Resource]=None) -> Optional[Awaitable]: ... + def shutdown_resources(self, resource_type: Type[Resource]=None) -> Optional[Awaitable]: ... def load_config(self) -> None: ... def apply_container_providers_overridings(self) -> None: ... def reset_singletons(self) -> SingletonResetContext[C_Base]: ... diff --git a/src/dependency_injector/containers.pyx b/src/dependency_injector/containers.pyx index 2f4c4af5..24e2aefc 100644 --- a/src/dependency_injector/containers.pyx +++ b/src/dependency_injector/containers.pyx @@ -310,11 +310,11 @@ class DynamicContainer(Container): self.wired_to_modules.clear() self.wired_to_packages.clear() - def init_resources(self): + def init_resources(self, resource_type=providers.Resource): """Initialize all container resources.""" futures = [] - for provider in self.traverse(types=[providers.Resource]): + for provider in self.traverse(types=[resource_type]): resource = provider.init() if __is_future_or_coroutine(resource): @@ -323,7 +323,7 @@ class DynamicContainer(Container): if futures: return asyncio.gather(*futures) - def shutdown_resources(self): + def shutdown_resources(self, resource_type=providers.Resource): """Shutdown all container resources.""" def _independent_resources(resources): for resource in resources: @@ -355,7 +355,7 @@ class DynamicContainer(Container): for resource in resources_to_shutdown: resource.shutdown() - resources = list(self.traverse(types=[providers.Resource])) + resources = list(self.traverse(types=[resource_type])) if any(resource.is_async_mode_enabled() for resource in resources): return _async_ordered_shutdown(resources) else: diff --git a/tests/unit/containers/instance/test_async_resources_py36.py b/tests/unit/containers/instance/test_async_resources_py36.py index b365b60d..47fd03e7 100644 --- a/tests/unit/containers/instance/test_async_resources_py36.py +++ b/tests/unit/containers/instance/test_async_resources_py36.py @@ -145,3 +145,121 @@ async def test_shutdown_sync_and_async_ordering(): await container.shutdown_resources() assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] assert shutdown_resources == ["r3", "r2", "r1", "r3", "r2", "r1"] + + +@mark.asyncio +async def test_init_and_shutdown_scoped_resources(): + initialized_resources = [] + shutdown_resources = [] + + def _sync_resource(name, **_): + initialized_resources.append(name) + yield name + shutdown_resources.append(name) + + async def _async_resource(name, **_): + initialized_resources.append(name) + yield name + shutdown_resources.append(name) + + + class ResourceA(providers.Resource): + pass + + + class ResourceB(providers.Resource): + pass + + + class Container(containers.DeclarativeContainer): + resource_a = ResourceA( + _sync_resource, + name="ra1", + ) + resource_b1 = ResourceB( + _sync_resource, + name="rb1", + r1=resource_a, + ) + resource_b2 = ResourceB( + _async_resource, + name="rb2", + r2=resource_b1, + ) + + container = Container() + + container.init_resources(resource_type=ResourceA) + assert initialized_resources == ["ra1"] + assert shutdown_resources == [] + + container.shutdown_resources(resource_type=ResourceA) + assert initialized_resources == ["ra1"] + assert shutdown_resources == ["ra1"] + + await container.init_resources(resource_type=ResourceB) + assert initialized_resources == ["ra1", "ra1", "rb1", "rb2"] + assert shutdown_resources == ["ra1"] + + await container.shutdown_resources(resource_type=ResourceB) + assert initialized_resources == ["ra1", "ra1", "rb1", "rb2"] + assert shutdown_resources == ["ra1", "rb2", "rb1"] + + +@mark.asyncio +async def test_init_and_shutdown_all_scoped_resources_using_default_value(): + initialized_resources = [] + shutdown_resources = [] + + def _sync_resource(name, **_): + initialized_resources.append(name) + yield name + shutdown_resources.append(name) + + async def _async_resource(name, **_): + initialized_resources.append(name) + yield name + shutdown_resources.append(name) + + + class ResourceA(providers.Resource): + pass + + + class ResourceB(providers.Resource): + pass + + + class Container(containers.DeclarativeContainer): + resource_a = ResourceA( + _sync_resource, + name="r1", + ) + resource_b1 = ResourceB( + _sync_resource, + name="r2", + r1=resource_a, + ) + resource_b2 = ResourceB( + _async_resource, + name="r3", + r2=resource_b1, + ) + + container = Container() + + await container.init_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == [] + + await container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + await container.init_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + await container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1", "r3", "r2", "r1"]