From 07901d1ce53ba3e9418a3a20c32cece322b1bb5d Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sun, 31 Jan 2021 10:03:32 -0500 Subject: [PATCH] Add list and dict provider tests --- tests/unit/providers/test_traversal_py3.py | 58 ++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py index c7ff7a8b..e249b759 100644 --- a/tests/unit/providers/test_traversal_py3.py +++ b/tests/unit/providers/test_traversal_py3.py @@ -513,3 +513,61 @@ class BaseSingletonTests(unittest.TestCase): self.assertIn(provider1, all_providers) self.assertIn(provider2, all_providers) self.assertIn(provider3, all_providers) + + +class ListTests(unittest.TestCase): + + def test_traverse_args(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + provider = providers.List('foo', provider1, provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_overridden(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + provider3 = providers.List(provider1, provider2) + + provider = providers.List('foo') + provider.override(provider3) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 3) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(provider3, all_providers) + + +class DictTests(unittest.TestCase): + + def test_traverse_kwargs(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + provider = providers.Dict(foo='foo', bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_overridden(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + provider3 = providers.Dict(bar=provider1, baz=provider2) + + provider = providers.Dict(foo='foo') + provider.override(provider3) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 3) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(provider3, all_providers)