Add wiring of class methods

This commit is contained in:
Roman Mogylatov 2020-09-29 23:40:13 -04:00
parent c4b5494b6b
commit 949a91b657
3 changed files with 36 additions and 25 deletions

View File

@ -158,7 +158,8 @@ def wire(
if inspect.isfunction(member): if inspect.isfunction(member):
_patch_fn(module, name, member, providers_map) _patch_fn(module, name, member, providers_map)
elif inspect.isclass(member): elif inspect.isclass(member):
_patch_cls(member, providers_map) for method_name, method in inspect.getmembers(member, inspect.isfunction):
_patch_fn(member, method_name, method, providers_map)
def unwire( def unwire(
@ -179,29 +180,8 @@ def unwire(
if inspect.isfunction(member): if inspect.isfunction(member):
_unpatch_fn(module, name, member) _unpatch_fn(module, name, member)
elif inspect.isclass(member): elif inspect.isclass(member):
_unpatch_cls(member,) for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch_fn(member, method_name, method)
def _patch_cls(
cls: Type[Any],
providers_map: ProvidersMap,
) -> None:
if not hasattr(cls, '__init__'):
return
init_method = getattr(cls, '__init__')
injections = _resolve_injections(init_method, providers_map)
if not injections:
return
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
def _unpatch_cls(cls: Type[Any]) -> None:
if not hasattr(cls, '__init__'):
return
init_method = getattr(cls, '__init__')
if not _is_patched(init_method):
return
setattr(cls, '__init__', _get_original_from_patched(init_method))
def _patch_fn( def _patch_fn(

View File

@ -14,6 +14,9 @@ class TestClass:
def __init__(self, service: Service = Provide[Container.service]): def __init__(self, service: Service = Provide[Container.service]):
self.service = service self.service = service
def method(self, service: Service = Provide[Container.service]):
return service
def test_function(service: Service = Provide[Container.service]): def test_function(service: Service = Provide[Container.service]):
return service return service

View File

@ -1,7 +1,7 @@
from decimal import Decimal from decimal import Decimal
import unittest import unittest
from dependency_injector.wiring import wire from dependency_injector.wiring import wire, Provide
from . import module, package from . import module, package
from .service import Service from .service import Service
@ -35,6 +35,11 @@ class WiringTest(unittest.TestCase):
test_class_object = module.TestClass(service=test_service) test_class_object = module.TestClass(service=test_service)
self.assertIs(test_class_object.service, test_service) self.assertIs(test_class_object.service, test_service)
def test_class_method_wiring(self):
test_class_object = module.TestClass()
service = test_class_object.method()
self.assertIsInstance(service, Service)
def test_function_wiring(self): def test_function_wiring(self):
service = module.test_function() service = module.test_function()
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)
@ -104,3 +109,26 @@ class WiringTest(unittest.TestCase):
modules=[module], modules=[module],
) )
def test_unwire_function(self):
self.container.unwire()
self.assertIsInstance(module.test_function(), Provide)
def test_unwire_class(self):
self.container.unwire()
test_class_object = module.TestClass()
self.assertIsInstance(test_class_object.service, Provide)
def test_unwire_class_method(self):
self.container.unwire()
test_class_object = module.TestClass()
self.assertIsInstance(test_class_object.method(), Provide)
def test_unwire_package_function(self):
self.container.unwire()
from .package.subpackage.submodule import test_function
self.assertIsInstance(test_function(), Provide)
def test_unwire_package_function_by_reference(self):
from .package.subpackage import submodule
self.container.unwire()
self.assertIsInstance(submodule.test_function(), Provide)