From 859d8c99157a631292a5cdbbc17cddd0c706101b Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sun, 31 Jan 2021 09:49:43 -0500 Subject: [PATCH] Add singleton provider tests --- tests/unit/providers/test_traversal_py3.py | 71 ++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py index 3c1a3a42..c7ff7a8b 100644 --- a/tests/unit/providers/test_traversal_py3.py +++ b/tests/unit/providers/test_traversal_py3.py @@ -442,3 +442,74 @@ class FactoryAggregateTests(unittest.TestCase): self.assertEqual(len(all_providers), 2) self.assertIn(factory1, all_providers) self.assertIn(factory2, all_providers) + + +class BaseSingletonTests(unittest.TestCase): + + def test_traverse(self): + provider = providers.Singleton(dict) + all_providers = list(provider.traverse()) + self.assertEqual(len(all_providers), 0) + + def test_traverse_args(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + provider = providers.Singleton(list, 'foo', provider1, provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_kwargs(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + provider = providers.Singleton(dict, foo='foo', bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_attributes(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + provider = providers.Singleton(dict) + provider.add_attributes(foo='foo', bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_overridden(self): + provider1 = providers.Object('bar') + provider2 = providers.Object('baz') + + provider = providers.Singleton(dict, 'foo') + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 2) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + + def test_traverse_provides(self): + provider1 = providers.Callable(list) + provider2 = providers.Object('bar') + provider3 = providers.Object('baz') + + provider = providers.Singleton(provider1, provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 3) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(provider3, all_providers)