Add resource type parameter to init and shutdown resources using specialized providers (#858)

This commit is contained in:
Aran Moncusí Ramírez 2025-06-16 10:34:02 +02:00 committed by GitHub
parent b411807572
commit 4bfe64563e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 212 additions and 7 deletions

View File

@ -252,6 +252,72 @@ first argument.
.. _resource-provider-wiring-closing:
Scoping Resources using specialized subclasses
----------------------------------------------
You can use specialized subclasses of ``Resource`` provider to initialize and shutdown resources by type.
Allowing for example to only initialize a subgroup of resources.
.. code-block:: python
class ScopedResource(resources.Resource):
pass
def init_service(name) -> Service:
print(f"Init {name}")
yield Service()
print(f"Shutdown {name}")
class Container(containers.DeclarativeContainer):
scoped = ScopedResource(
init_service,
"scoped",
)
generic = providers.Resource(
init_service,
"generic",
)
To initialize resources by type you can use ``init_resources(resource_type)`` and ``shutdown_resources(resource_type)``
methods adding the resource type as an argument:
.. code-block:: python
def main():
container = Container()
container.init_resources(ScopedResource)
# Generates:
# >>> Init scoped
container.shutdown_resources(ScopedResource)
# Generates:
# >>> Shutdown scoped
And to initialize all resources you can use ``init_resources()`` and ``shutdown_resources()`` without arguments:
.. code-block:: python
def main():
container = Container()
container.init_resources()
# Generates:
# >>> Init scoped
# >>> Init generic
container.shutdown_resources()
# Generates:
# >>> Shutdown scoped
# >>> Shutdown generic
It works using the :ref:`traverse` method to find all resources of the specified type, selecting all resources
which are instances of the specified type.
Resources, wiring, and per-function execution scope
---------------------------------------------------

View File

@ -22,7 +22,7 @@ try:
except ImportError:
from typing_extensions import Self as _Self
from .providers import Provider, ProviderParent, Self
from .providers import Provider, Resource, Self, ProviderParent
C_Base = TypeVar("C_Base", bound="Container")
C = TypeVar("C", bound="DeclarativeContainer")
@ -74,8 +74,8 @@ class Container:
from_package: Optional[str] = None,
) -> None: ...
def unwire(self) -> None: ...
def init_resources(self) -> Optional[Awaitable[None]]: ...
def shutdown_resources(self) -> Optional[Awaitable[None]]: ...
def init_resources(self, resource_type: Type[Resource[Any]] = Resource) -> Optional[Awaitable[None]]: ...
def shutdown_resources(self, resource_type: Type[Resource[Any]] = Resource) -> Optional[Awaitable[None]]: ...
def load_config(self) -> None: ...
def apply_container_providers_overridings(self) -> None: ...
def reset_singletons(self) -> SingletonResetContext[C_Base]: ...

View File

@ -315,11 +315,15 @@ class DynamicContainer(Container):
self.wired_to_modules.clear()
self.wired_to_packages.clear()
def init_resources(self):
def init_resources(self, resource_type=providers.Resource):
"""Initialize all container resources."""
if not issubclass(resource_type, providers.Resource):
raise TypeError("resource_type must be a subclass of Resource provider")
futures = []
for provider in self.traverse(types=[providers.Resource]):
for provider in self.traverse(types=[resource_type]):
resource = provider.init()
if __is_future_or_coroutine(resource):
@ -328,8 +332,12 @@ class DynamicContainer(Container):
if futures:
return asyncio.gather(*futures)
def shutdown_resources(self):
def shutdown_resources(self, resource_type=providers.Resource):
"""Shutdown all container resources."""
if not issubclass(resource_type, providers.Resource):
raise TypeError("resource_type must be a subclass of Resource provider")
def _independent_resources(resources):
for resource in resources:
for other_resource in resources:
@ -360,7 +368,7 @@ class DynamicContainer(Container):
for resource in resources_to_shutdown:
resource.shutdown()
resources = list(self.traverse(types=[providers.Resource]))
resources = list(self.traverse(types=[resource_type]))
if any(resource.is_async_mode_enabled() for resource in resources):
return _async_ordered_shutdown(resources)
else:

View File

@ -145,3 +145,121 @@ async def test_shutdown_sync_and_async_ordering():
await container.shutdown_resources()
assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"]
assert shutdown_resources == ["r3", "r2", "r1", "r3", "r2", "r1"]
@mark.asyncio
async def test_init_and_shutdown_scoped_resources():
initialized_resources = []
shutdown_resources = []
def _sync_resource(name, **_):
initialized_resources.append(name)
yield name
shutdown_resources.append(name)
async def _async_resource(name, **_):
initialized_resources.append(name)
yield name
shutdown_resources.append(name)
class ResourceA(providers.Resource):
pass
class ResourceB(providers.Resource):
pass
class Container(containers.DeclarativeContainer):
resource_a = ResourceA(
_sync_resource,
name="ra1",
)
resource_b1 = ResourceB(
_sync_resource,
name="rb1",
r1=resource_a,
)
resource_b2 = ResourceB(
_async_resource,
name="rb2",
r2=resource_b1,
)
container = Container()
container.init_resources(resource_type=ResourceA)
assert initialized_resources == ["ra1"]
assert shutdown_resources == []
container.shutdown_resources(resource_type=ResourceA)
assert initialized_resources == ["ra1"]
assert shutdown_resources == ["ra1"]
await container.init_resources(resource_type=ResourceB)
assert initialized_resources == ["ra1", "ra1", "rb1", "rb2"]
assert shutdown_resources == ["ra1"]
await container.shutdown_resources(resource_type=ResourceB)
assert initialized_resources == ["ra1", "ra1", "rb1", "rb2"]
assert shutdown_resources == ["ra1", "rb2", "rb1"]
@mark.asyncio
async def test_init_and_shutdown_all_scoped_resources_using_default_value():
initialized_resources = []
shutdown_resources = []
def _sync_resource(name, **_):
initialized_resources.append(name)
yield name
shutdown_resources.append(name)
async def _async_resource(name, **_):
initialized_resources.append(name)
yield name
shutdown_resources.append(name)
class ResourceA(providers.Resource):
pass
class ResourceB(providers.Resource):
pass
class Container(containers.DeclarativeContainer):
resource_a = ResourceA(
_sync_resource,
name="r1",
)
resource_b1 = ResourceB(
_sync_resource,
name="r2",
r1=resource_a,
)
resource_b2 = ResourceB(
_async_resource,
name="r3",
r2=resource_b1,
)
container = Container()
await container.init_resources()
assert initialized_resources == ["r1", "r2", "r3"]
assert shutdown_resources == []
await container.shutdown_resources()
assert initialized_resources == ["r1", "r2", "r3"]
assert shutdown_resources == ["r3", "r2", "r1"]
await container.init_resources()
assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"]
assert shutdown_resources == ["r3", "r2", "r1"]
await container.shutdown_resources()
assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"]
assert shutdown_resources == ["r3", "r2", "r1", "r3", "r2", "r1"]

View File

@ -325,6 +325,19 @@ def test_init_shutdown_nested_resources():
assert _init2.shutdown_counter == 2
def test_init_shutdown_resources_wrong_type() -> None:
class Container(containers.DeclarativeContainer):
pass
c = Container()
with raises(TypeError, match=r"resource_type must be a subclass of Resource provider"):
c.init_resources(int) # type: ignore[arg-type]
with raises(TypeError, match=r"resource_type must be a subclass of Resource provider"):
c.shutdown_resources(int) # type: ignore[arg-type]
def test_reset_singletons():
class SubSubContainer(containers.DeclarativeContainer):
singleton = providers.Singleton(object)