diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py index bad27d81..11c2f7b9 100644 --- a/tests/unit/providers/test_traversal_py3.py +++ b/tests/unit/providers/test_traversal_py3.py @@ -1,6 +1,6 @@ import unittest -from dependency_injector import providers +from dependency_injector import containers, providers class TraverseTests(unittest.TestCase): @@ -625,3 +625,49 @@ class ResourceTests(unittest.TestCase): self.assertEqual(len(all_providers), 1) self.assertIn(provider1, all_providers) + + +class ContainerTests(unittest.TestCase): + + def test_traverse(self): + class Container(containers.DeclarativeContainer): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Container(Container) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertEqual( + {provider.provides for provider in all_providers}, + {list, dict}, + ) + + def test_traverse_overridden(self): + class Container1(containers.DeclarativeContainer): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + class Container2(containers.DeclarativeContainer): + provider1 = providers.Callable(tuple) + provider2 = providers.Callable(str) + + container2 = Container2() + + provider = providers.Container(Container1) + provider.override(container2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 5) + self.assertEqual( + { + provider.provides + for provider in all_providers + if isinstance(provider, providers.Callable) + }, + {list, dict, tuple, str}, + ) + self.assertIn(provider.last_overriding, all_providers) + self.assertIs(provider.last_overriding(), container2)