From 75c9bdbd7ad2243e5a3e6d294b681cb2a5da3320 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Mon, 18 Oct 2021 10:30:46 -0400 Subject: [PATCH] Migrate traversal tests --- tests/unit/providers/test_traversal_py3.py | 869 ------------------ tests/unit/providers/traversal/__init__.py | 1 + .../traversal/test_attribute_getter_py3.py | 31 + .../providers/traversal/test_callable_py3.py | 64 ++ .../traversal/test_configuration_py3.py | 63 ++ .../providers/traversal/test_container_py3.py | 42 + .../providers/traversal/test_delegate_py3.py | 32 + .../test_dependencies_container_py3.py | 52 ++ .../traversal/test_dependency_py3.py | 35 + .../unit/providers/traversal/test_dict_py3.py | 31 + .../traversal/test_factory_aggregate_py3.py | 15 + .../providers/traversal/test_factory_py3.py | 77 ++ .../traversal/test_item_getter_py3.py | 31 + .../unit/providers/traversal/test_list_py3.py | 31 + .../traversal/test_method_caller_py3.py | 67 ++ .../providers/traversal/test_object_py3.py | 37 + .../traversal/test_provided_instance_py3.py | 27 + .../providers/traversal/test_provider_py3.py | 60 ++ .../providers/traversal/test_resource_py3.py | 60 ++ .../providers/traversal/test_selector_py3.py | 59 ++ .../providers/traversal/test_singleton_py3.py | 77 ++ .../providers/traversal/test_traverse_py3.py | 40 + 22 files changed, 932 insertions(+), 869 deletions(-) delete mode 100644 tests/unit/providers/test_traversal_py3.py create mode 100644 tests/unit/providers/traversal/__init__.py create mode 100644 tests/unit/providers/traversal/test_attribute_getter_py3.py create mode 100644 tests/unit/providers/traversal/test_callable_py3.py create mode 100644 tests/unit/providers/traversal/test_configuration_py3.py create mode 100644 tests/unit/providers/traversal/test_container_py3.py create mode 100644 tests/unit/providers/traversal/test_delegate_py3.py create mode 100644 tests/unit/providers/traversal/test_dependencies_container_py3.py create mode 100644 tests/unit/providers/traversal/test_dependency_py3.py create mode 100644 tests/unit/providers/traversal/test_dict_py3.py create mode 100644 tests/unit/providers/traversal/test_factory_aggregate_py3.py create mode 100644 tests/unit/providers/traversal/test_factory_py3.py create mode 100644 tests/unit/providers/traversal/test_item_getter_py3.py create mode 100644 tests/unit/providers/traversal/test_list_py3.py create mode 100644 tests/unit/providers/traversal/test_method_caller_py3.py create mode 100644 tests/unit/providers/traversal/test_object_py3.py create mode 100644 tests/unit/providers/traversal/test_provided_instance_py3.py create mode 100644 tests/unit/providers/traversal/test_provider_py3.py create mode 100644 tests/unit/providers/traversal/test_resource_py3.py create mode 100644 tests/unit/providers/traversal/test_selector_py3.py create mode 100644 tests/unit/providers/traversal/test_singleton_py3.py create mode 100644 tests/unit/providers/traversal/test_traverse_py3.py diff --git a/tests/unit/providers/test_traversal_py3.py b/tests/unit/providers/test_traversal_py3.py deleted file mode 100644 index cb2b7970..00000000 --- a/tests/unit/providers/test_traversal_py3.py +++ /dev/null @@ -1,869 +0,0 @@ -import unittest - -from dependency_injector import containers, providers - - -class TraverseTests(unittest.TestCase): - - def test_traverse_cycled_graph(self): - provider1 = providers.Provider() - - provider2 = providers.Provider() - provider2.override(provider1) - - provider3 = providers.Provider() - provider3.override(provider2) - - provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3 - - all_providers = list(providers.traverse(provider1)) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - def test_traverse_types_filtering(self): - provider1 = providers.Resource(dict) - provider2 = providers.Resource(dict) - provider3 = providers.Provider() - - provider = providers.Provider() - - provider.override(provider1) - provider.override(provider2) - provider.override(provider3) - - all_providers = list(providers.traverse(provider, types=[providers.Resource])) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - -class ProviderTests(unittest.TestCase): - - def test_traversal_overriding(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - provider3 = providers.Provider() - - provider = providers.Provider() - - provider.override(provider1) - provider.override(provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - def test_traversal_overriding_nested(self): - provider1 = providers.Provider() - - provider2 = providers.Provider() - provider2.override(provider1) - - provider3 = providers.Provider() - provider3.override(provider2) - - provider = providers.Provider() - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - def test_traverse_types_filtering(self): - provider1 = providers.Resource(dict) - provider2 = providers.Resource(dict) - provider3 = providers.Provider() - - provider = providers.Provider() - - provider.override(provider1) - provider.override(provider2) - provider.override(provider3) - - all_providers = list(provider.traverse(types=[providers.Resource])) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - -class ObjectTests(unittest.TestCase): - - def test_traversal(self): - provider = providers.Object("string") - all_providers = list(provider.traverse()) - assert len(all_providers) == 0 - - def test_traversal_provider(self): - another_provider = providers.Provider() - provider = providers.Object(another_provider) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 1 - assert another_provider in all_providers - - def test_traversal_provider_and_overriding(self): - another_provider_1 = providers.Provider() - another_provider_2 = providers.Provider() - another_provider_3 = providers.Provider() - - provider = providers.Object(another_provider_1) - - provider.override(another_provider_2) - provider.override(another_provider_3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert another_provider_1 in all_providers - assert another_provider_2 in all_providers - assert another_provider_3 in all_providers - - -class DelegateTests(unittest.TestCase): - - def test_traversal_provider(self): - another_provider = providers.Provider() - provider = providers.Delegate(another_provider) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 1 - assert another_provider in all_providers - - def test_traversal_provider_and_overriding(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - - provider3 = providers.Provider() - provider3.override(provider2) - - provider = providers.Delegate(provider1) - - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - -class DependencyTests(unittest.TestCase): - - def test_traversal(self): - provider = providers.Dependency() - all_providers = list(provider.traverse()) - assert len(all_providers) == 0 - - def test_traversal_default(self): - another_provider = providers.Provider() - provider = providers.Dependency(default=another_provider) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 1 - assert another_provider in all_providers - - def test_traversal_overriding(self): - provider1 = providers.Provider() - - provider2 = providers.Provider() - provider2.override(provider1) - - provider = providers.Dependency() - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - -class DependenciesContainerTests(unittest.TestCase): - - def test_traversal(self): - provider = providers.DependenciesContainer() - all_providers = list(provider.traverse()) - assert len(all_providers) == 0 - - def test_traversal_default(self): - another_provider = providers.Provider() - provider = providers.DependenciesContainer(default=another_provider) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 1 - assert another_provider in all_providers - - def test_traversal_fluent_interface(self): - provider = providers.DependenciesContainer() - provider1 = provider.provider1 - provider2 = provider.provider2 - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traversal_overriding(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - provider3 = providers.DependenciesContainer( - provider1=provider1, - provider2=provider2, - ) - - provider = providers.DependenciesContainer() - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 5 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - assert provider.provider1 in all_providers - assert provider.provider2 in all_providers - - -class CallableTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Callable(dict) - all_providers = list(provider.traverse()) - assert len(all_providers) == 0 - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Callable(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Callable(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - - provider = providers.Callable(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - provider2 = providers.Object("bar") - provider3 = providers.Object("baz") - - provider = providers.Callable(provider1, provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - -class ConfigurationTests(unittest.TestCase): - - def test_traverse(self): - config = providers.Configuration(default={"option1": {"option2": "option2"}}) - option1 = config.option1 - option2 = config.option1.option2 - option3 = config.option1[config.option1.option2] - - all_providers = list(config.traverse()) - - assert len(all_providers) == 3 - assert option1 in all_providers - assert option2 in all_providers - assert option3 in all_providers - - def test_traverse_typed(self): - config = providers.Configuration() - option = config.option - typed_option = config.option.as_int() - - all_providers = list(typed_option.traverse()) - - assert len(all_providers) == 1 - assert option in all_providers - - def test_traverse_overridden(self): - options = {"option1": {"option2": "option2"}} - config = providers.Configuration() - config.from_dict(options) - - all_providers = list(config.traverse()) - - assert len(all_providers) == 1 - overridden, = all_providers - assert overridden() == options - assert overridden is config.last_overriding - - def test_traverse_overridden_option_1(self): - options = {"option2": "option2"} - config = providers.Configuration() - config.option1.from_dict(options) - - all_providers = list(config.traverse()) - - assert len(all_providers) == 2 - assert config.option1 in all_providers - assert config.last_overriding in all_providers - - def test_traverse_overridden_option_2(self): - options = {"option2": "option2"} - config = providers.Configuration() - config.option1.from_dict(options) - - all_providers = list(config.option1.traverse()) - - assert len(all_providers) == 0 - - -class FactoryTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Factory(dict) - all_providers = list(provider.traverse()) - assert len(all_providers) == 0 - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Factory(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Factory(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_attributes(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Factory(dict) - provider.add_attributes(foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - - provider = providers.Factory(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - provider2 = providers.Object("bar") - provider3 = providers.Object("baz") - - provider = providers.Factory(provider1, provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - -class FactoryAggregateTests(unittest.TestCase): - - def test_traverse(self): - factory1 = providers.Factory(dict) - factory2 = providers.Factory(list) - provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert factory1 in all_providers - assert factory2 in all_providers - - -class BaseSingletonTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Singleton(dict) - all_providers = list(provider.traverse()) - assert len(all_providers) == 0 - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Singleton(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Singleton(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_attributes(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Singleton(dict) - provider.add_attributes(foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - - provider = providers.Singleton(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - provider2 = providers.Object("bar") - provider3 = providers.Object("baz") - - provider = providers.Singleton(provider1, provider2) - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - -class ListTests(unittest.TestCase): - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.List("foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider3 = providers.List(provider1, provider2) - - provider = providers.List("foo") - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - -class DictTests(unittest.TestCase): - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Dict(foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider3 = providers.Dict(bar=provider1, baz=provider2) - - provider = providers.Dict(foo="foo") - provider.override(provider3) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provider3 in all_providers - - -class ResourceTests(unittest.TestCase): - - def test_traverse(self): - provider = providers.Resource(dict) - all_providers = list(provider.traverse()) - assert len(all_providers) == 0 - - def test_traverse_args(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Resource(list, "foo", provider1, provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_kwargs(self): - provider1 = providers.Object("bar") - provider2 = providers.Object("baz") - provider = providers.Resource(dict, foo="foo", bar=provider1, baz=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Resource(list) - provider2 = providers.Resource(tuple) - - provider = providers.Resource(dict, "foo") - provider.override(provider1) - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_provides(self): - provider1 = providers.Callable(list) - - provider = providers.Resource(provider1) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 1 - assert provider1 in all_providers - - -class ContainerTests(unittest.TestCase): - - def test_traverse(self): - class Container(containers.DeclarativeContainer): - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - provider = providers.Container(Container) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert {list, dict} == {provider.provides for provider in all_providers} - - def test_traverse_overridden(self): - class Container1(containers.DeclarativeContainer): - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - class Container2(containers.DeclarativeContainer): - provider1 = providers.Callable(tuple) - provider2 = providers.Callable(str) - - container2 = Container2() - - provider = providers.Container(Container1) - provider.override(container2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 5 - assert {list, dict, tuple, str} == { - provider.provides - for provider in all_providers - if isinstance(provider, providers.Callable) - } - assert provider.last_overriding in all_providers - assert provider.last_overriding() is container2 - - -class SelectorTests(unittest.TestCase): - - def test_traverse(self): - switch = lambda: "provider1" - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - provider = providers.Selector( - switch, - provider1=provider1, - provider2=provider2, - ) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_switch(self): - switch = providers.Callable(lambda: "provider1") - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - - provider = providers.Selector( - switch, - provider1=provider1, - provider2=provider2, - ) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert switch in all_providers - assert provider1 in all_providers - assert provider2 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Callable(list) - provider2 = providers.Callable(dict) - selector1 = providers.Selector(lambda: "provider1", provider1=provider1) - - provider = providers.Selector( - lambda: "provider2", - provider2=provider2, - ) - provider.override(selector1) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert selector1 in all_providers - - -class ProvidedInstanceTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provider = provider1.provided - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 1 - assert provider1 in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provider2 = providers.Provider() - - provider = provider1.provided - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provider2 in all_providers - - -class AttributeGetterTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provided = provider1.provided - provider = provided.attr - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provided in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provided = provider1.provided - provider2 = providers.Provider() - - provider = provided.attr - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provided in all_providers - - -class ItemGetterTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provided = provider1.provided - provider = provided["item"] - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 2 - assert provider1 in all_providers - assert provided in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provided = provider1.provided - provider2 = providers.Provider() - - provider = provided["item"] - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provider2 in all_providers - assert provided in all_providers - - -class MethodCallerTests(unittest.TestCase): - - def test_traverse(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider = method.call() - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 3 - assert provider1 in all_providers - assert provided in all_providers - assert method in all_providers - - def test_traverse_args(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider2 = providers.Provider() - provider = method.call("foo", provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 4 - assert provider1 in all_providers - assert provider2 in all_providers - assert provided in all_providers - assert method in all_providers - - def test_traverse_kwargs(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider2 = providers.Provider() - provider = method.call(foo="foo", bar=provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 4 - assert provider1 in all_providers - assert provider2 in all_providers - assert provided in all_providers - assert method in all_providers - - def test_traverse_overridden(self): - provider1 = providers.Provider() - provided = provider1.provided - method = provided.method - provider2 = providers.Provider() - - provider = method.call() - provider.override(provider2) - - all_providers = list(provider.traverse()) - - assert len(all_providers) == 4 - assert provider1 in all_providers - assert provider2 in all_providers - assert provided in all_providers - assert method in all_providers diff --git a/tests/unit/providers/traversal/__init__.py b/tests/unit/providers/traversal/__init__.py new file mode 100644 index 00000000..c95aa224 --- /dev/null +++ b/tests/unit/providers/traversal/__init__.py @@ -0,0 +1 @@ +"""Traversal tests.""" diff --git a/tests/unit/providers/traversal/test_attribute_getter_py3.py b/tests/unit/providers/traversal/test_attribute_getter_py3.py new file mode 100644 index 00000000..27db8584 --- /dev/null +++ b/tests/unit/providers/traversal/test_attribute_getter_py3.py @@ -0,0 +1,31 @@ +"""AttributeGetter provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provided = provider1.provided + provider = provided.attr + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provided in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provided = provider1.provided + provider2 = providers.Provider() + + provider = provided.attr + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers diff --git a/tests/unit/providers/traversal/test_callable_py3.py b/tests/unit/providers/traversal/test_callable_py3.py new file mode 100644 index 00000000..aebf603a --- /dev/null +++ b/tests/unit/providers/traversal/test_callable_py3.py @@ -0,0 +1,64 @@ +"""Callable provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Callable(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Callable(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Callable(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + + provider = providers.Callable(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + provider2 = providers.Object("bar") + provider3 = providers.Object("baz") + + provider = providers.Callable(provider1, provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_configuration_py3.py b/tests/unit/providers/traversal/test_configuration_py3.py new file mode 100644 index 00000000..8b68f673 --- /dev/null +++ b/tests/unit/providers/traversal/test_configuration_py3.py @@ -0,0 +1,63 @@ +"""Configuration provider tests.""" + +from dependency_injector import providers + + +def test_traverse(): + config = providers.Configuration(default={"option1": {"option2": "option2"}}) + option1 = config.option1 + option2 = config.option1.option2 + option3 = config.option1[config.option1.option2] + + all_providers = list(config.traverse()) + + assert len(all_providers) == 3 + assert option1 in all_providers + assert option2 in all_providers + assert option3 in all_providers + + +def test_traverse_typed(): + config = providers.Configuration() + option = config.option + typed_option = config.option.as_int() + + all_providers = list(typed_option.traverse()) + + assert len(all_providers) == 1 + assert option in all_providers + + +def test_traverse_overridden(): + options = {"option1": {"option2": "option2"}} + config = providers.Configuration() + config.from_dict(options) + + all_providers = list(config.traverse()) + + assert len(all_providers) == 1 + overridden, = all_providers + assert overridden() == options + assert overridden is config.last_overriding + + +def test_traverse_overridden_option_1(): + options = {"option2": "option2"} + config = providers.Configuration() + config.option1.from_dict(options) + + all_providers = list(config.traverse()) + + assert len(all_providers) == 2 + assert config.option1 in all_providers + assert config.last_overriding in all_providers + + +def test_traverse_overridden_option_2(): + options = {"option2": "option2"} + config = providers.Configuration() + config.option1.from_dict(options) + + all_providers = list(config.option1.traverse()) + + assert len(all_providers) == 0 diff --git a/tests/unit/providers/traversal/test_container_py3.py b/tests/unit/providers/traversal/test_container_py3.py new file mode 100644 index 00000000..ebc6cca6 --- /dev/null +++ b/tests/unit/providers/traversal/test_container_py3.py @@ -0,0 +1,42 @@ +"""Container provider traversal tests.""" + +from dependency_injector import containers, providers + + +def test_traverse(): + class Container(containers.DeclarativeContainer): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Container(Container) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert {list, dict} == {provider.provides for provider in all_providers} + + +def test_traverse_overridden(): + class Container1(containers.DeclarativeContainer): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + class Container2(containers.DeclarativeContainer): + provider1 = providers.Callable(tuple) + provider2 = providers.Callable(str) + + container2 = Container2() + + provider = providers.Container(Container1) + provider.override(container2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 5 + assert {list, dict, tuple, str} == { + provider.provides + for provider in all_providers + if isinstance(provider, providers.Callable) + } + assert provider.last_overriding in all_providers + assert provider.last_overriding() is container2 diff --git a/tests/unit/providers/traversal/test_delegate_py3.py b/tests/unit/providers/traversal/test_delegate_py3.py new file mode 100644 index 00000000..c251ec2f --- /dev/null +++ b/tests/unit/providers/traversal/test_delegate_py3.py @@ -0,0 +1,32 @@ +"""Delegate provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal_provider(): + another_provider = providers.Provider() + provider = providers.Delegate(another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_provider_and_overriding(): + provider1 = providers.Provider() + provider2 = providers.Provider() + + provider3 = providers.Provider() + provider3.override(provider2) + + provider = providers.Delegate(provider1) + + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_dependencies_container_py3.py b/tests/unit/providers/traversal/test_dependencies_container_py3.py new file mode 100644 index 00000000..38d8b72a --- /dev/null +++ b/tests/unit/providers/traversal/test_dependencies_container_py3.py @@ -0,0 +1,52 @@ +"""DependenciesContainer provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal(): + provider = providers.DependenciesContainer() + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traversal_default(): + another_provider = providers.Provider() + provider = providers.DependenciesContainer(default=another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_fluent_interface(): + provider = providers.DependenciesContainer() + provider1 = provider.provider1 + provider2 = provider.provider2 + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traversal_overriding(): + provider1 = providers.Provider() + provider2 = providers.Provider() + provider3 = providers.DependenciesContainer( + provider1=provider1, + provider2=provider2, + ) + + provider = providers.DependenciesContainer() + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 5 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + assert provider.provider1 in all_providers + assert provider.provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_dependency_py3.py b/tests/unit/providers/traversal/test_dependency_py3.py new file mode 100644 index 00000000..4939a135 --- /dev/null +++ b/tests/unit/providers/traversal/test_dependency_py3.py @@ -0,0 +1,35 @@ +"""Dependency provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal(): + provider = providers.Dependency() + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traversal_default(): + another_provider = providers.Provider() + provider = providers.Dependency(default=another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_overriding(): + provider1 = providers.Provider() + + provider2 = providers.Provider() + provider2.override(provider1) + + provider = providers.Dependency() + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_dict_py3.py b/tests/unit/providers/traversal/test_dict_py3.py new file mode 100644 index 00000000..9469d48f --- /dev/null +++ b/tests/unit/providers/traversal/test_dict_py3.py @@ -0,0 +1,31 @@ +"""Dict provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Dict(foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider3 = providers.Dict(bar=provider1, baz=provider2) + + provider = providers.Dict(foo="foo") + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_factory_aggregate_py3.py b/tests/unit/providers/traversal/test_factory_aggregate_py3.py new file mode 100644 index 00000000..54bd8f6a --- /dev/null +++ b/tests/unit/providers/traversal/test_factory_aggregate_py3.py @@ -0,0 +1,15 @@ +"""FactoryAggregate provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + factory1 = providers.Factory(dict) + factory2 = providers.Factory(list) + provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert factory1 in all_providers + assert factory2 in all_providers diff --git a/tests/unit/providers/traversal/test_factory_py3.py b/tests/unit/providers/traversal/test_factory_py3.py new file mode 100644 index 00000000..57bdf25e --- /dev/null +++ b/tests/unit/providers/traversal/test_factory_py3.py @@ -0,0 +1,77 @@ +"""Factory provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Factory(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Factory(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Factory(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_attributes(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Factory(dict) + provider.add_attributes(foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + + provider = providers.Factory(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + provider2 = providers.Object("bar") + provider3 = providers.Object("baz") + + provider = providers.Factory(provider1, provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_item_getter_py3.py b/tests/unit/providers/traversal/test_item_getter_py3.py new file mode 100644 index 00000000..5629b536 --- /dev/null +++ b/tests/unit/providers/traversal/test_item_getter_py3.py @@ -0,0 +1,31 @@ +"""ItemGetter provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provided = provider1.provided + provider = provided["item"] + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provided in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provided = provider1.provided + provider2 = providers.Provider() + + provider = provided["item"] + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers diff --git a/tests/unit/providers/traversal/test_list_py3.py b/tests/unit/providers/traversal/test_list_py3.py new file mode 100644 index 00000000..da6b3b74 --- /dev/null +++ b/tests/unit/providers/traversal/test_list_py3.py @@ -0,0 +1,31 @@ +"""List provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.List("foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider3 = providers.List(provider1, provider2) + + provider = providers.List("foo") + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_method_caller_py3.py b/tests/unit/providers/traversal/test_method_caller_py3.py new file mode 100644 index 00000000..47bbfbab --- /dev/null +++ b/tests/unit/providers/traversal/test_method_caller_py3.py @@ -0,0 +1,67 @@ +"""MethodCaller provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider = method.call() + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provided in all_providers + assert method in all_providers + + +def test_traverse_args(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + provider = method.call("foo", provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 4 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers + assert method in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + provider = method.call(foo="foo", bar=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 4 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers + assert method in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provided = provider1.provided + method = provided.method + provider2 = providers.Provider() + + provider = method.call() + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 4 + assert provider1 in all_providers + assert provider2 in all_providers + assert provided in all_providers + assert method in all_providers diff --git a/tests/unit/providers/traversal/test_object_py3.py b/tests/unit/providers/traversal/test_object_py3.py new file mode 100644 index 00000000..6c55b93b --- /dev/null +++ b/tests/unit/providers/traversal/test_object_py3.py @@ -0,0 +1,37 @@ +"""Object provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal(): + provider = providers.Object("string") + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traversal_provider(): + another_provider = providers.Provider() + provider = providers.Object(another_provider) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert another_provider in all_providers + + +def test_traversal_provider_and_overriding(): + another_provider_1 = providers.Provider() + another_provider_2 = providers.Provider() + another_provider_3 = providers.Provider() + + provider = providers.Object(another_provider_1) + + provider.override(another_provider_2) + provider.override(another_provider_3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert another_provider_1 in all_providers + assert another_provider_2 in all_providers + assert another_provider_3 in all_providers diff --git a/tests/unit/providers/traversal/test_provided_instance_py3.py b/tests/unit/providers/traversal/test_provided_instance_py3.py new file mode 100644 index 00000000..8e13dbf4 --- /dev/null +++ b/tests/unit/providers/traversal/test_provided_instance_py3.py @@ -0,0 +1,27 @@ +"""ProvidedInstance provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider1 = providers.Provider() + provider = provider1.provided + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert provider1 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Provider() + provider2 = providers.Provider() + + provider = provider1.provided + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_provider_py3.py b/tests/unit/providers/traversal/test_provider_py3.py new file mode 100644 index 00000000..748709f9 --- /dev/null +++ b/tests/unit/providers/traversal/test_provider_py3.py @@ -0,0 +1,60 @@ +"""Provider traversal tests.""" + +from dependency_injector import providers + + +def test_traversal_overriding(): + provider1 = providers.Provider() + provider2 = providers.Provider() + provider3 = providers.Provider() + + provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + + +def test_traversal_overriding_nested(): + provider1 = providers.Provider() + + provider2 = providers.Provider() + provider2.override(provider1) + + provider3 = providers.Provider() + provider3.override(provider2) + + provider = providers.Provider() + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + + +def test_traverse_types_filtering(): + provider1 = providers.Resource(dict) + provider2 = providers.Resource(dict) + provider3 = providers.Provider() + + provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) + provider.override(provider3) + + all_providers = list(provider.traverse(types=[providers.Resource])) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers diff --git a/tests/unit/providers/traversal/test_resource_py3.py b/tests/unit/providers/traversal/test_resource_py3.py new file mode 100644 index 00000000..b4a1179c --- /dev/null +++ b/tests/unit/providers/traversal/test_resource_py3.py @@ -0,0 +1,60 @@ +"""Resource provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Resource(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Resource(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Resource(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Resource(list) + provider2 = providers.Resource(tuple) + + provider = providers.Resource(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + + provider = providers.Resource(provider1) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 1 + assert provider1 in all_providers + diff --git a/tests/unit/providers/traversal/test_selector_py3.py b/tests/unit/providers/traversal/test_selector_py3.py new file mode 100644 index 00000000..bd345076 --- /dev/null +++ b/tests/unit/providers/traversal/test_selector_py3.py @@ -0,0 +1,59 @@ +"""Selector provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + switch = lambda: "provider1" + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Selector( + switch, + provider1=provider1, + provider2=provider2, + ) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_switch(): + switch = providers.Callable(lambda: "provider1") + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + + provider = providers.Selector( + switch, + provider1=provider1, + provider2=provider2, + ) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert switch in all_providers + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Callable(list) + provider2 = providers.Callable(dict) + selector1 = providers.Selector(lambda: "provider1", provider1=provider1) + + provider = providers.Selector( + lambda: "provider2", + provider2=provider2, + ) + provider.override(selector1) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert selector1 in all_providers diff --git a/tests/unit/providers/traversal/test_singleton_py3.py b/tests/unit/providers/traversal/test_singleton_py3.py new file mode 100644 index 00000000..78240732 --- /dev/null +++ b/tests/unit/providers/traversal/test_singleton_py3.py @@ -0,0 +1,77 @@ +"""Singleton provider traversal tests.""" + +from dependency_injector import providers + + +def test_traverse(): + provider = providers.Singleton(dict) + all_providers = list(provider.traverse()) + assert len(all_providers) == 0 + + +def test_traverse_args(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Singleton(list, "foo", provider1, provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_kwargs(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Singleton(dict, foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_attributes(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + provider = providers.Singleton(dict) + provider.add_attributes(foo="foo", bar=provider1, baz=provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_overridden(): + provider1 = providers.Object("bar") + provider2 = providers.Object("baz") + + provider = providers.Singleton(dict, "foo") + provider.override(provider1) + provider.override(provider2) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers + + +def test_traverse_provides(): + provider1 = providers.Callable(list) + provider2 = providers.Object("bar") + provider3 = providers.Object("baz") + + provider = providers.Singleton(provider1, provider2) + provider.override(provider3) + + all_providers = list(provider.traverse()) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers diff --git a/tests/unit/providers/traversal/test_traverse_py3.py b/tests/unit/providers/traversal/test_traverse_py3.py new file mode 100644 index 00000000..01797b4a --- /dev/null +++ b/tests/unit/providers/traversal/test_traverse_py3.py @@ -0,0 +1,40 @@ +"""Provider's traversal tests.""" + +from dependency_injector import providers + + +def test_traverse_cycled_graph(): + provider1 = providers.Provider() + + provider2 = providers.Provider() + provider2.override(provider1) + + provider3 = providers.Provider() + provider3.override(provider2) + + provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3 + + all_providers = list(providers.traverse(provider1)) + + assert len(all_providers) == 3 + assert provider1 in all_providers + assert provider2 in all_providers + assert provider3 in all_providers + + +def test_traverse_types_filtering(): + provider1 = providers.Resource(dict) + provider2 = providers.Resource(dict) + provider3 = providers.Provider() + + provider = providers.Provider() + + provider.override(provider1) + provider.override(provider2) + provider.override(provider3) + + all_providers = list(providers.traverse(provider, types=[providers.Resource])) + + assert len(all_providers) == 2 + assert provider1 in all_providers + assert provider2 in all_providers