"""Dependency injector selector provider unit tests.""" import functools import itertools import sys import unittest from dependency_injector import providers, errors from pytest import raises class SelectorTests(unittest.TestCase): selector = providers.Configuration() def test_is_provider(self): assert providers.is_provider(providers.Selector(self.selector)) is True def test_init_optional(self): one = providers.Object(1) two = providers.Object(2) provider = providers.Selector() provider.set_selector(self.selector) provider.set_providers(one=one, two=two) assert provider.providers == {"one": one, "two": two} with self.selector.override("one"): assert provider() == one() with self.selector.override("two"): assert provider() == two() def test_set_selector_returns_self(self): provider = providers.Selector() assert provider.set_selector(self.selector) is provider def test_set_providers_returns_self(self): provider = providers.Selector() assert provider.set_providers(one=providers.Provider()) is provider def test_provided_instance_provider(self): provider = providers.Selector(self.selector) assert isinstance(provider.provided, providers.ProvidedInstance) def test_call(self): provider = providers.Selector( self.selector, one=providers.Object(1), two=providers.Object(2), ) with self.selector.override("one"): assert provider() == 1 with self.selector.override("two"): assert provider() == 2 def test_call_undefined_provider(self): provider = providers.Selector( self.selector, one=providers.Object(1), two=providers.Object(2), ) with self.selector.override("three"): with raises(errors.Error): provider() def test_call_selector_is_none(self): provider = providers.Selector( self.selector, one=providers.Object(1), two=providers.Object(2), ) with self.selector.override(None): with raises(errors.Error): provider() def test_call_any_callable(self): provider = providers.Selector( functools.partial(next, itertools.cycle(["one", "two"])), one=providers.Object(1), two=providers.Object(2), ) assert provider() == 1 assert provider() == 2 assert provider() == 1 assert provider() == 2 def test_call_with_context_args(self): provider = providers.Selector( self.selector, one=providers.Callable(lambda *args, **kwargs: (args, kwargs)), ) with self.selector.override("one"): args, kwargs = provider(1, 2, three=3, four=4) assert args == (1, 2) assert kwargs == {"three": 3, "four": 4} def test_getattr(self): provider_one = providers.Object(1) provider_two = providers.Object(2) provider = providers.Selector( self.selector, one=provider_one, two=provider_two, ) assert provider.one is provider_one assert provider.two is provider_two def test_getattr_attribute_error(self): provider_one = providers.Object(1) provider_two = providers.Object(2) provider = providers.Selector( self.selector, one=provider_one, two=provider_two, ) with raises(AttributeError): _ = provider.provider_three def test_call_overridden(self): provider = providers.Selector(self.selector, sample=providers.Object(1)) overriding_provider1 = providers.Selector(self.selector, sample=providers.Object(2)) overriding_provider2 = providers.Selector(self.selector, sample=providers.Object(3)) provider.override(overriding_provider1) provider.override(overriding_provider2) with self.selector.override("sample"): assert provider() == 3 def test_providers_attribute(self): provider_one = providers.Object(1) provider_two = providers.Object(2) provider = providers.Selector( self.selector, one=provider_one, two=provider_two, ) assert provider.providers == {"one": provider_one, "two": provider_two} def test_deepcopy(self): provider = providers.Selector(self.selector) provider_copy = providers.deepcopy(provider) assert provider is not provider_copy assert isinstance(provider, providers.Selector) def test_deepcopy_from_memo(self): provider = providers.Selector(self.selector) provider_copy_memo = providers.Selector(self.selector) provider_copy = providers.deepcopy( provider, memo={id(provider): provider_copy_memo}, ) assert provider_copy is provider_copy_memo def test_deepcopy_overridden(self): provider = providers.Selector(self.selector) object_provider = providers.Object(object()) provider.override(object_provider) provider_copy = providers.deepcopy(provider) object_provider_copy = provider_copy.overridden[0] assert provider is not provider_copy assert isinstance(provider, providers.Selector) assert object_provider is not object_provider_copy assert isinstance(object_provider_copy, providers.Object) def test_deepcopy_with_sys_streams(self): provider = providers.Selector( self.selector, stdin=providers.Object(sys.stdin), stdout=providers.Object(sys.stdout), stderr=providers.Object(sys.stderr), ) provider_copy = providers.deepcopy(provider) assert provider is not provider_copy assert isinstance(provider_copy, providers.Selector) with self.selector.override("stdin"): assert provider() is sys.stdin with self.selector.override("stdout"): assert provider() is sys.stdout with self.selector.override("stderr"): assert provider() is sys.stderr def test_repr(self): provider = providers.Selector( self.selector, one=providers.Object(1), two=providers.Object(2), ) assert "