diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 7806a06c..a8971082 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -158,7 +158,8 @@ def wire( if inspect.isfunction(member): _patch_fn(module, name, member, providers_map) elif inspect.isclass(member): - _patch_cls(member, providers_map) + for method_name, method in inspect.getmembers(member, inspect.isfunction): + _patch_fn(member, method_name, method, providers_map) def unwire( @@ -179,29 +180,8 @@ def unwire( if inspect.isfunction(member): _unpatch_fn(module, name, member) elif inspect.isclass(member): - _unpatch_cls(member,) - - -def _patch_cls( - cls: Type[Any], - providers_map: ProvidersMap, -) -> None: - if not hasattr(cls, '__init__'): - return - init_method = getattr(cls, '__init__') - injections = _resolve_injections(init_method, providers_map) - if not injections: - return - setattr(cls, '__init__', _patch_with_injections(init_method, injections)) - - -def _unpatch_cls(cls: Type[Any]) -> None: - if not hasattr(cls, '__init__'): - return - init_method = getattr(cls, '__init__') - if not _is_patched(init_method): - return - setattr(cls, '__init__', _get_original_from_patched(init_method)) + for method_name, method in inspect.getmembers(member, inspect.isfunction): + _unpatch_fn(member, method_name, method) def _patch_fn( diff --git a/tests/unit/wiring/module.py b/tests/unit/wiring/module.py index 8ac15f8c..d29a0ec9 100644 --- a/tests/unit/wiring/module.py +++ b/tests/unit/wiring/module.py @@ -14,6 +14,9 @@ class TestClass: def __init__(self, service: Service = Provide[Container.service]): self.service = service + def method(self, service: Service = Provide[Container.service]): + return service + def test_function(service: Service = Provide[Container.service]): return service diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index fddafba2..862a8386 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -1,7 +1,7 @@ from decimal import Decimal import unittest -from dependency_injector.wiring import wire +from dependency_injector.wiring import wire, Provide from . import module, package from .service import Service @@ -35,6 +35,11 @@ class WiringTest(unittest.TestCase): test_class_object = module.TestClass(service=test_service) self.assertIs(test_class_object.service, test_service) + def test_class_method_wiring(self): + test_class_object = module.TestClass() + service = test_class_object.method() + self.assertIsInstance(service, Service) + def test_function_wiring(self): service = module.test_function() self.assertIsInstance(service, Service) @@ -104,3 +109,26 @@ class WiringTest(unittest.TestCase): modules=[module], ) + def test_unwire_function(self): + self.container.unwire() + self.assertIsInstance(module.test_function(), Provide) + + def test_unwire_class(self): + self.container.unwire() + test_class_object = module.TestClass() + self.assertIsInstance(test_class_object.service, Provide) + + def test_unwire_class_method(self): + self.container.unwire() + test_class_object = module.TestClass() + self.assertIsInstance(test_class_object.method(), Provide) + + def test_unwire_package_function(self): + self.container.unwire() + from .package.subpackage.submodule import test_function + self.assertIsInstance(test_function(), Provide) + + def test_unwire_package_function_by_reference(self): + from .package.subpackage import submodule + self.container.unwire() + self.assertIsInstance(submodule.test_function(), Provide)