Refactor tests

This commit is contained in:
Roman Mogylatov 2021-01-30 16:36:22 -05:00
parent 897f2d3110
commit b7da1ed9f1

View File

@ -3,6 +3,44 @@ import unittest
from dependency_injector import providers
class TraverseTests(unittest.TestCase):
def test_traverse_cycled_graph(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
all_providers = list(providers.traverse(provider1))
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(providers.traverse(provider, types=[providers.Resource]))
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class ProviderTests(unittest.TestCase):
def test_traversal_overriding(self):
@ -42,26 +80,22 @@ class ProviderTests(unittest.TestCase):
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traversal_overriding_cycled(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider3.override(provider2)
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
all_providers = list(provider.traverse(types=[providers.Resource]))
self.assertEqual(len(all_providers), 3)
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ObjectTests(unittest.TestCase):