Merge branch 'release/4.1.6' into master

This commit is contained in:
Roman Mogylatov 2020-10-28 14:23:14 -04:00
commit c0e5eac016
5 changed files with 77 additions and 11 deletions

View File

@ -7,6 +7,13 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_ follows `Semantic versioning`_
4.1.6
-----
- Fix wiring of multiple containers
(see issue `#313 <https://github.com/ets-labs/python-dependency-injector/issues/313>`_).
Thanks to `iskorini <https://github.com/iskorini>`_ for reporting the issue.
- Fix wiring for ``@classmethod``.
4.1.5 4.1.5
----- -----
- Fix Travis CI windows and MacOS builds. - Fix Travis CI windows and MacOS builds.

View File

@ -1,6 +1,6 @@
"""Top-level package.""" """Top-level package."""
__version__ = '4.1.5' __version__ = '4.1.6'
"""Version number. """Version number.
:type: str :type: str

View File

@ -37,7 +37,10 @@ class ProvidersMap:
original_providers=container.declarative_parent.providers, original_providers=container.declarative_parent.providers,
) )
def resolve_provider(self, provider: providers.Provider) -> providers.Provider: def resolve_provider(
self,
provider: providers.Provider,
) -> Optional[providers.Provider]:
if isinstance(provider, providers.Delegate): if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider) return self._resolve_delegate(provider)
elif isinstance(provider, ( elif isinstance(provider, (
@ -54,10 +57,16 @@ class ProvidersMap:
else: else:
return self._resolve_provider(provider) return self._resolve_provider(provider)
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider: def _resolve_delegate(
self,
original: providers.Delegate,
) -> Optional[providers.Provider]:
return self._resolve_provider(original.provides) return self._resolve_provider(original.provides)
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider: def _resolve_provided_instance(
self,
original: providers.Provider,
) -> Optional[providers.Provider]:
modifiers = [] modifiers = []
while isinstance(original, ( while isinstance(original, (
providers.ProvidedInstance, providers.ProvidedInstance,
@ -69,6 +78,8 @@ class ProvidersMap:
original = original.provides original = original.provides
new = self._resolve_provider(original) new = self._resolve_provider(original)
if new is None:
return None
for modifier in modifiers: for modifier in modifiers:
if isinstance(modifier, providers.ProvidedInstance): if isinstance(modifier, providers.ProvidedInstance):
@ -89,9 +100,11 @@ class ProvidersMap:
self, self,
original: providers.ConfigurationOption, original: providers.ConfigurationOption,
as_: Any = None, as_: Any = None,
) -> providers.Provider: ) -> Optional[providers.Provider]:
original_root = original.root original_root = original.root
new = self._resolve_provider(original_root) new = self._resolve_provider(original_root)
if new is None:
return None
new = cast(providers.Configuration, new) new = cast(providers.Configuration, new)
for segment in original.get_name_segments(): for segment in original.get_name_segments():
@ -106,11 +119,14 @@ class ProvidersMap:
return new return new
def _resolve_provider(self, original: providers.Provider) -> providers.Provider: def _resolve_provider(
self,
original: providers.Provider,
) -> Optional[providers.Provider]:
try: try:
return self._map[original] return self._map[original]
except KeyError: except KeyError:
raise Exception('Unable to resolve original provider') pass
@classmethod @classmethod
def _create_providers_map( def _create_providers_map(
@ -158,7 +174,7 @@ def wire(
if inspect.isfunction(member): if inspect.isfunction(member):
_patch_fn(module, name, member, providers_map) _patch_fn(module, name, member, providers_map)
elif inspect.isclass(member): elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, inspect.isfunction): for method_name, method in inspect.getmembers(member, _is_method):
_patch_fn(member, method_name, method, providers_map) _patch_fn(member, method_name, method, providers_map)
@ -206,7 +222,7 @@ def _unpatch_fn(
setattr(module, name, _get_original_from_patched(fn)) setattr(module, name, _get_original_from_patched(fn))
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]:
signature = inspect.signature(fn) signature = inspect.signature(fn)
injections = {} injections = {}
@ -216,6 +232,9 @@ def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) ->
marker = parameter.default marker = parameter.default
provider = providers_map.resolve_provider(marker.provider) provider = providers_map.resolve_provider(marker.provider)
if provider is None:
continue
if isinstance(marker, Provide): if isinstance(marker, Provide):
injections[parameter_name] = provider injections[parameter_name] = provider
elif isinstance(marker, Provider): elif isinstance(marker, Provider):
@ -235,6 +254,10 @@ def _fetch_modules(package):
return modules return modules
def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member)
def _patch_with_injections(fn, injections): def _patch_with_injections(fn, injections):
if inspect.iscoroutinefunction(fn): if inspect.iscoroutinefunction(fn):
@functools.wraps(fn) @functools.wraps(fn)

View File

@ -5,7 +5,7 @@ from typing import Callable
from dependency_injector.wiring import Provide, Provider from dependency_injector.wiring import Provide, Provider
from .container import Container from .container import Container, SubContainer
from .service import Service from .service import Service
@ -17,6 +17,14 @@ class TestClass:
def method(self, service: Service = Provide[Container.service]): def method(self, service: Service = Provide[Container.service]):
return service return service
@classmethod
def class_method(cls, service: Service = Provide[Container.service]):
return service
@staticmethod
def static_method(service: Service = Provide[Container.service]):
return service
def test_function(service: Service = Provide[Container.service]): def test_function(service: Service = Provide[Container.service]):
return service return service
@ -50,3 +58,10 @@ def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_objec
def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]): def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]):
return some_value return some_value
def test_provide_from_different_containers(
service: Service = Provide[Container.service],
some_value: int = Provide[SubContainer.int_object],
):
return service, some_value

View File

@ -16,7 +16,7 @@ sys.path.append(_SAMPLES_DIR)
from wiringsamples import module, package from wiringsamples import module, package
from wiringsamples.service import Service from wiringsamples.service import Service
from wiringsamples.container import Container from wiringsamples.container import Container, SubContainer
class WiringTest(unittest.TestCase): class WiringTest(unittest.TestCase):
@ -61,6 +61,14 @@ class WiringTest(unittest.TestCase):
service = test_class_object.method() service = test_class_object.method()
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)
def test_class_classmethod_wiring(self):
service = module.TestClass.class_method()
self.assertIsInstance(service, Service)
def test_class_staticmethod_wiring(self):
service = module.TestClass.static_method()
self.assertIsInstance(service, Service)
def test_function_wiring(self): def test_function_wiring(self):
service = module.test_function() service = module.test_function()
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)
@ -153,3 +161,16 @@ class WiringTest(unittest.TestCase):
from wiringsamples.package.subpackage import submodule from wiringsamples.package.subpackage import submodule
self.container.unwire() self.container.unwire()
self.assertIsInstance(submodule.test_function(), Provide) self.assertIsInstance(submodule.test_function(), Provide)
def test_wire_multiple_containers(self):
sub_container = SubContainer()
sub_container.wire(
modules=[module],
packages=[package],
)
self.addCleanup(sub_container.unwire)
service, some_value = module.test_provide_from_different_containers()
self.assertIsInstance(service, Service)
self.assertEqual(some_value, 1)