import unittest from dependency_injector import containers, providers class TraverseTests(unittest.TestCase): def test_traverse_cycled_graph(self): provider1 = providers.Provider() provider2 = providers.Provider() provider2.override(provider1) provider3 = providers.Provider() provider3.override(provider2) provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3 all_providers = list(providers.traverse(provider1)) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers def test_traverse_types_filtering(self): provider1 = providers.Resource(dict) provider2 = providers.Resource(dict) provider3 = providers.Provider() provider = providers.Provider() provider.override(provider1) provider.override(provider2) provider.override(provider3) all_providers = list(providers.traverse(provider, types=[providers.Resource])) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers class ProviderTests(unittest.TestCase): def test_traversal_overriding(self): provider1 = providers.Provider() provider2 = providers.Provider() provider3 = providers.Provider() provider = providers.Provider() provider.override(provider1) provider.override(provider2) provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers def test_traversal_overriding_nested(self): provider1 = providers.Provider() provider2 = providers.Provider() provider2.override(provider1) provider3 = providers.Provider() provider3.override(provider2) provider = providers.Provider() provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers def test_traverse_types_filtering(self): provider1 = providers.Resource(dict) provider2 = providers.Resource(dict) provider3 = providers.Provider() provider = providers.Provider() provider.override(provider1) provider.override(provider2) provider.override(provider3) all_providers = list(provider.traverse(types=[providers.Resource])) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers class ObjectTests(unittest.TestCase): def test_traversal(self): provider = providers.Object("string") all_providers = list(provider.traverse()) assert len(all_providers) == 0 def test_traversal_provider(self): another_provider = providers.Provider() provider = providers.Object(another_provider) all_providers = list(provider.traverse()) assert len(all_providers) == 1 assert another_provider in all_providers def test_traversal_provider_and_overriding(self): another_provider_1 = providers.Provider() another_provider_2 = providers.Provider() another_provider_3 = providers.Provider() provider = providers.Object(another_provider_1) provider.override(another_provider_2) provider.override(another_provider_3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert another_provider_1 in all_providers assert another_provider_2 in all_providers assert another_provider_3 in all_providers class DelegateTests(unittest.TestCase): def test_traversal_provider(self): another_provider = providers.Provider() provider = providers.Delegate(another_provider) all_providers = list(provider.traverse()) assert len(all_providers) == 1 assert another_provider in all_providers def test_traversal_provider_and_overriding(self): provider1 = providers.Provider() provider2 = providers.Provider() provider3 = providers.Provider() provider3.override(provider2) provider = providers.Delegate(provider1) provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers class DependencyTests(unittest.TestCase): def test_traversal(self): provider = providers.Dependency() all_providers = list(provider.traverse()) assert len(all_providers) == 0 def test_traversal_default(self): another_provider = providers.Provider() provider = providers.Dependency(default=another_provider) all_providers = list(provider.traverse()) assert len(all_providers) == 1 assert another_provider in all_providers def test_traversal_overriding(self): provider1 = providers.Provider() provider2 = providers.Provider() provider2.override(provider1) provider = providers.Dependency() provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers class DependenciesContainerTests(unittest.TestCase): def test_traversal(self): provider = providers.DependenciesContainer() all_providers = list(provider.traverse()) assert len(all_providers) == 0 def test_traversal_default(self): another_provider = providers.Provider() provider = providers.DependenciesContainer(default=another_provider) all_providers = list(provider.traverse()) assert len(all_providers) == 1 assert another_provider in all_providers def test_traversal_fluent_interface(self): provider = providers.DependenciesContainer() provider1 = provider.provider1 provider2 = provider.provider2 all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traversal_overriding(self): provider1 = providers.Provider() provider2 = providers.Provider() provider3 = providers.DependenciesContainer( provider1=provider1, provider2=provider2, ) provider = providers.DependenciesContainer() provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 5 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers assert provider.provider1 in all_providers assert provider.provider2 in all_providers class CallableTests(unittest.TestCase): def test_traverse(self): provider = providers.Callable(dict) all_providers = list(provider.traverse()) assert len(all_providers) == 0 def test_traverse_args(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Callable(list, "foo", provider1, provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_kwargs(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Callable(dict, foo="foo", bar=provider1, baz=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_overridden(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Callable(dict, "foo") provider.override(provider1) provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_provides(self): provider1 = providers.Callable(list) provider2 = providers.Object("bar") provider3 = providers.Object("baz") provider = providers.Callable(provider1, provider2) provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers class ConfigurationTests(unittest.TestCase): def test_traverse(self): config = providers.Configuration(default={"option1": {"option2": "option2"}}) option1 = config.option1 option2 = config.option1.option2 option3 = config.option1[config.option1.option2] all_providers = list(config.traverse()) assert len(all_providers) == 3 assert option1 in all_providers assert option2 in all_providers assert option3 in all_providers def test_traverse_typed(self): config = providers.Configuration() option = config.option typed_option = config.option.as_int() all_providers = list(typed_option.traverse()) assert len(all_providers) == 1 assert option in all_providers def test_traverse_overridden(self): options = {"option1": {"option2": "option2"}} config = providers.Configuration() config.from_dict(options) all_providers = list(config.traverse()) assert len(all_providers) == 1 overridden, = all_providers assert overridden() == options assert overridden is config.last_overriding def test_traverse_overridden_option_1(self): options = {"option2": "option2"} config = providers.Configuration() config.option1.from_dict(options) all_providers = list(config.traverse()) assert len(all_providers) == 2 assert config.option1 in all_providers assert config.last_overriding in all_providers def test_traverse_overridden_option_2(self): options = {"option2": "option2"} config = providers.Configuration() config.option1.from_dict(options) all_providers = list(config.option1.traverse()) assert len(all_providers) == 0 class FactoryTests(unittest.TestCase): def test_traverse(self): provider = providers.Factory(dict) all_providers = list(provider.traverse()) assert len(all_providers) == 0 def test_traverse_args(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Factory(list, "foo", provider1, provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_kwargs(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Factory(dict, foo="foo", bar=provider1, baz=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_attributes(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Factory(dict) provider.add_attributes(foo="foo", bar=provider1, baz=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_overridden(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Factory(dict, "foo") provider.override(provider1) provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_provides(self): provider1 = providers.Callable(list) provider2 = providers.Object("bar") provider3 = providers.Object("baz") provider = providers.Factory(provider1, provider2) provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers class FactoryAggregateTests(unittest.TestCase): def test_traverse(self): factory1 = providers.Factory(dict) factory2 = providers.Factory(list) provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert factory1 in all_providers assert factory2 in all_providers class BaseSingletonTests(unittest.TestCase): def test_traverse(self): provider = providers.Singleton(dict) all_providers = list(provider.traverse()) assert len(all_providers) == 0 def test_traverse_args(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Singleton(list, "foo", provider1, provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_kwargs(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Singleton(dict, foo="foo", bar=provider1, baz=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_attributes(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Singleton(dict) provider.add_attributes(foo="foo", bar=provider1, baz=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_overridden(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Singleton(dict, "foo") provider.override(provider1) provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_provides(self): provider1 = providers.Callable(list) provider2 = providers.Object("bar") provider3 = providers.Object("baz") provider = providers.Singleton(provider1, provider2) provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers class ListTests(unittest.TestCase): def test_traverse_args(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.List("foo", provider1, provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_overridden(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider3 = providers.List(provider1, provider2) provider = providers.List("foo") provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers class DictTests(unittest.TestCase): def test_traverse_kwargs(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Dict(foo="foo", bar=provider1, baz=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_overridden(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider3 = providers.Dict(bar=provider1, baz=provider2) provider = providers.Dict(foo="foo") provider.override(provider3) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provider3 in all_providers class ResourceTests(unittest.TestCase): def test_traverse(self): provider = providers.Resource(dict) all_providers = list(provider.traverse()) assert len(all_providers) == 0 def test_traverse_args(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Resource(list, "foo", provider1, provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_kwargs(self): provider1 = providers.Object("bar") provider2 = providers.Object("baz") provider = providers.Resource(dict, foo="foo", bar=provider1, baz=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_overridden(self): provider1 = providers.Resource(list) provider2 = providers.Resource(tuple) provider = providers.Resource(dict, "foo") provider.override(provider1) provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_provides(self): provider1 = providers.Callable(list) provider = providers.Resource(provider1) all_providers = list(provider.traverse()) assert len(all_providers) == 1 assert provider1 in all_providers class ContainerTests(unittest.TestCase): def test_traverse(self): class Container(containers.DeclarativeContainer): provider1 = providers.Callable(list) provider2 = providers.Callable(dict) provider = providers.Container(Container) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert {list, dict} == {provider.provides for provider in all_providers} def test_traverse_overridden(self): class Container1(containers.DeclarativeContainer): provider1 = providers.Callable(list) provider2 = providers.Callable(dict) class Container2(containers.DeclarativeContainer): provider1 = providers.Callable(tuple) provider2 = providers.Callable(str) container2 = Container2() provider = providers.Container(Container1) provider.override(container2) all_providers = list(provider.traverse()) assert len(all_providers) == 5 assert {list, dict, tuple, str} == { provider.provides for provider in all_providers if isinstance(provider, providers.Callable) } assert provider.last_overriding in all_providers assert provider.last_overriding() is container2 class SelectorTests(unittest.TestCase): def test_traverse(self): switch = lambda: "provider1" provider1 = providers.Callable(list) provider2 = providers.Callable(dict) provider = providers.Selector( switch, provider1=provider1, provider2=provider2, ) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers def test_traverse_switch(self): switch = providers.Callable(lambda: "provider1") provider1 = providers.Callable(list) provider2 = providers.Callable(dict) provider = providers.Selector( switch, provider1=provider1, provider2=provider2, ) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert switch in all_providers assert provider1 in all_providers assert provider2 in all_providers def test_traverse_overridden(self): provider1 = providers.Callable(list) provider2 = providers.Callable(dict) selector1 = providers.Selector(lambda: "provider1", provider1=provider1) provider = providers.Selector( lambda: "provider2", provider2=provider2, ) provider.override(selector1) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert selector1 in all_providers class ProvidedInstanceTests(unittest.TestCase): def test_traverse(self): provider1 = providers.Provider() provider = provider1.provided all_providers = list(provider.traverse()) assert len(all_providers) == 1 assert provider1 in all_providers def test_traverse_overridden(self): provider1 = providers.Provider() provider2 = providers.Provider() provider = provider1.provided provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provider2 in all_providers class AttributeGetterTests(unittest.TestCase): def test_traverse(self): provider1 = providers.Provider() provided = provider1.provided provider = provided.attr all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provided in all_providers def test_traverse_overridden(self): provider1 = providers.Provider() provided = provider1.provided provider2 = providers.Provider() provider = provided.attr provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provided in all_providers class ItemGetterTests(unittest.TestCase): def test_traverse(self): provider1 = providers.Provider() provided = provider1.provided provider = provided["item"] all_providers = list(provider.traverse()) assert len(all_providers) == 2 assert provider1 in all_providers assert provided in all_providers def test_traverse_overridden(self): provider1 = providers.Provider() provided = provider1.provided provider2 = providers.Provider() provider = provided["item"] provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provider2 in all_providers assert provided in all_providers class MethodCallerTests(unittest.TestCase): def test_traverse(self): provider1 = providers.Provider() provided = provider1.provided method = provided.method provider = method.call() all_providers = list(provider.traverse()) assert len(all_providers) == 3 assert provider1 in all_providers assert provided in all_providers assert method in all_providers def test_traverse_args(self): provider1 = providers.Provider() provided = provider1.provided method = provided.method provider2 = providers.Provider() provider = method.call("foo", provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 4 assert provider1 in all_providers assert provider2 in all_providers assert provided in all_providers assert method in all_providers def test_traverse_kwargs(self): provider1 = providers.Provider() provided = provider1.provided method = provided.method provider2 = providers.Provider() provider = method.call(foo="foo", bar=provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 4 assert provider1 in all_providers assert provider2 in all_providers assert provided in all_providers assert method in all_providers def test_traverse_overridden(self): provider1 = providers.Provider() provided = provider1.provided method = provided.method provider2 = providers.Provider() provider = method.call() provider.override(provider2) all_providers = list(provider.traverse()) assert len(all_providers) == 4 assert provider1 in all_providers assert provider2 in all_providers assert provided in all_providers assert method in all_providers