"""Selector provider tests.""" import functools import itertools import sys from dependency_injector import providers, errors from pytest import fixture, mark, raises @fixture def switch(): return providers.Configuration() @fixture def one(): return providers.Object(1) @fixture def two(): return providers.Object(2) @fixture def selector_type(): return "default" @fixture def selector(selector_type, switch, one, two): if selector_type == "default": return providers.Selector(switch, one=one, two=two) elif selector_type == "empty": return providers.Selector() elif selector_type == "sys-streams": return providers.Selector( switch, stdin=providers.Object(sys.stdin), stdout=providers.Object(sys.stdout), stderr=providers.Object(sys.stderr), ) else: raise ValueError("Unknown selector type \"{0}\"".format(selector_type)) def test_is_provider(selector): assert providers.is_provider(selector) is True @mark.parametrize("selector_type", ["empty"]) def test_init_optional(selector, switch, one, two): selector.set_selector(switch) selector.set_providers(one=one, two=two) assert selector.providers == {"one": one, "two": two} with switch.override("one"): assert selector() == one() with switch.override("two"): assert selector() == two() def test_set_selector_returns_self(selector, switch): assert selector.set_selector(switch) is selector def test_set_providers_returns_self(selector, one): assert selector.set_providers(one=one) is selector def test_provided_instance_provider(selector): assert isinstance(selector.provided, providers.ProvidedInstance) def test_call(selector, switch): with switch.override("one"): assert selector() == 1 with switch.override("two"): assert selector() == 2 def test_call_undefined_provider(selector, switch): with switch.override("three"): with raises(errors.Error): selector() def test_call_selector_is_none(selector, switch): with switch.override(None): with raises(errors.Error): selector() @mark.parametrize("selector_type", ["empty"]) def test_call_any_callable(selector): selector.set_selector(functools.partial(next, itertools.cycle(["one", "two"]))) selector.set_providers( one=providers.Object(1), two=providers.Object(2), ) assert selector() == 1 assert selector() == 2 assert selector() == 1 assert selector() == 2 @mark.parametrize("selector_type", ["empty"]) def test_call_with_context_args(selector, switch): selector.set_selector(switch) selector.set_providers(one=providers.Callable(lambda *args, **kwargs: (args, kwargs))) with switch.override("one"): args, kwargs = selector(1, 2, three=3, four=4) assert args == (1, 2) assert kwargs == {"three": 3, "four": 4} def test_getattr(selector, one, two): assert selector.one is one assert selector.two is two def test_getattr_attribute_error(selector): with raises(AttributeError): _ = selector.provider_three def test_call_overridden(selector, switch): overriding_provider1 = providers.Selector(switch, one=providers.Object(2)) overriding_provider2 = providers.Selector(switch, one=providers.Object(3)) selector.override(overriding_provider1) selector.override(overriding_provider2) with switch.override("one"): assert selector() == 3 def test_providers_attribute(selector, one, two): assert selector.providers == {"one": one, "two": two} def test_deepcopy(selector): provider_copy = providers.deepcopy(selector) assert provider_copy is not selector assert isinstance(selector, providers.Selector) assert provider_copy.selector is not selector.selector assert isinstance(provider_copy.selector, providers.Configuration) assert provider_copy.one is not selector.one assert isinstance(provider_copy.one, providers.Object) assert provider_copy.one.provides == 1 assert provider_copy.two is not selector.two assert isinstance(provider_copy.two, providers.Object) assert provider_copy.two.provides == 2 def test_deepcopy_from_memo(selector): provider_copy = providers.deepcopy( selector, memo={id(selector): selector}, ) assert provider_copy is selector def test_deepcopy_overridden(selector): object_provider = providers.Object(object()) selector.override(object_provider) provider_copy = providers.deepcopy(selector) object_provider_copy = provider_copy.overridden[0] assert selector is not provider_copy assert isinstance(selector, providers.Selector) assert object_provider is not object_provider_copy assert isinstance(object_provider_copy, providers.Object) @mark.parametrize("selector_type", ["sys-streams"]) def test_deepcopy_with_sys_streams(selector, switch): provider_copy = providers.deepcopy(selector) assert selector is not provider_copy assert isinstance(provider_copy, providers.Selector) with switch.override("stdin"): assert selector() is sys.stdin with switch.override("stdout"): assert selector() is sys.stdout with switch.override("stderr"): assert selector() is sys.stderr def test_repr(selector, switch): assert "