python-dependency-injector/tests/unit/providers/test_traversal_py3.py
Roman Mogylatov 3ca6dd9af1
Providers traversal (#385)
* Implement providers traversal in first precision

* Implement traversal for all providers

* Update traverse interface + add some tests

* Refactor tests

* Add tests for callable provider

* Add configuration tests

* Add Factory tests

* Add FactoryAggrefate tests

* Add .provides attribute to singleton providers

* Add singleton provider tests

* Add list and dict provider tests

* Add resource tests

* Add Container provider tests

* Add Selector provider tests

* Add ProvidedInstance provider tests

* Add AttributeGetter provider tests

* Add ItemGetter provider tests

* Add MethodCaller provider tests

* Refactor container interface

* Update resource provider string representation

* Add .initializer attribute to Resource provider

* Add docs and examples

* Remove not needed EOL in the tests

* Make cosmetic refactoring

* Ignore flake8 line width error in traverse example
2021-02-01 09:42:21 -05:00

876 lines
27 KiB
Python

import unittest
from dependency_injector import containers, providers
class TraverseTests(unittest.TestCase):
def test_traverse_cycled_graph(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
all_providers = list(providers.traverse(provider1))
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(providers.traverse(provider, types=[providers.Resource]))
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class ProviderTests(unittest.TestCase):
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traversal_overriding_nested(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Provider()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse(types=[providers.Resource]))
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class ObjectTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Object('string')
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Object(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
another_provider_1 = providers.Provider()
another_provider_2 = providers.Provider()
another_provider_3 = providers.Provider()
provider = providers.Object(another_provider_1)
provider.override(another_provider_2)
provider.override(another_provider_3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(another_provider_1, all_providers)
self.assertIn(another_provider_2, all_providers)
self.assertIn(another_provider_3, all_providers)
class DelegateTests(unittest.TestCase):
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Delegate(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Delegate(provider1)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class DependencyTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Dependency()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.Dependency(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider = providers.Dependency()
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class DependenciesContainerTests(unittest.TestCase):
def test_traversal(self):
provider = providers.DependenciesContainer()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.DependenciesContainer(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_fluent_interface(self):
provider = providers.DependenciesContainer()
provider1 = provider.provider1
provider2 = provider.provider2
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.DependenciesContainer(
provider1=provider1,
provider2=provider2,
)
provider = providers.DependenciesContainer()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
self.assertIn(provider.provider1, all_providers)
self.assertIn(provider.provider2, all_providers)
class CallableTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Callable(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Callable(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ConfigurationTests(unittest.TestCase):
def test_traverse(self):
config = providers.Configuration(default={'option1': {'option2': 'option2'}})
option1 = config.option1
option2 = config.option1.option2
option3 = config.option1[config.option1.option2]
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(option1, all_providers)
self.assertIn(option2, all_providers)
self.assertIn(option3, all_providers)
def test_traverse_typed(self):
config = providers.Configuration()
option = config.option
typed_option = config.option.as_int()
all_providers = list(typed_option.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(option, all_providers)
def test_traverse_overridden(self):
options = {'option1': {'option2': 'option2'}}
config = providers.Configuration()
config.from_dict(options)
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 1)
overridden, = all_providers
self.assertEqual(overridden(), options)
self.assertIs(overridden, config.last_overriding)
def test_traverse_overridden_option_1(self):
options = {'option2': 'option2'}
config = providers.Configuration()
config.option1.from_dict(options)
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(config.option1, all_providers)
self.assertIn(config.last_overriding, all_providers)
def test_traverse_overridden_option_2(self):
options = {'option2': 'option2'}
config = providers.Configuration()
config.option1.from_dict(options)
all_providers = list(config.option1.traverse())
self.assertEqual(len(all_providers), 0)
class FactoryTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Factory(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_attributes(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict)
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Factory(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class FactoryAggregateTests(unittest.TestCase):
def test_traverse(self):
factory1 = providers.Factory(dict)
factory2 = providers.Factory(list)
provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(factory1, all_providers)
self.assertIn(factory2, all_providers)
class BaseSingletonTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Singleton(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_attributes(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict)
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Singleton(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ListTests(unittest.TestCase):
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.List('foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider3 = providers.List(provider1, provider2)
provider = providers.List('foo')
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class DictTests(unittest.TestCase):
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Dict(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider3 = providers.Dict(bar=provider1, baz=provider2)
provider = providers.Dict(foo='foo')
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ResourceTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Resource(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Resource(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Resource(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Resource(list)
provider2 = providers.Resource(tuple)
provider = providers.Resource(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider = providers.Resource(provider1)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(provider1, all_providers)
class ContainerTests(unittest.TestCase):
def test_traverse(self):
class Container(containers.DeclarativeContainer):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Container(Container)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertEqual(
{provider.provides for provider in all_providers},
{list, dict},
)
def test_traverse_overridden(self):
class Container1(containers.DeclarativeContainer):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
class Container2(containers.DeclarativeContainer):
provider1 = providers.Callable(tuple)
provider2 = providers.Callable(str)
container2 = Container2()
provider = providers.Container(Container1)
provider.override(container2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5)
self.assertEqual(
{
provider.provides
for provider in all_providers
if isinstance(provider, providers.Callable)
},
{list, dict, tuple, str},
)
self.assertIn(provider.last_overriding, all_providers)
self.assertIs(provider.last_overriding(), container2)
class SelectorTests(unittest.TestCase):
def test_traverse(self):
switch = lambda: 'provider1'
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_switch(self):
switch = providers.Callable(lambda: 'provider1')
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(switch, all_providers)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
selector1 = providers.Selector(lambda: 'provider1', provider1=provider1)
provider = providers.Selector(
lambda: 'provider2',
provider2=provider2,
)
provider.override(selector1)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(selector1, all_providers)
class ProvidedInstanceTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provider = provider1.provided
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(provider1, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider = provider1.provided
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class AttributeGetterTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
provider = provided.attr
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
provider2 = providers.Provider()
provider = provided.attr
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
class ItemGetterTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
provider = provided['item']
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
provider2 = providers.Provider()
provider = provided['item']
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
class MethodCallerTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider = method.call()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_args(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call('foo', provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call(foo='foo', bar=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call()
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)