diff --git a/Makefile b/Makefile index a7ef1e43..0c4bca04 100644 --- a/Makefile +++ b/Makefile @@ -45,17 +45,10 @@ install: uninstall clean cythonize uninstall: - pip uninstall -y -q dependency-injector 2> /dev/null -test-py2: build +test: # Unit tests with coverage report coverage erase - coverage run --rcfile=./.coveragerc -m unittest discover -s tests/unit/ -p test_*_py2_py3.py - coverage report --rcfile=./.coveragerc - coverage html --rcfile=./.coveragerc - -test: build - # Unit tests with coverage report - coverage erase - coverage run --rcfile=./.coveragerc -m unittest discover -s tests/unit/ -p test_*py3*.py + coverage run --rcfile=./.coveragerc -m pytest -c tests/.configs/pytest.ini coverage report --rcfile=./.coveragerc coverage html --rcfile=./.coveragerc diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index abb4b520..6005b270 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -15,7 +15,9 @@ Develop - Add support of ``with`` statement for ``container.override_providers()`` method. - Drop support of Python 3.4. There are no immediate breaking changes, but Dependency Injector will no longer be tested on Python 3.4 and any bugs will not be fixed. +- Fix ``Dependency.is_defined`` attribute to always return boolean value. - Update documentation and fix typos. +- Migrate tests to ``pytest``. 4.36.2 ------ diff --git a/requirements-dev.txt b/requirements-dev.txt index 201a4de2..13d2d170 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,6 @@ cython==0.29.22 +pytest +pytest-asyncio tox coverage flake8 diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index aa38f961..ca013906 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -49,7 +49,7 @@ class Container: def __delattr__(self, name: str) -> None: ... def set_providers(self, **providers: Provider): ... def set_provider(self, name: str, provider: Provider) -> None: ... - def override(self, overriding: C_Base) -> None: ... + def override(self, overriding: Union[Container, Type[Container]]) -> None: ... def override_providers(self, **overriding_providers: Union[Provider, Any]) -> ProvidersOverridingContext[C_Base]: ... def reset_last_overriding(self) -> None: ... def reset_override(self) -> None: ... @@ -88,6 +88,14 @@ class DeclarativeContainer(Container): cls_providers: ClassVar[Dict[str, Provider]] inherited_providers: ClassVar[Dict[str, Provider]] def __init__(self, **overriding_providers: Union[Provider, Any]) -> None: ... + @classmethod + def override(cls, overriding: Union[Container, Type[Container]]) -> None: ... + @classmethod + def override_providers(cls, **overriding_providers: Union[Provider, Any]) -> ProvidersOverridingContext[C_Base]: ... + @classmethod + def reset_last_overriding(cls) -> None: ... + @classmethod + def reset_override(cls) -> None: ... class ProvidersOverridingContext(Generic[T]): diff --git a/src/dependency_injector/providers.c b/src/dependency_injector/providers.c index bc4ff2e7..4afdfd08 100644 --- a/src/dependency_injector/providers.c +++ b/src/dependency_injector/providers.c @@ -17675,7 +17675,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_14set_de * @property * def is_defined(self): # <<<<<<<<<<<<<< * """Return True if dependency is defined.""" - * return self.__last_overriding or self.__default + * return self.__last_overriding is not None or self.__default is not None */ /* Python wrapper */ @@ -17696,6 +17696,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def __Pyx_RefNannyDeclarations PyObject *__pyx_t_1 = NULL; int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; int __pyx_lineno = 0; const char *__pyx_filename = NULL; int __pyx_clineno = 0; @@ -17704,20 +17705,25 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def /* "dependency_injector/providers.pyx":774 * def is_defined(self): * """Return True if dependency is defined.""" - * return self.__last_overriding or self.__default # <<<<<<<<<<<<<< + * return self.__last_overriding is not None or self.__default is not None # <<<<<<<<<<<<<< * * def provided_by(self, provider): */ __Pyx_XDECREF(__pyx_r); - __pyx_t_2 = __Pyx_PyObject_IsTrue(((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding)); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(1, 774, __pyx_L1_error) + __pyx_t_2 = (((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding) != Py_None); if (!__pyx_t_2) { } else { - __Pyx_INCREF(((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding)); - __pyx_t_1 = ((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding); + __pyx_t_3 = __Pyx_PyBool_FromLong(__pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 774, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = __pyx_t_3; + __pyx_t_3 = 0; goto __pyx_L3_bool_binop_done; } - __Pyx_INCREF(__pyx_v_self->__pyx___default); - __pyx_t_1 = __pyx_v_self->__pyx___default; + __pyx_t_2 = (__pyx_v_self->__pyx___default != Py_None); + __pyx_t_3 = __Pyx_PyBool_FromLong(__pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 774, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = __pyx_t_3; + __pyx_t_3 = 0; __pyx_L3_bool_binop_done:; __pyx_r = __pyx_t_1; __pyx_t_1 = 0; @@ -17728,12 +17734,13 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def * @property * def is_defined(self): # <<<<<<<<<<<<<< * """Return True if dependency is defined.""" - * return self.__last_overriding or self.__default + * return self.__last_overriding is not None or self.__default is not None */ /* function exit code */ __pyx_L1_error:; __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); __Pyx_AddTraceback("dependency_injector.providers.Dependency.is_defined.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); __pyx_r = NULL; __pyx_L0:; @@ -17743,7 +17750,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def } /* "dependency_injector/providers.pyx":776 - * return self.__last_overriding or self.__default + * return self.__last_overriding is not None or self.__default is not None * * def provided_by(self, provider): # <<<<<<<<<<<<<< * """Set external dependency provider. @@ -17805,7 +17812,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_16provid goto __pyx_L0; /* "dependency_injector/providers.pyx":776 - * return self.__last_overriding or self.__default + * return self.__last_overriding is not None or self.__default is not None * * def provided_by(self, provider): # <<<<<<<<<<<<<< * """Set external dependency provider. diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index 9dbab382..e38f4a64 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -771,7 +771,7 @@ cdef class Dependency(Provider): @property def is_defined(self): """Return True if dependency is defined.""" - return self.__last_overriding or self.__default + return self.__last_overriding is not None or self.__default is not None def provided_by(self, provider): """Set external dependency provider. diff --git a/tests/.configs/pytest-py27.ini b/tests/.configs/pytest-py27.ini new file mode 100644 index 00000000..14d76ae7 --- /dev/null +++ b/tests/.configs/pytest-py27.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests/unit +python_files = test_*_py2_py3.py +filterwarnings = + ignore:Module \"dependency_injector.ext.aiohttp\" is deprecated since version 4\.0\.0:DeprecationWarning + ignore:Module \"dependency_injector.ext.flask\" is deprecated since version 4\.0\.0:DeprecationWarning diff --git a/tests/.configs/pytest-py35.ini b/tests/.configs/pytest-py35.ini new file mode 100644 index 00000000..81330704 --- /dev/null +++ b/tests/.configs/pytest-py35.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests/unit +python_files = test_*_py3.py +filterwarnings = + ignore:Module \"dependency_injector.ext.aiohttp\" is deprecated since version 4\.0\.0:DeprecationWarning + ignore:Module \"dependency_injector.ext.flask\" is deprecated since version 4\.0\.0:DeprecationWarning diff --git a/tests/.configs/pytest.ini b/tests/.configs/pytest.ini new file mode 100644 index 00000000..53643540 --- /dev/null +++ b/tests/.configs/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests/unit/ +python_files = test_*_py3*.py +filterwarnings = + ignore:Module \"dependency_injector.ext.aiohttp\" is deprecated since version 4\.0\.0:DeprecationWarning + ignore:Module \"dependency_injector.ext.flask\" is deprecated since version 4\.0\.0:DeprecationWarning diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index b2572276..46816ddf 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1 +1 @@ -"""Dependency injector unit tests.""" +"""Tests package.""" diff --git a/tests/unit/asyncutils.py b/tests/unit/asyncutils.py deleted file mode 100644 index 53cf30cd..00000000 --- a/tests/unit/asyncutils.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Test utils.""" - -import asyncio -import contextlib -import sys -import gc -import unittest - - -def run(main): - loop = asyncio.get_event_loop() - return loop.run_until_complete(main) - - -def setup_test_loop( - loop_factory=asyncio.new_event_loop -) -> asyncio.AbstractEventLoop: - loop = loop_factory() - try: - module = loop.__class__.__module__ - skip_watcher = "uvloop" in module - except AttributeError: # pragma: no cover - # Just in case - skip_watcher = True - asyncio.set_event_loop(loop) - if sys.platform != "win32" and not skip_watcher: - policy = asyncio.get_event_loop_policy() - watcher = asyncio.SafeChildWatcher() # type: ignore - watcher.attach_loop(loop) - with contextlib.suppress(NotImplementedError): - policy.set_child_watcher(watcher) - return loop - - -def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: - closed = loop.is_closed() - if not closed: - loop.call_soon(loop.stop) - loop.run_forever() - loop.close() - - if not fast: - gc.collect() - - asyncio.set_event_loop(None) - - -class AsyncTestCase(unittest.TestCase): - - def setUp(self): - self.loop = setup_test_loop() - - def tearDown(self): - teardown_test_loop(self.loop) - - def _run(self, f): - return self.loop.run_until_complete(f) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..f40a23e8 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,8 @@ +"""Fixtures module.""" + +import sys +import os.path + + +# Add current package to import samples/ dir +sys.path.append(os.path.dirname(__file__)) diff --git a/tests/unit/containers/__init__.py b/tests/unit/containers/__init__.py index 69cef935..bfcb6292 100644 --- a/tests/unit/containers/__init__.py +++ b/tests/unit/containers/__init__.py @@ -1 +1 @@ -"""Dependency injector container unit tests.""" +"""Container tests.""" diff --git a/tests/unit/containers/cls/__init__.py b/tests/unit/containers/cls/__init__.py new file mode 100644 index 00000000..ff13b5c1 --- /dev/null +++ b/tests/unit/containers/cls/__init__.py @@ -0,0 +1 @@ +"""Container class tests.""" diff --git a/tests/unit/containers/cls/test_custom_strings_py2_py3.py b/tests/unit/containers/cls/test_custom_strings_py2_py3.py new file mode 100644 index 00000000..940a41f3 --- /dev/null +++ b/tests/unit/containers/cls/test_custom_strings_py2_py3.py @@ -0,0 +1,36 @@ +"""Tests for container cls with custom string classes as attribute names. + +See: https://github.com/ets-labs/python-dependency-injector/issues/479 +""" + +from dependency_injector import containers, providers +from pytest import fixture, raises + + +class CustomString(str): + pass + + +class CustomClass: + thing = None + + +class Container(containers.DeclarativeContainer): + pass + + +@fixture +def provider(): + return providers.Provider() + + +def test_setattr(provider): + setattr(Container, CustomString("test_attr"), provider) + assert Container.test_attr is provider + + +def test_delattr(): + setattr(Container, CustomString("test_attr"), provider) + delattr(Container, CustomString("test_attr")) + with raises(AttributeError): + Container.test_attr diff --git a/tests/unit/containers/cls/test_main_py2_py3.py b/tests/unit/containers/cls/test_main_py2_py3.py new file mode 100644 index 00000000..1d13d5b3 --- /dev/null +++ b/tests/unit/containers/cls/test_main_py2_py3.py @@ -0,0 +1,498 @@ +"""Main container class tests.""" + +import collections + +from dependency_injector import containers, providers, errors +from pytest import raises + + +class ContainerA(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + +class ContainerB(ContainerA): + p21 = providers.Provider() + p22 = providers.Provider() + + +class ContainerC(ContainerB): + p31 = providers.Provider() + p32 = providers.Provider() + + +def test_providers_attribute(): + assert ContainerA.providers == dict(p11=ContainerA.p11, p12=ContainerA.p12) + assert ContainerB.providers == dict( + p11=ContainerA.p11, + p12=ContainerA.p12, + p21=ContainerB.p21, + p22=ContainerB.p22, + ) + assert ContainerC.providers == dict( + p11=ContainerA.p11, + p12=ContainerA.p12, + p21=ContainerB.p21, + p22=ContainerB.p22, + p31=ContainerC.p31, + p32=ContainerC.p32, + ) + + +def test_providers_attribute_with_redefinition(): + p1 = providers.Provider() + p2 = providers.Provider() + + class ContainerA2(ContainerA): + p11 = p1 + p12 = p2 + + assert ContainerA.providers == { + "p11": ContainerA.p11, + "p12": ContainerA.p12, + } + assert ContainerA2.providers == { + "p11": p1, + "p12": p2, + } + + +def test_cls_providers_attribute(): + assert ContainerA.cls_providers == dict(p11=ContainerA.p11, p12=ContainerA.p12) + assert ContainerB.cls_providers == dict(p21=ContainerB.p21, p22=ContainerB.p22) + assert ContainerC.cls_providers == dict(p31=ContainerC.p31, p32=ContainerC.p32) + + +def test_inherited_providers_attribute(): + assert ContainerA.inherited_providers == dict() + assert ContainerB.inherited_providers == dict(p11=ContainerA.p11, p12=ContainerA.p12) + assert ContainerC.inherited_providers == dict( + p11=ContainerA.p11, + p12=ContainerA.p12, + p21=ContainerB.p21, + p22=ContainerB.p22, + ) + + +def test_dependencies_attribute(): + class ContainerD(ContainerC): + p41 = providers.Dependency() + p42 = providers.DependenciesContainer() + + class ContainerE(ContainerD): + p51 = providers.Dependency() + p52 = providers.DependenciesContainer() + + assert ContainerD.dependencies == { + "p41": ContainerD.p41, + "p42": ContainerD.p42, + } + assert ContainerE.dependencies == { + "p41": ContainerD.p41, + "p42": ContainerD.p42, + "p51": ContainerE.p51, + "p52": ContainerE.p52, + } + + +def test_set_get_del_providers(): + a_p13 = providers.Provider() + b_p23 = providers.Provider() + + ContainerA.p13 = a_p13 + ContainerB.p23 = b_p23 + + assert ContainerA.providers == dict( + p11=ContainerA.p11, + p12=ContainerA.p12, + p13=a_p13, + ) + assert ContainerB.providers == dict( + p11=ContainerA.p11, + p12=ContainerA.p12, + p21=ContainerB.p21, + p22=ContainerB.p22, + p23=b_p23, + ) + + assert ContainerA.cls_providers == dict( + p11=ContainerA.p11, + p12=ContainerA.p12, + p13=a_p13, + ) + assert ContainerB.cls_providers == dict( + p21=ContainerB.p21, + p22=ContainerB.p22, + p23=b_p23, + ) + + del ContainerA.p13 + del ContainerB.p23 + + assert ContainerA.providers == dict(p11=ContainerA.p11, p12=ContainerA.p12) + assert ContainerB.providers == dict( + p11=ContainerA.p11, + p12=ContainerA.p12, + p21=ContainerB.p21, + p22=ContainerB.p22, + ) + + assert ContainerA.cls_providers == dict(p11=ContainerA.p11, p12=ContainerA.p12) + assert ContainerB.cls_providers == dict(p21=ContainerB.p21, p22=ContainerB.p22) + + +def test_declare_with_valid_provider_type(): + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + px = providers.Object(object()) + + assert isinstance(_Container.px, providers.Object) + + +def test_declare_with_invalid_provider_type(): + with raises(errors.Error): + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + px = providers.Provider() + + +def test_seth_valid_provider_type(): + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + + _Container.px = providers.Object(object()) + + assert isinstance(_Container.px, providers.Object) + +def test_set_invalid_provider_type(): + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + + with raises(errors.Error): + _Container.px = providers.Provider() + + +def test_override(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + _Container.override(_OverridingContainer1) + _Container.override(_OverridingContainer2) + + assert _Container.overridden == (_OverridingContainer1, _OverridingContainer2) + assert _Container.p11.overridden == (_OverridingContainer1.p11, _OverridingContainer2.p11) + + +def test_override_with_it(): + with raises(errors.Error): + ContainerA.override(ContainerA) + + +def test_override_with_parent(): + with raises(errors.Error): + ContainerB.override(ContainerA) + + +def test_override_decorator(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + @containers.override(_Container) + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + @containers.override(_Container) + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + assert _Container.overridden == (_OverridingContainer1, _OverridingContainer2) + assert _Container.p11.overridden == (_OverridingContainer1.p11, _OverridingContainer2.p11) + + +def test_reset_last_overriding(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + _Container.override(_OverridingContainer1) + _Container.override(_OverridingContainer2) + _Container.reset_last_overriding() + + assert _Container.overridden == (_OverridingContainer1,) + assert _Container.p11.overridden == (_OverridingContainer1.p11,) + + +def test_reset_last_overriding_when_not_overridden(): + with raises(errors.Error): + ContainerA.reset_last_overriding() + + +def test_reset_override(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + _Container.override(_OverridingContainer1) + _Container.override(_OverridingContainer2) + _Container.reset_override() + + assert _Container.overridden == tuple() + assert _Container.p11.overridden == tuple() + + +def test_copy(): + @containers.copy(ContainerA) + class _Container1(ContainerA): + pass + + @containers.copy(ContainerA) + class _Container2(ContainerA): + pass + + assert ContainerA.p11 is not _Container1.p11 + assert ContainerA.p12 is not _Container1.p12 + + assert ContainerA.p11 is not _Container2.p11 + assert ContainerA.p12 is not _Container2.p12 + + assert _Container1.p11 is not _Container2.p11 + assert _Container1.p12 is not _Container2.p12 + + +def test_copy_with_replacing(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Object(0) + p12 = providers.Factory(dict, p11=p11) + + @containers.copy(_Container) + class _Container1(_Container): + p11 = providers.Object(1) + p13 = providers.Object(11) + + @containers.copy(_Container) + class _Container2(_Container): + p11 = providers.Object(2) + p13 = providers.Object(22) + + assert _Container.p11 is not _Container1.p11 + assert _Container.p12 is not _Container1.p12 + + assert _Container.p11 is not _Container2.p11 + assert _Container.p12 is not _Container2.p12 + + assert _Container1.p11 is not _Container2.p11 + assert _Container1.p12 is not _Container2.p12 + + assert _Container.p12() == {"p11": 0} + assert _Container1.p12() == {"p11": 1} + assert _Container2.p12() == {"p11": 2} + + assert _Container1.p13() == 11 + assert _Container2.p13() == 22 + + +def test_copy_with_parent_dependency(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/477 + class Base(containers.DeclarativeContainer): + p11 = providers.Object(0) + p12 = providers.Factory(dict, p11=p11) + + @containers.copy(Base) + class New(Base): + p13 = providers.Factory(dict, p12=Base.p12) + + new1 = New() + new2 = New(p11=1) + new3 = New(p11=2) + + assert new1.p13() == {"p12": {"p11": 0}} + assert new2.p13() == {"p12": {"p11": 1}} + assert new3.p13() == {"p12": {"p11": 2}} + + +def test_copy_with_replacing_subcontainer_providers(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/374 + class X(containers.DeclarativeContainer): + foo = providers.Dependency(instance_of=str) + + def build_x(): + return X(foo="1") + + class A(containers.DeclarativeContainer): + x = providers.DependenciesContainer(**X.providers) + y = x.foo + + @containers.copy(A) + class B1(A): + x = providers.Container(build_x) + + b1 = B1() + + assert b1.y() == "1" + + +def test_containers_attribute(): + class Container(containers.DeclarativeContainer): + class Container1(containers.DeclarativeContainer): + pass + + class Container2(containers.DeclarativeContainer): + pass + + Container3 = containers.DynamicContainer() + + assert Container.containers == dict( + Container1=Container.Container1, + Container2=Container.Container2, + Container3=Container.Container3, + ) + + +def test_init_with_overriding_providers(): + p1 = providers.Provider() + p2 = providers.Provider() + + container = ContainerA(p11=p1, p12=p2) + + assert container.p11.last_overriding is p1 + assert container.p12.last_overriding is p2 + + +def test_init_with_overridden_dependency(): + # Bug: https://github.com/ets-labs/python-dependency-injector/issues/198 + class _Container(containers.DeclarativeContainer): + p1 = providers.Dependency(instance_of=int) + + p2 = providers.Dependency(object) + p2.override(providers.Factory(dict, p1=p1)) + + container = _Container(p1=1) + + assert container.p2() == {"p1": 1} + assert container.p2.last_overriding.kwargs["p1"] is container.p1 + assert container.p2.last_overriding.kwargs["p1"] is not _Container.p1 + assert _Container.p2.last_overriding.kwargs["p1"] is _Container.p1 + + +def test_init_with_chained_dependency(): + # Bug: https://github.com/ets-labs/python-dependency-injector/issues/200 + class _Container(containers.DeclarativeContainer): + p1 = providers.Dependency(instance_of=int) + p2 = providers.Factory(p1) + + container = _Container(p1=1) + + assert container.p2() == 1 + assert container.p2.cls is container.p1 + assert _Container.p2.cls is _Container.p1 + assert container.p2.cls is not _Container.p1 + + +def test_init_with_dependency_delegation(): + # Bug: https://github.com/ets-labs/python-dependency-injector/issues/235 + A = collections.namedtuple("A", []) + B = collections.namedtuple("B", ["fa"]) + C = collections.namedtuple("B", ["a"]) + + class Services(containers.DeclarativeContainer): + a = providers.Dependency() + c = providers.Factory(C, a=a) + b = providers.Factory(B, fa=a.provider) + + a = providers.Factory(A) + assert isinstance(Services(a=a).c().a, A) # OK + Services(a=a).b().fa() + + +def test_init_with_grand_child_provider(): + # Bug: https://github.com/ets-labs/python-dependency-injector/issues/350 + provider = providers.Provider() + container = ContainerC(p11=provider) + + assert isinstance(container.p11, providers.Provider) + assert isinstance(container.p12, providers.Provider) + assert isinstance(container.p21, providers.Provider) + assert isinstance(container.p22, providers.Provider) + assert isinstance(container.p31, providers.Provider) + assert isinstance(container.p32, providers.Provider) + assert container.p11.last_overriding is provider + + +def test_parent_set_in__new__(): + class Container(containers.DeclarativeContainer): + dependency = providers.Dependency() + dependencies_container = providers.DependenciesContainer() + container = providers.Container(ContainerA) + + assert Container.dependency.parent is Container + assert Container.dependencies_container.parent is Container + assert Container.container.parent is Container + + +def test_parent_set_in__setattr__(): + class Container(containers.DeclarativeContainer): + pass + + Container.dependency = providers.Dependency() + Container.dependencies_container = providers.DependenciesContainer() + Container.container = providers.Container(ContainerA) + + assert Container.dependency.parent is Container + assert Container.dependencies_container.parent is Container + assert Container.container.parent is Container + + +def test_resolve_provider_name(): + assert ContainerA.resolve_provider_name(ContainerA.p11) == "p11" + + +def test_resolve_provider_name_no_provider(): + with raises(errors.Error): + ContainerA.resolve_provider_name(providers.Provider()) + + +def test_child_dependency_parent_name(): + class Container(containers.DeclarativeContainer): + dependency = providers.Dependency() + + with raises(errors.Error, match="Dependency \"Container.dependency\" is not defined"): + Container.dependency() + + +def test_child_dependencies_container_parent_name(): + class Container(containers.DeclarativeContainer): + dependencies_container = providers.DependenciesContainer() + + with raises(errors.Error, match="Dependency \"Container.dependencies_container.dependency\" is not defined"): + Container.dependencies_container.dependency() + + +def test_child_container_parent_name(): + class ChildContainer(containers.DeclarativeContainer): + dependency = providers.Dependency() + + class Container(containers.DeclarativeContainer): + child_container = providers.Container(ChildContainer) + + with raises(errors.Error, match="Dependency \"Container.child_container.dependency\" is not defined"): + Container.child_container.dependency() diff --git a/tests/unit/containers/instance/__init__.py b/tests/unit/containers/instance/__init__.py new file mode 100644 index 00000000..df1e8b69 --- /dev/null +++ b/tests/unit/containers/instance/__init__.py @@ -0,0 +1 @@ +"""Container instance tests.""" diff --git a/tests/unit/containers/instance/test_async_resources_py36.py b/tests/unit/containers/instance/test_async_resources_py36.py new file mode 100644 index 00000000..b365b60d --- /dev/null +++ b/tests/unit/containers/instance/test_async_resources_py36.py @@ -0,0 +1,147 @@ +"""Tests for container async resources.""" + +import asyncio + +from dependency_injector import containers, providers +from pytest import mark, raises + + +@mark.asyncio +async def test_init_and_shutdown_ordering(): + """Test init and shutdown resources. + + Methods .init_resources() and .shutdown_resources() should respect resources dependencies. + Initialization should first initialize resources without dependencies and then provide + these resources to other resources. Resources shutdown should follow the same rule: first + shutdown resources without initialized dependencies and then continue correspondingly + until all resources are shutdown. + """ + initialized_resources = [] + shutdown_resources = [] + + async def _resource(name, delay, **_): + await asyncio.sleep(delay) + initialized_resources.append(name) + + yield name + + await asyncio.sleep(delay) + shutdown_resources.append(name) + + class Container(containers.DeclarativeContainer): + resource1 = providers.Resource( + _resource, + name="r1", + delay=0.03, + ) + resource2 = providers.Resource( + _resource, + name="r2", + delay=0.02, + r1=resource1, + ) + resource3 = providers.Resource( + _resource, + name="r3", + delay=0.01, + r2=resource2, + ) + + container = Container() + + await container.init_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == [] + + await container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + await container.init_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + await container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1", "r3", "r2", "r1"] + + +@mark.asyncio +async def test_shutdown_circular_dependencies_breaker(): + async def _resource(name, **_): + yield name + + class Container(containers.DeclarativeContainer): + resource1 = providers.Resource( + _resource, + name="r1", + ) + resource2 = providers.Resource( + _resource, + name="r2", + r1=resource1, + ) + resource3 = providers.Resource( + _resource, + name="r3", + r2=resource2, + ) + + container = Container() + await container.init_resources() + + # Create circular dependency after initialization (r3 -> r2 -> r1 -> r3 -> ...) + container.resource1.add_kwargs(r3=container.resource3) + + with raises(RuntimeError, match="Unable to resolve resources shutdown order"): + await container.shutdown_resources() + + +@mark.asyncio +async def test_shutdown_sync_and_async_ordering(): + initialized_resources = [] + shutdown_resources = [] + + def _sync_resource(name, **_): + initialized_resources.append(name) + yield name + shutdown_resources.append(name) + + async def _async_resource(name, **_): + initialized_resources.append(name) + yield name + shutdown_resources.append(name) + + class Container(containers.DeclarativeContainer): + resource1 = providers.Resource( + _sync_resource, + name="r1", + ) + resource2 = providers.Resource( + _sync_resource, + name="r2", + r1=resource1, + ) + resource3 = providers.Resource( + _async_resource, + name="r3", + r2=resource2, + ) + + container = Container() + + await container.init_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == [] + + await container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + await container.init_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + await container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1", "r3", "r2", "r1"] diff --git a/tests/unit/containers/instance/test_custom_strings_py2_py3.py b/tests/unit/containers/instance/test_custom_strings_py2_py3.py new file mode 100644 index 00000000..45cc451d --- /dev/null +++ b/tests/unit/containers/instance/test_custom_strings_py2_py3.py @@ -0,0 +1,47 @@ +"""Tests for container with custom string classes as attribute names. + +See: https://github.com/ets-labs/python-dependency-injector/issues/479 +""" + +from dependency_injector import containers, providers +from pytest import fixture, raises + + +class CustomString(str): + pass + + +class CustomClass: + thing = None + + +@fixture +def container(): + return containers.DynamicContainer() + + +@fixture +def provider(): + return providers.Provider() + + +def test_setattr(container, provider): + setattr(container, CustomString("test_attr"), provider) + assert container.test_attr is provider + + +def test_delattr(container, provider): + setattr(container, CustomString("test_attr"), provider) + delattr(container, CustomString("test_attr")) + with raises(AttributeError): + container.test_attr + + +def test_set_provider(container, provider): + container.set_provider(CustomString("test_attr"), provider) + assert container.test_attr is provider + + +def test_set_providers(container, provider): + container.set_providers(**{CustomString("test_attr"): provider}) + assert container.test_attr is provider diff --git a/tests/unit/containers/instance/test_main_py2_py3.py b/tests/unit/containers/instance/test_main_py2_py3.py new file mode 100644 index 00000000..ddd61de9 --- /dev/null +++ b/tests/unit/containers/instance/test_main_py2_py3.py @@ -0,0 +1,493 @@ +"""Main container instance tests.""" + +from dependency_injector import containers, providers, errors +from pytest import raises + + +class Container(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + +def test_providers_attribute(): + container_1 = Container() + container_2 = Container() + + assert container_1.p11 is not container_2.p11 + assert container_1.p12 is not container_2.p12 + assert container_1.providers != container_2.providers + + +def test_dependencies_attribute(): + container = Container() + container.a1 = providers.Dependency() + container.a2 = providers.DependenciesContainer() + assert container.dependencies == {"a1": container.a1, "a2": container.a2} + + +def test_set_get_del_providers(): + p13 = providers.Provider() + + container_1 = Container() + container_2 = Container() + + container_1.p13 = p13 + container_2.p13 = p13 + + assert Container.providers == dict(p11=Container.p11, p12=Container.p12) + assert Container.cls_providers, dict(p11=Container.p11, p12=Container.p12) + + assert container_1.providers == dict(p11=container_1.p11, p12=container_1.p12, p13=p13) + assert container_2.providers == dict(p11=container_2.p11, p12=container_2.p12, p13=p13) + + del container_1.p13 + assert container_1.providers == dict(p11=container_1.p11, p12=container_1.p12) + + del container_2.p13 + assert container_2.providers == dict(p11=container_2.p11, p12=container_2.p12) + + del container_1.p11 + del container_1.p12 + assert container_1.providers == dict() + assert Container.providers == dict(p11=Container.p11, p12=Container.p12) + + del container_2.p11 + del container_2.p12 + assert container_2.providers == dict() + assert Container.providers == dict(p11=Container.p11, p12=Container.p12) + + +def test_set_invalid_provider_type(): + container = Container() + container.provider_type = providers.Object + + with raises(errors.Error): + container.px = providers.Provider() + + assert Container.provider_type is containers.DeclarativeContainer.provider_type + + +def test_set_providers(): + p13 = providers.Provider() + p14 = providers.Provider() + container = Container() + + container.set_providers(p13=p13, p14=p14) + + assert container.p13 is p13 + assert container.p14 is p14 + + +def test_override(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + container = _Container() + overriding_container1 = _OverridingContainer1() + overriding_container2 = _OverridingContainer2() + + container.override(overriding_container1) + container.override(overriding_container2) + + assert container.overridden == (overriding_container1, overriding_container2) + assert container.p11.overridden == (overriding_container1.p11, overriding_container2.p11) + + assert _Container.overridden == tuple() + assert _Container.p11.overridden == tuple() + + +def test_override_with_it(): + container = Container() + with raises(errors.Error): + container.override(container) + + +def test_override_providers(): + p1 = providers.Provider() + p2 = providers.Provider() + container = Container() + + container.override_providers(p11=p1, p12=p2) + + assert container.p11.last_overriding is p1 + assert container.p12.last_overriding is p2 + + +def test_override_providers_context_manager(): + p1 = providers.Provider() + p2 = providers.Provider() + container = Container() + + with container.override_providers(p11=p1, p12=p2) as context_container: + assert container is context_container + assert container.p11.last_overriding is p1 + assert container.p12.last_overriding is p2 + + assert container.p11.last_overriding is None + assert container.p12.last_overriding is None + + +def test_override_providers_with_unknown_provider(): + container = Container() + with raises(AttributeError): + container.override_providers(unknown=providers.Provider()) + + +def test_reset_last_overriding(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + container = _Container() + overriding_container1 = _OverridingContainer1() + overriding_container2 = _OverridingContainer2() + + container.override(overriding_container1) + container.override(overriding_container2) + container.reset_last_overriding() + + assert container.overridden == (overriding_container1,) + assert container.p11.overridden, (overriding_container1.p11,) + + +def test_reset_last_overriding_when_not_overridden(): + container = Container() + with raises(errors.Error): + container.reset_last_overriding() + + +def test_reset_override(): + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + container = _Container() + overriding_container1 = _OverridingContainer1() + overriding_container2 = _OverridingContainer2() + + container.override(overriding_container1) + container.override(overriding_container2) + container.reset_override() + + assert container.overridden == tuple() + assert container.p11.overridden == tuple() + + +def test_init_and_shutdown_resources_ordering(): + """Test init and shutdown resources. + + Methods .init_resources() and .shutdown_resources() should respect resources dependencies. + Initialization should first initialize resources without dependencies and then provide + these resources to other resources. Resources shutdown should follow the same rule: first + shutdown resources without initialized dependencies and then continue correspondingly + until all resources are shutdown. + """ + initialized_resources = [] + shutdown_resources = [] + + def _resource(name, **_): + initialized_resources.append(name) + yield name + shutdown_resources.append(name) + + class Container(containers.DeclarativeContainer): + resource1 = providers.Resource( + _resource, + name="r1", + ) + resource2 = providers.Resource( + _resource, + name="r2", + r1=resource1, + ) + resource3 = providers.Resource( + _resource, + name="r3", + r2=resource2, + ) + + container = Container() + + container.init_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == [] + + container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + container.init_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1"] + + container.shutdown_resources() + assert initialized_resources == ["r1", "r2", "r3", "r1", "r2", "r3"] + assert shutdown_resources == ["r3", "r2", "r1", "r3", "r2", "r1"] + + +def test_shutdown_resources_circular_dependencies_breaker(): + def _resource(name, **_): + yield name + + class Container(containers.DeclarativeContainer): + resource1 = providers.Resource( + _resource, + name="r1", + ) + resource2 = providers.Resource( + _resource, + name="r2", + r1=resource1, + ) + resource3 = providers.Resource( + _resource, + name="r3", + r2=resource2, + ) + + container = Container() + container.init_resources() + + # Create circular dependency after initialization (r3 -> r2 -> r1 -> r3 -> ...) + container.resource1.add_kwargs(r3=container.resource3) + + with raises(RuntimeError, match="Unable to resolve resources shutdown order"): + container.shutdown_resources() + + +def test_init_shutdown_nested_resources(): + def _init1(): + _init1.init_counter += 1 + yield + _init1.shutdown_counter += 1 + + _init1.init_counter = 0 + _init1.shutdown_counter = 0 + + def _init2(): + _init2.init_counter += 1 + yield + _init2.shutdown_counter += 1 + + _init2.init_counter = 0 + _init2.shutdown_counter = 0 + + class Container(containers.DeclarativeContainer): + + service = providers.Factory( + dict, + resource1=providers.Resource(_init1), + resource2=providers.Resource(_init2), + ) + + container = Container() + assert _init1.init_counter == 0 + assert _init1.shutdown_counter == 0 + assert _init2.init_counter == 0 + assert _init2.shutdown_counter == 0 + + container.init_resources() + assert _init1.init_counter == 1 + assert _init1.shutdown_counter == 0 + assert _init2.init_counter == 1 + assert _init2.shutdown_counter == 0 + + container.shutdown_resources() + assert _init1.init_counter == 1 + assert _init1.shutdown_counter == 1 + assert _init2.init_counter == 1 + assert _init2.shutdown_counter == 1 + + container.init_resources() + container.shutdown_resources() + assert _init1.init_counter == 2 + assert _init1.shutdown_counter == 2 + assert _init2.init_counter == 2 + assert _init2.shutdown_counter == 2 + + +def test_reset_singletons(): + class SubSubContainer(containers.DeclarativeContainer): + singleton = providers.Singleton(object) + + class SubContainer(containers.DeclarativeContainer): + singleton = providers.Singleton(object) + sub_sub_container = providers.Container(SubSubContainer) + + class Container(containers.DeclarativeContainer): + singleton = providers.Singleton(object) + sub_container = providers.Container(SubContainer) + + container = Container() + + obj11 = container.singleton() + obj12 = container.sub_container().singleton() + obj13 = container.sub_container().sub_sub_container().singleton() + + obj21 = container.singleton() + obj22 = container.sub_container().singleton() + obj23 = container.sub_container().sub_sub_container().singleton() + + assert obj11 is obj21 + assert obj12 is obj22 + assert obj13 is obj23 + + container.reset_singletons() + + obj31 = container.singleton() + obj32 = container.sub_container().singleton() + obj33 = container.sub_container().sub_sub_container().singleton() + + obj41 = container.singleton() + obj42 = container.sub_container().singleton() + obj43 = container.sub_container().sub_sub_container().singleton() + + assert obj11 is not obj31 + assert obj12 is not obj32 + assert obj13 is not obj33 + + assert obj21 is not obj31 + assert obj22 is not obj32 + assert obj23 is not obj33 + + assert obj31 is obj41 + assert obj32 is obj42 + assert obj33 is obj43 + + +def test_reset_singletons_context_manager(): + class Item: + def __init__(self, dependency): + self.dependency = dependency + + class Container(containers.DeclarativeContainer): + dependent = providers.Singleton(object) + singleton = providers.Singleton(Item, dependency=dependent) + + container = Container() + + instance1 = container.singleton() + with container.reset_singletons(): + instance2 = container.singleton() + instance3 = container.singleton() + + assert len({instance1, instance2, instance3}) == 3 + assert len({instance1.dependency, instance2.dependency, instance3.dependency}) == 3 + + +def test_reset_singletons_context_manager_as_attribute(): + container = containers.DeclarativeContainer() + with container.reset_singletons() as alias: + pass + assert container is alias + + +def test_check_dependencies(): + class SubContainer(containers.DeclarativeContainer): + dependency = providers.Dependency() + + class Container(containers.DeclarativeContainer): + dependency = providers.Dependency() + dependencies_container = providers.DependenciesContainer() + provider = providers.List(dependencies_container.dependency) + sub_container = providers.Container(SubContainer) + + container = Container() + + with raises(errors.Error) as exception_info: + container.check_dependencies() + + assert "Container \"Container\" has undefined dependencies:" in str(exception_info.value) + assert "\"Container.dependency\"" in str(exception_info.value) + assert "\"Container.dependencies_container.dependency\"" in str(exception_info.value) + assert "\"Container.sub_container.dependency\"" in str(exception_info.value) + + +def test_check_dependencies_all_defined(): + class Container(containers.DeclarativeContainer): + dependency = providers.Dependency() + + container = Container(dependency="provided") + result = container.check_dependencies() + + assert result is None + + +def test_assign_parent(): + parent = providers.DependenciesContainer() + container = Container() + + container.assign_parent(parent) + + assert container.parent is parent + + +def test_parent_name_declarative_parent(): + container = Container() + assert container.parent_name == "Container" + + +def test_parent_name(): + container = Container() + assert container.parent_name == "Container" + + +def test_parent_name_with_deep_parenting(): + class Container2(containers.DeclarativeContainer): + name = providers.Container(Container) + + class Container1(containers.DeclarativeContainer): + container = providers.Container(Container2) + + container = Container1() + assert container.container().name.parent_name == "Container1.container.name" + + +def test_parent_name_is_none(): + container = containers.DynamicContainer() + assert container.parent_name is None + + +def test_parent_deepcopy(): + class ParentContainer(containers.DeclarativeContainer): + child = providers.Container(Container) + + container = ParentContainer() + copied = providers.deepcopy(container) + + assert container.child.parent is container + assert copied.child.parent is copied + + assert container is not copied + assert container.child is not copied.child + assert container.child.parent is not copied.child.parent + + +def test_resolve_provider_name(): + container = Container() + assert container.resolve_provider_name(container.p11) == "p11" + + +def test_resolve_provider_name_no_provider(): + container = Container() + with raises(errors.Error): + container.resolve_provider_name(providers.Provider()) diff --git a/tests/unit/containers/instance/test_self_py2_py3.py b/tests/unit/containers/instance/test_self_py2_py3.py new file mode 100644 index 00000000..7938f378 --- /dev/null +++ b/tests/unit/containers/instance/test_self_py2_py3.py @@ -0,0 +1,215 @@ +"""Tests for container self provier.""" + +from dependency_injector import containers, providers, errors +from pytest import raises + + +def test_self(): + def call_bar(container): + return container.bar() + + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + foo = providers.Callable(call_bar, __self__) + bar = providers.Object("hello") + + container = Container() + assert container.foo() is "hello" + + +def test_self_attribute_implicit(): + class Container(containers.DeclarativeContainer): + pass + + container = Container() + assert container.__self__() is container + + +def test_self_attribute_explicit(): + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + + container = Container() + assert container.__self__() is container + + +def test_single_self(): + with raises(errors.Error): + class Container(containers.DeclarativeContainer): + self1 = providers.Self() + self2 = providers.Self() + + +def test_self_attribute_alt_name_implicit(): + class Container(containers.DeclarativeContainer): + foo = providers.Self() + + container = Container() + + assert container.__self__ is container.foo + assert set(container.__self__.alt_names) == {"foo"} + + +def test_self_attribute_alt_name_explicit_1(): + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + foo = __self__ + bar = __self__ + + container = Container() + + assert container.__self__ is container.foo + assert container.__self__ is container.bar + assert set(container.__self__.alt_names) == {"foo", "bar"} + + +def test_self_attribute_alt_name_explicit_2(): + class Container(containers.DeclarativeContainer): + foo = providers.Self() + bar = foo + + container = Container() + + assert container.__self__ is container.foo + assert container.__self__ is container.bar + assert set(container.__self__.alt_names) == {"foo", "bar"} + + +def test_providers_attribute_1(): + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + foo = __self__ + bar = __self__ + + container = Container() + + assert container.providers == {} + assert Container.providers == {} + + +def test_providers_attribute_2(): + class Container(containers.DeclarativeContainer): + foo = providers.Self() + bar = foo + + container = Container() + + assert container.providers == {} + assert Container.providers == {} + + +def test_container_multiple_instances(): + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + + container1 = Container() + container2 = Container() + + assert container1 is not container2 + assert container1.__self__() is container1 + assert container2.__self__() is container2 + + +def test_deepcopy(): + def call_bar(container): + return container.bar() + + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + foo = providers.Callable(call_bar, __self__) + bar = providers.Object("hello") + + container1 = Container() + container2 = providers.deepcopy(container1) + container1.bar.override("bye") + + assert container1.foo() == "bye" + assert container2.foo() == "hello" + + +def test_deepcopy_alt_names_1(): + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + foo = __self__ + bar = foo + + container1 = Container() + container2 = providers.deepcopy(container1) + + assert container2.__self__() is container2 + assert container2.foo() is container2 + assert container2.bar() is container2 + + +def test_deepcopy_alt_names_2(): + class Container(containers.DeclarativeContainer): + self = providers.Self() + + container1 = Container() + container2 = providers.deepcopy(container1) + + assert container2.__self__() is container2 + assert container2.self() is container2 + + +def test_deepcopy_no_self_dependencies(): + class Container(containers.DeclarativeContainer): + __self__ = providers.Self() + + container1 = Container() + container2 = providers.deepcopy(container1) + + assert container1 is not container2 + assert container1.__self__ is not container2.__self__ + assert container1.__self__() is container1 + assert container2.__self__() is container2 + + +def test_with_container_provider(): + def call_bar(container): + return container.bar() + + class SubContainer(containers.DeclarativeContainer): + __self__ = providers.Self() + foo = providers.Callable(call_bar, __self__) + bar = providers.Object("hello") + + class Container(containers.DeclarativeContainer): + sub_container = providers.Container(SubContainer) + + baz = providers.Callable(lambda value: value, sub_container.foo) + + container = Container() + assert container.baz() == "hello" + + +def test_with_container_provider_overriding(): + def call_bar(container): + return container.bar() + + class SubContainer(containers.DeclarativeContainer): + __self__ = providers.Self() + foo = providers.Callable(call_bar, __self__) + bar = providers.Object("hello") + + class Container(containers.DeclarativeContainer): + sub_container = providers.Container(SubContainer, bar="bye") + + baz = providers.Callable(lambda value: value, sub_container.foo) + + container = Container() + assert container.baz() == "bye" + + +def test_with_container_provider_self(): + class SubContainer(containers.DeclarativeContainer): + __self__ = providers.Self() + + class Container(containers.DeclarativeContainer): + sub_container = providers.Container(SubContainer) + + container = Container() + + assert container.__self__() is container + assert container.sub_container().__self__() is container.sub_container() + diff --git a/tests/unit/containers/test_declarative_py2_py3.py b/tests/unit/containers/test_declarative_py2_py3.py deleted file mode 100644 index f6176893..00000000 --- a/tests/unit/containers/test_declarative_py2_py3.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Dependency injector declarative container unit tests.""" - -import collections -import unittest - -from dependency_injector import ( - containers, - providers, - errors, -) - - -class ContainerA(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - -class ContainerB(ContainerA): - p21 = providers.Provider() - p22 = providers.Provider() - - -class ContainerC(ContainerB): - p31 = providers.Provider() - p32 = providers.Provider() - - -class DeclarativeContainerTests(unittest.TestCase): - - def test_providers_attribute(self): - self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - self.assertEqual(ContainerB.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12, - p21=ContainerB.p21, - p22=ContainerB.p22)) - self.assertEqual(ContainerC.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12, - p21=ContainerB.p21, - p22=ContainerB.p22, - p31=ContainerC.p31, - p32=ContainerC.p32)) - - def test_providers_attribute_with_redefinition(self): - p1 = providers.Provider() - p2 = providers.Provider() - - class ContainerA2(ContainerA): - p11 = p1 - p12 = p2 - - self.assertEqual( - ContainerA.providers, - { - "p11": ContainerA.p11, - "p12": ContainerA.p12, - }, - ) - self.assertEqual( - ContainerA2.providers, - { - "p11": p1, - "p12": p2, - }, - ) - - def test_cls_providers_attribute(self): - self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21, - p22=ContainerB.p22)) - self.assertEqual(ContainerC.cls_providers, dict(p31=ContainerC.p31, - p32=ContainerC.p32)) - - def test_inherited_providers_attribute(self): - self.assertEqual(ContainerA.inherited_providers, dict()) - self.assertEqual(ContainerB.inherited_providers, - dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - self.assertEqual(ContainerC.inherited_providers, - dict(p11=ContainerA.p11, - p12=ContainerA.p12, - p21=ContainerB.p21, - p22=ContainerB.p22)) - - def test_dependencies_attribute(self): - class ContainerD(ContainerC): - p41 = providers.Dependency() - p42 = providers.DependenciesContainer() - - class ContainerE(ContainerD): - p51 = providers.Dependency() - p52 = providers.DependenciesContainer() - - self.assertEqual( - ContainerD.dependencies, - { - "p41": ContainerD.p41, - "p42": ContainerD.p42, - }, - ) - self.assertEqual( - ContainerE.dependencies, - { - "p41": ContainerD.p41, - "p42": ContainerD.p42, - "p51": ContainerE.p51, - "p52": ContainerE.p52, - }, - ) - - def test_set_get_del_providers(self): - a_p13 = providers.Provider() - b_p23 = providers.Provider() - - ContainerA.p13 = a_p13 - ContainerB.p23 = b_p23 - - self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12, - p13=a_p13)) - self.assertEqual(ContainerB.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12, - p21=ContainerB.p21, - p22=ContainerB.p22, - p23=b_p23)) - - self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12, - p13=a_p13)) - self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21, - p22=ContainerB.p22, - p23=b_p23)) - - del ContainerA.p13 - del ContainerB.p23 - - self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - self.assertEqual(ContainerB.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12, - p21=ContainerB.p21, - p22=ContainerB.p22)) - - self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21, - p22=ContainerB.p22)) - - def test_declare_with_valid_provider_type(self): - class _Container(containers.DeclarativeContainer): - provider_type = providers.Object - px = providers.Object(object()) - - self.assertIsInstance(_Container.px, providers.Object) - - def test_declare_with_invalid_provider_type(self): - with self.assertRaises(errors.Error): - class _Container(containers.DeclarativeContainer): - provider_type = providers.Object - px = providers.Provider() - - def test_seth_valid_provider_type(self): - class _Container(containers.DeclarativeContainer): - provider_type = providers.Object - - _Container.px = providers.Object(object()) - - self.assertIsInstance(_Container.px, providers.Object) - - def test_set_invalid_provider_type(self): - class _Container(containers.DeclarativeContainer): - provider_type = providers.Object - - with self.assertRaises(errors.Error): - _Container.px = providers.Provider() - - def test_override(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer1(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer2(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - _Container.override(_OverridingContainer1) - _Container.override(_OverridingContainer2) - - self.assertEqual(_Container.overridden, - (_OverridingContainer1, - _OverridingContainer2)) - self.assertEqual(_Container.p11.overridden, - (_OverridingContainer1.p11, - _OverridingContainer2.p11)) - - def test_override_with_itself(self): - with self.assertRaises(errors.Error): - ContainerA.override(ContainerA) - - def test_override_with_parent(self): - with self.assertRaises(errors.Error): - ContainerB.override(ContainerA) - - def test_override_decorator(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Provider() - - @containers.override(_Container) - class _OverridingContainer1(containers.DeclarativeContainer): - p11 = providers.Provider() - - @containers.override(_Container) - class _OverridingContainer2(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - self.assertEqual(_Container.overridden, - (_OverridingContainer1, - _OverridingContainer2)) - self.assertEqual(_Container.p11.overridden, - (_OverridingContainer1.p11, - _OverridingContainer2.p11)) - - def test_reset_last_overriding(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer1(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer2(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - _Container.override(_OverridingContainer1) - _Container.override(_OverridingContainer2) - _Container.reset_last_overriding() - - self.assertEqual(_Container.overridden, - (_OverridingContainer1,)) - self.assertEqual(_Container.p11.overridden, - (_OverridingContainer1.p11,)) - - def test_reset_last_overriding_when_not_overridden(self): - with self.assertRaises(errors.Error): - ContainerA.reset_last_overriding() - - def test_reset_override(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer1(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer2(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - _Container.override(_OverridingContainer1) - _Container.override(_OverridingContainer2) - _Container.reset_override() - - self.assertEqual(_Container.overridden, tuple()) - self.assertEqual(_Container.p11.overridden, tuple()) - - def test_copy(self): - @containers.copy(ContainerA) - class _Container1(ContainerA): - pass - - @containers.copy(ContainerA) - class _Container2(ContainerA): - pass - - self.assertIsNot(ContainerA.p11, _Container1.p11) - self.assertIsNot(ContainerA.p12, _Container1.p12) - - self.assertIsNot(ContainerA.p11, _Container2.p11) - self.assertIsNot(ContainerA.p12, _Container2.p12) - - self.assertIsNot(_Container1.p11, _Container2.p11) - self.assertIsNot(_Container1.p12, _Container2.p12) - - def test_copy_with_replacing(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Object(0) - p12 = providers.Factory(dict, p11=p11) - - @containers.copy(_Container) - class _Container1(_Container): - p11 = providers.Object(1) - p13 = providers.Object(11) - - @containers.copy(_Container) - class _Container2(_Container): - p11 = providers.Object(2) - p13 = providers.Object(22) - - self.assertIsNot(_Container.p11, _Container1.p11) - self.assertIsNot(_Container.p12, _Container1.p12) - - self.assertIsNot(_Container.p11, _Container2.p11) - self.assertIsNot(_Container.p12, _Container2.p12) - - self.assertIsNot(_Container1.p11, _Container2.p11) - self.assertIsNot(_Container1.p12, _Container2.p12) - - self.assertEqual(_Container.p12(), {"p11": 0}) - self.assertEqual(_Container1.p12(), {"p11": 1}) - self.assertEqual(_Container2.p12(), {"p11": 2}) - - self.assertEqual(_Container1.p13(), 11) - self.assertEqual(_Container2.p13(), 22) - - def test_copy_with_parent_dependency(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/477 - class Base(containers.DeclarativeContainer): - p11 = providers.Object(0) - p12 = providers.Factory(dict, p11=p11) - - @containers.copy(Base) - class New(Base): - p13 = providers.Factory(dict, p12=Base.p12) - - new1 = New() - new2 = New(p11=1) - new3 = New(p11=2) - - self.assertEqual(new1.p13(), {"p12": {"p11": 0}}) - self.assertEqual(new2.p13(), {"p12": {"p11": 1}}) - self.assertEqual(new3.p13(), {"p12": {"p11": 2}}) - - def test_copy_with_replacing_subcontainer_providers(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/374 - class X(containers.DeclarativeContainer): - foo = providers.Dependency(instance_of=str) - - def build_x(): - return X(foo="1") - - class A(containers.DeclarativeContainer): - x = providers.DependenciesContainer(**X.providers) - y = x.foo - - @containers.copy(A) - class B1(A): - x = providers.Container(build_x) - - b1 = B1() - - self.assertEqual(b1.y(), "1") - - def test_containers_attribute(self): - class Container(containers.DeclarativeContainer): - class Container1(containers.DeclarativeContainer): - pass - - class Container2(containers.DeclarativeContainer): - pass - - Container3 = containers.DynamicContainer() - - self.assertEqual(Container.containers, - dict(Container1=Container.Container1, - Container2=Container.Container2, - Container3=Container.Container3)) - - def test_init_with_overriding_providers(self): - p1 = providers.Provider() - p2 = providers.Provider() - - container = ContainerA(p11=p1, p12=p2) - - self.assertIs(container.p11.last_overriding, p1) - self.assertIs(container.p12.last_overriding, p2) - - def test_init_with_overridden_dependency(self): - # Bug: - # https://github.com/ets-labs/python-dependency-injector/issues/198 - class _Container(containers.DeclarativeContainer): - p1 = providers.Dependency(instance_of=int) - - p2 = providers.Dependency(object) - p2.override(providers.Factory(dict, p1=p1)) - - container = _Container(p1=1) - - self.assertEqual(container.p2(), {"p1": 1}) - self.assertIs( - container.p2.last_overriding.kwargs["p1"], - container.p1, - ) - self.assertIsNot( - container.p2.last_overriding.kwargs["p1"], - _Container.p1, - ) - self.assertIs( - _Container.p2.last_overriding.kwargs["p1"], - _Container.p1, - ) - - def test_init_with_chained_dependency(self): - # Bug: - # https://github.com/ets-labs/python-dependency-injector/issues/200 - class _Container(containers.DeclarativeContainer): - p1 = providers.Dependency(instance_of=int) - p2 = providers.Factory(p1) - - container = _Container(p1=1) - - self.assertEqual(container.p2(), 1) - self.assertIs(container.p2.cls, container.p1) - self.assertIs(_Container.p2.cls, _Container.p1) - self.assertIsNot(container.p2.cls, _Container.p1) - - def test_init_with_dependency_delegation(self): - # Bug: - # https://github.com/ets-labs/python-dependency-injector/issues/235 - A = collections.namedtuple("A", []) - B = collections.namedtuple("B", ["fa"]) - C = collections.namedtuple("B", ["a"]) - - class Services(containers.DeclarativeContainer): - a = providers.Dependency() - c = providers.Factory(C, a=a) - b = providers.Factory(B, fa=a.provider) - - a = providers.Factory(A) - assert isinstance(Services(a=a).c().a, A) # ok - Services(a=a).b().fa() - - def test_init_with_grand_child_provider(self): - # Bug: - # https://github.com/ets-labs/python-dependency-injector/issues/350 - provider = providers.Provider() - container = ContainerC(p11=provider) - - self.assertIsInstance(container.p11, providers.Provider) - self.assertIsInstance(container.p12, providers.Provider) - self.assertIsInstance(container.p21, providers.Provider) - self.assertIsInstance(container.p22, providers.Provider) - self.assertIsInstance(container.p31, providers.Provider) - self.assertIsInstance(container.p32, providers.Provider) - self.assertIs(container.p11.last_overriding, provider) - - def test_parent_set_in__new__(self): - class Container(containers.DeclarativeContainer): - dependency = providers.Dependency() - dependencies_container = providers.DependenciesContainer() - container = providers.Container(ContainerA) - - self.assertIs(Container.dependency.parent, Container) - self.assertIs(Container.dependencies_container.parent, Container) - self.assertIs(Container.container.parent, Container) - - def test_parent_set_in__setattr__(self): - class Container(containers.DeclarativeContainer): - pass - - Container.dependency = providers.Dependency() - Container.dependencies_container = providers.DependenciesContainer() - Container.container = providers.Container(ContainerA) - - self.assertIs(Container.dependency.parent, Container) - self.assertIs(Container.dependencies_container.parent, Container) - self.assertIs(Container.container.parent, Container) - - def test_resolve_provider_name(self): - self.assertEqual(ContainerA.resolve_provider_name(ContainerA.p11), "p11") - - def test_resolve_provider_name_no_provider(self): - with self.assertRaises(errors.Error): - ContainerA.resolve_provider_name(providers.Provider()) - - def test_child_dependency_parent_name(self): - class Container(containers.DeclarativeContainer): - dependency = providers.Dependency() - - with self.assertRaises(errors.Error) as context: - Container.dependency() - self.assertEqual( - str(context.exception), - "Dependency \"Container.dependency\" is not defined", - ) - - def test_child_dependencies_container_parent_name(self): - class Container(containers.DeclarativeContainer): - dependencies_container = providers.DependenciesContainer() - - with self.assertRaises(errors.Error) as context: - Container.dependencies_container.dependency() - self.assertEqual( - str(context.exception), - "Dependency \"Container.dependencies_container.dependency\" is not defined", - ) - - def test_child_container_parent_name(self): - class ChildContainer(containers.DeclarativeContainer): - dependency = providers.Dependency() - - class Container(containers.DeclarativeContainer): - child_container = providers.Container(ChildContainer) - - with self.assertRaises(errors.Error) as context: - Container.child_container.dependency() - self.assertEqual( - str(context.exception), - "Dependency \"Container.child_container.dependency\" is not defined", - ) - - -class DeclarativeContainerWithCustomStringTests(unittest.TestCase): - # See: https://github.com/ets-labs/python-dependency-injector/issues/479 - - class CustomString(str): - pass - - class CustomClass: - thing = None - - class CustomContainer(containers.DeclarativeContainer): - pass - - def setUp(self): - self.container = self.CustomContainer - self.provider = providers.Provider() - - def test_setattr(self): - setattr(self.container, self.CustomString("test_attr"), self.provider) - self.assertIs(self.container.test_attr, self.provider) - - def test_delattr(self): - setattr(self.container, self.CustomString("test_attr"), self.provider) - delattr(self.container, self.CustomString("test_attr")) - with self.assertRaises(AttributeError): - self.container.test_attr diff --git a/tests/unit/containers/test_dynamic_async_resources_py36.py b/tests/unit/containers/test_dynamic_async_resources_py36.py deleted file mode 100644 index eaad68b9..00000000 --- a/tests/unit/containers/test_dynamic_async_resources_py36.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Dependency injector dynamic container unit tests for async resources.""" -import asyncio - -# Runtime import to get asyncutils module -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -import sys -sys.path.append(_TOP_DIR) - -from asyncutils import AsyncTestCase - -from dependency_injector import ( - containers, - providers, -) - - -class AsyncResourcesTest(AsyncTestCase): - - def test_init_and_shutdown_ordering(self): - """Test init and shutdown resources. - - Methods .init_resources() and .shutdown_resources() should respect resources dependencies. - Initialization should first initialize resources without dependencies and then provide - these resources to other resources. Resources shutdown should follow the same rule: first - shutdown resources without initialized dependencies and then continue correspondingly - until all resources are shutdown. - """ - initialized_resources = [] - shutdown_resources = [] - - async def _resource(name, delay, **_): - await asyncio.sleep(delay) - initialized_resources.append(name) - - yield name - - await asyncio.sleep(delay) - shutdown_resources.append(name) - - class Container(containers.DeclarativeContainer): - resource1 = providers.Resource( - _resource, - name="r1", - delay=0.03, - ) - resource2 = providers.Resource( - _resource, - name="r2", - delay=0.02, - r1=resource1, - ) - resource3 = providers.Resource( - _resource, - name="r3", - delay=0.01, - r2=resource2, - ) - - container = Container() - - self._run(container.init_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, []) - - self._run(container.shutdown_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1"]) - - self._run(container.init_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3", "r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1"]) - - self._run(container.shutdown_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3", "r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1", "r3", "r2", "r1"]) - - def test_shutdown_circular_dependencies_breaker(self): - async def _resource(name, **_): - yield name - - class Container(containers.DeclarativeContainer): - resource1 = providers.Resource( - _resource, - name="r1", - ) - resource2 = providers.Resource( - _resource, - name="r2", - r1=resource1, - ) - resource3 = providers.Resource( - _resource, - name="r3", - r2=resource2, - ) - - container = Container() - self._run(container.init_resources()) - - # Create circular dependency after initialization (r3 -> r2 -> r1 -> r3 -> ...) - container.resource1.add_kwargs(r3=container.resource3) - - with self.assertRaises(RuntimeError) as context: - self._run(container.shutdown_resources()) - self.assertEqual(str(context.exception), "Unable to resolve resources shutdown order") - - def test_shutdown_sync_and_async_ordering(self): - initialized_resources = [] - shutdown_resources = [] - - def _sync_resource(name, **_): - initialized_resources.append(name) - yield name - shutdown_resources.append(name) - - async def _async_resource(name, **_): - initialized_resources.append(name) - yield name - shutdown_resources.append(name) - - class Container(containers.DeclarativeContainer): - resource1 = providers.Resource( - _sync_resource, - name="r1", - ) - resource2 = providers.Resource( - _sync_resource, - name="r2", - r1=resource1, - ) - resource3 = providers.Resource( - _async_resource, - name="r3", - r2=resource2, - ) - - container = Container() - - self._run(container.init_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, []) - - self._run(container.shutdown_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1"]) - - self._run(container.init_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3", "r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1"]) - - self._run(container.shutdown_resources()) - self.assertEqual(initialized_resources, ["r1", "r2", "r3", "r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1", "r3", "r2", "r1"]) diff --git a/tests/unit/containers/test_dynamic_py2_py3.py b/tests/unit/containers/test_dynamic_py2_py3.py deleted file mode 100644 index d24bb62e..00000000 --- a/tests/unit/containers/test_dynamic_py2_py3.py +++ /dev/null @@ -1,733 +0,0 @@ -"""Dependency injector dynamic container unit tests.""" - -import unittest - -from dependency_injector import ( - containers, - providers, - errors, -) - - -class ContainerA(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - -class DeclarativeContainerInstanceTests(unittest.TestCase): - - def test_providers_attribute(self): - container_a1 = ContainerA() - container_a2 = ContainerA() - - self.assertIsNot(container_a1.p11, container_a2.p11) - self.assertIsNot(container_a1.p12, container_a2.p12) - self.assertNotEqual(container_a1.providers, container_a2.providers) - - def test_dependencies_attribute(self): - container = ContainerA() - container.a1 = providers.Dependency() - container.a2 = providers.DependenciesContainer() - self.assertEqual(container.dependencies, {"a1": container.a1, "a2": container.a2}) - - def test_set_get_del_providers(self): - p13 = providers.Provider() - - container_a1 = ContainerA() - container_a2 = ContainerA() - - container_a1.p13 = p13 - container_a2.p13 = p13 - - self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - - self.assertEqual(container_a1.providers, dict(p11=container_a1.p11, - p12=container_a1.p12, - p13=p13)) - self.assertEqual(container_a2.providers, dict(p11=container_a2.p11, - p12=container_a2.p12, - p13=p13)) - - del container_a1.p13 - self.assertEqual(container_a1.providers, dict(p11=container_a1.p11, - p12=container_a1.p12)) - - del container_a2.p13 - self.assertEqual(container_a2.providers, dict(p11=container_a2.p11, - p12=container_a2.p12)) - - del container_a1.p11 - del container_a1.p12 - self.assertEqual(container_a1.providers, dict()) - self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - - del container_a2.p11 - del container_a2.p12 - self.assertEqual(container_a2.providers, dict()) - self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, - p12=ContainerA.p12)) - - def test_set_invalid_provider_type(self): - container_a = ContainerA() - container_a.provider_type = providers.Object - - with self.assertRaises(errors.Error): - container_a.px = providers.Provider() - - self.assertIs(ContainerA.provider_type, - containers.DeclarativeContainer.provider_type) - - def test_set_providers(self): - p13 = providers.Provider() - p14 = providers.Provider() - container_a = ContainerA() - - container_a.set_providers(p13=p13, p14=p14) - - self.assertIs(container_a.p13, p13) - self.assertIs(container_a.p14, p14) - - def test_override(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer1(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer2(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - container = _Container() - overriding_container1 = _OverridingContainer1() - overriding_container2 = _OverridingContainer2() - - container.override(overriding_container1) - container.override(overriding_container2) - - self.assertEqual(container.overridden, - (overriding_container1, - overriding_container2)) - self.assertEqual(container.p11.overridden, - (overriding_container1.p11, - overriding_container2.p11)) - - self.assertEqual(_Container.overridden, tuple()) - self.assertEqual(_Container.p11.overridden, tuple()) - - def test_override_with_itself(self): - container = ContainerA() - with self.assertRaises(errors.Error): - container.override(container) - - def test_override_providers(self): - p1 = providers.Provider() - p2 = providers.Provider() - container_a = ContainerA() - - container_a.override_providers(p11=p1, p12=p2) - - self.assertIs(container_a.p11.last_overriding, p1) - self.assertIs(container_a.p12.last_overriding, p2) - - def test_override_providers_context_manager(self): - p1 = providers.Provider() - p2 = providers.Provider() - container_a = ContainerA() - - with container_a.override_providers(p11=p1, p12=p2) as container: - self.assertIs(container, container_a) - self.assertIs(container_a.p11.last_overriding, p1) - self.assertIs(container_a.p12.last_overriding, p2) - - self.assertIsNone(container_a.p11.last_overriding) - self.assertIsNone(container_a.p12.last_overriding) - - def test_override_providers_with_unknown_provider(self): - container_a = ContainerA() - - with self.assertRaises(AttributeError): - container_a.override_providers(unknown=providers.Provider()) - - def test_reset_last_overriding(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer1(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer2(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - container = _Container() - overriding_container1 = _OverridingContainer1() - overriding_container2 = _OverridingContainer2() - - container.override(overriding_container1) - container.override(overriding_container2) - container.reset_last_overriding() - - self.assertEqual(container.overridden, - (overriding_container1,)) - self.assertEqual(container.p11.overridden, - (overriding_container1.p11,)) - - def test_reset_last_overriding_when_not_overridden(self): - container = ContainerA() - - with self.assertRaises(errors.Error): - container.reset_last_overriding() - - def test_reset_override(self): - class _Container(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer1(containers.DeclarativeContainer): - p11 = providers.Provider() - - class _OverridingContainer2(containers.DeclarativeContainer): - p11 = providers.Provider() - p12 = providers.Provider() - - container = _Container() - overriding_container1 = _OverridingContainer1() - overriding_container2 = _OverridingContainer2() - - container.override(overriding_container1) - container.override(overriding_container2) - container.reset_override() - - self.assertEqual(container.overridden, tuple()) - self.assertEqual(container.p11.overridden, tuple()) - - def test_init_and_shutdown_resources_ordering(self): - """Test init and shutdown resources. - - Methods .init_resources() and .shutdown_resources() should respect resources dependencies. - Initialization should first initialize resources without dependencies and then provide - these resources to other resources. Resources shutdown should follow the same rule: first - shutdown resources without initialized dependencies and then continue correspondingly - until all resources are shutdown. - """ - initialized_resources = [] - shutdown_resources = [] - - def _resource(name, **_): - initialized_resources.append(name) - yield name - shutdown_resources.append(name) - - class Container(containers.DeclarativeContainer): - resource1 = providers.Resource( - _resource, - name="r1", - ) - resource2 = providers.Resource( - _resource, - name="r2", - r1=resource1, - ) - resource3 = providers.Resource( - _resource, - name="r3", - r2=resource2, - ) - - container = Container() - - container.init_resources() - self.assertEqual(initialized_resources, ["r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, []) - - container.shutdown_resources() - self.assertEqual(initialized_resources, ["r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1"]) - - container.init_resources() - self.assertEqual(initialized_resources, ["r1", "r2", "r3", "r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1"]) - - container.shutdown_resources() - self.assertEqual(initialized_resources, ["r1", "r2", "r3", "r1", "r2", "r3"]) - self.assertEqual(shutdown_resources, ["r3", "r2", "r1", "r3", "r2", "r1"]) - - def test_shutdown_resources_circular_dependencies_breaker(self): - def _resource(name, **_): - yield name - - class Container(containers.DeclarativeContainer): - resource1 = providers.Resource( - _resource, - name="r1", - ) - resource2 = providers.Resource( - _resource, - name="r2", - r1=resource1, - ) - resource3 = providers.Resource( - _resource, - name="r3", - r2=resource2, - ) - - container = Container() - container.init_resources() - - # Create circular dependency after initialization (r3 -> r2 -> r1 -> r3 -> ...) - container.resource1.add_kwargs(r3=container.resource3) - - with self.assertRaises(RuntimeError) as context: - container.shutdown_resources() - self.assertEqual(str(context.exception), "Unable to resolve resources shutdown order") - - def test_init_shutdown_nested_resources(self): - def _init1(): - _init1.init_counter += 1 - yield - _init1.shutdown_counter += 1 - - _init1.init_counter = 0 - _init1.shutdown_counter = 0 - - def _init2(): - _init2.init_counter += 1 - yield - _init2.shutdown_counter += 1 - - _init2.init_counter = 0 - _init2.shutdown_counter = 0 - - class Container(containers.DeclarativeContainer): - - service = providers.Factory( - dict, - resource1=providers.Resource(_init1), - resource2=providers.Resource(_init2), - ) - - container = Container() - self.assertEqual(_init1.init_counter, 0) - self.assertEqual(_init1.shutdown_counter, 0) - self.assertEqual(_init2.init_counter, 0) - self.assertEqual(_init2.shutdown_counter, 0) - - container.init_resources() - self.assertEqual(_init1.init_counter, 1) - self.assertEqual(_init1.shutdown_counter, 0) - self.assertEqual(_init2.init_counter, 1) - self.assertEqual(_init2.shutdown_counter, 0) - - container.shutdown_resources() - self.assertEqual(_init1.init_counter, 1) - self.assertEqual(_init1.shutdown_counter, 1) - self.assertEqual(_init2.init_counter, 1) - self.assertEqual(_init2.shutdown_counter, 1) - - container.init_resources() - container.shutdown_resources() - self.assertEqual(_init1.init_counter, 2) - self.assertEqual(_init1.shutdown_counter, 2) - self.assertEqual(_init2.init_counter, 2) - self.assertEqual(_init2.shutdown_counter, 2) - - def test_reset_singletons(self): - class SubSubContainer(containers.DeclarativeContainer): - singleton = providers.Singleton(object) - - class SubContainer(containers.DeclarativeContainer): - singleton = providers.Singleton(object) - sub_sub_container = providers.Container(SubSubContainer) - - class Container(containers.DeclarativeContainer): - singleton = providers.Singleton(object) - sub_container = providers.Container(SubContainer) - - container = Container() - - obj11 = container.singleton() - obj12 = container.sub_container().singleton() - obj13 = container.sub_container().sub_sub_container().singleton() - - obj21 = container.singleton() - obj22 = container.sub_container().singleton() - obj23 = container.sub_container().sub_sub_container().singleton() - - self.assertIs(obj11, obj21) - self.assertIs(obj12, obj22) - self.assertIs(obj13, obj23) - - container.reset_singletons() - - obj31 = container.singleton() - obj32 = container.sub_container().singleton() - obj33 = container.sub_container().sub_sub_container().singleton() - - obj41 = container.singleton() - obj42 = container.sub_container().singleton() - obj43 = container.sub_container().sub_sub_container().singleton() - - self.assertIsNot(obj11, obj31) - self.assertIsNot(obj12, obj32) - self.assertIsNot(obj13, obj33) - - self.assertIsNot(obj21, obj31) - self.assertIsNot(obj22, obj32) - self.assertIsNot(obj23, obj33) - - self.assertIs(obj31, obj41) - self.assertIs(obj32, obj42) - self.assertIs(obj33, obj43) - - def test_reset_singletons_context_manager(self): - class Item: - def __init__(self, dependency): - self.dependency = dependency - - class Container(containers.DeclarativeContainer): - dependent = providers.Singleton(object) - singleton = providers.Singleton(Item, dependency=dependent) - - container = Container() - - instance1 = container.singleton() - with container.reset_singletons(): - instance2 = container.singleton() - instance3 = container.singleton() - - self.assertEqual(len({instance1, instance2, instance3}), 3) - self.assertEqual( - len({instance1.dependency, instance2.dependency, instance3.dependency}), - 3, - ) - - def test_reset_singletons_context_manager_as_attribute(self): - container = containers.DeclarativeContainer() - - with container.reset_singletons() as alias: - pass - - self.assertIs(container, alias) - - def test_check_dependencies(self): - class SubContainer(containers.DeclarativeContainer): - dependency = providers.Dependency() - - class Container(containers.DeclarativeContainer): - dependency = providers.Dependency() - dependencies_container = providers.DependenciesContainer() - provider = providers.List(dependencies_container.dependency) - sub_container = providers.Container(SubContainer) - - container = Container() - - with self.assertRaises(errors.Error) as context: - container.check_dependencies() - - self.assertIn("Container \"Container\" has undefined dependencies:", str(context.exception)) - self.assertIn("\"Container.dependency\"", str(context.exception)) - self.assertIn("\"Container.dependencies_container.dependency\"", str(context.exception)) - self.assertIn("\"Container.sub_container.dependency\"", str(context.exception)) - - def test_check_dependencies_all_defined(self): - class Container(containers.DeclarativeContainer): - dependency = providers.Dependency() - - container = Container(dependency="provided") - result = container.check_dependencies() - - self.assertIsNone(result) - - def test_assign_parent(self): - parent = providers.DependenciesContainer() - container = ContainerA() - - container.assign_parent(parent) - - self.assertIs(container.parent, parent) - - def test_parent_name_declarative_parent(self): - container = ContainerA() - self.assertEqual(container.parent_name, "ContainerA") - - def test_parent_name(self): - container = ContainerA() - self.assertEqual(container.parent_name, "ContainerA") - - def test_parent_name_with_deep_parenting(self): - class Container2(containers.DeclarativeContainer): - - name = providers.Container(ContainerA) - - class Container1(containers.DeclarativeContainer): - - container = providers.Container(Container2) - - container = Container1() - self.assertEqual(container.container().name.parent_name, "Container1.container.name") - - def test_parent_name_is_none(self): - container = containers.DynamicContainer() - self.assertIsNone(container.parent_name) - - def test_parent_deepcopy(self): - class Container(containers.DeclarativeContainer): - container = providers.Container(ContainerA) - - container = Container() - - copied = providers.deepcopy(container) - - self.assertIs(container.container.parent, container) - self.assertIs(copied.container.parent, copied) - - self.assertIsNot(container, copied) - self.assertIsNot(container.container, copied.container) - self.assertIsNot(container.container.parent, copied.container.parent) - - def test_resolve_provider_name(self): - container = ContainerA() - self.assertEqual(container.resolve_provider_name(container.p11), "p11") - - def test_resolve_provider_name_no_provider(self): - container = ContainerA() - with self.assertRaises(errors.Error): - container.resolve_provider_name(providers.Provider()) - - -class SelfTests(unittest.TestCase): - - def test_self(self): - def call_bar(container): - return container.bar() - - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - foo = providers.Callable(call_bar, __self__) - bar = providers.Object("hello") - - container = Container() - - self.assertIs(container.foo(), "hello") - - def test_self_attribute_implicit(self): - class Container(containers.DeclarativeContainer): - pass - - container = Container() - - self.assertIs(container.__self__(), container) - - def test_self_attribute_explicit(self): - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - - container = Container() - - self.assertIs(container.__self__(), container) - - def test_single_self(self): - with self.assertRaises(errors.Error): - class Container(containers.DeclarativeContainer): - self1 = providers.Self() - self2 = providers.Self() - - def test_self_attribute_alt_name_implicit(self): - class Container(containers.DeclarativeContainer): - foo = providers.Self() - - container = Container() - - self.assertIs(container.__self__, container.foo) - self.assertEqual(set(container.__self__.alt_names), {"foo"}) - - def test_self_attribute_alt_name_explicit_1(self): - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - foo = __self__ - bar = __self__ - - container = Container() - - self.assertIs(container.__self__, container.foo) - self.assertIs(container.__self__, container.bar) - self.assertEqual(set(container.__self__.alt_names), {"foo", "bar"}) - - def test_self_attribute_alt_name_explicit_2(self): - class Container(containers.DeclarativeContainer): - foo = providers.Self() - bar = foo - - container = Container() - - self.assertIs(container.__self__, container.foo) - self.assertIs(container.__self__, container.bar) - self.assertEqual(set(container.__self__.alt_names), {"foo", "bar"}) - - def test_providers_attribute_1(self): - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - foo = __self__ - bar = __self__ - - container = Container() - - self.assertEqual(container.providers, {}) - self.assertEqual(Container.providers, {}) - - def test_providers_attribute_2(self): - class Container(containers.DeclarativeContainer): - foo = providers.Self() - bar = foo - - container = Container() - - self.assertEqual(container.providers, {}) - self.assertEqual(Container.providers, {}) - - def test_container_multiple_instances(self): - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - - container1 = Container() - container2 = Container() - - self.assertIsNot(container1, container2) - self.assertIs(container1.__self__(), container1) - self.assertIs(container2.__self__(), container2) - - def test_deepcopy(self): - def call_bar(container): - return container.bar() - - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - foo = providers.Callable(call_bar, __self__) - bar = providers.Object("hello") - - container1 = Container() - container2 = providers.deepcopy(container1) - container1.bar.override("bye") - - self.assertIs(container1.foo(), "bye") - self.assertIs(container2.foo(), "hello") - - def test_deepcopy_alt_names_1(self): - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - foo = __self__ - bar = foo - - container1 = Container() - container2 = providers.deepcopy(container1) - - self.assertIs(container2.__self__(), container2) - self.assertIs(container2.foo(), container2) - self.assertIs(container2.bar(), container2) - - def test_deepcopy_alt_names_2(self): - class Container(containers.DeclarativeContainer): - self = providers.Self() - - container1 = Container() - container2 = providers.deepcopy(container1) - - self.assertIs(container2.__self__(), container2) - self.assertIs(container2.self(), container2) - - def test_deepcopy_no_self_dependencies(self): - class Container(containers.DeclarativeContainer): - __self__ = providers.Self() - - container1 = Container() - container2 = providers.deepcopy(container1) - - self.assertIsNot(container1, container2) - self.assertIsNot(container1.__self__, container2.__self__) - self.assertIs(container1.__self__(), container1) - self.assertIs(container2.__self__(), container2) - - def test_with_container_provider(self): - def call_bar(container): - return container.bar() - - class SubContainer(containers.DeclarativeContainer): - __self__ = providers.Self() - foo = providers.Callable(call_bar, __self__) - bar = providers.Object("hello") - - class Container(containers.DeclarativeContainer): - sub_container = providers.Container(SubContainer) - - baz = providers.Callable(lambda value: value, sub_container.foo) - - container = Container() - - self.assertIs(container.baz(), "hello") - - def test_with_container_provider_overriding(self): - def call_bar(container): - return container.bar() - - class SubContainer(containers.DeclarativeContainer): - __self__ = providers.Self() - foo = providers.Callable(call_bar, __self__) - bar = providers.Object("hello") - - class Container(containers.DeclarativeContainer): - sub_container = providers.Container(SubContainer, bar="bye") - - baz = providers.Callable(lambda value: value, sub_container.foo) - - container = Container() - - self.assertIs(container.baz(), "bye") - - def test_with_container_provider_self(self): - class SubContainer(containers.DeclarativeContainer): - __self__ = providers.Self() - - class Container(containers.DeclarativeContainer): - sub_container = providers.Container(SubContainer) - - container = Container() - - self.assertIs(container.__self__(), container) - self.assertIs(container.sub_container().__self__(), container.sub_container()) - - -class DynamicContainerWithCustomStringTests(unittest.TestCase): - # See: https://github.com/ets-labs/python-dependency-injector/issues/479 - - class CustomString(str): - pass - - class CustomClass: - thing = None - - def setUp(self): - self.container = containers.DynamicContainer() - self.provider = providers.Provider() - - def test_setattr(self): - setattr(self.container, self.CustomString("test_attr"), self.provider) - self.assertIs(self.container.test_attr, self.provider) - - def test_delattr(self): - setattr(self.container, self.CustomString("test_attr"), self.provider) - delattr(self.container, self.CustomString("test_attr")) - with self.assertRaises(AttributeError): - self.container.test_attr - - def test_set_provider(self): - self.container.set_provider(self.CustomString("test_attr"), self.provider) - self.assertIs(self.container.test_attr, self.provider) - - def test_set_providers(self): - self.container.set_providers(**{self.CustomString("test_attr"): self.provider}) - self.assertIs(self.container.test_attr, self.provider) diff --git a/tests/unit/containers/test_traversal_py3.py b/tests/unit/containers/test_traversal_py3.py index 59b81197..d10d1dff 100644 --- a/tests/unit/containers/test_traversal_py3.py +++ b/tests/unit/containers/test_traversal_py3.py @@ -1,93 +1,53 @@ -import unittest +"""Container traversing tests.""" from dependency_injector import containers, providers -class TraverseProviderTests(unittest.TestCase): - - def test_nested_providers(self): - class Container(containers.DeclarativeContainer): - obj_factory = providers.DelegatedFactory( - dict, - foo=providers.Resource( - dict, - foo="bar" - ), - bar=providers.Resource( - dict, - foo="bar" - ) - ) - - container = Container() - all_providers = list(container.traverse()) - - self.assertIn(container.obj_factory, all_providers) - self.assertIn(container.obj_factory.kwargs["foo"], all_providers) - self.assertIn(container.obj_factory.kwargs["bar"], all_providers) - self.assertEqual(len(all_providers), 3) - - def test_nested_providers_with_filtering(self): - class Container(containers.DeclarativeContainer): - obj_factory = providers.DelegatedFactory( - dict, - foo=providers.Resource( - dict, - foo="bar" - ), - bar=providers.Resource( - dict, - foo="bar" - ) - ) - - container = Container() - all_providers = list(container.traverse(types=[providers.Resource])) - - self.assertIn(container.obj_factory.kwargs["foo"], all_providers) - self.assertIn(container.obj_factory.kwargs["bar"], all_providers) - self.assertEqual(len(all_providers), 2) +class Container(containers.DeclarativeContainer): + obj_factory = providers.DelegatedFactory( + dict, + foo=providers.Resource( + dict, + foo="bar" + ), + bar=providers.Resource( + dict, + foo="bar" + ) + ) -class TraverseProviderDeclarativeTests(unittest.TestCase): +def test_nested_providers(): + container = Container() + all_providers = list(container.traverse()) - def test_nested_providers(self): - class Container(containers.DeclarativeContainer): - obj_factory = providers.DelegatedFactory( - dict, - foo=providers.Resource( - dict, - foo="bar" - ), - bar=providers.Resource( - dict, - foo="bar" - ) - ) + assert container.obj_factory in all_providers + assert container.obj_factory.kwargs["foo"] in all_providers + assert container.obj_factory.kwargs["bar"] in all_providers + assert len(all_providers) == 3 - all_providers = list(Container.traverse()) - self.assertIn(Container.obj_factory, all_providers) - self.assertIn(Container.obj_factory.kwargs["foo"], all_providers) - self.assertIn(Container.obj_factory.kwargs["bar"], all_providers) - self.assertEqual(len(all_providers), 3) +def test_nested_providers_with_filtering(): + container = Container() + all_providers = list(container.traverse(types=[providers.Resource])) - def test_nested_providers_with_filtering(self): - class Container(containers.DeclarativeContainer): - obj_factory = providers.DelegatedFactory( - dict, - foo=providers.Resource( - dict, - foo="bar" - ), - bar=providers.Resource( - dict, - foo="bar" - ) - ) + assert container.obj_factory.kwargs["foo"] in all_providers + assert container.obj_factory.kwargs["bar"] in all_providers + assert len(all_providers) == 2 - all_providers = list(Container.traverse(types=[providers.Resource])) - self.assertIn(Container.obj_factory.kwargs["foo"], all_providers) - self.assertIn(Container.obj_factory.kwargs["bar"], all_providers) - self.assertEqual(len(all_providers), 2) +def test_container_cls_nested_providers(): + all_providers = list(Container.traverse()) + + assert Container.obj_factory in all_providers + assert Container.obj_factory.kwargs["foo"] in all_providers + assert Container.obj_factory.kwargs["bar"] in all_providers + assert len(all_providers) == 3 + + +def test_container_cls_nested_providers_with_filtering(): + all_providers = list(Container.traverse(types=[providers.Resource])) + + assert Container.obj_factory.kwargs["foo"] in all_providers + assert Container.obj_factory.kwargs["bar"] in all_providers + assert len(all_providers) == 2 diff --git a/tests/unit/containers/test_types_py36.py b/tests/unit/containers/test_types_py36.py index e55a2b1b..1349157e 100644 --- a/tests/unit/containers/test_types_py36.py +++ b/tests/unit/containers/test_types_py36.py @@ -1,18 +1,13 @@ -import unittest +"""Container typing in runtime tests.""" from dependency_injector import containers -class SomeClass: - ... +def test_types_declarative(): + container: containers.Container = containers.DeclarativeContainer() + assert isinstance(container, containers.Container) -class TypesTest(unittest.TestCase): - - def test_declarative(self): - container: containers.Container = containers.DeclarativeContainer() - self.assertIsInstance(container, containers.Container) - - def test_dynamic(self): - container: containers.Container = containers.DynamicContainer() - self.assertIsInstance(container, containers.Container) +def test_types_dynamic(): + container: containers.Container = containers.DynamicContainer() + assert isinstance(container, containers.Container) diff --git a/tests/unit/ext/__init__.py b/tests/unit/ext/__init__.py index a561a38b..a42e8d0a 100644 --- a/tests/unit/ext/__init__.py +++ b/tests/unit/ext/__init__.py @@ -1 +1 @@ -"""Dependency injector extension unit tests.""" +"""Extension tests.""" diff --git a/tests/unit/ext/test_aiohttp_py35.py b/tests/unit/ext/test_aiohttp_py35.py index 42dd7516..d4fcfc20 100644 --- a/tests/unit/ext/test_aiohttp_py35.py +++ b/tests/unit/ext/test_aiohttp_py35.py @@ -1,21 +1,20 @@ -"""Dependency injector Aiohttp extension unit tests.""" - -from aiohttp import web -from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop +"""Aiohttp extension tests.""" +from aiohttp import web, test_utils from dependency_injector import containers, providers from dependency_injector.ext import aiohttp +from pytest import fixture, mark -async def index(_): +async def index_view(_): return web.Response(text="Hello World!") -async def test(_): +async def second_view(_): return web.Response(text="Test!") -class Test(web.View): +class OtherClassBasedView(web.View): async def get(self): return web.Response(text="Test class-based!") @@ -46,48 +45,56 @@ class ApplicationContainer(containers.DeclarativeContainer): ), ) - index_view = aiohttp.View(index) - test_view = aiohttp.View(test) - test_class_view = aiohttp.ClassBasedView(Test) + index_view = aiohttp.View(index_view) + second_view = aiohttp.View(second_view) + other_class_based_view = aiohttp.ClassBasedView(OtherClassBasedView) -class ApplicationTests(AioHTTPTestCase): +@fixture +def app(): + container = ApplicationContainer() + app = container.app() + app.container = container + app.add_routes([ + web.get("/", container.index_view.as_view()), + web.get("/second", container.second_view.as_view(), name="second"), + web.get("/class-based", container.other_class_based_view.as_view()), + ]) + return app - async def get_application(self): - """ - Override the get_app method to return your application. - """ - container = ApplicationContainer() - app = container.app() - app.container = container - app.add_routes([ - web.get("/", container.index_view.as_view()), - web.get("/test", container.test_view.as_view(), name="test"), - web.get("/test-class", container.test_class_view.as_view()), - ]) - return app - @unittest_run_loop - async def test_index(self): - response = await self.client.get("/") +@fixture +async def client(app): + async with test_utils.TestClient(test_utils.TestServer(app)) as client: + yield client - self.assertEqual(response.status, 200) - self.assertEqual(await response.text(), "Hello World! wink2 wink1") - @unittest_run_loop - async def test_test(self): - response = await self.client.get("/test") +@mark.asyncio +@mark.filterwarnings("ignore:The loop argument is deprecated:DeprecationWarning") +async def test_index(client): + response = await client.get("/") - self.assertEqual(response.status, 200) - self.assertEqual(await response.text(), "Test! wink2 wink1") + assert response.status == 200 + assert await response.text() == "Hello World! wink2 wink1" - @unittest_run_loop - async def test_test_class_based(self): - response = await self.client.get("/test-class") - self.assertEqual(response.status, 200) - self.assertEqual(await response.text(), "Test class-based! wink2 wink1") +@mark.asyncio +@mark.filterwarnings("ignore:The loop argument is deprecated:DeprecationWarning") +async def test_second(client): + response = await client.get("/second") - @unittest_run_loop - async def test_endpoints(self): - self.assertEqual(str(self.app.router["test"].url_for()), "/test") + assert response.status == 200 + assert await response.text() == "Test! wink2 wink1" + + +@mark.asyncio +@mark.filterwarnings("ignore:The loop argument is deprecated:DeprecationWarning") +async def test_class_based(client): + response = await client.get("/class-based") + + assert response.status == 200 + assert await response.text() == "Test class-based! wink2 wink1" + + +def test_endpoints(app): + assert str(app.router["second"].url_for()) == "/second" diff --git a/tests/unit/ext/test_flask_py2_py3.py b/tests/unit/ext/test_flask_py2_py3.py index 60ca5408..e64de165 100644 --- a/tests/unit/ext/test_flask_py2_py3.py +++ b/tests/unit/ext/test_flask_py2_py3.py @@ -1,11 +1,10 @@ -"""Dependency injector Flask extension unit tests.""" - -import unittest -from flask import Flask, url_for -from flask.views import MethodView +"""Flask extension tests.""" from dependency_injector import containers from dependency_injector.ext import flask +from flask import Flask, url_for +from flask.views import MethodView +from pytest import fixture def index(): @@ -30,47 +29,47 @@ class ApplicationContainer(containers.DeclarativeContainer): test_class_view = flask.ClassBasedView(Test) -def create_app(): +@fixture +def app(): container = ApplicationContainer() app = container.app() app.container = container + app.config["SERVER_NAME"] = "test-server.com" app.add_url_rule("/", view_func=container.index_view.as_view()) app.add_url_rule("/test", "test-test", view_func=container.test_view.as_view()) app.add_url_rule("/test-class", view_func=container.test_class_view.as_view("test-class")) return app -class ApplicationTests(unittest.TestCase): +@fixture +def client(app): + with app.test_client() as client: + yield client - def setUp(self): - self.app = create_app() - self.app.config["SERVER_NAME"] = "test-server.com" - self.client = self.app.test_client() - self.client.__enter__() - def tearDown(self): - self.client.__exit__(None, None, None) +def test_index(client): + response = client.get("/") - def test_index(self): - response = self.client.get("/") + assert response.status_code == 200 + assert response.data == b"Hello World!" - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, b"Hello World!") - def test_test(self): - response = self.client.get("/test") +def test_test(client): + response = client.get("/test") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, b"Test!") + assert response.status_code == 200 + assert response.data == b"Test!" - def test_test_class_based(self): - response = self.client.get("/test-class") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, b"Test class-based!") +def test_test_class_based(client): + response = client.get("/test-class") - def test_endpoints(self): - with self.app.app_context(): - self.assertEqual(url_for("index"), "http://test-server.com/") - self.assertEqual(url_for("test-test"), "http://test-server.com/test") - self.assertEqual(url_for("test-class"), "http://test-server.com/test-class") + assert response.status_code == 200 + assert response.data == b"Test class-based!" + + +def test_endpoints(app): + with app.app_context(): + assert url_for("index") == "http://test-server.com/" + assert url_for("test-test") == "http://test-server.com/test" + assert url_for("test-class") == "http://test-server.com/test-class" diff --git a/tests/unit/providers/__init__.py b/tests/unit/providers/__init__.py index 5ee4b909..7e9190ff 100644 --- a/tests/unit/providers/__init__.py +++ b/tests/unit/providers/__init__.py @@ -1 +1 @@ -"""Dependency injector providers unit tests.""" +"""Providers tests.""" diff --git a/tests/unit/providers/async/__init__.py b/tests/unit/providers/async/__init__.py new file mode 100644 index 00000000..8b3ee44e --- /dev/null +++ b/tests/unit/providers/async/__init__.py @@ -0,0 +1 @@ +"""Provider asynchronous mode tests.""" diff --git a/tests/unit/providers/async/common.py b/tests/unit/providers/async/common.py new file mode 100644 index 00000000..ddea3e79 --- /dev/null +++ b/tests/unit/providers/async/common.py @@ -0,0 +1,45 @@ +"""Common test artifacts.""" + +import asyncio +import random + +from dependency_injector import containers, providers + + +RESOURCE1 = object() +RESOURCE2 = object() + + +async def init_resource(resource): + await asyncio.sleep(random.randint(1, 10) / 1000) + yield resource + await asyncio.sleep(random.randint(1, 10) / 1000) + + +class Client: + def __init__(self, resource1: object, resource2: object) -> None: + self.resource1 = resource1 + self.resource2 = resource2 + + +class Service: + def __init__(self, client: Client) -> None: + self.client = client + + +class BaseContainer(containers.DeclarativeContainer): + resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) + resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) + + +class Container(BaseContainer): + client = providers.Factory( + Client, + resource1=BaseContainer.resource1, + resource2=BaseContainer.resource2, + ) + + service = providers.Factory( + Service, + client=client, + ) diff --git a/tests/unit/providers/async/test_async_mode_api_py36.py b/tests/unit/providers/async/test_async_mode_api_py36.py new file mode 100644 index 00000000..bfc12723 --- /dev/null +++ b/tests/unit/providers/async/test_async_mode_api_py36.py @@ -0,0 +1,45 @@ +"""Tests for provider async mode API.""" + +from dependency_injector import providers +from pytest import fixture + + +@fixture +def provider(): + return providers.Provider() + + +def test_default_mode(provider: providers.Provider): + assert provider.is_async_mode_enabled() is False + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is True + + +def test_enable(provider: providers.Provider): + provider.enable_async_mode() + + assert provider.is_async_mode_enabled() is True + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is False + + +def test_disable(provider: providers.Provider): + provider.disable_async_mode() + + assert provider.is_async_mode_enabled() is False + assert provider.is_async_mode_disabled() is True + assert provider.is_async_mode_undefined() is False + + +def test_reset(provider: providers.Provider): + provider.enable_async_mode() + + assert provider.is_async_mode_enabled() is True + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is False + + provider.reset_async_mode() + + assert provider.is_async_mode_enabled() is False + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is True diff --git a/tests/unit/providers/async/test_delegated_singleton_py36.py b/tests/unit/providers/async/test_delegated_singleton_py36.py new file mode 100644 index 00000000..ebfbd772 --- /dev/null +++ b/tests/unit/providers/async/test_delegated_singleton_py36.py @@ -0,0 +1,38 @@ +"""DelegatedSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.DelegatedSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.DelegatedSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_delegated_thread_local_singleton_py36.py b/tests/unit/providers/async/test_delegated_thread_local_singleton_py36.py new file mode 100644 index 00000000..5f5e9423 --- /dev/null +++ b/tests/unit/providers/async/test_delegated_thread_local_singleton_py36.py @@ -0,0 +1,38 @@ +"""DelegatedThreadLocalSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.DelegatedThreadLocalSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.DelegatedThreadLocalSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_delegated_thread_safe_singleton_py36.py b/tests/unit/providers/async/test_delegated_thread_safe_singleton_py36.py new file mode 100644 index 00000000..046ce951 --- /dev/null +++ b/tests/unit/providers/async/test_delegated_thread_safe_singleton_py36.py @@ -0,0 +1,38 @@ +"""DelegatedThreadSafeSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.DelegatedThreadSafeSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.DelegatedThreadSafeSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_dependency_py36.py b/tests/unit/providers/async/test_dependency_py36.py new file mode 100644 index 00000000..b42d3a97 --- /dev/null +++ b/tests/unit/providers/async/test_dependency_py36.py @@ -0,0 +1,88 @@ +"""Dependency provider async mode tests.""" + +from dependency_injector import providers, errors +from pytest import mark, raises + + +@mark.asyncio +async def test_provide_error(): + async def get_async(): + raise Exception + + provider = providers.Dependency() + provider.override(providers.Callable(get_async)) + + with raises(Exception): + await provider() + + +@mark.asyncio +async def test_isinstance(): + dependency = 1.0 + + async def get_async(): + return dependency + + provider = providers.Dependency(instance_of=float) + provider.override(providers.Callable(get_async)) + + assert provider.is_async_mode_undefined() is True + + dependency1 = await provider() + + assert provider.is_async_mode_enabled() is True + + dependency2 = await provider() + + assert dependency1 == dependency + assert dependency2 == dependency + + +@mark.asyncio +async def test_isinstance_invalid(): + async def get_async(): + return {} + + provider = providers.Dependency(instance_of=float) + provider.override(providers.Callable(get_async)) + + assert provider.is_async_mode_undefined() is True + + with raises(errors.Error): + await provider() + + assert provider.is_async_mode_enabled() is True + + +@mark.asyncio +async def test_async_mode(): + dependency = 123 + + async def get_async(): + return dependency + + def get_sync(): + return dependency + + provider = providers.Dependency(instance_of=int) + provider.override(providers.Factory(get_async)) + + assert provider.is_async_mode_undefined() is True + + dependency1 = await provider() + + assert provider.is_async_mode_enabled() is True + + dependency2 = await provider() + assert dependency1 == dependency + assert dependency2 == dependency + + provider.override(providers.Factory(get_sync)) + + dependency3 = await provider() + + assert provider.is_async_mode_enabled() is True + + dependency4 = await provider() + assert dependency3 == dependency + assert dependency4 == dependency diff --git a/tests/unit/providers/async/test_dict_py36.py b/tests/unit/providers/async/test_dict_py36.py new file mode 100644 index 00000000..132df56c --- /dev/null +++ b/tests/unit/providers/async/test_dict_py36.py @@ -0,0 +1,23 @@ +"""Dict provider async mode tests.""" + +from dependency_injector import containers, providers +from pytest import mark + + +@mark.asyncio +async def test_provide(): + async def create_resource(param: str): + return param + + class Container(containers.DeclarativeContainer): + + resources = providers.Dict( + foo=providers.Resource(create_resource, "foo"), + bar=providers.Resource(create_resource, "bar") + ) + + container = Container() + resources = await container.resources() + + assert resources["foo"] == "foo" + assert resources["bar"] == "bar" diff --git a/tests/unit/providers/async/test_factory_aggregate_py36.py b/tests/unit/providers/async/test_factory_aggregate_py36.py new file mode 100644 index 00000000..ace7ffdf --- /dev/null +++ b/tests/unit/providers/async/test_factory_aggregate_py36.py @@ -0,0 +1,30 @@ +"""FactoryAggregate provider async mode tests.""" + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + object1 = object() + object2 = object() + + async def _get_object1(): + return object1 + + def _get_object2(): + return object2 + + provider = providers.FactoryAggregate( + object1=providers.Factory(_get_object1), + object2=providers.Factory(_get_object2), + ) + + assert provider.is_async_mode_undefined() is True + + created_object1 = await provider("object1") + assert created_object1 is object1 + assert provider.is_async_mode_enabled() is True + + created_object2 = await provider("object2") + assert created_object2 is object2 diff --git a/tests/unit/providers/async/test_factory_py36.py b/tests/unit/providers/async/test_factory_py36.py new file mode 100644 index 00000000..98ba2c83 --- /dev/null +++ b/tests/unit/providers/async/test_factory_py36.py @@ -0,0 +1,423 @@ +"""Factory provider async mode tests.""" + +import asyncio + +from dependency_injector import containers, providers +from pytest import mark, raises + +from .common import RESOURCE1, RESOURCE2, Client, Service, BaseContainer, Container, init_resource + + +@mark.asyncio +async def test_args_injection(): + class ContainerWithArgs(BaseContainer): + client = providers.Factory( + Client, + BaseContainer.resource1, + BaseContainer.resource2, + ) + + service = providers.Factory( + Service, + client, + ) + + container = ContainerWithArgs() + + client1 = await container.client() + client2 = await container.client() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service() + service2 = await container.service() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1.client is not service2.client + + +@mark.asyncio +async def test_kwargs_injection(): + class ContainerWithKwArgs(Container): + ... + + container = ContainerWithKwArgs() + + client1 = await container.client() + client2 = await container.client() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service() + service2 = await container.service() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1.client is not service2.client + + +@mark.asyncio +async def test_context_kwargs_injection(): + resource2_extra = object() + + container = Container() + + client1 = await container.client(resource2=resource2_extra) + client2 = await container.client(resource2=resource2_extra) + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is resource2_extra + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is resource2_extra + + +@mark.asyncio +async def test_args_kwargs_injection(): + class ContainerWithArgsAndKwArgs(BaseContainer): + client = providers.Factory( + Client, + BaseContainer.resource1, + resource2=BaseContainer.resource2, + ) + + service = providers.Factory( + Service, + client=client, + ) + + container = ContainerWithArgsAndKwArgs() + + client1 = await container.client() + client2 = await container.client() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service() + service2 = await container.service() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1.client is not service2.client + + +@mark.asyncio +async def test_async_provider_with_async_injections(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/368 + async def async_client_provider(): + return {"client": "OK"} + + async def async_service(client): + return {"service": "OK", "client": client} + + class Container(containers.DeclarativeContainer): + client = providers.Factory(async_client_provider) + service = providers.Factory(async_service, client=client) + + container = Container() + service = await container.service() + + assert service == {"service": "OK", "client": {"client": "OK"}} + + +@mark.asyncio +async def test_with_awaitable_injection(): + class SomeResource: + def __await__(self): + raise RuntimeError("Should never happen") + + async def init_resource(): + yield SomeResource() + + class Service: + def __init__(self, resource) -> None: + self.resource = resource + + class Container(containers.DeclarativeContainer): + resource = providers.Resource(init_resource) + service = providers.Factory(Service, resource=resource) + + container = Container() + + assert isinstance(container.service(), asyncio.Future) + assert isinstance(container.resource(), asyncio.Future) + + resource = await container.resource() + service = await container.service() + + assert isinstance(resource, SomeResource) + assert isinstance(service.resource, SomeResource) + assert service.resource is resource + + +@mark.asyncio +async def test_with_awaitable_injection_and_with_init_resources_call(): + class SomeResource: + def __await__(self): + raise RuntimeError("Should never happen") + + async def init_resource(): + yield SomeResource() + + class Service: + def __init__(self, resource) -> None: + self.resource = resource + + class Container(containers.DeclarativeContainer): + resource = providers.Resource(init_resource) + service = providers.Factory(Service, resource=resource) + + container = Container() + + await container.init_resources() + assert isinstance(container.service(), asyncio.Future) + assert isinstance(container.resource(), asyncio.Future) + + resource = await container.resource() + service = await container.service() + + assert isinstance(resource, SomeResource) + assert isinstance(service.resource, SomeResource) + assert service.resource is resource + + +@mark.asyncio +async def test_injection_error(): + async def init_resource(): + raise Exception("Something went wrong") + + class Container(containers.DeclarativeContainer): + resource_with_error = providers.Resource(init_resource) + + client = providers.Factory( + Client, + resource1=resource_with_error, + resource2=None, + ) + + container = Container() + + with raises(Exception, match="Something went wrong"): + await container.client() + + +@mark.asyncio +async def test_injection_runtime_error_async_provides(): + async def create_client(*args, **kwargs): + raise Exception("Something went wrong") + + class Container(BaseContainer): + client = providers.Factory( + create_client, + resource1=BaseContainer.resource1, + resource2=None, + ) + + container = Container() + + with raises(Exception, match="Something went wrong"): + await container.client() + + +@mark.asyncio +async def test_injection_call_error_async_provides(): + async def create_client(): # <-- no args defined + ... + + class Container(BaseContainer): + client = providers.Factory( + create_client, + resource1=BaseContainer.resource1, + resource2=None, + ) + + container = Container() + + with raises(TypeError) as exception_info: + await container.client() + assert "create_client() got" in str(exception_info.value) + assert "unexpected keyword argument" in str(exception_info.value) + + +@mark.asyncio +async def test_attributes_injection(): + class ContainerWithAttributes(BaseContainer): + client = providers.Factory( + Client, + BaseContainer.resource1, + resource2=None, + ) + client.add_attributes(resource2=BaseContainer.resource2) + + service = providers.Factory( + Service, + client=None, + ) + service.add_attributes(client=client) + + container = ContainerWithAttributes() + + client1 = await container.client() + client2 = await container.client() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service() + service2 = await container.service() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1.client is not service2.client + + +@mark.asyncio +async def test_attributes_injection_attribute_error(): + class ClientWithException(Client): + @property + def attribute_set_error(self): + return None + + @attribute_set_error.setter + def attribute_set_error(self, value): + raise Exception("Something went wrong") + + class Container(BaseContainer): + client = providers.Factory( + ClientWithException, + resource1=BaseContainer.resource1, + resource2=BaseContainer.resource2, + ) + client.add_attributes(attribute_set_error=123) + + container = Container() + + with raises(Exception, match="Something went wrong"): + await container.client() + + +@mark.asyncio +async def test_attributes_injection_runtime_error(): + async def init_resource(): + raise Exception("Something went wrong") + + class Container(containers.DeclarativeContainer): + resource = providers.Resource(init_resource) + + client = providers.Factory( + Client, + resource1=None, + resource2=None, + ) + client.add_attributes(resource1=resource) + client.add_attributes(resource2=resource) + + container = Container() + + with raises(Exception, match="Something went wrong"): + await container.client() + + +@mark.asyncio +async def test_async_instance_and_sync_attributes_injection(): + class ContainerWithAttributes(BaseContainer): + resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) + + client = providers.Factory( + Client, + BaseContainer.resource1, + resource2=None, + ) + client.add_attributes(resource2=providers.Object(RESOURCE2)) + + service = providers.Factory( + Service, + client=None, + ) + service.add_attributes(client=client) + + container = ContainerWithAttributes() + + client1 = await container.client() + client2 = await container.client() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service() + service2 = await container.service() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1.client is not service2.client diff --git a/tests/unit/providers/async/test_list_py36.py b/tests/unit/providers/async/test_list_py36.py new file mode 100644 index 00000000..5f0162f8 --- /dev/null +++ b/tests/unit/providers/async/test_list_py36.py @@ -0,0 +1,24 @@ +"""List provider async mode tests.""" + +from dependency_injector import containers, providers +from pytest import mark + + +@mark.asyncio +async def test_provide(): + # See issue: https://github.com/ets-labs/python-dependency-injector/issues/450 + async def create_resource(param: str): + return param + + class Container(containers.DeclarativeContainer): + + resources = providers.List( + providers.Resource(create_resource, "foo"), + providers.Resource(create_resource, "bar") + ) + + container = Container() + resources = await container.resources() + + assert resources[0] == "foo" + assert resources[1] == "bar" diff --git a/tests/unit/providers/async/test_override_py36.py b/tests/unit/providers/async/test_override_py36.py new file mode 100644 index 00000000..6e76ac3b --- /dev/null +++ b/tests/unit/providers/async/test_override_py36.py @@ -0,0 +1,127 @@ +"""Tests for provider overriding in async mode.""" + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_provider(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + def _get_dependency_sync(): + return dependency + + provider = providers.Provider() + + provider.override(providers.Callable(_get_dependency_async)) + dependency1 = await provider() + + provider.override(providers.Callable(_get_dependency_sync)) + dependency2 = await provider() + + assert dependency1 is dependency + assert dependency2 is dependency + + +@mark.asyncio +async def test_callable(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + def _get_dependency_sync(): + return dependency + + provider = providers.Callable(_get_dependency_async) + dependency1 = await provider() + + provider.override(providers.Callable(_get_dependency_sync)) + dependency2 = await provider() + + assert dependency1 is dependency + assert dependency2 is dependency + + +@mark.asyncio +async def test_factory(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + def _get_dependency_sync(): + return dependency + + provider = providers.Factory(_get_dependency_async) + dependency1 = await provider() + + provider.override(providers.Callable(_get_dependency_sync)) + dependency2 = await provider() + + assert dependency1 is dependency + assert dependency2 is dependency + + +@mark.asyncio +async def test_async_mode_enabling(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + provider = providers.Callable(_get_dependency_async) + assert provider.is_async_mode_undefined() is True + + await provider() + + assert provider.is_async_mode_enabled() is True + + +@mark.asyncio +async def test_async_mode_disabling(): + dependency = object() + + def _get_dependency(): + return dependency + + provider = providers.Callable(_get_dependency) + assert provider.is_async_mode_undefined() is True + + provider() + + assert provider.is_async_mode_disabled() is True + + +@mark.asyncio +async def test_async_mode_enabling_on_overriding(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + provider = providers.Provider() + provider.override(providers.Callable(_get_dependency_async)) + assert provider.is_async_mode_undefined() is True + + await provider() + + assert provider.is_async_mode_enabled() is True + + +def test_async_mode_disabling_on_overriding(): + dependency = object() + + def _get_dependency(): + return dependency + + provider = providers.Provider() + provider.override(providers.Callable(_get_dependency)) + assert provider.is_async_mode_undefined() is True + + provider() + + assert provider.is_async_mode_disabled() is True diff --git a/tests/unit/providers/async/test_provided_instance_py36.py b/tests/unit/providers/async/test_provided_instance_py36.py new file mode 100644 index 00000000..faea4132 --- /dev/null +++ b/tests/unit/providers/async/test_provided_instance_py36.py @@ -0,0 +1,180 @@ +"""ProvidedInstance provider async mode tests.""" + +import asyncio + +from dependency_injector import containers, providers +from pytest import mark, raises + +from .common import RESOURCE1, init_resource + + +@mark.asyncio +async def test_provided_attribute(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + class TestService: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + service = providers.Factory(TestService, resource=client.provided.resource) + + container = TestContainer() + + instance1, instance2 = await asyncio.gather( + container.service(), + container.service(), + ) + + assert instance1.resource is RESOURCE1 + assert instance2.resource is RESOURCE1 + assert instance1.resource is instance2.resource + + +@mark.asyncio +async def test_provided_attribute_error(): + async def raise_exception(): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(raise_exception) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided.attr() + + +@mark.asyncio +async def test_provided_attribute_undefined_attribute(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + + container = TestContainer() + + with raises(AttributeError): + await container.client.provided.attr() + + +@mark.asyncio +async def test_provided_item(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + def __getitem__(self, item): + return getattr(self, item) + + class TestService: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + service = providers.Factory(TestService, resource=client.provided["resource"]) + + container = TestContainer() + + instance1, instance2 = await asyncio.gather( + container.service(), + container.service(), + ) + + assert instance1.resource is RESOURCE1 + assert instance2.resource is RESOURCE1 + assert instance1.resource is instance2.resource + + +@mark.asyncio +async def test_provided_item_error(): + async def raise_exception(): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(raise_exception) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided["item"]() + + +@mark.asyncio +async def test_provided_item_undefined_item(): + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(dict, resource=resource) + + container = TestContainer() + + with raises(KeyError): + await container.client.provided["item"]() + + +@mark.asyncio +async def test_provided_method_call(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + def get_resource(self): + return self.resource + + class TestService: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + service = providers.Factory(TestService, resource=client.provided.get_resource.call()) + + container = TestContainer() + + instance1, instance2 = await asyncio.gather( + container.service(), + container.service(), + ) + + assert instance1.resource is RESOURCE1 + assert instance2.resource is RESOURCE1 + assert instance1.resource is instance2.resource + + +@mark.asyncio +async def test_provided_method_call_parent_error(): + async def raise_exception(): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(raise_exception) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided.method.call()() + + +@mark.asyncio +async def test_provided_method_call_error(): + class TestClient: + def method(self): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(TestClient) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided.method.call()() diff --git a/tests/unit/providers/async/test_singleton_py36.py b/tests/unit/providers/async/test_singleton_py36.py new file mode 100644 index 00000000..f1621bf5 --- /dev/null +++ b/tests/unit/providers/async/test_singleton_py36.py @@ -0,0 +1,117 @@ +"""Singleton provider async mode tests.""" + +import asyncio +import random + +from dependency_injector import providers +from pytest import mark, raises + +from .common import RESOURCE1, RESOURCE2, BaseContainer, Client, Service + + +@mark.asyncio +async def test_injections(): + class ContainerWithSingletons(BaseContainer): + client = providers.Singleton( + Client, + resource1=BaseContainer.resource1, + resource2=BaseContainer.resource2, + ) + + service = providers.Singleton( + Service, + client=client, + ) + + container = ContainerWithSingletons() + + client1 = await container.client() + client2 = await container.client() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service() + service2 = await container.service() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1 is service2 + assert service1.client is service2.client + assert service1.client is client1 + + assert service2.client is client2 + assert client1 is client2 + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.Singleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + await asyncio.sleep(random.randint(1, 10) / 1000) + return object() + + provider = providers.Singleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + instance3 = await provider() + + assert instance1 is instance2 is instance3 + + +@mark.asyncio +async def test_async_init_with_error(): + async def create_instance(): + create_instance.counter += 1 + raise RuntimeError() + + create_instance.counter = 0 + + provider = providers.Singleton(create_instance) + + future = provider() + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await future + + assert create_instance.counter == 1 + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await provider() + + assert create_instance.counter == 2 + assert provider.is_async_mode_enabled() is True diff --git a/tests/unit/providers/async/test_thread_local_singleton_py36.py b/tests/unit/providers/async/test_thread_local_singleton_py36.py new file mode 100644 index 00000000..bf8ec3b3 --- /dev/null +++ b/tests/unit/providers/async/test_thread_local_singleton_py36.py @@ -0,0 +1,63 @@ +"""ThreadLocalSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark, raises + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.ThreadLocalSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.ThreadLocalSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 + + +@mark.asyncio +async def test_async_init_with_error(): + async def create_instance(): + create_instance.counter += 1 + raise RuntimeError() + create_instance.counter = 0 + + provider = providers.ThreadLocalSingleton(create_instance) + + future = provider() + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await future + + assert create_instance.counter == 1 + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await provider() + + assert create_instance.counter == 2 + assert provider.is_async_mode_enabled() is True diff --git a/tests/unit/providers/async/test_thread_safe_singleton_py36.py b/tests/unit/providers/async/test_thread_safe_singleton_py36.py new file mode 100644 index 00000000..13654150 --- /dev/null +++ b/tests/unit/providers/async/test_thread_safe_singleton_py36.py @@ -0,0 +1,38 @@ +"""ThreadSafeSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.ThreadSafeSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.ThreadSafeSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_typing_stubs_py36.py b/tests/unit/providers/async/test_typing_stubs_py36.py new file mode 100644 index 00000000..c7214d2d --- /dev/null +++ b/tests/unit/providers/async/test_typing_stubs_py36.py @@ -0,0 +1,36 @@ +"""Tests for provide async mode typing stubs.""" + +from pytest import mark + +from .common import Container, Client, Service, RESOURCE1, RESOURCE2 + + +@mark.asyncio +async def test_async_(): + container = Container() + + client1 = await container.client.async_() + client2 = await container.client.async_() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service.async_() + service2 = await container.service.async_() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1.client is not service2.client diff --git a/tests/unit/providers/callables/__init__.py b/tests/unit/providers/callables/__init__.py new file mode 100644 index 00000000..c79748fa --- /dev/null +++ b/tests/unit/providers/callables/__init__.py @@ -0,0 +1 @@ +"""Tests for callables.""" diff --git a/tests/unit/providers/callables/common.py b/tests/unit/providers/callables/common.py new file mode 100644 index 00000000..822970eb --- /dev/null +++ b/tests/unit/providers/callables/common.py @@ -0,0 +1,5 @@ +"""Common test artifacts.""" + + +def example(arg1, arg2, arg3, arg4): + return arg1, arg2, arg3, arg4 diff --git a/tests/unit/providers/callables/test_abstract_callable_py2_py3.py b/tests/unit/providers/callables/test_abstract_callable_py2_py3.py new file mode 100644 index 00000000..fa794840 --- /dev/null +++ b/tests/unit/providers/callables/test_abstract_callable_py2_py3.py @@ -0,0 +1,56 @@ +"""AbstractCallable provider tests.""" + +from dependency_injector import providers, errors +from pytest import raises + +from .common import example + + +def test_inheritance(): + assert isinstance(providers.AbstractCallable(example), providers.Callable) + + +def test_call_overridden_by_callable(): + def _abstract_example(): + pass + + provider = providers.AbstractCallable(_abstract_example) + provider.override(providers.Callable(example)) + + assert provider(1, 2, 3, 4) == (1, 2, 3, 4) + + +def test_call_overridden_by_delegated_callable(): + def _abstract_example(): + pass + + provider = providers.AbstractCallable(_abstract_example) + provider.override(providers.DelegatedCallable(example)) + + assert provider(1, 2, 3, 4) == (1, 2, 3, 4) + + +def test_call_not_overridden(): + provider = providers.AbstractCallable(example) + with raises(errors.Error): + provider(1, 2, 3, 4) + + +def test_override_by_not_callable(): + provider = providers.AbstractCallable(example) + with raises(errors.Error): + provider.override(providers.Factory(object)) + + +def test_provide_not_implemented(): + provider = providers.AbstractCallable(example) + with raises(NotImplementedError): + provider._provide((1, 2, 3, 4), dict()) + + +def test_repr(): + provider = providers.AbstractCallable(example) + assert repr(provider) == ( + "".format(repr(example), hex(id(provider))) + ) diff --git a/tests/unit/providers/callables/test_callable_delegate_py2_py3.py b/tests/unit/providers/callables/test_callable_delegate_py2_py3.py new file mode 100644 index 00000000..f33f2723 --- /dev/null +++ b/tests/unit/providers/callables/test_callable_delegate_py2_py3.py @@ -0,0 +1,17 @@ +"""CallableDelegate provider tests.""" + +from dependency_injector import providers, errors +from pytest import raises + +from .common import example + + +def test_is_delegate(): + provider = providers.Callable(example) + delegate = providers.CallableDelegate(provider) + assert isinstance(delegate, providers.Delegate) + + +def test_init_with_not_callable(): + with raises(errors.Error): + providers.CallableDelegate(providers.Object(object())) diff --git a/tests/unit/providers/callables/test_callable_py2_py3.py b/tests/unit/providers/callables/test_callable_py2_py3.py new file mode 100644 index 00000000..fade0a69 --- /dev/null +++ b/tests/unit/providers/callables/test_callable_py2_py3.py @@ -0,0 +1,210 @@ +"""Callable provider tests.""" + +import sys + +from dependency_injector import providers, errors +from pytest import raises + +from .common import example + + +def test_is_provider(): + assert providers.is_provider(providers.Callable(example)) is True + + +def test_init_with_not_callable(): + with raises(errors.Error): + providers.Callable(123) + + +def test_init_optional_provides(): + provider = providers.Callable() + provider.set_provides(object) + assert provider.provides is object + assert isinstance(provider(), object) + + +def test_set_provides_returns_(): + provider = providers.Callable() + assert provider.set_provides(object) is provider + + +def test_provided_instance_provider(): + provider = providers.Callable(example) + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_call(): + provider = providers.Callable(lambda: True) + assert provider() is True + + +def test_call_with_positional_args(): + provider = providers.Callable(example, 1, 2, 3, 4) + assert provider() == (1, 2, 3, 4) + + +def test_call_with_keyword_args(): + provider = providers.Callable(example, arg1=1, arg2=2, arg3=3, arg4=4) + assert provider() == (1, 2, 3, 4) + + +def test_call_with_positional_and_keyword_args(): + provider = providers.Callable(example, 1, 2, arg3=3, arg4=4) + assert provider() == (1, 2, 3, 4) + + +def test_call_with_context_args(): + provider = providers.Callable(example, 1, 2) + assert provider(3, 4) == (1, 2, 3, 4) + + +def test_call_with_context_kwargs(): + provider = providers.Callable(example, arg1=1) + assert provider(arg2=2, arg3=3, arg4=4) == (1, 2, 3, 4) + + +def test_call_with_context_args_and_kwargs(): + provider = providers.Callable(example, 1) + assert provider(2, arg3=3, arg4=4) == (1, 2, 3, 4) + + +def test_fluent_interface(): + provider = providers.Singleton(example) \ + .add_args(1, 2) \ + .add_kwargs(arg3=3, arg4=4) + assert provider() == (1, 2, 3, 4) + + +def test_set_args(): + provider = providers.Callable(example) \ + .add_args(1, 2) \ + .set_args(3, 4) + assert provider.args == (3, 4) + + +def test_set_kwargs(): + provider = providers.Callable(example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .set_kwargs(init_arg3=4, init_arg4=5) + assert provider.kwargs == dict(init_arg3=4, init_arg4=5) + + +def test_clear_args(): + provider = providers.Callable(example) \ + .add_args(1, 2) \ + .clear_args() + assert provider.args == tuple() + + +def test_clear_kwargs(): + provider = providers.Callable(example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .clear_kwargs() + assert provider.kwargs == dict() + + +def test_call_overridden(): + provider = providers.Callable(example) + + provider.override(providers.Object((4, 3, 2, 1))) + provider.override(providers.Object((1, 2, 3, 4))) + + assert provider() == (1, 2, 3, 4) + + +def test_deepcopy(): + provider = providers.Callable(example) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert provider.provides is provider_copy.provides + assert isinstance(provider, providers.Callable) + + +def test_deepcopy_from_memo(): + provider = providers.Callable(example) + provider_copy_memo = providers.Callable(example) + + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + + assert provider_copy is provider_copy_memo + + +def test_deepcopy_args(): + provider = providers.Callable(example) + dependent_provider1 = providers.Callable(list) + dependent_provider2 = providers.Callable(dict) + + provider.add_args(dependent_provider1, dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] + + assert provider.args != provider_copy.args + + assert dependent_provider1.provides is dependent_provider_copy1.provides + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.provides is dependent_provider_copy2.provides + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs(): + provider = providers.Callable(example) + dependent_provider1 = providers.Callable(list) + dependent_provider2 = providers.Callable(dict) + + provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["a1"] + dependent_provider_copy2 = provider_copy.kwargs["a2"] + + assert provider.kwargs != provider_copy.kwargs + + assert dependent_provider1.provides is dependent_provider_copy1.provides + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.provides is dependent_provider_copy2.provides + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(): + provider = providers.Callable(example) + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.provides is provider_copy.provides + assert isinstance(provider, providers.Callable) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(): + provider = providers.Callable(example) + provider.add_args(sys.stdin) + provider.add_kwargs(a2=sys.stdout) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, providers.Callable) + assert provider.args[0] is sys.stdin + assert provider.kwargs["a2"] is sys.stdout + + +def test_repr(): + provider = providers.Callable(example) + assert repr(provider) == ( + "".format(repr(example), hex(id(provider))) + ) diff --git a/tests/unit/providers/callables/test_delegated_callable_py2_py3.py b/tests/unit/providers/callables/test_delegated_callable_py2_py3.py new file mode 100644 index 00000000..15311164 --- /dev/null +++ b/tests/unit/providers/callables/test_delegated_callable_py2_py3.py @@ -0,0 +1,26 @@ +"""DelegatedCallable provider tests.""" + +from dependency_injector import providers + +from .common import example + + +def test_inheritance(): + assert isinstance(providers.DelegatedCallable(example), providers.Callable) + + +def test_is_provider(): + assert providers.is_provider(providers.DelegatedCallable(example)) is True + + +def test_is_delegated_provider(): + provider = providers.DelegatedCallable(example) + assert providers.is_delegated(provider) is True + + +def test_repr(): + provider = providers.DelegatedCallable(example) + assert repr(provider) == ( + "".format(repr(example), hex(id(provider))) + ) diff --git a/tests/unit/providers/configuration/__init__.py b/tests/unit/providers/configuration/__init__.py new file mode 100644 index 00000000..fda0d161 --- /dev/null +++ b/tests/unit/providers/configuration/__init__.py @@ -0,0 +1 @@ +"""Configuration provider tests.""" diff --git a/tests/unit/providers/configuration/conftest.py b/tests/unit/providers/configuration/conftest.py new file mode 100644 index 00000000..027630ce --- /dev/null +++ b/tests/unit/providers/configuration/conftest.py @@ -0,0 +1,19 @@ +"""Fixtures module.""" + +from dependency_injector import providers +from pytest import fixture + + +@fixture +def config_type(): + return "default" + + +@fixture +def config(config_type): + if config_type == "strict": + return providers.Configuration(strict=True) + elif config_type == "default": + return providers.Configuration() + else: + raise ValueError("Undefined config type \"{0}\"".format(config_type)) diff --git a/tests/unit/providers/configuration/test_config_linking_py2_py3.py b/tests/unit/providers/configuration/test_config_linking_py2_py3.py new file mode 100644 index 00000000..1a1720a5 --- /dev/null +++ b/tests/unit/providers/configuration/test_config_linking_py2_py3.py @@ -0,0 +1,130 @@ +"""Tests for configuration provider linking.""" + + +from dependency_injector import containers, providers + + +class Core(containers.DeclarativeContainer): + config = providers.Configuration("core") + value_getter = providers.Callable(lambda _: _, config.value) + + +class Services(containers.DeclarativeContainer): + config = providers.Configuration("services") + value_getter = providers.Callable(lambda _: _, config.value) + + +def test(): + root_config = providers.Configuration("main") + core = Core(config=root_config.core) + services = Services(config=root_config.services) + + root_config.override( + { + "core": { + "value": "core", + }, + "services": { + "value": "services", + }, + }, + ) + + assert core.config() == {"value": "core"} + assert core.config.value() == "core" + assert core.value_getter() == "core" + + assert services.config() == {"value": "services"} + assert services.config.value() == "services" + assert services.value_getter() == "services" + + +def test_double_override(): + root_config = providers.Configuration("main") + core = Core(config=root_config.core) + services = Services(config=root_config.services) + + root_config.override( + { + "core": { + "value": "core1", + }, + "services": { + "value": "services1", + }, + }, + ) + root_config.override( + { + "core": { + "value": "core2", + }, + "services": { + "value": "services2", + }, + }, + ) + + assert core.config() == {"value": "core2"} + assert core.config.value() == "core2" + assert core.value_getter() == "core2" + + assert services.config() == {"value": "services2"} + assert services.config.value() == "services2" + assert services.value_getter() == "services2" + + +def test_reset_overriding_cache(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/428 + class Core(containers.DeclarativeContainer): + config = providers.Configuration() + + greetings = providers.Factory(str, config.greeting) + + class Application(containers.DeclarativeContainer): + config = providers.Configuration() + + core = providers.Container( + Core, + config=config, + ) + + greetings = providers.Factory(str, config.greeting) + + container = Application() + + container.config.set("greeting", "Hello World") + assert container.greetings() == "Hello World" + assert container.core.greetings() == "Hello World" + + container.config.set("greeting", "Hello Bob") + assert container.greetings() == "Hello Bob" + assert container.core.greetings() == "Hello Bob" + + +def test_reset_overriding_cache_for_option(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/428 + class Core(containers.DeclarativeContainer): + config = providers.Configuration() + + greetings = providers.Factory(str, config.greeting) + + class Application(containers.DeclarativeContainer): + config = providers.Configuration() + + core = providers.Container( + Core, + config=config.option, + ) + + greetings = providers.Factory(str, config.option.greeting) + + container = Application() + + container.config.set("option.greeting", "Hello World") + assert container.greetings() == "Hello World" + assert container.core.greetings() == "Hello World" + + container.config.set("option.greeting", "Hello Bob") + assert container.greetings() == "Hello Bob" + assert container.core.greetings() == "Hello Bob" diff --git a/tests/unit/providers/configuration/test_config_py2_py3.py b/tests/unit/providers/configuration/test_config_py2_py3.py new file mode 100644 index 00000000..34f1b409 --- /dev/null +++ b/tests/unit/providers/configuration/test_config_py2_py3.py @@ -0,0 +1,338 @@ +"""Configuration provider tests.""" + +import decimal + +from dependency_injector import containers, providers, errors +from pytest import mark, raises + + +def test_init_optional(config): + config.set_name("myconfig") + config.set_default({"foo": "bar"}) + config.set_strict(True) + + assert config.get_name() == "myconfig" + assert config.get_default() == {"foo": "bar"} + assert config.get_strict() is True + + +def test_set_name_returns_self(config): + assert config.set_name("myconfig") is config + + +def test_set_default_returns_self(config): + assert config.set_default({}) is config + + +def test_set_strict_returns_self(config): + assert config.set_strict(True) is config + + +def test_default_name(config): + assert config.get_name() == "config" + + +def test_providers_are_providers(config): + assert providers.is_provider(config.a) is True + assert providers.is_provider(config.a.b) is True + assert providers.is_provider(config.a.b.c) is True + assert providers.is_provider(config.a.b.d) is True + + +def test_providers_are_not_delegates(config): + assert providers.is_delegated(config.a) is False + assert providers.is_delegated(config.a.b) is False + assert providers.is_delegated(config.a.b.c) is False + assert providers.is_delegated(config.a.b.d) is False + + +def test_providers_identity(config): + assert config.a is config.a + assert config.a.b is config.a.b + assert config.a.b.c is config.a.b.c + assert config.a.b.d is config.a.b.d + + +def test_get_name(config): + assert config.a.b.c.get_name() == "config.a.b.c" + + +def test_providers_value_setting(config): + a = config.a + ab = config.a.b + abc = config.a.b.c + abd = config.a.b.d + + config.update({"a": {"b": {"c": 1, "d": 2}}}) + + assert a() == {"b": {"c": 1, "d": 2}} + assert ab() == {"c": 1, "d": 2} + assert abc() == 1 + assert abd() == 2 + + +def test_providers_with_already_set_value(config): + config.update({"a": {"b": {"c": 1, "d": 2}}}) + + a = config.a + ab = config.a.b + abc = config.a.b.c + abd = config.a.b.d + + assert a() == {"b": {"c": 1, "d": 2}} + assert ab() == {"c": 1, "d": 2} + assert abc() == 1 + assert abd() == 2 + + +def test_as_int(config): + value_provider = providers.Callable(lambda value: value, config.test.as_int()) + config.from_dict({"test": "123"}) + + value = value_provider() + assert value == 123 + + +def test_as_float(config): + value_provider = providers.Callable(lambda value: value, config.test.as_float()) + config.from_dict({"test": "123.123"}) + + value = value_provider() + assert value == 123.123 + + +def test_as_(config): + value_provider = providers.Callable( + lambda value: value, + config.test.as_(decimal.Decimal), + ) + config.from_dict({"test": "123.123"}) + + value = value_provider() + assert value == decimal.Decimal("123.123") + + +def test_required(config): + provider = providers.Callable( + lambda value: value, + config.a.required(), + ) + with raises(errors.Error, match="Undefined configuration option \"config.a\""): + provider() + + +def test_required_defined_none(config): + provider = providers.Callable( + lambda value: value, + config.a.required(), + ) + config.from_dict({"a": None}) + assert provider() is None + + +def test_required_no_side_effect(config): + _ = providers.Callable( + lambda value: value, + config.a.required(), + ) + assert config.a() is None + + +def test_required_as_(config): + provider = providers.List( + config.int_test.required().as_int(), + config.float_test.required().as_float(), + config._as_test.required().as_(decimal.Decimal), + ) + config.from_dict({"int_test": "1", "float_test": "2.0", "_as_test": "3.0"}) + assert provider() == [1, 2.0, decimal.Decimal("3.0")] + + +def test_providers_value_override(config): + a = config.a + ab = config.a.b + abc = config.a.b.c + abd = config.a.b.d + + config.override({"a": {"b": {"c": 1, "d": 2}}}) + + assert a() == {"b": {"c": 1, "d": 2}} + assert ab() == {"c": 1, "d": 2} + assert abc() == 1 + assert abd() == 2 + + +def test_configuration_option_override_and_reset_override(config): + # Bug: https://github.com/ets-labs/python-dependency-injector/issues/319 + config.from_dict({"a": {"b": {"c": 1}}}) + + assert config.a.b.c() == 1 + + with config.set("a.b.c", "xxx"): + assert config.a.b.c() == "xxx" + assert config.a.b.c() == 1 + + with config.a.b.c.override("yyy"): + assert config.a.b.c() == "yyy" + + assert config.a.b.c() == 1 + + +def test_providers_with_already_overridden_value(config): + config.override({"a": {"b": {"c": 1, "d": 2}}}) + + a = config.a + ab = config.a.b + abc = config.a.b.c + abd = config.a.b.d + + assert a() == {"b": {"c": 1, "d": 2}} + assert ab() == {"c": 1, "d": 2} + assert abc() == 1 + assert abd() == 2 + + +def test_providers_with_default_value(config): + config.set_default({"a": {"b": {"c": 1, "d": 2}}}) + + a = config.a + ab = config.a.b + abc = config.a.b.c + abd = config.a.b.d + + assert a() == {"b": {"c": 1, "d": 2}} + assert ab() == {"c": 1, "d": 2} + assert abc() == 1 + assert abd() == 2 + + +def test_providers_with_default_value_overriding(config): + config.set_default({"a": {"b": {"c": 1, "d": 2}}}) + + assert config.a() == {"b": {"c": 1, "d": 2}} + assert config.a.b() == {"c": 1, "d": 2} + assert config.a.b.c() == 1 + assert config.a.b.d() == 2 + + config.override({"a": {"b": {"c": 3, "d": 4}}}) + assert config.a() == {"b": {"c": 3, "d": 4}} + assert config.a.b() == {"c": 3, "d": 4} + assert config.a.b.c() == 3 + assert config.a.b.d() == 4 + + config.reset_override() + assert config.a() == {"b": {"c": 1, "d": 2}} + assert config.a.b() == {"c": 1, "d": 2} + assert config.a.b.c() == 1 + assert config.a.b.d() == 2 + + +def test_value_of_undefined_option(config): + assert config.option() is None + + +@mark.parametrize("config_type", ["strict"]) +def test_value_of_undefined_option_in_strict_mode(config): + with raises(errors.Error, match="Undefined configuration option \"config.option\""): + config.option() + + +@mark.parametrize("config_type", ["strict"]) +def test_value_of_undefined_option_with_root_none_in_strict_mode(config): + config.override(None) + with raises(errors.Error, match="Undefined configuration option \"config.option\""): + config.option() + + +@mark.parametrize("config_type", ["strict"]) +def test_value_of_defined_none_option_in_strict_mode(config): + config.from_dict({"a": None}) + assert config.a() is None + + +def test_getting_of_special_attributes(config): + with raises(AttributeError): + config.__name__ + + +def test_getting_of_special_attributes_from_child(config): + with raises(AttributeError): + config.child.__name__ + + +def test_context_manager_alias(): + class Container(containers.DeclarativeContainer): + config = providers.Configuration() + + container = Container() + + with container.config as config: + config.override({"foo": "foo", "bar": "bar"}) + + assert container.config() == {"foo": "foo", "bar": "bar"} + assert config() == {"foo": "foo", "bar": "bar"} + assert container.config is config + + +def test_option_context_manager_alias(): + class Container(containers.DeclarativeContainer): + config = providers.Configuration() + + container = Container() + + with container.config.option as option: + option.override({"foo": "foo", "bar": "bar"}) + + assert container.config() == {"option": {"foo": "foo", "bar": "bar"}} + assert container.config.option() == {"foo": "foo", "bar": "bar"} + assert option() == {"foo": "foo", "bar": "bar"} + assert container.config.option is option + + +def test_missing_key(config): + # See: https://github.com/ets-labs/python-dependency-injector/issues/358 + config.override(None) + value = config.key() + assert value is None + + +def test_deepcopy(config): + config_copy = providers.deepcopy(config) + assert isinstance(config_copy, providers.Configuration) + assert config is not config_copy + + +def test_deepcopy_from_memo(config): + config_copy_memo = providers.Configuration() + + provider_copy = providers.deepcopy(config, memo={id(config): config_copy_memo}) + assert provider_copy is config_copy_memo + + +def test_deepcopy_overridden(config): + object_provider = providers.Object(object()) + + config.override(object_provider) + + provider_copy = providers.deepcopy(config) + object_provider_copy = provider_copy.overridden[0] + + assert config is not provider_copy + assert isinstance(config, providers.Configuration) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_repr(config): + assert repr(config) == ( + "".format(repr("config"), hex(id(config))) + ) + + +def test_repr_child(config): + assert repr(config.a.b.c) == ( + "".format(repr("config.a.b.c"), hex(id(config.a.b.c))) + ) diff --git a/tests/unit/providers/configuration/test_from_dict_py2_py3.py b/tests/unit/providers/configuration/test_from_dict_py2_py3.py new file mode 100644 index 00000000..ebc833e7 --- /dev/null +++ b/tests/unit/providers/configuration/test_from_dict_py2_py3.py @@ -0,0 +1,103 @@ +"""Configuration.from_dict() tests.""" + +from pytest import mark, raises + + +CONFIG_OPTIONS_1 = { + "section1": { + "value1": "1", + }, + "section2": { + "value2": "2", + }, +} +CONFIG_OPTIONS_2 = { + "section1": { + "value1": "11", + "value11": "11", + }, + "section3": { + "value3": "3", + }, +} + + +def test(config): + config.from_dict(CONFIG_OPTIONS_1) + + assert config() == {"section1": {"value1": "1"}, "section2": {"value2": "2"}} + assert config.section1() == {"value1": "1"} + assert config.section1.value1() == "1" + assert config.section2() == {"value2": "2"} + assert config.section2.value2() == "2" + + +def test_merge(config): + config.from_dict(CONFIG_OPTIONS_1) + config.from_dict(CONFIG_OPTIONS_2) + + assert config() == { + "section1": { + "value1": "11", + "value11": "11", + }, + "section2": { + "value2": "2", + }, + "section3": { + "value3": "3", + }, + } + assert config.section1() == {"value1": "11", "value11": "11"} + assert config.section1.value1() == "11" + assert config.section1.value11() == "11" + assert config.section2() == {"value2": "2"} + assert config.section2.value2() == "2" + assert config.section3() == {"value3": "3"} + assert config.section3.value3() == "3" + + +def test_empty_dict(config): + config.from_dict({}) + assert config() == {} + + +def test_option_empty_dict(config): + config.option.from_dict({}) + assert config.option() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_empty_dict_in_strict_mode(config): + with raises(ValueError): + config.from_dict({}) + + +@mark.parametrize("config_type", ["strict"]) +def test_option_empty_dict_in_strict_mode(config): + with raises(ValueError): + config.option.from_dict({}) + + +def test_required_empty_dict(config): + with raises(ValueError): + config.from_dict({}, required=True) + + +def test_required_option_empty_dict(config): + with raises(ValueError): + config.option.from_dict({}, required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_empty_dict_strict_mode(config): + config.from_dict({}, required=False) + assert config() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_option_empty_dict_strict_mode(config): + config.option.from_dict({}, required=False) + assert config.option() == {} + assert config() == {"option": {}} + diff --git a/tests/unit/providers/configuration/test_from_env_py2_py3.py b/tests/unit/providers/configuration/test_from_env_py2_py3.py new file mode 100644 index 00000000..5c6b6395 --- /dev/null +++ b/tests/unit/providers/configuration/test_from_env_py2_py3.py @@ -0,0 +1,107 @@ +"""Configuration.from_env() tests.""" + +import os + +from pytest import fixture, mark, raises + + +@fixture(autouse=True) +def environment_variables(): + os.environ["CONFIG_TEST_ENV"] = "test-value" + yield + os.environ.pop("CONFIG_TEST_ENV", None) + + +def test(config): + config.from_env("CONFIG_TEST_ENV") + assert config() == "test-value" + + +def test_with_children(config): + config.section1.value1.from_env("CONFIG_TEST_ENV") + + assert config() == {"section1": {"value1": "test-value"}} + assert config.section1() == {"value1": "test-value"} + assert config.section1.value1() == "test-value" + + +def test_default(config): + config.from_env("UNDEFINED_ENV", "default-value") + assert config() == "default-value" + + +def test_default_none(config): + config.from_env("UNDEFINED_ENV") + assert config() is None + + +def test_option_default_none(config): + config.option.from_env("UNDEFINED_ENV") + assert config.option() is None + + +@mark.parametrize("config_type", ["strict"]) +def test_undefined_in_strict_mode(config): + with raises(ValueError): + config.from_env("UNDEFINED_ENV") + + +@mark.parametrize("config_type", ["strict"]) +def test_option_undefined_in_strict_mode(config): + with raises(ValueError): + config.option.from_env("UNDEFINED_ENV") + + +def test_undefined_in_strict_mode_with_default(config): + config.from_env("UNDEFINED_ENV", "default-value") + assert config() == "default-value" + + +@mark.parametrize("config_type", ["strict"]) +def test_option_undefined_in_strict_mode_with_default(config): + config.option.from_env("UNDEFINED_ENV", "default-value") + assert config.option() == "default-value" + + +def test_required_undefined(config): + with raises(ValueError): + config.from_env("UNDEFINED_ENV", required=True) + + +def test_required_undefined_with_default(config): + config.from_env("UNDEFINED_ENV", default="default-value", required=True) + assert config() == "default-value" + + +def test_option_required_undefined(config): + with raises(ValueError): + config.option.from_env("UNDEFINED_ENV", required=True) + + +def test_option_required_undefined_with_default(config): + config.option.from_env("UNDEFINED_ENV", default="default-value", required=True) + assert config.option() == "default-value" + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_undefined_in_strict_mode(config): + config.from_env("UNDEFINED_ENV", required=False) + assert config() is None + + +@mark.parametrize("config_type", ["strict"]) +def test_option_not_required_undefined_in_strict_mode(config): + config.option.from_env("UNDEFINED_ENV", required=False) + assert config.option() is None + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_undefined_with_default_in_strict_mode(config): + config.from_env("UNDEFINED_ENV", default="default-value", required=False) + assert config() == "default-value" + + +@mark.parametrize("config_type", ["strict"]) +def test_option_not_required_undefined_with_default_in_strict_mode(config): + config.option.from_env("UNDEFINED_ENV", default="default-value", required=False) + assert config.option() == "default-value" diff --git a/tests/unit/providers/configuration/test_from_ini_py2_py3.py b/tests/unit/providers/configuration/test_from_ini_py2_py3.py new file mode 100644 index 00000000..c9078c63 --- /dev/null +++ b/tests/unit/providers/configuration/test_from_ini_py2_py3.py @@ -0,0 +1,123 @@ +"""Configuration.from_ini() tests.""" + +from dependency_injector import errors +from pytest import fixture, mark, raises + + +@fixture +def config_file_1(tmp_path): + config_file = str(tmp_path / "config_1.ini") + with open(config_file, "w") as file: + file.write( + "[section1]\n" + "value1=1\n" + "\n" + "[section2]\n" + "value2=2\n" + ) + return config_file + + +@fixture +def config_file_2(tmp_path): + config_file = str(tmp_path / "config_2.ini") + with open(config_file, "w") as file: + file.write( + "[section1]\n" + "value1=11\n" + "value11=11\n" + "[section3]\n" + "value3=3\n" + ) + return config_file + + +def test(config, config_file_1): + config.from_ini(config_file_1) + + assert config() == {"section1": {"value1": "1"}, "section2": {"value2": "2"}} + assert config.section1() == {"value1": "1"} + assert config.section1.value1() == "1" + assert config.section2() == {"value2": "2"} + assert config.section2.value2() == "2" + + +def test_option(config, config_file_1): + config.option.from_ini(config_file_1) + + assert config() == {"option": {"section1": {"value1": "1"}, "section2": {"value2": "2"}}} + assert config.option() == {"section1": {"value1": "1"}, "section2": {"value2": "2"}} + assert config.option.section1() == {"value1": "1"} + assert config.option.section1.value1() == "1" + assert config.option.section2() == {"value2": "2"} + assert config.option.section2.value2() == "2" + + +def test_merge(config, config_file_1, config_file_2): + config.from_ini(config_file_1) + config.from_ini(config_file_2) + + assert config() == { + "section1": { + "value1": "11", + "value11": "11", + }, + "section2": { + "value2": "2", + }, + "section3": { + "value3": "3", + }, + } + assert config.section1() == {"value1": "11", "value11": "11"} + assert config.section1.value1() == "11" + assert config.section1.value11() == "11" + assert config.section2() == {"value2": "2"} + assert config.section2.value2() == "2" + assert config.section3() == {"value3": "3"} + assert config.section3.value3() == "3" + + +def test_file_does_not_exist(config): + config.from_ini("./does_not_exist.ini") + assert config() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_file_does_not_exist_strict_mode(config): + with raises(IOError): + config.from_ini("./does_not_exist.ini") + + +def test_option_file_does_not_exist(config): + config.option.from_ini("does_not_exist.ini") + assert config.option.undefined() is None + + +@mark.parametrize("config_type", ["strict"]) +def test_option_file_does_not_exist_strict_mode(config): + with raises(IOError): + config.option.from_ini("./does_not_exist.ini") + + +def test_required_file_does_not_exist(config): + with raises(IOError): + config.from_ini("./does_not_exist.ini", required=True) + + +def test_required_option_file_does_not_exist(config): + with raises(IOError): + config.option.from_ini("./does_not_exist.ini", required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_file_does_not_exist_strict_mode(config): + config.from_ini("./does_not_exist.ini", required=False) + assert config() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_option_file_does_not_exist_strict_mode(config): + config.option.from_ini("./does_not_exist.ini", required=False) + with raises(errors.Error): + config.option() diff --git a/tests/unit/providers/configuration/test_from_ini_with_env_py2_py3.py b/tests/unit/providers/configuration/test_from_ini_with_env_py2_py3.py new file mode 100644 index 00000000..cdbab241 --- /dev/null +++ b/tests/unit/providers/configuration/test_from_ini_with_env_py2_py3.py @@ -0,0 +1,145 @@ +"""Configuration.from_ini() with environment variables interpolation tests.""" + +import os + +from pytest import fixture, mark, raises + + +@fixture +def config_file(tmp_path): + config_file = str(tmp_path / "config_1.ini") + with open(config_file, "w") as file: + file.write( + "[section1]\n" + "value1=${CONFIG_TEST_ENV}\n" + "value2=${CONFIG_TEST_PATH}/path\n" + ) + return config_file + + +@fixture(autouse=True) +def environment_variables(): + os.environ["CONFIG_TEST_ENV"] = "test-value" + os.environ["CONFIG_TEST_PATH"] = "test-path" + os.environ["DEFINED"] = "defined" + yield + os.environ.pop("CONFIG_TEST_ENV", None) + os.environ.pop("CONFIG_TEST_PATH", None) + os.environ.pop("DEFINED", None) + + +def test_env_variable_interpolation(config, config_file): + config.from_ini(config_file) + + assert config() == { + "section1": { + "value1": "test-value", + "value2": "test-path/path", + }, + } + assert config.section1() == { + "value1": "test-value", + "value2": "test-path/path", + } + assert config.section1.value1() == "test-value" + assert config.section1.value2() == "test-path/path" + + +def test_missing_envs_not_required(config, config_file): + del os.environ["CONFIG_TEST_ENV"] + del os.environ["CONFIG_TEST_PATH"] + + config.from_ini(config_file) + + assert config() == { + "section1": { + "value1": "", + "value2": "/path", + }, + } + assert config.section1() == { + "value1": "", + "value2": "/path", + } + assert config.section1.value1() == "" + assert config.section1.value2() == "/path" + + +def test_missing_envs_required(config, config_file): + with open(config_file, "w") as file: + file.write( + "[section]\n" + "undefined=${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.from_ini(config_file, envs_required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_missing_envs_strict_mode(config, config_file): + with open(config_file, "w") as file: + file.write( + "[section]\n" + "undefined=${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.from_ini(config_file) + + +def test_option_missing_envs_not_required(config, config_file): + del os.environ["CONFIG_TEST_ENV"] + del os.environ["CONFIG_TEST_PATH"] + + config.option.from_ini(config_file) + + assert config.option() == { + "section1": { + "value1": "", + "value2": "/path", + }, + } + assert config.option.section1() == { + "value1": "", + "value2": "/path", + } + assert config.option.section1.value1() == "" + assert config.option.section1.value2() == "/path" + + +def test_option_missing_envs_required(config, config_file): + with open(config_file, "w") as file: + file.write( + "[section]\n" + "undefined=${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.option.from_ini(config_file, envs_required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_option_missing_envs_strict_mode(config, config_file): + with open(config_file, "w") as file: + file.write( + "[section]\n" + "undefined=${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.option.from_ini(config_file) + + +def test_default_values(config, config_file): + with open(config_file, "w") as file: + file.write( + "[section]\n" + "defined_with_default=${DEFINED:default}\n" + "undefined_with_default=${UNDEFINED:default}\n" + "complex=${DEFINED}/path/${DEFINED:default}/${UNDEFINED}/${UNDEFINED:default}\n" + ) + + config.from_ini(config_file) + + assert config.section() == { + "defined_with_default": "defined", + "undefined_with_default": "default", + "complex": "defined/path/defined//default", + } diff --git a/tests/unit/providers/configuration/test_from_pydantic_py36.py b/tests/unit/providers/configuration/test_from_pydantic_py36.py new file mode 100644 index 00000000..f5a2c97e --- /dev/null +++ b/tests/unit/providers/configuration/test_from_pydantic_py36.py @@ -0,0 +1,184 @@ +"""Configuration.from_pydantic() tests.""" + +import pydantic +from dependency_injector import providers, errors +from pytest import fixture, mark, raises + + +class Section11(pydantic.BaseModel): + value1 = 1 + + +class Section12(pydantic.BaseModel): + value2 = 2 + + +class Settings1(pydantic.BaseSettings): + section1 = Section11() + section2 = Section12() + + +class Section21(pydantic.BaseModel): + value1 = 11 + value11 = 11 + + +class Section3(pydantic.BaseModel): + value3 = 3 + + +class Settings2(pydantic.BaseSettings): + section1 = Section21() + section3 = Section3() + +@fixture +def no_pydantic_module_installed(): + providers.pydantic = None + yield + providers.pydantic = pydantic + + +def test(config): + config.from_pydantic(Settings1()) + + assert config() == {"section1": {"value1": 1}, "section2": {"value2": 2}} + assert config.section1() == {"value1": 1} + assert config.section1.value1() == 1 + assert config.section2() == {"value2": 2} + assert config.section2.value2() == 2 + + +def test_kwarg(config): + config.from_pydantic(Settings1(), exclude={"section2"}) + + assert config() == {"section1": {"value1": 1}} + assert config.section1() == {"value1": 1} + assert config.section1.value1() == 1 + + +def test_merge(config): + config.from_pydantic(Settings1()) + config.from_pydantic(Settings2()) + + assert config() == { + "section1": { + "value1": 11, + "value11": 11, + }, + "section2": { + "value2": 2, + }, + "section3": { + "value3": 3, + }, + } + assert config.section1() == {"value1": 11, "value11": 11} + assert config.section1.value1() == 11 + assert config.section1.value11() == 11 + assert config.section2() == {"value2": 2} + assert config.section2.value2() == 2 + assert config.section3() == {"value3": 3} + assert config.section3.value3() == 3 + + +def test_empty_settings(config): + config.from_pydantic(pydantic.BaseSettings()) + assert config() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_empty_settings_strict_mode(config): + with raises(ValueError): + config.from_pydantic(pydantic.BaseSettings()) + + +def test_option_empty_settings(config): + config.option.from_pydantic(pydantic.BaseSettings()) + assert config.option() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_option_empty_settings_strict_mode(config): + with raises(ValueError): + config.option.from_pydantic(pydantic.BaseSettings()) + + +def test_required_empty_settings(config): + with raises(ValueError): + config.from_pydantic(pydantic.BaseSettings(), required=True) + + +def test_required_option_empty_settings(config): + with raises(ValueError): + config.option.from_pydantic(pydantic.BaseSettings(), required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_empty_settings_strict_mode(config): + config.from_pydantic(pydantic.BaseSettings(), required=False) + assert config() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_option_empty_settings_strict_mode(config): + config.option.from_pydantic(pydantic.BaseSettings(), required=False) + assert config.option() == {} + assert config() == {"option": {}} + + +def test_not_instance_of_settings(config): + with raises(errors.Error) as error: + config.from_pydantic({}) + assert error.value.args[0] == ( + "Unable to recognize settings instance, expect \"pydantic.BaseSettings\", " + "got {0} instead".format({}) + ) + + +def test_option_not_instance_of_settings(config): + with raises(errors.Error) as error: + config.option.from_pydantic({}) + assert error.value.args[0] == ( + "Unable to recognize settings instance, expect \"pydantic.BaseSettings\", " + "got {0} instead".format({}) + ) + + +def test_subclass_instead_of_instance(config): + with raises(errors.Error) as error: + config.from_pydantic(Settings1) + assert error.value.args[0] == ( + "Got settings class, but expect instance: " + "instead \"Settings1\" use \"Settings1()\"" + ) + + +def test_option_subclass_instead_of_instance(config): + with raises(errors.Error) as error: + config.option.from_pydantic(Settings1) + assert error.value.args[0] == ( + "Got settings class, but expect instance: " + "instead \"Settings1\" use \"Settings1()\"" + ) + + +@mark.usefixtures("no_pydantic_module_installed") +def test_no_pydantic_installed(config): + with raises(errors.Error) as error: + config.from_pydantic(Settings1()) + assert error.value.args[0] == ( + "Unable to load pydantic configuration - pydantic is not installed. " + "Install pydantic or install Dependency Injector with pydantic extras: " + "\"pip install dependency-injector[pydantic]\"" + ) + + +@mark.usefixtures("no_pydantic_module_installed") +def test_option_no_pydantic_installed(config): + with raises(errors.Error) as error: + config.option.from_pydantic(Settings1()) + assert error.value.args[0] == ( + "Unable to load pydantic configuration - pydantic is not installed. " + "Install pydantic or install Dependency Injector with pydantic extras: " + "\"pip install dependency-injector[pydantic]\"" + ) diff --git a/tests/unit/providers/configuration/test_from_value_py2_py3.py b/tests/unit/providers/configuration/test_from_value_py2_py3.py new file mode 100644 index 00000000..23dd5191 --- /dev/null +++ b/tests/unit/providers/configuration/test_from_value_py2_py3.py @@ -0,0 +1,19 @@ +"""Configuration.from_value() tests.""" + + +def test_from_value(config): + test_value = 123321 + config.from_value(test_value) + assert config() == test_value + + +def test_option_from_value(config): + test_value_1 = 123 + test_value_2 = 321 + + config.option1.from_value(test_value_1) + config.option2.from_value(test_value_2) + + assert config() == {"option1": test_value_1, "option2": test_value_2} + assert config.option1() == test_value_1 + assert config.option2() == test_value_2 diff --git a/tests/unit/providers/configuration/test_from_yaml_py2_py3.py b/tests/unit/providers/configuration/test_from_yaml_py2_py3.py new file mode 100644 index 00000000..9e88f263 --- /dev/null +++ b/tests/unit/providers/configuration/test_from_yaml_py2_py3.py @@ -0,0 +1,142 @@ +"""Configuration.from_yaml() tests.""" + +from dependency_injector import providers, errors +from pytest import fixture, mark, raises + + +@fixture +def config_file_1(tmp_path): + config_file = str(tmp_path / "config_1.ini") + with open(config_file, "w") as file: + file.write( + "section1:\n" + " value1: 1\n" + "\n" + "section2:\n" + " value2: 2\n" + ) + return config_file + + +@fixture +def config_file_2(tmp_path): + config_file = str(tmp_path / "config_2.ini") + with open(config_file, "w") as file: + file.write( + "section1:\n" + " value1: 11\n" + " value11: 11\n" + "section3:\n" + " value3: 3\n" + ) + return config_file + + +@fixture +def no_yaml_module_installed(): + yaml = providers.yaml + providers.yaml = None + yield + providers.yaml = yaml + + +def test(config, config_file_1): + config.from_yaml(config_file_1) + + assert config() == {"section1": {"value1": 1}, "section2": {"value2": 2}} + assert config.section1() == {"value1": 1} + assert config.section1.value1() == 1 + assert config.section2() == {"value2": 2} + assert config.section2.value2() == 2 + + +def test_merge(config, config_file_1, config_file_2): + config.from_yaml(config_file_1) + config.from_yaml(config_file_2) + + assert config() == { + "section1": { + "value1": 11, + "value11": 11, + }, + "section2": { + "value2": 2, + }, + "section3": { + "value3": 3, + }, + } + assert config.section1() == {"value1": 11, "value11": 11} + assert config.section1.value1() == 11 + assert config.section1.value11() == 11 + assert config.section2() == {"value2": 2} + assert config.section2.value2() == 2 + assert config.section3() == {"value3": 3} + assert config.section3.value3() == 3 + + +def test_file_does_not_exist(config): + config.from_yaml("./does_not_exist.yml") + assert config() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_file_does_not_exist_strict_mode(config): + with raises(IOError): + config.from_yaml("./does_not_exist.yml") + + +def test_option_file_does_not_exist(config): + config.option.from_yaml("./does_not_exist.yml") + assert config.option() is None + + +@mark.parametrize("config_type", ["strict"]) +def test_option_file_does_not_exist_strict_mode(config): + with raises(IOError): + config.option.from_yaml("./does_not_exist.yml") + + +def test_required_file_does_not_exist(config): + with raises(IOError): + config.from_yaml("./does_not_exist.yml", required=True) + + +def test_required_option_file_does_not_exist(config): + with raises(IOError): + config.option.from_yaml("./does_not_exist.yml", required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_file_does_not_exist_strict_mode(config): + config.from_yaml("./does_not_exist.yml", required=False) + assert config() == {} + + +@mark.parametrize("config_type", ["strict"]) +def test_not_required_option_file_does_not_exist_strict_mode(config): + config.option.from_yaml("./does_not_exist.yml", required=False) + with raises(errors.Error): + config.option() + + +@mark.usefixtures("no_yaml_module_installed") +def test_no_yaml_installed(config, config_file_1): + with raises(errors.Error) as error: + config.from_yaml(config_file_1) + assert error.value.args[0] == ( + "Unable to load yaml configuration - PyYAML is not installed. " + "Install PyYAML or install Dependency Injector with yaml extras: " + "\"pip install dependency-injector[yaml]\"" + ) + + +@mark.usefixtures("no_yaml_module_installed") +def test_option_no_yaml_installed(config, config_file_1): + with raises(errors.Error) as error: + config.option.from_yaml(config_file_1) + assert error.value.args[0] == ( + "Unable to load yaml configuration - PyYAML is not installed. " + "Install PyYAML or install Dependency Injector with yaml extras: " + "\"pip install dependency-injector[yaml]\"" + ) diff --git a/tests/unit/providers/configuration/test_from_yaml_with_env_py2_py3.py b/tests/unit/providers/configuration/test_from_yaml_with_env_py2_py3.py new file mode 100644 index 00000000..8dc1e43f --- /dev/null +++ b/tests/unit/providers/configuration/test_from_yaml_with_env_py2_py3.py @@ -0,0 +1,185 @@ +"""Configuration.from_yaml() with environment variables interpolation tests.""" + +import os + +import yaml +from pytest import fixture, mark, raises + + +@fixture +def config_file(tmp_path): + config_file = str(tmp_path / "config_1.ini") + with open(config_file, "w") as file: + file.write( + "section1:\n" + " value1: ${CONFIG_TEST_ENV}\n" + " value2: ${CONFIG_TEST_PATH}/path\n" + ) + return config_file + + +@fixture(autouse=True) +def environment_variables(): + os.environ["CONFIG_TEST_ENV"] = "test-value" + os.environ["CONFIG_TEST_PATH"] = "test-path" + os.environ["DEFINED"] = "defined" + yield + os.environ.pop("CONFIG_TEST_ENV", None) + os.environ.pop("CONFIG_TEST_PATH", None) + os.environ.pop("DEFINED", None) + + +def test_env_variable_interpolation(config, config_file): + config.from_yaml(config_file) + + assert config() == { + "section1": { + "value1": "test-value", + "value2": "test-path/path", + }, + } + assert config.section1() == { + "value1": "test-value", + "value2": "test-path/path", + } + assert config.section1.value1() == "test-value" + assert config.section1.value2() == "test-path/path" + + +def test_missing_envs_not_required(config, config_file): + del os.environ["CONFIG_TEST_ENV"] + del os.environ["CONFIG_TEST_PATH"] + + config.from_yaml(config_file) + + assert config() == { + "section1": { + "value1": None, + "value2": "/path", + }, + } + assert config.section1() == { + "value1": None, + "value2": "/path", + } + assert config.section1.value1() is None + assert config.section1.value2() == "/path" + + +def test_missing_envs_required(config, config_file): + with open(config_file, "w") as file: + file.write( + "section:\n" + " undefined: ${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.from_yaml(config_file, envs_required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_missing_envs_strict_mode(config, config_file): + with open(config_file, "w") as file: + file.write( + "section:\n" + " undefined: ${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.from_yaml(config_file) + + +def test_option_missing_envs_not_required(config, config_file): + del os.environ["CONFIG_TEST_ENV"] + del os.environ["CONFIG_TEST_PATH"] + + config.option.from_yaml(config_file) + + assert config.option() == { + "section1": { + "value1": None, + "value2": "/path", + }, + } + assert config.option.section1() == { + "value1": None, + "value2": "/path", + } + assert config.option.section1.value1() is None + assert config.option.section1.value2() == "/path" + + +def test_option_missing_envs_required(config, config_file): + with open(config_file, "w") as file: + file.write( + "section:\n" + " undefined: ${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.option.from_yaml(config_file, envs_required=True) + + +@mark.parametrize("config_type", ["strict"]) +def test_option_missing_envs_strict_mode(config, config_file): + with open(config_file, "w") as file: + file.write( + "section:\n" + " undefined: ${UNDEFINED}\n" + ) + with raises(ValueError, match="Missing required environment variable \"UNDEFINED\""): + config.option.from_yaml(config_file) + + +def test_default_values(config, config_file): + with open(config_file, "w") as file: + file.write( + "section:\n" + " defined_with_default: ${DEFINED:default}\n" + " undefined_with_default: ${UNDEFINED:default}\n" + " complex: ${DEFINED}/path/${DEFINED:default}/${UNDEFINED}/${UNDEFINED:default}\n" + ) + + config.from_yaml(config_file) + + assert config.section() == { + "defined_with_default": "defined", + "undefined_with_default": "default", + "complex": "defined/path/defined//default", + } + + +def test_option_env_variable_interpolation(config, config_file): + config.option.from_yaml(config_file) + + assert config.option() == { + "section1": { + "value1": "test-value", + "value2": "test-path/path", + }, + } + assert config.option.section1() == { + "value1": "test-value", + "value2": "test-path/path", + } + assert config.option.section1.value1() == "test-value" + assert config.option.section1.value2() == "test-path/path" + + +def test_env_variable_interpolation_custom_loader(config, config_file): + config.from_yaml(config_file, loader=yaml.UnsafeLoader) + + assert config.section1() == { + "value1": "test-value", + "value2": "test-path/path", + } + assert config.section1.value1() == "test-value" + assert config.section1.value2() == "test-path/path" + + +def test_option_env_variable_interpolation_custom_loader(config, config_file): + config.option.from_yaml(config_file, loader=yaml.UnsafeLoader) + + assert config.option.section1() == { + "value1": "test-value", + "value2": "test-path/path", + } + assert config.option.section1.value1() == "test-value" + assert config.option.section1.value2() == "test-path/path" diff --git a/tests/unit/providers/coroutines/__init__.py b/tests/unit/providers/coroutines/__init__.py new file mode 100644 index 00000000..90a8072d --- /dev/null +++ b/tests/unit/providers/coroutines/__init__.py @@ -0,0 +1 @@ +"""Tests for coroutine providers.""" diff --git a/tests/unit/providers/coroutines/common.py b/tests/unit/providers/coroutines/common.py new file mode 100644 index 00000000..e092bc46 --- /dev/null +++ b/tests/unit/providers/coroutines/common.py @@ -0,0 +1,5 @@ +"""Common test artifacts.""" + + +async def example(arg1, arg2, arg3, arg4): + return arg1, arg2, arg3, arg4 diff --git a/tests/unit/providers/coroutines/test_abstract_coroutine_py35.py b/tests/unit/providers/coroutines/test_abstract_coroutine_py35.py new file mode 100644 index 00000000..4f098442 --- /dev/null +++ b/tests/unit/providers/coroutines/test_abstract_coroutine_py35.py @@ -0,0 +1,66 @@ +"""AbstractCoroutine provider tests.""" + +import asyncio + +from dependency_injector import providers, errors +from pytest import mark, raises + +from .common import example + + +def test_inheritance(): + assert isinstance(providers.AbstractCoroutine(example), providers.Coroutine) + + +@mark.asyncio +@mark.filterwarnings("ignore") +async def test_call_overridden_by_coroutine(): + @asyncio.coroutine + def abstract_example(): + raise RuntimeError("Should not be raised") + + provider = providers.AbstractCoroutine(abstract_example) + provider.override(providers.Coroutine(example)) + + result = await provider(1, 2, 3, 4) + assert result == (1, 2, 3, 4) + + +@mark.asyncio +@mark.filterwarnings("ignore") +async def test_call_overridden_by_delegated_coroutine(): + @asyncio.coroutine + def abstract_example(): + raise RuntimeError("Should not be raised") + + provider = providers.AbstractCoroutine(abstract_example) + provider.override(providers.DelegatedCoroutine(example)) + + result = await provider(1, 2, 3, 4) + assert result == (1, 2, 3, 4) + + +def test_call_not_overridden(): + provider = providers.AbstractCoroutine(example) + with raises(errors.Error): + provider(1, 2, 3, 4) + + +def test_override_by_not_coroutine(): + provider = providers.AbstractCoroutine(example) + with raises(errors.Error): + provider.override(providers.Factory(object)) + + +def test_provide_not_implemented(): + provider = providers.AbstractCoroutine(example) + with raises(NotImplementedError): + provider._provide((1, 2, 3, 4), dict()) + + +def test_repr(): + provider = providers.AbstractCoroutine(example) + assert repr(provider) == ( + "".format(repr(example), hex(id(provider))) + ) diff --git a/tests/unit/providers/coroutines/test_coroutine_delegate_py35.py b/tests/unit/providers/coroutines/test_coroutine_delegate_py35.py new file mode 100644 index 00000000..d21b30c7 --- /dev/null +++ b/tests/unit/providers/coroutines/test_coroutine_delegate_py35.py @@ -0,0 +1,17 @@ +"""CoroutineDelegate provider tests.""" + +from dependency_injector import providers, errors +from pytest import raises + +from .common import example + + +def test_is_delegate(): + provider = providers.Coroutine(example) + delegate = providers.CoroutineDelegate(provider) + assert isinstance(delegate, providers.Delegate) + + +def test_init_with_not_coroutine(): + with raises(errors.Error): + providers.CoroutineDelegate(providers.Object(object())) diff --git a/tests/unit/providers/coroutines/test_coroutine_py35.py b/tests/unit/providers/coroutines/test_coroutine_py35.py new file mode 100644 index 00000000..09a8a946 --- /dev/null +++ b/tests/unit/providers/coroutines/test_coroutine_py35.py @@ -0,0 +1,199 @@ +"""Coroutine provider tests.""" + +from dependency_injector import providers, errors +from pytest import mark, raises + +from .common import example + + +def test_init_with_coroutine(): + assert isinstance(providers.Coroutine(example), providers.Coroutine) + + +def test_init_with_not_coroutine(): + with raises(errors.Error): + providers.Coroutine(lambda: None) + + +@mark.asyncio +async def test_init_optional_provides(): + provider = providers.Coroutine() + provider.set_provides(example) + + result = await provider(1, 2, 3, 4) + + assert result == (1, 2, 3, 4) + assert provider.provides is example + + +def test_set_provides_returns_self(): + provider = providers.Coroutine() + assert provider.set_provides(example) is provider + + +@mark.asyncio +async def test_call_with_positional_args(): + provider = providers.Coroutine(example, 1, 2, 3, 4) + result = await provider() + assert result == (1, 2, 3, 4) + + +@mark.asyncio +async def test_call_with_keyword_args(): + provider = providers.Coroutine(example, arg1=1, arg2=2, arg3=3, arg4=4) + result = await provider() + assert result == (1, 2, 3, 4) + + +@mark.asyncio +async def test_call_with_positional_and_keyword_args(): + provider = providers.Coroutine(example, 1, 2, arg3=3, arg4=4) + result = await provider() + assert result == (1, 2, 3, 4) + + +@mark.asyncio +async def test_call_with_context_args(): + provider = providers.Coroutine(example, 1, 2) + result = await provider(3, 4) + assert result == (1, 2, 3, 4) + + +@mark.asyncio +async def test_call_with_context_kwargs(): + provider = providers.Coroutine(example, arg1=1) + result = await provider(arg2=2, arg3=3, arg4=4) + assert result == (1, 2, 3, 4) + + +@mark.asyncio +async def test_call_with_context_args_and_kwargs(): + provider = providers.Coroutine(example, 1) + result = await provider(2, arg3=3, arg4=4) + assert result == (1, 2, 3, 4) + + +@mark.asyncio +async def test_fluent_interface(): + provider = providers.Coroutine(example) \ + .add_args(1, 2) \ + .add_kwargs(arg3=3, arg4=4) + result = await provider() + assert result == (1, 2, 3, 4) + + +def test_set_args(): + provider = providers.Coroutine(example) \ + .add_args(1, 2) \ + .set_args(3, 4) + assert provider.args == (3, 4) + + +def test_set_kwargs(): + provider = providers.Coroutine(example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .set_kwargs(init_arg3=4, init_arg4=5) + assert provider.kwargs == dict(init_arg3=4, init_arg4=5) + + +def test_clear_args(): + provider = providers.Coroutine(example) \ + .add_args(1, 2) \ + .clear_args() + assert provider.args == tuple() + + +def test_clear_kwargs(): + provider = providers.Coroutine(example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .clear_kwargs() + assert provider.kwargs == dict() + + +def test_call_overridden(): + provider = providers.Coroutine(example) + + provider.override(providers.Object((4, 3, 2, 1))) + provider.override(providers.Object((1, 2, 3, 4))) + + assert provider() == (1, 2, 3, 4) + + +def test_deepcopy(): + provider = providers.Coroutine(example) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert provider.provides is provider_copy.provides + assert isinstance(provider, providers.Coroutine) + + +def test_deepcopy_from_memo(): + provider = providers.Coroutine(example) + provider_copy_memo = providers.Coroutine(example) + + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + + assert provider_copy is provider_copy_memo + + +def test_deepcopy_args(): + provider = providers.Coroutine(example) + dependent_provider1 = providers.Callable(list) + dependent_provider2 = providers.Callable(dict) + + provider.add_args(dependent_provider1, dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] + + assert dependent_provider1.provides is dependent_provider_copy1.provides + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.provides is dependent_provider_copy2.provides + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs(): + provider = providers.Coroutine(example) + dependent_provider1 = providers.Callable(list) + dependent_provider2 = providers.Callable(dict) + + provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["a1"] + dependent_provider_copy2 = provider_copy.kwargs["a2"] + + assert dependent_provider1.provides is dependent_provider_copy1.provides + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.provides is dependent_provider_copy2.provides + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(): + provider = providers.Coroutine(example) + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.provides is provider_copy.provides + assert isinstance(provider, providers.Callable) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_repr(): + provider = providers.Coroutine(example) + assert repr(provider) == ( + "".format(repr(example), hex(id(provider))) + ) diff --git a/tests/unit/providers/coroutines/test_delegated_coroutine_py35.py b/tests/unit/providers/coroutines/test_delegated_coroutine_py35.py new file mode 100644 index 00000000..96de0036 --- /dev/null +++ b/tests/unit/providers/coroutines/test_delegated_coroutine_py35.py @@ -0,0 +1,26 @@ +"""DelegatedCoroutine provider tests.""" + +from dependency_injector import providers + +from .common import example + + +def test_inheritance(): + assert isinstance(providers.DelegatedCoroutine(example), providers.Coroutine) + + +def test_is_provider(): + assert providers.is_provider(providers.DelegatedCoroutine(example)) is True + + +def test_is_delegated_provider(): + provider = providers.DelegatedCoroutine(example) + assert providers.is_delegated(provider) is True + + +def test_repr(): + provider = providers.DelegatedCoroutine(example) + assert repr(provider) == ( + "".format(repr(example), hex(id(provider))) + ) diff --git a/tests/unit/providers/factories/__init__.py b/tests/unit/providers/factories/__init__.py new file mode 100644 index 00000000..8642fd5c --- /dev/null +++ b/tests/unit/providers/factories/__init__.py @@ -0,0 +1 @@ +"""Tests for factories.""" diff --git a/tests/unit/providers/factories/common.py b/tests/unit/providers/factories/common.py new file mode 100644 index 00000000..21c88738 --- /dev/null +++ b/tests/unit/providers/factories/common.py @@ -0,0 +1,20 @@ +"""Common test artifacts.""" + + +class Example: + def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None, init_arg4=None): + self.init_arg1 = init_arg1 + self.init_arg2 = init_arg2 + self.init_arg3 = init_arg3 + self.init_arg4 = init_arg4 + + self.attribute1 = None + self.attribute2 = None + + +class ExampleA(Example): + pass + + +class ExampleB(Example): + pass diff --git a/tests/unit/providers/factories/test_abstract_factory_py2_py3.py b/tests/unit/providers/factories/test_abstract_factory_py2_py3.py new file mode 100644 index 00000000..b8dc9d0e --- /dev/null +++ b/tests/unit/providers/factories/test_abstract_factory_py2_py3.py @@ -0,0 +1,48 @@ +"""AbstractFactory provider tests.""" + +from dependency_injector import providers, errors +from pytest import raises + +from .common import Example + + +def test_inheritance(): + assert isinstance(providers.AbstractFactory(Example), providers.Factory) + + +def test_call_overridden_by_factory(): + provider = providers.AbstractFactory(object) + provider.override(providers.Factory(Example)) + assert isinstance(provider(), Example) + + +def test_call_overridden_by_delegated_factory(): + provider = providers.AbstractFactory(object) + provider.override(providers.DelegatedFactory(Example)) + assert isinstance(provider(), Example) + + +def test_call_not_overridden(): + provider = providers.AbstractFactory(object) + with raises(errors.Error): + provider() + + +def test_override_by_not_factory(): + provider = providers.AbstractFactory(object) + with raises(errors.Error): + provider.override(providers.Callable(object)) + + +def test_provide_not_implemented(): + provider = providers.AbstractFactory(Example) + with raises(NotImplementedError): + provider._provide(tuple(), dict()) + + +def test_repr(): + provider = providers.AbstractFactory(Example) + assert repr(provider) == ( + "".format(repr(Example), hex(id(provider))) + ) diff --git a/tests/unit/providers/factories/test_delegated_factory_py2_py3.py b/tests/unit/providers/factories/test_delegated_factory_py2_py3.py new file mode 100644 index 00000000..7fde2a87 --- /dev/null +++ b/tests/unit/providers/factories/test_delegated_factory_py2_py3.py @@ -0,0 +1,25 @@ +"""DelegatedFactory provider tests.""" + +from dependency_injector import providers + +from .common import Example + + +def test_inheritance(): + assert isinstance(providers.DelegatedFactory(object), providers.Factory) + + +def test_is_provider(): + assert providers.is_provider(providers.DelegatedFactory(object)) is True + + +def test_is_delegated_provider(): + assert providers.is_delegated(providers.DelegatedFactory(object)) is True + + +def test_repr(): + provider = providers.DelegatedFactory(Example) + assert repr(provider) == ( + "".format(repr(Example), hex(id(provider))) + ) diff --git a/tests/unit/providers/factories/test_factory_aggregate_py2_py3.py b/tests/unit/providers/factories/test_factory_aggregate_py2_py3.py new file mode 100644 index 00000000..269e24fb --- /dev/null +++ b/tests/unit/providers/factories/test_factory_aggregate_py2_py3.py @@ -0,0 +1,229 @@ +"""FactoryAggregate provider tests.""" + +from dependency_injector import providers, errors +from pytest import fixture, mark, raises + +from .common import ExampleA, ExampleB + + +@fixture +def factory_a(): + return providers.Factory(ExampleA) + + +@fixture +def factory_b(): + return providers.Factory(ExampleB) + + +@fixture +def factory_type(): + return "default" + + +@fixture +def factory_aggregate(factory_type, factory_a, factory_b): + if factory_type == "empty": + return providers.FactoryAggregate() + elif factory_type == "non-string-keys": + return providers.FactoryAggregate({ + ExampleA: factory_a, + ExampleB: factory_b, + }) + elif factory_type == "default": + return providers.FactoryAggregate( + example_a=factory_a, + example_b=factory_b, + ) + else: + raise ValueError("Unknown factory type \"{0}\"".format(factory_type)) + + +def test_is_provider(factory_aggregate): + assert providers.is_provider(factory_aggregate) is True + + +def test_is_delegated_provider(factory_aggregate): + assert providers.is_delegated(factory_aggregate) is True + + +@mark.parametrize("factory_type", ["non-string-keys"]) +def test_init_with_non_string_keys(factory_aggregate, factory_a, factory_b): + object_a = factory_aggregate(ExampleA, 1, 2, init_arg3=3, init_arg4=4) + object_b = factory_aggregate(ExampleB, 11, 22, init_arg3=33, init_arg4=44) + + assert isinstance(object_a, ExampleA) + assert object_a.init_arg1 == 1 + assert object_a.init_arg2 == 2 + assert object_a.init_arg3 == 3 + assert object_a.init_arg4 == 4 + + assert isinstance(object_b, ExampleB) + assert object_b.init_arg1 == 11 + assert object_b.init_arg2 == 22 + assert object_b.init_arg3 == 33 + assert object_b.init_arg4 == 44 + + assert factory_aggregate.factories == { + ExampleA: factory_a, + ExampleB: factory_b, + } + + +def test_init_with_not_a_factory(): + with raises(errors.Error): + providers.FactoryAggregate( + example_a=providers.Factory(ExampleA), + example_b=object(), + ) + + +@mark.parametrize("factory_type", ["empty"]) +def test_init_optional_factories(factory_aggregate, factory_a, factory_b): + factory_aggregate.set_factories( + example_a=factory_a, + example_b=factory_b, + ) + assert factory_aggregate.factories == { + "example_a": factory_a, + "example_b": factory_b, + } + assert isinstance(factory_aggregate("example_a"), ExampleA) + assert isinstance(factory_aggregate("example_b"), ExampleB) + + +@mark.parametrize("factory_type", ["non-string-keys"]) +def test_set_factories_with_non_string_keys(factory_aggregate, factory_a, factory_b): + factory_aggregate.set_factories({ + ExampleA: factory_a, + ExampleB: factory_b, + }) + + object_a = factory_aggregate(ExampleA, 1, 2, init_arg3=3, init_arg4=4) + object_b = factory_aggregate(ExampleB, 11, 22, init_arg3=33, init_arg4=44) + + assert isinstance(object_a, ExampleA) + assert object_a.init_arg1 == 1 + assert object_a.init_arg2 == 2 + assert object_a.init_arg3 == 3 + assert object_a.init_arg4 == 4 + + assert isinstance(object_b, ExampleB) + assert object_b.init_arg1 == 11 + assert object_b.init_arg2 == 22 + assert object_b.init_arg3 == 33 + assert object_b.init_arg4 == 44 + + assert factory_aggregate.factories == { + ExampleA: factory_a, + ExampleB: factory_b, + } + + +def test_set_factories_returns_self(factory_aggregate, factory_a): + assert factory_aggregate.set_factories(example_a=factory_a) is factory_aggregate + + +def test_call(factory_aggregate): + object_a = factory_aggregate("example_a", 1, 2, init_arg3=3, init_arg4=4) + object_b = factory_aggregate("example_b", 11, 22, init_arg3=33, init_arg4=44) + + assert isinstance(object_a, ExampleA) + assert object_a.init_arg1 == 1 + assert object_a.init_arg2 == 2 + assert object_a.init_arg3 == 3 + assert object_a.init_arg4 == 4 + + assert isinstance(object_b, ExampleB) + assert object_b.init_arg1 == 11 + assert object_b.init_arg2 == 22 + assert object_b.init_arg3 == 33 + assert object_b.init_arg4 == 44 + + +def test_call_factory_name_as_kwarg(factory_aggregate): + object_a = factory_aggregate( + factory_name="example_a", + init_arg1=1, + init_arg2=2, + init_arg3=3, + init_arg4=4, + ) + assert isinstance(object_a, ExampleA) + assert object_a.init_arg1 == 1 + assert object_a.init_arg2 == 2 + assert object_a.init_arg3 == 3 + assert object_a.init_arg4 == 4 + + +def test_call_no_factory_name(factory_aggregate): + with raises(TypeError): + factory_aggregate() + + +def test_call_no_such_provider(factory_aggregate): + with raises(errors.NoSuchProviderError): + factory_aggregate("unknown") + + +def test_overridden(factory_aggregate): + with raises(errors.Error): + factory_aggregate.override(providers.Object(object())) + + +def test_getattr(factory_aggregate, factory_a, factory_b): + assert factory_aggregate.example_a is factory_a + assert factory_aggregate.example_b is factory_b + + +def test_getattr_no_such_provider(factory_aggregate): + with raises(errors.NoSuchProviderError): + factory_aggregate.unknown + + +def test_factories(factory_aggregate, factory_a, factory_b): + assert factory_aggregate.factories == dict( + example_a=factory_a, + example_b=factory_b, + ) + + +def test_deepcopy(factory_aggregate): + provider_copy = providers.deepcopy(factory_aggregate) + + assert factory_aggregate is not provider_copy + assert isinstance(provider_copy, type(factory_aggregate)) + + assert factory_aggregate.example_a is not provider_copy.example_a + assert isinstance(factory_aggregate.example_a, type(provider_copy.example_a)) + assert factory_aggregate.example_a.cls is provider_copy.example_a.cls + + assert factory_aggregate.example_b is not provider_copy.example_b + assert isinstance(factory_aggregate.example_b, type(provider_copy.example_b)) + assert factory_aggregate.example_b.cls is provider_copy.example_b.cls + + +@mark.parametrize("factory_type", ["non-string-keys"]) +def test_deepcopy_with_non_string_keys(factory_aggregate): + provider_copy = providers.deepcopy(factory_aggregate) + + assert factory_aggregate is not provider_copy + assert isinstance(provider_copy, type(factory_aggregate)) + + assert factory_aggregate.factories[ExampleA] is not provider_copy.factories[ExampleA] + assert isinstance(factory_aggregate.factories[ExampleA], type(provider_copy.factories[ExampleA])) + assert factory_aggregate.factories[ExampleA].cls is provider_copy.factories[ExampleA].cls + + assert factory_aggregate.factories[ExampleB] is not provider_copy.factories[ExampleB] + assert isinstance(factory_aggregate.factories[ExampleB], type(provider_copy.factories[ExampleB])) + assert factory_aggregate.factories[ExampleB].cls is provider_copy.factories[ExampleB].cls + + +def test_repr(factory_aggregate): + assert repr(factory_aggregate) == ( + "".format( + repr(factory_aggregate.factories), + hex(id(factory_aggregate)), + ) + ) diff --git a/tests/unit/providers/factories/test_factory_delegate_py2_py3.py b/tests/unit/providers/factories/test_factory_delegate_py2_py3.py new file mode 100644 index 00000000..33fc30c4 --- /dev/null +++ b/tests/unit/providers/factories/test_factory_delegate_py2_py3.py @@ -0,0 +1,23 @@ +"""Factory delegate provider tests.""" + +from dependency_injector import providers, errors +from pytest import fixture, raises + + +@fixture +def factory(): + return providers.Factory(object) + + +@fixture +def delegate(factory): + return providers.FactoryDelegate(factory) + + +def test_is_delegate(delegate): + assert isinstance(delegate, providers.Delegate) + + +def test_init_with_not_factory(): + with raises(errors.Error): + providers.FactoryDelegate(providers.Object(object())) diff --git a/tests/unit/providers/factories/test_factory_py2_py3.py b/tests/unit/providers/factories/test_factory_py2_py3.py new file mode 100644 index 00000000..ce08006e --- /dev/null +++ b/tests/unit/providers/factories/test_factory_py2_py3.py @@ -0,0 +1,408 @@ +"""Factory provider tests.""" + +import sys + +from dependency_injector import providers, errors +from pytest import raises + +from .common import Example + + +def test_is_provider(): + assert providers.is_provider(providers.Factory(Example)) is True + + +def test_init_with_not_callable(): + with raises(errors.Error): + providers.Factory(123) + + +def test_init_optional_provides(): + provider = providers.Factory() + provider.set_provides(object) + assert provider.provides is object + assert isinstance(provider(), object) + + +def test_set_provides_returns_(): + provider = providers.Factory() + assert provider.set_provides(object) is provider + + +def test_init_with_valid_provided_type(): + class ExampleProvider(providers.Factory): + provided_type = Example + + example_provider = ExampleProvider(Example, 1, 2) + + assert isinstance(example_provider(), Example) + + +def test_init_with_valid_provided_subtype(): + class ExampleProvider(providers.Factory): + provided_type = Example + + class NewExample(Example): + pass + + example_provider = ExampleProvider(NewExample, 1, 2) + assert isinstance(example_provider(), NewExample) + + +def test_init_with_invalid_provided_type(): + class ExampleProvider(providers.Factory): + provided_type = Example + + with raises(errors.Error): + ExampleProvider(list) + + +def test_provided_instance_provider(): + provider = providers.Factory(Example) + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_call(): + provider = providers.Factory(Example) + + instance1 = provider() + instance2 = provider() + + assert instance1 is not instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_init_positional_args(): + provider = providers.Factory(Example, "i1", "i2") + + instance1 = provider() + instance2 = provider() + + assert instance1.init_arg1 == "i1" + assert instance1.init_arg2 == "i2" + + assert instance2.init_arg1 == "i1" + assert instance2.init_arg2 == "i2" + + assert instance1 is not instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_init_keyword_args(): + provider = providers.Factory(Example, init_arg1="i1", init_arg2="i2") + + instance1 = provider() + instance2 = provider() + + assert instance1.init_arg1 == "i1" + assert instance1.init_arg2 == "i2" + + assert instance2.init_arg1 == "i1" + assert instance2.init_arg2 == "i2" + + assert instance1 is not instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_init_positional_and_keyword_args(): + provider = providers.Factory(Example, "i1", init_arg2="i2") + + instance1 = provider() + instance2 = provider() + + assert instance1.init_arg1 == "i1" + assert instance1.init_arg2 == "i2" + + assert instance2.init_arg1 == "i1" + assert instance2.init_arg2 == "i2" + + assert instance1 is not instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_attributes(): + provider = providers.Factory(Example) + provider.add_attributes(attribute1="a1", attribute2="a2") + + instance1 = provider() + instance2 = provider() + + assert instance1.attribute1 == "a1" + assert instance1.attribute2 == "a2" + + assert instance2.attribute1 == "a1" + assert instance2.attribute2 == "a2" + + assert instance1 is not instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_context_args(): + provider = providers.Factory(Example, 11, 22) + + instance = provider(33, 44) + + assert instance.init_arg1 == 11 + assert instance.init_arg2 == 22 + assert instance.init_arg3 == 33 + assert instance.init_arg4 == 44 + + +def test_call_with_context_kwargs(): + provider = providers.Factory(Example, init_arg1=1) + + instance1 = provider(init_arg2=22) + assert instance1.init_arg1 == 1 + assert instance1.init_arg2 == 22 + + instance2 = provider(init_arg1=11, init_arg2=22) + assert instance2.init_arg1 == 11 + assert instance2.init_arg2 == 22 + + +def test_call_with_context_args_and_kwargs(): + provider = providers.Factory(Example, 11) + + instance = provider(22, init_arg3=33, init_arg4=44) + + assert instance.init_arg1 == 11 + assert instance.init_arg2 == 22 + assert instance.init_arg3 == 33 + assert instance.init_arg4 == 44 + + +def test_call_with_deep_context_kwargs(): + class Regularizer: + def __init__(self, alpha): + self.alpha = alpha + + class Loss: + def __init__(self, regularizer): + self.regularizer = regularizer + + class ClassificationTask: + def __init__(self, loss): + self.loss = loss + + class Algorithm: + def __init__(self, task): + self.task = task + + algorithm_factory = providers.Factory( + Algorithm, + task=providers.Factory( + ClassificationTask, + loss=providers.Factory( + Loss, + regularizer=providers.Factory( + Regularizer, + ), + ), + ), + ) + + algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5) + algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7) + algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8)) + + assert algorithm_1.task.loss.regularizer.alpha == 0.5 + assert algorithm_2.task.loss.regularizer.alpha == 0.7 + assert algorithm_3.task.loss.regularizer.alpha == 0.8 + + +def test_fluent_interface(): + provider = providers.Factory(Example) \ + .add_args(1, 2) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .add_attributes(attribute1=5, attribute2=6) + + instance = provider() + + assert instance.init_arg1 == 1 + assert instance.init_arg2 == 2 + assert instance.init_arg3 == 3 + assert instance.init_arg4 == 4 + assert instance.attribute1 == 5 + assert instance.attribute2 == 6 + + +def test_set_args(): + provider = providers.Factory(Example) \ + .add_args(1, 2) \ + .set_args(3, 4) + assert provider.args == (3, 4) + + +def test_set_kwargs(): + provider = providers.Factory(Example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .set_kwargs(init_arg3=4, init_arg4=5) + assert provider.kwargs == dict(init_arg3=4, init_arg4=5) + + +def test_set_attributes(): + provider = providers.Factory(Example) \ + .add_attributes(attribute1=5, attribute2=6) \ + .set_attributes(attribute1=6, attribute2=7) + assert provider.attributes == dict(attribute1=6, attribute2=7) + + +def test_clear_args(): + provider = providers.Factory(Example) \ + .add_args(1, 2) \ + .clear_args() + assert provider.args == tuple() + + +def test_clear_kwargs(): + provider = providers.Factory(Example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .clear_kwargs() + assert provider.kwargs == dict() + + +def test_clear_attributes(): + provider = providers.Factory(Example) \ + .add_attributes(attribute1=5, attribute2=6) \ + .clear_attributes() + assert provider.attributes == dict() + + +def test_call_overridden(): + provider = providers.Factory(Example) + overriding_provider1 = providers.Factory(dict) + overriding_provider2 = providers.Factory(list) + + provider.override(overriding_provider1) + provider.override(overriding_provider2) + + instance1 = provider() + instance2 = provider() + + assert instance1 is not instance2 + assert isinstance(instance1, list) + assert isinstance(instance2, list) + + +def test_deepcopy(): + provider = providers.Factory(Example) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert provider.cls is provider_copy.cls + assert isinstance(provider, providers.Factory) + + +def test_deepcopy_from_memo(): + provider = providers.Factory(Example) + provider_copy_memo = providers.Factory(Example) + + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + + assert provider_copy is provider_copy_memo + + +def test_deepcopy_args(): + provider = providers.Factory(Example) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_args(dependent_provider1, dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] + + assert provider.args != provider_copy.args + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs(): + provider = providers.Factory(Example) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["a1"] + dependent_provider_copy2 = provider_copy.kwargs["a2"] + + assert provider.kwargs != provider_copy.kwargs + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_attributes(): + provider = providers.Factory(Example) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_attributes(a1=dependent_provider1, a2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.attributes["a1"] + dependent_provider_copy2 = provider_copy.attributes["a2"] + + assert provider.attributes != provider_copy.attributes + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(): + provider = providers.Factory(Example) + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.cls is provider_copy.cls + assert isinstance(provider, providers.Factory) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(): + provider = providers.Factory(Example) + provider.add_args(sys.stdin) + provider.add_kwargs(a2=sys.stdout) + provider.add_attributes(a3=sys.stderr) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, providers.Factory) + assert provider.args[0] is sys.stdin + assert provider.kwargs["a2"] is sys.stdout + assert provider.attributes["a3"] is sys.stderr + + +def test_repr(): + provider = providers.Factory(Example) + assert repr(provider) == ( + "".format(repr(Example), hex(id(provider))) + ) diff --git a/tests/unit/providers/injections/__init__.py b/tests/unit/providers/injections/__init__.py new file mode 100644 index 00000000..9de62e7f --- /dev/null +++ b/tests/unit/providers/injections/__init__.py @@ -0,0 +1 @@ +"""Tests for injection objects.""" diff --git a/tests/unit/providers/injections/test_named_py2_py3.py b/tests/unit/providers/injections/test_named_py2_py3.py new file mode 100644 index 00000000..7d88db93 --- /dev/null +++ b/tests/unit/providers/injections/test_named_py2_py3.py @@ -0,0 +1,55 @@ +"""Named injection tests.""" + +from dependency_injector import providers + + +def test_isinstance(): + injection = providers.NamedInjection("name", 1) + assert isinstance(injection, providers.Injection) + + +def test_get_name(): + injection = providers.NamedInjection("name", 123) + assert injection.get_name() == "name" + + +def test_get_value_with_not_provider(): + injection = providers.NamedInjection("name", 123) + assert injection.get_value() == 123 + + +def test_get_value_with_factory(): + injection = providers.NamedInjection("name", providers.Factory(object)) + + obj1 = injection.get_value() + obj2 = injection.get_value() + + assert type(obj1) is object + assert type(obj2) is object + assert obj1 is not obj2 + + +def test_get_original_value(): + provider = providers.Factory(object) + injection = providers.NamedInjection("name", provider) + assert injection.get_original_value() is provider + + +def test_deepcopy(): + provider = providers.Factory(object) + injection = providers.NamedInjection("name", provider) + + injection_copy = providers.deepcopy(injection) + + assert injection_copy is not injection + assert injection_copy.get_original_value() is not injection.get_original_value() + +def test_deepcopy_memo(): + provider = providers.Factory(object) + injection = providers.NamedInjection("name", provider) + injection_copy_orig = providers.NamedInjection("name", provider) + + injection_copy = providers.deepcopy(injection, {id(injection): injection_copy_orig}) + + assert injection_copy is injection_copy_orig + assert injection_copy.get_original_value() is injection.get_original_value() diff --git a/tests/unit/providers/injections/test_positional_py2_py3.py b/tests/unit/providers/injections/test_positional_py2_py3.py new file mode 100644 index 00000000..9cf41464 --- /dev/null +++ b/tests/unit/providers/injections/test_positional_py2_py3.py @@ -0,0 +1,51 @@ +"""Positional injection tests.""" + +from dependency_injector import providers + + +def test_isinstance(): + injection = providers.PositionalInjection(1) + assert isinstance(injection, providers.Injection) + + +def test_get_value_with_not_provider(): + injection = providers.PositionalInjection(123) + assert injection.get_value() == 123 + + +def test_get_value_with_factory(): + injection = providers.PositionalInjection(providers.Factory(object)) + + obj1 = injection.get_value() + obj2 = injection.get_value() + + assert type(obj1) is object + assert type(obj2) is object + assert obj1 is not obj2 + + +def test_get_original_value(): + provider = providers.Factory(object) + injection = providers.PositionalInjection(provider) + assert injection.get_original_value() is provider + + +def test_deepcopy(): + provider = providers.Factory(object) + injection = providers.PositionalInjection(provider) + + injection_copy = providers.deepcopy(injection) + + assert injection_copy is not injection + assert injection_copy.get_original_value() is not injection.get_original_value() + + +def test_deepcopy_memo(): + provider = providers.Factory(object) + injection = providers.PositionalInjection(provider) + injection_copy_orig = providers.PositionalInjection(provider) + + injection_copy = providers.deepcopy(injection, {id(injection): injection_copy_orig}) + + assert injection_copy is injection_copy_orig + assert injection_copy.get_original_value() is injection.get_original_value() diff --git a/tests/unit/providers/resource/__init__.py b/tests/unit/providers/resource/__init__.py new file mode 100644 index 00000000..888679fd --- /dev/null +++ b/tests/unit/providers/resource/__init__.py @@ -0,0 +1 @@ +"""Resource provider tests.""" \ No newline at end of file diff --git a/tests/unit/providers/resource/test_async_resource_py35.py b/tests/unit/providers/resource/test_async_resource_py35.py new file mode 100644 index 00000000..ba983d60 --- /dev/null +++ b/tests/unit/providers/resource/test_async_resource_py35.py @@ -0,0 +1,308 @@ +"""Resource provider async tests.""" + +import asyncio +import inspect +import sys +from typing import Any + +from dependency_injector import containers, providers, resources +from pytest import mark, raises + + +@mark.asyncio +async def test_init_async_function(): + resource = object() + + async def _init(): + await asyncio.sleep(0.001) + _init.counter += 1 + return resource + + _init.counter = 0 + + provider = providers.Resource(_init) + + result1 = await provider() + assert result1 is resource + assert _init.counter == 1 + + result2 = await provider() + assert result2 is resource + assert _init.counter == 1 + + await provider.shutdown() + + +@mark.asyncio +@mark.skipif(sys.version_info < (3, 6), reason="requires Python 3.6+") +async def test_init_async_generator(): + resource = object() + + async def _init(): + await asyncio.sleep(0.001) + _init.init_counter += 1 + + yield resource + + await asyncio.sleep(0.001) + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.Resource(_init) + + result1 = await provider() + assert result1 is resource + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + await provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + result2 = await provider() + assert result2 is resource + assert _init.init_counter == 2 + assert _init.shutdown_counter == 1 + + await provider.shutdown() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 2 + + +@mark.asyncio +async def test_init_async_class(): + resource = object() + + class TestResource(resources.AsyncResource): + init_counter = 0 + shutdown_counter = 0 + + async def init(self): + await asyncio.sleep(0.001) + self.__class__.init_counter += 1 + return resource + + async def shutdown(self, resource_): + await asyncio.sleep(0.001) + self.__class__.shutdown_counter += 1 + assert resource_ is resource + + provider = providers.Resource(TestResource) + + result1 = await provider() + assert result1 is resource + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 0 + + await provider.shutdown() + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 1 + + result2 = await provider() + assert result2 is resource + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 1 + + await provider.shutdown() + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 2 + + +def test_init_async_class_generic_typing(): + # See issue: https://github.com/ets-labs/python-dependency-injector/issues/488 + class TestDependency: + ... + + class TestAsyncResource(resources.AsyncResource[TestDependency]): + async def init(self, *args: Any, **kwargs: Any) -> TestDependency: + return TestDependency() + + async def shutdown(self, resource: TestDependency) -> None: ... + + assert issubclass(TestAsyncResource, resources.AsyncResource) is True + + +def test_init_async_class_abc_init_definition_is_required(): + class TestAsyncResource(resources.AsyncResource): + ... + + with raises(TypeError) as context: + TestAsyncResource() + + assert "Can't instantiate abstract class TestAsyncResource" in str(context.value) + assert "init" in str(context.value) + + +def test_init_async_class_abc_shutdown_definition_is_not_required(): + class TestAsyncResource(resources.AsyncResource): + async def init(self): + ... + + assert hasattr(TestAsyncResource(), "shutdown") is True + assert inspect.iscoroutinefunction(TestAsyncResource.shutdown) is True + + +@mark.asyncio +async def test_init_with_error(): + async def _init(): + raise RuntimeError() + + provider = providers.Resource(_init) + + future = provider() + assert provider.initialized is True + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await future + + assert provider.initialized is False + assert provider.is_async_mode_enabled() is True + + +@mark.asyncio +async def test_init_async_gen_with_error(): + async def _init(): + raise RuntimeError() + yield + + provider = providers.Resource(_init) + + future = provider() + assert provider.initialized is True + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await future + + assert provider.initialized is False + assert provider.is_async_mode_enabled() is True + + +@mark.asyncio +async def test_init_async_subclass_with_error(): + class _Resource(resources.AsyncResource): + async def init(self): + raise RuntimeError() + + async def shutdown(self, resource): + pass + + provider = providers.Resource(_Resource) + + future = provider() + assert provider.initialized is True + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await future + + assert provider.initialized is False + assert provider.is_async_mode_enabled() is True + + +@mark.asyncio +async def test_init_with_dependency_to_other_resource(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/361 + async def init_db_connection(db_url: str): + await asyncio.sleep(0.001) + yield {"connection": "OK", "url": db_url} + + async def init_user_session(db): + await asyncio.sleep(0.001) + yield {"session": "OK", "db": db} + + class Container(containers.DeclarativeContainer): + config = providers.Configuration() + + db_connection = providers.Resource( + init_db_connection, + db_url=config.db_url, + ) + + user_session = providers.Resource( + init_user_session, + db=db_connection + ) + + async def main(): + container = Container(config={"db_url": "postgres://..."}) + try: + return await container.user_session() + finally: + await container.shutdown_resources() + + result = await main() + assert result == {"session": "OK", "db": {"connection": "OK", "url": "postgres://..."}} + + +@mark.asyncio +async def test_init_and_shutdown_methods(): + async def _init(): + await asyncio.sleep(0.001) + _init.init_counter += 1 + + yield + + await asyncio.sleep(0.001) + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.Resource(_init) + + await provider.init() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + await provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + await provider.init() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 1 + + await provider.shutdown() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 2 + + +@mark.asyncio +async def test_shutdown_of_not_initialized(): + async def _init(): + yield + + provider = providers.Resource(_init) + provider.enable_async_mode() + + result = await provider.shutdown() + assert result is None + + +@mark.asyncio +async def test_concurrent_init(): + resource = object() + + async def _init(): + await asyncio.sleep(0.001) + _init.counter += 1 + return resource + + _init.counter = 0 + + provider = providers.Resource(_init) + + result1, result2 = await asyncio.gather( + provider(), + provider() + ) + + assert result1 is resource + assert _init.counter == 1 + + assert result2 is resource + assert _init.counter == 1 diff --git a/tests/unit/providers/resource/test_resource_py35.py b/tests/unit/providers/resource/test_resource_py35.py new file mode 100644 index 00000000..921ec8fa --- /dev/null +++ b/tests/unit/providers/resource/test_resource_py35.py @@ -0,0 +1,398 @@ +"""Resource provider tests.""" + +import sys +from typing import Any + +from dependency_injector import containers, providers, resources, errors +from pytest import raises + + +def init_fn(*args, **kwargs): + return args, kwargs + + +def test_is_provider(): + assert providers.is_provider(providers.Resource(init_fn)) is True + + +def test_init_optional_provides(): + provider = providers.Resource() + provider.set_provides(init_fn) + assert provider.provides is init_fn + assert provider() == (tuple(), dict()) + + +def test_set_provides_returns_(): + provider = providers.Resource() + assert provider.set_provides(init_fn) is provider + + +def test_provided_instance_provider(): + provider = providers.Resource(init_fn) + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_injection(): + resource = object() + + def _init(): + _init.counter += 1 + return resource + + _init.counter = 0 + + class Container(containers.DeclarativeContainer): + resource = providers.Resource(_init) + dependency1 = providers.List(resource) + dependency2 = providers.List(resource) + + container = Container() + list1 = container.dependency1() + list2 = container.dependency2() + + assert list1 == [resource] + assert list1[0] is resource + + assert list2 == [resource] + assert list2[0] is resource + + assert _init.counter == 1 + + +def test_init_function(): + def _init(): + _init.counter += 1 + + _init.counter = 0 + + provider = providers.Resource(_init) + + result1 = provider() + assert result1 is None + assert _init.counter == 1 + + result2 = provider() + assert result2 is None + assert _init.counter == 1 + + provider.shutdown() + + +def test_init_generator(): + def _init(): + _init.init_counter += 1 + yield + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.Resource(_init) + + result1 = provider() + assert result1 is None + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert _init.init_counter == 2 + assert _init.shutdown_counter == 1 + + provider.shutdown() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 2 + + +def test_init_class(): + class TestResource(resources.Resource): + init_counter = 0 + shutdown_counter = 0 + + def init(self): + self.__class__.init_counter += 1 + + def shutdown(self, _): + self.__class__.shutdown_counter += 1 + + provider = providers.Resource(TestResource) + + result1 = provider() + assert result1 is None + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 0 + + provider.shutdown() + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 1 + + provider.shutdown() + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 2 + + +def test_init_class_generic_typing(): + # See issue: https://github.com/ets-labs/python-dependency-injector/issues/488 + class TestDependency: + ... + + class TestResource(resources.Resource[TestDependency]): + def init(self, *args: Any, **kwargs: Any) -> TestDependency: + return TestDependency() + + def shutdown(self, resource: TestDependency) -> None: ... + + assert issubclass(TestResource, resources.Resource) is True + + +def test_init_class_abc_init_definition_is_required(): + class TestResource(resources.Resource): + ... + + with raises(TypeError) as context: + TestResource() + + assert "Can't instantiate abstract class TestResource" in str(context.value) + assert "init" in str(context.value) + + +def test_init_class_abc_shutdown_definition_is_not_required(): + class TestResource(resources.Resource): + def init(self): + ... + + assert hasattr(TestResource(), "shutdown") is True + + +def test_init_not_callable(): + provider = providers.Resource(1) + with raises(errors.Error): + provider.init() + + +def test_init_and_shutdown(): + def _init(): + _init.init_counter += 1 + yield + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.Resource(_init) + + result1 = provider.init() + assert result1 is None + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + result2 = provider.init() + assert result2 is None + assert _init.init_counter == 2 + assert _init.shutdown_counter == 1 + + provider.shutdown() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 2 + + +def test_shutdown_of_not_initialized(): + def _init(): + yield + + provider = providers.Resource(_init) + + result = provider.shutdown() + assert result is None + + +def test_initialized(): + provider = providers.Resource(init_fn) + assert provider.initialized is False + + provider.init() + assert provider.initialized is True + + provider.shutdown() + assert provider.initialized is False + + +def test_call_with_context_args(): + provider = providers.Resource(init_fn, "i1", "i2") + assert provider("i3", i4=4) == (("i1", "i2", "i3"), {"i4": 4}) + + +def test_fluent_interface(): + provider = providers.Resource(init_fn) \ + .add_args(1, 2) \ + .add_kwargs(a3=3, a4=4) + assert provider() == ((1, 2), {"a3": 3, "a4": 4}) + + +def test_set_args(): + provider = providers.Resource(init_fn) \ + .add_args(1, 2) \ + .set_args(3, 4) + assert provider.args == (3, 4) + + +def test_clear_args(): + provider = providers.Resource(init_fn) \ + .add_args(1, 2) \ + .clear_args() + assert provider.args == tuple() + + +def test_set_kwargs(): + provider = providers.Resource(init_fn) \ + .add_kwargs(a1="i1", a2="i2") \ + .set_kwargs(a3="i3", a4="i4") + assert provider.kwargs == {"a3": "i3", "a4": "i4"} + + +def test_clear_kwargs(): + provider = providers.Resource(init_fn) \ + .add_kwargs(a1="i1", a2="i2") \ + .clear_kwargs() + assert provider.kwargs == {} + + +def test_call_overridden(): + provider = providers.Resource(init_fn, 1) + overriding_provider1 = providers.Resource(init_fn, 2) + overriding_provider2 = providers.Resource(init_fn, 3) + + provider.override(overriding_provider1) + provider.override(overriding_provider2) + + instance1 = provider() + instance2 = provider() + + assert instance1 is instance2 + assert instance1 == ((3,), {}) + assert instance2 == ((3,), {}) + + +def test_deepcopy(): + provider = providers.Resource(init_fn, 1, 2, a3=3, a4=4) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert provider.kwargs == provider_copy.kwargs + assert isinstance(provider, providers.Resource) + + +def test_deepcopy_initialized(): + provider = providers.Resource(init_fn) + provider.init() + + with raises(errors.Error): + providers.deepcopy(provider) + + +def test_deepcopy_from_memo(): + provider = providers.Resource(init_fn) + provider_copy_memo = providers.Resource(init_fn) + + provider_copy = providers.deepcopy( + provider, + memo={id(provider): provider_copy_memo}, + ) + + assert provider_copy is provider_copy_memo + + +def test_deepcopy_args(): + provider = providers.Resource(init_fn) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_args(dependent_provider1, dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] + + assert provider.args != provider_copy.args + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs(): + provider = providers.Resource(init_fn) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_kwargs(d1=dependent_provider1, d2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["d1"] + dependent_provider_copy2 = provider_copy.kwargs["d2"] + + assert provider.kwargs != provider_copy.kwargs + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(): + provider = providers.Resource(init_fn) + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert isinstance(provider, providers.Resource) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(): + provider = providers.Resource(init_fn) + provider.add_args(sys.stdin, sys.stdout, sys.stderr) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, providers.Resource) + assert provider.args[0] is sys.stdin + assert provider.args[1] is sys.stdout + assert provider.args[2] is sys.stderr + + +def test_repr(): + provider = providers.Resource(init_fn) + + assert repr(provider) == ( + "".format( + repr(init_fn), + hex(id(provider)), + ) + ) diff --git a/tests/unit/providers/singleton/__init__.py b/tests/unit/providers/singleton/__init__.py new file mode 100644 index 00000000..c31f39dd --- /dev/null +++ b/tests/unit/providers/singleton/__init__.py @@ -0,0 +1 @@ +"""Singleton provider tests.""" diff --git a/tests/unit/providers/singleton/common.py b/tests/unit/providers/singleton/common.py new file mode 100644 index 00000000..991624ed --- /dev/null +++ b/tests/unit/providers/singleton/common.py @@ -0,0 +1,13 @@ +"""Common test artifacts.""" + + +class Example: + + def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None, init_arg4=None): + self.init_arg1 = init_arg1 + self.init_arg2 = init_arg2 + self.init_arg3 = init_arg3 + self.init_arg4 = init_arg4 + + self.attribute1 = None + self.attribute2 = None diff --git a/tests/unit/providers/singleton/test_abstract_singleton_py2_py3.py b/tests/unit/providers/singleton/test_abstract_singleton_py2_py3.py new file mode 100644 index 00000000..3b372261 --- /dev/null +++ b/tests/unit/providers/singleton/test_abstract_singleton_py2_py3.py @@ -0,0 +1,63 @@ +"""AbstractSingleton provider tests.""" + +from dependency_injector import providers, errors +from pytest import raises + +from .common import Example + + +def test_inheritance(): + assert isinstance(providers.AbstractSingleton(Example), providers.BaseSingleton) + + +def test_call_overridden_by_singleton(): + provider = providers.AbstractSingleton(object) + provider.override(providers.Singleton(Example)) + assert isinstance(provider(), Example) + + +def test_call_overridden_by_delegated_singleton(): + provider = providers.AbstractSingleton(object) + provider.override(providers.DelegatedSingleton(Example)) + assert isinstance(provider(), Example) + + +def test_call_not_overridden(): + provider = providers.AbstractSingleton(object) + with raises(errors.Error): + provider() + + +def test_reset_overridden(): + provider = providers.AbstractSingleton(object) + provider.override(providers.Singleton(Example)) + + instance1 = provider() + + provider.reset() + + instance2 = provider() + + assert instance1 is not instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_reset_not_overridden(): + provider = providers.AbstractSingleton(object) + with raises(errors.Error): + provider.reset() + + +def test_override_by_not_singleton(): + provider = providers.AbstractSingleton(object) + with raises(errors.Error): + provider.override(providers.Factory(object)) + + +def test_repr(): + provider = providers.AbstractSingleton(Example) + assert repr(provider) == ( + "".format(repr(Example), hex(id(provider))) + ) diff --git a/tests/unit/providers/singleton/test_delegated_singleton_py2_py3.py b/tests/unit/providers/singleton/test_delegated_singleton_py2_py3.py new file mode 100644 index 00000000..8a76ec0b --- /dev/null +++ b/tests/unit/providers/singleton/test_delegated_singleton_py2_py3.py @@ -0,0 +1,27 @@ +"""Delegated singleton provider tests.""" + +from dependency_injector import providers +from pytest import fixture + +from .common import Example + + +PROVIDER_CLASSES = [ + providers.DelegatedSingleton, + providers.DelegatedThreadLocalSingleton, + providers.DelegatedThreadSafeSingleton, +] + + +@fixture(params=PROVIDER_CLASSES) +def singleton_cls(request): + return request.param + + +@fixture +def provider(singleton_cls): + return singleton_cls(Example) + + +def test_is_delegated_provider(provider): + assert providers.is_delegated(provider) is True diff --git a/tests/unit/providers/singleton/test_singleton_delegate_py2_py3.py b/tests/unit/providers/singleton/test_singleton_delegate_py2_py3.py new file mode 100644 index 00000000..be6eb2b2 --- /dev/null +++ b/tests/unit/providers/singleton/test_singleton_delegate_py2_py3.py @@ -0,0 +1,23 @@ +"""SingletonDelegate provider tests.""" + +from dependency_injector import providers, errors +from pytest import fixture, raises + + +@fixture +def provider(): + return providers.Singleton(object) + + +@fixture +def delegate(provider): + return providers.SingletonDelegate(provider) + + +def test_is_delegate(delegate): + assert isinstance(delegate, providers.Delegate) + + +def test_init_with_not_factory(): + with raises(errors.Error): + providers.SingletonDelegate(providers.Object(object())) diff --git a/tests/unit/providers/singleton/test_singleton_py2_py3.py b/tests/unit/providers/singleton/test_singleton_py2_py3.py new file mode 100644 index 00000000..49f777fb --- /dev/null +++ b/tests/unit/providers/singleton/test_singleton_py2_py3.py @@ -0,0 +1,470 @@ +"""Singleton provider tests.""" + +import sys + +from dependency_injector import providers, errors +from pytest import fixture, raises + +from .common import Example + + +PROVIDER_CLASSES = [ + providers.Singleton, + providers.DelegatedSingleton, + providers.ThreadLocalSingleton, + providers.DelegatedThreadLocalSingleton, + providers.ThreadSafeSingleton, + providers.DelegatedThreadSafeSingleton, +] +if sys.version_info >= (3, 5): + PROVIDER_CLASSES.append(providers.ContextLocalSingleton) + + +@fixture(params=PROVIDER_CLASSES) +def singleton_cls(request): + return request.param + + +@fixture +def provider(singleton_cls): + return singleton_cls(Example) + + +def test_is_provider(provider): + assert providers.is_provider(provider) is True + + +def test_init_with_not_callable(singleton_cls): + with raises(errors.Error): + singleton_cls(123) + + +def test_init_optional_provides(provider): + provider.set_provides(object) + assert provider.provides is object + assert isinstance(provider(), object) + + +def test_set_provides_returns_self(provider): + assert provider.set_provides(object) is provider + + +def test_init_with_valid_provided_type(singleton_cls): + class ExampleProvider(singleton_cls): + provided_type = Example + + example_provider = ExampleProvider(Example, 1, 2) + assert isinstance(example_provider(), Example) + + +def test_init_with_valid_provided_subtype(singleton_cls): + class ExampleProvider(singleton_cls): + provided_type = Example + + class NewExample(Example): + pass + + example_provider = ExampleProvider(NewExample, 1, 2) + assert isinstance(example_provider(), NewExample) + + +def test_init_with_invalid_provided_type(singleton_cls): + class ExampleProvider(singleton_cls): + provided_type = Example + + with raises(errors.Error): + ExampleProvider(list) + + +def test_provided_instance_provider(provider): + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_call(provider): + instance1 = provider() + instance2 = provider() + + assert instance1 is instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_init_positional_args(singleton_cls): + provider = singleton_cls(Example, "i1", "i2") + + instance1 = provider() + instance2 = provider() + + assert instance1.init_arg1 == "i1" + assert instance1.init_arg2 == "i2" + + assert instance2.init_arg1 == "i1" + assert instance2.init_arg2 == "i2" + + assert instance1 is instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_init_keyword_args(singleton_cls): + provider = singleton_cls(Example, init_arg1="i1", init_arg2="i2") + + instance1 = provider() + instance2 = provider() + + assert instance1.init_arg1 == "i1" + assert instance1.init_arg2 == "i2" + + assert instance2.init_arg1 == "i1" + assert instance2.init_arg2 == "i2" + + assert instance1 is instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_init_positional_and_keyword_args(singleton_cls): + provider = singleton_cls(Example, "i1", init_arg2="i2") + + instance1 = provider() + instance2 = provider() + + assert instance1.init_arg1 == "i1" + assert instance1.init_arg2 == "i2" + + assert instance2.init_arg1 == "i1" + assert instance2.init_arg2 == "i2" + + assert instance1 is instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_attributes(provider): + provider.add_attributes(attribute1="a1", attribute2="a2") + + instance1 = provider() + instance2 = provider() + + assert instance1.attribute1 == "a1" + assert instance1.attribute2 == "a2" + + assert instance2.attribute1 == "a1" + assert instance2.attribute2 == "a2" + + assert instance1 is instance2 + assert isinstance(instance1, Example) + assert isinstance(instance2, Example) + + +def test_call_with_context_args(provider): + instance = provider(11, 22) + + assert instance.init_arg1 == 11 + assert instance.init_arg2 == 22 + + +def test_call_with_context_kwargs(singleton_cls): + provider = singleton_cls(Example, init_arg1=1) + + instance1 = provider(init_arg2=22) + assert instance1.init_arg1 == 1 + assert instance1.init_arg2 == 22 + + # Instance is created earlier + instance1 = provider(init_arg1=11, init_arg2=22) + assert instance1.init_arg1 == 1 + assert instance1.init_arg2 == 22 + + +def test_call_with_context_args_and_kwargs(singleton_cls): + provider = singleton_cls(Example, 11) + + instance = provider(22, init_arg3=33, init_arg4=44) + + assert instance.init_arg1 == 11 + assert instance.init_arg2 == 22 + assert instance.init_arg3 == 33 + assert instance.init_arg4 == 44 + + +def test_fluent_interface(singleton_cls): + provider = singleton_cls(Example) \ + .add_args(1, 2) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .add_attributes(attribute1=5, attribute2=6) + + instance = provider() + + assert instance.init_arg1 == 1 + assert instance.init_arg2 == 2 + assert instance.init_arg3 == 3 + assert instance.init_arg4 == 4 + assert instance.attribute1 == 5 + assert instance.attribute2 == 6 + + +def test_set_args(singleton_cls): + provider = singleton_cls(Example) \ + .add_args(1, 2) \ + .set_args(3, 4) + assert provider.args == (3, 4) + + +def test_set_kwargs(singleton_cls): + provider = singleton_cls(Example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .set_kwargs(init_arg3=4, init_arg4=5) + assert provider.kwargs == dict(init_arg3=4, init_arg4=5) + + +def test_set_attributes(singleton_cls): + provider = singleton_cls(Example) \ + .add_attributes(attribute1=5, attribute2=6) \ + .set_attributes(attribute1=6, attribute2=7) + assert provider.attributes == dict(attribute1=6, attribute2=7) + + +def test_clear_args(singleton_cls): + provider = singleton_cls(Example) \ + .add_args(1, 2) \ + .clear_args() + assert provider.args == tuple() + + +def test_clear_kwargs(singleton_cls): + provider = singleton_cls(Example) \ + .add_kwargs(init_arg3=3, init_arg4=4) \ + .clear_kwargs() + assert provider.kwargs == dict() + + +def test_clear_attributes(singleton_cls): + provider = singleton_cls(Example) \ + .add_attributes(attribute1=5, attribute2=6) \ + .clear_attributes() + assert provider.attributes == dict() + + +def test_call_overridden(singleton_cls): + provider = singleton_cls(Example) + overriding_provider1 = singleton_cls(dict) + overriding_provider2 = singleton_cls(list) + + provider.override(overriding_provider1) + provider.override(overriding_provider2) + + instance1 = provider() + instance2 = provider() + + assert instance1 is instance2 + assert isinstance(instance1, list) + assert isinstance(instance2, list) + + +def test_deepcopy(singleton_cls): + provider = singleton_cls(Example) + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert provider.cls is provider_copy.cls + assert isinstance(provider, singleton_cls) + + +def test_deepcopy_from_memo(singleton_cls): + provider = singleton_cls(Example) + provider_copy_memo = singleton_cls(Example) + + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + assert provider_copy is provider_copy_memo + + +def test_deepcopy_args(singleton_cls): + provider = singleton_cls(Example) + dependent_provider1 = singleton_cls(list) + dependent_provider2 = singleton_cls(dict) + + provider.add_args(dependent_provider1, dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] + + assert provider.args != provider_copy.args + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs(singleton_cls): + provider = singleton_cls(Example) + dependent_provider1 = singleton_cls(list) + dependent_provider2 = singleton_cls(dict) + + provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["a1"] + dependent_provider_copy2 = provider_copy.kwargs["a2"] + + assert provider.kwargs != provider_copy.kwargs + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_attributes(singleton_cls): + provider = singleton_cls(Example) + dependent_provider1 = singleton_cls(list) + dependent_provider2 = singleton_cls(dict) + + provider.add_attributes(a1=dependent_provider1, a2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.attributes["a1"] + dependent_provider_copy2 = provider_copy.attributes["a2"] + + assert provider.attributes != provider_copy.attributes + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(singleton_cls): + provider = singleton_cls(Example) + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.cls is provider_copy.cls + assert isinstance(provider, singleton_cls) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(singleton_cls): + provider = singleton_cls(Example) + provider.add_args(sys.stdin) + provider.add_kwargs(a2=sys.stdout) + provider.add_attributes(a3=sys.stderr) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, singleton_cls) + assert provider.args[0] is sys.stdin + assert provider.kwargs["a2"] is sys.stdout + assert provider.attributes["a3"] is sys.stderr + + +def test_reset(singleton_cls): + provider = singleton_cls(object) + + instance1 = provider() + assert isinstance(instance1, object) + + provider.reset() + + instance2 = provider() + assert isinstance(instance2, object) + assert instance1 is not instance2 + + +def test_reset_clean(provider): + instance1 = provider() + + provider.reset() + provider.reset() + + instance2 = provider() + assert instance1 is not instance2 + + +def test_reset_with_singleton(singleton_cls): + dependent_singleton = providers.Singleton(object) + provider = singleton_cls(dict, dependency=dependent_singleton) + + dependent_instance = dependent_singleton() + instance1 = provider() + assert instance1["dependency"] is dependent_instance + + provider.reset() + + instance2 = provider() + assert instance2["dependency"] is dependent_instance + assert instance1 is not instance2 + + +def test_reset_context_manager(provider): + instance1 = provider() + with provider.reset(): + instance2 = provider() + instance3 = provider() + assert len({instance1, instance2, instance3}) == 3 + + +def test_reset_context_manager_as_attribute(provider): + with provider.reset() as alias: + pass + assert provider is alias + + +def test_full_reset(singleton_cls): + dependent_singleton = providers.Singleton(object) + provider = singleton_cls(dict, dependency=dependent_singleton) + + dependent_instance1 = dependent_singleton() + instance1 = provider() + assert instance1["dependency"] is dependent_instance1 + + provider.full_reset() + + dependent_instance2 = dependent_singleton() + instance2 = provider() + assert instance2["dependency"] is not dependent_instance1 + assert dependent_instance1 is not dependent_instance2 + assert instance1 is not instance2 + + +def test_full_reset_context_manager(singleton_cls): + class Item: + def __init__(self, dependency): + self.dependency = dependency + + dependent_singleton = providers.Singleton(object) + singleton = singleton_cls(Item, dependency=dependent_singleton) + + instance1 = singleton() + with singleton.full_reset(): + instance2 = singleton() + instance3 = singleton() + + assert len({instance1, instance2, instance3}) == 3 + assert len({instance1.dependency, instance2.dependency, instance3.dependency}) == 3 + + +def test_full_reset_context_manager_as_attribute(provider): + with provider.full_reset() as alias: + pass + assert provider is alias + + +def test_repr(provider): + assert repr(provider) == ( + "".format(provider.__class__.__name__, repr(Example), hex(id(provider))) + ) diff --git a/tests/unit/providers/singleton_common.py b/tests/unit/providers/singleton_common.py deleted file mode 100644 index 8f0c4a3d..00000000 --- a/tests/unit/providers/singleton_common.py +++ /dev/null @@ -1,434 +0,0 @@ -import sys - -from dependency_injector import providers, errors - - -class Example(object): - - def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None, - init_arg4=None): - self.init_arg1 = init_arg1 - self.init_arg2 = init_arg2 - self.init_arg3 = init_arg3 - self.init_arg4 = init_arg4 - - self.attribute1 = None - self.attribute2 = None - - -class _BaseSingletonTestCase(object): - - singleton_cls = None - - def test_is_provider(self): - self.assertTrue(providers.is_provider(self.singleton_cls(Example))) - - def test_init_with_callable(self): - self.assertTrue(self.singleton_cls(credits)) - - def test_init_with_not_callable(self): - self.assertRaises(errors.Error, self.singleton_cls, 123) - - def test_init_optional_provides(self): - provider = self.singleton_cls() - provider.set_provides(object) - self.assertIs(provider.provides, object) - self.assertIsInstance(provider(), object) - - def test_set_provides_returns_self(self): - provider = self.singleton_cls() - self.assertIs(provider.set_provides(object), provider) - - def test_init_with_valid_provided_type(self): - class ExampleProvider(self.singleton_cls): - provided_type = Example - - example_provider = ExampleProvider(Example, 1, 2) - - self.assertIsInstance(example_provider(), Example) - - def test_init_with_valid_provided_subtype(self): - class ExampleProvider(self.singleton_cls): - provided_type = Example - - class NewExampe(Example): - pass - - example_provider = ExampleProvider(NewExampe, 1, 2) - - self.assertIsInstance(example_provider(), NewExampe) - - def test_init_with_invalid_provided_type(self): - class ExampleProvider(self.singleton_cls): - provided_type = Example - - with self.assertRaises(errors.Error): - ExampleProvider(list) - - def test_provided_instance_provider(self): - provider = providers.Singleton(Example) - self.assertIsInstance(provider.provided, providers.ProvidedInstance) - - def test_call(self): - provider = self.singleton_cls(Example) - - instance1 = provider() - instance2 = provider() - - self.assertIs(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_init_positional_args(self): - provider = self.singleton_cls(Example, "i1", "i2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.init_arg1, "i1") - self.assertEqual(instance1.init_arg2, "i2") - - self.assertEqual(instance2.init_arg1, "i1") - self.assertEqual(instance2.init_arg2, "i2") - - self.assertIs(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_init_keyword_args(self): - provider = self.singleton_cls(Example, init_arg1="i1", init_arg2="i2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.init_arg1, "i1") - self.assertEqual(instance1.init_arg2, "i2") - - self.assertEqual(instance2.init_arg1, "i1") - self.assertEqual(instance2.init_arg2, "i2") - - self.assertIs(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_init_positional_and_keyword_args(self): - provider = self.singleton_cls(Example, "i1", init_arg2="i2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.init_arg1, "i1") - self.assertEqual(instance1.init_arg2, "i2") - - self.assertEqual(instance2.init_arg1, "i1") - self.assertEqual(instance2.init_arg2, "i2") - - self.assertIs(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_attributes(self): - provider = self.singleton_cls(Example) - provider.add_attributes(attribute1="a1", attribute2="a2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.attribute1, "a1") - self.assertEqual(instance1.attribute2, "a2") - - self.assertEqual(instance2.attribute1, "a1") - self.assertEqual(instance2.attribute2, "a2") - - self.assertIs(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_context_args(self): - provider = self.singleton_cls(Example) - - instance = provider(11, 22) - - self.assertEqual(instance.init_arg1, 11) - self.assertEqual(instance.init_arg2, 22) - - def test_call_with_context_kwargs(self): - provider = self.singleton_cls(Example, init_arg1=1) - - instance1 = provider(init_arg2=22) - self.assertEqual(instance1.init_arg1, 1) - self.assertEqual(instance1.init_arg2, 22) - - # Instance is created earlier - instance1 = provider(init_arg1=11, init_arg2=22) - self.assertEqual(instance1.init_arg1, 1) - self.assertEqual(instance1.init_arg2, 22) - - def test_call_with_context_args_and_kwargs(self): - provider = self.singleton_cls(Example, 11) - - instance = provider(22, init_arg3=33, init_arg4=44) - - self.assertEqual(instance.init_arg1, 11) - self.assertEqual(instance.init_arg2, 22) - self.assertEqual(instance.init_arg3, 33) - self.assertEqual(instance.init_arg4, 44) - - def test_fluent_interface(self): - provider = self.singleton_cls(Example) \ - .add_args(1, 2) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .add_attributes(attribute1=5, attribute2=6) - - instance = provider() - - self.assertEqual(instance.init_arg1, 1) - self.assertEqual(instance.init_arg2, 2) - self.assertEqual(instance.init_arg3, 3) - self.assertEqual(instance.init_arg4, 4) - self.assertEqual(instance.attribute1, 5) - self.assertEqual(instance.attribute2, 6) - - def test_set_args(self): - provider = self.singleton_cls(Example) \ - .add_args(1, 2) \ - .set_args(3, 4) - self.assertEqual(provider.args, (3, 4)) - - def test_set_kwargs(self): - provider = self.singleton_cls(Example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .set_kwargs(init_arg3=4, init_arg4=5) - self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) - - def test_set_attributes(self): - provider = self.singleton_cls(Example) \ - .add_attributes(attribute1=5, attribute2=6) \ - .set_attributes(attribute1=6, attribute2=7) - self.assertEqual(provider.attributes, dict(attribute1=6, attribute2=7)) - - def test_clear_args(self): - provider = self.singleton_cls(Example) \ - .add_args(1, 2) \ - .clear_args() - self.assertEqual(provider.args, tuple()) - - def test_clear_kwargs(self): - provider = self.singleton_cls(Example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .clear_kwargs() - self.assertEqual(provider.kwargs, dict()) - - def test_clear_attributes(self): - provider = self.singleton_cls(Example) \ - .add_attributes(attribute1=5, attribute2=6) \ - .clear_attributes() - self.assertEqual(provider.attributes, dict()) - - def test_call_overridden(self): - provider = self.singleton_cls(Example) - overriding_provider1 = self.singleton_cls(dict) - overriding_provider2 = self.singleton_cls(list) - - provider.override(overriding_provider1) - provider.override(overriding_provider2) - - instance1 = provider() - instance2 = provider() - - self.assertIs(instance1, instance2) - self.assertIsInstance(instance1, list) - self.assertIsInstance(instance2, list) - - def test_deepcopy(self): - provider = self.singleton_cls(Example) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.cls, provider_copy.cls) - self.assertIsInstance(provider, self.singleton_cls) - - def test_deepcopy_from_memo(self): - provider = self.singleton_cls(Example) - provider_copy_memo = self.singleton_cls(Example) - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_args(self): - provider = self.singleton_cls(Example) - dependent_provider1 = self.singleton_cls(list) - dependent_provider2 = self.singleton_cls(dict) - - provider.add_args(dependent_provider1, dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.args[0] - dependent_provider_copy2 = provider_copy.args[1] - - self.assertNotEqual(provider.args, provider_copy.args) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_kwargs(self): - provider = self.singleton_cls(Example) - dependent_provider1 = self.singleton_cls(list) - dependent_provider2 = self.singleton_cls(dict) - - provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.kwargs["a1"] - dependent_provider_copy2 = provider_copy.kwargs["a2"] - - self.assertNotEqual(provider.kwargs, provider_copy.kwargs) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_attributes(self): - provider = self.singleton_cls(Example) - dependent_provider1 = self.singleton_cls(list) - dependent_provider2 = self.singleton_cls(dict) - - provider.add_attributes(a1=dependent_provider1, a2=dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.attributes["a1"] - dependent_provider_copy2 = provider_copy.attributes["a2"] - - self.assertNotEqual(provider.attributes, provider_copy.attributes) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_overridden(self): - provider = self.singleton_cls(Example) - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.cls, provider_copy.cls) - self.assertIsInstance(provider, self.singleton_cls) - - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - - def test_deepcopy_with_sys_streams(self): - provider = providers.Singleton(Example) - provider.add_args(sys.stdin) - provider.add_kwargs(a2=sys.stdout) - provider.add_attributes(a3=sys.stderr) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider_copy, providers.Singleton) - self.assertIs(provider.args[0], sys.stdin) - self.assertIs(provider.kwargs["a2"], sys.stdout) - self.assertIs(provider.attributes["a3"], sys.stderr) - - def test_reset(self): - provider = self.singleton_cls(object) - - instance1 = provider() - self.assertIsInstance(instance1, object) - - provider.reset() - - instance2 = provider() - self.assertIsInstance(instance2, object) - - self.assertIsNot(instance1, instance2) - - def test_reset_with_singleton(self): - dependent_singleton = providers.Singleton(object) - provider = self.singleton_cls(dict, dependency=dependent_singleton) - - dependent_instance = dependent_singleton() - instance1 = provider() - self.assertIs(instance1["dependency"], dependent_instance) - - provider.reset() - - instance2 = provider() - self.assertIs(instance1["dependency"], dependent_instance) - - self.assertIsNot(instance1, instance2) - - def test_reset_context_manager(self): - singleton = self.singleton_cls(object) - - instance1 = singleton() - with singleton.reset(): - instance2 = singleton() - instance3 = singleton() - self.assertEqual(len({instance1, instance2, instance3}), 3) - - def test_reset_context_manager_as_attribute(self): - singleton = self.singleton_cls(object) - - with singleton.reset() as alias: - pass - - self.assertIs(singleton, alias) - - def test_full_reset(self): - dependent_singleton = providers.Singleton(object) - provider = self.singleton_cls(dict, dependency=dependent_singleton) - - dependent_instance1 = dependent_singleton() - instance1 = provider() - self.assertIs(instance1["dependency"], dependent_instance1) - - provider.full_reset() - - dependent_instance2 = dependent_singleton() - instance2 = provider() - self.assertIsNot(instance2["dependency"], dependent_instance1) - self.assertIsNot(dependent_instance1, dependent_instance2) - self.assertIsNot(instance1, instance2) - - def test_full_reset_context_manager(self): - class Item: - def __init__(self, dependency): - self.dependency = dependency - - dependent_singleton = providers.Singleton(object) - singleton = self.singleton_cls(Item, dependency=dependent_singleton) - - instance1 = singleton() - with singleton.full_reset(): - instance2 = singleton() - instance3 = singleton() - - self.assertEqual(len({instance1, instance2, instance3}), 3) - self.assertEqual( - len({instance1.dependency, instance2.dependency, instance3.dependency}), - 3, - ) - - def test_full_reset_context_manager_as_attribute(self): - singleton = self.singleton_cls(object) - - with singleton.full_reset() as alias: - pass - - self.assertIs(singleton, alias) diff --git a/tests/unit/providers/test_async_py36.py b/tests/unit/providers/test_async_py36.py deleted file mode 100644 index e0603a7b..00000000 --- a/tests/unit/providers/test_async_py36.py +++ /dev/null @@ -1,1175 +0,0 @@ -import asyncio -import random -import unittest - -from dependency_injector import containers, providers, errors - -# Runtime import to get asyncutils module -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -import sys -sys.path.append(_TOP_DIR) - -from asyncutils import AsyncTestCase - - -RESOURCE1 = object() -RESOURCE2 = object() - - -async def init_resource(resource): - await asyncio.sleep(random.randint(1, 10) / 1000) - yield resource - await asyncio.sleep(random.randint(1, 10) / 1000) - - -class Client: - def __init__(self, resource1: object, resource2: object) -> None: - self.resource1 = resource1 - self.resource2 = resource2 - - -class Service: - def __init__(self, client: Client) -> None: - self.client = client - - -class Container(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - client = providers.Factory( - Client, - resource1=resource1, - resource2=resource2, - ) - - service = providers.Factory( - Service, - client=client, - ) - - -class FactoryTests(AsyncTestCase): - - def test_args_injection(self): - class ContainerWithArgs(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - client = providers.Factory( - Client, - resource1, - resource2, - ) - - service = providers.Factory( - Service, - client, - ) - - container = ContainerWithArgs() - - client1 = self._run(container.client()) - client2 = self._run(container.client()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service()) - service2 = self._run(container.service()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIsNot(service1.client, service2.client) - - def test_kwargs_injection(self): - container = Container() - - client1 = self._run(container.client()) - client2 = self._run(container.client()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service()) - service2 = self._run(container.service()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIsNot(service1.client, service2.client) - - def test_context_kwargs_injection(self): - resource2_extra = object() - - container = Container() - - client1 = self._run(container.client(resource2=resource2_extra)) - client2 = self._run(container.client(resource2=resource2_extra)) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, resource2_extra) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, resource2_extra) - - def test_args_kwargs_injection(self): - class ContainerWithArgsAndKwArgs(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - client = providers.Factory( - Client, - resource1, - resource2=resource2, - ) - - service = providers.Factory( - Service, - client=client, - ) - - container = ContainerWithArgsAndKwArgs() - - client1 = self._run(container.client()) - client2 = self._run(container.client()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service()) - service2 = self._run(container.service()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIsNot(service1.client, service2.client) - - def test_injection_error(self): - async def init_resource(): - raise Exception("Something went wrong") - - class Container(containers.DeclarativeContainer): - resource_with_error = providers.Resource(init_resource) - - client = providers.Factory( - Client, - resource1=resource_with_error, - resource2=None, - ) - - container = Container() - - with self.assertRaises(Exception) as context: - self._run(container.client()) - self.assertEqual(str(context.exception), "Something went wrong") - - def test_injection_runtime_error_async_provides(self): - async def create_client(*args, **kwargs): - raise Exception("Something went wrong") - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - - client = providers.Factory( - create_client, - resource1=resource, - resource2=None, - ) - - container = Container() - - with self.assertRaises(Exception) as context: - self._run(container.client()) - self.assertEqual(str(context.exception), "Something went wrong") - - def test_injection_call_error_async_provides(self): - async def create_client(): # <-- no args defined - ... - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - - client = providers.Factory( - create_client, - resource1=resource, - resource2=None, - ) - - container = Container() - - with self.assertRaises(TypeError) as context: - self._run(container.client()) - self.assertIn("create_client() got", str(context.exception)) - self.assertIn("unexpected keyword argument", str(context.exception)) - - def test_attributes_injection(self): - class ContainerWithAttributes(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - client = providers.Factory( - Client, - resource1, - resource2=None, - ) - client.add_attributes(resource2=resource2) - - service = providers.Factory( - Service, - client=None, - ) - service.add_attributes(client=client) - - container = ContainerWithAttributes() - - client1 = self._run(container.client()) - client2 = self._run(container.client()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service()) - service2 = self._run(container.service()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIsNot(service1.client, service2.client) - - def test_attributes_injection_attribute_error(self): - class ClientWithException(Client): - @property - def attribute_set_error(self): - return None - - @attribute_set_error.setter - def attribute_set_error(self, value): - raise Exception("Something went wrong") - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - - client = providers.Factory( - ClientWithException, - resource1=resource, - resource2=resource, - ) - client.add_attributes(attribute_set_error=123) - - container = Container() - - with self.assertRaises(Exception) as context: - self._run(container.client()) - self.assertEqual(str(context.exception), "Something went wrong") - - def test_attributes_injection_runtime_error(self): - async def init_resource(): - raise Exception("Something went wrong") - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource) - - client = providers.Factory( - Client, - resource1=None, - resource2=None, - ) - client.add_attributes(resource1=resource) - client.add_attributes(resource2=resource) - - container = Container() - - with self.assertRaises(Exception) as context: - self._run(container.client()) - self.assertEqual(str(context.exception), "Something went wrong") - - def test_async_instance_and_sync_attributes_injection(self): - class ContainerWithAttributes(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - - client = providers.Factory( - Client, - resource1, - resource2=None, - ) - client.add_attributes(resource2=providers.Object(RESOURCE2)) - - service = providers.Factory( - Service, - client=None, - ) - service.add_attributes(client=client) - - container = ContainerWithAttributes() - - client1 = self._run(container.client()) - client2 = self._run(container.client()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service()) - service2 = self._run(container.service()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIsNot(service1.client, service2.client) - - -class FactoryAggregateTests(AsyncTestCase): - - def test_async_mode(self): - object1 = object() - object2 = object() - - async def _get_object1(): - return object1 - - def _get_object2(): - return object2 - - provider = providers.FactoryAggregate( - object1=providers.Factory(_get_object1), - object2=providers.Factory(_get_object2), - ) - - self.assertTrue(provider.is_async_mode_undefined()) - - created_object1 = self._run(provider("object1")) - self.assertIs(created_object1, object1) - self.assertTrue(provider.is_async_mode_enabled()) - - created_object2 = self._run(provider("object2")) - self.assertIs(created_object2, object2) - - -class SingletonTests(AsyncTestCase): - - def test_injections(self): - class ContainerWithSingletons(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - client = providers.Singleton( - Client, - resource1=resource1, - resource2=resource2, - ) - - service = providers.Singleton( - Service, - client=client, - ) - - container = ContainerWithSingletons() - - client1 = self._run(container.client()) - client2 = self._run(container.client()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service()) - service2 = self._run(container.service()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIs(service1, service2) - self.assertIs(service1.client, service2.client) - self.assertIs(service1.client, client1) - - self.assertIs(service2.client, client2) - self.assertIs(client1, client2) - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.Singleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - def test_async_init_with_error(self): - # Disable default exception handling to prevent output - asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...) - - async def create_instance(): - create_instance.counter += 1 - raise RuntimeError() - - create_instance.counter = 0 - - provider = providers.Singleton(create_instance) - - - future = provider() - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(future) - - self.assertEqual(create_instance.counter, 1) - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(provider()) - - self.assertEqual(create_instance.counter, 2) - self.assertTrue(provider.is_async_mode_enabled()) - - # Restore default exception handling - asyncio.get_event_loop().set_exception_handler(None) - - -class DelegatedSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.DelegatedSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class ThreadSafeSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.ThreadSafeSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class DelegatedThreadSafeSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.DelegatedThreadSafeSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class ThreadLocalSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.ThreadLocalSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - - def test_async_init_with_error(self): - # Disable default exception handling to prevent output - asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...) - - async def create_instance(): - create_instance.counter += 1 - raise RuntimeError() - create_instance.counter = 0 - - provider = providers.ThreadLocalSingleton(create_instance) - - future = provider() - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(future) - - self.assertEqual(create_instance.counter, 1) - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(provider()) - - self.assertEqual(create_instance.counter, 2) - self.assertTrue(provider.is_async_mode_enabled()) - - # Restore default exception handling - asyncio.get_event_loop().set_exception_handler(None) - - -class DelegatedThreadLocalSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.DelegatedThreadLocalSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class ProvidedInstanceTests(AsyncTestCase): - - def test_provided_attribute(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - class TestService: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - service = providers.Factory(TestService, resource=client.provided.resource) - - container = TestContainer() - - instance1, instance2 = self._run( - asyncio.gather( - container.service(), - container.service(), - ), - ) - - self.assertIs(instance1.resource, RESOURCE1) - self.assertIs(instance2.resource, RESOURCE1) - self.assertIs(instance1.resource, instance2.resource) - - def test_provided_attribute_error(self): - async def raise_exception(): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(raise_exception) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided.attr()) - - def test_provided_attribute_undefined_attribute(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - - container = TestContainer() - - with self.assertRaises(AttributeError): - self._run(container.client.provided.attr()) - - def test_provided_item(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - def __getitem__(self, item): - return getattr(self, item) - - class TestService: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - service = providers.Factory(TestService, resource=client.provided["resource"]) - - container = TestContainer() - - instance1, instance2 = self._run( - asyncio.gather( - container.service(), - container.service(), - ), - ) - - self.assertIs(instance1.resource, RESOURCE1) - self.assertIs(instance2.resource, RESOURCE1) - self.assertIs(instance1.resource, instance2.resource) - - def test_provided_item_error(self): - async def raise_exception(): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(raise_exception) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided["item"]()) - - def test_provided_item_undefined_item(self): - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(dict, resource=resource) - - container = TestContainer() - - with self.assertRaises(KeyError): - self._run(container.client.provided["item"]()) - - def test_provided_method_call(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - def get_resource(self): - return self.resource - - class TestService: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - service = providers.Factory(TestService, resource=client.provided.get_resource.call()) - - container = TestContainer() - - instance1, instance2 = self._run( - asyncio.gather( - container.service(), - container.service(), - ), - ) - - self.assertIs(instance1.resource, RESOURCE1) - self.assertIs(instance2.resource, RESOURCE1) - self.assertIs(instance1.resource, instance2.resource) - - def test_provided_method_call_parent_error(self): - async def raise_exception(): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(raise_exception) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided.method.call()()) - - def test_provided_method_call_error(self): - class TestClient: - def method(self): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(TestClient) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided.method.call()()) - - -class DependencyTests(AsyncTestCase): - - def test_provide_error(self): - async def get_async(): - raise Exception - - provider = providers.Dependency() - provider.override(providers.Callable(get_async)) - - with self.assertRaises(Exception): - self._run(provider()) - - def test_isinstance(self): - dependency = 1.0 - - async def get_async(): - return dependency - - provider = providers.Dependency(instance_of=float) - provider.override(providers.Callable(get_async)) - - self.assertTrue(provider.is_async_mode_undefined()) - - dependency1 = self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - dependency2 = self._run(provider()) - - self.assertEqual(dependency1, dependency) - self.assertEqual(dependency2, dependency) - - def test_isinstance_invalid(self): - async def get_async(): - return {} - - provider = providers.Dependency(instance_of=float) - provider.override(providers.Callable(get_async)) - - self.assertTrue(provider.is_async_mode_undefined()) - - with self.assertRaises(errors.Error): - self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - def test_async_mode(self): - dependency = 123 - - async def get_async(): - return dependency - - def get_sync(): - return dependency - - provider = providers.Dependency(instance_of=int) - provider.override(providers.Factory(get_async)) - - self.assertTrue(provider.is_async_mode_undefined()) - - dependency1 = self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - dependency2 = self._run(provider()) - self.assertEqual(dependency1, dependency) - self.assertEqual(dependency2, dependency) - - provider.override(providers.Factory(get_sync)) - - dependency3 = self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - dependency4 = self._run(provider()) - self.assertEqual(dependency3, dependency) - self.assertEqual(dependency4, dependency) - - -class ListTests(AsyncTestCase): - - def test_provide(self): - # See issue: https://github.com/ets-labs/python-dependency-injector/issues/450 - async def create_resource(param: str): - return param - - class Container(containers.DeclarativeContainer): - - resources = providers.List( - providers.Resource(create_resource, "foo"), - providers.Resource(create_resource, "bar") - ) - - container = Container() - resources = self._run(container.resources()) - - self.assertEqual(resources[0], "foo") - self.assertEqual(resources[1], "bar") - - -class DictTests(AsyncTestCase): - - def test_provide(self): - async def create_resource(param: str): - return param - - class Container(containers.DeclarativeContainer): - - resources = providers.Dict( - foo=providers.Resource(create_resource, "foo"), - bar=providers.Resource(create_resource, "bar") - ) - - container = Container() - resources = self._run(container.resources()) - - self.assertEqual(resources["foo"], "foo") - self.assertEqual(resources["bar"], "bar") - - -class OverrideTests(AsyncTestCase): - - def test_provider(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - def _get_dependency_sync(): - return dependency - - provider = providers.Provider() - - provider.override(providers.Callable(_get_dependency_async)) - dependency1 = self._run(provider()) - - provider.override(providers.Callable(_get_dependency_sync)) - dependency2 = self._run(provider()) - - self.assertIs(dependency1, dependency) - self.assertIs(dependency2, dependency) - - def test_callable(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - def _get_dependency_sync(): - return dependency - - provider = providers.Callable(_get_dependency_async) - dependency1 = self._run(provider()) - - provider.override(providers.Callable(_get_dependency_sync)) - dependency2 = self._run(provider()) - - self.assertIs(dependency1, dependency) - self.assertIs(dependency2, dependency) - - def test_factory(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - def _get_dependency_sync(): - return dependency - - provider = providers.Factory(_get_dependency_async) - dependency1 = self._run(provider()) - - provider.override(providers.Callable(_get_dependency_sync)) - dependency2 = self._run(provider()) - - self.assertIs(dependency1, dependency) - self.assertIs(dependency2, dependency) - - def test_async_mode_enabling(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - provider = providers.Callable(_get_dependency_async) - self.assertTrue(provider.is_async_mode_undefined()) - - self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - def test_async_mode_disabling(self): - dependency = object() - - def _get_dependency(): - return dependency - - provider = providers.Callable(_get_dependency) - self.assertTrue(provider.is_async_mode_undefined()) - - provider() - - self.assertTrue(provider.is_async_mode_disabled()) - - def test_async_mode_enabling_on_overriding(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - provider = providers.Provider() - provider.override(providers.Callable(_get_dependency_async)) - self.assertTrue(provider.is_async_mode_undefined()) - - self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - def test_async_mode_disabling_on_overriding(self): - dependency = object() - - def _get_dependency(): - return dependency - - provider = providers.Provider() - provider.override(providers.Callable(_get_dependency)) - self.assertTrue(provider.is_async_mode_undefined()) - - provider() - - self.assertTrue(provider.is_async_mode_disabled()) - - -class TestAsyncModeApi(unittest.TestCase): - - def setUp(self): - self.provider = providers.Provider() - - def test_default_mode(self): - self.assertFalse(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertTrue(self.provider.is_async_mode_undefined()) - - def test_enable(self): - self.provider.enable_async_mode() - - self.assertTrue(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertFalse(self.provider.is_async_mode_undefined()) - - def test_disable(self): - self.provider.disable_async_mode() - - self.assertFalse(self.provider.is_async_mode_enabled()) - self.assertTrue(self.provider.is_async_mode_disabled()) - self.assertFalse(self.provider.is_async_mode_undefined()) - - def test_reset(self): - self.provider.enable_async_mode() - - self.assertTrue(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertFalse(self.provider.is_async_mode_undefined()) - - self.provider.reset_async_mode() - - self.assertFalse(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertTrue(self.provider.is_async_mode_undefined()) - - -class AsyncTypingStubTests(AsyncTestCase): - - def test_async_(self): - container = Container() - - client1 = self._run(container.client.async_()) - client2 = self._run(container.client.async_()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service.async_()) - service2 = self._run(container.service.async_()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIsNot(service1.client, service2.client) - - -class AsyncProvidersWithAsyncDependenciesTests(AsyncTestCase): - - def test_injections(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/368 - async def async_db_provider(): - return {"db": "ok"} - - async def async_service(db=None): - return {"service": "ok", "db": db} - - class Container(containers.DeclarativeContainer): - - db = providers.Factory(async_db_provider) - service = providers.Singleton(async_service, db=db) - - container = Container() - service = self._run(container.service()) - - self.assertEqual(service, {"service": "ok", "db": {"db": "ok"}}) - - -class AsyncProviderWithAwaitableObjectTests(AsyncTestCase): - - def test(self): - class SomeResource: - def __await__(self): - raise RuntimeError("Should never happen") - - async def init_resource(): - pool = SomeResource() - yield pool - - class Service: - def __init__(self, resource) -> None: - self.resource = resource - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource) - service = providers.Singleton(Service, resource=resource) - - container = Container() - - self._run(container.init_resources()) - self.assertIsInstance(container.service(), asyncio.Future) - self.assertIsInstance(container.resource(), asyncio.Future) - - resource = self._run(container.resource()) - service = self._run(container.service()) - - self.assertIsInstance(resource, SomeResource) - self.assertIsInstance(service.resource, SomeResource) - self.assertIs(service.resource, resource) - - def test_without_init_resources(self): - class SomeResource: - def __await__(self): - raise RuntimeError("Should never happen") - - async def init_resource(): - pool = SomeResource() - yield pool - - class Service: - def __init__(self, resource) -> None: - self.resource = resource - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource) - service = providers.Singleton(Service, resource=resource) - - container = Container() - - self.assertIsInstance(container.service(), asyncio.Future) - self.assertIsInstance(container.resource(), asyncio.Future) - - resource = self._run(container.resource()) - service = self._run(container.service()) - - self.assertIsInstance(resource, SomeResource) - self.assertIsInstance(service.resource, SomeResource) - self.assertIs(service.resource, resource) diff --git a/tests/unit/providers/test_base_py2_py3.py b/tests/unit/providers/test_base_py2_py3.py deleted file mode 100644 index b93f7f6b..00000000 --- a/tests/unit/providers/test_base_py2_py3.py +++ /dev/null @@ -1,791 +0,0 @@ -"""Dependency injector base providers unit tests.""" - -import unittest -import warnings - -from dependency_injector import ( - containers, - providers, - errors, -) - - -class ProviderTests(unittest.TestCase): - - def setUp(self): - self.provider = providers.Provider() - - def test_is_provider(self): - self.assertTrue(providers.is_provider(self.provider)) - - def test_call(self): - self.assertRaises(NotImplementedError, self.provider.__call__) - - def test_delegate(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - delegate1 = self.provider.delegate() - delegate2 = self.provider.delegate() - - self.assertIsInstance(delegate1, providers.Delegate) - self.assertIs(delegate1(), self.provider) - - self.assertIsInstance(delegate2, providers.Delegate) - self.assertIs(delegate2(), self.provider) - - self.assertIsNot(delegate1, delegate2) - - def test_provider(self): - delegate1 = self.provider.provider - - self.assertIsInstance(delegate1, providers.Delegate) - self.assertIs(delegate1(), self.provider) - - delegate2 = self.provider.provider - - self.assertIsInstance(delegate2, providers.Delegate) - self.assertIs(delegate2(), self.provider) - - self.assertIsNot(delegate1, delegate2) - - def test_override(self): - overriding_provider = providers.Provider() - self.provider.override(overriding_provider) - self.assertTrue(self.provider.overridden) - self.assertIs(self.provider.last_overriding, overriding_provider) - - def test_double_override(self): - overriding_provider1 = providers.Object(1) - overriding_provider2 = providers.Object(2) - - self.provider.override(overriding_provider1) - overriding_provider1.override(overriding_provider2) - - self.assertEqual(self.provider(), overriding_provider2()) - - def test_overriding_context(self): - overriding_provider = providers.Provider() - with self.provider.override(overriding_provider): - self.assertTrue(self.provider.overridden) - self.assertFalse(self.provider.overridden) - - def test_override_with_itself(self): - self.assertRaises(errors.Error, self.provider.override, self.provider) - - def test_override_with_not_provider(self): - obj = object() - self.provider.override(obj) - self.assertIs(self.provider(), obj) - - def test_reset_last_overriding(self): - overriding_provider1 = providers.Provider() - overriding_provider2 = providers.Provider() - - self.provider.override(overriding_provider1) - self.provider.override(overriding_provider2) - - self.assertIs(self.provider.overridden[-1], overriding_provider2) - self.assertIs(self.provider.last_overriding, overriding_provider2) - - self.provider.reset_last_overriding() - self.assertIs(self.provider.overridden[-1], overriding_provider1) - self.assertIs(self.provider.last_overriding, overriding_provider1) - - self.provider.reset_last_overriding() - self.assertFalse(self.provider.overridden) - self.assertIsNone(self.provider.last_overriding) - - def test_reset_last_overriding_of_not_overridden_provider(self): - self.assertRaises(errors.Error, self.provider.reset_last_overriding) - - def test_reset_override(self): - overriding_provider = providers.Provider() - self.provider.override(overriding_provider) - - self.assertTrue(self.provider.overridden) - self.assertEqual(self.provider.overridden, (overriding_provider,)) - - self.provider.reset_override() - - self.assertEqual(self.provider.overridden, tuple()) - - def test_deepcopy(self): - provider = providers.Provider() - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Provider) - - def test_deepcopy_from_memo(self): - provider = providers.Provider() - provider_copy_memo = providers.Provider() - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_overridden(self): - provider = providers.Provider() - overriding_provider = providers.Provider() - - provider.override(overriding_provider) - - provider_copy = providers.deepcopy(provider) - overriding_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Provider) - - self.assertIsNot(overriding_provider, overriding_provider_copy) - self.assertIsInstance(overriding_provider_copy, providers.Provider) - - def test_repr(self): - self.assertEqual(repr(self.provider), - "".format(hex(id(self.provider)))) - - -class ObjectProviderTests(unittest.TestCase): - - def test_is_provider(self): - self.assertTrue(providers.is_provider(providers.Object(object()))) - - def test_init_optional_provides(self): - instance = object() - provider = providers.Object() - provider.set_provides(instance) - self.assertIs(provider.provides, instance) - self.assertIs(provider(), instance) - - def test_set_provides_returns_self(self): - provider = providers.Object() - self.assertIs(provider.set_provides(object()), provider) - - def test_provided_instance_provider(self): - provider = providers.Object(object()) - self.assertIsInstance(provider.provided, providers.ProvidedInstance) - - def test_call_object_provider(self): - obj = object() - self.assertIs(providers.Object(obj)(), obj) - - def test_call_overridden_object_provider(self): - obj1 = object() - obj2 = object() - provider = providers.Object(obj1) - provider.override(providers.Object(obj2)) - self.assertIs(provider(), obj2) - - def test_deepcopy(self): - provider = providers.Object(1) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Object) - - def test_deepcopy_from_memo(self): - provider = providers.Object(1) - provider_copy_memo = providers.Provider() - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_overridden(self): - provider = providers.Object(1) - overriding_provider = providers.Provider() - - provider.override(overriding_provider) - - provider_copy = providers.deepcopy(provider) - overriding_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Object) - - self.assertIsNot(overriding_provider, overriding_provider_copy) - self.assertIsInstance(overriding_provider_copy, providers.Provider) - - def test_deepcopy_doesnt_copy_provided_object(self): - # Fixes bug #231 - # Details: https://github.com/ets-labs/python-dependency-injector/issues/231 - some_object = object() - provider = providers.Object(some_object) - - provider_copy = providers.deepcopy(provider) - - self.assertIs(provider(), some_object) - self.assertIs(provider_copy(), some_object) - - def test_repr(self): - some_object = object() - provider = providers.Object(some_object) - self.assertEqual(repr(provider), - "".format( - repr(some_object), - hex(id(provider)))) - - -class SelfProviderTests(unittest.TestCase): - - def test_is_provider(self): - self.assertTrue(providers.is_provider(providers.Self())) - - def test_call_object_provider(self): - container = containers.DeclarativeContainer() - self.assertIs(providers.Self(container)(), container) - - def test_set_container(self): - container = containers.DeclarativeContainer() - provider = providers.Self() - provider.set_container(container) - self.assertIs(provider(), container) - - def test_set_alt_names(self): - provider = providers.Self() - provider.set_alt_names({"foo", "bar", "baz"}) - self.assertEqual(set(provider.alt_names), {"foo", "bar", "baz"}) - - def test_deepcopy(self): - provider = providers.Self() - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Self) - - def test_deepcopy_from_memo(self): - provider = providers.Self() - provider_copy_memo = providers.Provider() - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_overridden(self): - provider = providers.Self() - overriding_provider = providers.Provider() - - provider.override(overriding_provider) - - provider_copy = providers.deepcopy(provider) - overriding_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Self) - - self.assertIsNot(overriding_provider, overriding_provider_copy) - self.assertIsInstance(overriding_provider_copy, providers.Provider) - - def test_repr(self): - container = containers.DeclarativeContainer() - provider = providers.Self(container) - self.assertEqual(repr(provider), - "".format( - repr(container), - hex(id(provider)))) - - -class DelegateTests(unittest.TestCase): - - def setUp(self): - self.delegated = providers.Provider() - self.delegate = providers.Delegate(self.delegated) - - def test_is_provider(self): - self.assertTrue(providers.is_provider(self.delegate)) - - def test_init_optional_provides(self): - provider = providers.Delegate() - provider.set_provides(self.delegated) - self.assertIs(provider.provides, self.delegated) - self.assertIs(provider(), self.delegated) - - def test_set_provides_returns_self(self): - provider = providers.Delegate() - self.assertIs(provider.set_provides(self.delegated), provider) - - def test_init_with_not_provider(self): - self.assertRaises(errors.Error, providers.Delegate, object()) - - def test_call(self): - delegated1 = self.delegate() - delegated2 = self.delegate() - - self.assertIs(delegated1, self.delegated) - self.assertIs(delegated2, self.delegated) - - def test_repr(self): - self.assertEqual(repr(self.delegate), - "".format( - repr(self.delegated), - hex(id(self.delegate)))) - - -class DependencyTests(unittest.TestCase): - - def setUp(self): - self.provider = providers.Dependency(instance_of=list) - - def test_init_optional(self): - list_provider = providers.List(1, 2, 3) - provider = providers.Dependency() - provider.set_instance_of(list) - provider.set_default(list_provider) - - self.assertIs(provider.instance_of, list) - self.assertIs(provider.default, list_provider) - self.assertEqual(provider(), [1, 2, 3]) - - def test_set_instance_of_returns_self(self): - provider = providers.Dependency() - self.assertIs(provider.set_instance_of(list), provider) - - def test_set_default_returns_self(self): - provider = providers.Dependency() - self.assertIs(provider.set_default(providers.Provider()), provider) - - def test_init_with_not_class(self): - self.assertRaises(TypeError, providers.Dependency, object()) - - def test_with_abc(self): - try: - import collections.abc as collections_abc - except ImportError: - import collections as collections_abc - - provider = providers.Dependency(collections_abc.Mapping) - provider.provided_by(providers.Factory(dict)) - - self.assertIsInstance(provider(), collections_abc.Mapping) - self.assertIsInstance(provider(), dict) - - def test_is_provider(self): - self.assertTrue(providers.is_provider(self.provider)) - - def test_provided_instance_provider(self): - self.assertIsInstance(self.provider.provided, providers.ProvidedInstance) - - def test_default(self): - provider = providers.Dependency(instance_of=dict, default={"foo": "bar"}) - self.assertEqual(provider(), {"foo": "bar"}) - - def test_default_attribute(self): - provider = providers.Dependency(instance_of=dict, default={"foo": "bar"}) - self.assertEqual(provider.default(), {"foo": "bar"}) - - def test_default_provider(self): - provider = providers.Dependency(instance_of=dict, default=providers.Factory(dict, foo="bar")) - self.assertEqual(provider.default(), {"foo": "bar"}) - - def test_default_attribute_provider(self): - default = providers.Factory(dict, foo="bar") - provider = providers.Dependency(instance_of=dict, default=default) - - self.assertEqual(provider.default(), {"foo": "bar"}) - self.assertIs(provider.default, default) - - def test_is_defined(self): - provider = providers.Dependency() - self.assertFalse(provider.is_defined) - - def test_is_defined_when_overridden(self): - provider = providers.Dependency() - provider.override("value") - self.assertTrue(provider.is_defined) - - def test_is_defined_with_default(self): - provider = providers.Dependency(default="value") - self.assertTrue(provider.is_defined) - - def test_call_overridden(self): - self.provider.provided_by(providers.Factory(list)) - self.assertIsInstance(self.provider(), list) - - def test_call_overridden_but_not_instance_of(self): - self.provider.provided_by(providers.Factory(dict)) - self.assertRaises(errors.Error, self.provider) - - def test_call_undefined(self): - with self.assertRaises(errors.Error) as context: - self.provider() - self.assertEqual(str(context.exception), "Dependency is not defined") - - def test_call_undefined_error_message_with_container_instance_parent(self): - class UserService: - def __init__(self, database): - self.database = database - - class Container(containers.DeclarativeContainer): - database = providers.Dependency() - - user_service = providers.Factory( - UserService, - database=database, # <---- missing dependency - ) - - container = Container() - - with self.assertRaises(errors.Error) as context: - container.user_service() - - self.assertEqual(str(context.exception), "Dependency \"Container.database\" is not defined") - - def test_call_undefined_error_message_with_container_provider_parent_deep(self): - class Database: - pass - - class UserService: - def __init__(self, db): - self.db = db - - class Gateways(containers.DeclarativeContainer): - database_client = providers.Singleton(Database) - - class Services(containers.DeclarativeContainer): - gateways = providers.DependenciesContainer() - - user = providers.Factory( - UserService, - db=gateways.database_client, - ) - - class Container(containers.DeclarativeContainer): - gateways = providers.Container(Gateways) - - services = providers.Container( - Services, - # gateways=gateways, # <---- missing dependency - ) - - container = Container() - - with self.assertRaises(errors.Error) as context: - container.services().user() - - self.assertEqual( - str(context.exception), - "Dependency \"Container.services.gateways.database_client\" is not defined", - ) - - def test_call_undefined_error_message_with_dependenciescontainer_provider_parent(self): - class UserService: - def __init__(self, db): - self.db = db - - class Services(containers.DeclarativeContainer): - gateways = providers.DependenciesContainer() - - user = providers.Factory( - UserService, - db=gateways.database_client, # <---- missing dependency - ) - - services = Services() - - with self.assertRaises(errors.Error) as context: - services.user() - - self.assertEqual( - str(context.exception), - "Dependency \"Services.gateways.database_client\" is not defined", - ) - - def test_assign_parent(self): - parent = providers.DependenciesContainer() - provider = providers.Dependency() - - provider.assign_parent(parent) - - self.assertIs(provider.parent, parent) - - def test_parent_name(self): - container = containers.DynamicContainer() - provider = providers.Dependency() - container.name = provider - self.assertEqual(provider.parent_name, "name") - - def test_parent_name_with_deep_parenting(self): - provider = providers.Dependency() - container = providers.DependenciesContainer(name=provider) - _ = providers.DependenciesContainer(container=container) - self.assertEqual(provider.parent_name, "container.name") - - def test_parent_name_is_none(self): - provider = providers.DependenciesContainer() - self.assertIsNone(provider.parent_name) - - def test_parent_deepcopy(self): - container = containers.DynamicContainer() - provider = providers.Dependency() - container.name = provider - - copied = providers.deepcopy(container) - - self.assertIs(container.name.parent, container) - self.assertIs(copied.name.parent, copied) - - self.assertIsNot(container, copied) - self.assertIsNot(container.name, copied.name) - self.assertIsNot(container.name.parent, copied.name.parent) - - def test_forward_attr_to_default(self): - default = providers.Configuration() - - provider = providers.Dependency(default=default) - provider.from_dict({"foo": "bar"}) - - self.assertEqual(default(), {"foo": "bar"}) - - def test_forward_attr_to_overriding(self): - overriding = providers.Configuration() - - provider = providers.Dependency() - provider.override(overriding) - provider.from_dict({"foo": "bar"}) - - self.assertEqual(overriding(), {"foo": "bar"}) - - def test_forward_attr_to_none(self): - provider = providers.Dependency() - with self.assertRaises(AttributeError): - provider.from_dict - - def test_deepcopy(self): - provider = providers.Dependency(int) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Dependency) - - def test_deepcopy_from_memo(self): - provider = providers.Dependency(int) - provider_copy_memo = providers.Provider() - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_overridden(self): - provider = providers.Dependency(int) - overriding_provider = providers.Provider() - - provider.override(overriding_provider) - - provider_copy = providers.deepcopy(provider) - overriding_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Dependency) - - self.assertIsNot(overriding_provider, overriding_provider_copy) - self.assertIsInstance(overriding_provider_copy, providers.Provider) - - def test_deep_copy_default_object(self): - default = {"foo": "bar"} - provider = providers.Dependency(dict, default=default) - - provider_copy = providers.deepcopy(provider) - - self.assertIs(provider_copy(), default) - self.assertIs(provider_copy.default(), default) - - def test_deep_copy_default_provider(self): - bar = object() - default = providers.Factory(dict, foo=providers.Object(bar)) - provider = providers.Dependency(dict, default=default) - - provider_copy = providers.deepcopy(provider) - - self.assertEqual(provider_copy(), {"foo": bar}) - self.assertEqual(provider_copy.default(), {"foo": bar}) - self.assertIs(provider_copy()["foo"], bar) - - def test_with_container_default_object(self): - default = {"foo": "bar"} - - class Container(containers.DeclarativeContainer): - provider = providers.Dependency(dict, default=default) - - container = Container() - - self.assertIs(container.provider(), default) - self.assertIs(container.provider.default(), default) - - def test_with_container_default_provider(self): - bar = object() - - class Container(containers.DeclarativeContainer): - provider = providers.Dependency(dict, default=providers.Factory(dict, foo=providers.Object(bar))) - - container = Container() - - self.assertEqual(container.provider(), {"foo": bar}) - self.assertEqual(container.provider.default(), {"foo": bar}) - self.assertIs(container.provider()["foo"], bar) - - def test_with_container_default_provider_with_overriding(self): - bar = object() - baz = object() - - class Container(containers.DeclarativeContainer): - provider = providers.Dependency(dict, default=providers.Factory(dict, foo=providers.Object(bar))) - - container = Container(provider=providers.Factory(dict, foo=providers.Object(baz))) - - self.assertEqual(container.provider(), {"foo": baz}) - self.assertEqual(container.provider.default(), {"foo": bar}) - self.assertIs(container.provider()["foo"], baz) - - def test_repr(self): - self.assertEqual(repr(self.provider), - "".format( - repr(list), - hex(id(self.provider)))) - - def test_repr_in_container(self): - class Container(containers.DeclarativeContainer): - dependency = providers.Dependency(instance_of=int) - - container = Container() - - self.assertEqual(repr(container.dependency), - "".format( - repr(int), - hex(id(container.dependency)))) - - -class ExternalDependencyTests(unittest.TestCase): - - def setUp(self): - self.provider = providers.ExternalDependency(instance_of=list) - - def test_is_instance(self): - self.assertIsInstance(self.provider, providers.Dependency) - - -class DependenciesContainerTests(unittest.TestCase): - - class Container(containers.DeclarativeContainer): - - dependency = providers.Provider() - - def setUp(self): - self.provider = providers.DependenciesContainer() - self.container = self.Container() - - def test_getattr(self): - has_dependency = hasattr(self.provider, "dependency") - dependency = self.provider.dependency - - self.assertIsInstance(dependency, providers.Dependency) - self.assertIs(dependency, self.provider.dependency) - self.assertTrue(has_dependency) - self.assertIsNone(dependency.last_overriding) - - def test_getattr_with_container(self): - self.provider.override(self.container) - - dependency = self.provider.dependency - - self.assertTrue(dependency.overridden) - self.assertIs(dependency.last_overriding, self.container.dependency) - - def test_providers(self): - dependency1 = self.provider.dependency1 - dependency2 = self.provider.dependency2 - self.assertEqual(self.provider.providers, {"dependency1": dependency1, - "dependency2": dependency2}) - - def test_override(self): - dependency = self.provider.dependency - self.provider.override(self.container) - - self.assertTrue(dependency.overridden) - self.assertIs(dependency.last_overriding, self.container.dependency) - - def test_reset_last_overriding(self): - dependency = self.provider.dependency - self.provider.override(self.container) - self.provider.reset_last_overriding() - - self.assertIsNone(dependency.last_overriding) - self.assertIsNone(dependency.last_overriding) - - def test_reset_override(self): - dependency = self.provider.dependency - self.provider.override(self.container) - self.provider.reset_override() - - self.assertFalse(dependency.overridden) - self.assertFalse(dependency.overridden) - - def test_assign_parent(self): - parent = providers.DependenciesContainer() - provider = providers.DependenciesContainer() - - provider.assign_parent(parent) - - self.assertIs(provider.parent, parent) - - def test_parent_name(self): - container = containers.DynamicContainer() - provider = providers.DependenciesContainer() - container.name = provider - self.assertEqual(provider.parent_name, "name") - - def test_parent_name_with_deep_parenting(self): - provider = providers.DependenciesContainer() - container = providers.DependenciesContainer(name=provider) - _ = providers.DependenciesContainer(container=container) - self.assertEqual(provider.parent_name, "container.name") - - def test_parent_name_is_none(self): - provider = providers.DependenciesContainer() - self.assertIsNone(provider.parent_name) - - def test_parent_deepcopy(self): - container = containers.DynamicContainer() - provider = providers.DependenciesContainer() - container.name = provider - - copied = providers.deepcopy(container) - - self.assertIs(container.name.parent, container) - self.assertIs(copied.name.parent, copied) - - self.assertIsNot(container, copied) - self.assertIsNot(container.name, copied.name) - self.assertIsNot(container.name.parent, copied.name.parent) - - def test_parent_set_on__getattr__(self): - provider = providers.DependenciesContainer() - self.assertIsInstance(provider.name, providers.Dependency) - self.assertIs(provider.name.parent, provider) - - def test_parent_set_on__init__(self): - provider = providers.Dependency() - container = providers.DependenciesContainer(name=provider) - self.assertIs(container.name, provider) - self.assertIs(container.name.parent, container) - - def test_resolve_provider_name(self): - container = providers.DependenciesContainer() - self.assertEqual(container.resolve_provider_name(container.name), "name") - - def test_resolve_provider_name_no_provider(self): - container = providers.DependenciesContainer() - with self.assertRaises(errors.Error): - container.resolve_provider_name(providers.Provider()) diff --git a/tests/unit/providers/test_callables_py2_py3.py b/tests/unit/providers/test_callables_py2_py3.py deleted file mode 100644 index 9b42294a..00000000 --- a/tests/unit/providers/test_callables_py2_py3.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Dependency injector callable providers unit tests.""" - -import sys - -import unittest - -from dependency_injector import ( - providers, - errors, -) - - -def _example(arg1, arg2, arg3, arg4): - return arg1, arg2, arg3, arg4 - - -class CallableTests(unittest.TestCase): - - def test_init_with_callable(self): - self.assertTrue(providers.Callable(_example)) - - def test_init_with_not_callable(self): - self.assertRaises(errors.Error, providers.Callable, 123) - - def test_init_optional_provides(self): - provider = providers.Callable() - provider.set_provides(object) - self.assertIs(provider.provides, object) - self.assertIsInstance(provider(), object) - - def test_set_provides_returns_self(self): - provider = providers.Callable() - self.assertIs(provider.set_provides(object), provider) - - def test_provided_instance_provider(self): - provider = providers.Callable(_example) - self.assertIsInstance(provider.provided, providers.ProvidedInstance) - - def test_call(self): - provider = providers.Callable(lambda: True) - self.assertTrue(provider()) - - def test_call_with_positional_args(self): - provider = providers.Callable(_example, - 1, 2, 3, 4) - self.assertTupleEqual(provider(), (1, 2, 3, 4)) - - def test_call_with_keyword_args(self): - provider = providers.Callable(_example, - arg1=1, arg2=2, arg3=3, arg4=4) - self.assertTupleEqual(provider(), (1, 2, 3, 4)) - - def test_call_with_positional_and_keyword_args(self): - provider = providers.Callable(_example, - 1, 2, - arg3=3, arg4=4) - self.assertTupleEqual(provider(), (1, 2, 3, 4)) - - def test_call_with_context_args(self): - provider = providers.Callable(_example, 1, 2) - self.assertTupleEqual(provider(3, 4), (1, 2, 3, 4)) - - def test_call_with_context_kwargs(self): - provider = providers.Callable(_example, arg1=1) - self.assertTupleEqual(provider(arg2=2, arg3=3, arg4=4), (1, 2, 3, 4)) - - def test_call_with_context_args_and_kwargs(self): - provider = providers.Callable(_example, 1) - self.assertTupleEqual(provider(2, arg3=3, arg4=4), (1, 2, 3, 4)) - - def test_fluent_interface(self): - provider = providers.Singleton(_example) \ - .add_args(1, 2) \ - .add_kwargs(arg3=3, arg4=4) - - self.assertTupleEqual(provider(), (1, 2, 3, 4)) - - def test_set_args(self): - provider = providers.Callable(_example) \ - .add_args(1, 2) \ - .set_args(3, 4) - self.assertEqual(provider.args, (3, 4)) - - def test_set_kwargs(self): - provider = providers.Callable(_example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .set_kwargs(init_arg3=4, init_arg4=5) - self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) - - def test_clear_args(self): - provider = providers.Callable(_example) \ - .add_args(1, 2) \ - .clear_args() - self.assertEqual(provider.args, tuple()) - - def test_clear_kwargs(self): - provider = providers.Callable(_example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .clear_kwargs() - self.assertEqual(provider.kwargs, dict()) - - def test_call_overridden(self): - provider = providers.Callable(_example) - - provider.override(providers.Object((4, 3, 2, 1))) - provider.override(providers.Object((1, 2, 3, 4))) - - self.assertTupleEqual(provider(), (1, 2, 3, 4)) - - def test_deepcopy(self): - provider = providers.Callable(_example) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.provides, provider_copy.provides) - self.assertIsInstance(provider, providers.Callable) - - def test_deepcopy_from_memo(self): - provider = providers.Callable(_example) - provider_copy_memo = providers.Callable(_example) - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_args(self): - provider = providers.Callable(_example) - dependent_provider1 = providers.Callable(list) - dependent_provider2 = providers.Callable(dict) - - provider.add_args(dependent_provider1, dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.args[0] - dependent_provider_copy2 = provider_copy.args[1] - - self.assertNotEqual(provider.args, provider_copy.args) - - self.assertIs(dependent_provider1.provides, - dependent_provider_copy1.provides) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.provides, - dependent_provider_copy2.provides) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_kwargs(self): - provider = providers.Callable(_example) - dependent_provider1 = providers.Callable(list) - dependent_provider2 = providers.Callable(dict) - - provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.kwargs["a1"] - dependent_provider_copy2 = provider_copy.kwargs["a2"] - - self.assertNotEqual(provider.kwargs, provider_copy.kwargs) - - self.assertIs(dependent_provider1.provides, - dependent_provider_copy1.provides) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.provides, - dependent_provider_copy2.provides) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_overridden(self): - provider = providers.Callable(_example) - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.provides, provider_copy.provides) - self.assertIsInstance(provider, providers.Callable) - - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - - def test_deepcopy_with_sys_streams(self): - provider = providers.Callable(_example) - provider.add_args(sys.stdin) - provider.add_kwargs(a2=sys.stdout) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider_copy, providers.Callable) - self.assertIs(provider.args[0], sys.stdin) - self.assertIs(provider.kwargs["a2"], sys.stdout) - - def test_repr(self): - provider = providers.Callable(_example) - - self.assertEqual(repr(provider), - "".format( - repr(_example), - hex(id(provider)))) - - -class DelegatedCallableTests(unittest.TestCase): - - def test_inheritance(self): - self.assertIsInstance(providers.DelegatedCallable(_example), - providers.Callable) - - def test_is_provider(self): - self.assertTrue( - providers.is_provider(providers.DelegatedCallable(_example))) - - def test_is_delegated_provider(self): - provider = providers.DelegatedCallable(_example) - self.assertTrue(providers.is_delegated(provider)) - - def test_repr(self): - provider = providers.DelegatedCallable(_example) - - self.assertEqual(repr(provider), - "".format( - repr(_example), - hex(id(provider)))) - - -class AbstractCallableTests(unittest.TestCase): - - def test_inheritance(self): - self.assertIsInstance(providers.AbstractCallable(_example), - providers.Callable) - - def test_call_overridden_by_callable(self): - def _abstract_example(): - pass - - provider = providers.AbstractCallable(_abstract_example) - provider.override(providers.Callable(_example)) - - self.assertTrue(provider(1, 2, 3, 4), (1, 2, 3, 4)) - - def test_call_overridden_by_delegated_callable(self): - def _abstract_example(): - pass - - provider = providers.AbstractCallable(_abstract_example) - provider.override(providers.DelegatedCallable(_example)) - - self.assertTrue(provider(1, 2, 3, 4), (1, 2, 3, 4)) - - def test_call_not_overridden(self): - provider = providers.AbstractCallable(_example) - - with self.assertRaises(errors.Error): - provider(1, 2, 3, 4) - - def test_override_by_not_callable(self): - provider = providers.AbstractCallable(_example) - - with self.assertRaises(errors.Error): - provider.override(providers.Factory(object)) - - def test_provide_not_implemented(self): - provider = providers.AbstractCallable(_example) - - with self.assertRaises(NotImplementedError): - provider._provide((1, 2, 3, 4), dict()) - - def test_repr(self): - provider = providers.AbstractCallable(_example) - - self.assertEqual(repr(provider), - "".format( - repr(_example), - hex(id(provider)))) - - -class CallableDelegateTests(unittest.TestCase): - - def setUp(self): - self.delegated = providers.Callable(_example) - self.delegate = providers.CallableDelegate(self.delegated) - - def test_is_delegate(self): - self.assertIsInstance(self.delegate, providers.Delegate) - - def test_init_with_not_callable(self): - self.assertRaises(errors.Error, - providers.CallableDelegate, - providers.Object(object())) diff --git a/tests/unit/providers/test_configuration_py2_py3.py b/tests/unit/providers/test_configuration_py2_py3.py deleted file mode 100644 index dad317d0..00000000 --- a/tests/unit/providers/test_configuration_py2_py3.py +++ /dev/null @@ -1,1558 +0,0 @@ -"""Dependency injector config providers unit tests.""" - -import contextlib -import decimal -import os -import sys -import tempfile - -import unittest - -from dependency_injector import containers, providers, errors -try: - import yaml -except ImportError: - yaml = None - -try: - import pydantic -except ImportError: - pydantic = None - - -class ConfigTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - def tearDown(self): - del self.config - - def test_init_optional(self): - provider = providers.Configuration() - provider.set_name("myconfig") - provider.set_default({"foo": "bar"}) - provider.set_strict(True) - - self.assertEqual(provider.get_name(), "myconfig") - self.assertEqual(provider.get_default(), {"foo": "bar"}) - self.assertTrue(provider.get_strict()) - - def test_set_name_returns_self(self): - provider = providers.Configuration() - self.assertIs(provider.set_name("myconfig"), provider) - - def test_set_default_returns_self(self): - provider = providers.Configuration() - self.assertIs(provider.set_default({}), provider) - - def test_set_strict_returns_self(self): - provider = providers.Configuration() - self.assertIs(provider.set_strict(True), provider) - - def test_default_name(self): - config = providers.Configuration() - self.assertEqual(config.get_name(), "config") - - def test_providers_are_providers(self): - self.assertTrue(providers.is_provider(self.config.a)) - self.assertTrue(providers.is_provider(self.config.a.b)) - self.assertTrue(providers.is_provider(self.config.a.b.c)) - self.assertTrue(providers.is_provider(self.config.a.b.d)) - - def test_providers_are_not_delegates(self): - self.assertFalse(providers.is_delegated(self.config.a)) - self.assertFalse(providers.is_delegated(self.config.a.b)) - self.assertFalse(providers.is_delegated(self.config.a.b.c)) - self.assertFalse(providers.is_delegated(self.config.a.b.d)) - - def test_providers_identity(self): - self.assertIs(self.config.a, self.config.a) - self.assertIs(self.config.a.b, self.config.a.b) - self.assertIs(self.config.a.b.c, self.config.a.b.c) - self.assertIs(self.config.a.b.d, self.config.a.b.d) - - def test_get_name(self): - self.assertEqual(self.config.a.b.c.get_name(), "config.a.b.c") - - def test_providers_value_setting(self): - a = self.config.a - ab = self.config.a.b - abc = self.config.a.b.c - abd = self.config.a.b.d - - self.config.update({"a": {"b": {"c": 1, "d": 2}}}) - - self.assertEqual(a(), {"b": {"c": 1, "d": 2}}) - self.assertEqual(ab(), {"c": 1, "d": 2}) - self.assertEqual(abc(), 1) - self.assertEqual(abd(), 2) - - def test_providers_with_already_set_value(self): - self.config.update({"a": {"b": {"c": 1, "d": 2}}}) - - a = self.config.a - ab = self.config.a.b - abc = self.config.a.b.c - abd = self.config.a.b.d - - self.assertEqual(a(), {"b": {"c": 1, "d": 2}}) - self.assertEqual(ab(), {"c": 1, "d": 2}) - self.assertEqual(abc(), 1) - self.assertEqual(abd(), 2) - - def test_as_int(self): - value_provider = providers.Callable(lambda value: value, self.config.test.as_int()) - self.config.from_dict({"test": "123"}) - - value = value_provider() - - self.assertEqual(value, 123) - - def test_as_float(self): - value_provider = providers.Callable(lambda value: value, self.config.test.as_float()) - self.config.from_dict({"test": "123.123"}) - - value = value_provider() - - self.assertEqual(value, 123.123) - - def test_as_(self): - value_provider = providers.Callable( - lambda value: value, - self.config.test.as_(decimal.Decimal), - ) - self.config.from_dict({"test": "123.123"}) - - value = value_provider() - - self.assertEqual(value, decimal.Decimal("123.123")) - - @unittest.skipIf(sys.version_info[:2] == (2, 7), "Python 2.7 does not support this assert") - def test_required(self): - provider = providers.Callable( - lambda value: value, - self.config.a.required(), - ) - with self.assertRaisesRegex(errors.Error, "Undefined configuration option \"config.a\""): - provider() - - def test_required_defined_none(self): - provider = providers.Callable( - lambda value: value, - self.config.a.required(), - ) - self.config.from_dict({"a": None}) - self.assertIsNone(provider()) - - def test_required_no_side_effect(self): - _ = providers.Callable( - lambda value: value, - self.config.a.required(), - ) - self.assertIsNone(self.config.a()) - - def test_required_as_(self): - provider = providers.List( - self.config.int_test.required().as_int(), - self.config.float_test.required().as_float(), - self.config._as_test.required().as_(decimal.Decimal), - ) - self.config.from_dict({"int_test": "1", "float_test": "2.0", "_as_test": "3.0"}) - - self.assertEqual(provider(), [1, 2.0, decimal.Decimal("3.0")]) - - def test_providers_value_override(self): - a = self.config.a - ab = self.config.a.b - abc = self.config.a.b.c - abd = self.config.a.b.d - - self.config.override({"a": {"b": {"c": 1, "d": 2}}}) - - self.assertEqual(a(), {"b": {"c": 1, "d": 2}}) - self.assertEqual(ab(), {"c": 1, "d": 2}) - self.assertEqual(abc(), 1) - self.assertEqual(abd(), 2) - - def test_configuration_option_override_and_reset_override(self): - # Bug: https://github.com/ets-labs/python-dependency-injector/issues/319 - self.config.from_dict({"a": {"b": {"c": 1}}}) - - self.assertEqual(self.config.a.b.c(), 1) - - with self.config.set("a.b.c", "xxx"): - self.assertEqual(self.config.a.b.c(), "xxx") - self.assertEqual(self.config.a.b.c(), 1) - - with self.config.a.b.c.override("yyy"): - self.assertEqual(self.config.a.b.c(), "yyy") - - self.assertEqual(self.config.a.b.c(), 1) - - def test_providers_with_already_overridden_value(self): - self.config.override({"a": {"b": {"c": 1, "d": 2}}}) - - a = self.config.a - ab = self.config.a.b - abc = self.config.a.b.c - abd = self.config.a.b.d - - self.assertEqual(a(), {"b": {"c": 1, "d": 2}}) - self.assertEqual(ab(), {"c": 1, "d": 2}) - self.assertEqual(abc(), 1) - self.assertEqual(abd(), 2) - - def test_providers_with_default_value(self): - self.config = providers.Configuration( - name="config", default={"a": {"b": {"c": 1, "d": 2}}}) - - a = self.config.a - ab = self.config.a.b - abc = self.config.a.b.c - abd = self.config.a.b.d - - self.assertEqual(a(), {"b": {"c": 1, "d": 2}}) - self.assertEqual(ab(), {"c": 1, "d": 2}) - self.assertEqual(abc(), 1) - self.assertEqual(abd(), 2) - - def test_providers_with_default_value_overriding(self): - self.config = providers.Configuration( - name="config", default={"a": {"b": {"c": 1, "d": 2}}}) - - self.assertEqual(self.config.a(), {"b": {"c": 1, "d": 2}}) - self.assertEqual(self.config.a.b(), {"c": 1, "d": 2}) - self.assertEqual(self.config.a.b.c(), 1) - self.assertEqual(self.config.a.b.d(), 2) - - self.config.override({"a": {"b": {"c": 3, "d": 4}}}) - self.assertEqual(self.config.a(), {"b": {"c": 3, "d": 4}}) - self.assertEqual(self.config.a.b(), {"c": 3, "d": 4}) - self.assertEqual(self.config.a.b.c(), 3) - self.assertEqual(self.config.a.b.d(), 4) - - self.config.reset_override() - self.assertEqual(self.config.a(), {"b": {"c": 1, "d": 2}}) - self.assertEqual(self.config.a.b(), {"c": 1, "d": 2}) - self.assertEqual(self.config.a.b.c(), 1) - self.assertEqual(self.config.a.b.d(), 2) - - def test_value_of_undefined_option(self): - self.assertIsNone(self.config.a()) - - @unittest.skipIf(sys.version_info[:2] == (2, 7), "Python 2.7 does not support this assert") - def test_value_of_undefined_option_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaisesRegex(errors.Error, "Undefined configuration option \"config.a\""): - self.config.a() - - @unittest.skipIf(sys.version_info[:2] == (2, 7), "Python 2.7 does not support this assert") - def test_value_of_undefined_option_with_root_none_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.override(None) - with self.assertRaisesRegex(errors.Error, "Undefined configuration option \"config.a\""): - self.config.a() - - def test_value_of_defined_none_option_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.from_dict({"a": None}) - self.assertIsNone(self.config.a()) - - def test_getting_of_special_attributes(self): - with self.assertRaises(AttributeError): - self.config.__name__ - - def test_getting_of_special_attributes_from_child(self): - a = self.config.a - with self.assertRaises(AttributeError): - a.__name__ - - def test_context_manager_alias(self): - class Container(containers.DeclarativeContainer): - config = providers.Configuration() - - container = Container() - - with container.config as cfg: - cfg.override({"foo": "foo", "bar": "bar"}) - - self.assertEqual(container.config(), {"foo": "foo", "bar": "bar"}) - self.assertEqual(cfg(), {"foo": "foo", "bar": "bar"}) - self.assertIs(container.config, cfg) - - def test_option_context_manager_alias(self): - class Container(containers.DeclarativeContainer): - config = providers.Configuration() - - container = Container() - - with container.config.option as opt: - opt.override({"foo": "foo", "bar": "bar"}) - - self.assertEqual(container.config(), {"option": {"foo": "foo", "bar": "bar"}}) - self.assertEqual(container.config.option(), {"foo": "foo", "bar": "bar"}) - self.assertEqual(opt(), {"foo": "foo", "bar": "bar"}) - self.assertIs(container.config.option, opt) - - def test_missing_key(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/358 - self.config.override(None) - value = self.config.key() - - self.assertIsNone(value) - - def test_deepcopy(self): - provider = providers.Configuration("config") - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Configuration) - - def test_deepcopy_from_memo(self): - provider = providers.Configuration("config") - provider_copy_memo = providers.Configuration("config") - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_overridden(self): - provider = providers.Configuration("config") - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Configuration) - - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - - def test_repr(self): - self.assertEqual(repr(self.config), - "".format( - repr("config"), - hex(id(self.config)))) - - def test_repr_child(self): - self.assertEqual(repr(self.config.a.b.c), - "".format( - repr("config.a.b.c"), - hex(id(self.config.a.b.c)))) - - -class ConfigLinkingTests(unittest.TestCase): - - class TestCore(containers.DeclarativeContainer): - config = providers.Configuration("core") - value_getter = providers.Callable(lambda _: _, config.value) - - class TestServices(containers.DeclarativeContainer): - config = providers.Configuration("services") - value_getter = providers.Callable(lambda _: _, config.value) - - def test(self): - root_config = providers.Configuration("main") - core = self.TestCore(config=root_config.core) - services = self.TestServices(config=root_config.services) - - root_config.override( - { - "core": { - "value": "core", - }, - "services": { - "value": "services", - }, - }, - ) - - self.assertEqual(core.config(), {"value": "core"}) - self.assertEqual(core.config.value(), "core") - self.assertEqual(core.value_getter(), "core") - - self.assertEqual(services.config(), {"value": "services"}) - self.assertEqual(services.config.value(), "services") - self.assertEqual(services.value_getter(), "services") - - def test_double_override(self): - root_config = providers.Configuration("main") - core = self.TestCore(config=root_config.core) - services = self.TestServices(config=root_config.services) - - root_config.override( - { - "core": { - "value": "core1", - }, - "services": { - "value": "services1", - }, - }, - ) - root_config.override( - { - "core": { - "value": "core2", - }, - "services": { - "value": "services2", - }, - }, - ) - - self.assertEqual(core.config(), {"value": "core2"}) - self.assertEqual(core.config.value(), "core2") - self.assertEqual(core.value_getter(), "core2") - - self.assertEqual(services.config(), {"value": "services2"}) - self.assertEqual(services.config.value(), "services2") - self.assertEqual(services.value_getter(), "services2") - - def test_reset_overriding_cache(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/428 - class Core(containers.DeclarativeContainer): - config = providers.Configuration() - - greetings = providers.Factory(str, config.greeting) - - class Application(containers.DeclarativeContainer): - config = providers.Configuration() - - core = providers.Container( - Core, - config=config, - ) - - greetings = providers.Factory(str, config.greeting) - - container = Application() - - container.config.set("greeting", "Hello World") - self.assertEqual(container.greetings(), "Hello World") - self.assertEqual(container.core.greetings(), "Hello World") - - container.config.set("greeting", "Hello Bob") - self.assertEqual(container.greetings(), "Hello Bob") - self.assertEqual(container.core.greetings(), "Hello Bob") - - def test_reset_overriding_cache_for_option(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/428 - class Core(containers.DeclarativeContainer): - config = providers.Configuration() - - greetings = providers.Factory(str, config.greeting) - - class Application(containers.DeclarativeContainer): - config = providers.Configuration() - - core = providers.Container( - Core, - config=config.option, - ) - - greetings = providers.Factory(str, config.option.greeting) - - container = Application() - - container.config.set("option.greeting", "Hello World") - self.assertEqual(container.greetings(), "Hello World") - self.assertEqual(container.core.greetings(), "Hello World") - - container.config.set("option.greeting", "Hello Bob") - self.assertEqual(container.greetings(), "Hello Bob") - self.assertEqual(container.core.greetings(), "Hello Bob") - - -class ConfigFromIniTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - _, self.config_file_1 = tempfile.mkstemp() - with open(self.config_file_1, "w") as config_file: - config_file.write( - "[section1]\n" - "value1=1\n" - "\n" - "[section2]\n" - "value2=2\n" - ) - - _, self.config_file_2 = tempfile.mkstemp() - with open(self.config_file_2, "w") as config_file: - config_file.write( - "[section1]\n" - "value1=11\n" - "value11=11\n" - "[section3]\n" - "value3=3\n" - ) - - def tearDown(self): - del self.config - os.unlink(self.config_file_1) - os.unlink(self.config_file_2) - - def test(self): - self.config.from_ini(self.config_file_1) - - self.assertEqual(self.config(), {"section1": {"value1": "1"}, "section2": {"value2": "2"}}) - self.assertEqual(self.config.section1(), {"value1": "1"}) - self.assertEqual(self.config.section1.value1(), "1") - self.assertEqual(self.config.section2(), {"value2": "2"}) - self.assertEqual(self.config.section2.value2(), "2") - - def test_option(self): - self.config.option.from_ini(self.config_file_1) - - self.assertEqual(self.config(), {"option": {"section1": {"value1": "1"}, "section2": {"value2": "2"}}}) - self.assertEqual(self.config.option(), {"section1": {"value1": "1"}, "section2": {"value2": "2"}}) - self.assertEqual(self.config.option.section1(), {"value1": "1"}) - self.assertEqual(self.config.option.section1.value1(), "1") - self.assertEqual(self.config.option.section2(), {"value2": "2"}) - self.assertEqual(self.config.option.section2.value2(), "2") - - def test_merge(self): - self.config.from_ini(self.config_file_1) - self.config.from_ini(self.config_file_2) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": "11", - "value11": "11", - }, - "section2": { - "value2": "2", - }, - "section3": { - "value3": "3", - }, - }, - ) - self.assertEqual(self.config.section1(), {"value1": "11", "value11": "11"}) - self.assertEqual(self.config.section1.value1(), "11") - self.assertEqual(self.config.section1.value11(), "11") - self.assertEqual(self.config.section2(), {"value2": "2"}) - self.assertEqual(self.config.section2.value2(), "2") - self.assertEqual(self.config.section3(), {"value3": "3"}) - self.assertEqual(self.config.section3.value3(), "3") - - def test_file_does_not_exist(self): - self.config.from_ini("./does_not_exist.ini") - self.assertEqual(self.config(), {}) - - def test_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(IOError): - self.config.from_ini("./does_not_exist.ini") - - def test_option_file_does_not_exist(self): - self.config.option.from_ini("does_not_exist.ini") - self.assertIsNone(self.config.option.undefined()) - - def test_option_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(IOError): - self.config.option.from_ini("./does_not_exist.ini") - - def test_required_file_does_not_exist(self): - with self.assertRaises(IOError): - self.config.from_ini("./does_not_exist.ini", required=True) - - def test_required_option_file_does_not_exist(self): - with self.assertRaises(IOError): - self.config.option.from_ini("./does_not_exist.ini", required=True) - - def test_not_required_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.from_ini("./does_not_exist.ini", required=False) - self.assertEqual(self.config(), {}) - - def test_not_required_option_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.option.from_ini("./does_not_exist.ini", required=False) - with self.assertRaises(errors.Error): - self.config.option() - - -class ConfigFromIniWithEnvInterpolationTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - os.environ["CONFIG_TEST_ENV"] = "test-value" - os.environ["CONFIG_TEST_PATH"] = "test-path" - - _, self.config_file = tempfile.mkstemp() - with open(self.config_file, "w") as config_file: - config_file.write( - "[section1]\n" - "value1=${CONFIG_TEST_ENV}\n" - "value2=${CONFIG_TEST_PATH}/path\n" - ) - - def tearDown(self): - del self.config - os.environ.pop("CONFIG_TEST_ENV", None) - os.environ.pop("CONFIG_TEST_PATH", None) - os.unlink(self.config_file) - - def test_env_variable_interpolation(self): - self.config.from_ini(self.config_file) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": "test-value", - "value2": "test-path/path", - }, - }, - ) - self.assertEqual( - self.config.section1(), - { - "value1": "test-value", - "value2": "test-path/path", - }, - ) - self.assertEqual(self.config.section1.value1(), "test-value") - self.assertEqual(self.config.section1.value2(), "test-path/path") - - def test_missing_envs_not_required(self): - del os.environ["CONFIG_TEST_ENV"] - del os.environ["CONFIG_TEST_PATH"] - - self.config.from_ini(self.config_file) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": "", - "value2": "/path", - }, - }, - ) - self.assertEqual( - self.config.section1(), - { - "value1": "", - "value2": "/path", - }, - ) - self.assertEqual(self.config.section1.value1(), "") - self.assertEqual(self.config.section1.value2(), "/path") - - def test_missing_envs_required(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "[section]\n" - "undefined=${UNDEFINED}\n" - ) - - with self.assertRaises(ValueError) as context: - self.config.from_ini(self.config_file, envs_required=True) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_missing_envs_strict_mode(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "[section]\n" - "undefined=${UNDEFINED}\n" - ) - - self.config.set_strict(True) - with self.assertRaises(ValueError) as context: - self.config.from_ini(self.config_file) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_option_missing_envs_not_required(self): - del os.environ["CONFIG_TEST_ENV"] - del os.environ["CONFIG_TEST_PATH"] - - self.config.option.from_ini(self.config_file) - - self.assertEqual( - self.config.option(), - { - "section1": { - "value1": "", - "value2": "/path", - }, - }, - ) - self.assertEqual( - self.config.option.section1(), - { - "value1": "", - "value2": "/path", - }, - ) - self.assertEqual(self.config.option.section1.value1(), "") - self.assertEqual(self.config.option.section1.value2(), "/path") - - def test_option_missing_envs_required(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "[section]\n" - "undefined=${UNDEFINED}\n" - ) - - with self.assertRaises(ValueError) as context: - self.config.option.from_ini(self.config_file, envs_required=True) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_option_missing_envs_strict_mode(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "[section]\n" - "undefined=${UNDEFINED}\n" - ) - - self.config.set_strict(True) - with self.assertRaises(ValueError) as context: - self.config.option.from_ini(self.config_file) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_default_values(self): - os.environ["DEFINED"] = "defined" - self.addCleanup(os.environ.pop, "DEFINED") - - with open(self.config_file, "w") as config_file: - config_file.write( - "[section]\n" - "defined_with_default=${DEFINED:default}\n" - "undefined_with_default=${UNDEFINED:default}\n" - "complex=${DEFINED}/path/${DEFINED:default}/${UNDEFINED}/${UNDEFINED:default}\n" - ) - - self.config.from_ini(self.config_file) - - self.assertEqual( - self.config.section(), - { - "defined_with_default": "defined", - "undefined_with_default": "default", - "complex": "defined/path/defined//default", - }, - ) - - -class ConfigFromYamlTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - _, self.config_file_1 = tempfile.mkstemp() - with open(self.config_file_1, "w") as config_file: - config_file.write( - "section1:\n" - " value1: 1\n" - "\n" - "section2:\n" - " value2: 2\n" - ) - - _, self.config_file_2 = tempfile.mkstemp() - with open(self.config_file_2, "w") as config_file: - config_file.write( - "section1:\n" - " value1: 11\n" - " value11: 11\n" - "section3:\n" - " value3: 3\n" - ) - - def tearDown(self): - del self.config - os.unlink(self.config_file_1) - os.unlink(self.config_file_2) - - def test(self): - self.config.from_yaml(self.config_file_1) - - self.assertEqual(self.config(), {"section1": {"value1": 1}, "section2": {"value2": 2}}) - self.assertEqual(self.config.section1(), {"value1": 1}) - self.assertEqual(self.config.section1.value1(), 1) - self.assertEqual(self.config.section2(), {"value2": 2}) - self.assertEqual(self.config.section2.value2(), 2) - - def test_merge(self): - self.config.from_yaml(self.config_file_1) - self.config.from_yaml(self.config_file_2) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": 11, - "value11": 11, - }, - "section2": { - "value2": 2, - }, - "section3": { - "value3": 3, - }, - }, - ) - self.assertEqual(self.config.section1(), {"value1": 11, "value11": 11}) - self.assertEqual(self.config.section1.value1(), 11) - self.assertEqual(self.config.section1.value11(), 11) - self.assertEqual(self.config.section2(), {"value2": 2}) - self.assertEqual(self.config.section2.value2(), 2) - self.assertEqual(self.config.section3(), {"value3": 3}) - self.assertEqual(self.config.section3.value3(), 3) - - def test_file_does_not_exist(self): - self.config.from_yaml("./does_not_exist.yml") - self.assertEqual(self.config(), {}) - - def test_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(IOError): - self.config.from_yaml("./does_not_exist.yml") - - def test_option_file_does_not_exist(self): - self.config.option.from_yaml("./does_not_exist.yml") - self.assertIsNone(self.config.option()) - - def test_option_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(IOError): - self.config.option.from_yaml("./does_not_exist.yml") - - def test_required_file_does_not_exist(self): - with self.assertRaises(IOError): - self.config.from_yaml("./does_not_exist.yml", required=True) - - def test_required_option_file_does_not_exist(self): - with self.assertRaises(IOError): - self.config.option.from_yaml("./does_not_exist.yml", required=True) - - def test_not_required_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.from_yaml("./does_not_exist.yml", required=False) - self.assertEqual(self.config(), {}) - - def test_not_required_option_file_does_not_exist_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.option.from_yaml("./does_not_exist.yml", required=False) - with self.assertRaises(errors.Error): - self.config.option() - - def test_no_yaml_installed(self): - @contextlib.contextmanager - def no_yaml_module(): - yaml = providers.yaml - providers.yaml = None - - yield - - providers.yaml = yaml - - with no_yaml_module(): - with self.assertRaises(errors.Error) as error: - self.config.from_yaml(self.config_file_1) - - self.assertEqual( - error.exception.args[0], - "Unable to load yaml configuration - PyYAML is not installed. " - "Install PyYAML or install Dependency Injector with yaml extras: " - "\"pip install dependency-injector[yaml]\"", - ) - - def test_option_no_yaml_installed(self): - @contextlib.contextmanager - def no_yaml_module(): - yaml = providers.yaml - providers.yaml = None - - yield - - providers.yaml = yaml - - with no_yaml_module(): - with self.assertRaises(errors.Error) as error: - self.config.option.from_yaml(self.config_file_1) - - self.assertEqual( - error.exception.args[0], - "Unable to load yaml configuration - PyYAML is not installed. " - "Install PyYAML or install Dependency Injector with yaml extras: " - "\"pip install dependency-injector[yaml]\"", - ) - - -class ConfigFromYamlWithEnvInterpolationTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - os.environ["CONFIG_TEST_ENV"] = "test-value" - os.environ["CONFIG_TEST_PATH"] = "test-path" - - _, self.config_file = tempfile.mkstemp() - with open(self.config_file, "w") as config_file: - config_file.write( - "section1:\n" - " value1: ${CONFIG_TEST_ENV}\n" - " value2: ${CONFIG_TEST_PATH}/path\n" - ) - - def tearDown(self): - del self.config - os.environ.pop("CONFIG_TEST_ENV", None) - os.environ.pop("CONFIG_TEST_PATH", None) - os.unlink(self.config_file) - - def test_env_variable_interpolation(self): - self.config.from_yaml(self.config_file) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": "test-value", - "value2": "test-path/path", - }, - }, - ) - self.assertEqual( - self.config.section1(), - { - "value1": "test-value", - "value2": "test-path/path", - }, - ) - self.assertEqual(self.config.section1.value1(), "test-value") - self.assertEqual(self.config.section1.value2(), "test-path/path") - - def test_missing_envs_not_required(self): - del os.environ["CONFIG_TEST_ENV"] - del os.environ["CONFIG_TEST_PATH"] - - self.config.from_yaml(self.config_file) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": None, - "value2": "/path", - }, - }, - ) - self.assertEqual( - self.config.section1(), - { - "value1": None, - "value2": "/path", - }, - ) - self.assertIsNone(self.config.section1.value1()) - self.assertEqual(self.config.section1.value2(), "/path") - - def test_missing_envs_required(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "section:\n" - " undefined: ${UNDEFINED}\n" - ) - - with self.assertRaises(ValueError) as context: - self.config.from_yaml(self.config_file, envs_required=True) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_missing_envs_strict_mode(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "section:\n" - " undefined: ${UNDEFINED}\n" - ) - - self.config.set_strict(True) - with self.assertRaises(ValueError) as context: - self.config.from_yaml(self.config_file) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_option_missing_envs_not_required(self): - del os.environ["CONFIG_TEST_ENV"] - del os.environ["CONFIG_TEST_PATH"] - - self.config.option.from_yaml(self.config_file) - - self.assertEqual( - self.config.option(), - { - "section1": { - "value1": None, - "value2": "/path", - }, - }, - ) - self.assertEqual( - self.config.option.section1(), - { - "value1": None, - "value2": "/path", - }, - ) - self.assertIsNone(self.config.option.section1.value1()) - self.assertEqual(self.config.option.section1.value2(), "/path") - - def test_option_missing_envs_required(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "section:\n" - " undefined: ${UNDEFINED}\n" - ) - - with self.assertRaises(ValueError) as context: - self.config.option.from_yaml(self.config_file, envs_required=True) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_option_missing_envs_strict_mode(self): - with open(self.config_file, "w") as config_file: - config_file.write( - "section:\n" - " undefined: ${UNDEFINED}\n" - ) - - self.config.set_strict(True) - with self.assertRaises(ValueError) as context: - self.config.option.from_yaml(self.config_file) - - self.assertEqual( - str(context.exception), - "Missing required environment variable \"UNDEFINED\"", - ) - - def test_default_values(self): - os.environ["DEFINED"] = "defined" - self.addCleanup(os.environ.pop, "DEFINED") - - with open(self.config_file, "w") as config_file: - config_file.write( - "section:\n" - " defined_with_default: ${DEFINED:default}\n" - " undefined_with_default: ${UNDEFINED:default}\n" - " complex: ${DEFINED}/path/${DEFINED:default}/${UNDEFINED}/${UNDEFINED:default}\n" - ) - - self.config.from_yaml(self.config_file) - - self.assertEqual( - self.config.section(), - { - "defined_with_default": "defined", - "undefined_with_default": "default", - "complex": "defined/path/defined//default", - }, - ) - - def test_option_env_variable_interpolation(self): - self.config.option.from_yaml(self.config_file) - - self.assertEqual( - self.config.option(), - { - "section1": { - "value1": "test-value", - "value2": "test-path/path", - }, - }, - ) - self.assertEqual( - self.config.option.section1(), - { - "value1": "test-value", - "value2": "test-path/path", - }, - ) - self.assertEqual(self.config.option.section1.value1(), "test-value") - self.assertEqual(self.config.option.section1.value2(), "test-path/path") - - def test_env_variable_interpolation_custom_loader(self): - self.config.from_yaml(self.config_file, loader=yaml.UnsafeLoader) - - self.assertEqual( - self.config.section1(), - { - "value1": "test-value", - "value2": "test-path/path", - }, - ) - self.assertEqual(self.config.section1.value1(), "test-value") - self.assertEqual(self.config.section1.value2(), "test-path/path") - - def test_option_env_variable_interpolation_custom_loader(self): - self.config.option.from_yaml(self.config_file, loader=yaml.UnsafeLoader) - - self.assertEqual( - self.config.option.section1(), - { - "value1": "test-value", - "value2": "test-path/path", - }, - ) - self.assertEqual(self.config.option.section1.value1(), "test-value") - self.assertEqual(self.config.option.section1.value2(), "test-path/path") - - -class ConfigFromPydanticTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - class Section11(pydantic.BaseModel): - value1 = 1 - - class Section12(pydantic.BaseModel): - value2 = 2 - - class Settings1(pydantic.BaseSettings): - section1 = Section11() - section2 = Section12() - - self.Settings1 = Settings1 - - class Section21(pydantic.BaseModel): - value1 = 11 - value11 = 11 - - class Section3(pydantic.BaseModel): - value3 = 3 - - class Settings2(pydantic.BaseSettings): - section1 = Section21() - section3 = Section3() - - self.Settings2 = Settings2 - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test(self): - self.config.from_pydantic(self.Settings1()) - - self.assertEqual(self.config(), {"section1": {"value1": 1}, "section2": {"value2": 2}}) - self.assertEqual(self.config.section1(), {"value1": 1}) - self.assertEqual(self.config.section1.value1(), 1) - self.assertEqual(self.config.section2(), {"value2": 2}) - self.assertEqual(self.config.section2.value2(), 2) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_kwarg(self): - self.config.from_pydantic(self.Settings1(), exclude={"section2"}) - - self.assertEqual(self.config(), {"section1": {"value1": 1}}) - self.assertEqual(self.config.section1(), {"value1": 1}) - self.assertEqual(self.config.section1.value1(), 1) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_merge(self): - self.config.from_pydantic(self.Settings1()) - self.config.from_pydantic(self.Settings2()) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": 11, - "value11": 11, - }, - "section2": { - "value2": 2, - }, - "section3": { - "value3": 3, - }, - }, - ) - self.assertEqual(self.config.section1(), {"value1": 11, "value11": 11}) - self.assertEqual(self.config.section1.value1(), 11) - self.assertEqual(self.config.section1.value11(), 11) - self.assertEqual(self.config.section2(), {"value2": 2}) - self.assertEqual(self.config.section2.value2(), 2) - self.assertEqual(self.config.section3(), {"value3": 3}) - self.assertEqual(self.config.section3.value3(), 3) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_empty_settings(self): - self.config.from_pydantic(pydantic.BaseSettings()) - self.assertEqual(self.config(), {}) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_empty_settings_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(ValueError): - self.config.from_pydantic(pydantic.BaseSettings()) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_option_empty_settings(self): - self.config.option.from_pydantic(pydantic.BaseSettings()) - self.assertEqual(self.config.option(), {}) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_option_empty_settings_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(ValueError): - self.config.option.from_pydantic(pydantic.BaseSettings()) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_required_empty_settings(self): - with self.assertRaises(ValueError): - self.config.from_pydantic(pydantic.BaseSettings(), required=True) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_required_option_empty_settings(self): - with self.assertRaises(ValueError): - self.config.option.from_pydantic(pydantic.BaseSettings(), required=True) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_not_required_empty_settings_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.from_pydantic(pydantic.BaseSettings(), required=False) - self.assertEqual(self.config(), {}) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_not_required_option_empty_settings_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.option.from_pydantic(pydantic.BaseSettings(), required=False) - self.assertEqual(self.config.option(), {}) - self.assertEqual(self.config(), {"option": {}}) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_not_instance_of_settings(self): - with self.assertRaises(errors.Error) as error: - self.config.from_pydantic({}) - - self.assertEqual( - error.exception.args[0], - "Unable to recognize settings instance, expect \"pydantic.BaseSettings\", " - "got {0} instead".format({}) - ) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_option_not_instance_of_settings(self): - with self.assertRaises(errors.Error) as error: - self.config.option.from_pydantic({}) - - self.assertEqual( - error.exception.args[0], - "Unable to recognize settings instance, expect \"pydantic.BaseSettings\", " - "got {0} instead".format({}) - ) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_subclass_instead_of_instance(self): - with self.assertRaises(errors.Error) as error: - self.config.from_pydantic(self.Settings1) - - self.assertEqual( - error.exception.args[0], - "Got settings class, but expect instance: " - "instead \"Settings1\" use \"Settings1()\"" - ) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_option_subclass_instead_of_instance(self): - with self.assertRaises(errors.Error) as error: - self.config.option.from_pydantic(self.Settings1) - - self.assertEqual( - error.exception.args[0], - "Got settings class, but expect instance: " - "instead \"Settings1\" use \"Settings1()\"" - ) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_no_pydantic_installed(self): - @contextlib.contextmanager - def no_pydantic_module(): - pydantic = providers.pydantic - providers.pydantic = None - - yield - - providers.pydantic = pydantic - - with no_pydantic_module(): - with self.assertRaises(errors.Error) as error: - self.config.from_pydantic(self.Settings1()) - - self.assertEqual( - error.exception.args[0], - "Unable to load pydantic configuration - pydantic is not installed. " - "Install pydantic or install Dependency Injector with pydantic extras: " - "\"pip install dependency-injector[pydantic]\"", - ) - - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Pydantic supports Python 3.6+") - def test_option_no_pydantic_installed(self): - @contextlib.contextmanager - def no_pydantic_module(): - pydantic = providers.pydantic - providers.pydantic = None - - yield - - providers.pydantic = pydantic - - with no_pydantic_module(): - with self.assertRaises(errors.Error) as error: - self.config.option.from_pydantic(self.Settings1()) - - self.assertEqual( - error.exception.args[0], - "Unable to load pydantic configuration - pydantic is not installed. " - "Install pydantic or install Dependency Injector with pydantic extras: " - "\"pip install dependency-injector[pydantic]\"", - ) - - -class ConfigFromDict(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - self.config_options_1 = { - "section1": { - "value1": "1", - }, - "section2": { - "value2": "2", - }, - } - self.config_options_2 = { - "section1": { - "value1": "11", - "value11": "11", - }, - "section3": { - "value3": "3", - }, - } - - def test(self): - self.config.from_dict(self.config_options_1) - - self.assertEqual(self.config(), {"section1": {"value1": "1"}, "section2": {"value2": "2"}}) - self.assertEqual(self.config.section1(), {"value1": "1"}) - self.assertEqual(self.config.section1.value1(), "1") - self.assertEqual(self.config.section2(), {"value2": "2"}) - self.assertEqual(self.config.section2.value2(), "2") - - def test_merge(self): - self.config.from_dict(self.config_options_1) - self.config.from_dict(self.config_options_2) - - self.assertEqual( - self.config(), - { - "section1": { - "value1": "11", - "value11": "11", - }, - "section2": { - "value2": "2", - }, - "section3": { - "value3": "3", - }, - }, - ) - self.assertEqual(self.config.section1(), {"value1": "11", "value11": "11"}) - self.assertEqual(self.config.section1.value1(), "11") - self.assertEqual(self.config.section1.value11(), "11") - self.assertEqual(self.config.section2(), {"value2": "2"}) - self.assertEqual(self.config.section2.value2(), "2") - self.assertEqual(self.config.section3(), {"value3": "3"}) - self.assertEqual(self.config.section3.value3(), "3") - - def test_empty_dict(self): - self.config.from_dict({}) - self.assertEqual(self.config(), {}) - - def test_option_empty_dict(self): - self.config.option.from_dict({}) - self.assertEqual(self.config.option(), {}) - - def test_empty_dict_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(ValueError): - self.config.from_dict({}) - - def test_option_empty_dict_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(ValueError): - self.config.option.from_dict({}) - - def test_required_empty_dict(self): - with self.assertRaises(ValueError): - self.config.from_dict({}, required=True) - - def test_required_option_empty_dict(self): - with self.assertRaises(ValueError): - self.config.option.from_dict({}, required=True) - - def test_not_required_empty_dict_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.from_dict({}, required=False) - self.assertEqual(self.config(), {}) - - def test_not_required_option_empty_dict_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.option.from_dict({}, required=False) - self.assertEqual(self.config.option(), {}) - self.assertEqual(self.config(), {"option": {}}) - - -class ConfigFromEnvTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - os.environ["CONFIG_TEST_ENV"] = "test-value" - - def tearDown(self): - del self.config - del os.environ["CONFIG_TEST_ENV"] - - def test(self): - self.config.from_env("CONFIG_TEST_ENV") - self.assertEqual(self.config(), "test-value") - - def test_with_children(self): - self.config.section1.value1.from_env("CONFIG_TEST_ENV") - - self.assertEqual(self.config(), {"section1": {"value1": "test-value"}}) - self.assertEqual(self.config.section1(), {"value1": "test-value"}) - self.assertEqual(self.config.section1.value1(), "test-value") - - def test_default(self): - self.config.from_env("UNDEFINED_ENV", "default-value") - self.assertEqual(self.config(), "default-value") - - def test_default_none(self): - self.config.from_env("UNDEFINED_ENV") - self.assertIsNone(self.config()) - - def test_option_default_none(self): - self.config.option.from_env("UNDEFINED_ENV") - self.assertIsNone(self.config.option()) - - def test_undefined_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(ValueError): - self.config.from_env("UNDEFINED_ENV") - - def test_option_undefined_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - with self.assertRaises(ValueError): - self.config.option.from_env("UNDEFINED_ENV") - - def test_undefined_in_strict_mode_with_default(self): - self.config = providers.Configuration(strict=True) - self.config.from_env("UNDEFINED_ENV", "default-value") - self.assertEqual(self.config(), "default-value") - - def test_option_undefined_in_strict_mode_with_default(self): - self.config = providers.Configuration(strict=True) - self.config.option.from_env("UNDEFINED_ENV", "default-value") - self.assertEqual(self.config.option(), "default-value") - - def test_required_undefined(self): - with self.assertRaises(ValueError): - self.config.from_env("UNDEFINED_ENV", required=True) - - def test_required_undefined_with_default(self): - self.config.from_env("UNDEFINED_ENV", default="default-value", required=True) - self.assertEqual(self.config(), "default-value") - - def test_option_required_undefined(self): - with self.assertRaises(ValueError): - self.config.option.from_env("UNDEFINED_ENV", required=True) - - def test_option_required_undefined_with_default(self): - self.config.option.from_env("UNDEFINED_ENV", default="default-value", required=True) - self.assertEqual(self.config.option(), "default-value") - - def test_not_required_undefined_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.from_env("UNDEFINED_ENV", required=False) - self.assertIsNone(self.config()) - - def test_option_not_required_undefined_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.option.from_env("UNDEFINED_ENV", required=False) - self.assertIsNone(self.config.option()) - - def test_not_required_undefined_with_default_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.from_env("UNDEFINED_ENV", default="default-value", required=False) - self.assertEqual(self.config(), "default-value") - - def test_option_not_required_undefined_with_default_in_strict_mode(self): - self.config = providers.Configuration(strict=True) - self.config.option.from_env("UNDEFINED_ENV", default="default-value", required=False) - self.assertEqual(self.config.option(), "default-value") - - -class ConfigFromValueTests(unittest.TestCase): - - def setUp(self): - self.config = providers.Configuration(name="config") - - def test_from_value(self): - test_value = 123321 - self.config.from_value(test_value) - self.assertEqual(self.config(), test_value) - - def test_option_from_value(self): - test_value_1 = 123 - test_value_2 = 321 - - self.config.option1.from_value(test_value_1) - self.config.option2.from_value(test_value_2) - - self.assertEqual(self.config(), {"option1": test_value_1, "option2": test_value_2}) - self.assertEqual(self.config.option1(), test_value_1) - self.assertEqual(self.config.option2(), test_value_2) diff --git a/tests/unit/providers/test_container_py2_py3.py b/tests/unit/providers/test_container_py2_py3.py index 341d1643..d594f369 100644 --- a/tests/unit/providers/test_container_py2_py3.py +++ b/tests/unit/providers/test_container_py2_py3.py @@ -1,10 +1,9 @@ -"""Dependency injector container provider unit tests.""" +"""Container provider tests.""" import copy -import unittest - from dependency_injector import containers, providers, errors +from pytest import raises TEST_VALUE_1 = "core_section_value1" @@ -30,237 +29,238 @@ def _copied(value): return copy.deepcopy(value) -class TestCore(containers.DeclarativeContainer): +class Core(containers.DeclarativeContainer): config = providers.Configuration("core") value_getter = providers.Callable(lambda _: _, config.section.value) -class TestApplication(containers.DeclarativeContainer): +class Application(containers.DeclarativeContainer): config = providers.Configuration("config") - core = providers.Container(TestCore, config=config.core) + core = providers.Container(Core, config=config.core) dict_factory = providers.Factory(dict, value=core.value_getter) -class ContainerTests(unittest.TestCase): +def test(): + application = Application(config=_copied(TEST_CONFIG_1)) + assert application.dict_factory() == {"value": TEST_VALUE_1} - def test(self): - application = TestApplication(config=_copied(TEST_CONFIG_1)) - self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) - def test_double_override(self): - application = TestApplication() - application.config.override(_copied(TEST_CONFIG_1)) - application.config.override(_copied(TEST_CONFIG_2)) - self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_2}) +def test_double_override(): + application = Application() + application.config.override(_copied(TEST_CONFIG_1)) + application.config.override(_copied(TEST_CONFIG_2)) + assert application.dict_factory() == {"value": TEST_VALUE_2} - def test_override(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/354 - class D(containers.DeclarativeContainer): - foo = providers.Object("foo") - class A(containers.DeclarativeContainer): - d = providers.DependenciesContainer() - bar = providers.Callable(lambda f: f + "++", d.foo.provided) +def test_override(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/354 + class D(containers.DeclarativeContainer): + foo = providers.Object("foo") - class B(containers.DeclarativeContainer): - d = providers.Container(D) + class A(containers.DeclarativeContainer): + d = providers.DependenciesContainer() + bar = providers.Callable(lambda f: f + "++", d.foo.provided) - a = providers.Container(A, d=d) + class B(containers.DeclarativeContainer): + d = providers.Container(D) - b = B(d=D()) - result = b.a().bar() - self.assertEqual(result, "foo++") + a = providers.Container(A, d=d) - def test_override_not_root_provider(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/379 - class NestedContainer(containers.DeclarativeContainer): - settings = providers.Configuration() + b = B(d=D()) + result = b.a().bar() + assert result == "foo++" - print_settings = providers.Callable( - lambda s: s, - settings, - ) - class TestContainer(containers.DeclarativeContainer): - settings = providers.Configuration() +def test_override_not_root_provider(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/379 + class NestedContainer(containers.DeclarativeContainer): + settings = providers.Configuration() - root_container = providers.Container( + print_settings = providers.Callable( + lambda s: s, + settings, + ) + + class TestContainer(containers.DeclarativeContainer): + settings = providers.Configuration() + + root_container = providers.Container( + NestedContainer, + settings=settings, + ) + + not_root_container = providers.Selector( + settings.container, + using_factory=providers.Factory( + NestedContainer, + settings=settings, + ), + using_container=providers.Container( NestedContainer, settings=settings, ) - - not_root_container = providers.Selector( - settings.container, - using_factory=providers.Factory( - NestedContainer, - settings=settings, - ), - using_container=providers.Container( - NestedContainer, - settings=settings, - ) - ) - - container_using_factory = TestContainer(settings=dict( - container="using_factory", - foo="bar" - )) - self.assertEqual( - container_using_factory.root_container().print_settings(), - {"container": "using_factory", "foo": "bar"}, - ) - self.assertEqual( - container_using_factory.not_root_container().print_settings(), - {"container": "using_factory", "foo": "bar"}, ) + container_using_factory = TestContainer(settings=dict( + container="using_factory", + foo="bar" + )) + assert container_using_factory.root_container().print_settings() == {"container": "using_factory", "foo": "bar"} + assert container_using_factory.not_root_container().print_settings() == {"container": "using_factory", "foo": "bar"} - container_using_container = TestContainer(settings=dict( - container="using_container", - foo="bar" - )) - self.assertEqual( - container_using_container.root_container().print_settings(), - {"container": "using_container", "foo": "bar"}, - ) - self.assertEqual( - container_using_container.not_root_container().print_settings(), - {"container": "using_container", "foo": "bar"}, - ) + container_using_container = TestContainer(settings=dict( + container="using_container", + foo="bar" + )) + assert container_using_container.root_container().print_settings() == {"container": "using_container", "foo": "bar"} + assert container_using_container.not_root_container().print_settings() == {"container": "using_container", "foo": "bar"} - def test_override_by_not_a_container(self): - provider = providers.Container(TestCore) - with self.assertRaises(errors.Error): - provider.override(providers.Object("foo")) +def test_override_by_not_a_container(): + provider = providers.Container(Core) - def test_lazy_overriding(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/354 + with raises(errors.Error): + provider.override(providers.Object("foo")) - class D(containers.DeclarativeContainer): - foo = providers.Object("foo") - class A(containers.DeclarativeContainer): - d = providers.DependenciesContainer() - bar = providers.Callable(lambda f: f + "++", d.foo.provided) +def test_lazy_overriding(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/354 + class D(containers.DeclarativeContainer): + foo = providers.Object("foo") - class B(containers.DeclarativeContainer): - d = providers.DependenciesContainer() + class A(containers.DeclarativeContainer): + d = providers.DependenciesContainer() + bar = providers.Callable(lambda f: f + "++", d.foo.provided) - a = providers.Container(A, d=d) + class B(containers.DeclarativeContainer): + d = providers.DependenciesContainer() - b = B(d=D()) - result = b.a().bar() - self.assertEqual(result, "foo++") + a = providers.Container(A, d=d) - def test_lazy_overriding_deep(self): - # Extended version of test_lazy_overriding() + b = B(d=D()) + result = b.a().bar() + assert result == "foo++" - class D(containers.DeclarativeContainer): - foo = providers.Object("foo") - class C(containers.DeclarativeContainer): - d = providers.DependenciesContainer() - bar = providers.Callable(lambda f: f + "++", d.foo.provided) +def test_lazy_overriding_deep(): + # Extended version of test_lazy_overriding() + class D(containers.DeclarativeContainer): + foo = providers.Object("foo") - class A(containers.DeclarativeContainer): - d = providers.DependenciesContainer() - c = providers.Container(C, d=d) + class C(containers.DeclarativeContainer): + d = providers.DependenciesContainer() + bar = providers.Callable(lambda f: f + "++", d.foo.provided) - class B(containers.DeclarativeContainer): - d = providers.DependenciesContainer() + class A(containers.DeclarativeContainer): + d = providers.DependenciesContainer() + c = providers.Container(C, d=d) - a = providers.Container(A, d=d) + class B(containers.DeclarativeContainer): + d = providers.DependenciesContainer() - b = B(d=D()) - result = b.a().c().bar() - self.assertEqual(result, "foo++") + a = providers.Container(A, d=d) - def test_reset_last_overriding(self): - application = TestApplication(config=_copied(TEST_CONFIG_1)) - application.core.override(TestCore(config=_copied(TEST_CONFIG_2["core"]))) + b = B(d=D()) + result = b.a().c().bar() + assert result == "foo++" - application.core.reset_last_overriding() - self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) +def test_reset_last_overriding(): + application = Application(config=_copied(TEST_CONFIG_1)) + application.core.override(Core(config=_copied(TEST_CONFIG_2["core"]))) - def test_reset_last_overriding_only_overridden(self): - application = TestApplication(config=_copied(TEST_CONFIG_1)) - application.core.override(providers.DependenciesContainer(config=_copied(TEST_CONFIG_2["core"]))) + application.core.reset_last_overriding() - application.core.reset_last_overriding() + assert application.dict_factory() == {"value": TEST_VALUE_1} - self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) - def test_override_context_manager(self): - application = TestApplication(config=_copied(TEST_CONFIG_1)) - overriding_core = TestCore(config=_copied(TEST_CONFIG_2["core"])) +def test_reset_last_overriding_only_overridden(): + application = Application(config=_copied(TEST_CONFIG_1)) + application.core.override(providers.DependenciesContainer(config=_copied(TEST_CONFIG_2["core"]))) - with application.core.override(overriding_core) as context_core: - self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_2}) - self.assertIs(context_core(), overriding_core) + application.core.reset_last_overriding() - self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) + assert application.dict_factory() == {"value": TEST_VALUE_1} - def test_reset_override(self): - application = TestApplication(config=_copied(TEST_CONFIG_1)) - application.core.override(TestCore(config=_copied(TEST_CONFIG_2["core"]))) - application.core.reset_override() +def test_override_context_manager(): + application = Application(config=_copied(TEST_CONFIG_1)) + overriding_core = Core(config=_copied(TEST_CONFIG_2["core"])) - self.assertEqual(application.dict_factory(), {"value": None}) + with application.core.override(overriding_core) as context_core: + assert application.dict_factory() == {"value": TEST_VALUE_2} + assert context_core() is overriding_core - def test_reset_override_only_overridden(self): - application = TestApplication(config=_copied(TEST_CONFIG_1)) - application.core.override(providers.DependenciesContainer(config=_copied(TEST_CONFIG_2["core"]))) + assert application.dict_factory() == {"value": TEST_VALUE_1} - application.core.reset_override() - self.assertEqual(application.dict_factory(), {"value": None}) +def test_reset_override(): + application = Application(config=_copied(TEST_CONFIG_1)) + application.core.override(Core(config=_copied(TEST_CONFIG_2["core"]))) - def test_assign_parent(self): - parent = providers.DependenciesContainer() - provider = providers.Container(TestCore) + application.core.reset_override() - provider.assign_parent(parent) + assert application.dict_factory() == {"value": None} - self.assertIs(provider.parent, parent) - def test_parent_name(self): - container = containers.DynamicContainer() - provider = providers.Container(TestCore) - container.name = provider - self.assertEqual(provider.parent_name, "name") +def test_reset_override_only_overridden(): + application = Application(config=_copied(TEST_CONFIG_1)) + application.core.override(providers.DependenciesContainer(config=_copied(TEST_CONFIG_2["core"]))) - def test_parent_name_with_deep_parenting(self): - provider = providers.Container(TestCore) - container = providers.DependenciesContainer(name=provider) - _ = providers.DependenciesContainer(container=container) - self.assertEqual(provider.parent_name, "container.name") + application.core.reset_override() - def test_parent_name_is_none(self): - provider = providers.Container(TestCore) - self.assertIsNone(provider.parent_name) + assert application.dict_factory() == {"value": None} - def test_parent_deepcopy(self): - container = containers.DynamicContainer() - provider = providers.Container(TestCore) - container.name = provider - copied = providers.deepcopy(container) +def test_assign_parent(): + parent = providers.DependenciesContainer() + provider = providers.Container(Core) - self.assertIs(container.name.parent, container) - self.assertIs(copied.name.parent, copied) + provider.assign_parent(parent) - self.assertIsNot(container, copied) - self.assertIsNot(container.name, copied.name) - self.assertIsNot(container.name.parent, copied.name.parent) + assert provider.parent is parent - def test_resolve_provider_name(self): - container = providers.Container(TestCore) - self.assertEqual(container.resolve_provider_name(container.value_getter), "value_getter") - def test_resolve_provider_name_no_provider(self): - container = providers.Container(TestCore) - with self.assertRaises(errors.Error): - container.resolve_provider_name(providers.Provider()) +def test_parent_name(): + container = containers.DynamicContainer() + provider = providers.Container(Core) + container.name = provider + assert provider.parent_name == "name" + + +def test_parent_name_with_deep_parenting(): + provider = providers.Container(Core) + container = providers.DependenciesContainer(name=provider) + _ = providers.DependenciesContainer(container=container) + assert provider.parent_name == "container.name" + + +def test_parent_name_is_none(): + provider = providers.Container(Core) + assert provider.parent_name is None + + +def test_parent_deepcopy(): + container = containers.DynamicContainer() + provider = providers.Container(Core) + container.name = provider + + copied = providers.deepcopy(container) + + assert container.name.parent is container + assert copied.name.parent is copied + + assert container is not copied + assert container.name is not copied.name + assert container.name.parent is not copied.name.parent + + +def test_resolve_provider_name(): + container = providers.Container(Core) + assert container.resolve_provider_name(container.value_getter) == "value_getter" + + +def test_resolve_provider_name_no_provider(): + container = providers.Container(Core) + with raises(errors.Error): + container.resolve_provider_name(providers.Provider()) diff --git a/tests/unit/providers/test_coroutines_py35.py b/tests/unit/providers/test_coroutines_py35.py deleted file mode 100644 index 18eabaf0..00000000 --- a/tests/unit/providers/test_coroutines_py35.py +++ /dev/null @@ -1,310 +0,0 @@ -"""Dependency injector coroutine providers unit tests.""" - -import asyncio -import unittest -import warnings - -from dependency_injector import ( - providers, - errors, -) - -# Runtime import to get asyncutils module -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -import sys -sys.path.append(_TOP_DIR) - -from asyncutils import AsyncTestCase - - -async def _example(arg1, arg2, arg3, arg4): - future = asyncio.Future() - future.set_result(None) - await future - return arg1, arg2, arg3, arg4 - - -def run(main): - loop = asyncio.get_event_loop() - return loop.run_until_complete(main) - - -class CoroutineTests(AsyncTestCase): - - def test_init_with_coroutine(self): - self.assertTrue(providers.Coroutine(_example)) - - def test_init_with_not_coroutine(self): - self.assertRaises(errors.Error, providers.Coroutine, lambda: None) - - def test_init_optional_provides(self): - provider = providers.Coroutine() - provider.set_provides(_example) - self.assertIs(provider.provides, _example) - self.assertEqual(run(provider(1, 2, 3, 4)), (1, 2, 3, 4)) - - def test_set_provides_returns_self(self): - provider = providers.Coroutine() - self.assertIs(provider.set_provides(_example), provider) - - def test_call_with_positional_args(self): - provider = providers.Coroutine(_example, 1, 2, 3, 4) - self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4)) - - def test_call_with_keyword_args(self): - provider = providers.Coroutine(_example, - arg1=1, arg2=2, arg3=3, arg4=4) - self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4)) - - def test_call_with_positional_and_keyword_args(self): - provider = providers.Coroutine(_example, - 1, 2, - arg3=3, arg4=4) - self.assertTupleEqual(run(provider()), (1, 2, 3, 4)) - - def test_call_with_context_args(self): - provider = providers.Coroutine(_example, 1, 2) - self.assertTupleEqual(self._run(provider(3, 4)), (1, 2, 3, 4)) - - def test_call_with_context_kwargs(self): - provider = providers.Coroutine(_example, arg1=1) - self.assertTupleEqual( - self._run(provider(arg2=2, arg3=3, arg4=4)), - (1, 2, 3, 4), - ) - - def test_call_with_context_args_and_kwargs(self): - provider = providers.Coroutine(_example, 1) - self.assertTupleEqual( - self._run(provider(2, arg3=3, arg4=4)), - (1, 2, 3, 4), - ) - - def test_fluent_interface(self): - provider = providers.Coroutine(_example) \ - .add_args(1, 2) \ - .add_kwargs(arg3=3, arg4=4) - - self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4)) - - def test_set_args(self): - provider = providers.Coroutine(_example) \ - .add_args(1, 2) \ - .set_args(3, 4) - self.assertEqual(provider.args, (3, 4)) - - def test_set_kwargs(self): - provider = providers.Coroutine(_example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .set_kwargs(init_arg3=4, init_arg4=5) - self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) - - def test_clear_args(self): - provider = providers.Coroutine(_example) \ - .add_args(1, 2) \ - .clear_args() - self.assertEqual(provider.args, tuple()) - - def test_clear_kwargs(self): - provider = providers.Coroutine(_example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .clear_kwargs() - self.assertEqual(provider.kwargs, dict()) - - def test_call_overridden(self): - provider = providers.Coroutine(_example) - - provider.override(providers.Object((4, 3, 2, 1))) - provider.override(providers.Object((1, 2, 3, 4))) - - self.assertTupleEqual(provider(), (1, 2, 3, 4)) - - def test_deepcopy(self): - provider = providers.Coroutine(_example) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.provides, provider_copy.provides) - self.assertIsInstance(provider, providers.Coroutine) - - def test_deepcopy_from_memo(self): - provider = providers.Coroutine(_example) - provider_copy_memo = providers.Coroutine(_example) - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_args(self): - provider = providers.Coroutine(_example) - dependent_provider1 = providers.Callable(list) - dependent_provider2 = providers.Callable(dict) - - provider.add_args(dependent_provider1, dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.args[0] - dependent_provider_copy2 = provider_copy.args[1] - - self.assertNotEqual(provider.args, provider_copy.args) - - self.assertIs(dependent_provider1.provides, - dependent_provider_copy1.provides) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.provides, - dependent_provider_copy2.provides) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_kwargs(self): - provider = providers.Coroutine(_example) - dependent_provider1 = providers.Callable(list) - dependent_provider2 = providers.Callable(dict) - - provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.kwargs["a1"] - dependent_provider_copy2 = provider_copy.kwargs["a2"] - - self.assertNotEqual(provider.kwargs, provider_copy.kwargs) - - self.assertIs(dependent_provider1.provides, - dependent_provider_copy1.provides) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.provides, - dependent_provider_copy2.provides) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_overridden(self): - provider = providers.Coroutine(_example) - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.provides, provider_copy.provides) - self.assertIsInstance(provider, providers.Callable) - - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - - def test_repr(self): - provider = providers.Coroutine(_example) - - self.assertEqual(repr(provider), - "".format( - repr(_example), - hex(id(provider)))) - - -class DelegatedCoroutineTests(unittest.TestCase): - - def test_inheritance(self): - self.assertIsInstance(providers.DelegatedCoroutine(_example), - providers.Coroutine) - - def test_is_provider(self): - self.assertTrue( - providers.is_provider(providers.DelegatedCoroutine(_example))) - - def test_is_delegated_provider(self): - provider = providers.DelegatedCoroutine(_example) - self.assertTrue(providers.is_delegated(provider)) - - def test_repr(self): - provider = providers.DelegatedCoroutine(_example) - - self.assertEqual(repr(provider), - "".format( - repr(_example), - hex(id(provider)))) - - -class AbstractCoroutineTests(AsyncTestCase): - - def test_inheritance(self): - self.assertIsInstance(providers.AbstractCoroutine(_example), - providers.Coroutine) - - def test_call_overridden_by_coroutine(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - @asyncio.coroutine - def _abstract_example(): - raise RuntimeError("Should not be raised") - - provider = providers.AbstractCoroutine(_abstract_example) - provider.override(providers.Coroutine(_example)) - - self.assertTrue(self._run(provider(1, 2, 3, 4)), (1, 2, 3, 4)) - - def test_call_overridden_by_delegated_coroutine(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - @asyncio.coroutine - def _abstract_example(): - raise RuntimeError("Should not be raised") - - provider = providers.AbstractCoroutine(_abstract_example) - provider.override(providers.DelegatedCoroutine(_example)) - - self.assertTrue(self._run(provider(1, 2, 3, 4)), (1, 2, 3, 4)) - - def test_call_not_overridden(self): - provider = providers.AbstractCoroutine(_example) - - with self.assertRaises(errors.Error): - provider(1, 2, 3, 4) - - def test_override_by_not_coroutine(self): - provider = providers.AbstractCoroutine(_example) - - with self.assertRaises(errors.Error): - provider.override(providers.Factory(object)) - - def test_provide_not_implemented(self): - provider = providers.AbstractCoroutine(_example) - - with self.assertRaises(NotImplementedError): - provider._provide((1, 2, 3, 4), dict()) - - def test_repr(self): - provider = providers.AbstractCoroutine(_example) - - self.assertEqual(repr(provider), - "".format( - repr(_example), - hex(id(provider)))) - - -class CoroutineDelegateTests(unittest.TestCase): - - def setUp(self): - self.delegated = providers.Coroutine(_example) - self.delegate = providers.CoroutineDelegate(self.delegated) - - def test_is_delegate(self): - self.assertIsInstance(self.delegate, providers.Delegate) - - def test_init_with_not_callable(self): - self.assertRaises(errors.Error, - providers.CoroutineDelegate, - providers.Object(object())) diff --git a/tests/unit/providers/test_delegate_py2_py3.py b/tests/unit/providers/test_delegate_py2_py3.py new file mode 100644 index 00000000..24be727b --- /dev/null +++ b/tests/unit/providers/test_delegate_py2_py3.py @@ -0,0 +1,49 @@ +"""Delegate provider tests.""" + +from dependency_injector import providers, errors +from pytest import fixture, raises + + +@fixture +def provider(): + return providers.Provider() + + +@fixture +def delegate(provider): + return providers.Delegate(provider) + + +def test_is_provider(delegate): + assert providers.is_provider(delegate) is True + + +def test_init_optional_provides(provider): + delegate = providers.Delegate() + delegate.set_provides(provider) + assert delegate.provides is provider + assert delegate() is provider + + +def test_set_provides_returns_self(delegate, provider): + assert delegate.set_provides(provider) is delegate + + +def test_init_with_not_provider(): + with raises(errors.Error): + providers.Delegate(object()) + + +def test_call(delegate, provider): + delegated1 = delegate() + delegated2 = delegate() + + assert delegated1 is provider + assert delegated2 is provider + + +def test_repr(delegate, provider): + assert repr(delegate) == ( + "".format(repr(provider), hex(id(delegate))) + ) diff --git a/tests/unit/providers/test_dependencies_container_py2_py3.py b/tests/unit/providers/test_dependencies_container_py2_py3.py new file mode 100644 index 00000000..600aa2f7 --- /dev/null +++ b/tests/unit/providers/test_dependencies_container_py2_py3.py @@ -0,0 +1,125 @@ +"""DependencyContainer provider tests.""" + +from dependency_injector import containers, providers, errors +from pytest import fixture, raises + + +class Container(containers.DeclarativeContainer): + + dependency = providers.Provider() + + +@fixture +def provider(): + return providers.DependenciesContainer() + + +@fixture +def container(): + return Container() + + +def test_getattr(provider): + has_dependency = hasattr(provider, "dependency") + dependency = provider.dependency + + assert isinstance(dependency, providers.Dependency) + assert dependency is provider.dependency + assert has_dependency is True + assert dependency.last_overriding is None + + +def test_getattr_with_container(provider, container): + provider.override(container) + + dependency = provider.dependency + + assert dependency.overridden == (container.dependency,) + assert dependency.last_overriding is container.dependency + + +def test_providers(provider): + dependency1 = provider.dependency1 + dependency2 = provider.dependency2 + assert provider.providers == {"dependency1": dependency1, "dependency2": dependency2} + + +def test_override(provider, container): + dependency = provider.dependency + provider.override(container) + + assert dependency.overridden == (container.dependency,) + assert dependency.last_overriding is container.dependency + + +def test_reset_last_overriding(provider, container): + dependency = provider.dependency + provider.override(container) + provider.reset_last_overriding() + + assert dependency.last_overriding is None + assert dependency.last_overriding is None + + +def test_reset_override(provider, container): + dependency = provider.dependency + provider.override(container) + provider.reset_override() + + assert dependency.overridden == tuple() + assert not dependency.overridden + + +def test_assign_parent(provider): + parent = providers.DependenciesContainer() + provider.assign_parent(parent) + assert provider.parent is parent + + +def test_parent_name(provider): + container = containers.DynamicContainer() + container.name = provider + assert provider.parent_name == "name" + + +def test_parent_name_with_deep_parenting(provider): + container = providers.DependenciesContainer(name=provider) + _ = providers.DependenciesContainer(container=container) + assert provider.parent_name == "container.name" + + +def test_parent_name_is_none(provider): + assert provider.parent_name is None + + +def test_parent_deepcopy(provider, container): + container.name = provider + copied = providers.deepcopy(container) + + assert container.name.parent is container + assert copied.name.parent is copied + + assert container is not copied + assert container.name is not copied.name + assert container.name.parent is not copied.name.parent + + +def test_parent_set_on__getattr__(provider): + assert isinstance(provider.name, providers.Dependency) + assert provider.name.parent is provider + + +def test_parent_set_on__init__(): + provider = providers.Dependency() + container = providers.DependenciesContainer(name=provider) + assert container.name is provider + assert container.name.parent is container + + +def test_resolve_provider_name(provider): + assert provider.resolve_provider_name(provider.name) == "name" + + +def test_resolve_provider_name_no_provider(provider): + with raises(errors.Error): + provider.resolve_provider_name(providers.Provider()) diff --git a/tests/unit/providers/test_dependency_py2_py3.py b/tests/unit/providers/test_dependency_py2_py3.py new file mode 100644 index 00000000..c4c2cea1 --- /dev/null +++ b/tests/unit/providers/test_dependency_py2_py3.py @@ -0,0 +1,348 @@ +"""Dependency provider tests.""" + +from dependency_injector import containers, providers, errors +from pytest import fixture, raises + + +@fixture +def provider(): + return providers.Dependency(instance_of=list) + + +def test_init_optional(): + list_provider = providers.List(1, 2, 3) + provider = providers.Dependency() + provider.set_instance_of(list) + provider.set_default(list_provider) + + assert provider.instance_of is list + assert provider.default is list_provider + assert provider() == [1, 2, 3] + + +def test_set_instance_of_returns_self(provider): + assert provider.set_instance_of(list) is provider + + +def test_set_default_returns_self(provider): + assert provider.set_default(providers.Provider()) is provider + + +def test_init_with_not_class(): + with raises(TypeError): + providers.Dependency(object()) + + +def test_with_abc(): + try: + import collections.abc as collections_abc + except ImportError: + import collections as collections_abc + + provider = providers.Dependency(collections_abc.Mapping) + provider.provided_by(providers.Factory(dict)) + + assert isinstance(provider(), collections_abc.Mapping) + assert isinstance(provider(), dict) + + +def test_is_provider(provider): + assert providers.is_provider(provider) is True + + +def test_provided_instance_provider(provider): + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_default(): + provider = providers.Dependency(instance_of=dict, default={"foo": "bar"}) + assert provider() == {"foo": "bar"} + + +def test_default_attribute(): + provider = providers.Dependency(instance_of=dict, default={"foo": "bar"}) + assert provider.default() == {"foo": "bar"} + + +def test_default_provider(): + provider = providers.Dependency(instance_of=dict, default=providers.Factory(dict, foo="bar")) + assert provider.default() == {"foo": "bar"} + + +def test_default_attribute_provider(): + default = providers.Factory(dict, foo="bar") + provider = providers.Dependency(instance_of=dict, default=default) + + assert provider.default() == {"foo": "bar"} + assert provider.default is default + + +def test_is_defined(provider): + assert provider.is_defined is False + + +def test_is_defined_when_overridden(provider): + provider.override("value") + assert provider.is_defined is True + + +def test_is_defined_with_default(): + provider = providers.Dependency(default="value") + assert provider.is_defined is True + + +def test_call_overridden(provider): + provider.provided_by(providers.Factory(list)) + assert isinstance(provider(), list) + + +def test_call_overridden_but_not_instance_of(provider): + provider.provided_by(providers.Factory(dict)) + with raises(errors.Error): + provider() + + +def test_call_undefined(provider): + with raises(errors.Error, match="Dependency is not defined"): + provider() + + +def test_call_undefined_error_message_with_container_instance_parent(): + class UserService: + def __init__(self, database): + self.database = database + + class Container(containers.DeclarativeContainer): + database = providers.Dependency() + + user_service = providers.Factory( + UserService, + database=database, # <---- missing dependency + ) + + container = Container() + + with raises(errors.Error, match="Dependency \"Container.database\" is not defined"): + container.user_service() + + +def test_call_undefined_error_message_with_container_provider_parent_deep(): + class Database: + pass + + class UserService: + def __init__(self, db): + self.db = db + + class Gateways(containers.DeclarativeContainer): + database_client = providers.Singleton(Database) + + class Services(containers.DeclarativeContainer): + gateways = providers.DependenciesContainer() + + user = providers.Factory( + UserService, + db=gateways.database_client, + ) + + class Container(containers.DeclarativeContainer): + gateways = providers.Container(Gateways) + + services = providers.Container( + Services, + # gateways=gateways, # <---- missing dependency + ) + + container = Container() + + with raises(errors.Error, match="Dependency \"Container.services.gateways.database_client\" is not defined"): + container.services().user() + + +def test_call_undefined_error_message_with_dependenciescontainer_provider_parent(): + class UserService: + def __init__(self, db): + self.db = db + + class Services(containers.DeclarativeContainer): + gateways = providers.DependenciesContainer() + + user = providers.Factory( + UserService, + db=gateways.database_client, # <---- missing dependency + ) + + services = Services() + + with raises(errors.Error, match="Dependency \"Services.gateways.database_client\" is not defined"): + services.user() + + +def test_assign_parent(provider): + parent = providers.DependenciesContainer() + provider.assign_parent(parent) + assert provider.parent is parent + + +def test_parent_name(provider): + container = containers.DynamicContainer() + container.name = provider + assert provider.parent_name == "name" + + +def test_parent_name_with_deep_parenting(provider): + container = providers.DependenciesContainer(name=provider) + _ = providers.DependenciesContainer(container=container) + assert provider.parent_name == "container.name" + + +def test_parent_name_is_none(): + provider = providers.Dependency() + assert provider.parent_name is None + + +def test_parent_deepcopy(provider): + container = containers.DynamicContainer() + container.name = provider + + copied = providers.deepcopy(container) + + assert container.name.parent is container + assert copied.name.parent is copied + + assert container is not copied + assert container.name is not copied.name + assert container.name.parent is not copied.name.parent + + +def test_forward_attr_to_default(): + default = providers.Configuration() + provider = providers.Dependency(default=default) + provider.from_dict({"foo": "bar"}) + assert default() == {"foo": "bar"} + + +def test_forward_attr_to_overriding(provider): + overriding = providers.Configuration() + provider.override(overriding) + provider.from_dict({"foo": "bar"}) + assert overriding() == {"foo": "bar"} + + +def test_forward_attr_to_none(provider): + with raises(AttributeError): + provider.from_dict + + +def test_deepcopy(provider): + provider_copy = providers.deepcopy(provider) + assert provider is not provider_copy + assert isinstance(provider, providers.Dependency) + + +def test_deepcopy_from_memo(provider): + provider_copy_memo = providers.Provider() + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + assert provider_copy is provider_copy_memo + + +def test_deepcopy_overridden(provider): + overriding_provider = providers.Provider() + + provider.override(overriding_provider) + + provider_copy = providers.deepcopy(provider) + overriding_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert isinstance(provider, providers.Dependency) + + assert overriding_provider is not overriding_provider_copy + assert isinstance(overriding_provider_copy, providers.Provider) + + +def test_deep_copy_default_object(): + default = {"foo": "bar"} + provider = providers.Dependency(dict, default=default) + + provider_copy = providers.deepcopy(provider) + + assert provider_copy() is default + assert provider_copy.default() is default + + +def test_deep_copy_default_provider(): + bar = object() + default = providers.Factory(dict, foo=providers.Object(bar)) + provider = providers.Dependency(dict, default=default) + + provider_copy = providers.deepcopy(provider) + + assert provider_copy() == {"foo": bar} + assert provider_copy.default() == {"foo": bar} + assert provider_copy()["foo"] is bar + + +def test_with_container_default_object(): + default = {"foo": "bar"} + + class Container(containers.DeclarativeContainer): + provider = providers.Dependency(dict, default=default) + + container = Container() + + assert container.provider() is default + assert container.provider.default() is default + + +def test_with_container_default_provider(): + bar = object() + + class Container(containers.DeclarativeContainer): + provider = providers.Dependency(dict, default=providers.Factory(dict, foo=providers.Object(bar))) + + container = Container() + + assert container.provider() == {"foo": bar} + assert container.provider.default() == {"foo": bar} + assert container.provider()["foo"] is bar + + +def test_with_container_default_provider_with_overriding(): + bar = object() + baz = object() + + class Container(containers.DeclarativeContainer): + provider = providers.Dependency(dict, default=providers.Factory(dict, foo=providers.Object(bar))) + + container = Container(provider=providers.Factory(dict, foo=providers.Object(baz))) + + assert container.provider() == {"foo": baz} + assert container.provider.default() == {"foo": bar} + assert container.provider()["foo"] is baz + + +def test_repr(provider): + assert repr(provider) == ( + "".format(repr(list), hex(id(provider))) + ) + + +def test_repr_in_container(): + class Container(containers.DeclarativeContainer): + dependency = providers.Dependency(instance_of=int) + + container = Container() + + assert repr(container.dependency) == ( + "".format( + repr(int), + hex(id(container.dependency)), + ) + ) + + +def test_external_dependency(): + assert isinstance(providers.ExternalDependency(), providers.Dependency) diff --git a/tests/unit/providers/test_dict_py2_py3.py b/tests/unit/providers/test_dict_py2_py3.py index 08e6670a..306d8a23 100644 --- a/tests/unit/providers/test_dict_py2_py3.py +++ b/tests/unit/providers/test_dict_py2_py3.py @@ -1,231 +1,245 @@ -"""Dependency injector dict provider unit tests.""" +"""Dict provider tests.""" import sys -import unittest - from dependency_injector import providers -class DictTests(unittest.TestCase): +def test_is_provider(): + assert providers.is_provider(providers.Dict()) is True - def test_is_provider(self): - self.assertTrue(providers.is_provider(providers.Dict())) - def test_provided_instance_provider(self): - provider = providers.Dict() - self.assertIsInstance(provider.provided, providers.ProvidedInstance) +def test_provided_instance_provider(): + provider = providers.Dict() + assert isinstance(provider.provided, providers.ProvidedInstance) - def test_init_with_non_string_keys(self): - a1 = object() - a2 = object() - provider = providers.Dict({a1: "i1", a2: "i2"}) - dict1 = provider() - dict2 = provider() +def test_init_with_non_string_keys(): + a1 = object() + a2 = object() + provider = providers.Dict({a1: "i1", a2: "i2"}) - self.assertEqual(dict1, {a1: "i1", a2: "i2"}) - self.assertEqual(dict2, {a1: "i1", a2: "i2"}) + dict1 = provider() + dict2 = provider() - self.assertIsNot(dict1, dict2) + assert dict1 == {a1: "i1", a2: "i2"} + assert dict2 == {a1: "i1", a2: "i2"} - def test_init_with_string_and_non_string_keys(self): - a1 = object() - provider = providers.Dict({a1: "i1"}, a2="i2") + assert dict1 is not dict2 - dict1 = provider() - dict2 = provider() - self.assertEqual(dict1, {a1: "i1", "a2": "i2"}) - self.assertEqual(dict2, {a1: "i1", "a2": "i2"}) +def test_init_with_string_and_non_string_keys(): + a1 = object() + provider = providers.Dict({a1: "i1"}, a2="i2") - self.assertIsNot(dict1, dict2) + dict1 = provider() + dict2 = provider() - def test_call_with_init_keyword_args(self): - provider = providers.Dict(a1="i1", a2="i2") + assert dict1 == {a1: "i1", "a2": "i2"} + assert dict2 == {a1: "i1", "a2": "i2"} - dict1 = provider() - dict2 = provider() + assert dict1 is not dict2 - self.assertEqual(dict1, {"a1": "i1", "a2": "i2"}) - self.assertEqual(dict2, {"a1": "i1", "a2": "i2"}) - self.assertIsNot(dict1, dict2) +def test_call_with_init_keyword_args(): + provider = providers.Dict(a1="i1", a2="i2") - def test_call_with_context_keyword_args(self): - provider = providers.Dict(a1="i1", a2="i2") - self.assertEqual( - provider(a3="i3", a4="i4"), - {"a1": "i1", "a2": "i2", "a3": "i3", "a4": "i4"}, - ) + dict1 = provider() + dict2 = provider() - def test_call_with_provider(self): - provider = providers.Dict( - a1=providers.Factory(str, "i1"), - a2=providers.Factory(str, "i2"), - ) - self.assertEqual(provider(), {"a1": "i1", "a2": "i2"}) + assert dict1 == {"a1": "i1", "a2": "i2"} + assert dict2 == {"a1": "i1", "a2": "i2"} - def test_fluent_interface(self): - provider = providers.Dict() \ - .add_kwargs(a1="i1", a2="i2") - self.assertEqual(provider(), {"a1": "i1", "a2": "i2"}) + assert dict1 is not dict2 - def test_add_kwargs(self): - provider = providers.Dict() \ - .add_kwargs(a1="i1") \ - .add_kwargs(a2="i2") - self.assertEqual(provider.kwargs, {"a1": "i1", "a2": "i2"}) - def test_add_kwargs_non_string_keys(self): - a1 = object() - a2 = object() - provider = providers.Dict() \ - .add_kwargs({a1: "i1"}) \ - .add_kwargs({a2: "i2"}) - self.assertEqual(provider.kwargs, {a1: "i1", a2: "i2"}) +def test_call_with_context_keyword_args(): + provider = providers.Dict(a1="i1", a2="i2") + assert provider(a3="i3", a4="i4") == {"a1": "i1", "a2": "i2", "a3": "i3", "a4": "i4"} - def test_add_kwargs_string_and_non_string_keys(self): - a2 = object() - provider = providers.Dict() \ - .add_kwargs(a1="i1") \ - .add_kwargs({a2: "i2"}) - self.assertEqual(provider.kwargs, {"a1": "i1", a2: "i2"}) - def test_set_kwargs(self): - provider = providers.Dict() \ - .add_kwargs(a1="i1", a2="i2") \ - .set_kwargs(a3="i3", a4="i4") - self.assertEqual(provider.kwargs, {"a3": "i3", "a4": "i4"}) +def test_call_with_provider(): + provider = providers.Dict( + a1=providers.Factory(str, "i1"), + a2=providers.Factory(str, "i2"), + ) + assert provider() == {"a1": "i1", "a2": "i2"} - def test_set_kwargs_non_string_keys(self): - a3 = object() - a4 = object() - provider = providers.Dict() \ - .add_kwargs(a1="i1", a2="i2") \ - .set_kwargs({a3: "i3", a4: "i4"}) - self.assertEqual(provider.kwargs, {a3: "i3", a4: "i4"}) - def test_set_kwargs_string_and_non_string_keys(self): - a3 = object() - provider = providers.Dict() \ - .add_kwargs(a1="i1", a2="i2") \ - .set_kwargs({a3: "i3"}, a4="i4") - self.assertEqual(provider.kwargs, {a3: "i3", "a4": "i4"}) +def test_fluent_interface(): + provider = providers.Dict() \ + .add_kwargs(a1="i1", a2="i2") + assert provider() == {"a1": "i1", "a2": "i2"} - def test_clear_kwargs(self): - provider = providers.Dict() \ - .add_kwargs(a1="i1", a2="i2") \ - .clear_kwargs() - self.assertEqual(provider.kwargs, {}) - def test_call_overridden(self): - provider = providers.Dict(a1="i1", a2="i2") - overriding_provider1 = providers.Dict(a2="i2", a3="i3") - overriding_provider2 = providers.Dict(a3="i3", a4="i4") +def test_add_kwargs(): + provider = providers.Dict() \ + .add_kwargs(a1="i1") \ + .add_kwargs(a2="i2") + assert provider.kwargs == {"a1": "i1", "a2": "i2"} - provider.override(overriding_provider1) - provider.override(overriding_provider2) - instance1 = provider() - instance2 = provider() +def test_add_kwargs_non_string_keys(): + a1 = object() + a2 = object() + provider = providers.Dict() \ + .add_kwargs({a1: "i1"}) \ + .add_kwargs({a2: "i2"}) + assert provider.kwargs == {a1: "i1", a2: "i2"} - self.assertIsNot(instance1, instance2) - self.assertEqual(instance1, {"a3": "i3", "a4": "i4"}) - self.assertEqual(instance2, {"a3": "i3", "a4": "i4"}) - def test_deepcopy(self): - provider = providers.Dict(a1="i1", a2="i2") +def test_add_kwargs_string_and_non_string_keys(): + a2 = object() + provider = providers.Dict() \ + .add_kwargs(a1="i1") \ + .add_kwargs({a2: "i2"}) + assert provider.kwargs == {"a1": "i1", a2: "i2"} - provider_copy = providers.deepcopy(provider) - self.assertIsNot(provider, provider_copy) - self.assertEqual(provider.kwargs, provider_copy.kwargs) - self.assertIsInstance(provider, providers.Dict) +def test_set_kwargs(): + provider = providers.Dict() \ + .add_kwargs(a1="i1", a2="i2") \ + .set_kwargs(a3="i3", a4="i4") + assert provider.kwargs == {"a3": "i3", "a4": "i4"} - def test_deepcopy_from_memo(self): - provider = providers.Dict(a1="i1", a2="i2") - provider_copy_memo = providers.Dict(a1="i1", a2="i2") - provider_copy = providers.deepcopy( - provider, - memo={id(provider): provider_copy_memo}, - ) +def test_set_kwargs_non_string_keys(): + a3 = object() + a4 = object() + provider = providers.Dict() \ + .add_kwargs(a1="i1", a2="i2") \ + .set_kwargs({a3: "i3", a4: "i4"}) + assert provider.kwargs == {a3: "i3", a4: "i4"} - self.assertIs(provider_copy, provider_copy_memo) - def test_deepcopy_kwargs(self): - provider = providers.Dict() - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) +def test_set_kwargs_string_and_non_string_keys(): + a3 = object() + provider = providers.Dict() \ + .add_kwargs(a1="i1", a2="i2") \ + .set_kwargs({a3: "i3"}, a4="i4") + assert provider.kwargs == {a3: "i3", "a4": "i4"} - provider.add_kwargs(d1=dependent_provider1, d2=dependent_provider2) - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.kwargs["d1"] - dependent_provider_copy2 = provider_copy.kwargs["d2"] +def test_clear_kwargs(): + provider = providers.Dict() \ + .add_kwargs(a1="i1", a2="i2") \ + .clear_kwargs() + assert provider.kwargs == {} - self.assertNotEqual(provider.kwargs, provider_copy.kwargs) - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) +def test_call_overridden(): + provider = providers.Dict(a1="i1", a2="i2") + overriding_provider1 = providers.Dict(a2="i2", a3="i3") + overriding_provider2 = providers.Dict(a3="i3", a4="i4") - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) + provider.override(overriding_provider1) + provider.override(overriding_provider2) - def test_deepcopy_kwargs_non_string_keys(self): - a1 = object() - a2 = object() + instance1 = provider() + instance2 = provider() - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) + assert instance1 is not instance2 + assert instance1 == {"a3": "i3", "a4": "i4"} + assert instance2 == {"a3": "i3", "a4": "i4"} - provider = providers.Dict({a1: dependent_provider1, a2: dependent_provider2}) - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.kwargs[a1] - dependent_provider_copy2 = provider_copy.kwargs[a2] +def test_deepcopy(): + provider = providers.Dict(a1="i1", a2="i2") - self.assertNotEqual(provider.kwargs, provider_copy.kwargs) + provider_copy = providers.deepcopy(provider) - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) + assert provider is not provider_copy + assert provider.kwargs == provider_copy.kwargs + assert isinstance(provider, providers.Dict) - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - def test_deepcopy_overridden(self): - provider = providers.Dict() - object_provider = providers.Object(object()) +def test_deepcopy_from_memo(): + provider = providers.Dict(a1="i1", a2="i2") + provider_copy_memo = providers.Dict(a1="i1", a2="i2") - provider.override(object_provider) + provider_copy = providers.deepcopy( + provider, + memo={id(provider): provider_copy_memo}, + ) - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] + assert provider_copy is provider_copy_memo - self.assertIsNot(provider, provider_copy) - self.assertEqual(provider.kwargs, provider_copy.kwargs) - self.assertIsInstance(provider, providers.Dict) - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) +def test_deepcopy_kwargs(): + provider = providers.Dict() + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) - def test_deepcopy_with_sys_streams(self): - provider = providers.Dict() - provider.add_kwargs(stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr) + provider.add_kwargs(d1=dependent_provider1, d2=dependent_provider2) - provider_copy = providers.deepcopy(provider) + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["d1"] + dependent_provider_copy2 = provider_copy.kwargs["d2"] - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider_copy, providers.Dict) - self.assertIs(provider.kwargs["stdin"], sys.stdin) - self.assertIs(provider.kwargs["stdout"], sys.stdout) - self.assertIs(provider.kwargs["stderr"], sys.stderr) + assert provider.kwargs != provider_copy.kwargs - def test_repr(self): - provider = providers.Dict(a1=1, a2=2) - self.assertEqual(repr(provider), - "".format( - repr(provider.kwargs), - hex(id(provider)))) + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs_non_string_keys(): + a1 = object() + a2 = object() + + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider = providers.Dict({a1: dependent_provider1, a2: dependent_provider2}) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs[a1] + dependent_provider_copy2 = provider_copy.kwargs[a2] + + assert provider.kwargs != provider_copy.kwargs + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(): + provider = providers.Dict() + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.kwargs == provider_copy.kwargs + assert isinstance(provider, providers.Dict) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(): + provider = providers.Dict() + provider.add_kwargs(stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, providers.Dict) + assert provider.kwargs["stdin"] is sys.stdin + assert provider.kwargs["stdout"] is sys.stdout + assert provider.kwargs["stderr"] is sys.stderr + + +def test_repr(): + provider = providers.Dict(a1=1, a2=2) + assert repr(provider) == ( + "".format(repr(provider.kwargs), hex(id(provider))) + ) diff --git a/tests/unit/providers/test_factories_py2_py3.py b/tests/unit/providers/test_factories_py2_py3.py deleted file mode 100644 index c0ed5971..00000000 --- a/tests/unit/providers/test_factories_py2_py3.py +++ /dev/null @@ -1,690 +0,0 @@ -"""Dependency injector factory providers unit tests.""" - -import sys - -import unittest - -from dependency_injector import ( - containers, - providers, - errors, -) - - -class Example(object): - - def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None, - init_arg4=None): - self.init_arg1 = init_arg1 - self.init_arg2 = init_arg2 - self.init_arg3 = init_arg3 - self.init_arg4 = init_arg4 - - self.attribute1 = None - self.attribute2 = None - - -class FactoryTests(unittest.TestCase): - - def test_is_provider(self): - self.assertTrue(providers.is_provider(providers.Factory(Example))) - - def test_init_with_callable(self): - self.assertTrue(providers.Factory(credits)) - - def test_init_with_not_callable(self): - self.assertRaises(errors.Error, providers.Factory, 123) - - def test_init_optional_provides(self): - provider = providers.Factory() - provider.set_provides(object) - self.assertIs(provider.provides, object) - self.assertIsInstance(provider(), object) - - def test_set_provides_returns_self(self): - provider = providers.Factory() - self.assertIs(provider.set_provides(object), provider) - - def test_init_with_valid_provided_type(self): - class ExampleProvider(providers.Factory): - provided_type = Example - - example_provider = ExampleProvider(Example, 1, 2) - - self.assertIsInstance(example_provider(), Example) - - def test_init_with_valid_provided_subtype(self): - class ExampleProvider(providers.Factory): - provided_type = Example - - class NewExampe(Example): - pass - - example_provider = ExampleProvider(NewExampe, 1, 2) - - self.assertIsInstance(example_provider(), NewExampe) - - def test_init_with_invalid_provided_type(self): - class ExampleProvider(providers.Factory): - provided_type = Example - - with self.assertRaises(errors.Error): - ExampleProvider(list) - - def test_provided_instance_provider(self): - provider = providers.Factory(Example) - self.assertIsInstance(provider.provided, providers.ProvidedInstance) - - def test_call(self): - provider = providers.Factory(Example) - - instance1 = provider() - instance2 = provider() - - self.assertIsNot(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_init_positional_args(self): - provider = providers.Factory(Example, "i1", "i2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.init_arg1, "i1") - self.assertEqual(instance1.init_arg2, "i2") - - self.assertEqual(instance2.init_arg1, "i1") - self.assertEqual(instance2.init_arg2, "i2") - - self.assertIsNot(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_init_keyword_args(self): - provider = providers.Factory(Example, init_arg1="i1", init_arg2="i2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.init_arg1, "i1") - self.assertEqual(instance1.init_arg2, "i2") - - self.assertEqual(instance2.init_arg1, "i1") - self.assertEqual(instance2.init_arg2, "i2") - - self.assertIsNot(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_init_positional_and_keyword_args(self): - provider = providers.Factory(Example, "i1", init_arg2="i2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.init_arg1, "i1") - self.assertEqual(instance1.init_arg2, "i2") - - self.assertEqual(instance2.init_arg1, "i1") - self.assertEqual(instance2.init_arg2, "i2") - - self.assertIsNot(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_attributes(self): - provider = providers.Factory(Example) - provider.add_attributes(attribute1="a1", attribute2="a2") - - instance1 = provider() - instance2 = provider() - - self.assertEqual(instance1.attribute1, "a1") - self.assertEqual(instance1.attribute2, "a2") - - self.assertEqual(instance2.attribute1, "a1") - self.assertEqual(instance2.attribute2, "a2") - - self.assertIsNot(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_call_with_context_args(self): - provider = providers.Factory(Example, 11, 22) - - instance = provider(33, 44) - - self.assertEqual(instance.init_arg1, 11) - self.assertEqual(instance.init_arg2, 22) - self.assertEqual(instance.init_arg3, 33) - self.assertEqual(instance.init_arg4, 44) - - def test_call_with_context_kwargs(self): - provider = providers.Factory(Example, init_arg1=1) - - instance1 = provider(init_arg2=22) - self.assertEqual(instance1.init_arg1, 1) - self.assertEqual(instance1.init_arg2, 22) - - instance2 = provider(init_arg1=11, init_arg2=22) - self.assertEqual(instance2.init_arg1, 11) - self.assertEqual(instance2.init_arg2, 22) - - def test_call_with_context_args_and_kwargs(self): - provider = providers.Factory(Example, 11) - - instance = provider(22, init_arg3=33, init_arg4=44) - - self.assertEqual(instance.init_arg1, 11) - self.assertEqual(instance.init_arg2, 22) - self.assertEqual(instance.init_arg3, 33) - self.assertEqual(instance.init_arg4, 44) - - def test_call_with_deep_context_kwargs(self): - """`Factory` providers deep init injections example.""" - class Regularizer: - def __init__(self, alpha): - self.alpha = alpha - - class Loss: - def __init__(self, regularizer): - self.regularizer = regularizer - - class ClassificationTask: - def __init__(self, loss): - self.loss = loss - - class Algorithm: - def __init__(self, task): - self.task = task - - algorithm_factory = providers.Factory( - Algorithm, - task=providers.Factory( - ClassificationTask, - loss=providers.Factory( - Loss, - regularizer=providers.Factory( - Regularizer, - ), - ), - ), - ) - - algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5) - algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7) - algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8)) - - self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5) - self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7) - self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8) - - def test_fluent_interface(self): - provider = providers.Factory(Example) \ - .add_args(1, 2) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .add_attributes(attribute1=5, attribute2=6) - - instance = provider() - - self.assertEqual(instance.init_arg1, 1) - self.assertEqual(instance.init_arg2, 2) - self.assertEqual(instance.init_arg3, 3) - self.assertEqual(instance.init_arg4, 4) - self.assertEqual(instance.attribute1, 5) - self.assertEqual(instance.attribute2, 6) - - def test_set_args(self): - provider = providers.Factory(Example) \ - .add_args(1, 2) \ - .set_args(3, 4) - self.assertEqual(provider.args, (3, 4)) - - def test_set_kwargs(self): - provider = providers.Factory(Example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .set_kwargs(init_arg3=4, init_arg4=5) - self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) - - def test_set_attributes(self): - provider = providers.Factory(Example) \ - .add_attributes(attribute1=5, attribute2=6) \ - .set_attributes(attribute1=6, attribute2=7) - self.assertEqual(provider.attributes, dict(attribute1=6, attribute2=7)) - - def test_clear_args(self): - provider = providers.Factory(Example) \ - .add_args(1, 2) \ - .clear_args() - self.assertEqual(provider.args, tuple()) - - def test_clear_kwargs(self): - provider = providers.Factory(Example) \ - .add_kwargs(init_arg3=3, init_arg4=4) \ - .clear_kwargs() - self.assertEqual(provider.kwargs, dict()) - - def test_clear_attributes(self): - provider = providers.Factory(Example) \ - .add_attributes(attribute1=5, attribute2=6) \ - .clear_attributes() - self.assertEqual(provider.attributes, dict()) - - def test_call_overridden(self): - provider = providers.Factory(Example) - overriding_provider1 = providers.Factory(dict) - overriding_provider2 = providers.Factory(list) - - provider.override(overriding_provider1) - provider.override(overriding_provider2) - - instance1 = provider() - instance2 = provider() - - self.assertIsNot(instance1, instance2) - self.assertIsInstance(instance1, list) - self.assertIsInstance(instance2, list) - - def test_deepcopy(self): - provider = providers.Factory(Example) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.cls, provider_copy.cls) - self.assertIsInstance(provider, providers.Factory) - - def test_deepcopy_from_memo(self): - provider = providers.Factory(Example) - provider_copy_memo = providers.Factory(Example) - - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_args(self): - provider = providers.Factory(Example) - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) - - provider.add_args(dependent_provider1, dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.args[0] - dependent_provider_copy2 = provider_copy.args[1] - - self.assertNotEqual(provider.args, provider_copy.args) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_kwargs(self): - provider = providers.Factory(Example) - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) - - provider.add_kwargs(a1=dependent_provider1, a2=dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.kwargs["a1"] - dependent_provider_copy2 = provider_copy.kwargs["a2"] - - self.assertNotEqual(provider.kwargs, provider_copy.kwargs) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_attributes(self): - provider = providers.Factory(Example) - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) - - provider.add_attributes(a1=dependent_provider1, a2=dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.attributes["a1"] - dependent_provider_copy2 = provider_copy.attributes["a2"] - - self.assertNotEqual(provider.attributes, provider_copy.attributes) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_overridden(self): - provider = providers.Factory(Example) - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIs(provider.cls, provider_copy.cls) - self.assertIsInstance(provider, providers.Factory) - - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - - def test_deepcopy_with_sys_streams(self): - provider = providers.Factory(Example) - provider.add_args(sys.stdin) - provider.add_kwargs(a2=sys.stdout) - provider.add_attributes(a3=sys.stderr) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider_copy, providers.Factory) - self.assertIs(provider.args[0], sys.stdin) - self.assertIs(provider.kwargs["a2"], sys.stdout) - self.assertIs(provider.attributes["a3"], sys.stderr) - - def test_repr(self): - provider = providers.Factory(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class DelegatedFactoryTests(unittest.TestCase): - - def test_inheritance(self): - self.assertIsInstance(providers.DelegatedFactory(object), - providers.Factory) - - def test_is_provider(self): - self.assertTrue( - providers.is_provider(providers.DelegatedFactory(object))) - - def test_is_delegated_provider(self): - self.assertTrue( - providers.is_delegated(providers.DelegatedFactory(object))) - - def test_repr(self): - provider = providers.DelegatedFactory(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class AbstractFactoryTests(unittest.TestCase): - - def test_inheritance(self): - self.assertIsInstance(providers.AbstractFactory(Example), - providers.Factory) - - def test_call_overridden_by_factory(self): - provider = providers.AbstractFactory(object) - provider.override(providers.Factory(Example)) - - self.assertIsInstance(provider(), Example) - - def test_call_overridden_by_delegated_factory(self): - provider = providers.AbstractFactory(object) - provider.override(providers.DelegatedFactory(Example)) - - self.assertIsInstance(provider(), Example) - - def test_call_not_overridden(self): - provider = providers.AbstractFactory(object) - - with self.assertRaises(errors.Error): - provider() - - def test_override_by_not_factory(self): - provider = providers.AbstractFactory(object) - - with self.assertRaises(errors.Error): - provider.override(providers.Callable(object)) - - def test_provide_not_implemented(self): - provider = providers.AbstractFactory(Example) - - with self.assertRaises(NotImplementedError): - provider._provide(tuple(), dict()) - - def test_repr(self): - provider = providers.AbstractFactory(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class FactoryDelegateTests(unittest.TestCase): - - def setUp(self): - self.delegated = providers.Factory(object) - self.delegate = providers.FactoryDelegate(self.delegated) - - def test_is_delegate(self): - self.assertIsInstance(self.delegate, providers.Delegate) - - def test_init_with_not_factory(self): - self.assertRaises(errors.Error, - providers.FactoryDelegate, - providers.Object(object())) - - -class FactoryAggregateTests(unittest.TestCase): - - class ExampleA(Example): - pass - - class ExampleB(Example): - pass - - def setUp(self): - self.example_a_factory = providers.Factory(self.ExampleA) - self.example_b_factory = providers.Factory(self.ExampleB) - self.factory_aggregate = providers.FactoryAggregate( - example_a=self.example_a_factory, - example_b=self.example_b_factory, - ) - - def test_is_provider(self): - self.assertTrue(providers.is_provider(self.factory_aggregate)) - - def test_is_delegated_provider(self): - self.assertTrue(providers.is_delegated(self.factory_aggregate)) - - def test_init_with_non_string_keys(self): - factory = providers.FactoryAggregate({ - self.ExampleA: self.example_a_factory, - self.ExampleB: self.example_b_factory, - }) - - object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4) - object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44) - - self.assertIsInstance(object_a, self.ExampleA) - self.assertEqual(object_a.init_arg1, 1) - self.assertEqual(object_a.init_arg2, 2) - self.assertEqual(object_a.init_arg3, 3) - self.assertEqual(object_a.init_arg4, 4) - - self.assertIsInstance(object_b, self.ExampleB) - self.assertEqual(object_b.init_arg1, 11) - self.assertEqual(object_b.init_arg2, 22) - self.assertEqual(object_b.init_arg3, 33) - self.assertEqual(object_b.init_arg4, 44) - - self.assertEqual( - factory.factories, - { - self.ExampleA: self.example_a_factory, - self.ExampleB: self.example_b_factory, - }, - ) - - def test_init_with_not_a_factory(self): - with self.assertRaises(errors.Error): - providers.FactoryAggregate( - example_a=providers.Factory(self.ExampleA), - example_b=object()) - - def test_init_optional_factories(self): - provider = providers.FactoryAggregate() - provider.set_factories( - example_a=self.example_a_factory, - example_b=self.example_b_factory, - ) - self.assertEqual( - provider.factories, - { - "example_a": self.example_a_factory, - "example_b": self.example_b_factory, - }, - ) - self.assertIsInstance(provider("example_a"), self.ExampleA) - self.assertIsInstance(provider("example_b"), self.ExampleB) - - def test_set_factories_with_non_string_keys(self): - factory = providers.FactoryAggregate() - factory.set_factories({ - self.ExampleA: self.example_a_factory, - self.ExampleB: self.example_b_factory, - }) - - object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4) - object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44) - - self.assertIsInstance(object_a, self.ExampleA) - self.assertEqual(object_a.init_arg1, 1) - self.assertEqual(object_a.init_arg2, 2) - self.assertEqual(object_a.init_arg3, 3) - self.assertEqual(object_a.init_arg4, 4) - - self.assertIsInstance(object_b, self.ExampleB) - self.assertEqual(object_b.init_arg1, 11) - self.assertEqual(object_b.init_arg2, 22) - self.assertEqual(object_b.init_arg3, 33) - self.assertEqual(object_b.init_arg4, 44) - - self.assertEqual( - factory.factories, - { - self.ExampleA: self.example_a_factory, - self.ExampleB: self.example_b_factory, - }, - ) - - def test_set_factories_returns_self(self): - provider = providers.FactoryAggregate() - self.assertIs(provider.set_factories(example_a=self.example_a_factory), provider) - - def test_call(self): - object_a = self.factory_aggregate("example_a", - 1, 2, init_arg3=3, init_arg4=4) - object_b = self.factory_aggregate("example_b", - 11, 22, init_arg3=33, init_arg4=44) - - self.assertIsInstance(object_a, self.ExampleA) - self.assertEqual(object_a.init_arg1, 1) - self.assertEqual(object_a.init_arg2, 2) - self.assertEqual(object_a.init_arg3, 3) - self.assertEqual(object_a.init_arg4, 4) - - self.assertIsInstance(object_b, self.ExampleB) - self.assertEqual(object_b.init_arg1, 11) - self.assertEqual(object_b.init_arg2, 22) - self.assertEqual(object_b.init_arg3, 33) - self.assertEqual(object_b.init_arg4, 44) - - def test_call_factory_name_as_kwarg(self): - object_a = self.factory_aggregate( - factory_name="example_a", - init_arg1=1, - init_arg2=2, - init_arg3=3, - init_arg4=4, - ) - self.assertIsInstance(object_a, self.ExampleA) - self.assertEqual(object_a.init_arg1, 1) - self.assertEqual(object_a.init_arg2, 2) - self.assertEqual(object_a.init_arg3, 3) - self.assertEqual(object_a.init_arg4, 4) - - def test_call_no_factory_name(self): - with self.assertRaises(TypeError): - self.factory_aggregate() - - def test_call_no_such_provider(self): - with self.assertRaises(errors.NoSuchProviderError): - self.factory_aggregate("unknown") - - def test_overridden(self): - with self.assertRaises(errors.Error): - self.factory_aggregate.override(providers.Object(object())) - - def test_getattr(self): - self.assertIs(self.factory_aggregate.example_a, self.example_a_factory) - self.assertIs(self.factory_aggregate.example_b, self.example_b_factory) - - def test_getattr_no_such_provider(self): - with self.assertRaises(errors.NoSuchProviderError): - self.factory_aggregate.unknown - - def test_factories(self): - self.assertDictEqual(self.factory_aggregate.factories, - dict(example_a=self.example_a_factory, - example_b=self.example_b_factory)) - - def test_deepcopy(self): - provider_copy = providers.deepcopy(self.factory_aggregate) - - self.assertIsNot(self.factory_aggregate, provider_copy) - self.assertIsInstance(provider_copy, type(self.factory_aggregate)) - - self.assertIsNot(self.factory_aggregate.example_a, provider_copy.example_a) - self.assertIsInstance(self.factory_aggregate.example_a, type(provider_copy.example_a)) - self.assertIs(self.factory_aggregate.example_a.cls, provider_copy.example_a.cls) - - self.assertIsNot(self.factory_aggregate.example_b, provider_copy.example_b) - self.assertIsInstance(self.factory_aggregate.example_b, type(provider_copy.example_b)) - self.assertIs(self.factory_aggregate.example_b.cls, provider_copy.example_b.cls) - - def test_deepcopy_with_non_string_keys(self): - factory_aggregate = providers.FactoryAggregate({ - self.ExampleA: self.example_a_factory, - self.ExampleB: self.example_b_factory, - }) - provider_copy = providers.deepcopy(factory_aggregate) - - self.assertIsNot(factory_aggregate, provider_copy) - self.assertIsInstance(provider_copy, type(factory_aggregate)) - - self.assertIsNot(factory_aggregate.factories[self.ExampleA], provider_copy.factories[self.ExampleA]) - self.assertIsInstance(factory_aggregate.factories[self.ExampleA], type(provider_copy.factories[self.ExampleA])) - self.assertIs(factory_aggregate.factories[self.ExampleA].cls, provider_copy.factories[self.ExampleA].cls) - - self.assertIsNot(factory_aggregate.factories[self.ExampleB], provider_copy.factories[self.ExampleB]) - self.assertIsInstance(factory_aggregate.factories[self.ExampleB], type(provider_copy.factories[self.ExampleB])) - self.assertIs(factory_aggregate.factories[self.ExampleB].cls, provider_copy.factories[self.ExampleB].cls) - - def test_repr(self): - self.assertEqual(repr(self.factory_aggregate), - "".format( - repr(self.factory_aggregate.factories), - hex(id(self.factory_aggregate)))) diff --git a/tests/unit/providers/test_injections_py2_py3.py b/tests/unit/providers/test_injections_py2_py3.py deleted file mode 100644 index c257f230..00000000 --- a/tests/unit/providers/test_injections_py2_py3.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Dependency injector injections unit tests.""" - -import unittest - -from dependency_injector import providers - - -class PositionalInjectionTests(unittest.TestCase): - - def test_isinstance(self): - injection = providers.PositionalInjection(1) - self.assertIsInstance(injection, providers.Injection) - - def test_get_value_with_not_provider(self): - injection = providers.PositionalInjection(123) - self.assertEqual(injection.get_value(), 123) - - def test_get_value_with_factory(self): - injection = providers.PositionalInjection(providers.Factory(object)) - - obj1 = injection.get_value() - obj2 = injection.get_value() - - self.assertIs(type(obj1), object) - self.assertIs(type(obj2), object) - self.assertIsNot(obj1, obj2) - - def test_get_original_value(self): - provider = providers.Factory(object) - injection = providers.PositionalInjection(provider) - self.assertIs(injection.get_original_value(), provider) - - def test_deepcopy(self): - provider = providers.Factory(object) - injection = providers.PositionalInjection(provider) - - injection_copy = providers.deepcopy(injection) - - self.assertIsNot(injection_copy, injection) - self.assertIsNot(injection_copy.get_original_value(), - injection.get_original_value()) - - def test_deepcopy_memo(self): - provider = providers.Factory(object) - injection = providers.PositionalInjection(provider) - injection_copy_orig = providers.PositionalInjection(provider) - - injection_copy = providers.deepcopy( - injection, {id(injection): injection_copy_orig}) - - self.assertIs(injection_copy, injection_copy_orig) - self.assertIs(injection_copy.get_original_value(), - injection.get_original_value()) - - -class NamedInjectionTests(unittest.TestCase): - - def test_isinstance(self): - injection = providers.NamedInjection("name", 1) - self.assertIsInstance(injection, providers.Injection) - - def test_get_name(self): - injection = providers.NamedInjection("name", 123) - self.assertEqual(injection.get_name(), "name") - - def test_get_value_with_not_provider(self): - injection = providers.NamedInjection("name", 123) - self.assertEqual(injection.get_value(), 123) - - def test_get_value_with_factory(self): - injection = providers.NamedInjection("name", - providers.Factory(object)) - - obj1 = injection.get_value() - obj2 = injection.get_value() - - self.assertIs(type(obj1), object) - self.assertIs(type(obj2), object) - self.assertIsNot(obj1, obj2) - - def test_get_original_value(self): - provider = providers.Factory(object) - injection = providers.NamedInjection("name", provider) - self.assertIs(injection.get_original_value(), provider) - - def test_deepcopy(self): - provider = providers.Factory(object) - injection = providers.NamedInjection("name", provider) - - injection_copy = providers.deepcopy(injection) - - self.assertIsNot(injection_copy, injection) - self.assertIsNot(injection_copy.get_original_value(), - injection.get_original_value()) - - def test_deepcopy_memo(self): - provider = providers.Factory(object) - injection = providers.NamedInjection("name", provider) - injection_copy_orig = providers.NamedInjection("name", provider) - - injection_copy = providers.deepcopy( - injection, {id(injection): injection_copy_orig}) - - self.assertIs(injection_copy, injection_copy_orig) - self.assertIs(injection_copy.get_original_value(), - injection.get_original_value()) diff --git a/tests/unit/providers/test_list_py2_py3.py b/tests/unit/providers/test_list_py2_py3.py index a234ec43..9e709bd4 100644 --- a/tests/unit/providers/test_list_py2_py3.py +++ b/tests/unit/providers/test_list_py2_py3.py @@ -1,140 +1,142 @@ -"""Dependency injector list provider unit tests.""" +"""List provider tests.""" import sys -import unittest - from dependency_injector import providers -class ListTests(unittest.TestCase): +def test_is_provider(): + assert providers.is_provider(providers.List()) is True + - def test_is_provider(self): - self.assertTrue(providers.is_provider(providers.List())) +def test_provided_instance_provider(): + provider = providers.List() + assert isinstance(provider.provided, providers.ProvidedInstance) - def test_provided_instance_provider(self): - provider = providers.List() - self.assertIsInstance(provider.provided, providers.ProvidedInstance) - def test_call_with_init_positional_args(self): - provider = providers.List("i1", "i2") +def test_call_with_init_positional_args(): + provider = providers.List("i1", "i2") - list1 = provider() - list2 = provider() + list1 = provider() + list2 = provider() - self.assertEqual(list1, ["i1", "i2"]) - self.assertEqual(list2, ["i1", "i2"]) + assert list1 == ["i1", "i2"] + assert list2 == ["i1", "i2"] + assert list1 is not list2 - self.assertIsNot(list1, list2) - def test_call_with_context_args(self): - provider = providers.List("i1", "i2") +def test_call_with_context_args(): + provider = providers.List("i1", "i2") + assert provider("i3", "i4") == ["i1", "i2", "i3", "i4"] - self.assertEqual(provider("i3", "i4"), ["i1", "i2", "i3", "i4"]) - def test_fluent_interface(self): - provider = providers.List() \ - .add_args(1, 2) +def test_fluent_interface(): + provider = providers.List() \ + .add_args(1, 2) + assert provider() == [1, 2] - self.assertEqual(provider(), [1, 2]) - def test_set_args(self): - provider = providers.List() \ - .add_args(1, 2) \ - .set_args(3, 4) - self.assertEqual(provider.args, (3, 4)) +def test_set_args(): + provider = providers.List() \ + .add_args(1, 2) \ + .set_args(3, 4) + assert provider.args == (3, 4) - def test_clear_args(self): - provider = providers.List() \ - .add_args(1, 2) \ - .clear_args() - self.assertEqual(provider.args, tuple()) - def test_call_overridden(self): - provider = providers.List(1, 2) - overriding_provider1 = providers.List(2, 3) - overriding_provider2 = providers.List(3, 4) +def test_clear_args(): + provider = providers.List() \ + .add_args(1, 2) \ + .clear_args() + assert provider.args == tuple() - provider.override(overriding_provider1) - provider.override(overriding_provider2) - instance1 = provider() - instance2 = provider() +def test_call_overridden(): + provider = providers.List(1, 2) + overriding_provider1 = providers.List(2, 3) + overriding_provider2 = providers.List(3, 4) - self.assertIsNot(instance1, instance2) - self.assertEqual(instance1, [3, 4]) - self.assertEqual(instance2, [3, 4]) + provider.override(overriding_provider1) + provider.override(overriding_provider2) - def test_deepcopy(self): - provider = providers.List(1, 2) + instance1 = provider() + instance2 = provider() - provider_copy = providers.deepcopy(provider) + assert instance1 is not instance2 + assert instance1 == [3, 4] + assert instance2 == [3, 4] - self.assertIsNot(provider, provider_copy) - self.assertEqual(provider.args, provider_copy.args) - self.assertIsInstance(provider, providers.List) - def test_deepcopy_from_memo(self): - provider = providers.List(1, 2) - provider_copy_memo = providers.List(1, 2) +def test_deepcopy(): + provider = providers.List(1, 2) - provider_copy = providers.deepcopy( - provider, memo={id(provider): provider_copy_memo}) + provider_copy = providers.deepcopy(provider) - self.assertIs(provider_copy, provider_copy_memo) + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert isinstance(provider, providers.List) - def test_deepcopy_args(self): - provider = providers.List() - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) - provider.add_args(dependent_provider1, dependent_provider2) +def test_deepcopy_from_memo(): + provider = providers.List(1, 2) + provider_copy_memo = providers.List(1, 2) - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.args[0] - dependent_provider_copy2 = provider_copy.args[1] + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + assert provider_copy is provider_copy_memo - self.assertNotEqual(provider.args, provider_copy.args) - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) +def test_deepcopy_args(): + provider = providers.List() + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) + provider.add_args(dependent_provider1, dependent_provider2) - def test_deepcopy_overridden(self): - provider = providers.List() - object_provider = providers.Object(object()) + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] - provider.override(object_provider) + assert provider.args != provider_copy.args - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 - self.assertIsNot(provider, provider_copy) - self.assertEqual(provider.args, provider_copy.args) - self.assertIsInstance(provider, providers.List) + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - def test_deepcopy_with_sys_streams(self): - provider = providers.List() - provider.add_args(sys.stdin, sys.stdout, sys.stderr) +def test_deepcopy_overridden(): + provider = providers.List() + object_provider = providers.Object(object()) - provider_copy = providers.deepcopy(provider) + provider.override(object_provider) - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider_copy, providers.List) - self.assertIs(provider.args[0], sys.stdin) - self.assertIs(provider.args[1], sys.stdout) - self.assertIs(provider.args[2], sys.stderr) + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] - def test_repr(self): - provider = providers.List(1, 2) + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert isinstance(provider, providers.List) - self.assertEqual(repr(provider), - "".format( - repr(list(provider.args)), - hex(id(provider)))) + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(): + provider = providers.List() + provider.add_args(sys.stdin, sys.stdout, sys.stderr) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, providers.List) + assert provider.args[0] is sys.stdin + assert provider.args[1] is sys.stdout + assert provider.args[2] is sys.stderr + + +def test_repr(): + provider = providers.List(1, 2) + assert repr(provider) == ( + "".format(repr(list(provider.args)), hex(id(provider))) + ) diff --git a/tests/unit/providers/test_object_py2_py3.py b/tests/unit/providers/test_object_py2_py3.py new file mode 100644 index 00000000..d49aa71a --- /dev/null +++ b/tests/unit/providers/test_object_py2_py3.py @@ -0,0 +1,93 @@ +"""Object provider tests.""" + +from dependency_injector import providers + + +def test_is_provider(): + assert providers.is_provider(providers.Object(object())) is True + + +def test_init_optional_provides(): + instance = object() + provider = providers.Object() + provider.set_provides(instance) + assert provider.provides is instance + assert provider() is instance + + +def test_set_provides_returns_(): + provider = providers.Object() + assert provider.set_provides(object()) is provider + + +def test_provided_instance_provider(): + provider = providers.Object(object()) + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_call_object_provider(): + obj = object() + assert providers.Object(obj)() is obj + + +def test_call_overridden_object_provider(): + obj1 = object() + obj2 = object() + provider = providers.Object(obj1) + provider.override(providers.Object(obj2)) + assert provider() is obj2 + + +def test_deepcopy(): + provider = providers.Object(1) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider, providers.Object) + + +def test_deepcopy_from_memo(): + provider = providers.Object(1) + provider_copy_memo = providers.Provider() + + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + + assert provider_copy is provider_copy_memo + + +def test_deepcopy_overridden(): + provider = providers.Object(1) + overriding_provider = providers.Provider() + + provider.override(overriding_provider) + + provider_copy = providers.deepcopy(provider) + overriding_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert isinstance(provider, providers.Object) + + assert overriding_provider is not overriding_provider_copy + assert isinstance(overriding_provider_copy, providers.Provider) + + +def test_deepcopy_doesnt_copy_provided_object(): + # Fixes bug #231 + # Details: https://github.com/ets-labs/python-dependency-injector/issues/231 + some_object = object() + provider = providers.Object(some_object) + + provider_copy = providers.deepcopy(provider) + + assert provider() is some_object + assert provider_copy() is some_object + + +def test_repr(): + some_object = object() + provider = providers.Object(some_object) + assert repr(provider) == ( + "".format(repr(some_object), hex(id(provider))) + ) diff --git a/tests/unit/providers/test_provided_instance_py2_py3.py b/tests/unit/providers/test_provided_instance_py2_py3.py index 98b7b6b3..d5ba7730 100644 --- a/tests/unit/providers/test_provided_instance_py2_py3.py +++ b/tests/unit/providers/test_provided_instance_py2_py3.py @@ -1,8 +1,7 @@ -"""Dependency injector provided instance provider unit tests.""" - -import unittest +"""ProvidedInstance provider tests.""" from dependency_injector import containers, providers +from pytest import fixture class Service: @@ -64,146 +63,140 @@ class Container(containers.DeclarativeContainer): ) -class ProvidedInstanceTests(unittest.TestCase): - - def setUp(self): - self.container = Container() - - def test_is_provider(self): - self.assertTrue(providers.is_provider(self.container.service.provided)) - - def test_attribute(self): - client = self.container.client_attribute() - self.assertEqual(client.value, "foo") - - def test_item(self): - client = self.container.client_item() - self.assertEqual(client.value, "foo") - - def test_attribute_item(self): - client = self.container.client_attribute_item() - self.assertEqual(client.value, "foo") - - def test_method_call(self): - client = self.container.client_method_call() - self.assertEqual(client.value, "foo") - - def test_method_closure_call(self): - client = self.container.client_method_closure_call() - self.assertEqual(client.value, "foo") - - def test_provided_call(self): - client = self.container.client_provided_call() - self.assertEqual(client.value, "foo") - - def test_call_overridden(self): - value = "bar" - with self.container.service.override(Service(value)): - self.assertEqual(self.container.client_attribute().value, value) - self.assertEqual(self.container.client_item().value, value) - self.assertEqual(self.container.client_attribute_item().value, value) - self.assertEqual(self.container.client_method_call().value, value) - - def test_repr_provided_instance(self): - provider = self.container.service.provided - self.assertEqual( - "ProvidedInstance(\"{0}\")".format(repr(self.container.service)), - repr(provider), - ) - - def test_repr_attribute_getter(self): - provider = self.container.service.provided.value - self.assertEqual( - "AttributeGetter(\"value\")", - repr(provider), - ) - - def test_repr_item_getter(self): - provider = self.container.service.provided["test-test"] - self.assertEqual( - "ItemGetter(\"test-test\")", - repr(provider), - ) +@fixture +def container(): + return Container() -class LazyInitTests(unittest.TestCase): - - def test_provided_instance(self): - provides = providers.Object(object()) - provider = providers.ProvidedInstance() - provider.set_provides(provides) - self.assertIs(provider.provides, provides) - self.assertIs(provider.set_provides(providers.Provider()), provider) - - def test_attribute_getter(self): - provides = providers.Object(object()) - provider = providers.AttributeGetter() - provider.set_provides(provides) - provider.set_name("__dict__") - self.assertIs(provider.provides, provides) - self.assertEqual(provider.name, "__dict__") - self.assertIs(provider.set_provides(providers.Provider()), provider) - self.assertIs(provider.set_name("__dict__"), provider) - - def test_item_getter(self): - provides = providers.Object({"foo": "bar"}) - provider = providers.ItemGetter() - provider.set_provides(provides) - provider.set_name("foo") - self.assertIs(provider.provides, provides) - self.assertEqual(provider.name, "foo") - self.assertIs(provider.set_provides(providers.Provider()), provider) - self.assertIs(provider.set_name("foo"), provider) - - def test_method_caller(self): - provides = providers.Object(lambda: 42) - provider = providers.MethodCaller() - provider.set_provides(provides) - self.assertIs(provider.provides, provides) - self.assertEqual(provider(), 42) - self.assertIs(provider.set_provides(providers.Provider()), provider) +def test_is_provider(container): + assert providers.is_provider(container.service.provided) is True -class ProvidedInstancePuzzleTests(unittest.TestCase): +def test_attribute(container): + client = container.client_attribute() + assert client.value == "foo" - def test_puzzled(self): - service = providers.Singleton(Service, value="foo-bar") - dependency = providers.Object( - { - "a": { - "b": { - "c1": 10, - "c2": lambda arg: {"arg": arg} - }, +def test_item(container): + client = container.client_item() + assert client.value == "foo" + + +def test_attribute_item(container): + client = container.client_attribute_item() + assert client.value == "foo" + + +def test_method_call(container): + client = container.client_method_call() + assert client.value == "foo" + + +def test_method_closure_call(container): + client = container.client_method_closure_call() + assert client.value == "foo" + + +def test_provided_call(container): + client = container.client_provided_call() + assert client.value == "foo" + + +def test_call_overridden(container): + value = "bar" + with container.service.override(Service(value)): + assert container.client_attribute().value == value + assert container.client_item().value == value + assert container.client_attribute_item().value == value + assert container.client_method_call().value == value + + +def test_repr_provided_instance(container): + provider = container.service.provided + assert repr(provider) == "ProvidedInstance(\"{0}\")".format(repr(container.service)) + + +def test_repr_attribute_getter(container): + provider = container.service.provided.value + assert repr(provider) == "AttributeGetter(\"value\")" + + +def test_repr_item_getter(container): + provider = container.service.provided["test-test"] + assert repr(provider) == "ItemGetter(\"test-test\")" + + +def test_provided_instance(): + provides = providers.Object(object()) + provider = providers.ProvidedInstance() + provider.set_provides(provides) + assert provider.provides is provides + assert provider.set_provides(providers.Provider()) is provider + + +def test_attribute_getter(): + provides = providers.Object(object()) + provider = providers.AttributeGetter() + provider.set_provides(provides) + provider.set_name("__dict__") + assert provider.provides is provides + assert provider.name == "__dict__" + assert provider.set_provides(providers.Provider()) is provider + assert provider.set_name("__dict__") is provider + + +def test_item_getter(): + provides = providers.Object({"foo": "bar"}) + provider = providers.ItemGetter() + provider.set_provides(provides) + provider.set_name("foo") + assert provider.provides is provides + assert provider.name == "foo" + assert provider.set_provides(providers.Provider()) is provider + assert provider.set_name("foo") is provider + + +def test_method_caller(): + provides = providers.Object(lambda: 42) + provider = providers.MethodCaller() + provider.set_provides(provides) + assert provider.provides is provides + assert provider() == 42 + assert provider.set_provides(providers.Provider()) is provider + + +def test_puzzled(): + service = providers.Singleton(Service, value="foo-bar") + + dependency = providers.Object( + { + "a": { + "b": { + "c1": 10, + "c2": lambda arg: {"arg": arg} }, }, - ) + }, + ) - test_list = providers.List( - dependency.provided["a"]["b"]["c1"], - dependency.provided["a"]["b"]["c2"].call(22)["arg"], - dependency.provided["a"]["b"]["c2"].call(service)["arg"], - dependency.provided["a"]["b"]["c2"].call(service)["arg"].value, - dependency.provided["a"]["b"]["c2"].call(service)["arg"].get_value.call(), - ) + test_list = providers.List( + dependency.provided["a"]["b"]["c1"], + dependency.provided["a"]["b"]["c2"].call(22)["arg"], + dependency.provided["a"]["b"]["c2"].call(service)["arg"], + dependency.provided["a"]["b"]["c2"].call(service)["arg"].value, + dependency.provided["a"]["b"]["c2"].call(service)["arg"].get_value.call(), + ) - result = test_list() - - self.assertEqual( - result, - [ - 10, - 22, - service(), - "foo-bar", - "foo-bar", - ], - ) + result = test_list() + assert result == [ + 10, + 22, + service(), + "foo-bar", + "foo-bar", + ] -class ProvidedInstanceInBaseClassTests(unittest.TestCase): - - def test_provided_attribute(self): - provider = providers.Provider() - assert isinstance(provider.provided, providers.ProvidedInstance) +def test_provided_attribute_in_base_class(): + provider = providers.Provider() + assert isinstance(provider.provided, providers.ProvidedInstance) diff --git a/tests/unit/providers/test_provider_py2_py3.py b/tests/unit/providers/test_provider_py2_py3.py new file mode 100644 index 00000000..50f4dc3e --- /dev/null +++ b/tests/unit/providers/test_provider_py2_py3.py @@ -0,0 +1,157 @@ +"""Provider tests.""" + +import warnings + +from dependency_injector import providers, errors +from pytest import fixture, raises + + +@fixture +def provider(): + return providers.Provider() + + +def test_is_provider(provider): + assert providers.is_provider(provider) is True + + +def test_call(provider): + with raises(NotImplementedError): + provider() + + +def test_delegate(provider): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + delegate1 = provider.delegate() + delegate2 = provider.delegate() + + assert isinstance(delegate1, providers.Delegate) + assert delegate1() is provider + + assert isinstance(delegate2, providers.Delegate) + assert delegate2() is provider + + assert delegate1 is not delegate2 + + +def test_provider(provider): + delegate1 = provider.provider + + assert isinstance(delegate1, providers.Delegate) + assert delegate1() is provider + + delegate2 = provider.provider + + assert isinstance(delegate2, providers.Delegate) + assert delegate2() is provider + + assert delegate1 is not delegate2 + + +def test_override(provider): + overriding_provider = providers.Provider() + provider.override(overriding_provider) + assert provider.overridden == (overriding_provider,) + assert provider.last_overriding is overriding_provider + + +def test_double_override(provider): + overriding_provider1 = providers.Object(1) + overriding_provider2 = providers.Object(2) + + provider.override(overriding_provider1) + overriding_provider1.override(overriding_provider2) + + assert provider() == overriding_provider2() + + +def test_overriding_context(provider): + overriding_provider = providers.Provider() + with provider.override(overriding_provider): + assert provider.overridden == (overriding_provider,) + assert provider.overridden == tuple() + assert not provider.overridden + + +def test_override_with_itself(provider): + with raises(errors.Error): + provider.override(provider) + + +def test_override_with_not_provider(provider): + obj = object() + provider.override(obj) + assert provider() is obj + + +def test_reset_last_overriding(provider): + overriding_provider1 = providers.Provider() + overriding_provider2 = providers.Provider() + + provider.override(overriding_provider1) + provider.override(overriding_provider2) + + assert provider.overridden[-1] is overriding_provider2 + assert provider.last_overriding is overriding_provider2 + + provider.reset_last_overriding() + assert provider.overridden[-1] is overriding_provider1 + assert provider.last_overriding is overriding_provider1 + + provider.reset_last_overriding() + assert provider.overridden == tuple() + assert not provider.overridden + assert provider.last_overriding is None + + +def test_reset_last_overriding_of_not_overridden_provider(provider): + with raises(errors.Error): + provider.reset_last_overriding() + + +def test_reset_override(provider): + overriding_provider = providers.Provider() + provider.override(overriding_provider) + + assert provider.overridden + assert provider.overridden == (overriding_provider,) + + provider.reset_override() + + assert provider.overridden == tuple() + + +def test_deepcopy(provider): + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider, providers.Provider) + + +def test_deepcopy_from_memo(provider): + provider_copy_memo = providers.Provider() + provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo}) + assert provider_copy is provider_copy_memo + + +def test_deepcopy_overridden(provider): + overriding_provider = providers.Provider() + + provider.override(overriding_provider) + + provider_copy = providers.deepcopy(provider) + overriding_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert isinstance(provider, providers.Provider) + + assert overriding_provider is not overriding_provider_copy + assert isinstance(overriding_provider_copy, providers.Provider) + + +def test_repr(provider): + assert repr(provider) == ( + "".format(hex(id(provider))) + ) diff --git a/tests/unit/providers/test_resource_py35.py b/tests/unit/providers/test_resource_py35.py deleted file mode 100644 index 67f63821..00000000 --- a/tests/unit/providers/test_resource_py35.py +++ /dev/null @@ -1,665 +0,0 @@ -"""Dependency injector resource provider unit tests.""" - -import asyncio -import inspect -import unittest -from typing import Any - -from dependency_injector import containers, providers, resources, errors - -# Runtime import to get asyncutils module -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -import sys -sys.path.append(_TOP_DIR) - -from asyncutils import AsyncTestCase - - -def init_fn(*args, **kwargs): - return args, kwargs - - -class ResourceTests(unittest.TestCase): - - def test_is_provider(self): - self.assertTrue(providers.is_provider(providers.Resource(init_fn))) - - def test_init_optional_provides(self): - provider = providers.Resource() - provider.set_provides(init_fn) - self.assertIs(provider.provides, init_fn) - self.assertEqual(provider(), (tuple(), dict())) - - def test_set_provides_returns_self(self): - provider = providers.Resource() - self.assertIs(provider.set_provides(init_fn), provider) - - def test_provided_instance_provider(self): - provider = providers.Resource(init_fn) - self.assertIsInstance(provider.provided, providers.ProvidedInstance) - - def test_injection(self): - resource = object() - - def _init(): - _init.counter += 1 - return resource - _init.counter = 0 - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(_init) - dependency1 = providers.List(resource) - dependency2 = providers.List(resource) - - container = Container() - list1 = container.dependency1() - list2 = container.dependency2() - - self.assertEqual(list1, [resource]) - self.assertIs(list1[0], resource) - - self.assertEqual(list2, [resource]) - self.assertIs(list2[0], resource) - - self.assertEqual(_init.counter, 1) - - def test_init_function(self): - def _init(): - _init.counter += 1 - _init.counter = 0 - - provider = providers.Resource(_init) - - result1 = provider() - self.assertIsNone(result1) - self.assertEqual(_init.counter, 1) - - result2 = provider() - self.assertIsNone(result2) - self.assertEqual(_init.counter, 1) - - provider.shutdown() - - def test_init_generator(self): - def _init(): - _init.init_counter += 1 - yield - _init.shutdown_counter += 1 - - _init.init_counter = 0 - _init.shutdown_counter = 0 - - provider = providers.Resource(_init) - - result1 = provider() - self.assertIsNone(result1) - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 0) - - provider.shutdown() - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 1) - - result2 = provider() - self.assertIsNone(result2) - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 1) - - provider.shutdown() - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 2) - - def test_init_class(self): - class TestResource(resources.Resource): - init_counter = 0 - shutdown_counter = 0 - - def init(self): - self.__class__.init_counter += 1 - - def shutdown(self, _): - self.__class__.shutdown_counter += 1 - - provider = providers.Resource(TestResource) - - result1 = provider() - self.assertIsNone(result1) - self.assertEqual(TestResource.init_counter, 1) - self.assertEqual(TestResource.shutdown_counter, 0) - - provider.shutdown() - self.assertEqual(TestResource.init_counter, 1) - self.assertEqual(TestResource.shutdown_counter, 1) - - result2 = provider() - self.assertIsNone(result2) - self.assertEqual(TestResource.init_counter, 2) - self.assertEqual(TestResource.shutdown_counter, 1) - - provider.shutdown() - self.assertEqual(TestResource.init_counter, 2) - self.assertEqual(TestResource.shutdown_counter, 2) - - def test_init_class_generic_typing(self): - # See issue: https://github.com/ets-labs/python-dependency-injector/issues/488 - class TestDependency: - ... - - class TestResource(resources.Resource[TestDependency]): - def init(self, *args: Any, **kwargs: Any) -> TestDependency: - return TestDependency() - - def shutdown(self, resource: TestDependency) -> None: ... - - self.assertTrue(issubclass(TestResource, resources.Resource)) - - def test_init_class_abc_init_definition_is_required(self): - class TestResource(resources.Resource): - ... - - with self.assertRaises(TypeError) as context: - TestResource() - - self.assertIn("Can't instantiate abstract class TestResource", str(context.exception)) - self.assertIn("init", str(context.exception)) - - def test_init_class_abc_shutdown_definition_is_not_required(self): - class TestResource(resources.Resource): - def init(self): - ... - self.assertTrue(hasattr(TestResource(), "shutdown")) - - def test_init_not_callable(self): - provider = providers.Resource(1) - with self.assertRaises(errors.Error): - provider.init() - - def test_init_and_shutdown(self): - def _init(): - _init.init_counter += 1 - yield - _init.shutdown_counter += 1 - - _init.init_counter = 0 - _init.shutdown_counter = 0 - - provider = providers.Resource(_init) - - result1 = provider.init() - self.assertIsNone(result1) - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 0) - - provider.shutdown() - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 1) - - result2 = provider.init() - self.assertIsNone(result2) - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 1) - - provider.shutdown() - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 2) - - def test_shutdown_of_not_initialized(self): - def _init(): - yield - - provider = providers.Resource(_init) - - result = provider.shutdown() - self.assertIsNone(result) - - def test_initialized(self): - provider = providers.Resource(init_fn) - self.assertFalse(provider.initialized) - - provider.init() - self.assertTrue(provider.initialized) - - provider.shutdown() - self.assertFalse(provider.initialized) - - def test_call_with_context_args(self): - provider = providers.Resource(init_fn, "i1", "i2") - self.assertEqual(provider("i3", i4=4), (("i1", "i2", "i3"), {"i4": 4})) - - def test_fluent_interface(self): - provider = providers.Resource(init_fn) \ - .add_args(1, 2) \ - .add_kwargs(a3=3, a4=4) - - self.assertEqual(provider(), ((1, 2), {"a3": 3, "a4": 4})) - - def test_set_args(self): - provider = providers.Resource(init_fn) \ - .add_args(1, 2) \ - .set_args(3, 4) - self.assertEqual(provider.args, (3, 4)) - - def test_clear_args(self): - provider = providers.Resource(init_fn) \ - .add_args(1, 2) \ - .clear_args() - self.assertEqual(provider.args, tuple()) - - def test_set_kwargs(self): - provider = providers.Resource(init_fn) \ - .add_kwargs(a1="i1", a2="i2") \ - .set_kwargs(a3="i3", a4="i4") - self.assertEqual(provider.kwargs, {"a3": "i3", "a4": "i4"}) - - def test_clear_kwargs(self): - provider = providers.Resource(init_fn) \ - .add_kwargs(a1="i1", a2="i2") \ - .clear_kwargs() - self.assertEqual(provider.kwargs, {}) - - def test_call_overridden(self): - provider = providers.Resource(init_fn, 1) - overriding_provider1 = providers.Resource(init_fn, 2) - overriding_provider2 = providers.Resource(init_fn, 3) - - provider.override(overriding_provider1) - provider.override(overriding_provider2) - - instance1 = provider() - instance2 = provider() - - self.assertIs(instance1, instance2) - self.assertEqual(instance1, ((3,), {})) - self.assertEqual(instance2, ((3,), {})) - - def test_deepcopy(self): - provider = providers.Resource(init_fn, 1, 2, a3=3, a4=4) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertEqual(provider.args, provider_copy.args) - self.assertEqual(provider.kwargs, provider_copy.kwargs) - self.assertIsInstance(provider, providers.Resource) - - def test_deepcopy_initialized(self): - provider = providers.Resource(init_fn) - provider.init() - - with self.assertRaises(errors.Error): - providers.deepcopy(provider) - - def test_deepcopy_from_memo(self): - provider = providers.Resource(init_fn) - provider_copy_memo = providers.Resource(init_fn) - - provider_copy = providers.deepcopy( - provider, - memo={id(provider): provider_copy_memo}, - ) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_args(self): - provider = providers.Resource(init_fn) - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) - - provider.add_args(dependent_provider1, dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.args[0] - dependent_provider_copy2 = provider_copy.args[1] - - self.assertNotEqual(provider.args, provider_copy.args) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_kwargs(self): - provider = providers.Resource(init_fn) - dependent_provider1 = providers.Factory(list) - dependent_provider2 = providers.Factory(dict) - - provider.add_kwargs(d1=dependent_provider1, d2=dependent_provider2) - - provider_copy = providers.deepcopy(provider) - dependent_provider_copy1 = provider_copy.kwargs["d1"] - dependent_provider_copy2 = provider_copy.kwargs["d2"] - - self.assertNotEqual(provider.kwargs, provider_copy.kwargs) - - self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) - self.assertIsNot(dependent_provider1, dependent_provider_copy1) - - self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) - self.assertIsNot(dependent_provider2, dependent_provider_copy2) - - def test_deepcopy_overridden(self): - provider = providers.Resource(init_fn) - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertEqual(provider.args, provider_copy.args) - self.assertIsInstance(provider, providers.Resource) - - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - - def test_deepcopy_with_sys_streams(self): - provider = providers.Resource(init_fn) - provider.add_args(sys.stdin, sys.stdout, sys.stderr) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider_copy, providers.Resource) - self.assertIs(provider.args[0], sys.stdin) - self.assertIs(provider.args[1], sys.stdout) - self.assertIs(provider.args[2], sys.stderr) - - def test_repr(self): - provider = providers.Resource(init_fn) - - self.assertEqual( - repr(provider), - "".format( - repr(init_fn), - hex(id(provider)), - ) - ) - - -class AsyncResourceTest(AsyncTestCase): - - def test_init_async_function(self): - resource = object() - - async def _init(): - await asyncio.sleep(0.001) - _init.counter += 1 - return resource - _init.counter = 0 - - provider = providers.Resource(_init) - - result1 = self._run(provider()) - self.assertIs(result1, resource) - self.assertEqual(_init.counter, 1) - - result2 = self._run(provider()) - self.assertIs(result2, resource) - self.assertEqual(_init.counter, 1) - - self._run(provider.shutdown()) - - def test_init_async_generator(self): - resource = object() - - async def _init(): - await asyncio.sleep(0.001) - _init.init_counter += 1 - - yield resource - - await asyncio.sleep(0.001) - _init.shutdown_counter += 1 - - _init.init_counter = 0 - _init.shutdown_counter = 0 - - provider = providers.Resource(_init) - - result1 = self._run(provider()) - self.assertIs(result1, resource) - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 0) - - self._run(provider.shutdown()) - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 1) - - result2 = self._run(provider()) - self.assertIs(result2, resource) - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 1) - - self._run(provider.shutdown()) - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 2) - - def test_init_async_class(self): - resource = object() - - class TestResource(resources.AsyncResource): - init_counter = 0 - shutdown_counter = 0 - - async def init(self): - await asyncio.sleep(0.001) - self.__class__.init_counter += 1 - return resource - - async def shutdown(self, resource_): - await asyncio.sleep(0.001) - self.__class__.shutdown_counter += 1 - assert resource_ is resource - - provider = providers.Resource(TestResource) - - result1 = self._run(provider()) - self.assertIs(result1, resource) - self.assertEqual(TestResource.init_counter, 1) - self.assertEqual(TestResource.shutdown_counter, 0) - - self._run(provider.shutdown()) - self.assertEqual(TestResource.init_counter, 1) - self.assertEqual(TestResource.shutdown_counter, 1) - - result2 = self._run(provider()) - self.assertIs(result2, resource) - self.assertEqual(TestResource.init_counter, 2) - self.assertEqual(TestResource.shutdown_counter, 1) - - self._run(provider.shutdown()) - self.assertEqual(TestResource.init_counter, 2) - self.assertEqual(TestResource.shutdown_counter, 2) - - def test_init_async_class_generic_typing(self): - # See issue: https://github.com/ets-labs/python-dependency-injector/issues/488 - class TestDependency: - ... - - class TestAsyncResource(resources.AsyncResource[TestDependency]): - async def init(self, *args: Any, **kwargs: Any) -> TestDependency: - return TestDependency() - - async def shutdown(self, resource: TestDependency) -> None: ... - - self.assertTrue(issubclass(TestAsyncResource, resources.AsyncResource)) - - def test_init_async_class_abc_init_definition_is_required(self): - class TestAsyncResource(resources.AsyncResource): - ... - - with self.assertRaises(TypeError) as context: - TestAsyncResource() - - self.assertIn("Can't instantiate abstract class TestAsyncResource", str(context.exception)) - self.assertIn("init", str(context.exception)) - - def test_init_async_class_abc_shutdown_definition_is_not_required(self): - class TestAsyncResource(resources.AsyncResource): - async def init(self): - ... - self.assertTrue(hasattr(TestAsyncResource(), "shutdown")) - self.assertTrue(inspect.iscoroutinefunction(TestAsyncResource.shutdown)) - - def test_init_with_error(self): - async def _init(): - raise RuntimeError() - - provider = providers.Resource(_init) - - future = provider() - self.assertTrue(provider.initialized) - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(future) - - self.assertFalse(provider.initialized) - self.assertTrue(provider.is_async_mode_enabled()) - - def test_init_async_gen_with_error(self): - async def _init(): - raise RuntimeError() - yield - - provider = providers.Resource(_init) - - future = provider() - self.assertTrue(provider.initialized) - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(future) - - self.assertFalse(provider.initialized) - self.assertTrue(provider.is_async_mode_enabled()) - - def test_init_async_subclass_with_error(self): - class _Resource(resources.AsyncResource): - async def init(self): - raise RuntimeError() - - async def shutdown(self, resource): - pass - - provider = providers.Resource(_Resource) - - future = provider() - self.assertTrue(provider.initialized) - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(future) - - self.assertFalse(provider.initialized) - self.assertTrue(provider.is_async_mode_enabled()) - - def test_init_with_dependency_to_other_resource(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/361 - async def init_db_connection(db_url: str): - await asyncio.sleep(0.001) - yield {"connection": "ok", "url": db_url} - - async def init_user_session(db): - await asyncio.sleep(0.001) - yield {"session": "ok", "db": db} - - class Container(containers.DeclarativeContainer): - config = providers.Configuration() - - db_connection = providers.Resource( - init_db_connection, - db_url=config.db_url, - ) - - user_session = providers.Resource( - init_user_session, - db=db_connection - ) - - async def main(): - container = Container(config={"db_url": "postgres://..."}) - try: - return await container.user_session() - finally: - await container.shutdown_resources() - - result = self._run(main()) - - self.assertEqual( - result, - {"session": "ok", "db": {"connection": "ok", "url": "postgres://..."}}, - ) - - def test_init_and_shutdown_methods(self): - async def _init(): - await asyncio.sleep(0.001) - _init.init_counter += 1 - - yield - - await asyncio.sleep(0.001) - _init.shutdown_counter += 1 - - _init.init_counter = 0 - _init.shutdown_counter = 0 - - provider = providers.Resource(_init) - - self._run(provider.init()) - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 0) - - self._run(provider.shutdown()) - self.assertEqual(_init.init_counter, 1) - self.assertEqual(_init.shutdown_counter, 1) - - self._run(provider.init()) - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 1) - - self._run(provider.shutdown()) - self.assertEqual(_init.init_counter, 2) - self.assertEqual(_init.shutdown_counter, 2) - - def test_shutdown_of_not_initialized(self): - async def _init(): - yield - - provider = providers.Resource(_init) - provider.enable_async_mode() - - result = self._run(provider.shutdown()) - self.assertIsNone(result) - - def test_concurrent_init(self): - resource = object() - - async def _init(): - await asyncio.sleep(0.001) - _init.counter += 1 - return resource - _init.counter = 0 - - provider = providers.Resource(_init) - - result1, result2 = self._run( - asyncio.gather( - provider(), - provider() - ), - ) - - self.assertIs(result1, resource) - self.assertEqual(_init.counter, 1) - - self.assertIs(result2, resource) - self.assertEqual(_init.counter, 1) diff --git a/tests/unit/providers/test_selector_py2_py3.py b/tests/unit/providers/test_selector_py2_py3.py index 8ccd4aca..c7e96d01 100644 --- a/tests/unit/providers/test_selector_py2_py3.py +++ b/tests/unit/providers/test_selector_py2_py3.py @@ -1,223 +1,209 @@ -"""Dependency injector selector provider unit tests.""" +"""Selector provider tests.""" import functools import itertools import sys -import unittest - from dependency_injector import providers, errors +from pytest import fixture, mark, raises -class SelectorTests(unittest.TestCase): +@fixture +def switch(): + return providers.Configuration() - selector = providers.Configuration() - def test_is_provider(self): - self.assertTrue(providers.is_provider(providers.Selector(self.selector))) +@fixture +def one(): + return providers.Object(1) - def test_init_optional(self): - one = providers.Object(1) - two = providers.Object(2) - provider = providers.Selector() - provider.set_selector(self.selector) - provider.set_providers(one=one, two=two) +@fixture +def two(): + return providers.Object(2) - self.assertEqual(provider.providers, {"one": one, "two": two}) - with self.selector.override("one"): - self.assertEqual(provider(), one()) - with self.selector.override("two"): - self.assertEqual(provider(), two()) - def test_set_selector_returns_self(self): - provider = providers.Selector() - self.assertIs(provider.set_selector(self.selector), provider) +@fixture +def selector_type(): + return "default" - def test_set_providers_returns_self(self): - provider = providers.Selector() - self.assertIs(provider.set_providers(one=providers.Provider()), provider) - def test_provided_instance_provider(self): - provider = providers.Selector(self.selector) - self.assertIsInstance(provider.provided, providers.ProvidedInstance) - - def test_call(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - - with self.selector.override("one"): - self.assertEqual(provider(), 1) - - with self.selector.override("two"): - self.assertEqual(provider(), 2) - - def test_call_undefined_provider(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - - with self.selector.override("three"): - with self.assertRaises(errors.Error): - provider() - - def test_call_selector_is_none(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - - with self.selector.override(None): - with self.assertRaises(errors.Error): - provider() - - def test_call_any_callable(self): - provider = providers.Selector( - functools.partial(next, itertools.cycle(["one", "two"])), - one=providers.Object(1), - two=providers.Object(2), - ) - - self.assertEqual(provider(), 1) - self.assertEqual(provider(), 2) - self.assertEqual(provider(), 1) - self.assertEqual(provider(), 2) - - def test_call_with_context_args(self): - provider = providers.Selector( - self.selector, - one=providers.Callable(lambda *args, **kwargs: (args, kwargs)), - ) - - with self.selector.override("one"): - args, kwargs = provider(1, 2, three=3, four=4) - - self.assertEqual(args, (1, 2)) - self.assertEqual(kwargs, {"three": 3, "four": 4}) - - def test_getattr(self): - provider_one = providers.Object(1) - provider_two = providers.Object(2) - - provider = providers.Selector( - self.selector, - one=provider_one, - two=provider_two, - ) - - self.assertIs(provider.one, provider_one) - self.assertIs(provider.two, provider_two) - - def test_getattr_attribute_error(self): - provider_one = providers.Object(1) - provider_two = providers.Object(2) - - provider = providers.Selector( - self.selector, - one=provider_one, - two=provider_two, - ) - - with self.assertRaises(AttributeError): - _ = provider.provider_three - - def test_call_overridden(self): - provider = providers.Selector(self.selector, sample=providers.Object(1)) - overriding_provider1 = providers.Selector(self.selector, sample=providers.Object(2)) - overriding_provider2 = providers.Selector(self.selector, sample=providers.Object(3)) - - provider.override(overriding_provider1) - provider.override(overriding_provider2) - - with self.selector.override("sample"): - self.assertEqual(provider(), 3) - - def test_providers_attribute(self): - provider_one = providers.Object(1) - provider_two = providers.Object(2) - - provider = providers.Selector( - self.selector, - one=provider_one, - two=provider_two, - ) - - self.assertEqual(provider.providers, {"one": provider_one, "two": provider_two}) - - def test_deepcopy(self): - provider = providers.Selector(self.selector) - - provider_copy = providers.deepcopy(provider) - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Selector) - - def test_deepcopy_from_memo(self): - provider = providers.Selector(self.selector) - provider_copy_memo = providers.Selector(self.selector) - - provider_copy = providers.deepcopy( - provider, - memo={id(provider): provider_copy_memo}, - ) - - self.assertIs(provider_copy, provider_copy_memo) - - def test_deepcopy_overridden(self): - provider = providers.Selector(self.selector) - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider, providers.Selector) - - self.assertIsNot(object_provider, object_provider_copy) - self.assertIsInstance(object_provider_copy, providers.Object) - - def test_deepcopy_with_sys_streams(self): - provider = providers.Selector( - self.selector, +@fixture +def selector(selector_type, switch, one, two): + if selector_type == "default": + return providers.Selector(switch, one=one, two=two) + elif selector_type == "empty": + return providers.Selector() + elif selector_type == "sys-streams": + return providers.Selector( + switch, stdin=providers.Object(sys.stdin), stdout=providers.Object(sys.stdout), stderr=providers.Object(sys.stderr), - ) + else: + raise ValueError("Unknown selector type \"{0}\"".format(selector_type)) - provider_copy = providers.deepcopy(provider) - self.assertIsNot(provider, provider_copy) - self.assertIsInstance(provider_copy, providers.Selector) +def test_is_provider(selector): + assert providers.is_provider(selector) is True - with self.selector.override("stdin"): - self.assertIs(provider(), sys.stdin) - with self.selector.override("stdout"): - self.assertIs(provider(), sys.stdout) +@mark.parametrize("selector_type", ["empty"]) +def test_init_optional(selector, switch, one, two): + selector.set_selector(switch) + selector.set_providers(one=one, two=two) - with self.selector.override("stderr"): - self.assertIs(provider(), sys.stderr) + assert selector.providers == {"one": one, "two": two} + with switch.override("one"): + assert selector() == one() + with switch.override("two"): + assert selector() == two() - def test_repr(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - self.assertIn( - "".format(repr(container), hex(id(provider))) + ) diff --git a/tests/unit/providers/test_singletons_py2_py3.py b/tests/unit/providers/test_singletons_py2_py3.py deleted file mode 100644 index 1dff4d15..00000000 --- a/tests/unit/providers/test_singletons_py2_py3.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Dependency injector singleton providers unit tests.""" - -import unittest - -from dependency_injector import ( - providers, - errors, -) - -from .singleton_common import Example, _BaseSingletonTestCase - - -class SingletonTests(_BaseSingletonTestCase, unittest.TestCase): - - singleton_cls = providers.Singleton - - def test_repr(self): - provider = self.singleton_cls(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class DelegatedSingletonTests(_BaseSingletonTestCase, unittest.TestCase): - - singleton_cls = providers.DelegatedSingleton - - def test_is_delegated_provider(self): - provider = self.singleton_cls(object) - self.assertTrue(providers.is_delegated(provider)) - - def test_repr(self): - provider = self.singleton_cls(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class ThreadLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase): - - singleton_cls = providers.ThreadLocalSingleton - - def test_repr(self): - provider = providers.ThreadLocalSingleton(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - def test_reset(self): - provider = providers.ThreadLocalSingleton(Example) - - instance1 = provider() - self.assertIsInstance(instance1, Example) - - provider.reset() - - instance2 = provider() - self.assertIsInstance(instance2, Example) - - self.assertIsNot(instance1, instance2) - - def test_reset_clean(self): - provider = providers.ThreadLocalSingleton(Example) - instance1 = provider() - - provider.reset() - provider.reset() - - instance2 = provider() - self.assertIsNot(instance1, instance2) - - -class DelegatedThreadLocalSingletonTests(_BaseSingletonTestCase, - unittest.TestCase): - - singleton_cls = providers.DelegatedThreadLocalSingleton - - def test_is_delegated_provider(self): - provider = self.singleton_cls(object) - self.assertTrue(providers.is_delegated(provider)) - - def test_repr(self): - provider = self.singleton_cls(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class ThreadSafeSingletonTests(_BaseSingletonTestCase, unittest.TestCase): - - singleton_cls = providers.ThreadSafeSingleton - - def test_repr(self): - provider = self.singleton_cls(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class DelegatedThreadSafeSingletonTests(_BaseSingletonTestCase, - unittest.TestCase): - - singleton_cls = providers.DelegatedThreadSafeSingleton - - def test_is_delegated_provider(self): - provider = self.singleton_cls(object) - self.assertTrue(providers.is_delegated(provider)) - - def test_repr(self): - provider = self.singleton_cls(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class AbstractSingletonTests(unittest.TestCase): - - def test_inheritance(self): - self.assertIsInstance(providers.AbstractSingleton(Example), - providers.BaseSingleton) - - def test_call_overridden_by_singleton(self): - provider = providers.AbstractSingleton(object) - provider.override(providers.Singleton(Example)) - - self.assertIsInstance(provider(), Example) - - def test_call_overridden_by_delegated_singleton(self): - provider = providers.AbstractSingleton(object) - provider.override(providers.DelegatedSingleton(Example)) - - self.assertIsInstance(provider(), Example) - - def test_call_not_overridden(self): - provider = providers.AbstractSingleton(object) - - with self.assertRaises(errors.Error): - provider() - - def test_reset_overridden(self): - provider = providers.AbstractSingleton(object) - provider.override(providers.Singleton(Example)) - - instance1 = provider() - - provider.reset() - - instance2 = provider() - - self.assertIsNot(instance1, instance2) - self.assertIsInstance(instance1, Example) - self.assertIsInstance(instance2, Example) - - def test_reset_not_overridden(self): - provider = providers.AbstractSingleton(object) - - with self.assertRaises(errors.Error): - provider.reset() - - def test_override_by_not_singleton(self): - provider = providers.AbstractSingleton(object) - - with self.assertRaises(errors.Error): - provider.override(providers.Factory(object)) - - def test_repr(self): - provider = providers.AbstractSingleton(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - -class SingletonDelegateTests(unittest.TestCase): - - def setUp(self): - self.delegated = providers.Singleton(Example) - self.delegate = providers.SingletonDelegate(self.delegated) - - def test_is_delegate(self): - self.assertIsInstance(self.delegate, providers.Delegate) - - def test_init_with_not_singleton(self): - self.assertRaises(errors.Error, - providers.SingletonDelegate, - providers.Object(object())) diff --git a/tests/unit/providers/test_singletons_py3.py b/tests/unit/providers/test_singletons_py3.py deleted file mode 100644 index 6a09f763..00000000 --- a/tests/unit/providers/test_singletons_py3.py +++ /dev/null @@ -1,42 +0,0 @@ -import unittest - -from dependency_injector import providers - -from .singleton_common import Example, _BaseSingletonTestCase - - -class ContextLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase): - - singleton_cls = providers.ContextLocalSingleton - - def test_repr(self): - provider = providers.ContextLocalSingleton(Example) - - self.assertEqual(repr(provider), - "".format( - repr(Example), - hex(id(provider)))) - - def test_reset(self): - provider = providers.ContextLocalSingleton(Example) - - instance1 = provider() - self.assertIsInstance(instance1, Example) - - provider.reset() - - instance2 = provider() - self.assertIsInstance(instance2, Example) - - self.assertIsNot(instance1, instance2) - - def test_reset_clean(self): - provider = providers.ContextLocalSingleton(Example) - instance1 = provider() - - provider.reset() - provider.reset() - - instance2 = provider() - self.assertIsNot(instance1, instance2) diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py deleted file mode 100644 index 597a0119..00000000 --- a/tests/unit/providers/test_traversal_py3.py +++ /dev/null @@ -1,875 +0,0 @@ -import unittest - -from dependency_injector import containers, providers - - -class TraverseTests(unittest.TestCase): - - def test_traverse_cycled_graph(self): - provider1 = providers.Provider() - - provider2 = providers.Provider() - provider2.override(provider1) - - provider3 = providers.Provider() - provider3.override(provider2) - - provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3 - - all_providers = list(providers.traverse(provider1)) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - def test_traverse_types_filtering(self): - provider1 = providers.Resource(dict) - provider2 = providers.Resource(dict) - provider3 = providers.Provider() - - provider = providers.Provider() - - provider.override(provider1) - provider.override(provider2) - provider.override(provider3) - - all_providers = list(providers.traverse(provider, types=[providers.Resource])) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - -class ProviderTests(unittest.TestCase): - - def test_traversal_overriding(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - provider3 = providers.Provider() - - provider = providers.Provider() - - provider.override(provider1) - provider.override(provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - def test_traversal_overriding_nested(self): - provider1 = providers.Provider() - - provider2 = providers.Provider() - provider2.override(provider1) - - provider3 = providers.Provider() - provider3.override(provider2) - - provider = providers.Provider() - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - def test_traverse_types_filtering(self): - provider1 = providers.Resource(dict) - provider2 = providers.Resource(dict) - provider3 = providers.Provider() - - provider = providers.Provider() - - provider.override(provider1) - provider.override(provider2) - provider.override(provider3) - - all_providers = list(provider.traverse(types=[providers.Resource])) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - -class ObjectTests(unittest.TestCase): - - def test_traversal(self): - provider = providers.Object("string") - all_providers = list(provider.traverse()) - self.assertEqual(len(all_providers), 0) - - def test_traversal_provider(self): - another_provider = providers.Provider() - provider = providers.Object(another_provider) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 1) - self.assertIn(another_provider, all_providers) - - def test_traversal_provider_and_overriding(self): - another_provider_1 = providers.Provider() - another_provider_2 = providers.Provider() - another_provider_3 = providers.Provider() - - provider = providers.Object(another_provider_1) - - provider.override(another_provider_2) - provider.override(another_provider_3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(another_provider_1, all_providers) - self.assertIn(another_provider_2, all_providers) - self.assertIn(another_provider_3, all_providers) - - -class DelegateTests(unittest.TestCase): - - def test_traversal_provider(self): - another_provider = providers.Provider() - provider = providers.Delegate(another_provider) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 1) - self.assertIn(another_provider, all_providers) - - def test_traversal_provider_and_overriding(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - - provider3 = providers.Provider() - provider3.override(provider2) - - provider = providers.Delegate(provider1) - - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - -class DependencyTests(unittest.TestCase): - - def test_traversal(self): - provider = providers.Dependency() - all_providers = list(provider.traverse()) - self.assertEqual(len(all_providers), 0) - - def test_traversal_default(self): - another_provider = providers.Provider() - provider = providers.Dependency(default=another_provider) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 1) - self.assertIn(another_provider, all_providers) - - def test_traversal_overriding(self): - provider1 = providers.Provider() - - provider2 = providers.Provider() - provider2.override(provider1) - - provider = providers.Dependency() - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - -class DependenciesContainerTests(unittest.TestCase): - - def test_traversal(self): - provider = providers.DependenciesContainer() - all_providers = list(provider.traverse()) - self.assertEqual(len(all_providers), 0) - - def test_traversal_default(self): - another_provider = providers.Provider() - provider = providers.DependenciesContainer(default=another_provider) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 1) - self.assertIn(another_provider, all_providers) - - def test_traversal_fluent_interface(self): - provider = providers.DependenciesContainer() - provider1 = provider.provider1 - provider2 = provider.provider2 - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traversal_overriding(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - provider3 = providers.DependenciesContainer( - provider1=provider1, - provider2=provider2, - ) - - provider = providers.DependenciesContainer() - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 5) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - self.assertIn(provider.provider1, all_providers) - self.assertIn(provider.provider2, all_providers) - - -class CallableTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Callable(dict) - all_providers = list(provider.traverse()) - self.assertEqual(len(all_providers), 0) - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Callable(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Callable(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - - provider = providers.Callable(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - provider2 = providers.Object("bar") - provider3 = providers.Object("baz") - - provider = providers.Callable(provider1, provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - -class ConfigurationTests(unittest.TestCase): - - def test_traverse(self): - config = providers.Configuration(default={"option1": {"option2": "option2"}}) - option1 = config.option1 - option2 = config.option1.option2 - option3 = config.option1[config.option1.option2] - - all_providers = list(config.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(option1, all_providers) - self.assertIn(option2, all_providers) - self.assertIn(option3, all_providers) - - def test_traverse_typed(self): - config = providers.Configuration() - option = config.option - typed_option = config.option.as_int() - - all_providers = list(typed_option.traverse()) - - self.assertEqual(len(all_providers), 1) - self.assertIn(option, all_providers) - - def test_traverse_overridden(self): - options = {"option1": {"option2": "option2"}} - config = providers.Configuration() - config.from_dict(options) - - all_providers = list(config.traverse()) - - self.assertEqual(len(all_providers), 1) - overridden, = all_providers - self.assertEqual(overridden(), options) - self.assertIs(overridden, config.last_overriding) - - def test_traverse_overridden_option_1(self): - options = {"option2": "option2"} - config = providers.Configuration() - config.option1.from_dict(options) - - all_providers = list(config.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(config.option1, all_providers) - self.assertIn(config.last_overriding, all_providers) - - def test_traverse_overridden_option_2(self): - options = {"option2": "option2"} - config = providers.Configuration() - config.option1.from_dict(options) - - all_providers = list(config.option1.traverse()) - - self.assertEqual(len(all_providers), 0) - - -class FactoryTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Factory(dict) - all_providers = list(provider.traverse()) - self.assertEqual(len(all_providers), 0) - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Factory(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Factory(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_attributes(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Factory(dict) - provider.add_attributes(foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - - provider = providers.Factory(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - provider2 = providers.Object("bar") - provider3 = providers.Object("baz") - - provider = providers.Factory(provider1, provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - -class FactoryAggregateTests(unittest.TestCase): - - def test_traverse(self): - factory1 = providers.Factory(dict) - factory2 = providers.Factory(list) - provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(factory1, all_providers) - self.assertIn(factory2, all_providers) - - -class BaseSingletonTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Singleton(dict) - all_providers = list(provider.traverse()) - self.assertEqual(len(all_providers), 0) - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Singleton(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Singleton(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_attributes(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Singleton(dict) - provider.add_attributes(foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - - provider = providers.Singleton(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - provider2 = providers.Object("bar") - provider3 = providers.Object("baz") - - provider = providers.Singleton(provider1, provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - -class ListTests(unittest.TestCase): - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.List("foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider3 = providers.List(provider1, provider2) - - provider = providers.List("foo") - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - -class DictTests(unittest.TestCase): - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Dict(foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider3 = providers.Dict(bar=provider1, baz=provider2) - - provider = providers.Dict(foo="foo") - provider.override(provider3) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) - - -class ResourceTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Resource(dict) - all_providers = list(provider.traverse()) - self.assertEqual(len(all_providers), 0) - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Resource(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Resource(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Resource(list) - provider2 = providers.Resource(tuple) - - provider = providers.Resource(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - - provider = providers.Resource(provider1) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 1) - self.assertIn(provider1, all_providers) - - -class ContainerTests(unittest.TestCase): - - def test_traverse(self): - class Container(containers.DeclarativeContainer): - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - provider = providers.Container(Container) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertEqual( - {provider.provides for provider in all_providers}, - {list, dict}, - ) - - def test_traverse_overridden(self): - class Container1(containers.DeclarativeContainer): - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - class Container2(containers.DeclarativeContainer): - provider1 = providers.Callable(tuple) - provider2 = providers.Callable(str) - - container2 = Container2() - - provider = providers.Container(Container1) - provider.override(container2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 5) - self.assertEqual( - { - provider.provides - for provider in all_providers - if isinstance(provider, providers.Callable) - }, - {list, dict, tuple, str}, - ) - self.assertIn(provider.last_overriding, all_providers) - self.assertIs(provider.last_overriding(), container2) - - -class SelectorTests(unittest.TestCase): - - def test_traverse(self): - switch = lambda: "provider1" - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - provider = providers.Selector( - switch, - provider1=provider1, - provider2=provider2, - ) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_switch(self): - switch = providers.Callable(lambda: "provider1") - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - provider = providers.Selector( - switch, - provider1=provider1, - provider2=provider2, - ) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(switch, all_providers) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - selector1 = providers.Selector(lambda: "provider1", provider1=provider1) - - provider = providers.Selector( - lambda: "provider2", - provider2=provider2, - ) - provider.override(selector1) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(selector1, all_providers) - - -class ProvidedInstanceTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provider = provider1.provided - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 1) - self.assertIn(provider1, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - - provider = provider1.provided - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - - -class AttributeGetterTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provided = provider1.provided - provider = provided.attr - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provided, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provided = provider1.provided - provider2 = providers.Provider() - - provider = provided.attr - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provided, all_providers) - - -class ItemGetterTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provided = provider1.provided - provider = provided["item"] - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 2) - self.assertIn(provider1, all_providers) - self.assertIn(provided, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provided = provider1.provided - provider2 = providers.Provider() - - provider = provided["item"] - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provided, all_providers) - - -class MethodCallerTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider = method.call() - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 3) - self.assertIn(provider1, all_providers) - self.assertIn(provided, all_providers) - self.assertIn(method, all_providers) - - def test_traverse_args(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider2 = providers.Provider() - provider = method.call("foo", provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 4) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provided, all_providers) - self.assertIn(method, all_providers) - - def test_traverse_kwargs(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider2 = providers.Provider() - provider = method.call(foo="foo", bar=provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 4) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provided, all_providers) - self.assertIn(method, all_providers) - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider2 = providers.Provider() - - provider = method.call() - provider.override(provider2) - - all_providers = list(provider.traverse()) - - self.assertEqual(len(all_providers), 4) - self.assertIn(provider1, all_providers) - self.assertIn(provider2, all_providers) - self.assertIn(provided, all_providers) - self.assertIn(method, all_providers) diff --git a/tests/unit/providers/test_types_py36.py b/tests/unit/providers/test_types_py36.py index acaabb2a..430a1857 100644 --- a/tests/unit/providers/test_types_py36.py +++ b/tests/unit/providers/test_types_py36.py @@ -1,4 +1,4 @@ -import unittest +"""Provider typing in runtime tests.""" from dependency_injector import providers @@ -7,9 +7,7 @@ class SomeClass: ... -class TypesTest(unittest.TestCase): - - def test_provider(self): - provider: providers.Provider[SomeClass] = providers.Factory(SomeClass) - some_object = provider() - self.assertIsInstance(some_object, SomeClass) +def test_provider(): + provider: providers.Provider[SomeClass] = providers.Factory(SomeClass) + some_object = provider() + assert isinstance(some_object, SomeClass) diff --git a/tests/unit/providers/test_utils_py2_py3.py b/tests/unit/providers/test_utils_py2_py3.py deleted file mode 100644 index 3ff105a3..00000000 --- a/tests/unit/providers/test_utils_py2_py3.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Dependency injector provider utils unit tests.""" - -import unittest - -from dependency_injector import ( - providers, - errors, -) - - -class IsProviderTests(unittest.TestCase): - - def test_with_instance(self): - self.assertTrue(providers.is_provider(providers.Provider())) - - def test_with_class(self): - self.assertFalse(providers.is_provider(providers.Provider)) - - def test_with_string(self): - self.assertFalse(providers.is_provider("some_string")) - - def test_with_object(self): - self.assertFalse(providers.is_provider(object())) - - def test_with_subclass_instance(self): - class SomeProvider(providers.Provider): - pass - - self.assertTrue(providers.is_provider(SomeProvider())) - - def test_with_class_with_getattr(self): - class SomeClass(object): - def __getattr__(self, _): - return False - - self.assertFalse(providers.is_provider(SomeClass())) - - -class EnsureIsProviderTests(unittest.TestCase): - - def test_with_instance(self): - provider = providers.Provider() - self.assertIs(providers.ensure_is_provider(provider), provider) - - def test_with_class(self): - self.assertRaises(errors.Error, - providers.ensure_is_provider, - providers.Provider) - - def test_with_string(self): - self.assertRaises(errors.Error, - providers.ensure_is_provider, - "some_string") - - def test_with_object(self): - self.assertRaises(errors.Error, providers.ensure_is_provider, object()) diff --git a/tests/unit/providers/traversal/__init__.py b/tests/unit/providers/traversal/__init__.py new file mode 100644 index 00000000..c95aa224 --- /dev/null +++ b/tests/unit/providers/traversal/__init__.py @@ -0,0 +1 @@ +"""Traversal tests.""" diff --git a/tests/unit/providers/traversal/test_attribute_getter_py3.py b/tests/unit/providers/traversal/test_attribute_getter_py3.py new file mode 100644 index 00000000..27db8584 --- /dev/null +++ b/tests/unit/providers/traversal/test_attribute_getter_py3.py @@ -0,0 +1,31 @@ +"""AttributeGetter provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provided = provider1.provided + provider = provided.attr + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provided in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provided = provider1.provided + provider2 = providers.Provider() + + provider = provided.attr + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers diff --git a/tests/unit/providers/traversal/test_callable_py3.py b/tests/unit/providers/traversal/test_callable_py3.py new file mode 100644 index 00000000..aebf603a --- /dev/null +++ b/tests/unit/providers/traversal/test_callable_py3.py @@ -0,0 +1,64 @@ +"""Callable provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Callable(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Callable(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Callable(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + + provider = providers.Callable(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + provider2 = providers.Object("bar") + provider3 = providers.Object("baz") + + provider = providers.Callable(provider1, provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_configuration_py3.py b/tests/unit/providers/traversal/test_configuration_py3.py new file mode 100644 index 00000000..8b68f673 --- /dev/null +++ b/tests/unit/providers/traversal/test_configuration_py3.py @@ -0,0 +1,63 @@ +"""Configuration provider tests.""" + +from dependency_injector import providers + + +def test_traverse(): + config = providers.Configuration(default={"option1": {"option2": "option2"}}) + option1 = config.option1 + option2 = config.option1.option2 + option3 = config.option1[config.option1.option2] + + all_providers = list(config.traverse()) + + assert len(all_providers) == 3 + assert option1 in all_providers + assert option2 in all_providers + assert option3 in all_providers + + +def test_traverse_typed(): + config = providers.Configuration() + option = config.option + typed_option = config.option.as_int() + + all_providers = list(typed_option.traverse()) + + assert len(all_providers) == 1 + assert option in all_providers + + +def test_traverse_overridden(): + options = {"option1": {"option2": "option2"}} + config = providers.Configuration() + config.from_dict(options) + + all_providers = list(config.traverse()) + + assert len(all_providers) == 1 + overridden, = all_providers + assert overridden() == options + assert overridden is config.last_overriding + + +def test_traverse_overridden_option_1(): + options = {"option2": "option2"} + config = providers.Configuration() + config.option1.from_dict(options) + + all_providers = list(config.traverse()) + + assert len(all_providers) == 2 + assert config.option1 in all_providers + assert config.last_overriding in all_providers + + +def test_traverse_overridden_option_2(): + options = {"option2": "option2"} + config = providers.Configuration() + config.option1.from_dict(options) + + all_providers = list(config.option1.traverse()) + + assert len(all_providers) == 0 diff --git a/tests/unit/providers/traversal/test_container_py3.py b/tests/unit/providers/traversal/test_container_py3.py new file mode 100644 index 00000000..ebc6cca6 --- /dev/null +++ b/tests/unit/providers/traversal/test_container_py3.py @@ -0,0 +1,42 @@ +"""Container provider traversal tests.""" + +from dependency_injector import containers, providers + + +def test_traverse(): + class Container(containers.DeclarativeContainer): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Container(Container) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert {list, dict} == {provider.provides for provider in all_providers} + + +def test_traverse_overridden(): + class Container1(containers.DeclarativeContainer): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + class Container2(containers.DeclarativeContainer): + provider1 = providers.Callable(tuple) + provider2 = providers.Callable(str) + + container2 = Container2() + + provider = providers.Container(Container1) + provider.override(container2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 5 + assert {list, dict, tuple, str} == { + provider.provides + for provider in all_providers + if isinstance(provider, providers.Callable) + } + assert provider.last_overriding in all_providers + assert provider.last_overriding() is container2 diff --git a/tests/unit/providers/traversal/test_delegate_py3.py b/tests/unit/providers/traversal/test_delegate_py3.py new file mode 100644 index 00000000..c251ec2f --- /dev/null +++ b/tests/unit/providers/traversal/test_delegate_py3.py @@ -0,0 +1,32 @@ +"""Delegate provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal_provider(): + another_provider = providers.Provider() + provider = providers.Delegate(another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_provider_and_overriding(): + provider1 = providers.Provider() + provider2 = providers.Provider() + + provider3 = providers.Provider() + provider3.override(provider2) + + provider = providers.Delegate(provider1) + + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_dependencies_container_py3.py b/tests/unit/providers/traversal/test_dependencies_container_py3.py new file mode 100644 index 00000000..38d8b72a --- /dev/null +++ b/tests/unit/providers/traversal/test_dependencies_container_py3.py @@ -0,0 +1,52 @@ +"""DependenciesContainer provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal(): + provider = providers.DependenciesContainer() + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traversal_default(): + another_provider = providers.Provider() + provider = providers.DependenciesContainer(default=another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_fluent_interface(): + provider = providers.DependenciesContainer() + provider1 = provider.provider1 + provider2 = provider.provider2 + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traversal_overriding(): + provider1 = providers.Provider() + provider2 = providers.Provider() + provider3 = providers.DependenciesContainer( + provider1=provider1, + provider2=provider2, + ) + + provider = providers.DependenciesContainer() + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 5 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + assert provider.provider1 in all_providers + assert provider.provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_dependency_py3.py b/tests/unit/providers/traversal/test_dependency_py3.py new file mode 100644 index 00000000..4939a135 --- /dev/null +++ b/tests/unit/providers/traversal/test_dependency_py3.py @@ -0,0 +1,35 @@ +"""Dependency provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal(): + provider = providers.Dependency() + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traversal_default(): + another_provider = providers.Provider() + provider = providers.Dependency(default=another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_overriding(): + provider1 = providers.Provider() + + provider2 = providers.Provider() + provider2.override(provider1) + + provider = providers.Dependency() + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_dict_py3.py b/tests/unit/providers/traversal/test_dict_py3.py new file mode 100644 index 00000000..9469d48f --- /dev/null +++ b/tests/unit/providers/traversal/test_dict_py3.py @@ -0,0 +1,31 @@ +"""Dict provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Dict(foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider3 = providers.Dict(bar=provider1, baz=provider2) + + provider = providers.Dict(foo="foo") + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_factory_aggregate_py3.py b/tests/unit/providers/traversal/test_factory_aggregate_py3.py new file mode 100644 index 00000000..54bd8f6a --- /dev/null +++ b/tests/unit/providers/traversal/test_factory_aggregate_py3.py @@ -0,0 +1,15 @@ +"""FactoryAggregate provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + factory1 = providers.Factory(dict) + factory2 = providers.Factory(list) + provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert factory1 in all_providers + assert factory2 in all_providers diff --git a/tests/unit/providers/traversal/test_factory_py3.py b/tests/unit/providers/traversal/test_factory_py3.py new file mode 100644 index 00000000..57bdf25e --- /dev/null +++ b/tests/unit/providers/traversal/test_factory_py3.py @@ -0,0 +1,77 @@ +"""Factory provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Factory(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Factory(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Factory(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_attributes(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Factory(dict) + provider.add_attributes(foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + + provider = providers.Factory(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + provider2 = providers.Object("bar") + provider3 = providers.Object("baz") + + provider = providers.Factory(provider1, provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_item_getter_py3.py b/tests/unit/providers/traversal/test_item_getter_py3.py new file mode 100644 index 00000000..5629b536 --- /dev/null +++ b/tests/unit/providers/traversal/test_item_getter_py3.py @@ -0,0 +1,31 @@ +"""ItemGetter provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provided = provider1.provided + provider = provided["item"] + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provided in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provided = provider1.provided + provider2 = providers.Provider() + + provider = provided["item"] + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers diff --git a/tests/unit/providers/traversal/test_list_py3.py b/tests/unit/providers/traversal/test_list_py3.py new file mode 100644 index 00000000..da6b3b74 --- /dev/null +++ b/tests/unit/providers/traversal/test_list_py3.py @@ -0,0 +1,31 @@ +"""List provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.List("foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider3 = providers.List(provider1, provider2) + + provider = providers.List("foo") + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_method_caller_py3.py b/tests/unit/providers/traversal/test_method_caller_py3.py new file mode 100644 index 00000000..47bbfbab --- /dev/null +++ b/tests/unit/providers/traversal/test_method_caller_py3.py @@ -0,0 +1,67 @@ +"""MethodCaller provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider = method.call() + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provided in all_providers + assert method in all_providers + + +def test_traverse_args(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + provider = method.call("foo", provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 4 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers + assert method in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + provider = method.call(foo="foo", bar=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 4 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers + assert method in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + + provider = method.call() + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 4 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers + assert method in all_providers diff --git a/tests/unit/providers/traversal/test_object_py3.py b/tests/unit/providers/traversal/test_object_py3.py new file mode 100644 index 00000000..6c55b93b --- /dev/null +++ b/tests/unit/providers/traversal/test_object_py3.py @@ -0,0 +1,37 @@ +"""Object provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal(): + provider = providers.Object("string") + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traversal_provider(): + another_provider = providers.Provider() + provider = providers.Object(another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_provider_and_overriding(): + another_provider_1 = providers.Provider() + another_provider_2 = providers.Provider() + another_provider_3 = providers.Provider() + + provider = providers.Object(another_provider_1) + + provider.override(another_provider_2) + provider.override(another_provider_3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert another_provider_1 in all_providers + assert another_provider_2 in all_providers + assert another_provider_3 in all_providers diff --git a/tests/unit/providers/traversal/test_provided_instance_py3.py b/tests/unit/providers/traversal/test_provided_instance_py3.py new file mode 100644 index 00000000..8e13dbf4 --- /dev/null +++ b/tests/unit/providers/traversal/test_provided_instance_py3.py @@ -0,0 +1,27 @@ +"""ProvidedInstance provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provider = provider1.provided + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert provider1 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provider2 = providers.Provider() + + provider = provider1.provided + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_provider_py3.py b/tests/unit/providers/traversal/test_provider_py3.py new file mode 100644 index 00000000..748709f9 --- /dev/null +++ b/tests/unit/providers/traversal/test_provider_py3.py @@ -0,0 +1,60 @@ +"""Provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal_overriding(): + provider1 = providers.Provider() + provider2 = providers.Provider() + provider3 = providers.Provider() + + provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + + +def test_traversal_overriding_nested(): + provider1 = providers.Provider() + + provider2 = providers.Provider() + provider2.override(provider1) + + provider3 = providers.Provider() + provider3.override(provider2) + + provider = providers.Provider() + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + + +def test_traverse_types_filtering(): + provider1 = providers.Resource(dict) + provider2 = providers.Resource(dict) + provider3 = providers.Provider() + + provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) + provider.override(provider3) + + all_providers = list(provider.traverse(types=[providers.Resource])) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_resource_py3.py b/tests/unit/providers/traversal/test_resource_py3.py new file mode 100644 index 00000000..b4a1179c --- /dev/null +++ b/tests/unit/providers/traversal/test_resource_py3.py @@ -0,0 +1,60 @@ +"""Resource provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Resource(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Resource(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Resource(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Resource(list) + provider2 = providers.Resource(tuple) + + provider = providers.Resource(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + + provider = providers.Resource(provider1) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert provider1 in all_providers + diff --git a/tests/unit/providers/traversal/test_selector_py3.py b/tests/unit/providers/traversal/test_selector_py3.py new file mode 100644 index 00000000..bd345076 --- /dev/null +++ b/tests/unit/providers/traversal/test_selector_py3.py @@ -0,0 +1,59 @@ +"""Selector provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + switch = lambda: "provider1" + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Selector( + switch, + provider1=provider1, + provider2=provider2, + ) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_switch(): + switch = providers.Callable(lambda: "provider1") + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Selector( + switch, + provider1=provider1, + provider2=provider2, + ) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert switch in all_providers + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + selector1 = providers.Selector(lambda: "provider1", provider1=provider1) + + provider = providers.Selector( + lambda: "provider2", + provider2=provider2, + ) + provider.override(selector1) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert selector1 in all_providers diff --git a/tests/unit/providers/traversal/test_singleton_py3.py b/tests/unit/providers/traversal/test_singleton_py3.py new file mode 100644 index 00000000..78240732 --- /dev/null +++ b/tests/unit/providers/traversal/test_singleton_py3.py @@ -0,0 +1,77 @@ +"""Singleton provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Singleton(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Singleton(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Singleton(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_attributes(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Singleton(dict) + provider.add_attributes(foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + + provider = providers.Singleton(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + provider2 = providers.Object("bar") + provider3 = providers.Object("baz") + + provider = providers.Singleton(provider1, provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_traverse_py3.py b/tests/unit/providers/traversal/test_traverse_py3.py new file mode 100644 index 00000000..01797b4a --- /dev/null +++ b/tests/unit/providers/traversal/test_traverse_py3.py @@ -0,0 +1,40 @@ +"""Provider's traversal tests.""" + +from dependency_injector import providers + + +def test_traverse_cycled_graph(): + provider1 = providers.Provider() + + provider2 = providers.Provider() + provider2.override(provider1) + + provider3 = providers.Provider() + provider3.override(provider2) + + provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3 + + all_providers = list(providers.traverse(provider1)) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + + +def test_traverse_types_filtering(): + provider1 = providers.Resource(dict) + provider2 = providers.Resource(dict) + provider3 = providers.Provider() + + provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) + provider.override(provider3) + + all_providers = list(providers.traverse(provider, types=[providers.Resource])) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers diff --git a/tests/unit/providers/utils/__init__.py b/tests/unit/providers/utils/__init__.py new file mode 100644 index 00000000..1c2f9686 --- /dev/null +++ b/tests/unit/providers/utils/__init__.py @@ -0,0 +1 @@ +"""Provider utils tests.""" diff --git a/tests/unit/providers/utils/test_ensure_is_provider_py2_py3.py b/tests/unit/providers/utils/test_ensure_is_provider_py2_py3.py new file mode 100644 index 00000000..cbb38fda --- /dev/null +++ b/tests/unit/providers/utils/test_ensure_is_provider_py2_py3.py @@ -0,0 +1,24 @@ +"""Provider utils tests.""" + +from dependency_injector import providers, errors +from pytest import raises + + +def test_with_instance(): + provider = providers.Provider() + assert providers.ensure_is_provider(provider), provider + + +def test_with_class(): + with raises(errors.Error): + providers.ensure_is_provider(providers.Provider) + + +def test_with_string(): + with raises(errors.Error): + providers.ensure_is_provider("some_string") + + +def test_with_object(): + with raises(errors.Error): + providers.ensure_is_provider(object()) diff --git a/tests/unit/providers/utils/test_is_provider_py2_py3.py b/tests/unit/providers/utils/test_is_provider_py2_py3.py new file mode 100644 index 00000000..5e9e537e --- /dev/null +++ b/tests/unit/providers/utils/test_is_provider_py2_py3.py @@ -0,0 +1,34 @@ +"""Provider utils tests.""" + +from dependency_injector import providers + + +def test_with_instance(): + assert providers.is_provider(providers.Provider()) is True + + +def test_with_class(): + assert providers.is_provider(providers.Provider) is False + + +def test_with_string(): + assert providers.is_provider("some_string") is False + + +def test_with_object(): + assert providers.is_provider(object()) is False + + +def test_with_subclass_instance(): + class SomeProvider(providers.Provider): + pass + + assert providers.is_provider(SomeProvider()) is True + + +def test_with_class_with_getattr(): + class SomeClass(object): + def __getattr__(self, _): + return False + + assert providers.is_provider(SomeClass()) is False diff --git a/tests/unit/samples/__init__.py b/tests/unit/samples/__init__.py new file mode 100644 index 00000000..b7aba378 --- /dev/null +++ b/tests/unit/samples/__init__.py @@ -0,0 +1 @@ +"""Sample code for testing.""" diff --git a/tests/unit/samples/schemasample/__init__.py b/tests/unit/samples/schema/__init__.py similarity index 100% rename from tests/unit/samples/schemasample/__init__.py rename to tests/unit/samples/schema/__init__.py diff --git a/tests/unit/samples/schemasample/container-boto3-session.yml b/tests/unit/samples/schema/container-boto3-session.yml similarity index 100% rename from tests/unit/samples/schemasample/container-boto3-session.yml rename to tests/unit/samples/schema/container-boto3-session.yml diff --git a/tests/unit/samples/schemasample/container-multiple-inline.yml b/tests/unit/samples/schema/container-multiple-inline.yml similarity index 75% rename from tests/unit/samples/schemasample/container-multiple-inline.yml rename to tests/unit/samples/schema/container-multiple-inline.yml index a089508f..f5394a2d 100644 --- a/tests/unit/samples/schemasample/container-multiple-inline.yml +++ b/tests/unit/samples/schema/container-multiple-inline.yml @@ -12,7 +12,7 @@ container: provides: sqlite3.connect args: - provider: Callable - provides: schemasample.utils.return_ + provides: samples.schema.utils.return_ args: - container.core.config.database.dsn @@ -27,32 +27,32 @@ container: services: user: provider: Factory - provides: schemasample.services.UserService + provides: samples.schema.services.UserService kwargs: db: provider: Callable - provides: schemasample.utils.return_ + provides: samples.schema.utils.return_ args: - container.gateways.database_client auth: provider: Factory - provides: schemasample.services.AuthService + provides: samples.schema.services.AuthService kwargs: db: provider: Callable - provides: schemasample.utils.return_ + provides: samples.schema.utils.return_ args: - container.gateways.database_client token_ttl: container.core.config.auth.token_ttl.as_int() photo: provider: Factory - provides: schemasample.services.PhotoService + provides: samples.schema.services.PhotoService kwargs: db: provider: Callable - provides: schemasample.utils.return_ + provides: samples.schema.utils.return_ args: - container.gateways.database_client s3: container.gateways.s3_client diff --git a/tests/unit/samples/schemasample/container-multiple-reordered.yml b/tests/unit/samples/schema/container-multiple-reordered.yml similarity index 84% rename from tests/unit/samples/schemasample/container-multiple-reordered.yml rename to tests/unit/samples/schema/container-multiple-reordered.yml index 245c4f37..499cf0a2 100644 --- a/tests/unit/samples/schemasample/container-multiple-reordered.yml +++ b/tests/unit/samples/schema/container-multiple-reordered.yml @@ -5,20 +5,20 @@ container: services: user: provider: Factory - provides: schemasample.services.UserService + provides: samples.schema.services.UserService kwargs: db: container.gateways.database_client auth: provider: Factory - provides: schemasample.services.AuthService + provides: samples.schema.services.AuthService kwargs: db: container.gateways.database_client token_ttl: container.core.config.auth.token_ttl.as_int() photo: provider: Factory - provides: schemasample.services.PhotoService + provides: samples.schema.services.PhotoService kwargs: db: container.gateways.database_client s3: container.gateways.s3_client diff --git a/tests/unit/samples/schemasample/container-multiple.yml b/tests/unit/samples/schema/container-multiple.yml similarity index 84% rename from tests/unit/samples/schemasample/container-multiple.yml rename to tests/unit/samples/schema/container-multiple.yml index 03a5221a..e7dc40aa 100644 --- a/tests/unit/samples/schemasample/container-multiple.yml +++ b/tests/unit/samples/schema/container-multiple.yml @@ -24,20 +24,20 @@ container: services: user: provider: Factory - provides: schemasample.services.UserService + provides: samples.schema.services.UserService kwargs: db: container.gateways.database_client auth: provider: Factory - provides: schemasample.services.AuthService + provides: samples.schema.services.AuthService kwargs: db: container.gateways.database_client token_ttl: container.core.config.auth.token_ttl.as_int() photo: provider: Factory - provides: schemasample.services.PhotoService + provides: samples.schema.services.PhotoService kwargs: db: container.gateways.database_client s3: container.gateways.s3_client diff --git a/tests/unit/samples/schemasample/container-single.yml b/tests/unit/samples/schema/container-single.yml similarity index 83% rename from tests/unit/samples/schemasample/container-single.yml rename to tests/unit/samples/schema/container-single.yml index ad732992..553ab4b0 100644 --- a/tests/unit/samples/schemasample/container-single.yml +++ b/tests/unit/samples/schema/container-single.yml @@ -20,20 +20,20 @@ container: user_service: provider: Factory - provides: schemasample.services.UserService + provides: samples.schema.services.UserService kwargs: db: container.database_client auth_service: provider: Factory - provides: schemasample.services.AuthService + provides: samples.schema.services.AuthService kwargs: db: container.database_client token_ttl: container.config.auth.token_ttl.as_int() photo_service: provider: Factory - provides: schemasample.services.PhotoService + provides: samples.schema.services.PhotoService kwargs: db: container.database_client s3: container.s3_client diff --git a/tests/unit/samples/schemasample/services.py b/tests/unit/samples/schema/services.py similarity index 100% rename from tests/unit/samples/schemasample/services.py rename to tests/unit/samples/schema/services.py diff --git a/tests/unit/samples/schemasample/utils.py b/tests/unit/samples/schema/utils.py similarity index 100% rename from tests/unit/samples/schemasample/utils.py rename to tests/unit/samples/schema/utils.py diff --git a/tests/unit/samples/wiringsamples/__init__.py b/tests/unit/samples/wiring/__init__.py similarity index 100% rename from tests/unit/samples/wiringsamples/__init__.py rename to tests/unit/samples/wiring/__init__.py diff --git a/tests/unit/samples/wiringsamples/asyncinjections.py b/tests/unit/samples/wiring/asyncinjections.py similarity index 100% rename from tests/unit/samples/wiringsamples/asyncinjections.py rename to tests/unit/samples/wiring/asyncinjections.py diff --git a/tests/unit/samples/wiringsamples/container.py b/tests/unit/samples/wiring/container.py similarity index 100% rename from tests/unit/samples/wiringsamples/container.py rename to tests/unit/samples/wiring/container.py diff --git a/tests/unit/samples/wiringsamples/imports.py b/tests/unit/samples/wiring/imports.py similarity index 100% rename from tests/unit/samples/wiringsamples/imports.py rename to tests/unit/samples/wiring/imports.py diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiring/module.py similarity index 100% rename from tests/unit/samples/wiringsamples/module.py rename to tests/unit/samples/wiring/module.py diff --git a/tests/unit/samples/wiringsamples/module_invalid_attr_injection.py b/tests/unit/samples/wiring/module_invalid_attr_injection.py similarity index 100% rename from tests/unit/samples/wiringsamples/module_invalid_attr_injection.py rename to tests/unit/samples/wiring/module_invalid_attr_injection.py diff --git a/tests/unit/samples/wiringsamples/package/__init__.py b/tests/unit/samples/wiring/package/__init__.py similarity index 100% rename from tests/unit/samples/wiringsamples/package/__init__.py rename to tests/unit/samples/wiring/package/__init__.py diff --git a/tests/unit/samples/wiringsamples/package/subpackage/__init__.py b/tests/unit/samples/wiring/package/subpackage/__init__.py similarity index 100% rename from tests/unit/samples/wiringsamples/package/subpackage/__init__.py rename to tests/unit/samples/wiring/package/subpackage/__init__.py diff --git a/tests/unit/samples/wiringsamples/package/subpackage/submodule.py b/tests/unit/samples/wiring/package/subpackage/submodule.py similarity index 100% rename from tests/unit/samples/wiringsamples/package/subpackage/submodule.py rename to tests/unit/samples/wiring/package/subpackage/submodule.py diff --git a/tests/unit/samples/wiringsamples/queuemodule.py b/tests/unit/samples/wiring/queuemodule.py similarity index 100% rename from tests/unit/samples/wiringsamples/queuemodule.py rename to tests/unit/samples/wiring/queuemodule.py diff --git a/tests/unit/samples/wiringsamples/resourceclosing.py b/tests/unit/samples/wiring/resourceclosing.py similarity index 100% rename from tests/unit/samples/wiringsamples/resourceclosing.py rename to tests/unit/samples/wiring/resourceclosing.py diff --git a/tests/unit/samples/wiringsamples/service.py b/tests/unit/samples/wiring/service.py similarity index 100% rename from tests/unit/samples/wiringsamples/service.py rename to tests/unit/samples/wiring/service.py diff --git a/tests/unit/samples/wiringsamples/wire_relative_string_names.py b/tests/unit/samples/wiring/wire_relative_string_names.py similarity index 100% rename from tests/unit/samples/wiringsamples/wire_relative_string_names.py rename to tests/unit/samples/wiring/wire_relative_string_names.py diff --git a/tests/unit/samples/wiringflask/web.py b/tests/unit/samples/wiringflask/web.py index fe6e39dc..37fbd5e0 100644 --- a/tests/unit/samples/wiringflask/web.py +++ b/tests/unit/samples/wiringflask/web.py @@ -10,7 +10,7 @@ _request_ctx_stack, _app_ctx_stack # noqa class Service: def process(self) -> str: - return "Ok" + return "OK" class Container(containers.DeclarativeContainer): diff --git a/tests/unit/samples/wiringstringidssamples/__init__.py b/tests/unit/samples/wiringstringids/__init__.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/__init__.py rename to tests/unit/samples/wiringstringids/__init__.py diff --git a/tests/unit/samples/wiringstringidssamples/asyncinjections.py b/tests/unit/samples/wiringstringids/asyncinjections.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/asyncinjections.py rename to tests/unit/samples/wiringstringids/asyncinjections.py diff --git a/tests/unit/samples/wiringstringidssamples/container.py b/tests/unit/samples/wiringstringids/container.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/container.py rename to tests/unit/samples/wiringstringids/container.py diff --git a/tests/unit/samples/wiringstringidssamples/module.py b/tests/unit/samples/wiringstringids/module.py similarity index 95% rename from tests/unit/samples/wiringstringidssamples/module.py rename to tests/unit/samples/wiringstringids/module.py index 0e3708fe..aac85aa8 100644 --- a/tests/unit/samples/wiringstringidssamples/module.py +++ b/tests/unit/samples/wiringstringids/module.py @@ -98,6 +98,12 @@ def test_provide_provider(service_provider: Callable[..., Service] = Provide["se return service +@inject +def test_provider_provider(service_provider: Callable[..., Service] = Provider["service.provider"]): + service = service_provider() + return service + + @inject def test_provided_instance(some_value: int = Provide["service", provided().foo["bar"].call()]): return some_value diff --git a/tests/unit/samples/wiringstringids/module_invalid_attr_injection.py b/tests/unit/samples/wiringstringids/module_invalid_attr_injection.py new file mode 100644 index 00000000..78b407f7 --- /dev/null +++ b/tests/unit/samples/wiringstringids/module_invalid_attr_injection.py @@ -0,0 +1,6 @@ +"""Test module for wiring with invalid type of marker for attribute injection.""" + +from dependency_injector.wiring import Closing + + +service = Closing["service"] diff --git a/tests/unit/samples/wiringstringidssamples/package/__init__.py b/tests/unit/samples/wiringstringids/package/__init__.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/package/__init__.py rename to tests/unit/samples/wiringstringids/package/__init__.py diff --git a/tests/unit/samples/wiringstringidssamples/package/subpackage/__init__.py b/tests/unit/samples/wiringstringids/package/subpackage/__init__.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/package/subpackage/__init__.py rename to tests/unit/samples/wiringstringids/package/subpackage/__init__.py diff --git a/tests/unit/samples/wiringstringidssamples/package/subpackage/submodule.py b/tests/unit/samples/wiringstringids/package/subpackage/submodule.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/package/subpackage/submodule.py rename to tests/unit/samples/wiringstringids/package/subpackage/submodule.py diff --git a/tests/unit/samples/wiringstringidssamples/resourceclosing.py b/tests/unit/samples/wiringstringids/resourceclosing.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/resourceclosing.py rename to tests/unit/samples/wiringstringids/resourceclosing.py diff --git a/tests/unit/samples/wiringstringidssamples/service.py b/tests/unit/samples/wiringstringids/service.py similarity index 100% rename from tests/unit/samples/wiringstringidssamples/service.py rename to tests/unit/samples/wiringstringids/service.py diff --git a/tests/unit/schema/conftest.py b/tests/unit/schema/conftest.py new file mode 100644 index 00000000..fe0a8a30 --- /dev/null +++ b/tests/unit/schema/conftest.py @@ -0,0 +1,9 @@ +"""Container schema fixtures.""" + +from dependency_injector import containers +from pytest import fixture + + +@fixture +def container(): + return containers.DynamicContainer() diff --git a/tests/unit/schema/test_container_api_py36.py b/tests/unit/schema/test_container_api_py36.py new file mode 100644 index 00000000..b8bc3a0a --- /dev/null +++ b/tests/unit/schema/test_container_api_py36.py @@ -0,0 +1,145 @@ +"""Container API tests for building container from schema.""" + +import contextlib +import json +import pathlib +import re + +import yaml +from dependency_injector import containers, providers, errors +from pytest import raises + + +def test_from_schema(container: containers.DynamicContainer): + container.from_schema( + { + "version": "1", + "container": { + "provider1": { + "provider": "Factory", + "provides": "list", + "args": [1, 2, 3], + }, + "provider2": { + "provider": "Factory", + "provides": "dict", + "kwargs": { + "one": "container.provider1", + "two": 2, + }, + }, + }, + }, + ) + + assert isinstance(container.provider1, providers.Factory) + assert container.provider1.provides is list + assert container.provider1.args == (1, 2, 3) + + assert isinstance(container.provider2, providers.Factory) + assert container.provider2.provides is dict + assert container.provider2.kwargs == {"one": container.provider1, "two": 2} + + +def test_from_yaml_schema(container: containers.DynamicContainer, tmp_path: pathlib.Path): + schema_path = tmp_path / "schema.yml" + with open(schema_path, "w") as file: + file.write(""" + version: "1" + container: + provider1: + provider: Factory + provides: list + args: + - 1 + - 2 + - 3 + provider2: + provider: Factory + provides: dict + kwargs: + one: container.provider1 + two: 2 + """) + container.from_yaml_schema(schema_path) + + assert isinstance(container.provider1, providers.Factory) + assert container.provider1.provides == list + assert container.provider1.args == (1, 2, 3) + + assert isinstance(container.provider2, providers.Factory) + assert container.provider2.provides is dict + assert container.provider2.kwargs == {"one": container.provider1, "two": 2} + + +def test_from_yaml_schema_with_loader(container: containers.DynamicContainer, tmp_path: pathlib.Path): + schema_path = tmp_path / "schema.yml" + with open(schema_path, "w") as file: + file.write(""" + version: "1" + container: + provider: + provider: Factory + provides: list + args: [1, 2, 3] + """) + container.from_yaml_schema(schema_path, loader=yaml.Loader) + + assert isinstance(container.provider, providers.Factory) + assert container.provider.provides is list + assert container.provider.args == (1, 2, 3) + + +def test_from_yaml_schema_no_yaml_installed(container: containers.DynamicContainer): + @contextlib.contextmanager + def no_yaml_module(): + containers.yaml = None + yield + containers.yaml = yaml + + error_message = re.escape( + "Unable to load yaml schema - PyYAML is not installed. " + "Install PyYAML or install Dependency Injector with yaml extras: " + "\"pip install dependency-injector[yaml]\"" + ) + + with no_yaml_module(): + with raises(errors.Error, match=error_message): + container.from_yaml_schema("./no-yaml-installed.yml") + + +def test_from_json_schema(container: containers.DynamicContainer, tmp_path: pathlib.Path): + schema_path = tmp_path / "schema.json" + with open(schema_path, "w") as file: + file.write( + json.dumps( + { + "version": "1", + "container": { + "provider1": { + "provider": "Factory", + "provides": "list", + "args": [1, 2, 3], + }, + "provider2": { + "provider": "Factory", + "provides": "dict", + "kwargs": { + "one": "container.provider1", + "two": 2, + }, + }, + }, + }, + indent=4, + ), + ) + container.from_json_schema(schema_path) + + assert isinstance(container.provider1, providers.Factory) + assert container.provider1.provides is list + assert container.provider1.args == (1, 2, 3) + + assert isinstance(container.provider2, providers.Factory) + assert container.provider2.provides is dict + assert container.provider2.kwargs == {"one": container.provider1, "two": 2} diff --git a/tests/unit/schema/test_containers_api_py36.py b/tests/unit/schema/test_containers_api_py36.py deleted file mode 100644 index 1678aabf..00000000 --- a/tests/unit/schema/test_containers_api_py36.py +++ /dev/null @@ -1,162 +0,0 @@ -import contextlib -import json -import os.path -import tempfile -import unittest - -import yaml -from dependency_injector import containers, providers, errors - - -class FromSchemaTests(unittest.TestCase): - - def test(self): - container = containers.DynamicContainer() - container.from_schema( - { - "version": "1", - "container": { - "provider1": { - "provider": "Factory", - "provides": "list", - "args": [1, 2, 3], - }, - "provider2": { - "provider": "Factory", - "provides": "dict", - "kwargs": { - "one": "container.provider1", - "two": 2, - }, - }, - }, - }, - ) - - self.assertIsInstance(container.provider1, providers.Factory) - self.assertIs(container.provider1.provides, list) - self.assertEqual(container.provider1.args, (1, 2, 3)) - - self.assertIsInstance(container.provider2, providers.Factory) - self.assertIs(container.provider2.provides, dict) - self.assertEqual(container.provider2.kwargs, {"one": container.provider1, "two": 2}) - - -class FromYamlSchemaTests(unittest.TestCase): - - def test(self): - container = containers.DynamicContainer() - - with tempfile.TemporaryDirectory() as tmp_dir: - schema_path = os.path.join(tmp_dir, "schema.yml") - with open(schema_path, "w") as file: - file.write(""" - version: "1" - container: - provider1: - provider: Factory - provides: list - args: - - 1 - - 2 - - 3 - provider2: - provider: Factory - provides: dict - kwargs: - one: container.provider1 - two: 2 - """) - - container.from_yaml_schema(schema_path) - - self.assertIsInstance(container.provider1, providers.Factory) - self.assertIs(container.provider1.provides, list) - self.assertEqual(container.provider1.args, (1, 2, 3)) - - self.assertIsInstance(container.provider2, providers.Factory) - self.assertIs(container.provider2.provides, dict) - self.assertEqual(container.provider2.kwargs, {"one": container.provider1, "two": 2}) - - def test_with_loader(self): - container = containers.DynamicContainer() - - with tempfile.TemporaryDirectory() as tmp_dir: - schema_path = os.path.join(tmp_dir, "schema.yml") - with open(schema_path, "w") as file: - file.write(""" - version: "1" - container: - provider: - provider: Factory - provides: list - args: [1, 2, 3] - """) - - container.from_yaml_schema(schema_path, loader=yaml.Loader) - - self.assertIsInstance(container.provider, providers.Factory) - self.assertIs(container.provider.provides, list) - self.assertEqual(container.provider.args, (1, 2, 3)) - - def test_no_yaml_installed(self): - @contextlib.contextmanager - def no_yaml_module(): - containers.yaml = None - yield - containers.yaml = yaml - - container = containers.DynamicContainer() - with no_yaml_module(): - with self.assertRaises(errors.Error) as error: - container.from_yaml_schema("./no-yaml-installed.yml") - - self.assertEqual( - error.exception.args[0], - "Unable to load yaml schema - PyYAML is not installed. " - "Install PyYAML or install Dependency Injector with yaml extras: " - "\"pip install dependency-injector[yaml]\"", - ) - - -class FromJsonSchemaTests(unittest.TestCase): - - def test(self): - container = containers.DynamicContainer() - - with tempfile.TemporaryDirectory() as tmp_dir: - schema_path = os.path.join(tmp_dir, "schema.json") - with open(schema_path, "w") as file: - file.write( - json.dumps( - { - "version": "1", - "container": { - "provider1": { - "provider": "Factory", - "provides": "list", - "args": [1, 2, 3], - }, - "provider2": { - "provider": "Factory", - "provides": "dict", - "kwargs": { - "one": "container.provider1", - "two": 2, - }, - }, - }, - }, - indent=4, - ), - ) - - container.from_json_schema(schema_path) - - self.assertIsInstance(container.provider1, providers.Factory) - self.assertIs(container.provider1.provides, list) - self.assertEqual(container.provider1.args, (1, 2, 3)) - - self.assertIsInstance(container.provider2, providers.Factory) - self.assertIs(container.provider2.provides, dict) - self.assertEqual(container.provider2.kwargs, {"one": container.provider1, "two": 2}) diff --git a/tests/unit/schema/test_integration_py36.py b/tests/unit/schema/test_integration_py36.py index f4eb615c..aa7a6fd0 100644 --- a/tests/unit/schema/test_integration_py36.py +++ b/tests/unit/schema/test_integration_py36.py @@ -1,294 +1,273 @@ +"""Container tests for building containers from configuration files.""" + +import os import sqlite3 -import unittest from dependency_injector import containers +from pytest import mark -# Runtime import -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -_SAMPLES_DIR = os.path.abspath( +from samples.schema.services import UserService, AuthService, PhotoService + + +SAMPLES_DIR = os.path.abspath( os.path.sep.join(( os.path.dirname(__file__), "../samples/", )), ) -import sys -sys.path.append(_SAMPLES_DIR) - -from schemasample.services import UserService, AuthService, PhotoService -class TestSchemaSingleContainer(unittest.TestCase): +def test_single_container_schema(container: containers.DynamicContainer): + container.from_yaml_schema(f"{SAMPLES_DIR}/schema/container-single.yml") + container.config.from_dict({ + "database": { + "dsn": ":memory:", + }, + "aws": { + "access_key_id": "KEY", + "secret_access_key": "SECRET", + }, + "auth": { + "token_ttl": 3600, + }, + }) - def test(self): - container = containers.DynamicContainer() - container.from_yaml_schema(f"{_SAMPLES_DIR}/schemasample/container-single.yml") - container.config.from_dict({ - "database": { - "dsn": ":memory:", - }, - "aws": { - "access_key_id": "KEY", - "secret_access_key": "SECRET", - }, - "auth": { - "token_ttl": 3600, - }, - }) + # User service + user_service1 = container.user_service() + user_service2 = container.user_service() + assert isinstance(user_service1, UserService) + assert isinstance(user_service2, UserService) + assert user_service1 is not user_service2 - # User service - user_service1 = container.user_service() - user_service2 = container.user_service() - self.assertIsInstance(user_service1, UserService) - self.assertIsInstance(user_service2, UserService) - self.assertIsNot(user_service1, user_service2) + assert isinstance(user_service1.db, sqlite3.Connection) + assert isinstance(user_service2.db, sqlite3.Connection) + assert user_service1.db is user_service2.db - self.assertIsInstance(user_service1.db, sqlite3.Connection) - self.assertIsInstance(user_service2.db, sqlite3.Connection) - self.assertIs(user_service1.db, user_service2.db) + # Auth service + auth_service1 = container.auth_service() + auth_service2 = container.auth_service() + assert isinstance(auth_service1, AuthService) + assert isinstance(auth_service2, AuthService) + assert auth_service1 is not auth_service2 - # Auth service - auth_service1 = container.auth_service() - auth_service2 = container.auth_service() - self.assertIsInstance(auth_service1, AuthService) - self.assertIsInstance(auth_service2, AuthService) - self.assertIsNot(auth_service1, auth_service2) + assert isinstance(auth_service1.db, sqlite3.Connection) + assert isinstance(auth_service2.db, sqlite3.Connection) + assert auth_service1.db is auth_service2.db + assert auth_service1.db is container.database_client() + assert auth_service2.db is container.database_client() - self.assertIsInstance(auth_service1.db, sqlite3.Connection) - self.assertIsInstance(auth_service2.db, sqlite3.Connection) - self.assertIs(auth_service1.db, auth_service2.db) - self.assertIs(auth_service1.db, container.database_client()) - self.assertIs(auth_service2.db, container.database_client()) + assert auth_service1.token_ttl == 3600 + assert auth_service2.token_ttl == 3600 - self.assertEqual(auth_service1.token_ttl, 3600) - self.assertEqual(auth_service2.token_ttl, 3600) + # Photo service + photo_service1 = container.photo_service() + photo_service2 = container.photo_service() + assert isinstance(photo_service1, PhotoService) + assert isinstance(photo_service2, PhotoService) + assert photo_service1 is not photo_service2 - # Photo service - photo_service1 = container.photo_service() - photo_service2 = container.photo_service() - self.assertIsInstance(photo_service1, PhotoService) - self.assertIsInstance(photo_service2, PhotoService) - self.assertIsNot(photo_service1, photo_service2) + assert isinstance(photo_service1.db, sqlite3.Connection) + assert isinstance(photo_service2.db, sqlite3.Connection) + assert photo_service1.db is photo_service2.db + assert photo_service1.db is container.database_client() + assert photo_service2.db is container.database_client() - self.assertIsInstance(photo_service1.db, sqlite3.Connection) - self.assertIsInstance(photo_service2.db, sqlite3.Connection) - self.assertIs(photo_service1.db, photo_service2.db) - self.assertIs(photo_service1.db, container.database_client()) - self.assertIs(photo_service2.db, container.database_client()) - - self.assertIs(photo_service1.s3, photo_service2.s3) - self.assertIs(photo_service1.s3, container.s3_client()) - self.assertIs(photo_service2.s3, container.s3_client()) + assert photo_service1.s3 is photo_service2.s3 + assert photo_service1.s3 is container.s3_client() + assert photo_service2.s3 is container.s3_client() -class TestSchemaMultipleContainers(unittest.TestCase): +def test_multiple_containers_schema(container: containers.DynamicContainer): + container.from_yaml_schema(f"{SAMPLES_DIR}/schema/container-multiple.yml") + container.core.config.from_dict({ + "database": { + "dsn": ":memory:", + }, + "aws": { + "access_key_id": "KEY", + "secret_access_key": "SECRET", + }, + "auth": { + "token_ttl": 3600, + }, + }) - def test(self): - container = containers.DynamicContainer() - container.from_yaml_schema(f"{_SAMPLES_DIR}/schemasample/container-multiple.yml") - container.core.config.from_dict({ - "database": { - "dsn": ":memory:", - }, - "aws": { - "access_key_id": "KEY", - "secret_access_key": "SECRET", - }, - "auth": { - "token_ttl": 3600, - }, - }) + # User service + user_service1 = container.services.user() + user_service2 = container.services.user() + assert isinstance(user_service1, UserService) + assert isinstance(user_service2, UserService) + assert user_service1 is not user_service2 - # User service - user_service1 = container.services.user() - user_service2 = container.services.user() - self.assertIsInstance(user_service1, UserService) - self.assertIsInstance(user_service2, UserService) - self.assertIsNot(user_service1, user_service2) + assert isinstance(user_service1.db, sqlite3.Connection) + assert isinstance(user_service2.db, sqlite3.Connection) + assert user_service1.db is user_service2.db - self.assertIsInstance(user_service1.db, sqlite3.Connection) - self.assertIsInstance(user_service2.db, sqlite3.Connection) - self.assertIs(user_service1.db, user_service2.db) + # Auth service + auth_service1 = container.services.auth() + auth_service2 = container.services.auth() + assert isinstance(auth_service1, AuthService) + assert isinstance(auth_service2, AuthService) + assert auth_service1 is not auth_service2 - # Auth service - auth_service1 = container.services.auth() - auth_service2 = container.services.auth() - self.assertIsInstance(auth_service1, AuthService) - self.assertIsInstance(auth_service2, AuthService) - self.assertIsNot(auth_service1, auth_service2) + assert isinstance(auth_service1.db, sqlite3.Connection) + assert isinstance(auth_service2.db, sqlite3.Connection) + assert auth_service1.db is auth_service2.db + assert auth_service1.db is container.gateways.database_client() + assert auth_service2.db is container.gateways.database_client() - self.assertIsInstance(auth_service1.db, sqlite3.Connection) - self.assertIsInstance(auth_service2.db, sqlite3.Connection) - self.assertIs(auth_service1.db, auth_service2.db) - self.assertIs(auth_service1.db, container.gateways.database_client()) - self.assertIs(auth_service2.db, container.gateways.database_client()) + assert auth_service1.token_ttl == 3600 + assert auth_service2.token_ttl == 3600 - self.assertEqual(auth_service1.token_ttl, 3600) - self.assertEqual(auth_service2.token_ttl, 3600) + # Photo service + photo_service1 = container.services.photo() + photo_service2 = container.services.photo() + assert isinstance(photo_service1, PhotoService) + assert isinstance(photo_service2, PhotoService) + assert photo_service1 is not photo_service2 - # Photo service - photo_service1 = container.services.photo() - photo_service2 = container.services.photo() - self.assertIsInstance(photo_service1, PhotoService) - self.assertIsInstance(photo_service2, PhotoService) - self.assertIsNot(photo_service1, photo_service2) + assert isinstance(photo_service1.db, sqlite3.Connection) + assert isinstance(photo_service2.db, sqlite3.Connection) + assert photo_service1.db is photo_service2.db + assert photo_service1.db is container.gateways.database_client() + assert photo_service2.db is container.gateways.database_client() - self.assertIsInstance(photo_service1.db, sqlite3.Connection) - self.assertIsInstance(photo_service2.db, sqlite3.Connection) - self.assertIs(photo_service1.db, photo_service2.db) - self.assertIs(photo_service1.db, container.gateways.database_client()) - self.assertIs(photo_service2.db, container.gateways.database_client()) - - self.assertIs(photo_service1.s3, photo_service2.s3) - self.assertIs(photo_service1.s3, container.gateways.s3_client()) - self.assertIs(photo_service2.s3, container.gateways.s3_client()) + assert photo_service1.s3 is photo_service2.s3 + assert photo_service1.s3 is container.gateways.s3_client() + assert photo_service2.s3 is container.gateways.s3_client() -class TestSchemaMultipleContainersReordered(unittest.TestCase): +def test_multiple_reordered_containers_schema(container: containers.DynamicContainer): + container.from_yaml_schema(f"{SAMPLES_DIR}/schema/container-multiple-reordered.yml") + container.core.config.from_dict({ + "database": { + "dsn": ":memory:", + }, + "aws": { + "access_key_id": "KEY", + "secret_access_key": "SECRET", + }, + "auth": { + "token_ttl": 3600, + }, + }) - def test(self): - container = containers.DynamicContainer() - container.from_yaml_schema(f"{_SAMPLES_DIR}/schemasample/container-multiple-reordered.yml") - container.core.config.from_dict({ - "database": { - "dsn": ":memory:", - }, - "aws": { - "access_key_id": "KEY", - "secret_access_key": "SECRET", - }, - "auth": { - "token_ttl": 3600, - }, - }) + # User service + user_service1 = container.services.user() + user_service2 = container.services.user() + assert isinstance(user_service1, UserService) + assert isinstance(user_service2, UserService) + assert user_service1 is not user_service2 - # User service - user_service1 = container.services.user() - user_service2 = container.services.user() - self.assertIsInstance(user_service1, UserService) - self.assertIsInstance(user_service2, UserService) - self.assertIsNot(user_service1, user_service2) + assert isinstance(user_service1.db, sqlite3.Connection) + assert isinstance(user_service2.db, sqlite3.Connection) + assert user_service1.db is user_service2.db - self.assertIsInstance(user_service1.db, sqlite3.Connection) - self.assertIsInstance(user_service2.db, sqlite3.Connection) - self.assertIs(user_service1.db, user_service2.db) + # Auth service + auth_service1 = container.services.auth() + auth_service2 = container.services.auth() + assert isinstance(auth_service1, AuthService) + assert isinstance(auth_service2, AuthService) + assert auth_service1 is not auth_service2 - # Auth service - auth_service1 = container.services.auth() - auth_service2 = container.services.auth() - self.assertIsInstance(auth_service1, AuthService) - self.assertIsInstance(auth_service2, AuthService) - self.assertIsNot(auth_service1, auth_service2) + assert isinstance(auth_service1.db, sqlite3.Connection) + assert isinstance(auth_service2.db, sqlite3.Connection) + assert auth_service1.db is auth_service2.db + assert auth_service1.db is container.gateways.database_client() + assert auth_service2.db is container.gateways.database_client() - self.assertIsInstance(auth_service1.db, sqlite3.Connection) - self.assertIsInstance(auth_service2.db, sqlite3.Connection) - self.assertIs(auth_service1.db, auth_service2.db) - self.assertIs(auth_service1.db, container.gateways.database_client()) - self.assertIs(auth_service2.db, container.gateways.database_client()) + assert auth_service1.token_ttl == 3600 + assert auth_service2.token_ttl == 3600 - self.assertEqual(auth_service1.token_ttl, 3600) - self.assertEqual(auth_service2.token_ttl, 3600) + # Photo service + photo_service1 = container.services.photo() + photo_service2 = container.services.photo() + assert isinstance(photo_service1, PhotoService) + assert isinstance(photo_service2, PhotoService) + assert photo_service1 is not photo_service2 - # Photo service - photo_service1 = container.services.photo() - photo_service2 = container.services.photo() - self.assertIsInstance(photo_service1, PhotoService) - self.assertIsInstance(photo_service2, PhotoService) - self.assertIsNot(photo_service1, photo_service2) + assert isinstance(photo_service1.db, sqlite3.Connection) + assert isinstance(photo_service2.db, sqlite3.Connection) + assert photo_service1.db is photo_service2.db + assert photo_service1.db is container.gateways.database_client() + assert photo_service2.db is container.gateways.database_client() - self.assertIsInstance(photo_service1.db, sqlite3.Connection) - self.assertIsInstance(photo_service2.db, sqlite3.Connection) - self.assertIs(photo_service1.db, photo_service2.db) - self.assertIs(photo_service1.db, container.gateways.database_client()) - self.assertIs(photo_service2.db, container.gateways.database_client()) - - self.assertIs(photo_service1.s3, photo_service2.s3) - self.assertIs(photo_service1.s3, container.gateways.s3_client()) - self.assertIs(photo_service2.s3, container.gateways.s3_client()) + assert photo_service1.s3 is photo_service2.s3 + assert photo_service1.s3 is container.gateways.s3_client() + assert photo_service2.s3 is container.gateways.s3_client() -class TestSchemaMultipleContainersWithInlineProviders(unittest.TestCase): +def test_multiple_containers_with_inline_providers_schema(container: containers.DynamicContainer): + container.from_yaml_schema(f"{SAMPLES_DIR}/schema/container-multiple-inline.yml") + container.core.config.from_dict({ + "database": { + "dsn": ":memory:", + }, + "aws": { + "access_key_id": "KEY", + "secret_access_key": "SECRET", + }, + "auth": { + "token_ttl": 3600, + }, + }) - def test(self): - container = containers.DynamicContainer() - container.from_yaml_schema(f"{_SAMPLES_DIR}/schemasample/container-multiple-inline.yml") - container.core.config.from_dict({ - "database": { - "dsn": ":memory:", - }, - "aws": { - "access_key_id": "KEY", - "secret_access_key": "SECRET", - }, - "auth": { - "token_ttl": 3600, - }, - }) + # User service + user_service1 = container.services.user() + user_service2 = container.services.user() + assert isinstance(user_service1, UserService) + assert isinstance(user_service2, UserService) + assert user_service1 is not user_service2 - # User service - user_service1 = container.services.user() - user_service2 = container.services.user() - self.assertIsInstance(user_service1, UserService) - self.assertIsInstance(user_service2, UserService) - self.assertIsNot(user_service1, user_service2) + assert isinstance(user_service1.db, sqlite3.Connection) + assert isinstance(user_service2.db, sqlite3.Connection) + assert user_service1.db is user_service2.db - self.assertIsInstance(user_service1.db, sqlite3.Connection) - self.assertIsInstance(user_service2.db, sqlite3.Connection) - self.assertIs(user_service1.db, user_service2.db) + # Auth service + auth_service1 = container.services.auth() + auth_service2 = container.services.auth() + assert isinstance(auth_service1, AuthService) + assert isinstance(auth_service2, AuthService) + assert auth_service1 is not auth_service2 - # Auth service - auth_service1 = container.services.auth() - auth_service2 = container.services.auth() - self.assertIsInstance(auth_service1, AuthService) - self.assertIsInstance(auth_service2, AuthService) - self.assertIsNot(auth_service1, auth_service2) + assert isinstance(auth_service1.db, sqlite3.Connection) + assert isinstance(auth_service2.db, sqlite3.Connection) + assert auth_service1.db is auth_service2.db + assert auth_service1.db is container.gateways.database_client() + assert auth_service2.db is container.gateways.database_client() - self.assertIsInstance(auth_service1.db, sqlite3.Connection) - self.assertIsInstance(auth_service2.db, sqlite3.Connection) - self.assertIs(auth_service1.db, auth_service2.db) - self.assertIs(auth_service1.db, container.gateways.database_client()) - self.assertIs(auth_service2.db, container.gateways.database_client()) + assert auth_service1.token_ttl == 3600 + assert auth_service2.token_ttl == 3600 - self.assertEqual(auth_service1.token_ttl, 3600) - self.assertEqual(auth_service2.token_ttl, 3600) + # Photo service + photo_service1 = container.services.photo() + photo_service2 = container.services.photo() + assert isinstance(photo_service1, PhotoService) + assert isinstance(photo_service2, PhotoService) + assert photo_service1 is not photo_service2 - # Photo service - photo_service1 = container.services.photo() - photo_service2 = container.services.photo() - self.assertIsInstance(photo_service1, PhotoService) - self.assertIsInstance(photo_service2, PhotoService) - self.assertIsNot(photo_service1, photo_service2) + assert isinstance(photo_service1.db, sqlite3.Connection) + assert isinstance(photo_service2.db, sqlite3.Connection) + assert photo_service1.db is photo_service2.db + assert photo_service1.db is container.gateways.database_client() + assert photo_service2.db is container.gateways.database_client() - self.assertIsInstance(photo_service1.db, sqlite3.Connection) - self.assertIsInstance(photo_service2.db, sqlite3.Connection) - self.assertIs(photo_service1.db, photo_service2.db) - self.assertIs(photo_service1.db, container.gateways.database_client()) - self.assertIs(photo_service2.db, container.gateways.database_client()) - - self.assertIs(photo_service1.s3, photo_service2.s3) - self.assertIs(photo_service1.s3, container.gateways.s3_client()) - self.assertIs(photo_service2.s3, container.gateways.s3_client()) + assert photo_service1.s3 is photo_service2.s3 + assert photo_service1.s3 is container.gateways.s3_client() + assert photo_service2.s3 is container.gateways.s3_client() -class TestSchemaBoto3Session(unittest.TestCase): +@mark.skip(reason="Boto3 tries to connect to the internet") +def test_schema_with_boto3_session(container: containers.DynamicContainer): + container.from_yaml_schema(f"{SAMPLES_DIR}/schema/container-boto3-session.yml") + container.config.from_dict( + { + "aws_access_key_id": "key", + "aws_secret_access_key": "secret", + "aws_session_token": "token", + "aws_region_name": "us-east-1", + }, + ) - @unittest.skip("Boto3 tries to connect to the internet") - def test(self): - container = containers.DynamicContainer() - container.from_yaml_schema(f"{_SAMPLES_DIR}/schemasample/container-boto3-session.yml") - container.config.from_dict( - { - "aws_access_key_id": "key", - "aws_secret_access_key": "secret", - "aws_session_token": "token", - "aws_region_name": "us-east-1", - }, - ) - - self.assertEqual(container.s3_client().__class__.__name__, "S3") - self.assertEqual(container.sqs_client().__class__.__name__, "SQS") + assert container.s3_client().__class__.__name__ == "S3" + assert container.sqs_client().__class__.__name__ == "SQS" diff --git a/tests/unit/test_common_py2_py3.py b/tests/unit/test_common_py2_py3.py deleted file mode 100644 index 7315dfbe..00000000 --- a/tests/unit/test_common_py2_py3.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Dependency injector common unit tests.""" - -import unittest - -from dependency_injector import __version__ - - -class VersionTest(unittest.TestCase): - - def test_version_follows_semantic_versioning(self): - self.assertEqual(len(__version__.split(".")), 3) diff --git a/tests/unit/test_version_py2_py3.py b/tests/unit/test_version_py2_py3.py new file mode 100644 index 00000000..e1e0bf39 --- /dev/null +++ b/tests/unit/test_version_py2_py3.py @@ -0,0 +1,7 @@ +"""Dependency injector common unit tests.""" + +from dependency_injector import __version__ + + +def test_version_follows_semantic_versioning(): + assert len(__version__.split(".")) == 3 diff --git a/tests/unit/wiring/provider_ids/__init__.py b/tests/unit/wiring/provider_ids/__init__.py new file mode 100644 index 00000000..4947ead7 --- /dev/null +++ b/tests/unit/wiring/provider_ids/__init__.py @@ -0,0 +1 @@ +"""Tests for wiring based on provider instance identification.""" diff --git a/tests/unit/wiring/provider_ids/test_async_injections_py36.py b/tests/unit/wiring/provider_ids/test_async_injections_py36.py new file mode 100644 index 00000000..f17f19c7 --- /dev/null +++ b/tests/unit/wiring/provider_ids/test_async_injections_py36.py @@ -0,0 +1,55 @@ +"""Async injection tests.""" + +from pytest import fixture, mark + +from samples.wiring import asyncinjections + + +@fixture(autouse=True) +def container(): + container = asyncinjections.Container() + container.wire(modules=[asyncinjections]) + yield container + container.unwire() + + +@fixture(autouse=True) +def reset_counters(): + asyncinjections.resource1.reset_counters() + asyncinjections.resource2.reset_counters() + + +@mark.asyncio +async def test_async_injections(): + resource1, resource2 = await asyncinjections.async_injection() + + assert resource1 is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 1 + assert asyncinjections.resource1.shutdown_counter == 0 + + assert resource2 is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 1 + assert asyncinjections.resource2.shutdown_counter == 0 + + +@mark.asyncio +async def test_async_injections_with_closing(): + resource1, resource2 = await asyncinjections.async_injection_with_closing() + + assert resource1 is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 1 + assert asyncinjections.resource1.shutdown_counter == 1 + + assert resource2 is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 1 + assert asyncinjections.resource2.shutdown_counter == 1 + + resource1, resource2 = await asyncinjections.async_injection_with_closing() + + assert resource1 is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 2 + assert asyncinjections.resource1.shutdown_counter == 2 + + assert resource2 is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 2 + assert asyncinjections.resource2.shutdown_counter == 2 diff --git a/tests/unit/wiring/provider_ids/test_autoloader_py36.py b/tests/unit/wiring/provider_ids/test_autoloader_py36.py new file mode 100644 index 00000000..becafe18 --- /dev/null +++ b/tests/unit/wiring/provider_ids/test_autoloader_py36.py @@ -0,0 +1,41 @@ +"""Auto loader tests.""" + +import contextlib +import importlib + +from dependency_injector.wiring import register_loader_containers, unregister_loader_containers +from pytest import fixture + +from samples.wiring import module +from samples.wiring.service import Service +from samples.wiring.container import Container + + +@fixture +def container(): + container = Container() + + yield container + + with contextlib.suppress(ValueError): + unregister_loader_containers(container) + container.unwire() + importlib.reload(module) + + +def test_register_container(container: Container) -> None: + register_loader_containers(container) + importlib.reload(module) + + service = module.test_function() + assert isinstance(service, Service) + + +def test_numpy_scipy_and_builtins_dont_break_wiring(container: Container) -> None: + register_loader_containers(container) + importlib.reload(module) + importlib.import_module("samples.wiring.imports") + + service = module.test_function() + + assert isinstance(service, Service) diff --git a/tests/unit/wiring/provider_ids/test_main_py36.py b/tests/unit/wiring/provider_ids/test_main_py36.py new file mode 100644 index 00000000..15ac31c0 --- /dev/null +++ b/tests/unit/wiring/provider_ids/test_main_py36.py @@ -0,0 +1,337 @@ +"""Main wiring tests.""" + +from decimal import Decimal + +from dependency_injector import errors +from dependency_injector.wiring import Closing, Provide, Provider, wire +from pytest import fixture, mark, raises + +from samples.wiring import module, package, resourceclosing +from samples.wiring.service import Service +from samples.wiring.container import Container, SubContainer + + +@fixture(autouse=True) +def container(): + container = Container(config={"a": {"b": {"c": 10}}}) + container.wire( + modules=[module], + packages=[package], + ) + yield container + container.unwire() + + +@fixture +def subcontainer(): + container = SubContainer() + container.wire( + modules=[module], + packages=[package], + ) + yield container + container.unwire() + + +@fixture +def resourceclosing_container(): + container = resourceclosing.Container() + container.wire(modules=[resourceclosing]) + yield container + container.unwire() + + +def test_package_lookup(): + from samples.wiring.package import test_package_function + service = test_package_function() + assert isinstance(service, Service) + + +def test_package_subpackage_lookup(): + from samples.wiring.package.subpackage import test_package_function + service = test_package_function() + assert isinstance(service, Service) + + +def test_package_submodule_lookup(): + from samples.wiring.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + +def test_module_attributes_wiring(): + assert isinstance(module.service, Service) + assert isinstance(module.service_provider(), Service) + assert isinstance(module.undefined, Provide) + + +def test_module_attribute_wiring_with_invalid_marker(container: Container): + from samples.wiring import module_invalid_attr_injection + with raises(Exception, match="Unknown type of marker {0}".format(module_invalid_attr_injection.service)): + container.wire(modules=[module_invalid_attr_injection]) + + +def test_class_wiring(): + test_class_object = module.TestClass() + assert isinstance(test_class_object.service, Service) + + +def test_class_wiring_context_arg(container: Container): + test_service = container.service() + test_class_object = module.TestClass(service=test_service) + assert test_class_object.service is test_service + + +def test_class_method_wiring(): + test_class_object = module.TestClass() + service = test_class_object.method() + assert isinstance(service, Service) + + +def test_class_classmethod_wiring(): + service = module.TestClass.class_method() + assert isinstance(service, Service) + + +def test_instance_classmethod_wiring(): + instance = module.TestClass() + service = instance.class_method() + assert isinstance(service, Service) + + +def test_class_staticmethod_wiring(): + service = module.TestClass.static_method() + assert isinstance(service, Service) + + +def test_instance_staticmethod_wiring(): + instance = module.TestClass() + service = instance.static_method() + assert isinstance(service, Service) + + +def test_class_attribute_wiring(): + assert isinstance(module.TestClass.service, Service) + assert isinstance(module.TestClass.service_provider(), Service) + assert isinstance(module.TestClass.undefined, Provide) + + +def test_function_wiring(): + service = module.test_function() + assert isinstance(service, Service) + + +def test_function_wiring_context_arg(container: Container): + test_service = container.service() + service = module.test_function(service=test_service) + assert service is test_service + + +def test_function_wiring_provider(): + service = module.test_function_provider() + assert isinstance(service, Service) + + +def test_function_wiring_provider_context_arg(container: Container): + test_service = container.service() + service = module.test_function_provider(service_provider=lambda: test_service) + assert service is test_service + + +def test_configuration_option(): + ( + value_int, + value_float, + value_str, + value_decimal, + value_required, + value_required_int, + value_required_float, + value_required_str, + value_required_decimal, + ) = module.test_config_value() + + assert value_int == 10 + assert value_float == 10.0 + assert value_str == "10" + assert value_decimal == Decimal(10) + assert value_required == 10 + assert value_required_int == 10 + assert value_required_float == 10.0 + assert value_required_str == "10" + assert value_required_decimal == Decimal(10) + + +def test_configuration_option_required_undefined(container: Container): + container.config.reset_override() + with raises(errors.Error, match="Undefined configuration option \"config.a.b.c\""): + module.test_config_value_required_undefined() + + +def test_provide_provider(): + service = module.test_provide_provider() + assert isinstance(service, Service) + + +def test_provider_provider(): + service = module.test_provider_provider() + assert isinstance(service, Service) + + +def test_provided_instance(container: Container): + class TestService: + foo = {"bar": lambda: 10} + + with container.service.override(TestService()): + some_value = module.test_provided_instance() + assert some_value == 10 + + +def test_subcontainer(): + some_value = module.test_subcontainer_provider() + assert some_value == 1 + + +def test_config_invariant(container: Container): + config = { + "option": { + "a": 1, + "b": 2, + }, + "switch": "a", + } + container.config.from_dict(config) + + value_default = module.test_config_invariant() + assert value_default == 1 + + with container.config.switch.override("a"): + value_a = module.test_config_invariant() + assert value_a == 1 + + with container.config.switch.override("b"): + value_b = module.test_config_invariant() + assert value_b == 2 + + +def test_wire_with_class_error(): + with raises(Exception): + wire( + container=Container, + modules=[module], + ) + + +def test_unwire_function(container: Container): + container.unwire() + assert isinstance(module.test_function(), Provide) + + +def test_unwire_class(container: Container): + container.unwire() + test_class_object = module.TestClass() + assert isinstance(test_class_object.service, Provide) + + +def test_unwire_class_method(container: Container): + container.unwire() + test_class_object = module.TestClass() + assert isinstance(test_class_object.method(), Provide) + + +def test_unwire_package_function(container: Container): + container.unwire() + from samples.wiring.package.subpackage.submodule import test_function + assert isinstance(test_function(), Provide) + + +def test_unwire_package_function_by_reference(container: Container): + from samples.wiring.package.subpackage import submodule + container.unwire() + assert isinstance(submodule.test_function(), Provide) + + +def test_unwire_module_attributes(container: Container): + container.unwire() + assert isinstance(module.service, Provide) + assert isinstance(module.service_provider, Provider) + assert isinstance(module.undefined, Provide) + + +def test_unwire_class_attributes(container: Container): + container.unwire() + assert isinstance(module.TestClass.service, Provide) + assert isinstance(module.TestClass.service_provider, Provider) + assert isinstance(module.TestClass.undefined, Provide) + + +@mark.usefixtures("subcontainer") +def test_wire_multiple_containers(): + service, some_value = module.test_provide_from_different_containers() + assert isinstance(service, Service) + assert some_value == 1 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_resource(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function() + assert isinstance(result_1, resourceclosing.Service) + assert result_1.init_counter == 1 + assert result_1.shutdown_counter == 1 + + result_2 = resourceclosing.test_function() + assert isinstance(result_2, resourceclosing.Service) + assert result_2.init_counter == 2 + assert result_2.shutdown_counter == 2 + + assert result_1 is not result_2 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_resource_bypass_marker_injection(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function(service=Closing[Provide[resourceclosing.Container.service]]) + assert isinstance(result_1, resourceclosing.Service) + assert result_1.init_counter == 1 + assert result_1.shutdown_counter == 1 + + result_2 = resourceclosing.test_function(service=Closing[Provide[resourceclosing.Container.service]]) + assert isinstance(result_2, resourceclosing.Service) + assert result_2.init_counter == 2 + assert result_2.shutdown_counter == 2 + + assert result_1 is not result_2 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_resource_context(): + resourceclosing.Service.reset_counter() + service = resourceclosing.Service() + + result_1 = resourceclosing.test_function(service=service) + assert result_1 is service + assert result_1.init_counter == 0 + assert result_1.shutdown_counter == 0 + + result_2 = resourceclosing.test_function(service=service) + assert result_2 is service + assert result_2.init_counter == 0 + assert result_2.shutdown_counter == 0 + + +def test_class_decorator(): + service = module.test_class_decorator() + assert isinstance(service, Service) + + +def test_container(): + service = module.test_container() + assert isinstance(service, Service) + + +def test_bypass_marker_injection(): + service = module.test_function(service=Provide[Container.service]) + assert isinstance(service, Service) diff --git a/tests/unit/wiring/string_ids/__init__.py b/tests/unit/wiring/string_ids/__init__.py new file mode 100644 index 00000000..cc5689cc --- /dev/null +++ b/tests/unit/wiring/string_ids/__init__.py @@ -0,0 +1 @@ +"""Tests for wiring based on provider string name identification.""" diff --git a/tests/unit/wiring/string_ids/test_async_injections_py36.py b/tests/unit/wiring/string_ids/test_async_injections_py36.py new file mode 100644 index 00000000..cff13ce5 --- /dev/null +++ b/tests/unit/wiring/string_ids/test_async_injections_py36.py @@ -0,0 +1,55 @@ +"""Async injection tests.""" + +from pytest import fixture, mark + +from samples.wiringstringids import asyncinjections + + +@fixture(autouse=True) +def container(): + container = asyncinjections.Container() + container.wire(modules=[asyncinjections]) + yield container + container.unwire() + + +@fixture(autouse=True) +def reset_counters(): + asyncinjections.resource1.reset_counters() + asyncinjections.resource2.reset_counters() + + +@mark.asyncio +async def test_async_injections(): + resource1, resource2 = await asyncinjections.async_injection() + + assert resource1 is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 1 + assert asyncinjections.resource1.shutdown_counter == 0 + + assert resource2 is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 1 + assert asyncinjections.resource2.shutdown_counter == 0 + + +@mark.asyncio +async def test_async_injections_with_closing(): + resource1, resource2 = await asyncinjections.async_injection_with_closing() + + assert resource1 is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 1 + assert asyncinjections.resource1.shutdown_counter == 1 + + assert resource2 is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 1 + assert asyncinjections.resource2.shutdown_counter == 1 + + resource1, resource2 = await asyncinjections.async_injection_with_closing() + + assert resource1 is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 2 + assert asyncinjections.resource1.shutdown_counter == 2 + + assert resource2 is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 2 + assert asyncinjections.resource2.shutdown_counter == 2 diff --git a/tests/unit/wiring/string_ids/test_autoloader_py36.py b/tests/unit/wiring/string_ids/test_autoloader_py36.py new file mode 100644 index 00000000..98e4a4f1 --- /dev/null +++ b/tests/unit/wiring/string_ids/test_autoloader_py36.py @@ -0,0 +1,31 @@ +"""Auto loader tests.""" + +import contextlib +import importlib + +from dependency_injector.wiring import register_loader_containers, unregister_loader_containers +from pytest import fixture + +from samples.wiringstringids import module +from samples.wiringstringids.service import Service +from samples.wiringstringids.container import Container + + +@fixture +def container(): + container = Container() + + yield container + + with contextlib.suppress(ValueError): + unregister_loader_containers(container) + container.unwire() + importlib.reload(module) + + +def test_register_container(container: Container) -> None: + register_loader_containers(container) + importlib.reload(module) + + service = module.test_function() + assert isinstance(service, Service) diff --git a/tests/unit/wiring/string_ids/test_dynamic_container_py36.py b/tests/unit/wiring/string_ids/test_dynamic_container_py36.py new file mode 100644 index 00000000..e6250b37 --- /dev/null +++ b/tests/unit/wiring/string_ids/test_dynamic_container_py36.py @@ -0,0 +1,32 @@ +"""Tests for wiring with dynamic container.""" + +from dependency_injector import containers, providers +from pytest import fixture + +from samples.wiringstringids import module, package +from samples.wiringstringids.service import Service + + +@fixture(autouse=True) +def container(): + sub = containers.DynamicContainer() + sub.int_object = providers.Object(1) + + container = containers.DynamicContainer() + container.config = providers.Configuration() + container.service = providers.Factory(Service) + container.sub = sub + + container.wire( + modules=[module], + packages=[package], + ) + + yield container + + container.unwire() + + +def test_wire(): + service = module.test_function() + assert isinstance(service, Service) diff --git a/tests/unit/wiring/string_ids/test_main_py36.py b/tests/unit/wiring/string_ids/test_main_py36.py new file mode 100644 index 00000000..4c8f2e55 --- /dev/null +++ b/tests/unit/wiring/string_ids/test_main_py36.py @@ -0,0 +1,337 @@ +"""Main wiring tests.""" + +from decimal import Decimal + +from dependency_injector import errors +from dependency_injector.wiring import Closing, Provide, Provider, wire +from pytest import fixture, mark, raises + +from samples.wiringstringids import module, package, resourceclosing +from samples.wiringstringids.service import Service +from samples.wiringstringids.container import Container, SubContainer + + +@fixture(autouse=True) +def container(): + container = Container(config={"a": {"b": {"c": 10}}}) + container.wire( + modules=[module], + packages=[package], + ) + yield container + container.unwire() + + +@fixture +def subcontainer(): + container = SubContainer() + container.wire( + modules=[module], + packages=[package], + ) + yield container + container.unwire() + + +@fixture +def resourceclosing_container(): + container = resourceclosing.Container() + container.wire(modules=[resourceclosing]) + yield container + container.unwire() + + +def test_package_lookup(): + from samples.wiringstringids.package import test_package_function + service = test_package_function() + assert isinstance(service, Service) + + +def test_package_subpackage_lookup(): + from samples.wiringstringids.package.subpackage import test_package_function + service = test_package_function() + assert isinstance(service, Service) + + +def test_package_submodule_lookup(): + from samples.wiringstringids.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + +def test_module_attributes_wiring(): + assert isinstance(module.service, Service) + assert isinstance(module.service_provider(), Service) + assert isinstance(module.undefined, Provide) + + +def test_module_attribute_wiring_with_invalid_marker(container: Container): + from samples.wiringstringids import module_invalid_attr_injection + with raises(Exception, match="Unknown type of marker {0}".format(module_invalid_attr_injection.service)): + container.wire(modules=[module_invalid_attr_injection]) + + +def test_class_wiring(): + test_class_object = module.TestClass() + assert isinstance(test_class_object.service, Service) + + +def test_class_wiring_context_arg(container: Container): + test_service = container.service() + test_class_object = module.TestClass(service=test_service) + assert test_class_object.service is test_service + + +def test_class_method_wiring(): + test_class_object = module.TestClass() + service = test_class_object.method() + assert isinstance(service, Service) + + +def test_class_classmethod_wiring(): + service = module.TestClass.class_method() + assert isinstance(service, Service) + + +def test_instance_classmethod_wiring(): + instance = module.TestClass() + service = instance.class_method() + assert isinstance(service, Service) + + +def test_class_staticmethod_wiring(): + service = module.TestClass.static_method() + assert isinstance(service, Service) + + +def test_instance_staticmethod_wiring(): + instance = module.TestClass() + service = instance.static_method() + assert isinstance(service, Service) + + +def test_class_attribute_wiring(): + assert isinstance(module.TestClass.service, Service) + assert isinstance(module.TestClass.service_provider(), Service) + assert isinstance(module.TestClass.undefined, Provide) + + +def test_function_wiring(): + service = module.test_function() + assert isinstance(service, Service) + + +def test_function_wiring_context_arg(container: Container): + test_service = container.service() + service = module.test_function(service=test_service) + assert service is test_service + + +def test_function_wiring_provider(): + service = module.test_function_provider() + assert isinstance(service, Service) + + +def test_function_wiring_provider_context_arg(container: Container): + test_service = container.service() + service = module.test_function_provider(service_provider=lambda: test_service) + assert service is test_service + + +def test_configuration_option(): + ( + value_int, + value_float, + value_str, + value_decimal, + value_required, + value_required_int, + value_required_float, + value_required_str, + value_required_decimal, + ) = module.test_config_value() + + assert value_int == 10 + assert value_float == 10.0 + assert value_str == "10" + assert value_decimal == Decimal(10) + assert value_required == 10 + assert value_required_int == 10 + assert value_required_float == 10.0 + assert value_required_str == "10" + assert value_required_decimal == Decimal(10) + + +def test_configuration_option_required_undefined(container: Container): + container.config.reset_override() + with raises(errors.Error, match="Undefined configuration option \"config.a.b.c\""): + module.test_config_value_required_undefined() + + +def test_provide_provider(): + service = module.test_provide_provider() + assert isinstance(service, Service) + + +def test_provider_provider(): + service = module.test_provider_provider() + assert isinstance(service, Service) + + +def test_provided_instance(container: Container): + class TestService: + foo = {"bar": lambda: 10} + + with container.service.override(TestService()): + some_value = module.test_provided_instance() + assert some_value == 10 + + +def test_subcontainer(): + some_value = module.test_subcontainer_provider() + assert some_value == 1 + + +def test_config_invariant(container: Container): + config = { + "option": { + "a": 1, + "b": 2, + }, + "switch": "a", + } + container.config.from_dict(config) + + value_default = module.test_config_invariant() + assert value_default == 1 + + with container.config.switch.override("a"): + value_a = module.test_config_invariant() + assert value_a == 1 + + with container.config.switch.override("b"): + value_b = module.test_config_invariant() + assert value_b == 2 + + +def test_wire_with_class_error(): + with raises(Exception): + wire( + container=Container, + modules=[module], + ) + + +def test_unwire_function(container: Container): + container.unwire() + assert isinstance(module.test_function(), Provide) + + +def test_unwire_class(container: Container): + container.unwire() + test_class_object = module.TestClass() + assert isinstance(test_class_object.service, Provide) + + +def test_unwire_class_method(container: Container): + container.unwire() + test_class_object = module.TestClass() + assert isinstance(test_class_object.method(), Provide) + + +def test_unwire_package_function(container: Container): + container.unwire() + from samples.wiringstringids.package.subpackage.submodule import test_function + assert isinstance(test_function(), Provide) + + +def test_unwire_package_function_by_reference(container: Container): + from samples.wiringstringids.package.subpackage import submodule + container.unwire() + assert isinstance(submodule.test_function(), Provide) + + +def test_unwire_module_attributes(container: Container): + container.unwire() + assert isinstance(module.service, Provide) + assert isinstance(module.service_provider, Provider) + assert isinstance(module.undefined, Provide) + + +def test_unwire_class_attributes(container: Container): + container.unwire() + assert isinstance(module.TestClass.service, Provide) + assert isinstance(module.TestClass.service_provider, Provider) + assert isinstance(module.TestClass.undefined, Provide) + + +@mark.usefixtures("subcontainer") +def test_wire_multiple_containers(): + service, some_value = module.test_provide_from_different_containers() + assert isinstance(service, Service) + assert some_value == 1 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_resource(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function() + assert isinstance(result_1, resourceclosing.Service) + assert result_1.init_counter == 1 + assert result_1.shutdown_counter == 1 + + result_2 = resourceclosing.test_function() + assert isinstance(result_2, resourceclosing.Service) + assert result_2.init_counter == 2 + assert result_2.shutdown_counter == 2 + + assert result_1 is not result_2 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_resource_bypass_marker_injection(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function(service=Closing[Provide["service"]]) + assert isinstance(result_1, resourceclosing.Service) + assert result_1.init_counter == 1 + assert result_1.shutdown_counter == 1 + + result_2 = resourceclosing.test_function(service=Closing[Provide["service"]]) + assert isinstance(result_2, resourceclosing.Service) + assert result_2.init_counter == 2 + assert result_2.shutdown_counter == 2 + + assert result_1 is not result_2 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_resource_context(): + resourceclosing.Service.reset_counter() + service = resourceclosing.Service() + + result_1 = resourceclosing.test_function(service=service) + assert result_1 is service + assert result_1.init_counter == 0 + assert result_1.shutdown_counter == 0 + + result_2 = resourceclosing.test_function(service=service) + assert result_2 is service + assert result_2.init_counter == 0 + assert result_2.shutdown_counter == 0 + + +def test_class_decorator(): + service = module.test_class_decorator() + assert isinstance(service, Service) + + +def test_container(): + service = module.test_container() + assert isinstance(service, Service) + + +def test_bypass_marker_injection(): + service = module.test_function(service=Provide["service"]) + assert isinstance(service, Service) diff --git a/tests/unit/wiring/test_fastapi_py36.py b/tests/unit/wiring/test_fastapi_py36.py new file mode 100644 index 00000000..d93cc9c2 --- /dev/null +++ b/tests/unit/wiring/test_fastapi_py36.py @@ -0,0 +1,43 @@ +from httpx import AsyncClient +from pytest import fixture, mark + +# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir +import os +_SAMPLES_DIR = os.path.abspath( + os.path.sep.join(( + os.path.dirname(__file__), + "../samples/", + )), +) +import sys +sys.path.append(_SAMPLES_DIR) + + +from wiringfastapi import web + + +@fixture +async def async_client(): + client = AsyncClient(app=web.app, base_url="http://test") + yield client + await client.aclose() + + +@mark.asyncio +async def test_depends_marker_injection(async_client: AsyncClient): + class ServiceMock: + async def process(self): + return "Foo" + + with web.container.service.override(ServiceMock()): + response = await async_client.get("/") + + assert response.status_code == 200 + assert response.json() == {"result": "Foo"} + + +@mark.asyncio +async def test_depends_injection(async_client: AsyncClient): + response = await async_client.get("/auth", auth=("john_smith", "secret")) + assert response.status_code == 200 + assert response.json() == {"username": "john_smith", "password": "secret"} diff --git a/tests/unit/wiring/test_wiringflask_py36.py b/tests/unit/wiring/test_flask_py36.py similarity index 58% rename from tests/unit/wiring/test_wiringflask_py36.py rename to tests/unit/wiring/test_flask_py36.py index 586ade16..751f04d8 100644 --- a/tests/unit/wiring/test_wiringflask_py36.py +++ b/tests/unit/wiring/test_flask_py36.py @@ -1,4 +1,3 @@ -import unittest import json # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir @@ -22,13 +21,11 @@ sys.path.append(_SAMPLES_DIR) from wiringflask import web -class WiringFlaskTest(unittest.TestCase): +def test_wiring_with_flask(): + client = web.app.test_client() - def test(self): - client = web.app.test_client() + with web.app.app_context(): + response = client.get("/") - with web.app.app_context(): - response = client.get("/") - - self.assertEqual(response.status_code, 200) - self.assertEqual(json.loads(response.data), {"result": "Ok"}) + assert response.status_code == 200 + assert json.loads(response.data) == {"result": "OK"} diff --git a/tests/unit/wiring/test_module_as_package_py36.py b/tests/unit/wiring/test_module_as_package_py36.py new file mode 100644 index 00000000..dbb5cc67 --- /dev/null +++ b/tests/unit/wiring/test_module_as_package_py36.py @@ -0,0 +1,20 @@ +"""Tests for wiring to module as package.""" + +from pytest import fixture + +from samples.wiring import module +from samples.wiring.service import Service +from samples.wiring.container import Container + + +@fixture +def container(): + container = Container() + yield container + container.unwire() + + +def test_module_as_package_wiring(container: Container): + # See: https://github.com/ets-labs/python-dependency-injector/issues/481 + container.wire(packages=[module]) + assert isinstance(module.service, Service) diff --git a/tests/unit/wiring/test_string_module_names_py36.py b/tests/unit/wiring/test_string_module_names_py36.py new file mode 100644 index 00000000..04930a0b --- /dev/null +++ b/tests/unit/wiring/test_string_module_names_py36.py @@ -0,0 +1,57 @@ +"""Tests for string module and package names.""" + +from pytest import fixture + +from samples.wiring import module +from samples.wiring.service import Service +from samples.wiring.container import Container +from samples.wiring.wire_relative_string_names import wire_with_relative_string_names + + +@fixture +def container(): + container = Container() + yield container + container.unwire() + + +def test_absolute_names(container: Container): + container.wire( + modules=["samples.wiring.module"], + packages=["samples.wiring.package"], + ) + + service = module.test_function() + assert isinstance(service, Service) + + from samples.wiring.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + +def test_relative_names_with_explicit_package(container: Container): + container.wire( + modules=[".module"], + packages=[".package"], + from_package="samples.wiring", + ) + + service = module.test_function() + assert isinstance(service, Service) + + from samples.wiring.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + +def test_relative_names_with_auto_package(container: Container): + wire_with_relative_string_names(container) + + service = module.test_function() + assert isinstance(service, Service) + + from samples.wiring.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + diff --git a/tests/unit/wiring/test_wiring_config_in_container_py36.py b/tests/unit/wiring/test_wiring_config_in_container_py36.py new file mode 100644 index 00000000..a4ac1fbe --- /dev/null +++ b/tests/unit/wiring/test_wiring_config_in_container_py36.py @@ -0,0 +1,92 @@ +"""Tests for specifying wiring config in the container.""" + +from dependency_injector import containers +from dependency_injector.wiring import Provide +from pytest import fixture, mark + +from samples.wiring import module +from samples.wiring.service import Service +from samples.wiring.container import Container + + +@fixture(autouse=True) +def container(wiring_config: containers.WiringConfiguration): + original_wiring_config = Container.wiring_config + Container.wiring_config = wiring_config + container = Container() + yield container + container.unwire() + Container.wiring_config = original_wiring_config + + +@mark.parametrize( + "wiring_config", + [ + containers.WiringConfiguration( + modules=["samples.wiring.module"], + packages=["samples.wiring.package"], + ), + ], +) +def test_absolute_names(): + service = module.test_function() + assert isinstance(service, Service) + + from samples.wiring.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + +@mark.parametrize( + "wiring_config", + [ + containers.WiringConfiguration( + modules=[".module"], + packages=[".package"], + from_package="samples.wiring", + ), + ], +) +def test_relative_names_with_explicit_package(): + service = module.test_function() + assert isinstance(service, Service) + + from samples.wiring.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + +@mark.parametrize( + "wiring_config", + [ + containers.WiringConfiguration( + modules=[".module"], + packages=[".package"], + ), + ], +) +def test_relative_names_with_auto_package(): + service = module.test_function() + assert isinstance(service, Service) + + from samples.wiring.package.subpackage.submodule import test_function + service = test_function() + assert isinstance(service, Service) + + +@mark.parametrize( + "wiring_config", + [ + containers.WiringConfiguration( + modules=[".module"], + auto_wire=False, + ), + ], +) +def test_auto_wire_disabled(container: Container): + service = module.test_function() + assert isinstance(service, Provide) + + container.wire() + service = module.test_function() + assert isinstance(service, Service) diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py deleted file mode 100644 index e423c398..00000000 --- a/tests/unit/wiring/test_wiring_py36.py +++ /dev/null @@ -1,573 +0,0 @@ -import contextlib -from decimal import Decimal -import importlib -import unittest - -from dependency_injector.wiring import ( - wire, - Provide, - Provider, - Closing, - register_loader_containers, - unregister_loader_containers, -) -from dependency_injector import containers, errors - -# Runtime import to avoid syntax errors in samples on Python < 3.5 -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -_SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), -) -import sys -sys.path.append(_TOP_DIR) -sys.path.append(_SAMPLES_DIR) - -from asyncutils import AsyncTestCase - -from wiringsamples import module, package -from wiringsamples.service import Service -from wiringsamples.container import Container, SubContainer -from wiringsamples.wire_relative_string_names import wire_with_relative_string_names - - -class WiringTest(unittest.TestCase): - - container: Container - - def setUp(self) -> None: - self.container = Container(config={"a": {"b": {"c": 10}}}) - self.container.wire( - modules=[module], - packages=[package], - ) - self.addCleanup(self.container.unwire) - - def test_package_lookup(self): - from wiringsamples.package import test_package_function - service = test_package_function() - self.assertIsInstance(service, Service) - - def test_package_subpackage_lookup(self): - from wiringsamples.package.subpackage import test_package_function - service = test_package_function() - self.assertIsInstance(service, Service) - - def test_package_submodule_lookup(self): - from wiringsamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - def test_module_attributes_wiring(self): - self.assertIsInstance(module.service, Service) - self.assertIsInstance(module.service_provider(), Service) - self.assertIsInstance(module.undefined, Provide) - - def test_module_attribute_wiring_with_invalid_marker(self): - from wiringsamples import module_invalid_attr_injection - with self.assertRaises(Exception) as context: - self.container.wire(modules=[module_invalid_attr_injection]) - self.assertEqual( - str(context.exception), - "Unknown type of marker {0}".format(module_invalid_attr_injection.service), - ) - - def test_class_wiring(self): - test_class_object = module.TestClass() - self.assertIsInstance(test_class_object.service, Service) - - def test_class_wiring_context_arg(self): - test_service = self.container.service() - - test_class_object = module.TestClass(service=test_service) - self.assertIs(test_class_object.service, test_service) - - def test_class_method_wiring(self): - test_class_object = module.TestClass() - service = test_class_object.method() - self.assertIsInstance(service, Service) - - def test_class_classmethod_wiring(self): - service = module.TestClass.class_method() - self.assertIsInstance(service, Service) - - def test_instance_classmethod_wiring(self): - instance = module.TestClass() - service = instance.class_method() - self.assertIsInstance(service, Service) - - def test_class_staticmethod_wiring(self): - service = module.TestClass.static_method() - self.assertIsInstance(service, Service) - - def test_instance_staticmethod_wiring(self): - instance = module.TestClass() - service = instance.static_method() - self.assertIsInstance(service, Service) - - def test_class_attribute_wiring(self): - self.assertIsInstance(module.TestClass.service, Service) - self.assertIsInstance(module.TestClass.service_provider(), Service) - self.assertIsInstance(module.TestClass.undefined, Provide) - - def test_function_wiring(self): - service = module.test_function() - self.assertIsInstance(service, Service) - - def test_function_wiring_context_arg(self): - test_service = self.container.service() - - service = module.test_function(service=test_service) - self.assertIs(service, test_service) - - def test_function_wiring_provider(self): - service = module.test_function_provider() - self.assertIsInstance(service, Service) - - def test_function_wiring_provider_context_arg(self): - test_service = self.container.service() - - service = module.test_function_provider(service_provider=lambda: test_service) - self.assertIs(service, test_service) - - def test_configuration_option(self): - ( - value_int, - value_float, - value_str, - value_decimal, - value_required, - value_required_int, - value_required_float, - value_required_str, - value_required_decimal, - ) = module.test_config_value() - - self.assertEqual(value_int, 10) - self.assertEqual(value_float, 10.0) - self.assertEqual(value_str, "10") - self.assertEqual(value_decimal, Decimal(10)) - self.assertEqual(value_required, 10) - self.assertEqual(value_required_int, 10) - self.assertEqual(value_required_float, 10.0) - self.assertEqual(value_required_str, "10") - self.assertEqual(value_required_decimal, Decimal(10)) - - def test_configuration_option_required_undefined(self): - self.container.config.reset_override() - with self.assertRaisesRegex(errors.Error, "Undefined configuration option \"config.a.b.c\""): - module.test_config_value_required_undefined() - - def test_provide_provider(self): - service = module.test_provide_provider() - self.assertIsInstance(service, Service) - - def test_provider_provider(self): - service = module.test_provider_provider() - self.assertIsInstance(service, Service) - - def test_provided_instance(self): - class TestService: - foo = { - "bar": lambda: 10, - } - - with self.container.service.override(TestService()): - some_value = module.test_provided_instance() - self.assertEqual(some_value, 10) - - def test_subcontainer(self): - some_value = module.test_subcontainer_provider() - self.assertEqual(some_value, 1) - - def test_config_invariant(self): - config = { - "option": { - "a": 1, - "b": 2, - }, - "switch": "a", - } - self.container.config.from_dict(config) - - value_default = module.test_config_invariant() - self.assertEqual(value_default, 1) - - with self.container.config.switch.override("a"): - value_a = module.test_config_invariant() - self.assertEqual(value_a, 1) - - with self.container.config.switch.override("b"): - value_b = module.test_config_invariant() - self.assertEqual(value_b, 2) - - def test_wire_with_class_error(self): - with self.assertRaises(Exception): - wire( - container=Container, - modules=[module], - ) - - def test_unwire_function(self): - self.container.unwire() - self.assertIsInstance(module.test_function(), Provide) - - def test_unwire_class(self): - self.container.unwire() - test_class_object = module.TestClass() - self.assertIsInstance(test_class_object.service, Provide) - - def test_unwire_class_method(self): - self.container.unwire() - test_class_object = module.TestClass() - self.assertIsInstance(test_class_object.method(), Provide) - - def test_unwire_package_function(self): - self.container.unwire() - from wiringsamples.package.subpackage.submodule import test_function - self.assertIsInstance(test_function(), Provide) - - def test_unwire_package_function_by_reference(self): - from wiringsamples.package.subpackage import submodule - self.container.unwire() - self.assertIsInstance(submodule.test_function(), Provide) - - def test_unwire_module_attributes(self): - self.container.unwire() - self.assertIsInstance(module.service, Provide) - self.assertIsInstance(module.service_provider, Provider) - self.assertIsInstance(module.undefined, Provide) - - def test_unwire_class_attributes(self): - self.container.unwire() - self.assertIsInstance(module.TestClass.service, Provide) - self.assertIsInstance(module.TestClass.service_provider, Provider) - self.assertIsInstance(module.TestClass.undefined, Provide) - - def test_wire_multiple_containers(self): - sub_container = SubContainer() - sub_container.wire( - modules=[module], - packages=[package], - ) - self.addCleanup(sub_container.unwire) - - service, some_value = module.test_provide_from_different_containers() - - self.assertIsInstance(service, Service) - self.assertEqual(some_value, 1) - - def test_closing_resource(self): - from wiringsamples import resourceclosing - - resourceclosing.Service.reset_counter() - - container = resourceclosing.Container() - container.wire(modules=[resourceclosing]) - self.addCleanup(container.unwire) - - result_1 = resourceclosing.test_function() - self.assertIsInstance(result_1, resourceclosing.Service) - self.assertEqual(result_1.init_counter, 1) - self.assertEqual(result_1.shutdown_counter, 1) - - result_2 = resourceclosing.test_function() - self.assertIsInstance(result_2, resourceclosing.Service) - self.assertEqual(result_2.init_counter, 2) - self.assertEqual(result_2.shutdown_counter, 2) - - self.assertIsNot(result_1, result_2) - - def test_closing_resource_context(self): - from wiringsamples import resourceclosing - - resourceclosing.Service.reset_counter() - service = resourceclosing.Service() - - container = resourceclosing.Container() - container.wire(modules=[resourceclosing]) - self.addCleanup(container.unwire) - - result_1 = resourceclosing.test_function(service=service) - self.assertIs(result_1, service) - self.assertEqual(result_1.init_counter, 0) - self.assertEqual(result_1.shutdown_counter, 0) - - result_2 = resourceclosing.test_function(service=service) - self.assertIs(result_2, service) - self.assertEqual(result_2.init_counter, 0) - self.assertEqual(result_2.shutdown_counter, 0) - - def test_class_decorator(self): - service = module.test_class_decorator() - self.assertIsInstance(service, Service) - - def test_container(self): - service = module.test_container() - self.assertIsInstance(service, Service) - - -class WiringWithStringModuleAndPackageNamesTest(unittest.TestCase): - - container: Container - - def setUp(self) -> None: - self.container = Container() - self.addCleanup(self.container.unwire) - - def test_absolute_names(self): - self.container.wire( - modules=["wiringsamples.module"], - packages=["wiringsamples.package"], - ) - - service = module.test_function() - self.assertIsInstance(service, Service) - - from wiringsamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - def test_relative_names_with_explicit_package(self): - self.container.wire( - modules=[".module"], - packages=[".package"], - from_package="wiringsamples", - ) - - service = module.test_function() - self.assertIsInstance(service, Service) - - from wiringsamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - def test_relative_names_with_auto_package(self): - wire_with_relative_string_names(self.container) - - service = module.test_function() - self.assertIsInstance(service, Service) - - from wiringsamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - -class WiringWithWiringConfigInTheContainerTest(unittest.TestCase): - - container: Container - original_wiring_config = Container.wiring_config - - def tearDown(self) -> None: - Container.wiring_config = self.original_wiring_config - self.container.unwire() - - def test_absolute_names(self): - Container.wiring_config = containers.WiringConfiguration( - modules=["wiringsamples.module"], - packages=["wiringsamples.package"], - ) - self.container = Container() - - service = module.test_function() - self.assertIsInstance(service, Service) - - from wiringsamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - def test_relative_names_with_explicit_package(self): - Container.wiring_config = containers.WiringConfiguration( - modules=[".module"], - packages=[".package"], - from_package="wiringsamples", - ) - self.container = Container() - - service = module.test_function() - self.assertIsInstance(service, Service) - - from wiringsamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - def test_relative_names_with_auto_package(self): - Container.wiring_config = containers.WiringConfiguration( - modules=[".module"], - packages=[".package"], - ) - self.container = Container() - - service = module.test_function() - self.assertIsInstance(service, Service) - - from wiringsamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - def test_auto_wire_disabled(self): - Container.wiring_config = containers.WiringConfiguration( - modules=[".module"], - auto_wire=False, - ) - self.container = Container() - - service = module.test_function() - self.assertIsInstance(service, Provide) - - self.container.wire() - service = module.test_function() - self.assertIsInstance(service, Service) - - -class ModuleAsPackageTest(unittest.TestCase): - - def setUp(self): - self.container = Container(config={"a": {"b": {"c": 10}}}) - self.addCleanup(self.container.unwire) - - def test_module_as_package_wiring(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/481 - self.container.wire(packages=[module]) - self.assertIsInstance(module.service, Service) - - -class WiringAndQueue(unittest.TestCase): - - def test_wire_queue(self) -> None: - from wiringsamples import queuemodule - container = Container() - self.addCleanup(container.unwire) - - # Should not raise exception - # See: https://github.com/ets-labs/python-dependency-injector/issues/362 - try: - container.wire(modules=[queuemodule]) - except: - raise - - -class WiringAndFastAPITest(unittest.TestCase): - - container: Container - - def test_bypass_marker_injection(self): - container = Container() - container.wire(modules=[module]) - self.addCleanup(container.unwire) - - service = module.test_function(service=Provide[Container.service]) - self.assertIsInstance(service, Service) - - def test_closing_resource_bypass_marker_injection(self): - from wiringsamples import resourceclosing - - resourceclosing.Service.reset_counter() - - container = resourceclosing.Container() - container.wire(modules=[resourceclosing]) - self.addCleanup(container.unwire) - - result_1 = resourceclosing.test_function( - service=Closing[Provide[resourceclosing.Container.service]], - ) - self.assertIsInstance(result_1, resourceclosing.Service) - self.assertEqual(result_1.init_counter, 1) - self.assertEqual(result_1.shutdown_counter, 1) - - result_2 = resourceclosing.test_function( - service=Closing[Provide[resourceclosing.Container.service]], - ) - self.assertIsInstance(result_2, resourceclosing.Service) - self.assertEqual(result_2.init_counter, 2) - self.assertEqual(result_2.shutdown_counter, 2) - - self.assertIsNot(result_1, result_2) - - -class WiringAsyncInjectionsTest(AsyncTestCase): - - def test_async_injections(self): - from wiringsamples import asyncinjections - - container = asyncinjections.Container() - container.wire(modules=[asyncinjections]) - self.addCleanup(container.unwire) - - asyncinjections.resource1.reset_counters() - asyncinjections.resource2.reset_counters() - - resource1, resource2 = self._run(asyncinjections.async_injection()) - - self.assertIs(resource1, asyncinjections.resource1) - self.assertEqual(asyncinjections.resource1.init_counter, 1) - self.assertEqual(asyncinjections.resource1.shutdown_counter, 0) - - self.assertIs(resource2, asyncinjections.resource2) - self.assertEqual(asyncinjections.resource2.init_counter, 1) - self.assertEqual(asyncinjections.resource2.shutdown_counter, 0) - - def test_async_injections_with_closing(self): - from wiringsamples import asyncinjections - - container = asyncinjections.Container() - container.wire(modules=[asyncinjections]) - self.addCleanup(container.unwire) - - asyncinjections.resource1.reset_counters() - asyncinjections.resource2.reset_counters() - - resource1, resource2 = self._run(asyncinjections.async_injection_with_closing()) - - self.assertIs(resource1, asyncinjections.resource1) - self.assertEqual(asyncinjections.resource1.init_counter, 1) - self.assertEqual(asyncinjections.resource1.shutdown_counter, 1) - - self.assertIs(resource2, asyncinjections.resource2) - self.assertEqual(asyncinjections.resource2.init_counter, 1) - self.assertEqual(asyncinjections.resource2.shutdown_counter, 1) - - resource1, resource2 = self._run(asyncinjections.async_injection_with_closing()) - - self.assertIs(resource1, asyncinjections.resource1) - self.assertEqual(asyncinjections.resource1.init_counter, 2) - self.assertEqual(asyncinjections.resource1.shutdown_counter, 2) - - self.assertIs(resource2, asyncinjections.resource2) - self.assertEqual(asyncinjections.resource2.init_counter, 2) - self.assertEqual(asyncinjections.resource2.shutdown_counter, 2) - - -class AutoLoaderTest(unittest.TestCase): - - container: Container - - def setUp(self) -> None: - self.container = Container(config={"a": {"b": {"c": 10}}}) - importlib.reload(module) - - def tearDown(self) -> None: - with contextlib.suppress(ValueError): - unregister_loader_containers(self.container) - - self.container.unwire() - - @classmethod - def tearDownClass(cls) -> None: - importlib.reload(module) - - def test_register_container(self): - register_loader_containers(self.container) - importlib.reload(module) - importlib.import_module("wiringsamples.imports") - - service = module.test_function() - self.assertIsInstance(service, Service) diff --git a/tests/unit/wiring/test_wiring_string_ids_py36.py b/tests/unit/wiring/test_wiring_string_ids_py36.py deleted file mode 100644 index 01a5df4d..00000000 --- a/tests/unit/wiring/test_wiring_string_ids_py36.py +++ /dev/null @@ -1,439 +0,0 @@ -import contextlib -from decimal import Decimal -import importlib -import unittest - -from dependency_injector.wiring import ( - wire, - Provide, - Provider, - Closing, - register_loader_containers, - unregister_loader_containers, -) -from dependency_injector import containers, providers, errors - -# Runtime import to avoid syntax errors in samples on Python < 3.5 -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -_SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), -) -import sys -sys.path.append(_TOP_DIR) -sys.path.append(_SAMPLES_DIR) - -from asyncutils import AsyncTestCase - -from wiringstringidssamples import module, package -from wiringstringidssamples.service import Service -from wiringstringidssamples.container import Container, SubContainer - - -class WiringTest(unittest.TestCase): - - container: Container - - def setUp(self) -> None: - self.container = Container(config={"a": {"b": {"c": 10}}}) - self.container.wire( - modules=[module], - packages=[package], - ) - self.addCleanup(self.container.unwire) - - def test_package_lookup(self): - from wiringstringidssamples.package import test_package_function - service = test_package_function() - self.assertIsInstance(service, Service) - - def test_package_subpackage_lookup(self): - from wiringstringidssamples.package.subpackage import test_package_function - service = test_package_function() - self.assertIsInstance(service, Service) - - def test_package_submodule_lookup(self): - from wiringstringidssamples.package.subpackage.submodule import test_function - service = test_function() - self.assertIsInstance(service, Service) - - def test_module_attributes_wiring(self): - self.assertIsInstance(module.service, Service) - self.assertIsInstance(module.service_provider(), Service) - self.assertIsInstance(module.undefined, Provide) - - def test_class_wiring(self): - test_class_object = module.TestClass() - self.assertIsInstance(test_class_object.service, Service) - - def test_class_wiring_context_arg(self): - test_service = self.container.service() - - test_class_object = module.TestClass(service=test_service) - self.assertIs(test_class_object.service, test_service) - - def test_class_method_wiring(self): - test_class_object = module.TestClass() - service = test_class_object.method() - self.assertIsInstance(service, Service) - - def test_class_classmethod_wiring(self): - service = module.TestClass.class_method() - self.assertIsInstance(service, Service) - - def test_instance_classmethod_wiring(self): - instance = module.TestClass() - service = instance.class_method() - self.assertIsInstance(service, Service) - - def test_class_staticmethod_wiring(self): - service = module.TestClass.static_method() - self.assertIsInstance(service, Service) - - def test_instance_staticmethod_wiring(self): - instance = module.TestClass() - service = instance.static_method() - self.assertIsInstance(service, Service) - - def test_class_attribute_wiring(self): - self.assertIsInstance(module.TestClass.service, Service) - self.assertIsInstance(module.TestClass.service_provider(), Service) - self.assertIsInstance(module.TestClass.undefined, Provide) - - def test_function_wiring(self): - service = module.test_function() - self.assertIsInstance(service, Service) - - def test_function_wiring_context_arg(self): - test_service = self.container.service() - - service = module.test_function(service=test_service) - self.assertIs(service, test_service) - - def test_function_wiring_provider(self): - service = module.test_function_provider() - self.assertIsInstance(service, Service) - - def test_function_wiring_provider_context_arg(self): - test_service = self.container.service() - - service = module.test_function_provider(service_provider=lambda: test_service) - self.assertIs(service, test_service) - - def test_configuration_option(self): - ( - value_int, - value_float, - value_str, - value_decimal, - value_required, - value_required_int, - value_required_float, - value_required_str, - value_required_decimal, - ) = module.test_config_value() - - self.assertEqual(value_int, 10) - self.assertEqual(value_float, 10.0) - self.assertEqual(value_str, "10") - self.assertEqual(value_decimal, Decimal(10)) - self.assertEqual(value_required, 10) - self.assertEqual(value_required_int, 10) - self.assertEqual(value_required_float, 10.0) - self.assertEqual(value_required_str, "10") - self.assertEqual(value_required_decimal, Decimal(10)) - - def test_configuration_option_required_undefined(self): - self.container.config.reset_override() - with self.assertRaisesRegex(errors.Error, "Undefined configuration option \"config.a.b.c\""): - module.test_config_value_required_undefined() - - def test_provide_provider(self): - service = module.test_provide_provider() - self.assertIsInstance(service, Service) - - def test_provided_instance(self): - class TestService: - foo = { - "bar": lambda: 10, - } - - with self.container.service.override(TestService()): - some_value = module.test_provided_instance() - self.assertEqual(some_value, 10) - - def test_subcontainer(self): - some_value = module.test_subcontainer_provider() - self.assertEqual(some_value, 1) - - def test_config_invariant(self): - config = { - "option": { - "a": 1, - "b": 2, - }, - "switch": "a", - } - self.container.config.from_dict(config) - - value_default = module.test_config_invariant() - self.assertEqual(value_default, 1) - - with self.container.config.switch.override("a"): - value_a = module.test_config_invariant() - self.assertEqual(value_a, 1) - - with self.container.config.switch.override("b"): - value_b = module.test_config_invariant() - self.assertEqual(value_b, 2) - - def test_wire_with_class_error(self): - with self.assertRaises(Exception): - wire( - container=Container, - modules=[module], - ) - - def test_unwire_function(self): - self.container.unwire() - self.assertIsInstance(module.test_function(), Provide) - - def test_unwire_class(self): - self.container.unwire() - test_class_object = module.TestClass() - self.assertIsInstance(test_class_object.service, Provide) - - def test_unwire_class_method(self): - self.container.unwire() - test_class_object = module.TestClass() - self.assertIsInstance(test_class_object.method(), Provide) - - def test_unwire_package_function(self): - self.container.unwire() - from wiringstringidssamples.package.subpackage.submodule import test_function - self.assertIsInstance(test_function(), Provide) - - def test_unwire_package_function_by_reference(self): - from wiringstringidssamples.package.subpackage import submodule - self.container.unwire() - self.assertIsInstance(submodule.test_function(), Provide) - - def test_unwire_module_attributes(self): - self.container.unwire() - self.assertIsInstance(module.service, Provide) - self.assertIsInstance(module.service_provider, Provider) - self.assertIsInstance(module.undefined, Provide) - - def test_unwire_class_attributes(self): - self.container.unwire() - self.assertIsInstance(module.TestClass.service, Provide) - self.assertIsInstance(module.TestClass.service_provider, Provider) - self.assertIsInstance(module.TestClass.undefined, Provide) - - def test_wire_multiple_containers(self): - sub_container = SubContainer() - sub_container.wire( - modules=[module], - packages=[package], - ) - self.addCleanup(sub_container.unwire) - - service, some_value = module.test_provide_from_different_containers() - - self.assertIsInstance(service, Service) - self.assertEqual(some_value, 1) - - def test_closing_resource(self): - from wiringstringidssamples import resourceclosing - - resourceclosing.Service.reset_counter() - - container = resourceclosing.Container() - container.wire(modules=[resourceclosing]) - self.addCleanup(container.unwire) - - result_1 = resourceclosing.test_function() - self.assertIsInstance(result_1, resourceclosing.Service) - self.assertEqual(result_1.init_counter, 1) - self.assertEqual(result_1.shutdown_counter, 1) - - result_2 = resourceclosing.test_function() - self.assertIsInstance(result_2, resourceclosing.Service) - self.assertEqual(result_2.init_counter, 2) - self.assertEqual(result_2.shutdown_counter, 2) - - self.assertIsNot(result_1, result_2) - - def test_closing_resource_context(self): - from wiringstringidssamples import resourceclosing - - resourceclosing.Service.reset_counter() - service = resourceclosing.Service() - - container = resourceclosing.Container() - container.wire(modules=[resourceclosing]) - self.addCleanup(container.unwire) - - result_1 = resourceclosing.test_function(service=service) - self.assertIs(result_1, service) - self.assertEqual(result_1.init_counter, 0) - self.assertEqual(result_1.shutdown_counter, 0) - - result_2 = resourceclosing.test_function(service=service) - self.assertIs(result_2, service) - self.assertEqual(result_2.init_counter, 0) - self.assertEqual(result_2.shutdown_counter, 0) - - def test_class_decorator(self): - service = module.test_class_decorator() - self.assertIsInstance(service, Service) - - def test_container(self): - service = module.test_container() - self.assertIsInstance(service, Service) - - -class WiringAndFastAPITest(unittest.TestCase): - - container: Container - - def test_bypass_marker_injection(self): - container = Container() - container.wire(modules=[module]) - self.addCleanup(container.unwire) - - service = module.test_function(service=Provide[Container.service]) - self.assertIsInstance(service, Service) - - def test_closing_resource_bypass_marker_injection(self): - from wiringstringidssamples import resourceclosing - - resourceclosing.Service.reset_counter() - - container = resourceclosing.Container() - container.wire(modules=[resourceclosing]) - self.addCleanup(container.unwire) - - result_1 = resourceclosing.test_function( - service=Closing[Provide[resourceclosing.Container.service]], - ) - self.assertIsInstance(result_1, resourceclosing.Service) - self.assertEqual(result_1.init_counter, 1) - self.assertEqual(result_1.shutdown_counter, 1) - - result_2 = resourceclosing.test_function( - service=Closing[Provide[resourceclosing.Container.service]], - ) - self.assertIsInstance(result_2, resourceclosing.Service) - self.assertEqual(result_2.init_counter, 2) - self.assertEqual(result_2.shutdown_counter, 2) - - self.assertIsNot(result_1, result_2) - - -class WireDynamicContainerTest(unittest.TestCase): - - def test_wire(self): - sub = containers.DynamicContainer() - sub.int_object = providers.Object(1) - - container = containers.DynamicContainer() - container.config = providers.Configuration() - container.service = providers.Factory(Service) - container.sub = sub - - container.wire( - modules=[module], - packages=[package], - ) - self.addCleanup(container.unwire) - - service = module.test_function() - self.assertIsInstance(service, Service) - - -class WiringAsyncInjectionsTest(AsyncTestCase): - - def test_async_injections(self): - from wiringstringidssamples import asyncinjections - - container = asyncinjections.Container() - container.wire(modules=[asyncinjections]) - self.addCleanup(container.unwire) - - asyncinjections.resource1.reset_counters() - asyncinjections.resource2.reset_counters() - - resource1, resource2 = self._run(asyncinjections.async_injection()) - - self.assertIs(resource1, asyncinjections.resource1) - self.assertEqual(asyncinjections.resource1.init_counter, 1) - self.assertEqual(asyncinjections.resource1.shutdown_counter, 0) - - self.assertIs(resource2, asyncinjections.resource2) - self.assertEqual(asyncinjections.resource2.init_counter, 1) - self.assertEqual(asyncinjections.resource2.shutdown_counter, 0) - - def test_async_injections_with_closing(self): - from wiringstringidssamples import asyncinjections - - container = asyncinjections.Container() - container.wire(modules=[asyncinjections]) - self.addCleanup(container.unwire) - - asyncinjections.resource1.reset_counters() - asyncinjections.resource2.reset_counters() - - resource1, resource2 = self._run(asyncinjections.async_injection_with_closing()) - - self.assertIs(resource1, asyncinjections.resource1) - self.assertEqual(asyncinjections.resource1.init_counter, 1) - self.assertEqual(asyncinjections.resource1.shutdown_counter, 1) - - self.assertIs(resource2, asyncinjections.resource2) - self.assertEqual(asyncinjections.resource2.init_counter, 1) - self.assertEqual(asyncinjections.resource2.shutdown_counter, 1) - - resource1, resource2 = self._run(asyncinjections.async_injection_with_closing()) - - self.assertIs(resource1, asyncinjections.resource1) - self.assertEqual(asyncinjections.resource1.init_counter, 2) - self.assertEqual(asyncinjections.resource1.shutdown_counter, 2) - - self.assertIs(resource2, asyncinjections.resource2) - self.assertEqual(asyncinjections.resource2.init_counter, 2) - self.assertEqual(asyncinjections.resource2.shutdown_counter, 2) - - -class AutoLoaderTest(unittest.TestCase): - - container: Container - - def setUp(self) -> None: - self.container = Container(config={"a": {"b": {"c": 10}}}) - importlib.reload(module) - - def tearDown(self) -> None: - with contextlib.suppress(ValueError): - unregister_loader_containers(self.container) - - self.container.unwire() - - @classmethod - def tearDownClass(cls) -> None: - importlib.reload(module) - - def test_register_container(self): - register_loader_containers(self.container) - importlib.reload(module) - - service = module.test_function() - self.assertIsInstance(service, Service) diff --git a/tests/unit/wiring/test_wiringfastapi_py36.py b/tests/unit/wiring/test_wiringfastapi_py36.py deleted file mode 100644 index baf93e3f..00000000 --- a/tests/unit/wiring/test_wiringfastapi_py36.py +++ /dev/null @@ -1,52 +0,0 @@ -from httpx import AsyncClient - -# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -_SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), -) -import sys -sys.path.append(_TOP_DIR) -sys.path.append(_SAMPLES_DIR) - -from asyncutils import AsyncTestCase - -from wiringfastapi import web - - -class WiringFastAPITest(AsyncTestCase): - - client: AsyncClient - - def setUp(self) -> None: - super().setUp() - self.client = AsyncClient(app=web.app, base_url="http://test") - - def tearDown(self) -> None: - self._run(self.client.aclose()) - super().tearDown() - - def test_depends_marker_injection(self): - class ServiceMock: - async def process(self): - return "Foo" - - with web.container.service.override(ServiceMock()): - response = self._run(self.client.get("/")) - - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json(), {"result": "Foo"}) - - def test_depends_injection(self): - response = self._run(self.client.get("/auth", auth=("john_smith", "secret"))) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json(), {"username": "john_smith", "password": "secret"}) diff --git a/tests/unit/wiring/test_with_stdlib_queue_py36.py b/tests/unit/wiring/test_with_stdlib_queue_py36.py new file mode 100644 index 00000000..f80ce981 --- /dev/null +++ b/tests/unit/wiring/test_with_stdlib_queue_py36.py @@ -0,0 +1,22 @@ +"""Tests for wiring causes no issues with queue.Queue from std lib.""" + +from pytest import fixture + +from samples.wiring import queuemodule +from samples.wiring.container import Container + + +@fixture +def container(): + container = Container() + yield container + container.unwire() + + +def test_wire_queue(container: Container): + # See: https://github.com/ets-labs/python-dependency-injector/issues/362 + # Should not raise exception + try: + container.wire(modules=[queuemodule]) + except: + raise diff --git a/tox.ini b/tox.ini index e3e44ed3..e1b2fcdb 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,8 @@ envlist= [testenv] deps= + pytest + pytest-asyncio # TODO: Hotfix, remove when fixed https://github.com/aio-libs/aiohttp/issues/5107 typing_extensions httpx @@ -17,8 +19,8 @@ extras= pydantic flask aiohttp -commands= - python -m unittest discover -s tests/unit -p test_*_py3*.py +commands = pytest -c tests/.configs/pytest.ini +python_files = test_*_py3*.py [testenv:coveralls] passenv = GITHUB_* COVERALLS_* @@ -31,37 +33,40 @@ deps= coveralls commands= coverage erase - coverage run --rcfile=./.coveragerc -m unittest discover -s tests/unit/ -p test_*_py3*.py + coverage run --rcfile=./.coveragerc -m pytest -c tests/.configs/pytest.ini coverage report --rcfile=./.coveragerc coveralls [testenv:2.7] deps= + pytest extras= yaml flask -commands= - python -m unittest discover -s tests/unit -p test_*_py2_py3.py +commands = pytest -c tests/.configs/pytest-py27.ini [testenv:3.5] deps= + pytest + pytest-asyncio contextvars extras= yaml flask -commands= - python -m unittest discover -s tests/unit -p test_*_py3.py +commands = pytest -c tests/.configs/pytest-py35.ini [testenv:pypy2] deps= + pytest extras= yaml flask -commands= - python -m unittest discover -s tests/unit -p test_*_py2_py3.py +commands = pytest -c tests/.configs/pytest-py27.ini [testenv:pypy3] deps= + pytest + pytest-asyncio httpx fastapi boto3 @@ -69,8 +74,7 @@ deps= extras= yaml flask -commands= - python -m unittest discover -s tests/unit -p test_*_py2_py3.py +commands = pytest -c tests/.configs/pytest-py27.ini [testenv:pylint]