"""Tests for provider overriding in async mode."""

from dependency_injector import providers
from pytest import mark


@mark.asyncio
async def test_provider():
    dependency = object()

    async def _get_dependency_async():
        return dependency

    def _get_dependency_sync():
        return dependency

    provider = providers.Provider()

    provider.override(providers.Callable(_get_dependency_async))
    dependency1 = await provider()

    provider.override(providers.Callable(_get_dependency_sync))
    dependency2 = await provider()

    assert dependency1 is dependency
    assert dependency2 is dependency


@mark.asyncio
async def test_callable():
    dependency = object()

    async def _get_dependency_async():
        return dependency

    def _get_dependency_sync():
        return dependency

    provider = providers.Callable(_get_dependency_async)
    dependency1 = await provider()

    provider.override(providers.Callable(_get_dependency_sync))
    dependency2 = await provider()

    assert dependency1 is dependency
    assert dependency2 is dependency


@mark.asyncio
async def test_factory():
    dependency = object()

    async def _get_dependency_async():
        return dependency

    def _get_dependency_sync():
        return dependency

    provider = providers.Factory(_get_dependency_async)
    dependency1 = await provider()

    provider.override(providers.Callable(_get_dependency_sync))
    dependency2 = await provider()

    assert dependency1 is dependency
    assert dependency2 is dependency


@mark.asyncio
async def test_async_mode_enabling():
    dependency = object()

    async def _get_dependency_async():
        return dependency

    provider = providers.Callable(_get_dependency_async)
    assert provider.is_async_mode_undefined() is True

    await provider()

    assert provider.is_async_mode_enabled() is True


@mark.asyncio
async def test_async_mode_disabling():
    dependency = object()

    def _get_dependency():
        return dependency

    provider = providers.Callable(_get_dependency)
    assert provider.is_async_mode_undefined() is True

    provider()

    assert provider.is_async_mode_disabled() is True


@mark.asyncio
async def test_async_mode_enabling_on_overriding():
    dependency = object()

    async def _get_dependency_async():
        return dependency

    provider = providers.Provider()
    provider.override(providers.Callable(_get_dependency_async))
    assert provider.is_async_mode_undefined() is True

    await provider()

    assert provider.is_async_mode_enabled() is True


def test_async_mode_disabling_on_overriding():
    dependency = object()

    def _get_dependency():
        return dependency

    provider = providers.Provider()
    provider.override(providers.Callable(_get_dependency))
    assert provider.is_async_mode_undefined() is True

    provider()

    assert provider.is_async_mode_disabled() is True