From fb2d927caea56874d897fc38167841632b06e2c5 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Tue, 3 Nov 2020 15:59:02 -0500 Subject: [PATCH] Fix wiring for @classmethod and @staticmethod --- docs/main/changelog.rst | 5 +++ src/dependency_injector/wiring.py | 49 ++++++++++++++++++++------- tests/unit/wiring/test_wiring_py36.py | 10 ++++++ 3 files changed, 52 insertions(+), 12 deletions(-) diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index e22993d1..6fc6a0c5 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -7,6 +7,11 @@ that were made in every particular version. From version 0.7.6 *Dependency Injector* framework strictly follows `Semantic versioning`_ +Develop +------- +- Fix a bug in ``wiring`` with improper patching of ``@classmethod`` and ``@staticmethod`` decorated methods + (See issue `#318 `_). + 4.3.2 ----- - Fix a bug in ``wiring`` with mistakenly initialized and shutdown resource with ``Closing`` diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index f169a7e8..1c4fcbf4 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -5,7 +5,7 @@ import inspect import pkgutil import sys from types import ModuleType -from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast +from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast if sys.version_info < (3, 7): from typing import GenericMeta @@ -176,7 +176,7 @@ def wire( _patch_fn(module, name, member, providers_map) elif inspect.isclass(member): for method_name, method in inspect.getmembers(member, _is_method): - _patch_fn(member, method_name, method, providers_map) + _patch_method(member, method_name, method, providers_map) def unwire( @@ -195,10 +195,10 @@ def unwire( for module in modules: for name, member in inspect.getmembers(module): if inspect.isfunction(member): - _unpatch_fn(module, name, member) + _unpatch(module, name, member) elif inspect.isclass(member): for method_name, method in inspect.getmembers(member, inspect.isfunction): - _unpatch_fn(member, method_name, method) + _unpatch(member, method_name, method) def _patch_fn( @@ -210,10 +210,41 @@ def _patch_fn( injections, closing = _resolve_injections(fn, providers_map) if not injections: return - setattr(module, name, _patch_with_injections(fn, injections, closing)) + patched = _patch_with_injections(fn, injections, closing) + setattr(module, name, _wrap_patched(patched, fn, injections, closing)) -def _unpatch_fn( +def _patch_method( + cls: Type, + name: str, + method: Callable[..., Any], + providers_map: ProvidersMap, +) -> None: + injections, closing = _resolve_injections(method, providers_map) + if not injections: + return + + if hasattr(cls, '__dict__') \ + and name in cls.__dict__ \ + and isinstance(cls.__dict__[name], (classmethod, staticmethod)): + method = cls.__dict__[name] + patched = _patch_with_injections(method.__func__, injections, closing) + patched = type(method)(patched) + else: + patched = _patch_with_injections(method, injections, closing) + + setattr(cls, name, _wrap_patched(patched, method, injections, closing)) + + +def _wrap_patched(patched: Callable[..., Any], original, injections, closing): + patched.__wired__ = True + patched.__original__ = original + patched.__injections__ = injections + patched.__closing__ = closing + return patched + + +def _unpatch( module: ModuleType, name: str, fn: Callable[..., Any], @@ -276,12 +307,6 @@ def _patch_with_injections(fn, injections, closing): _patched = _get_async_patched(fn, injections, closing) else: _patched = _get_patched(fn, injections, closing) - - _patched.__wired__ = True - _patched.__original__ = fn - _patched.__injections__ = injections - _patched.__closing__ = closing - return _patched diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index d8e78c1d..753c88b4 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -65,10 +65,20 @@ class WiringTest(unittest.TestCase): service = module.TestClass.class_method() self.assertIsInstance(service, Service) + def test_instance_classmethod_wiring(self): + instance = module.TestClass() + service = instance.class_method() + self.assertIsInstance(service, Service) + def test_class_staticmethod_wiring(self): service = module.TestClass.static_method() self.assertIsInstance(service, Service) + def test_instance_staticmethod_wiring(self): + instance = module.TestClass() + service = instance.static_method() + self.assertIsInstance(service, Service) + def test_function_wiring(self): service = module.test_function() self.assertIsInstance(service, Service)