mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-22 09:36:48 +03:00
Fix wiring for @classmethod and @staticmethod
This commit is contained in:
parent
c1cf1bfa1c
commit
fb2d927cae
|
@ -7,6 +7,11 @@ that were made in every particular version.
|
|||
From version 0.7.6 *Dependency Injector* framework strictly
|
||||
follows `Semantic versioning`_
|
||||
|
||||
Develop
|
||||
-------
|
||||
- Fix a bug in ``wiring`` with improper patching of ``@classmethod`` and ``@staticmethod`` decorated methods
|
||||
(See issue `#318 <https://github.com/ets-labs/python-dependency-injector/issues/318>`_).
|
||||
|
||||
4.3.2
|
||||
-----
|
||||
- Fix a bug in ``wiring`` with mistakenly initialized and shutdown resource with ``Closing``
|
||||
|
|
|
@ -5,7 +5,7 @@ import inspect
|
|||
import pkgutil
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast
|
||||
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
from typing import GenericMeta
|
||||
|
@ -176,7 +176,7 @@ def wire(
|
|||
_patch_fn(module, name, member, providers_map)
|
||||
elif inspect.isclass(member):
|
||||
for method_name, method in inspect.getmembers(member, _is_method):
|
||||
_patch_fn(member, method_name, method, providers_map)
|
||||
_patch_method(member, method_name, method, providers_map)
|
||||
|
||||
|
||||
def unwire(
|
||||
|
@ -195,10 +195,10 @@ def unwire(
|
|||
for module in modules:
|
||||
for name, member in inspect.getmembers(module):
|
||||
if inspect.isfunction(member):
|
||||
_unpatch_fn(module, name, member)
|
||||
_unpatch(module, name, member)
|
||||
elif inspect.isclass(member):
|
||||
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
||||
_unpatch_fn(member, method_name, method)
|
||||
_unpatch(member, method_name, method)
|
||||
|
||||
|
||||
def _patch_fn(
|
||||
|
@ -210,10 +210,41 @@ def _patch_fn(
|
|||
injections, closing = _resolve_injections(fn, providers_map)
|
||||
if not injections:
|
||||
return
|
||||
setattr(module, name, _patch_with_injections(fn, injections, closing))
|
||||
patched = _patch_with_injections(fn, injections, closing)
|
||||
setattr(module, name, _wrap_patched(patched, fn, injections, closing))
|
||||
|
||||
|
||||
def _unpatch_fn(
|
||||
def _patch_method(
|
||||
cls: Type,
|
||||
name: str,
|
||||
method: Callable[..., Any],
|
||||
providers_map: ProvidersMap,
|
||||
) -> None:
|
||||
injections, closing = _resolve_injections(method, providers_map)
|
||||
if not injections:
|
||||
return
|
||||
|
||||
if hasattr(cls, '__dict__') \
|
||||
and name in cls.__dict__ \
|
||||
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
|
||||
method = cls.__dict__[name]
|
||||
patched = _patch_with_injections(method.__func__, injections, closing)
|
||||
patched = type(method)(patched)
|
||||
else:
|
||||
patched = _patch_with_injections(method, injections, closing)
|
||||
|
||||
setattr(cls, name, _wrap_patched(patched, method, injections, closing))
|
||||
|
||||
|
||||
def _wrap_patched(patched: Callable[..., Any], original, injections, closing):
|
||||
patched.__wired__ = True
|
||||
patched.__original__ = original
|
||||
patched.__injections__ = injections
|
||||
patched.__closing__ = closing
|
||||
return patched
|
||||
|
||||
|
||||
def _unpatch(
|
||||
module: ModuleType,
|
||||
name: str,
|
||||
fn: Callable[..., Any],
|
||||
|
@ -276,12 +307,6 @@ def _patch_with_injections(fn, injections, closing):
|
|||
_patched = _get_async_patched(fn, injections, closing)
|
||||
else:
|
||||
_patched = _get_patched(fn, injections, closing)
|
||||
|
||||
_patched.__wired__ = True
|
||||
_patched.__original__ = fn
|
||||
_patched.__injections__ = injections
|
||||
_patched.__closing__ = closing
|
||||
|
||||
return _patched
|
||||
|
||||
|
||||
|
|
|
@ -65,10 +65,20 @@ class WiringTest(unittest.TestCase):
|
|||
service = module.TestClass.class_method()
|
||||
self.assertIsInstance(service, Service)
|
||||
|
||||
def test_instance_classmethod_wiring(self):
|
||||
instance = module.TestClass()
|
||||
service = instance.class_method()
|
||||
self.assertIsInstance(service, Service)
|
||||
|
||||
def test_class_staticmethod_wiring(self):
|
||||
service = module.TestClass.static_method()
|
||||
self.assertIsInstance(service, Service)
|
||||
|
||||
def test_instance_staticmethod_wiring(self):
|
||||
instance = module.TestClass()
|
||||
service = instance.static_method()
|
||||
self.assertIsInstance(service, Service)
|
||||
|
||||
def test_function_wiring(self):
|
||||
service = module.test_function()
|
||||
self.assertIsInstance(service, Service)
|
||||
|
|
Loading…
Reference in New Issue
Block a user