diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 8fffd56d..6a5f176e 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -6,7 +6,7 @@ import importlib import pkgutil import sys from types import ModuleType -from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast +from typing import Optional, Iterable, Iterator, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast if sys.version_info < (3, 7): from typing import GenericMeta @@ -32,6 +32,24 @@ F = TypeVar('F', bound=Callable[..., Any]) Container = Any +class Registry: + + def __init__(self): + self._storage = set() + + def add(self, patched: Callable[..., Any]) -> None: + self._storage.add(patched) + + def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: + for patched in self._storage: + if patched.__module__ != module.__name__: + continue + yield patched + + +_patched_registry = Registry() + + class ProvidersMap: def __init__(self, container): @@ -181,6 +199,9 @@ def wire( for method_name, method in inspect.getmembers(member, _is_method): _patch_method(member, method_name, method, providers_map) + for patched in _patched_registry.get_from_module(module): + _bind_injections(patched, providers_map) + def unwire( *, @@ -203,11 +224,15 @@ def unwire( for method_name, method in inspect.getmembers(member, inspect.isfunction): _unpatch(member, method_name, method) + for patched in _patched_registry.get_from_module(module): + _unbind_injections(patched) + def inject(fn: F) -> F: """Decorate callable with injecting decorator.""" reference_injections, reference_closing = _fetch_reference_injections(fn) patched = _get_patched(fn, reference_injections, reference_closing) + _patched_registry.add(patched) return cast(F, patched) @@ -222,6 +247,7 @@ def _patch_fn( if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) + _patched_registry.add(fn) _bind_injections(fn, providers_map) @@ -247,6 +273,7 @@ def _patch_method( if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) + _patched_registry.add(fn) _bind_injections(fn, providers_map) diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py index c39d1414..ac8517ea 100644 --- a/tests/unit/samples/wiringsamples/module.py +++ b/tests/unit/samples/wiringsamples/module.py @@ -3,7 +3,7 @@ from decimal import Decimal from typing import Callable -from dependency_injector.wiring import Provide, Provider +from dependency_injector.wiring import inject, Provide, Provider from .container import Container, SubContainer from .service import Service @@ -65,3 +65,17 @@ def test_provide_from_different_containers( some_value: int = Provide[SubContainer.int_object], ): return service, some_value + + +class ClassDecorator: + def __init__(self, fn): + self._fn = fn + + def __call__(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + +@ClassDecorator +@inject +def test_class_decorator(service: Service = Provide[Container.service]): + return service diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index b1886ea7..7c061ad9 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -226,6 +226,10 @@ class WiringTest(unittest.TestCase): self.assertEqual(result_2.init_counter, 0) self.assertEqual(result_2.shutdown_counter, 0) + def test_class_decorator(self): + service = module.test_class_decorator() + self.assertIsInstance(service, Service) + class WiringAndFastAPITest(unittest.TestCase):