mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-02-06 06:30:51 +03:00
Add wiring of class methods
This commit is contained in:
parent
c4b5494b6b
commit
949a91b657
|
@ -158,7 +158,8 @@ def wire(
|
||||||
if inspect.isfunction(member):
|
if inspect.isfunction(member):
|
||||||
_patch_fn(module, name, member, providers_map)
|
_patch_fn(module, name, member, providers_map)
|
||||||
elif inspect.isclass(member):
|
elif inspect.isclass(member):
|
||||||
_patch_cls(member, providers_map)
|
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
||||||
|
_patch_fn(member, method_name, method, providers_map)
|
||||||
|
|
||||||
|
|
||||||
def unwire(
|
def unwire(
|
||||||
|
@ -179,29 +180,8 @@ def unwire(
|
||||||
if inspect.isfunction(member):
|
if inspect.isfunction(member):
|
||||||
_unpatch_fn(module, name, member)
|
_unpatch_fn(module, name, member)
|
||||||
elif inspect.isclass(member):
|
elif inspect.isclass(member):
|
||||||
_unpatch_cls(member,)
|
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
||||||
|
_unpatch_fn(member, method_name, method)
|
||||||
|
|
||||||
def _patch_cls(
|
|
||||||
cls: Type[Any],
|
|
||||||
providers_map: ProvidersMap,
|
|
||||||
) -> None:
|
|
||||||
if not hasattr(cls, '__init__'):
|
|
||||||
return
|
|
||||||
init_method = getattr(cls, '__init__')
|
|
||||||
injections = _resolve_injections(init_method, providers_map)
|
|
||||||
if not injections:
|
|
||||||
return
|
|
||||||
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
|
|
||||||
|
|
||||||
|
|
||||||
def _unpatch_cls(cls: Type[Any]) -> None:
|
|
||||||
if not hasattr(cls, '__init__'):
|
|
||||||
return
|
|
||||||
init_method = getattr(cls, '__init__')
|
|
||||||
if not _is_patched(init_method):
|
|
||||||
return
|
|
||||||
setattr(cls, '__init__', _get_original_from_patched(init_method))
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_fn(
|
def _patch_fn(
|
||||||
|
|
|
@ -14,6 +14,9 @@ class TestClass:
|
||||||
def __init__(self, service: Service = Provide[Container.service]):
|
def __init__(self, service: Service = Provide[Container.service]):
|
||||||
self.service = service
|
self.service = service
|
||||||
|
|
||||||
|
def method(self, service: Service = Provide[Container.service]):
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
def test_function(service: Service = Provide[Container.service]):
|
def test_function(service: Service = Provide[Container.service]):
|
||||||
return service
|
return service
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from dependency_injector.wiring import wire
|
from dependency_injector.wiring import wire, Provide
|
||||||
|
|
||||||
from . import module, package
|
from . import module, package
|
||||||
from .service import Service
|
from .service import Service
|
||||||
|
@ -35,6 +35,11 @@ class WiringTest(unittest.TestCase):
|
||||||
test_class_object = module.TestClass(service=test_service)
|
test_class_object = module.TestClass(service=test_service)
|
||||||
self.assertIs(test_class_object.service, test_service)
|
self.assertIs(test_class_object.service, test_service)
|
||||||
|
|
||||||
|
def test_class_method_wiring(self):
|
||||||
|
test_class_object = module.TestClass()
|
||||||
|
service = test_class_object.method()
|
||||||
|
self.assertIsInstance(service, Service)
|
||||||
|
|
||||||
def test_function_wiring(self):
|
def test_function_wiring(self):
|
||||||
service = module.test_function()
|
service = module.test_function()
|
||||||
self.assertIsInstance(service, Service)
|
self.assertIsInstance(service, Service)
|
||||||
|
@ -104,3 +109,26 @@ class WiringTest(unittest.TestCase):
|
||||||
modules=[module],
|
modules=[module],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_unwire_function(self):
|
||||||
|
self.container.unwire()
|
||||||
|
self.assertIsInstance(module.test_function(), Provide)
|
||||||
|
|
||||||
|
def test_unwire_class(self):
|
||||||
|
self.container.unwire()
|
||||||
|
test_class_object = module.TestClass()
|
||||||
|
self.assertIsInstance(test_class_object.service, Provide)
|
||||||
|
|
||||||
|
def test_unwire_class_method(self):
|
||||||
|
self.container.unwire()
|
||||||
|
test_class_object = module.TestClass()
|
||||||
|
self.assertIsInstance(test_class_object.method(), Provide)
|
||||||
|
|
||||||
|
def test_unwire_package_function(self):
|
||||||
|
self.container.unwire()
|
||||||
|
from .package.subpackage.submodule import test_function
|
||||||
|
self.assertIsInstance(test_function(), Provide)
|
||||||
|
|
||||||
|
def test_unwire_package_function_by_reference(self):
|
||||||
|
from .package.subpackage import submodule
|
||||||
|
self.container.unwire()
|
||||||
|
self.assertIsInstance(submodule.test_function(), Provide)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user