From 2565a1eab00610cca80c57a090e2015139f248c4 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Wed, 28 Oct 2020 13:11:07 -0400 Subject: [PATCH] Fix wiring for @classmethod --- docs/main/changelog.rst | 4 ++++ src/dependency_injector/wiring.py | 6 +++++- tests/unit/samples/wiringsamples/module.py | 8 ++++++++ tests/unit/wiring/test_wiring_py36.py | 8 ++++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 3faef89b..01a71c21 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -7,6 +7,10 @@ that were made in every particular version. From version 0.7.6 *Dependency Injector* framework strictly follows `Semantic versioning`_ +Develop +----- +- Fix wiring for ``@classmethod``. + 4.1.5 ----- - Fix Travis CI windows and MacOS builds. diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index dbc5f9ba..5e8378bd 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -158,7 +158,7 @@ def wire( if inspect.isfunction(member): _patch_fn(module, name, member, providers_map) elif inspect.isclass(member): - for method_name, method in inspect.getmembers(member, inspect.isfunction): + for method_name, method in inspect.getmembers(member, _is_method): _patch_fn(member, method_name, method, providers_map) @@ -235,6 +235,10 @@ def _fetch_modules(package): return modules +def _is_method(member): + return inspect.ismethod(member) or inspect.isfunction(member) + + def _patch_with_injections(fn, injections): if inspect.iscoroutinefunction(fn): @functools.wraps(fn) diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py index d29a0ec9..917b7691 100644 --- a/tests/unit/samples/wiringsamples/module.py +++ b/tests/unit/samples/wiringsamples/module.py @@ -17,6 +17,14 @@ class TestClass: def method(self, service: Service = Provide[Container.service]): return service + @classmethod + def class_method(cls, service: Service = Provide[Container.service]): + return service + + @staticmethod + def static_method(service: Service = Provide[Container.service]): + return service + def test_function(service: Service = Provide[Container.service]): return service diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index 6e88be66..ff9925a5 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -61,6 +61,14 @@ class WiringTest(unittest.TestCase): service = test_class_object.method() self.assertIsInstance(service, Service) + def test_class_classmethod_wiring(self): + service = module.TestClass.class_method() + self.assertIsInstance(service, Service) + + def test_class_staticmethod_wiring(self): + service = module.TestClass.static_method() + self.assertIsInstance(service, Service) + def test_function_wiring(self): service = module.test_function() self.assertIsInstance(service, Service)