From 776cb4eebfb420007af30358d166f89b7ef04854 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Wed, 28 Oct 2020 13:44:11 -0400 Subject: [PATCH] Fix multiple containers wiring issue --- docs/main/changelog.rst | 3 +++ src/dependency_injector/wiring.py | 19 +++++++++++++------ tests/unit/samples/wiringsamples/module.py | 9 ++++++++- tests/unit/wiring/test_wiring_py36.py | 15 ++++++++++++++- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 01a71c21..0650aedf 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -9,6 +9,9 @@ follows `Semantic versioning`_ Develop ----- +- Fix wiring of multiple containers + (see issue `#313 `_). + Thanks to `iskorini `_ for reporting the issue. - Fix wiring for ``@classmethod``. 4.1.5 diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 5e8378bd..2b8af542 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -37,7 +37,7 @@ class ProvidersMap: original_providers=container.declarative_parent.providers, ) - def resolve_provider(self, provider: providers.Provider) -> providers.Provider: + def resolve_provider(self, provider: providers.Provider) -> Optional[providers.Provider]: if isinstance(provider, providers.Delegate): return self._resolve_delegate(provider) elif isinstance(provider, ( @@ -54,10 +54,10 @@ class ProvidersMap: else: return self._resolve_provider(provider) - def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider: + def _resolve_delegate(self, original: providers.Delegate) -> Optional[providers.Provider]: return self._resolve_provider(original.provides) - def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider: + def _resolve_provided_instance(self, original: providers.Provider) -> Optional[providers.Provider]: modifiers = [] while isinstance(original, ( providers.ProvidedInstance, @@ -69,6 +69,8 @@ class ProvidersMap: original = original.provides new = self._resolve_provider(original) + if new is None: + return None for modifier in modifiers: if isinstance(modifier, providers.ProvidedInstance): @@ -89,9 +91,11 @@ class ProvidersMap: self, original: providers.ConfigurationOption, as_: Any = None, - ) -> providers.Provider: + ) -> Optional[providers.Provider]: original_root = original.root new = self._resolve_provider(original_root) + if new is None: + return None new = cast(providers.Configuration, new) for segment in original.get_name_segments(): @@ -106,11 +110,11 @@ class ProvidersMap: return new - def _resolve_provider(self, original: providers.Provider) -> providers.Provider: + def _resolve_provider(self, original: providers.Provider) -> Optional[providers.Provider]: try: return self._map[original] except KeyError: - raise Exception('Unable to resolve original provider') + pass @classmethod def _create_providers_map( @@ -216,6 +220,9 @@ def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> marker = parameter.default provider = providers_map.resolve_provider(marker.provider) + if provider is None: + continue + if isinstance(marker, Provide): injections[parameter_name] = provider elif isinstance(marker, Provider): diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py index 917b7691..c39d1414 100644 --- a/tests/unit/samples/wiringsamples/module.py +++ b/tests/unit/samples/wiringsamples/module.py @@ -5,7 +5,7 @@ from typing import Callable from dependency_injector.wiring import Provide, Provider -from .container import Container +from .container import Container, SubContainer from .service import Service @@ -58,3 +58,10 @@ def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_objec def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]): return some_value + + +def test_provide_from_different_containers( + service: Service = Provide[Container.service], + some_value: int = Provide[SubContainer.int_object], +): + return service, some_value diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index ff9925a5..14bdad84 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -16,7 +16,7 @@ sys.path.append(_SAMPLES_DIR) from wiringsamples import module, package from wiringsamples.service import Service -from wiringsamples.container import Container +from wiringsamples.container import Container, SubContainer class WiringTest(unittest.TestCase): @@ -161,3 +161,16 @@ class WiringTest(unittest.TestCase): from wiringsamples.package.subpackage import submodule self.container.unwire() self.assertIsInstance(submodule.test_function(), Provide) + + def test_wire_multiple_containers(self): + sub_container = SubContainer() + sub_container.wire( + modules=[module], + packages=[package], + ) + self.addCleanup(sub_container.unwire) + + service, some_value = module.test_provide_from_different_containers() + + self.assertIsInstance(service, Service) + self.assertEqual(some_value, 1)