diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 49f9edb3..d78c37e4 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -593,6 +593,22 @@ def _fetch_reference_injections( # noqa: C901 return injections, closing +def _locate_dependent_closing_args(provider: providers.Provider) -> dict[str, providers.Provider]: + if not hasattr(provider, "args"): + return {} + + closing_deps = {} + for arg in provider.args: + if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"): + continue + + if not arg.args and isinstance(arg, providers.Resource): + return {str(id(arg)): arg} + else: + closing_deps += _locate_dependent_closing_args(arg) + return closing_deps + + def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: patched_callable = _patched_registry.get_callable(fn) if patched_callable is None: @@ -614,6 +630,9 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non if injection in patched_callable.reference_closing: patched_callable.add_closing(injection, provider) + deps = _locate_dependent_closing_args(provider) + for key, dep in deps.items(): + patched_callable.add_closing(key, dep) def _unbind_injections(fn: Callable[..., Any]) -> None: diff --git a/tests/unit/samples/wiringstringids/resourceclosing.py b/tests/unit/samples/wiringstringids/resourceclosing.py index 2faf0efd..6360e15c 100644 --- a/tests/unit/samples/wiringstringids/resourceclosing.py +++ b/tests/unit/samples/wiringstringids/resourceclosing.py @@ -20,6 +20,11 @@ class Service: cls.shutdown_counter += 1 +class FactoryService: + def __init__(self, service: Service): + self.service = service + + def init_service(): service = Service() service.init() @@ -30,8 +35,14 @@ def init_service(): class Container(containers.DeclarativeContainer): service = providers.Resource(init_service) + factory_service = providers.Factory(FactoryService, service) @inject def test_function(service: Service = Closing[Provide["service"]]): return service + + +@inject +def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]): + return factory diff --git a/tests/unit/wiring/string_ids/test_main_py36.py b/tests/unit/wiring/string_ids/test_main_py36.py index 4c8f2e55..d4c49fe8 100644 --- a/tests/unit/wiring/string_ids/test_main_py36.py +++ b/tests/unit/wiring/string_ids/test_main_py36.py @@ -289,6 +289,23 @@ def test_closing_resource(): assert result_1 is not result_2 +@mark.usefixtures("resourceclosing_container") +def test_closing_dependency_resource(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function_dependency() + assert isinstance(result_1, resourceclosing.FactoryService) + assert result_1.service.init_counter == 1 + assert result_1.service.shutdown_counter == 1 + + result_2 = resourceclosing.test_function_dependency() + assert isinstance(result_2, resourceclosing.FactoryService) + assert result_2.service.init_counter == 2 + assert result_2.service.shutdown_counter == 2 + + assert result_1 is not result_2 + + @mark.usefixtures("resourceclosing_container") def test_closing_resource_bypass_marker_injection(): resourceclosing.Service.reset_counter()