"""Selector provider traversal tests."""

from dependency_injector import providers


def test_traverse():
    switch = lambda: "provider1"
    provider1 = providers.Callable(list)
    provider2 = providers.Callable(dict)

    provider = providers.Selector(
        switch,
        provider1=provider1,
        provider2=provider2,
    )

    all_providers = list(provider.traverse())

    assert len(all_providers) == 2
    assert provider1 in all_providers
    assert provider2 in all_providers


def test_traverse_switch():
    switch = providers.Callable(lambda: "provider1")
    provider1 = providers.Callable(list)
    provider2 = providers.Callable(dict)

    provider = providers.Selector(
        switch,
        provider1=provider1,
        provider2=provider2,
    )

    all_providers = list(provider.traverse())

    assert len(all_providers) == 3
    assert switch in all_providers
    assert provider1 in all_providers
    assert provider2 in all_providers


def test_traverse_overridden():
    provider1 = providers.Callable(list)
    provider2 = providers.Callable(dict)
    selector1 = providers.Selector(lambda: "provider1", provider1=provider1)

    provider = providers.Selector(
        lambda: "provider2",
        provider2=provider2,
    )
    provider.override(selector1)

    all_providers = list(provider.traverse())

    assert len(all_providers) == 3
    assert provider1 in all_providers
    assert provider2 in all_providers
    assert selector1 in all_providers