diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 9bb990ab..5cded9f5 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -1,26 +1,26 @@ """Wiring module.""" import functools -import inspect import importlib import importlib.machinery +import inspect import pkgutil -import warnings import sys +import warnings from types import ModuleType from typing import ( - Optional, - Iterable, - Iterator, - Callable, Any, - Tuple, + Callable, Dict, Generic, - TypeVar, - Type, - Union, + Iterable, + Iterator, + Optional, Set, + Tuple, + Type, + TypeVar, + Union, cast, ) @@ -643,21 +643,18 @@ def _fetch_reference_injections( # noqa: C901 def _locate_dependent_closing_args( - provider: providers.Provider, + provider: providers.Provider, closing_deps: Dict[str, providers.Provider] ) -> Dict[str, providers.Provider]: - if not hasattr(provider, "args"): - return {} - - closing_deps = {} - for arg in provider.args: - if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"): + for arg in [ + *getattr(provider, "args", []), + *getattr(provider, "kwargs", {}).values(), + ]: + if not isinstance(arg, providers.Provider): continue + if isinstance(arg, providers.Resource): + closing_deps[str(id(arg))] = arg - if not arg.args and isinstance(arg, providers.Resource): - return {str(id(arg)): arg} - else: - closing_deps += _locate_dependent_closing_args(arg) - return closing_deps + _locate_dependent_closing_args(arg, closing_deps) def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: @@ -681,7 +678,8 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non if injection in patched_callable.reference_closing: patched_callable.add_closing(injection, provider) - deps = _locate_dependent_closing_args(provider) + deps = {} + _locate_dependent_closing_args(provider, deps) for key, dep in deps.items(): patched_callable.add_closing(key, dep) diff --git a/tests/unit/samples/wiringstringids/resourceclosing.py b/tests/unit/samples/wiringstringids/resourceclosing.py index 6360e15c..c4d1f20f 100644 --- a/tests/unit/samples/wiringstringids/resourceclosing.py +++ b/tests/unit/samples/wiringstringids/resourceclosing.py @@ -1,41 +1,80 @@ +from typing import Any, Dict, List, Optional + from dependency_injector import containers, providers -from dependency_injector.wiring import inject, Provide, Closing +from dependency_injector.wiring import Closing, Provide, inject + + +class Counter: + def __init__(self) -> None: + self._init = 0 + self._shutdown = 0 + + def init(self) -> None: + self._init += 1 + + def shutdown(self) -> None: + self._shutdown += 1 + + def reset(self) -> None: + self._init = 0 + self._shutdown = 0 class Service: - init_counter: int = 0 - shutdown_counter: int = 0 + def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None: + self.counter = counter or Counter() + self.dependencies = dependencies - @classmethod - def reset_counter(cls): - cls.init_counter = 0 - cls.shutdown_counter = 0 + def init(self) -> None: + self.counter.init() - @classmethod - def init(cls): - cls.init_counter += 1 + def shutdown(self) -> None: + self.counter.shutdown() - @classmethod - def shutdown(cls): - cls.shutdown_counter += 1 + @property + def init_counter(self) -> int: + return self.counter._init + + @property + def shutdown_counter(self) -> int: + return self.counter._shutdown class FactoryService: - def __init__(self, service: Service): + def __init__(self, service: Service, service2: Service): self.service = service + self.service2 = service2 -def init_service(): - service = Service() +class NestedService: + def __init__(self, factory_service: FactoryService): + self.factory_service = factory_service + + +def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]): + service = Service(counter, _list=_list, _dict=_dict) service.init() yield service service.shutdown() class Container(containers.DeclarativeContainer): - - service = providers.Resource(init_service) - factory_service = providers.Factory(FactoryService, service) + counter = providers.Singleton(Counter) + _list = providers.List( + providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2) + ) + _dict = providers.Dict( + a=providers.Callable(lambda a: a, a=3), b=providers.Callable(lambda b: b, 4) + ) + service = providers.Resource(init_service, counter, _list, _dict=_dict) + service2 = providers.Resource(init_service, counter, _list, _dict=_dict) + factory_service = providers.Factory(FactoryService, service, service2) + factory_service_kwargs = providers.Factory( + FactoryService, + service=service, + service2=service2, + ) + nested_service = providers.Factory(NestedService, factory_service) @inject @@ -44,5 +83,21 @@ def test_function(service: Service = Closing[Provide["service"]]): @inject -def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]): +def test_function_dependency( + factory: FactoryService = Closing[Provide["factory_service"]], +): return factory + + +@inject +def test_function_dependency_kwargs( + factory: FactoryService = Closing[Provide["factory_service_kwargs"]], +): + return factory + + +@inject +def test_function_nested_dependency( + nested: NestedService = Closing[Provide["nested_service"]], +): + return nested diff --git a/tests/unit/wiring/string_ids/test_main_py36.py b/tests/unit/wiring/string_ids/test_main_py36.py index d4c49fe8..8125481a 100644 --- a/tests/unit/wiring/string_ids/test_main_py36.py +++ b/tests/unit/wiring/string_ids/test_main_py36.py @@ -2,13 +2,13 @@ from decimal import Decimal +from pytest import fixture, mark, raises +from samples.wiringstringids import module, package, resourceclosing +from samples.wiringstringids.container import Container, SubContainer +from samples.wiringstringids.service import Service + from dependency_injector import errors from dependency_injector.wiring import Closing, Provide, Provider, wire -from pytest import fixture, mark, raises - -from samples.wiringstringids import module, package, resourceclosing -from samples.wiringstringids.service import Service -from samples.wiringstringids.container import Container, SubContainer @fixture(autouse=True) @@ -34,10 +34,11 @@ def subcontainer(): @fixture -def resourceclosing_container(): +def resourceclosing_container(request): container = resourceclosing.Container() container.wire(modules=[resourceclosing]) - yield container + with container.reset_singletons(): + yield container container.unwire() @@ -274,42 +275,65 @@ def test_wire_multiple_containers(): @mark.usefixtures("resourceclosing_container") def test_closing_resource(): - resourceclosing.Service.reset_counter() - result_1 = resourceclosing.test_function() assert isinstance(result_1, resourceclosing.Service) assert result_1.init_counter == 1 assert result_1.shutdown_counter == 1 + assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}} result_2 = resourceclosing.test_function() assert isinstance(result_2, resourceclosing.Service) assert result_2.init_counter == 2 assert result_2.shutdown_counter == 2 + assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}} assert result_1 is not result_2 @mark.usefixtures("resourceclosing_container") def test_closing_dependency_resource(): - resourceclosing.Service.reset_counter() - result_1 = resourceclosing.test_function_dependency() assert isinstance(result_1, resourceclosing.FactoryService) - assert result_1.service.init_counter == 1 - assert result_1.service.shutdown_counter == 1 + assert result_1.service.init_counter == 2 + assert result_1.service.shutdown_counter == 2 result_2 = resourceclosing.test_function_dependency() + assert isinstance(result_2, resourceclosing.FactoryService) - assert result_2.service.init_counter == 2 - assert result_2.service.shutdown_counter == 2 + assert result_2.service.init_counter == 4 + assert result_2.service.shutdown_counter == 4 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_dependency_resource_kwargs(): + result_1 = resourceclosing.test_function_dependency_kwargs() + assert isinstance(result_1, resourceclosing.FactoryService) + assert result_1.service.init_counter == 2 + assert result_1.service.shutdown_counter == 2 + + result_2 = resourceclosing.test_function_dependency_kwargs() + assert isinstance(result_2, resourceclosing.FactoryService) + assert result_2.service.init_counter == 4 + assert result_2.service.shutdown_counter == 4 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_nested_dependency_resource(): + result_1 = resourceclosing.test_function_nested_dependency() + assert isinstance(result_1, resourceclosing.NestedService) + assert result_1.factory_service.service.init_counter == 2 + assert result_1.factory_service.service.shutdown_counter == 2 + + result_2 = resourceclosing.test_function_nested_dependency() + assert isinstance(result_2, resourceclosing.NestedService) + assert result_2.factory_service.service.init_counter == 4 + assert result_2.factory_service.service.shutdown_counter == 4 assert result_1 is not result_2 @mark.usefixtures("resourceclosing_container") def test_closing_resource_bypass_marker_injection(): - resourceclosing.Service.reset_counter() - result_1 = resourceclosing.test_function(service=Closing[Provide["service"]]) assert isinstance(result_1, resourceclosing.Service) assert result_1.init_counter == 1 @@ -325,7 +349,6 @@ def test_closing_resource_bypass_marker_injection(): @mark.usefixtures("resourceclosing_container") def test_closing_resource_context(): - resourceclosing.Service.reset_counter() service = resourceclosing.Service() result_1 = resourceclosing.test_function(service=service)