From da133414531ef8031255c2240d41ff80c407f62e Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sun, 28 Feb 2021 21:07:50 -0500 Subject: [PATCH] Wiring: attribute injections (#414) * Add implementation * Add tests for module and class * Add tests for module and class for string ids * Update tests with typing * Add tests for invalid type of marker * Add docs and the example * Update changelog * Fix Python 3.6 tests and flake8 --- docs/main/changelog.rst | 7 + docs/wiring.rst | 23 ++++ examples/wiring/example_attribute.py | 31 +++++ src/dependency_injector/wiring.py | 123 ++++++++++++++---- tests/unit/samples/wiringsamples/module.py | 10 ++ .../module_invalid_attr_injection.py | 8 ++ .../samples/wiringstringidssamples/module.py | 9 ++ tests/unit/wiring/test_wiring_py36.py | 32 +++++ .../wiring/test_wiring_string_ids_py36.py | 26 +++- 9 files changed, 242 insertions(+), 27 deletions(-) create mode 100644 examples/wiring/example_attribute.py create mode 100644 tests/unit/samples/wiringsamples/module_invalid_attr_injection.py diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index b2143e90..c74f716c 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -7,6 +7,13 @@ that were made in every particular version. From version 0.7.6 *Dependency Injector* framework strictly follows `Semantic versioning`_ +Development version +------------------- +- Add wiring injections into modules and class attributes. + See issue: `#411 `_. + Many thanks to `@brunopereira27 `_ for submitting + the use case. + 4.27.0 ------ - Introduce wiring inspect filter to filter out ``flask.request`` and other local proxy objects diff --git a/docs/wiring.rst b/docs/wiring.rst index d4e2d6a6..551720f5 100644 --- a/docs/wiring.rst +++ b/docs/wiring.rst @@ -164,6 +164,29 @@ To inject a container use special identifier ````: def foo(container: Container = Provide['']) -> None: ... + +Making injections into modules and class attributes +--------------------------------------------------- + +You can use wiring to make injections into modules and class attributes. + +.. literalinclude:: ../examples/wiring/example_attribute.py + :language: python + :lines: 3- + :emphasize-lines: 16,21 + +You could also use string identifiers to avoid a dependency on a container: + +.. code-block:: python + :emphasize-lines: 1,6 + + service: Service = Provide['service'] + + + class Main: + + service: Service = Provide['service'] + Wiring with modules and packages -------------------------------- diff --git a/examples/wiring/example_attribute.py b/examples/wiring/example_attribute.py new file mode 100644 index 00000000..9868ae66 --- /dev/null +++ b/examples/wiring/example_attribute.py @@ -0,0 +1,31 @@ +"""Wiring attribute example.""" + +import sys + +from dependency_injector import containers, providers +from dependency_injector.wiring import Provide + + +class Service: + ... + + +class Container(containers.DeclarativeContainer): + + service = providers.Factory(Service) + + +service: Service = Provide[Container.service] + + +class Main: + + service: Service = Provide[Container.service] + + +if __name__ == '__main__': + container = Container() + container.wire(modules=[sys.modules[__name__]]) + + assert isinstance(service, Service) + assert isinstance(Main.service, Service) diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index c5a591da..222e0d63 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -20,6 +20,7 @@ from typing import ( TypeVar, Type, Union, + Set, cast, ) @@ -82,22 +83,53 @@ F = TypeVar('F', bound=Callable[..., Any]) Container = Any -class Registry: +class PatchedRegistry: def __init__(self): - self._storage = set() + self._callables: Set[Callable[..., Any]] = set() + self._attributes: Set[PatchedAttribute] = set() - def add(self, patched: Callable[..., Any]) -> None: - self._storage.add(patched) + def add_callable(self, patched: Callable[..., Any]) -> None: + self._callables.add(patched) - def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: - for patched in self._storage: + def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: + for patched in self._callables: if patched.__module__ != module.__name__: continue yield patched + def add_attribute(self, patched: 'PatchedAttribute'): + self._attributes.add(patched) -_patched_registry = Registry() + def get_attributes_from_module(self, module: ModuleType) -> Iterator['PatchedAttribute']: + for attribute in self._attributes: + if not attribute.is_in_module(module): + continue + yield attribute + + def clear_module_attributes(self, module: ModuleType): + for attribute in self._attributes.copy(): + if not attribute.is_in_module(module): + continue + self._attributes.remove(attribute) + + +class PatchedAttribute: + + def __init__(self, member: Any, name: str, marker: '_Marker'): + self.member = member + self.name = name + self.marker = marker + + @property + def module_name(self) -> str: + if isinstance(self.member, ModuleType): + return self.member.__name__ + else: + return self.member.__module__ + + def is_in_module(self, module: ModuleType) -> bool: + return self.module_name == module.__name__ class ProvidersMap: @@ -278,9 +310,6 @@ class InspectFilter: and issubclass(instance, starlette.requests.Request) -inspect_filter = InspectFilter() - - def wire( # noqa: C901 container: Container, *, @@ -301,20 +330,27 @@ def wire( # noqa: C901 providers_map = ProvidersMap(container) for module in modules: - for name, member in inspect.getmembers(module): - if inspect_filter.is_excluded(member): + for member_name, member in inspect.getmembers(module): + if _inspect_filter.is_excluded(member): continue - if inspect.isfunction(member): - _patch_fn(module, name, member, providers_map) - elif inspect.isclass(member): - for method_name, method in inspect.getmembers(member, _is_method): - _patch_method(member, method_name, method, providers_map) - for patched in _patched_registry.get_from_module(module): + if _is_marker(member): + _patch_attribute(module, member_name, member, providers_map) + elif inspect.isfunction(member): + _patch_fn(module, member_name, member, providers_map) + elif inspect.isclass(member): + cls = member + for cls_member_name, cls_member in inspect.getmembers(cls): + if _is_marker(cls_member): + _patch_attribute(cls, cls_member_name, cls_member, providers_map) + elif _is_method(cls_member): + _patch_method(cls, cls_member_name, cls_member, providers_map) + + for patched in _patched_registry.get_callables_from_module(module): _bind_injections(patched, providers_map) -def unwire( +def unwire( # noqa: C901 *, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None, @@ -335,15 +371,19 @@ def unwire( for method_name, method in inspect.getmembers(member, inspect.isfunction): _unpatch(member, method_name, method) - for patched in _patched_registry.get_from_module(module): + for patched in _patched_registry.get_callables_from_module(module): _unbind_injections(patched) + for patched_attribute in _patched_registry.get_attributes_from_module(module): + _unpatch_attribute(patched_attribute) + _patched_registry.clear_module_attributes(module) + def inject(fn: F) -> F: """Decorate callable with injecting decorator.""" reference_injections, reference_closing = _fetch_reference_injections(fn) patched = _get_patched(fn, reference_injections, reference_closing) - _patched_registry.add(patched) + _patched_registry.add_callable(patched) return cast(F, patched) @@ -358,7 +398,7 @@ def _patch_fn( if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) - _patched_registry.add(fn) + _patched_registry.add_callable(fn) _bind_injections(fn, providers_map) @@ -384,7 +424,7 @@ def _patch_method( if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) - _patched_registry.add(fn) + _patched_registry.add_callable(fn) _bind_injections(fn, providers_map) @@ -411,6 +451,31 @@ def _unpatch( _unbind_injections(fn) +def _patch_attribute( + member: Any, + name: str, + marker: '_Marker', + providers_map: ProvidersMap, +) -> None: + provider = providers_map.resolve_provider(marker.provider, marker.modifier) + if provider is None: + return + + _patched_registry.add_attribute(PatchedAttribute(member, name, marker)) + + if isinstance(marker, Provide): + instance = provider() + setattr(member, name, instance) + elif isinstance(marker, Provider): + setattr(member, name, provider) + else: + raise Exception(f'Unknown type of marker {marker}') + + +def _unpatch_attribute(patched: PatchedAttribute) -> None: + setattr(patched.member, patched.name, patched.marker) + + def _fetch_reference_injections( fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -484,6 +549,10 @@ def _is_method(member): return inspect.ismethod(member) or inspect.isfunction(member) +def _is_marker(member): + return isinstance(member, _Marker) + + def _get_patched(fn, reference_injections, reference_closing): if inspect.iscoroutinefunction(fn): patched = _get_async_patched(fn) @@ -825,9 +894,6 @@ class AutoLoader: importlib.invalidate_caches() -_loader = AutoLoader() - - def register_loader_containers(*containers: Container) -> None: """Register containers in auto-wiring module loader.""" _loader.register_containers(*containers) @@ -851,3 +917,8 @@ def uninstall_loader() -> None: def is_loader_installed() -> bool: """Check if auto-wiring module loader hook is installed.""" return _loader.installed + + +_patched_registry = PatchedRegistry() +_inspect_filter = InspectFilter() +_loader = AutoLoader() diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py index 4c02b553..333de332 100644 --- a/tests/unit/samples/wiringsamples/module.py +++ b/tests/unit/samples/wiringsamples/module.py @@ -3,14 +3,24 @@ from decimal import Decimal from typing import Callable +from dependency_injector import providers from dependency_injector.wiring import inject, Provide, Provider from .container import Container, SubContainer from .service import Service +service: Service = Provide[Container.service] +service_provider: Callable[..., Service] = Provider[Container.service] +undefined: Callable = Provide[providers.Provider()] + + class TestClass: + service: Service = Provide[Container.service] + service_provider: Callable[..., Service] = Provider[Container.service] + undefined: Callable = Provide[providers.Provider()] + @inject def __init__(self, service: Service = Provide[Container.service]): self.service = service diff --git a/tests/unit/samples/wiringsamples/module_invalid_attr_injection.py b/tests/unit/samples/wiringsamples/module_invalid_attr_injection.py new file mode 100644 index 00000000..20f20846 --- /dev/null +++ b/tests/unit/samples/wiringsamples/module_invalid_attr_injection.py @@ -0,0 +1,8 @@ +"""Test module for wiring with invalid type of marker for attribute injection.""" + +from dependency_injector.wiring import Closing + +from .container import Container + + +service = Closing[Container.service] diff --git a/tests/unit/samples/wiringstringidssamples/module.py b/tests/unit/samples/wiringstringidssamples/module.py index 019e290b..e56b7892 100644 --- a/tests/unit/samples/wiringstringidssamples/module.py +++ b/tests/unit/samples/wiringstringidssamples/module.py @@ -19,8 +19,17 @@ from .container import Container from .service import Service +service: Service = Provide['service'] +service_provider: Callable[..., Service] = Provider['service'] +undefined: Callable = Provide['undefined'] + + class TestClass: + service: Service = Provide['service'] + service_provider: Callable[..., Service] = Provider['service'] + undefined: Callable = Provide['undefined'] + @inject def __init__(self, service: Service = Provide['service']): self.service = service diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index fd08799b..91e7df5e 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -6,6 +6,7 @@ import unittest from dependency_injector.wiring import ( wire, Provide, + Provider, Closing, register_loader_containers, unregister_loader_containers, @@ -64,6 +65,20 @@ class WiringTest(unittest.TestCase): service = test_function() self.assertIsInstance(service, Service) + def test_module_attributes_wiring(self): + self.assertIsInstance(module.service, Service) + self.assertIsInstance(module.service_provider(), Service) + self.assertIsInstance(module.undefined, Provide) + + def test_module_attribute_wiring_with_invalid_marker(self): + from wiringsamples import module_invalid_attr_injection + with self.assertRaises(Exception) as context: + self.container.wire(modules=[module_invalid_attr_injection]) + self.assertEqual( + str(context.exception), + 'Unknown type of marker {0}'.format(module_invalid_attr_injection.service), + ) + def test_class_wiring(self): test_class_object = module.TestClass() self.assertIsInstance(test_class_object.service, Service) @@ -97,6 +112,11 @@ class WiringTest(unittest.TestCase): service = instance.static_method() self.assertIsInstance(service, Service) + def test_class_attribute_wiring(self): + self.assertIsInstance(module.TestClass.service, Service) + self.assertIsInstance(module.TestClass.service_provider(), Service) + self.assertIsInstance(module.TestClass.undefined, Provide) + def test_function_wiring(self): service = module.test_function() self.assertIsInstance(service, Service) @@ -215,6 +235,18 @@ class WiringTest(unittest.TestCase): self.container.unwire() self.assertIsInstance(submodule.test_function(), Provide) + def test_unwire_module_attributes(self): + self.container.unwire() + self.assertIsInstance(module.service, Provide) + self.assertIsInstance(module.service_provider, Provider) + self.assertIsInstance(module.undefined, Provide) + + def test_unwire_class_attributes(self): + self.container.unwire() + self.assertIsInstance(module.TestClass.service, Provide) + self.assertIsInstance(module.TestClass.service_provider, Provider) + self.assertIsInstance(module.TestClass.undefined, Provide) + def test_wire_multiple_containers(self): sub_container = SubContainer() sub_container.wire( diff --git a/tests/unit/wiring/test_wiring_string_ids_py36.py b/tests/unit/wiring/test_wiring_string_ids_py36.py index 42002372..0ecc5d7c 100644 --- a/tests/unit/wiring/test_wiring_string_ids_py36.py +++ b/tests/unit/wiring/test_wiring_string_ids_py36.py @@ -4,7 +4,9 @@ import unittest from dependency_injector.wiring import ( wire, Provide, - Closing) + Provider, + Closing, +) from dependency_injector import errors # Runtime import to avoid syntax errors in samples on Python < 3.5 @@ -59,6 +61,11 @@ class WiringTest(unittest.TestCase): service = test_function() self.assertIsInstance(service, Service) + def test_module_attributes_wiring(self): + self.assertIsInstance(module.service, Service) + self.assertIsInstance(module.service_provider(), Service) + self.assertIsInstance(module.undefined, Provide) + def test_class_wiring(self): test_class_object = module.TestClass() self.assertIsInstance(test_class_object.service, Service) @@ -92,6 +99,11 @@ class WiringTest(unittest.TestCase): service = instance.static_method() self.assertIsInstance(service, Service) + def test_class_attribute_wiring(self): + self.assertIsInstance(module.TestClass.service, Service) + self.assertIsInstance(module.TestClass.service_provider(), Service) + self.assertIsInstance(module.TestClass.undefined, Provide) + def test_function_wiring(self): service = module.test_function() self.assertIsInstance(service, Service) @@ -210,6 +222,18 @@ class WiringTest(unittest.TestCase): self.container.unwire() self.assertIsInstance(submodule.test_function(), Provide) + def test_unwire_module_attributes(self): + self.container.unwire() + self.assertIsInstance(module.service, Provide) + self.assertIsInstance(module.service_provider, Provider) + self.assertIsInstance(module.undefined, Provide) + + def test_unwire_class_attributes(self): + self.container.unwire() + self.assertIsInstance(module.TestClass.service, Provide) + self.assertIsInstance(module.TestClass.service_provider, Provider) + self.assertIsInstance(module.TestClass.undefined, Provide) + def test_wire_multiple_containers(self): sub_container = SubContainer() sub_container.wire(