diff --git a/src/dependency_injector/ext/starlette.py b/src/dependency_injector/ext/starlette.py index becadf0a..620d7d26 100644 --- a/src/dependency_injector/ext/starlette.py +++ b/src/dependency_injector/ext/starlette.py @@ -1,5 +1,5 @@ import sys -from typing import Any +from typing import Any, Type if sys.version_info >= (3, 11): # pragma: no cover from typing import Self @@ -7,6 +7,7 @@ else: # pragma: no cover from typing_extensions import Self from dependency_injector.containers import Container +from dependency_injector.providers import Resource class Lifespan: @@ -29,24 +30,32 @@ class Lifespan: app = Factory(Starlette, lifespan=lifespan) :param container: container instance + :param resource_type: A :py:class:`~dependency_injector.resources.Resource` + subclass. Limits the resources to be initialized and shutdown. """ container: Container + resource_type: Type[Resource[Any]] - def __init__(self, container: Container) -> None: + def __init__( + self, + container: Container, + resource_type: Type[Resource[Any]] = Resource, + ) -> None: self.container = container + self.resource_type = resource_type def __call__(self, app: Any) -> Self: return self async def __aenter__(self) -> None: - result = self.container.init_resources() + result = self.container.init_resources(self.resource_type) if result is not None: await result async def __aexit__(self, *exc_info: Any) -> None: - result = self.container.shutdown_resources() + result = self.container.shutdown_resources(self.resource_type) if result is not None: await result diff --git a/tests/unit/ext/test_starlette.py b/tests/unit/ext/test_starlette.py index e569a382..f50d6f46 100644 --- a/tests/unit/ext/test_starlette.py +++ b/tests/unit/ext/test_starlette.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Iterator +from typing import AsyncIterator, Iterator, TypeVar from unittest.mock import ANY from pytest import mark @@ -7,6 +7,12 @@ from dependency_injector.containers import DeclarativeContainer from dependency_injector.ext.starlette import Lifespan from dependency_injector.providers import Resource +T = TypeVar("T") + + +class XResource(Resource[T]): + """A test provider""" + class TestLifespan: @mark.parametrize("sync", [False, True]) @@ -28,11 +34,15 @@ class TestLifespan: yield shutdown = True + def nope(): + assert False, "should not be called" + class Container(DeclarativeContainer): - x = Resource(sync_resource if sync else async_resource) + x = XResource(sync_resource if sync else async_resource) + y = Resource(nope) container = Container() - lifespan = Lifespan(container) + lifespan = Lifespan(container, resource_type=XResource) async with lifespan(ANY) as scope: assert scope is None