import unittest

from dependency_injector import containers, providers


class TraverseTests(unittest.TestCase):

    def test_traverse_cycled_graph(self):
        provider1 = providers.Provider()

        provider2 = providers.Provider()
        provider2.override(provider1)

        provider3 = providers.Provider()
        provider3.override(provider2)

        provider1.override(provider3)  # Cycle: provider3 -> provider2 -> provider1 -> provider3

        all_providers = list(providers.traverse(provider1))

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)

    def test_traverse_types_filtering(self):
        provider1 = providers.Resource(dict)
        provider2 = providers.Resource(dict)
        provider3 = providers.Provider()

        provider = providers.Provider()

        provider.override(provider1)
        provider.override(provider2)
        provider.override(provider3)

        all_providers = list(providers.traverse(provider, types=[providers.Resource]))

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)


class ProviderTests(unittest.TestCase):

    def test_traversal_overriding(self):
        provider1 = providers.Provider()
        provider2 = providers.Provider()
        provider3 = providers.Provider()

        provider = providers.Provider()

        provider.override(provider1)
        provider.override(provider2)
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)

    def test_traversal_overriding_nested(self):
        provider1 = providers.Provider()

        provider2 = providers.Provider()
        provider2.override(provider1)

        provider3 = providers.Provider()
        provider3.override(provider2)

        provider = providers.Provider()
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)

    def test_traverse_types_filtering(self):
        provider1 = providers.Resource(dict)
        provider2 = providers.Resource(dict)
        provider3 = providers.Provider()

        provider = providers.Provider()

        provider.override(provider1)
        provider.override(provider2)
        provider.override(provider3)

        all_providers = list(provider.traverse(types=[providers.Resource]))

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)


class ObjectTests(unittest.TestCase):

    def test_traversal(self):
        provider = providers.Object('string')
        all_providers = list(provider.traverse())
        self.assertEqual(len(all_providers), 0)

    def test_traversal_provider(self):
        another_provider = providers.Provider()
        provider = providers.Object(another_provider)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 1)
        self.assertIn(another_provider, all_providers)

    def test_traversal_provider_and_overriding(self):
        another_provider_1 = providers.Provider()
        another_provider_2 = providers.Provider()
        another_provider_3 = providers.Provider()

        provider = providers.Object(another_provider_1)

        provider.override(another_provider_2)
        provider.override(another_provider_3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(another_provider_1, all_providers)
        self.assertIn(another_provider_2, all_providers)
        self.assertIn(another_provider_3, all_providers)


class DelegateTests(unittest.TestCase):

    def test_traversal_provider(self):
        another_provider = providers.Provider()
        provider = providers.Delegate(another_provider)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 1)
        self.assertIn(another_provider, all_providers)

    def test_traversal_provider_and_overriding(self):
        provider1 = providers.Provider()
        provider2 = providers.Provider()

        provider3 = providers.Provider()
        provider3.override(provider2)

        provider = providers.Delegate(provider1)

        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)


class DependencyTests(unittest.TestCase):

    def test_traversal(self):
        provider = providers.Dependency()
        all_providers = list(provider.traverse())
        self.assertEqual(len(all_providers), 0)

    def test_traversal_default(self):
        another_provider = providers.Provider()
        provider = providers.Dependency(default=another_provider)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 1)
        self.assertIn(another_provider, all_providers)

    def test_traversal_overriding(self):
        provider1 = providers.Provider()

        provider2 = providers.Provider()
        provider2.override(provider1)

        provider = providers.Dependency()
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)


class DependenciesContainerTests(unittest.TestCase):

    def test_traversal(self):
        provider = providers.DependenciesContainer()
        all_providers = list(provider.traverse())
        self.assertEqual(len(all_providers), 0)

    def test_traversal_default(self):
        another_provider = providers.Provider()
        provider = providers.DependenciesContainer(default=another_provider)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 1)
        self.assertIn(another_provider, all_providers)

    def test_traversal_fluent_interface(self):
        provider = providers.DependenciesContainer()
        provider1 = provider.provider1
        provider2 = provider.provider2

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traversal_overriding(self):
        provider1 = providers.Provider()
        provider2 = providers.Provider()
        provider3 = providers.DependenciesContainer(
            provider1=provider1,
            provider2=provider2,
        )

        provider = providers.DependenciesContainer()
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 5)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)
        self.assertIn(provider.provider1, all_providers)
        self.assertIn(provider.provider2, all_providers)


class CallableTests(unittest.TestCase):

    def test_traverse(self):
        provider = providers.Callable(dict)
        all_providers = list(provider.traverse())
        self.assertEqual(len(all_providers), 0)

    def test_traverse_args(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Callable(list, 'foo', provider1, provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_kwargs(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Callable(dict, foo='foo', bar=provider1, baz=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')

        provider = providers.Callable(dict, 'foo')
        provider.override(provider1)
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_provides(self):
        provider1 = providers.Callable(list)
        provider2 = providers.Object('bar')
        provider3 = providers.Object('baz')

        provider = providers.Callable(provider1, provider2)
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)


class ConfigurationTests(unittest.TestCase):

    def test_traverse(self):
        config = providers.Configuration(default={'option1': {'option2': 'option2'}})
        option1 = config.option1
        option2 = config.option1.option2
        option3 = config.option1[config.option1.option2]

        all_providers = list(config.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(option1, all_providers)
        self.assertIn(option2, all_providers)
        self.assertIn(option3, all_providers)

    def test_traverse_typed(self):
        config = providers.Configuration()
        option = config.option
        typed_option = config.option.as_int()

        all_providers = list(typed_option.traverse())

        self.assertEqual(len(all_providers), 1)
        self.assertIn(option, all_providers)

    def test_traverse_overridden(self):
        options = {'option1': {'option2': 'option2'}}
        config = providers.Configuration()
        config.from_dict(options)

        all_providers = list(config.traverse())

        self.assertEqual(len(all_providers), 1)
        overridden, = all_providers
        self.assertEqual(overridden(), options)
        self.assertIs(overridden, config.last_overriding)

    def test_traverse_overridden_option_1(self):
        options = {'option2': 'option2'}
        config = providers.Configuration()
        config.option1.from_dict(options)

        all_providers = list(config.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(config.option1, all_providers)
        self.assertIn(config.last_overriding, all_providers)

    def test_traverse_overridden_option_2(self):
        options = {'option2': 'option2'}
        config = providers.Configuration()
        config.option1.from_dict(options)

        all_providers = list(config.option1.traverse())

        self.assertEqual(len(all_providers), 0)


class FactoryTests(unittest.TestCase):

    def test_traverse(self):
        provider = providers.Factory(dict)
        all_providers = list(provider.traverse())
        self.assertEqual(len(all_providers), 0)

    def test_traverse_args(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Factory(list, 'foo', provider1, provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_kwargs(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Factory(dict, foo='foo', bar=provider1, baz=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_attributes(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Factory(dict)
        provider.add_attributes(foo='foo', bar=provider1, baz=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')

        provider = providers.Factory(dict, 'foo')
        provider.override(provider1)
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_provides(self):
        provider1 = providers.Callable(list)
        provider2 = providers.Object('bar')
        provider3 = providers.Object('baz')

        provider = providers.Factory(provider1, provider2)
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)


class FactoryAggregateTests(unittest.TestCase):

    def test_traverse(self):
        factory1 = providers.Factory(dict)
        factory2 = providers.Factory(list)
        provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(factory1, all_providers)
        self.assertIn(factory2, all_providers)


class BaseSingletonTests(unittest.TestCase):

    def test_traverse(self):
        provider = providers.Singleton(dict)
        all_providers = list(provider.traverse())
        self.assertEqual(len(all_providers), 0)

    def test_traverse_args(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Singleton(list, 'foo', provider1, provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_kwargs(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Singleton(dict, foo='foo', bar=provider1, baz=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_attributes(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Singleton(dict)
        provider.add_attributes(foo='foo', bar=provider1, baz=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')

        provider = providers.Singleton(dict, 'foo')
        provider.override(provider1)
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_provides(self):
        provider1 = providers.Callable(list)
        provider2 = providers.Object('bar')
        provider3 = providers.Object('baz')

        provider = providers.Singleton(provider1, provider2)
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)


class ListTests(unittest.TestCase):

    def test_traverse_args(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.List('foo', provider1, provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider3 = providers.List(provider1, provider2)

        provider = providers.List('foo')
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)


class DictTests(unittest.TestCase):

    def test_traverse_kwargs(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Dict(foo='foo', bar=provider1, baz=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider3 = providers.Dict(bar=provider1, baz=provider2)

        provider = providers.Dict(foo='foo')
        provider.override(provider3)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provider3, all_providers)


class ResourceTests(unittest.TestCase):

    def test_traverse(self):
        provider = providers.Resource(dict)
        all_providers = list(provider.traverse())
        self.assertEqual(len(all_providers), 0)

    def test_traverse_args(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Resource(list, 'foo', provider1, provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_kwargs(self):
        provider1 = providers.Object('bar')
        provider2 = providers.Object('baz')
        provider = providers.Resource(dict, foo='foo', bar=provider1, baz=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Resource(list)
        provider2 = providers.Resource(tuple)

        provider = providers.Resource(dict, 'foo')
        provider.override(provider1)
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_provides(self):
        provider1 = providers.Callable(list)

        provider = providers.Resource(provider1)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 1)
        self.assertIn(provider1, all_providers)


class ContainerTests(unittest.TestCase):

    def test_traverse(self):
        class Container(containers.DeclarativeContainer):
            provider1 = providers.Callable(list)
            provider2 = providers.Callable(dict)

        provider = providers.Container(Container)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertEqual(
            {provider.provides for provider in all_providers},
            {list,  dict},
        )

    def test_traverse_overridden(self):
        class Container1(containers.DeclarativeContainer):
            provider1 = providers.Callable(list)
            provider2 = providers.Callable(dict)

        class Container2(containers.DeclarativeContainer):
            provider1 = providers.Callable(tuple)
            provider2 = providers.Callable(str)

        container2 = Container2()

        provider = providers.Container(Container1)
        provider.override(container2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 5)
        self.assertEqual(
            {
                provider.provides
                for provider in all_providers
                if isinstance(provider, providers.Callable)
            },
            {list, dict, tuple, str},
        )
        self.assertIn(provider.last_overriding, all_providers)
        self.assertIs(provider.last_overriding(), container2)


class SelectorTests(unittest.TestCase):

    def test_traverse(self):
        switch = lambda: 'provider1'
        provider1 = providers.Callable(list)
        provider2 = providers.Callable(dict)

        provider = providers.Selector(
            switch,
            provider1=provider1,
            provider2=provider2,
        )

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_switch(self):
        switch = providers.Callable(lambda: 'provider1')
        provider1 = providers.Callable(list)
        provider2 = providers.Callable(dict)

        provider = providers.Selector(
            switch,
            provider1=provider1,
            provider2=provider2,
        )

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(switch, all_providers)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Callable(list)
        provider2 = providers.Callable(dict)
        selector1 = providers.Selector(lambda: 'provider1', provider1=provider1)

        provider = providers.Selector(
            lambda: 'provider2',
            provider2=provider2,
        )
        provider.override(selector1)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(selector1, all_providers)


class ProvidedInstanceTests(unittest.TestCase):

    def test_traverse(self):
        provider1 = providers.Provider()
        provider = provider1.provided

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 1)
        self.assertIn(provider1, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Provider()
        provider2 = providers.Provider()

        provider = provider1.provided
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)


class AttributeGetterTests(unittest.TestCase):

    def test_traverse(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        provider = provided.attr

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provided, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        provider2 = providers.Provider()

        provider = provided.attr
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provided, all_providers)


class ItemGetterTests(unittest.TestCase):

    def test_traverse(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        provider = provided['item']

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provided, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        provider2 = providers.Provider()

        provider = provided['item']
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provided, all_providers)


class MethodCallerTests(unittest.TestCase):

    def test_traverse(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        method = provided.method
        provider = method.call()

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provided, all_providers)
        self.assertIn(method, all_providers)

    def test_traverse_args(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        method = provided.method
        provider2 = providers.Provider()
        provider = method.call('foo', provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 4)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provided, all_providers)
        self.assertIn(method, all_providers)

    def test_traverse_kwargs(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        method = provided.method
        provider2 = providers.Provider()
        provider = method.call(foo='foo', bar=provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 4)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provided, all_providers)
        self.assertIn(method, all_providers)

    def test_traverse_overridden(self):
        provider1 = providers.Provider()
        provided = provider1.provided
        method = provided.method
        provider2 = providers.Provider()

        provider = method.call()
        provider.override(provider2)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 4)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(provided, all_providers)
        self.assertIn(method, all_providers)