diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 3faef89b..5816a4e9 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -7,6 +7,13 @@ that were made in every particular version. From version 0.7.6 *Dependency Injector* framework strictly follows `Semantic versioning`_ +4.1.6 +----- +- Fix wiring of multiple containers + (see issue `#313 `_). + Thanks to `iskorini `_ for reporting the issue. +- Fix wiring for ``@classmethod``. + 4.1.5 ----- - Fix Travis CI windows and MacOS builds. diff --git a/src/dependency_injector/__init__.py b/src/dependency_injector/__init__.py index 5d4c95ef..5fc48203 100644 --- a/src/dependency_injector/__init__.py +++ b/src/dependency_injector/__init__.py @@ -1,6 +1,6 @@ """Top-level package.""" -__version__ = '4.1.5' +__version__ = '4.1.6' """Version number. :type: str diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index dbc5f9ba..e58a97cd 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -37,7 +37,10 @@ class ProvidersMap: original_providers=container.declarative_parent.providers, ) - def resolve_provider(self, provider: providers.Provider) -> providers.Provider: + def resolve_provider( + self, + provider: providers.Provider, + ) -> Optional[providers.Provider]: if isinstance(provider, providers.Delegate): return self._resolve_delegate(provider) elif isinstance(provider, ( @@ -54,10 +57,16 @@ class ProvidersMap: else: return self._resolve_provider(provider) - def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider: + def _resolve_delegate( + self, + original: providers.Delegate, + ) -> Optional[providers.Provider]: return self._resolve_provider(original.provides) - def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider: + def _resolve_provided_instance( + self, + original: providers.Provider, + ) -> Optional[providers.Provider]: modifiers = [] while isinstance(original, ( providers.ProvidedInstance, @@ -69,6 +78,8 @@ class ProvidersMap: original = original.provides new = self._resolve_provider(original) + if new is None: + return None for modifier in modifiers: if isinstance(modifier, providers.ProvidedInstance): @@ -89,9 +100,11 @@ class ProvidersMap: self, original: providers.ConfigurationOption, as_: Any = None, - ) -> providers.Provider: + ) -> Optional[providers.Provider]: original_root = original.root new = self._resolve_provider(original_root) + if new is None: + return None new = cast(providers.Configuration, new) for segment in original.get_name_segments(): @@ -106,11 +119,14 @@ class ProvidersMap: return new - def _resolve_provider(self, original: providers.Provider) -> providers.Provider: + def _resolve_provider( + self, + original: providers.Provider, + ) -> Optional[providers.Provider]: try: return self._map[original] except KeyError: - raise Exception('Unable to resolve original provider') + pass @classmethod def _create_providers_map( @@ -158,7 +174,7 @@ def wire( if inspect.isfunction(member): _patch_fn(module, name, member, providers_map) elif inspect.isclass(member): - for method_name, method in inspect.getmembers(member, inspect.isfunction): + for method_name, method in inspect.getmembers(member, _is_method): _patch_fn(member, method_name, method, providers_map) @@ -206,7 +222,7 @@ def _unpatch_fn( setattr(module, name, _get_original_from_patched(fn)) -def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa +def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: signature = inspect.signature(fn) injections = {} @@ -216,6 +232,9 @@ def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> marker = parameter.default provider = providers_map.resolve_provider(marker.provider) + if provider is None: + continue + if isinstance(marker, Provide): injections[parameter_name] = provider elif isinstance(marker, Provider): @@ -235,6 +254,10 @@ def _fetch_modules(package): return modules +def _is_method(member): + return inspect.ismethod(member) or inspect.isfunction(member) + + def _patch_with_injections(fn, injections): if inspect.iscoroutinefunction(fn): @functools.wraps(fn) diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py index d29a0ec9..c39d1414 100644 --- a/tests/unit/samples/wiringsamples/module.py +++ b/tests/unit/samples/wiringsamples/module.py @@ -5,7 +5,7 @@ from typing import Callable from dependency_injector.wiring import Provide, Provider -from .container import Container +from .container import Container, SubContainer from .service import Service @@ -17,6 +17,14 @@ class TestClass: def method(self, service: Service = Provide[Container.service]): return service + @classmethod + def class_method(cls, service: Service = Provide[Container.service]): + return service + + @staticmethod + def static_method(service: Service = Provide[Container.service]): + return service + def test_function(service: Service = Provide[Container.service]): return service @@ -50,3 +58,10 @@ def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_objec def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]): return some_value + + +def test_provide_from_different_containers( + service: Service = Provide[Container.service], + some_value: int = Provide[SubContainer.int_object], +): + return service, some_value diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index 6e88be66..14bdad84 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -16,7 +16,7 @@ sys.path.append(_SAMPLES_DIR) from wiringsamples import module, package from wiringsamples.service import Service -from wiringsamples.container import Container +from wiringsamples.container import Container, SubContainer class WiringTest(unittest.TestCase): @@ -61,6 +61,14 @@ class WiringTest(unittest.TestCase): service = test_class_object.method() self.assertIsInstance(service, Service) + def test_class_classmethod_wiring(self): + service = module.TestClass.class_method() + self.assertIsInstance(service, Service) + + def test_class_staticmethod_wiring(self): + service = module.TestClass.static_method() + self.assertIsInstance(service, Service) + def test_function_wiring(self): service = module.test_function() self.assertIsInstance(service, Service) @@ -153,3 +161,16 @@ class WiringTest(unittest.TestCase): from wiringsamples.package.subpackage import submodule self.container.unwire() self.assertIsInstance(submodule.test_function(), Provide) + + def test_wire_multiple_containers(self): + sub_container = SubContainer() + sub_container.wire( + modules=[module], + packages=[package], + ) + self.addCleanup(sub_container.unwire) + + service, some_value = module.test_provide_from_different_containers() + + self.assertIsInstance(service, Service) + self.assertEqual(some_value, 1)