Add Selector provider tests

This commit is contained in:
Roman Mogylatov 2021-01-31 10:34:09 -05:00
parent edcd542fdf
commit 5d99a6f674

View File

@ -671,3 +671,59 @@ class ContainerTests(unittest.TestCase):
)
self.assertIn(provider.last_overriding, all_providers)
self.assertIs(provider.last_overriding(), container2)
class SelectorTests(unittest.TestCase):
def test_traverse(self):
switch = lambda: 'provider1'
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_switch(self):
switch = providers.Callable(lambda: 'provider1')
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(switch, all_providers)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
selector1 = providers.Selector(lambda: 'provider1', provider1=provider1)
provider = providers.Selector(
lambda: 'provider2',
provider2=provider2,
)
provider.override(selector1)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(selector1, all_providers)