diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py index 790557f5..d41c78b1 100644 --- a/tests/unit/providers/test_traversal_py3.py +++ b/tests/unit/providers/test_traversal_py3.py @@ -3,6 +3,44 @@ import unittest from dependency_injector import providers +class TraverseTests(unittest.TestCase): + + def test_traverse_cycled_graph(self): + provider1 = providers.Provider() + + provider2 = providers.Provider() + provider2.override(provider1) + + provider3 = providers.Provider() + provider3.override(provider2) + + provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3 + + all_providers = list(providers.traverse(provider1)) + + self.assertEqual(len(all_providers), 3) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(provider3, all_providers) + + def test_traverse_types_filtering(self): + provider1 = providers.Resource(dict) + provider2 = providers.Resource(dict) + provider3 = providers.Provider() + + provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) + provider.override(provider3) + + all_providers = list(providers.traverse(provider, types=[providers.Resource])) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + class ProviderTests(unittest.TestCase): def test_traversal_overriding(self): @@ -42,26 +80,22 @@ class ProviderTests(unittest.TestCase): self.assertIn(provider2, all_providers) self.assertIn(provider3, all_providers) - def test_traversal_overriding_cycled(self): - provider1 = providers.Provider() - - provider2 = providers.Provider() - provider2.override(provider1) - + def test_traverse_types_filtering(self): + provider1 = providers.Resource(dict) + provider2 = providers.Resource(dict) provider3 = providers.Provider() - provider3.override(provider2) - - provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3 provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) provider.override(provider3) - all_providers = list(provider.traverse()) + all_providers = list(provider.traverse(types=[providers.Resource])) - self.assertEqual(len(all_providers), 3) + self.assertEqual(len(all_providers), 2) self.assertIn(provider1, all_providers) self.assertIn(provider2, all_providers) - self.assertIn(provider3, all_providers) class ObjectTests(unittest.TestCase):