mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-24 10:34:01 +03:00
da13341453
* Add implementation * Add tests for module and class * Add tests for module and class for string ids * Update tests with typing * Add tests for invalid type of marker * Add docs and the example * Update changelog * Fix Python 3.6 tests and flake8
415 lines
14 KiB
Python
415 lines
14 KiB
Python
from decimal import Decimal
|
|
import unittest
|
|
|
|
from dependency_injector.wiring import (
|
|
wire,
|
|
Provide,
|
|
Provider,
|
|
Closing,
|
|
)
|
|
from dependency_injector import errors
|
|
|
|
# Runtime import to avoid syntax errors in samples on Python < 3.5
|
|
import os
|
|
_TOP_DIR = os.path.abspath(
|
|
os.path.sep.join((
|
|
os.path.dirname(__file__),
|
|
'../',
|
|
)),
|
|
)
|
|
_SAMPLES_DIR = os.path.abspath(
|
|
os.path.sep.join((
|
|
os.path.dirname(__file__),
|
|
'../samples/',
|
|
)),
|
|
)
|
|
import sys
|
|
sys.path.append(_TOP_DIR)
|
|
sys.path.append(_SAMPLES_DIR)
|
|
|
|
from asyncutils import AsyncTestCase
|
|
|
|
from wiringstringidssamples import module, package
|
|
from wiringstringidssamples.service import Service
|
|
from wiringstringidssamples.container import Container, SubContainer
|
|
|
|
|
|
class WiringTest(unittest.TestCase):
|
|
|
|
container: Container
|
|
|
|
def setUp(self) -> None:
|
|
self.container = Container(config={'a': {'b': {'c': 10}}})
|
|
self.container.wire(
|
|
modules=[module],
|
|
packages=[package],
|
|
)
|
|
self.addCleanup(self.container.unwire)
|
|
|
|
def test_package_lookup(self):
|
|
from wiringstringidssamples.package import test_package_function
|
|
service = test_package_function()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_package_subpackage_lookup(self):
|
|
from wiringstringidssamples.package.subpackage import test_package_function
|
|
service = test_package_function()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_package_submodule_lookup(self):
|
|
from wiringstringidssamples.package.subpackage.submodule import test_function
|
|
service = test_function()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_module_attributes_wiring(self):
|
|
self.assertIsInstance(module.service, Service)
|
|
self.assertIsInstance(module.service_provider(), Service)
|
|
self.assertIsInstance(module.undefined, Provide)
|
|
|
|
def test_class_wiring(self):
|
|
test_class_object = module.TestClass()
|
|
self.assertIsInstance(test_class_object.service, Service)
|
|
|
|
def test_class_wiring_context_arg(self):
|
|
test_service = self.container.service()
|
|
|
|
test_class_object = module.TestClass(service=test_service)
|
|
self.assertIs(test_class_object.service, test_service)
|
|
|
|
def test_class_method_wiring(self):
|
|
test_class_object = module.TestClass()
|
|
service = test_class_object.method()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_class_classmethod_wiring(self):
|
|
service = module.TestClass.class_method()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_instance_classmethod_wiring(self):
|
|
instance = module.TestClass()
|
|
service = instance.class_method()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_class_staticmethod_wiring(self):
|
|
service = module.TestClass.static_method()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_instance_staticmethod_wiring(self):
|
|
instance = module.TestClass()
|
|
service = instance.static_method()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_class_attribute_wiring(self):
|
|
self.assertIsInstance(module.TestClass.service, Service)
|
|
self.assertIsInstance(module.TestClass.service_provider(), Service)
|
|
self.assertIsInstance(module.TestClass.undefined, Provide)
|
|
|
|
def test_function_wiring(self):
|
|
service = module.test_function()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_function_wiring_context_arg(self):
|
|
test_service = self.container.service()
|
|
|
|
service = module.test_function(service=test_service)
|
|
self.assertIs(service, test_service)
|
|
|
|
def test_function_wiring_provider(self):
|
|
service = module.test_function_provider()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_function_wiring_provider_context_arg(self):
|
|
test_service = self.container.service()
|
|
|
|
service = module.test_function_provider(service_provider=lambda: test_service)
|
|
self.assertIs(service, test_service)
|
|
|
|
def test_configuration_option(self):
|
|
(
|
|
value_int,
|
|
value_float,
|
|
value_str,
|
|
value_decimal,
|
|
value_required,
|
|
value_required_int,
|
|
value_required_float,
|
|
value_required_str,
|
|
value_required_decimal,
|
|
) = module.test_config_value()
|
|
|
|
self.assertEqual(value_int, 10)
|
|
self.assertEqual(value_float, 10.0)
|
|
self.assertEqual(value_str, '10')
|
|
self.assertEqual(value_decimal, Decimal(10))
|
|
self.assertEqual(value_required, 10)
|
|
self.assertEqual(value_required_int, 10)
|
|
self.assertEqual(value_required_float, 10.0)
|
|
self.assertEqual(value_required_str, '10')
|
|
self.assertEqual(value_required_decimal, Decimal(10))
|
|
|
|
def test_configuration_option_required_undefined(self):
|
|
self.container.config.reset_override()
|
|
with self.assertRaisesRegex(errors.Error, 'Undefined configuration option "config.a.b.c"'):
|
|
module.test_config_value_required_undefined()
|
|
|
|
def test_provide_provider(self):
|
|
service = module.test_provide_provider()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_provided_instance(self):
|
|
class TestService:
|
|
foo = {
|
|
'bar': lambda: 10,
|
|
}
|
|
|
|
with self.container.service.override(TestService()):
|
|
some_value = module.test_provided_instance()
|
|
self.assertEqual(some_value, 10)
|
|
|
|
def test_subcontainer(self):
|
|
some_value = module.test_subcontainer_provider()
|
|
self.assertEqual(some_value, 1)
|
|
|
|
def test_config_invariant(self):
|
|
config = {
|
|
'option': {
|
|
'a': 1,
|
|
'b': 2,
|
|
},
|
|
'switch': 'a',
|
|
}
|
|
self.container.config.from_dict(config)
|
|
|
|
value_default = module.test_config_invariant()
|
|
self.assertEqual(value_default, 1)
|
|
|
|
with self.container.config.switch.override('a'):
|
|
value_a = module.test_config_invariant()
|
|
self.assertEqual(value_a, 1)
|
|
|
|
with self.container.config.switch.override('b'):
|
|
value_b = module.test_config_invariant()
|
|
self.assertEqual(value_b, 2)
|
|
|
|
def test_wire_with_class_error(self):
|
|
with self.assertRaises(Exception):
|
|
wire(
|
|
container=Container,
|
|
modules=[module],
|
|
)
|
|
|
|
def test_unwire_function(self):
|
|
self.container.unwire()
|
|
self.assertIsInstance(module.test_function(), Provide)
|
|
|
|
def test_unwire_class(self):
|
|
self.container.unwire()
|
|
test_class_object = module.TestClass()
|
|
self.assertIsInstance(test_class_object.service, Provide)
|
|
|
|
def test_unwire_class_method(self):
|
|
self.container.unwire()
|
|
test_class_object = module.TestClass()
|
|
self.assertIsInstance(test_class_object.method(), Provide)
|
|
|
|
def test_unwire_package_function(self):
|
|
self.container.unwire()
|
|
from wiringstringidssamples.package.subpackage.submodule import test_function
|
|
self.assertIsInstance(test_function(), Provide)
|
|
|
|
def test_unwire_package_function_by_reference(self):
|
|
from wiringstringidssamples.package.subpackage import submodule
|
|
self.container.unwire()
|
|
self.assertIsInstance(submodule.test_function(), Provide)
|
|
|
|
def test_unwire_module_attributes(self):
|
|
self.container.unwire()
|
|
self.assertIsInstance(module.service, Provide)
|
|
self.assertIsInstance(module.service_provider, Provider)
|
|
self.assertIsInstance(module.undefined, Provide)
|
|
|
|
def test_unwire_class_attributes(self):
|
|
self.container.unwire()
|
|
self.assertIsInstance(module.TestClass.service, Provide)
|
|
self.assertIsInstance(module.TestClass.service_provider, Provider)
|
|
self.assertIsInstance(module.TestClass.undefined, Provide)
|
|
|
|
def test_wire_multiple_containers(self):
|
|
sub_container = SubContainer()
|
|
sub_container.wire(
|
|
modules=[module],
|
|
packages=[package],
|
|
)
|
|
self.addCleanup(sub_container.unwire)
|
|
|
|
service, some_value = module.test_provide_from_different_containers()
|
|
|
|
self.assertIsInstance(service, Service)
|
|
self.assertEqual(some_value, 1)
|
|
|
|
def test_closing_resource(self):
|
|
from wiringstringidssamples import resourceclosing
|
|
|
|
resourceclosing.Service.reset_counter()
|
|
|
|
container = resourceclosing.Container()
|
|
container.wire(modules=[resourceclosing])
|
|
self.addCleanup(container.unwire)
|
|
|
|
result_1 = resourceclosing.test_function()
|
|
self.assertIsInstance(result_1, resourceclosing.Service)
|
|
self.assertEqual(result_1.init_counter, 1)
|
|
self.assertEqual(result_1.shutdown_counter, 1)
|
|
|
|
result_2 = resourceclosing.test_function()
|
|
self.assertIsInstance(result_2, resourceclosing.Service)
|
|
self.assertEqual(result_2.init_counter, 2)
|
|
self.assertEqual(result_2.shutdown_counter, 2)
|
|
|
|
self.assertIsNot(result_1, result_2)
|
|
|
|
def test_closing_resource_context(self):
|
|
from wiringstringidssamples import resourceclosing
|
|
|
|
resourceclosing.Service.reset_counter()
|
|
service = resourceclosing.Service()
|
|
|
|
container = resourceclosing.Container()
|
|
container.wire(modules=[resourceclosing])
|
|
self.addCleanup(container.unwire)
|
|
|
|
result_1 = resourceclosing.test_function(service=service)
|
|
self.assertIs(result_1, service)
|
|
self.assertEqual(result_1.init_counter, 0)
|
|
self.assertEqual(result_1.shutdown_counter, 0)
|
|
|
|
result_2 = resourceclosing.test_function(service=service)
|
|
self.assertIs(result_2, service)
|
|
self.assertEqual(result_2.init_counter, 0)
|
|
self.assertEqual(result_2.shutdown_counter, 0)
|
|
|
|
def test_class_decorator(self):
|
|
service = module.test_class_decorator()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_container(self):
|
|
service = module.test_container()
|
|
self.assertIsInstance(service, Service)
|
|
|
|
|
|
class WiringAndFastAPITest(unittest.TestCase):
|
|
|
|
container: Container
|
|
|
|
def test_bypass_marker_injection(self):
|
|
container = Container()
|
|
container.wire(modules=[module])
|
|
self.addCleanup(container.unwire)
|
|
|
|
service = module.test_function(service=Provide[Container.service])
|
|
self.assertIsInstance(service, Service)
|
|
|
|
def test_closing_resource_bypass_marker_injection(self):
|
|
from wiringstringidssamples import resourceclosing
|
|
|
|
resourceclosing.Service.reset_counter()
|
|
|
|
container = resourceclosing.Container()
|
|
container.wire(modules=[resourceclosing])
|
|
self.addCleanup(container.unwire)
|
|
|
|
result_1 = resourceclosing.test_function(
|
|
service=Closing[Provide[resourceclosing.Container.service]],
|
|
)
|
|
self.assertIsInstance(result_1, resourceclosing.Service)
|
|
self.assertEqual(result_1.init_counter, 1)
|
|
self.assertEqual(result_1.shutdown_counter, 1)
|
|
|
|
result_2 = resourceclosing.test_function(
|
|
service=Closing[Provide[resourceclosing.Container.service]],
|
|
)
|
|
self.assertIsInstance(result_2, resourceclosing.Service)
|
|
self.assertEqual(result_2.init_counter, 2)
|
|
self.assertEqual(result_2.shutdown_counter, 2)
|
|
|
|
self.assertIsNot(result_1, result_2)
|
|
|
|
|
|
class WiringAsyncInjectionsTest(AsyncTestCase):
|
|
|
|
def test_async_injections(self):
|
|
from wiringstringidssamples import asyncinjections
|
|
|
|
container = asyncinjections.Container()
|
|
container.wire(modules=[asyncinjections])
|
|
self.addCleanup(container.unwire)
|
|
|
|
asyncinjections.resource1.reset_counters()
|
|
asyncinjections.resource2.reset_counters()
|
|
|
|
resource1, resource2 = self._run(asyncinjections.async_injection())
|
|
|
|
self.assertIs(resource1, asyncinjections.resource1)
|
|
self.assertEqual(asyncinjections.resource1.init_counter, 1)
|
|
self.assertEqual(asyncinjections.resource1.shutdown_counter, 0)
|
|
|
|
self.assertIs(resource2, asyncinjections.resource2)
|
|
self.assertEqual(asyncinjections.resource2.init_counter, 1)
|
|
self.assertEqual(asyncinjections.resource2.shutdown_counter, 0)
|
|
|
|
def test_async_injections_with_closing(self):
|
|
from wiringstringidssamples import asyncinjections
|
|
|
|
container = asyncinjections.Container()
|
|
container.wire(modules=[asyncinjections])
|
|
self.addCleanup(container.unwire)
|
|
|
|
asyncinjections.resource1.reset_counters()
|
|
asyncinjections.resource2.reset_counters()
|
|
|
|
resource1, resource2 = self._run(asyncinjections.async_injection_with_closing())
|
|
|
|
self.assertIs(resource1, asyncinjections.resource1)
|
|
self.assertEqual(asyncinjections.resource1.init_counter, 1)
|
|
self.assertEqual(asyncinjections.resource1.shutdown_counter, 1)
|
|
|
|
self.assertIs(resource2, asyncinjections.resource2)
|
|
self.assertEqual(asyncinjections.resource2.init_counter, 1)
|
|
self.assertEqual(asyncinjections.resource2.shutdown_counter, 1)
|
|
|
|
resource1, resource2 = self._run(asyncinjections.async_injection_with_closing())
|
|
|
|
self.assertIs(resource1, asyncinjections.resource1)
|
|
self.assertEqual(asyncinjections.resource1.init_counter, 2)
|
|
self.assertEqual(asyncinjections.resource1.shutdown_counter, 2)
|
|
|
|
self.assertIs(resource2, asyncinjections.resource2)
|
|
self.assertEqual(asyncinjections.resource2.init_counter, 2)
|
|
self.assertEqual(asyncinjections.resource2.shutdown_counter, 2)
|
|
|
|
|
|
# class AutoLoaderTest(unittest.TestCase):
|
|
#
|
|
# container: Container
|
|
#
|
|
# def setUp(self) -> None:
|
|
# self.container = Container(config={'a': {'b': {'c': 10}}})
|
|
# importlib.reload(module)
|
|
#
|
|
# def tearDown(self) -> None:
|
|
# with contextlib.suppress(ValueError):
|
|
# unregister_loader_containers(self.container)
|
|
#
|
|
# self.container.unwire()
|
|
#
|
|
# @classmethod
|
|
# def tearDownClass(cls) -> None:
|
|
# importlib.reload(module)
|
|
#
|
|
# def test_register_container(self):
|
|
# register_loader_containers(self.container)
|
|
# importlib.reload(module)
|
|
#
|
|
# service = module.test_function()
|
|
# self.assertIsInstance(service, Service)
|