Add Factory tests

This commit is contained in:
Roman Mogylatov 2021-01-31 09:37:10 -05:00
parent 127bf42526
commit 79b14cf7c4

View File

@ -357,3 +357,74 @@ class ConfigurationTests(unittest.TestCase):
all_providers = list(config.option1.traverse()) all_providers = list(config.option1.traverse())
self.assertEqual(len(all_providers), 0) self.assertEqual(len(all_providers), 0)
class FactoryTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Factory(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_attributes(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict)
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Factory(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)