From 3b76a0d091f3edbc2d5393a5ac9722a5f2f76562 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Sun, 18 Dec 2022 19:49:23 -0700 Subject: [PATCH] Allow `Closing` to detect dependent resources (#636) --- src/dependency_injector/wiring.py | 19 +++++++++++++++++++ .../wiringstringids/resourceclosing.py | 11 +++++++++++ .../unit/wiring/string_ids/test_main_py36.py | 17 +++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 49f9edb3..d78c37e4 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -593,6 +593,22 @@ def _fetch_reference_injections( # noqa: C901 return injections, closing +def _locate_dependent_closing_args(provider: providers.Provider) -> dict[str, providers.Provider]: + if not hasattr(provider, "args"): + return {} + + closing_deps = {} + for arg in provider.args: + if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"): + continue + + if not arg.args and isinstance(arg, providers.Resource): + return {str(id(arg)): arg} + else: + closing_deps += _locate_dependent_closing_args(arg) + return closing_deps + + def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: patched_callable = _patched_registry.get_callable(fn) if patched_callable is None: @@ -614,6 +630,9 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non if injection in patched_callable.reference_closing: patched_callable.add_closing(injection, provider) + deps = _locate_dependent_closing_args(provider) + for key, dep in deps.items(): + patched_callable.add_closing(key, dep) def _unbind_injections(fn: Callable[..., Any]) -> None: diff --git a/tests/unit/samples/wiringstringids/resourceclosing.py b/tests/unit/samples/wiringstringids/resourceclosing.py index 2faf0efd..6360e15c 100644 --- a/tests/unit/samples/wiringstringids/resourceclosing.py +++ b/tests/unit/samples/wiringstringids/resourceclosing.py @@ -20,6 +20,11 @@ class Service: cls.shutdown_counter += 1 +class FactoryService: + def __init__(self, service: Service): + self.service = service + + def init_service(): service = Service() service.init() @@ -30,8 +35,14 @@ def init_service(): class Container(containers.DeclarativeContainer): service = providers.Resource(init_service) + factory_service = providers.Factory(FactoryService, service) @inject def test_function(service: Service = Closing[Provide["service"]]): return service + + +@inject +def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]): + return factory diff --git a/tests/unit/wiring/string_ids/test_main_py36.py b/tests/unit/wiring/string_ids/test_main_py36.py index 4c8f2e55..d4c49fe8 100644 --- a/tests/unit/wiring/string_ids/test_main_py36.py +++ b/tests/unit/wiring/string_ids/test_main_py36.py @@ -289,6 +289,23 @@ def test_closing_resource(): assert result_1 is not result_2 +@mark.usefixtures("resourceclosing_container") +def test_closing_dependency_resource(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function_dependency() + assert isinstance(result_1, resourceclosing.FactoryService) + assert result_1.service.init_counter == 1 + assert result_1.service.shutdown_counter == 1 + + result_2 = resourceclosing.test_function_dependency() + assert isinstance(result_2, resourceclosing.FactoryService) + assert result_2.service.init_counter == 2 + assert result_2.service.shutdown_counter == 2 + + assert result_1 is not result_2 + + @mark.usefixtures("resourceclosing_container") def test_closing_resource_bypass_marker_injection(): resourceclosing.Service.reset_counter()