diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py index 4c02b553..a1fb992e 100644 --- a/tests/unit/samples/wiringsamples/module.py +++ b/tests/unit/samples/wiringsamples/module.py @@ -9,8 +9,15 @@ from .container import Container, SubContainer from .service import Service +service = Provide[Container.service] +service_provider = Provider[Container.service] + + class TestClass: + service = Provide[Container.service] + service_provider = Provider[Container.service] + @inject def __init__(self, service: Service = Provide[Container.service]): self.service = service diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index fd08799b..62136386 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -6,6 +6,7 @@ import unittest from dependency_injector.wiring import ( wire, Provide, + Provider, Closing, register_loader_containers, unregister_loader_containers, @@ -64,6 +65,10 @@ class WiringTest(unittest.TestCase): service = test_function() self.assertIsInstance(service, Service) + def test_module_attributes_wiring(self): + self.assertIsInstance(module.service, Service) + self.assertIsInstance(module.service_provider(), Service) + def test_class_wiring(self): test_class_object = module.TestClass() self.assertIsInstance(test_class_object.service, Service) @@ -97,6 +102,10 @@ class WiringTest(unittest.TestCase): service = instance.static_method() self.assertIsInstance(service, Service) + def test_class_attribute_wiring(self): + self.assertIsInstance(module.TestClass.service, Service) + self.assertIsInstance(module.TestClass.service_provider(), Service) + def test_function_wiring(self): service = module.test_function() self.assertIsInstance(service, Service) @@ -215,6 +224,16 @@ class WiringTest(unittest.TestCase): self.container.unwire() self.assertIsInstance(submodule.test_function(), Provide) + def test_unwire_module_attributes(self): + self.container.unwire() + self.assertIsInstance(module.service, Provide) + self.assertIsInstance(module.service_provider, Provider) + + def test_unwire_class_attributes(self): + self.container.unwire() + self.assertIsInstance(module.TestClass.service, Provide) + self.assertIsInstance(module.TestClass.service_provider, Provider) + def test_wire_multiple_containers(self): sub_container = SubContainer() sub_container.wire(