From d9ba72f95016157676674be8b7653fb77e621db6 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sat, 16 Oct 2021 11:51:43 -0400 Subject: [PATCH] Migrate selector provider tests --- tests/unit/providers/test_selector_py2_py3.py | 375 +++++++++--------- 1 file changed, 182 insertions(+), 193 deletions(-) diff --git a/tests/unit/providers/test_selector_py2_py3.py b/tests/unit/providers/test_selector_py2_py3.py index b6c22ca1..c7e96d01 100644 --- a/tests/unit/providers/test_selector_py2_py3.py +++ b/tests/unit/providers/test_selector_py2_py3.py @@ -1,220 +1,209 @@ -"""Dependency injector selector provider unit tests.""" +"""Selector provider tests.""" import functools import itertools import sys -import unittest - from dependency_injector import providers, errors -from pytest import raises +from pytest import fixture, mark, raises -class SelectorTests(unittest.TestCase): +@fixture +def switch(): + return providers.Configuration() - selector = providers.Configuration() - def test_is_provider(self): - assert providers.is_provider(providers.Selector(self.selector)) is True +@fixture +def one(): + return providers.Object(1) - def test_init_optional(self): - one = providers.Object(1) - two = providers.Object(2) - provider = providers.Selector() - provider.set_selector(self.selector) - provider.set_providers(one=one, two=two) +@fixture +def two(): + return providers.Object(2) - assert provider.providers == {"one": one, "two": two} - with self.selector.override("one"): - assert provider() == one() - with self.selector.override("two"): - assert provider() == two() - def test_set_selector_returns_self(self): - provider = providers.Selector() - assert provider.set_selector(self.selector) is provider +@fixture +def selector_type(): + return "default" - def test_set_providers_returns_self(self): - provider = providers.Selector() - assert provider.set_providers(one=providers.Provider()) is provider - def test_provided_instance_provider(self): - provider = providers.Selector(self.selector) - assert isinstance(provider.provided, providers.ProvidedInstance) - - def test_call(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - - with self.selector.override("one"): - assert provider() == 1 - - with self.selector.override("two"): - assert provider() == 2 - - def test_call_undefined_provider(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - - with self.selector.override("three"): - with raises(errors.Error): - provider() - - def test_call_selector_is_none(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - - with self.selector.override(None): - with raises(errors.Error): - provider() - - def test_call_any_callable(self): - provider = providers.Selector( - functools.partial(next, itertools.cycle(["one", "two"])), - one=providers.Object(1), - two=providers.Object(2), - ) - - assert provider() == 1 - assert provider() == 2 - assert provider() == 1 - assert provider() == 2 - - def test_call_with_context_args(self): - provider = providers.Selector( - self.selector, - one=providers.Callable(lambda *args, **kwargs: (args, kwargs)), - ) - - with self.selector.override("one"): - args, kwargs = provider(1, 2, three=3, four=4) - - assert args == (1, 2) - assert kwargs == {"three": 3, "four": 4} - - def test_getattr(self): - provider_one = providers.Object(1) - provider_two = providers.Object(2) - - provider = providers.Selector( - self.selector, - one=provider_one, - two=provider_two, - ) - - assert provider.one is provider_one - assert provider.two is provider_two - - def test_getattr_attribute_error(self): - provider_one = providers.Object(1) - provider_two = providers.Object(2) - - provider = providers.Selector( - self.selector, - one=provider_one, - two=provider_two, - ) - - with raises(AttributeError): - _ = provider.provider_three - - def test_call_overridden(self): - provider = providers.Selector(self.selector, sample=providers.Object(1)) - overriding_provider1 = providers.Selector(self.selector, sample=providers.Object(2)) - overriding_provider2 = providers.Selector(self.selector, sample=providers.Object(3)) - - provider.override(overriding_provider1) - provider.override(overriding_provider2) - - with self.selector.override("sample"): - assert provider() == 3 - - def test_providers_attribute(self): - provider_one = providers.Object(1) - provider_two = providers.Object(2) - - provider = providers.Selector( - self.selector, - one=provider_one, - two=provider_two, - ) - assert provider.providers == {"one": provider_one, "two": provider_two} - - def test_deepcopy(self): - provider = providers.Selector(self.selector) - - provider_copy = providers.deepcopy(provider) - - assert provider is not provider_copy - assert isinstance(provider, providers.Selector) - - def test_deepcopy_from_memo(self): - provider = providers.Selector(self.selector) - provider_copy_memo = providers.Selector(self.selector) - - provider_copy = providers.deepcopy( - provider, - memo={id(provider): provider_copy_memo}, - ) - - assert provider_copy is provider_copy_memo - - def test_deepcopy_overridden(self): - provider = providers.Selector(self.selector) - object_provider = providers.Object(object()) - - provider.override(object_provider) - - provider_copy = providers.deepcopy(provider) - object_provider_copy = provider_copy.overridden[0] - - assert provider is not provider_copy - assert isinstance(provider, providers.Selector) - - assert object_provider is not object_provider_copy - assert isinstance(object_provider_copy, providers.Object) - - def test_deepcopy_with_sys_streams(self): - provider = providers.Selector( - self.selector, +@fixture +def selector(selector_type, switch, one, two): + if selector_type == "default": + return providers.Selector(switch, one=one, two=two) + elif selector_type == "empty": + return providers.Selector() + elif selector_type == "sys-streams": + return providers.Selector( + switch, stdin=providers.Object(sys.stdin), stdout=providers.Object(sys.stdout), stderr=providers.Object(sys.stderr), - ) + else: + raise ValueError("Unknown selector type \"{0}\"".format(selector_type)) - provider_copy = providers.deepcopy(provider) - assert provider is not provider_copy - assert isinstance(provider_copy, providers.Selector) +def test_is_provider(selector): + assert providers.is_provider(selector) is True - with self.selector.override("stdin"): - assert provider() is sys.stdin - with self.selector.override("stdout"): - assert provider() is sys.stdout +@mark.parametrize("selector_type", ["empty"]) +def test_init_optional(selector, switch, one, two): + selector.set_selector(switch) + selector.set_providers(one=one, two=two) - with self.selector.override("stderr"): - assert provider() is sys.stderr + assert selector.providers == {"one": one, "two": two} + with switch.override("one"): + assert selector() == one() + with switch.override("two"): + assert selector() == two() - def test_repr(self): - provider = providers.Selector( - self.selector, - one=providers.Object(1), - two=providers.Object(2), - ) - assert "