From 2ae3e32429f04e63297674f675266996187cf2b1 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sun, 31 Jan 2021 10:51:31 -0500 Subject: [PATCH] Add MethodCaller provider tests --- tests/unit/providers/test_traversal_py3.py | 63 ++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py index 0aac714d..f32bd913 100644 --- a/tests/unit/providers/test_traversal_py3.py +++ b/tests/unit/providers/test_traversal_py3.py @@ -810,3 +810,66 @@ class ItemGetterTests(unittest.TestCase): self.assertIn(provider1, all_providers) self.assertIn(provider2, all_providers) self.assertIn(provided, all_providers) + + +class MethodCallerTests(unittest.TestCase): + + def test_traverse(self): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider = method.call() + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 3) + self.assertIn(provider1, all_providers) + self.assertIn(provided, all_providers) + self.assertIn(method, all_providers) + + def test_traverse_args(self): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + provider = method.call('foo', provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 4) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(provided, all_providers) + self.assertIn(method, all_providers) + + def test_traverse_kwargs(self): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + provider = method.call(foo='foo', bar=provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 4) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(provided, all_providers) + self.assertIn(method, all_providers) + + def test_traverse_overridden(self): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + + provider = method.call() + provider.override(provider2) + + all_providers = list(provider.traverse()) + + self.assertEqual(len(all_providers), 4) + self.assertIn(provider1, all_providers) + self.assertIn(provider2, all_providers) + self.assertIn(provided, all_providers) + self.assertIn(method, all_providers)