Fix wiring for @classmethod and @staticmethod

This commit is contained in:
Roman Mogylatov 2020-11-03 15:59:02 -05:00
parent c1cf1bfa1c
commit fb2d927cae
3 changed files with 52 additions and 12 deletions

View File

@ -7,6 +7,11 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_
Develop
-------
- Fix a bug in ``wiring`` with improper patching of ``@classmethod`` and ``@staticmethod`` decorated methods
(See issue `#318 <https://github.com/ets-labs/python-dependency-injector/issues/318>`_).
4.3.2
-----
- Fix a bug in ``wiring`` with mistakenly initialized and shutdown resource with ``Closing``

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
if sys.version_info < (3, 7):
from typing import GenericMeta
@ -176,7 +176,7 @@ def wire(
_patch_fn(module, name, member, providers_map)
elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, _is_method):
_patch_fn(member, method_name, method, providers_map)
_patch_method(member, method_name, method, providers_map)
def unwire(
@ -195,10 +195,10 @@ def unwire(
for module in modules:
for name, member in inspect.getmembers(module):
if inspect.isfunction(member):
_unpatch_fn(module, name, member)
_unpatch(module, name, member)
elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch_fn(member, method_name, method)
_unpatch(member, method_name, method)
def _patch_fn(
@ -210,10 +210,41 @@ def _patch_fn(
injections, closing = _resolve_injections(fn, providers_map)
if not injections:
return
setattr(module, name, _patch_with_injections(fn, injections, closing))
patched = _patch_with_injections(fn, injections, closing)
setattr(module, name, _wrap_patched(patched, fn, injections, closing))
def _unpatch_fn(
def _patch_method(
cls: Type,
name: str,
method: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
injections, closing = _resolve_injections(method, providers_map)
if not injections:
return
if hasattr(cls, '__dict__') \
and name in cls.__dict__ \
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
method = cls.__dict__[name]
patched = _patch_with_injections(method.__func__, injections, closing)
patched = type(method)(patched)
else:
patched = _patch_with_injections(method, injections, closing)
setattr(cls, name, _wrap_patched(patched, method, injections, closing))
def _wrap_patched(patched: Callable[..., Any], original, injections, closing):
patched.__wired__ = True
patched.__original__ = original
patched.__injections__ = injections
patched.__closing__ = closing
return patched
def _unpatch(
module: ModuleType,
name: str,
fn: Callable[..., Any],
@ -276,12 +307,6 @@ def _patch_with_injections(fn, injections, closing):
_patched = _get_async_patched(fn, injections, closing)
else:
_patched = _get_patched(fn, injections, closing)
_patched.__wired__ = True
_patched.__original__ = fn
_patched.__injections__ = injections
_patched.__closing__ = closing
return _patched

View File

@ -65,10 +65,20 @@ class WiringTest(unittest.TestCase):
service = module.TestClass.class_method()
self.assertIsInstance(service, Service)
def test_instance_classmethod_wiring(self):
instance = module.TestClass()
service = instance.class_method()
self.assertIsInstance(service, Service)
def test_class_staticmethod_wiring(self):
service = module.TestClass.static_method()
self.assertIsInstance(service, Service)
def test_instance_staticmethod_wiring(self):
instance = module.TestClass()
service = instance.static_method()
self.assertIsInstance(service, Service)
def test_function_wiring(self):
service = module.test_function()
self.assertIsInstance(service, Service)