Fix multiple containers wiring issue

This commit is contained in:
Roman Mogylatov 2020-10-28 13:44:11 -04:00
parent 2565a1eab0
commit 776cb4eebf
4 changed files with 38 additions and 8 deletions

View File

@ -9,6 +9,9 @@ follows `Semantic versioning`_
Develop
-----
- Fix wiring of multiple containers
(see issue `#313 <https://github.com/ets-labs/python-dependency-injector/issues/313>`_).
Thanks to `iskorini <https://github.com/iskorini>`_ for reporting the issue.
- Fix wiring for ``@classmethod``.
4.1.5

View File

@ -37,7 +37,7 @@ class ProvidersMap:
original_providers=container.declarative_parent.providers,
)
def resolve_provider(self, provider: providers.Provider) -> providers.Provider:
def resolve_provider(self, provider: providers.Provider) -> Optional[providers.Provider]:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
elif isinstance(provider, (
@ -54,10 +54,10 @@ class ProvidersMap:
else:
return self._resolve_provider(provider)
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider:
def _resolve_delegate(self, original: providers.Delegate) -> Optional[providers.Provider]:
return self._resolve_provider(original.provides)
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider:
def _resolve_provided_instance(self, original: providers.Provider) -> Optional[providers.Provider]:
modifiers = []
while isinstance(original, (
providers.ProvidedInstance,
@ -69,6 +69,8 @@ class ProvidersMap:
original = original.provides
new = self._resolve_provider(original)
if new is None:
return None
for modifier in modifiers:
if isinstance(modifier, providers.ProvidedInstance):
@ -89,9 +91,11 @@ class ProvidersMap:
self,
original: providers.ConfigurationOption,
as_: Any = None,
) -> providers.Provider:
) -> Optional[providers.Provider]:
original_root = original.root
new = self._resolve_provider(original_root)
if new is None:
return None
new = cast(providers.Configuration, new)
for segment in original.get_name_segments():
@ -106,11 +110,11 @@ class ProvidersMap:
return new
def _resolve_provider(self, original: providers.Provider) -> providers.Provider:
def _resolve_provider(self, original: providers.Provider) -> Optional[providers.Provider]:
try:
return self._map[original]
except KeyError:
raise Exception('Unable to resolve original provider')
pass
@classmethod
def _create_providers_map(
@ -216,6 +220,9 @@ def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) ->
marker = parameter.default
provider = providers_map.resolve_provider(marker.provider)
if provider is None:
continue
if isinstance(marker, Provide):
injections[parameter_name] = provider
elif isinstance(marker, Provider):

View File

@ -5,7 +5,7 @@ from typing import Callable
from dependency_injector.wiring import Provide, Provider
from .container import Container
from .container import Container, SubContainer
from .service import Service
@ -58,3 +58,10 @@ def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_objec
def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]):
return some_value
def test_provide_from_different_containers(
service: Service = Provide[Container.service],
some_value: int = Provide[SubContainer.int_object],
):
return service, some_value

View File

@ -16,7 +16,7 @@ sys.path.append(_SAMPLES_DIR)
from wiringsamples import module, package
from wiringsamples.service import Service
from wiringsamples.container import Container
from wiringsamples.container import Container, SubContainer
class WiringTest(unittest.TestCase):
@ -161,3 +161,16 @@ class WiringTest(unittest.TestCase):
from wiringsamples.package.subpackage import submodule
self.container.unwire()
self.assertIsInstance(submodule.test_function(), Provide)
def test_wire_multiple_containers(self):
sub_container = SubContainer()
sub_container.wire(
modules=[module],
packages=[package],
)
self.addCleanup(sub_container.unwire)
service, some_value = module.test_provide_from_different_containers()
self.assertIsInstance(service, Service)
self.assertEqual(some_value, 1)