diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py index 11c2f7b9..1cfa1b4a 100644 --- a/tests/unit/providers/test_traversal_py3.py +++ b/tests/unit/providers/test_traversal_py3.py @@ -671,3 +671,59 @@ class ContainerTests(unittest.TestCase): ) self.assertIn(provider.last_overriding, all_providers) self.assertIs(provider.last_overriding(), container2) + + +class SelectorTests(unittest.TestCase): + + def test_traverse(self): + switch = lambda: 'provider1' + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Selector( + switch, + provider1=provider1, + provider2=provider2, + ) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_switch(self): + switch = providers.Callable(lambda: 'provider1') + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Selector( + switch, + provider1=provider1, + provider2=provider2, + ) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 3) + self.assertIn(switch, all_providers) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_overridden(self): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + selector1 = providers.Selector(lambda: 'provider1', provider1=provider1) + + provider = providers.Selector( + lambda: 'provider2', + provider2=provider2, + ) + provider.override(selector1) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 3) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(selector1, all_providers)