Fix wiring for @classmethod and @staticmethod

This commit is contained in:
Roman Mogylatov 2020-11-03 15:59:02 -05:00
parent c1cf1bfa1c
commit fb2d927cae
3 changed files with 52 additions and 12 deletions

View File

@ -7,6 +7,11 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_ follows `Semantic versioning`_
Develop
-------
- Fix a bug in ``wiring`` with improper patching of ``@classmethod`` and ``@staticmethod`` decorated methods
(See issue `#318 <https://github.com/ets-labs/python-dependency-injector/issues/318>`_).
4.3.2 4.3.2
----- -----
- Fix a bug in ``wiring`` with mistakenly initialized and shutdown resource with ``Closing`` - Fix a bug in ``wiring`` with mistakenly initialized and shutdown resource with ``Closing``

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil import pkgutil
import sys import sys
from types import ModuleType from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
from typing import GenericMeta from typing import GenericMeta
@ -176,7 +176,7 @@ def wire(
_patch_fn(module, name, member, providers_map) _patch_fn(module, name, member, providers_map)
elif inspect.isclass(member): elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, _is_method): for method_name, method in inspect.getmembers(member, _is_method):
_patch_fn(member, method_name, method, providers_map) _patch_method(member, method_name, method, providers_map)
def unwire( def unwire(
@ -195,10 +195,10 @@ def unwire(
for module in modules: for module in modules:
for name, member in inspect.getmembers(module): for name, member in inspect.getmembers(module):
if inspect.isfunction(member): if inspect.isfunction(member):
_unpatch_fn(module, name, member) _unpatch(module, name, member)
elif inspect.isclass(member): elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, inspect.isfunction): for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch_fn(member, method_name, method) _unpatch(member, method_name, method)
def _patch_fn( def _patch_fn(
@ -210,10 +210,41 @@ def _patch_fn(
injections, closing = _resolve_injections(fn, providers_map) injections, closing = _resolve_injections(fn, providers_map)
if not injections: if not injections:
return return
setattr(module, name, _patch_with_injections(fn, injections, closing)) patched = _patch_with_injections(fn, injections, closing)
setattr(module, name, _wrap_patched(patched, fn, injections, closing))
def _unpatch_fn( def _patch_method(
cls: Type,
name: str,
method: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
injections, closing = _resolve_injections(method, providers_map)
if not injections:
return
if hasattr(cls, '__dict__') \
and name in cls.__dict__ \
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
method = cls.__dict__[name]
patched = _patch_with_injections(method.__func__, injections, closing)
patched = type(method)(patched)
else:
patched = _patch_with_injections(method, injections, closing)
setattr(cls, name, _wrap_patched(patched, method, injections, closing))
def _wrap_patched(patched: Callable[..., Any], original, injections, closing):
patched.__wired__ = True
patched.__original__ = original
patched.__injections__ = injections
patched.__closing__ = closing
return patched
def _unpatch(
module: ModuleType, module: ModuleType,
name: str, name: str,
fn: Callable[..., Any], fn: Callable[..., Any],
@ -276,12 +307,6 @@ def _patch_with_injections(fn, injections, closing):
_patched = _get_async_patched(fn, injections, closing) _patched = _get_async_patched(fn, injections, closing)
else: else:
_patched = _get_patched(fn, injections, closing) _patched = _get_patched(fn, injections, closing)
_patched.__wired__ = True
_patched.__original__ = fn
_patched.__injections__ = injections
_patched.__closing__ = closing
return _patched return _patched

View File

@ -65,10 +65,20 @@ class WiringTest(unittest.TestCase):
service = module.TestClass.class_method() service = module.TestClass.class_method()
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)
def test_instance_classmethod_wiring(self):
instance = module.TestClass()
service = instance.class_method()
self.assertIsInstance(service, Service)
def test_class_staticmethod_wiring(self): def test_class_staticmethod_wiring(self):
service = module.TestClass.static_method() service = module.TestClass.static_method()
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)
def test_instance_staticmethod_wiring(self):
instance = module.TestClass()
service = instance.static_method()
self.assertIsInstance(service, Service)
def test_function_wiring(self): def test_function_wiring(self):
service = module.test_function() service = module.test_function()
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)