From 11721b0d5ab8aefe4faec545915a471a2d3c81b9 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Fri, 15 Oct 2021 11:44:17 -0400 Subject: [PATCH] Migrate async tests --- tests/unit/providers/async/common.py | 45 + .../async/test_async_mode_api_py36.py | 45 + .../async/test_delegated_singleton_py36.py | 38 + ...t_delegated_thread_local_singleton_py36.py | 38 + ...st_delegated_thread_safe_singleton_py36.py | 38 + .../providers/async/test_dependency_py36.py | 88 ++ tests/unit/providers/async/test_dict_py36.py | 23 + .../async/test_factory_aggregate_py36.py | 30 + .../unit/providers/async/test_factory_py36.py | 122 ++- tests/unit/providers/async/test_list_py36.py | 24 + .../providers/async/test_override_py36.py | 127 +++ .../async/test_provided_instance_py36.py | 180 ++++ .../providers/async/test_singleton_py36.py | 117 +++ .../async/test_thread_local_singleton_py36.py | 63 ++ .../async/test_thread_safe_singleton_py36.py | 38 + .../providers/async/test_typing_stubs_py36.py | 36 + tests/unit/providers/test_async_py36.py | 843 ------------------ 17 files changed, 1012 insertions(+), 883 deletions(-) create mode 100644 tests/unit/providers/async/common.py create mode 100644 tests/unit/providers/async/test_async_mode_api_py36.py create mode 100644 tests/unit/providers/async/test_delegated_singleton_py36.py create mode 100644 tests/unit/providers/async/test_delegated_thread_local_singleton_py36.py create mode 100644 tests/unit/providers/async/test_delegated_thread_safe_singleton_py36.py create mode 100644 tests/unit/providers/async/test_dependency_py36.py create mode 100644 tests/unit/providers/async/test_dict_py36.py create mode 100644 tests/unit/providers/async/test_factory_aggregate_py36.py create mode 100644 tests/unit/providers/async/test_list_py36.py create mode 100644 tests/unit/providers/async/test_override_py36.py create mode 100644 tests/unit/providers/async/test_provided_instance_py36.py create mode 100644 tests/unit/providers/async/test_singleton_py36.py create mode 100644 tests/unit/providers/async/test_thread_local_singleton_py36.py create mode 100644 tests/unit/providers/async/test_thread_safe_singleton_py36.py create mode 100644 tests/unit/providers/async/test_typing_stubs_py36.py delete mode 100644 tests/unit/providers/test_async_py36.py diff --git a/tests/unit/providers/async/common.py b/tests/unit/providers/async/common.py new file mode 100644 index 00000000..ddea3e79 --- /dev/null +++ b/tests/unit/providers/async/common.py @@ -0,0 +1,45 @@ +"""Common test artifacts.""" + +import asyncio +import random + +from dependency_injector import containers, providers + + +RESOURCE1 = object() +RESOURCE2 = object() + + +async def init_resource(resource): + await asyncio.sleep(random.randint(1, 10) / 1000) + yield resource + await asyncio.sleep(random.randint(1, 10) / 1000) + + +class Client: + def __init__(self, resource1: object, resource2: object) -> None: + self.resource1 = resource1 + self.resource2 = resource2 + + +class Service: + def __init__(self, client: Client) -> None: + self.client = client + + +class BaseContainer(containers.DeclarativeContainer): + resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) + resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) + + +class Container(BaseContainer): + client = providers.Factory( + Client, + resource1=BaseContainer.resource1, + resource2=BaseContainer.resource2, + ) + + service = providers.Factory( + Service, + client=client, + ) diff --git a/tests/unit/providers/async/test_async_mode_api_py36.py b/tests/unit/providers/async/test_async_mode_api_py36.py new file mode 100644 index 00000000..bfc12723 --- /dev/null +++ b/tests/unit/providers/async/test_async_mode_api_py36.py @@ -0,0 +1,45 @@ +"""Tests for provider async mode API.""" + +from dependency_injector import providers +from pytest import fixture + + +@fixture +def provider(): + return providers.Provider() + + +def test_default_mode(provider: providers.Provider): + assert provider.is_async_mode_enabled() is False + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is True + + +def test_enable(provider: providers.Provider): + provider.enable_async_mode() + + assert provider.is_async_mode_enabled() is True + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is False + + +def test_disable(provider: providers.Provider): + provider.disable_async_mode() + + assert provider.is_async_mode_enabled() is False + assert provider.is_async_mode_disabled() is True + assert provider.is_async_mode_undefined() is False + + +def test_reset(provider: providers.Provider): + provider.enable_async_mode() + + assert provider.is_async_mode_enabled() is True + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is False + + provider.reset_async_mode() + + assert provider.is_async_mode_enabled() is False + assert provider.is_async_mode_disabled() is False + assert provider.is_async_mode_undefined() is True diff --git a/tests/unit/providers/async/test_delegated_singleton_py36.py b/tests/unit/providers/async/test_delegated_singleton_py36.py new file mode 100644 index 00000000..ebfbd772 --- /dev/null +++ b/tests/unit/providers/async/test_delegated_singleton_py36.py @@ -0,0 +1,38 @@ +"""DelegatedSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.DelegatedSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.DelegatedSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_delegated_thread_local_singleton_py36.py b/tests/unit/providers/async/test_delegated_thread_local_singleton_py36.py new file mode 100644 index 00000000..5f5e9423 --- /dev/null +++ b/tests/unit/providers/async/test_delegated_thread_local_singleton_py36.py @@ -0,0 +1,38 @@ +"""DelegatedThreadLocalSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.DelegatedThreadLocalSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.DelegatedThreadLocalSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_delegated_thread_safe_singleton_py36.py b/tests/unit/providers/async/test_delegated_thread_safe_singleton_py36.py new file mode 100644 index 00000000..046ce951 --- /dev/null +++ b/tests/unit/providers/async/test_delegated_thread_safe_singleton_py36.py @@ -0,0 +1,38 @@ +"""DelegatedThreadSafeSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.DelegatedThreadSafeSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.DelegatedThreadSafeSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_dependency_py36.py b/tests/unit/providers/async/test_dependency_py36.py new file mode 100644 index 00000000..b42d3a97 --- /dev/null +++ b/tests/unit/providers/async/test_dependency_py36.py @@ -0,0 +1,88 @@ +"""Dependency provider async mode tests.""" + +from dependency_injector import providers, errors +from pytest import mark, raises + + +@mark.asyncio +async def test_provide_error(): + async def get_async(): + raise Exception + + provider = providers.Dependency() + provider.override(providers.Callable(get_async)) + + with raises(Exception): + await provider() + + +@mark.asyncio +async def test_isinstance(): + dependency = 1.0 + + async def get_async(): + return dependency + + provider = providers.Dependency(instance_of=float) + provider.override(providers.Callable(get_async)) + + assert provider.is_async_mode_undefined() is True + + dependency1 = await provider() + + assert provider.is_async_mode_enabled() is True + + dependency2 = await provider() + + assert dependency1 == dependency + assert dependency2 == dependency + + +@mark.asyncio +async def test_isinstance_invalid(): + async def get_async(): + return {} + + provider = providers.Dependency(instance_of=float) + provider.override(providers.Callable(get_async)) + + assert provider.is_async_mode_undefined() is True + + with raises(errors.Error): + await provider() + + assert provider.is_async_mode_enabled() is True + + +@mark.asyncio +async def test_async_mode(): + dependency = 123 + + async def get_async(): + return dependency + + def get_sync(): + return dependency + + provider = providers.Dependency(instance_of=int) + provider.override(providers.Factory(get_async)) + + assert provider.is_async_mode_undefined() is True + + dependency1 = await provider() + + assert provider.is_async_mode_enabled() is True + + dependency2 = await provider() + assert dependency1 == dependency + assert dependency2 == dependency + + provider.override(providers.Factory(get_sync)) + + dependency3 = await provider() + + assert provider.is_async_mode_enabled() is True + + dependency4 = await provider() + assert dependency3 == dependency + assert dependency4 == dependency diff --git a/tests/unit/providers/async/test_dict_py36.py b/tests/unit/providers/async/test_dict_py36.py new file mode 100644 index 00000000..132df56c --- /dev/null +++ b/tests/unit/providers/async/test_dict_py36.py @@ -0,0 +1,23 @@ +"""Dict provider async mode tests.""" + +from dependency_injector import containers, providers +from pytest import mark + + +@mark.asyncio +async def test_provide(): + async def create_resource(param: str): + return param + + class Container(containers.DeclarativeContainer): + + resources = providers.Dict( + foo=providers.Resource(create_resource, "foo"), + bar=providers.Resource(create_resource, "bar") + ) + + container = Container() + resources = await container.resources() + + assert resources["foo"] == "foo" + assert resources["bar"] == "bar" diff --git a/tests/unit/providers/async/test_factory_aggregate_py36.py b/tests/unit/providers/async/test_factory_aggregate_py36.py new file mode 100644 index 00000000..ace7ffdf --- /dev/null +++ b/tests/unit/providers/async/test_factory_aggregate_py36.py @@ -0,0 +1,30 @@ +"""FactoryAggregate provider async mode tests.""" + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + object1 = object() + object2 = object() + + async def _get_object1(): + return object1 + + def _get_object2(): + return object2 + + provider = providers.FactoryAggregate( + object1=providers.Factory(_get_object1), + object2=providers.Factory(_get_object2), + ) + + assert provider.is_async_mode_undefined() is True + + created_object1 = await provider("object1") + assert created_object1 is object1 + assert provider.is_async_mode_enabled() is True + + created_object2 = await provider("object2") + assert created_object2 is object2 diff --git a/tests/unit/providers/async/test_factory_py36.py b/tests/unit/providers/async/test_factory_py36.py index 5337ce5b..98ba2c83 100644 --- a/tests/unit/providers/async/test_factory_py36.py +++ b/tests/unit/providers/async/test_factory_py36.py @@ -1,49 +1,11 @@ -"""Factory async mode tests.""" +"""Factory provider async mode tests.""" import asyncio -import random from dependency_injector import containers, providers from pytest import mark, raises - -RESOURCE1 = object() -RESOURCE2 = object() - - -async def init_resource(resource): - await asyncio.sleep(random.randint(1, 10) / 1000) - yield resource - await asyncio.sleep(random.randint(1, 10) / 1000) - - -class Client: - def __init__(self, resource1: object, resource2: object) -> None: - self.resource1 = resource1 - self.resource2 = resource2 - - -class Service: - def __init__(self, client: Client) -> None: - self.client = client - - -class BaseContainer(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - -class Container(BaseContainer): - client = providers.Factory( - Client, - resource1=BaseContainer.resource1, - resource2=BaseContainer.resource2, - ) - - service = providers.Factory( - Service, - client=client, - ) +from .common import RESOURCE1, RESOURCE2, Client, Service, BaseContainer, Container, init_resource @mark.asyncio @@ -184,6 +146,86 @@ async def test_args_kwargs_injection(): assert service1.client is not service2.client +@mark.asyncio +async def test_async_provider_with_async_injections(): + # See: https://github.com/ets-labs/python-dependency-injector/issues/368 + async def async_client_provider(): + return {"client": "OK"} + + async def async_service(client): + return {"service": "OK", "client": client} + + class Container(containers.DeclarativeContainer): + client = providers.Factory(async_client_provider) + service = providers.Factory(async_service, client=client) + + container = Container() + service = await container.service() + + assert service == {"service": "OK", "client": {"client": "OK"}} + + +@mark.asyncio +async def test_with_awaitable_injection(): + class SomeResource: + def __await__(self): + raise RuntimeError("Should never happen") + + async def init_resource(): + yield SomeResource() + + class Service: + def __init__(self, resource) -> None: + self.resource = resource + + class Container(containers.DeclarativeContainer): + resource = providers.Resource(init_resource) + service = providers.Factory(Service, resource=resource) + + container = Container() + + assert isinstance(container.service(), asyncio.Future) + assert isinstance(container.resource(), asyncio.Future) + + resource = await container.resource() + service = await container.service() + + assert isinstance(resource, SomeResource) + assert isinstance(service.resource, SomeResource) + assert service.resource is resource + + +@mark.asyncio +async def test_with_awaitable_injection_and_with_init_resources_call(): + class SomeResource: + def __await__(self): + raise RuntimeError("Should never happen") + + async def init_resource(): + yield SomeResource() + + class Service: + def __init__(self, resource) -> None: + self.resource = resource + + class Container(containers.DeclarativeContainer): + resource = providers.Resource(init_resource) + service = providers.Factory(Service, resource=resource) + + container = Container() + + await container.init_resources() + assert isinstance(container.service(), asyncio.Future) + assert isinstance(container.resource(), asyncio.Future) + + resource = await container.resource() + service = await container.service() + + assert isinstance(resource, SomeResource) + assert isinstance(service.resource, SomeResource) + assert service.resource is resource + + @mark.asyncio async def test_injection_error(): async def init_resource(): diff --git a/tests/unit/providers/async/test_list_py36.py b/tests/unit/providers/async/test_list_py36.py new file mode 100644 index 00000000..5f0162f8 --- /dev/null +++ b/tests/unit/providers/async/test_list_py36.py @@ -0,0 +1,24 @@ +"""List provider async mode tests.""" + +from dependency_injector import containers, providers +from pytest import mark + + +@mark.asyncio +async def test_provide(): + # See issue: https://github.com/ets-labs/python-dependency-injector/issues/450 + async def create_resource(param: str): + return param + + class Container(containers.DeclarativeContainer): + + resources = providers.List( + providers.Resource(create_resource, "foo"), + providers.Resource(create_resource, "bar") + ) + + container = Container() + resources = await container.resources() + + assert resources[0] == "foo" + assert resources[1] == "bar" diff --git a/tests/unit/providers/async/test_override_py36.py b/tests/unit/providers/async/test_override_py36.py new file mode 100644 index 00000000..6e76ac3b --- /dev/null +++ b/tests/unit/providers/async/test_override_py36.py @@ -0,0 +1,127 @@ +"""Tests for provider overriding in async mode.""" + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_provider(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + def _get_dependency_sync(): + return dependency + + provider = providers.Provider() + + provider.override(providers.Callable(_get_dependency_async)) + dependency1 = await provider() + + provider.override(providers.Callable(_get_dependency_sync)) + dependency2 = await provider() + + assert dependency1 is dependency + assert dependency2 is dependency + + +@mark.asyncio +async def test_callable(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + def _get_dependency_sync(): + return dependency + + provider = providers.Callable(_get_dependency_async) + dependency1 = await provider() + + provider.override(providers.Callable(_get_dependency_sync)) + dependency2 = await provider() + + assert dependency1 is dependency + assert dependency2 is dependency + + +@mark.asyncio +async def test_factory(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + def _get_dependency_sync(): + return dependency + + provider = providers.Factory(_get_dependency_async) + dependency1 = await provider() + + provider.override(providers.Callable(_get_dependency_sync)) + dependency2 = await provider() + + assert dependency1 is dependency + assert dependency2 is dependency + + +@mark.asyncio +async def test_async_mode_enabling(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + provider = providers.Callable(_get_dependency_async) + assert provider.is_async_mode_undefined() is True + + await provider() + + assert provider.is_async_mode_enabled() is True + + +@mark.asyncio +async def test_async_mode_disabling(): + dependency = object() + + def _get_dependency(): + return dependency + + provider = providers.Callable(_get_dependency) + assert provider.is_async_mode_undefined() is True + + provider() + + assert provider.is_async_mode_disabled() is True + + +@mark.asyncio +async def test_async_mode_enabling_on_overriding(): + dependency = object() + + async def _get_dependency_async(): + return dependency + + provider = providers.Provider() + provider.override(providers.Callable(_get_dependency_async)) + assert provider.is_async_mode_undefined() is True + + await provider() + + assert provider.is_async_mode_enabled() is True + + +def test_async_mode_disabling_on_overriding(): + dependency = object() + + def _get_dependency(): + return dependency + + provider = providers.Provider() + provider.override(providers.Callable(_get_dependency)) + assert provider.is_async_mode_undefined() is True + + provider() + + assert provider.is_async_mode_disabled() is True diff --git a/tests/unit/providers/async/test_provided_instance_py36.py b/tests/unit/providers/async/test_provided_instance_py36.py new file mode 100644 index 00000000..faea4132 --- /dev/null +++ b/tests/unit/providers/async/test_provided_instance_py36.py @@ -0,0 +1,180 @@ +"""ProvidedInstance provider async mode tests.""" + +import asyncio + +from dependency_injector import containers, providers +from pytest import mark, raises + +from .common import RESOURCE1, init_resource + + +@mark.asyncio +async def test_provided_attribute(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + class TestService: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + service = providers.Factory(TestService, resource=client.provided.resource) + + container = TestContainer() + + instance1, instance2 = await asyncio.gather( + container.service(), + container.service(), + ) + + assert instance1.resource is RESOURCE1 + assert instance2.resource is RESOURCE1 + assert instance1.resource is instance2.resource + + +@mark.asyncio +async def test_provided_attribute_error(): + async def raise_exception(): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(raise_exception) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided.attr() + + +@mark.asyncio +async def test_provided_attribute_undefined_attribute(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + + container = TestContainer() + + with raises(AttributeError): + await container.client.provided.attr() + + +@mark.asyncio +async def test_provided_item(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + def __getitem__(self, item): + return getattr(self, item) + + class TestService: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + service = providers.Factory(TestService, resource=client.provided["resource"]) + + container = TestContainer() + + instance1, instance2 = await asyncio.gather( + container.service(), + container.service(), + ) + + assert instance1.resource is RESOURCE1 + assert instance2.resource is RESOURCE1 + assert instance1.resource is instance2.resource + + +@mark.asyncio +async def test_provided_item_error(): + async def raise_exception(): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(raise_exception) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided["item"]() + + +@mark.asyncio +async def test_provided_item_undefined_item(): + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(dict, resource=resource) + + container = TestContainer() + + with raises(KeyError): + await container.client.provided["item"]() + + +@mark.asyncio +async def test_provided_method_call(): + class TestClient: + def __init__(self, resource): + self.resource = resource + + def get_resource(self): + return self.resource + + class TestService: + def __init__(self, resource): + self.resource = resource + + class TestContainer(containers.DeclarativeContainer): + resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) + client = providers.Factory(TestClient, resource=resource) + service = providers.Factory(TestService, resource=client.provided.get_resource.call()) + + container = TestContainer() + + instance1, instance2 = await asyncio.gather( + container.service(), + container.service(), + ) + + assert instance1.resource is RESOURCE1 + assert instance2.resource is RESOURCE1 + assert instance1.resource is instance2.resource + + +@mark.asyncio +async def test_provided_method_call_parent_error(): + async def raise_exception(): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(raise_exception) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided.method.call()() + + +@mark.asyncio +async def test_provided_method_call_error(): + class TestClient: + def method(self): + raise RuntimeError() + + class TestContainer(containers.DeclarativeContainer): + client = providers.Factory(TestClient) + + container = TestContainer() + + with raises(RuntimeError): + await container.client.provided.method.call()() diff --git a/tests/unit/providers/async/test_singleton_py36.py b/tests/unit/providers/async/test_singleton_py36.py new file mode 100644 index 00000000..f1621bf5 --- /dev/null +++ b/tests/unit/providers/async/test_singleton_py36.py @@ -0,0 +1,117 @@ +"""Singleton provider async mode tests.""" + +import asyncio +import random + +from dependency_injector import providers +from pytest import mark, raises + +from .common import RESOURCE1, RESOURCE2, BaseContainer, Client, Service + + +@mark.asyncio +async def test_injections(): + class ContainerWithSingletons(BaseContainer): + client = providers.Singleton( + Client, + resource1=BaseContainer.resource1, + resource2=BaseContainer.resource2, + ) + + service = providers.Singleton( + Service, + client=client, + ) + + container = ContainerWithSingletons() + + client1 = await container.client() + client2 = await container.client() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service() + service2 = await container.service() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1 is service2 + assert service1.client is service2.client + assert service1.client is client1 + + assert service2.client is client2 + assert client1 is client2 + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.Singleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + await asyncio.sleep(random.randint(1, 10) / 1000) + return object() + + provider = providers.Singleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + instance3 = await provider() + + assert instance1 is instance2 is instance3 + + +@mark.asyncio +async def test_async_init_with_error(): + async def create_instance(): + create_instance.counter += 1 + raise RuntimeError() + + create_instance.counter = 0 + + provider = providers.Singleton(create_instance) + + future = provider() + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await future + + assert create_instance.counter == 1 + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await provider() + + assert create_instance.counter == 2 + assert provider.is_async_mode_enabled() is True diff --git a/tests/unit/providers/async/test_thread_local_singleton_py36.py b/tests/unit/providers/async/test_thread_local_singleton_py36.py new file mode 100644 index 00000000..bf8ec3b3 --- /dev/null +++ b/tests/unit/providers/async/test_thread_local_singleton_py36.py @@ -0,0 +1,63 @@ +"""ThreadLocalSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark, raises + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.ThreadLocalSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.ThreadLocalSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 + + +@mark.asyncio +async def test_async_init_with_error(): + async def create_instance(): + create_instance.counter += 1 + raise RuntimeError() + create_instance.counter = 0 + + provider = providers.ThreadLocalSingleton(create_instance) + + future = provider() + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await future + + assert create_instance.counter == 1 + assert provider.is_async_mode_enabled() is True + + with raises(RuntimeError): + await provider() + + assert create_instance.counter == 2 + assert provider.is_async_mode_enabled() is True diff --git a/tests/unit/providers/async/test_thread_safe_singleton_py36.py b/tests/unit/providers/async/test_thread_safe_singleton_py36.py new file mode 100644 index 00000000..13654150 --- /dev/null +++ b/tests/unit/providers/async/test_thread_safe_singleton_py36.py @@ -0,0 +1,38 @@ +"""ThreadSafeSingleton provider async mode tests.""" + +import asyncio + +from dependency_injector import providers +from pytest import mark + + +@mark.asyncio +async def test_async_mode(): + instance = object() + + async def create_instance(): + return instance + + provider = providers.ThreadSafeSingleton(create_instance) + + instance1 = await provider() + instance2 = await provider() + + assert instance1 is instance2 + assert instance1 is instance + assert instance2 is instance + + +@mark.asyncio +async def test_concurrent_init(): + async def create_instance(): + return object() + + provider = providers.ThreadSafeSingleton(create_instance) + + future_instance1 = provider() + future_instance2 = provider() + + instance1, instance2 = await asyncio.gather(future_instance1, future_instance2) + + assert instance1 is instance2 diff --git a/tests/unit/providers/async/test_typing_stubs_py36.py b/tests/unit/providers/async/test_typing_stubs_py36.py new file mode 100644 index 00000000..c7214d2d --- /dev/null +++ b/tests/unit/providers/async/test_typing_stubs_py36.py @@ -0,0 +1,36 @@ +"""Tests for provide async mode typing stubs.""" + +from pytest import mark + +from .common import Container, Client, Service, RESOURCE1, RESOURCE2 + + +@mark.asyncio +async def test_async_(): + container = Container() + + client1 = await container.client.async_() + client2 = await container.client.async_() + + assert isinstance(client1, Client) + assert client1.resource1 is RESOURCE1 + assert client1.resource2 is RESOURCE2 + + assert isinstance(client2, Client) + assert client2.resource1 is RESOURCE1 + assert client2.resource2 is RESOURCE2 + + service1 = await container.service.async_() + service2 = await container.service.async_() + + assert isinstance(service1, Service) + assert isinstance(service1.client, Client) + assert service1.client.resource1 is RESOURCE1 + assert service1.client.resource2 is RESOURCE2 + + assert isinstance(service2, Service) + assert isinstance(service2.client, Client) + assert service2.client.resource1 is RESOURCE1 + assert service2.client.resource2 is RESOURCE2 + + assert service1.client is not service2.client diff --git a/tests/unit/providers/test_async_py36.py b/tests/unit/providers/test_async_py36.py deleted file mode 100644 index 79624e6d..00000000 --- a/tests/unit/providers/test_async_py36.py +++ /dev/null @@ -1,843 +0,0 @@ -import asyncio -import random -import unittest - -from dependency_injector import containers, providers, errors - -# Runtime import to get asyncutils module -import os -_TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), -) -import sys -sys.path.append(_TOP_DIR) - -from asyncutils import AsyncTestCase - - -RESOURCE1 = object() -RESOURCE2 = object() - - -async def init_resource(resource): - await asyncio.sleep(random.randint(1, 10) / 1000) - yield resource - await asyncio.sleep(random.randint(1, 10) / 1000) - - -class Client: - def __init__(self, resource1: object, resource2: object) -> None: - self.resource1 = resource1 - self.resource2 = resource2 - - -class Service: - def __init__(self, client: Client) -> None: - self.client = client - - -class Container(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - client = providers.Factory( - Client, - resource1=resource1, - resource2=resource2, - ) - - service = providers.Factory( - Service, - client=client, - ) - - -class FactoryAggregateTests(AsyncTestCase): - - def test_async_mode(self): - object1 = object() - object2 = object() - - async def _get_object1(): - return object1 - - def _get_object2(): - return object2 - - provider = providers.FactoryAggregate( - object1=providers.Factory(_get_object1), - object2=providers.Factory(_get_object2), - ) - - self.assertTrue(provider.is_async_mode_undefined()) - - created_object1 = self._run(provider("object1")) - self.assertIs(created_object1, object1) - self.assertTrue(provider.is_async_mode_enabled()) - - created_object2 = self._run(provider("object2")) - self.assertIs(created_object2, object2) - - -class SingletonTests(AsyncTestCase): - - def test_injections(self): - class ContainerWithSingletons(containers.DeclarativeContainer): - resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1)) - resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2)) - - client = providers.Singleton( - Client, - resource1=resource1, - resource2=resource2, - ) - - service = providers.Singleton( - Service, - client=client, - ) - - container = ContainerWithSingletons() - - client1 = self._run(container.client()) - client2 = self._run(container.client()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service()) - service2 = self._run(container.service()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIs(service1, service2) - self.assertIs(service1.client, service2.client) - self.assertIs(service1.client, client1) - - self.assertIs(service2.client, client2) - self.assertIs(client1, client2) - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.Singleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - def test_async_init_with_error(self): - # Disable default exception handling to prevent output - asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...) - - async def create_instance(): - create_instance.counter += 1 - raise RuntimeError() - - create_instance.counter = 0 - - provider = providers.Singleton(create_instance) - - - future = provider() - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(future) - - self.assertEqual(create_instance.counter, 1) - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(provider()) - - self.assertEqual(create_instance.counter, 2) - self.assertTrue(provider.is_async_mode_enabled()) - - # Restore default exception handling - asyncio.get_event_loop().set_exception_handler(None) - - -class DelegatedSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.DelegatedSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class ThreadSafeSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.ThreadSafeSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class DelegatedThreadSafeSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.DelegatedThreadSafeSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class ThreadLocalSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.ThreadLocalSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - - def test_async_init_with_error(self): - # Disable default exception handling to prevent output - asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...) - - async def create_instance(): - create_instance.counter += 1 - raise RuntimeError() - create_instance.counter = 0 - - provider = providers.ThreadLocalSingleton(create_instance) - - future = provider() - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(future) - - self.assertEqual(create_instance.counter, 1) - self.assertTrue(provider.is_async_mode_enabled()) - - with self.assertRaises(RuntimeError): - self._run(provider()) - - self.assertEqual(create_instance.counter, 2) - self.assertTrue(provider.is_async_mode_enabled()) - - # Restore default exception handling - asyncio.get_event_loop().set_exception_handler(None) - - -class DelegatedThreadLocalSingletonTests(AsyncTestCase): - - def test_async_mode(self): - instance = object() - - async def create_instance(): - return instance - - provider = providers.DelegatedThreadLocalSingleton(create_instance) - - instance1 = self._run(provider()) - instance2 = self._run(provider()) - - self.assertIs(instance1, instance2) - self.assertIs(instance, instance) - - -class ProvidedInstanceTests(AsyncTestCase): - - def test_provided_attribute(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - class TestService: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - service = providers.Factory(TestService, resource=client.provided.resource) - - container = TestContainer() - - instance1, instance2 = self._run( - asyncio.gather( - container.service(), - container.service(), - ), - ) - - self.assertIs(instance1.resource, RESOURCE1) - self.assertIs(instance2.resource, RESOURCE1) - self.assertIs(instance1.resource, instance2.resource) - - def test_provided_attribute_error(self): - async def raise_exception(): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(raise_exception) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided.attr()) - - def test_provided_attribute_undefined_attribute(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - - container = TestContainer() - - with self.assertRaises(AttributeError): - self._run(container.client.provided.attr()) - - def test_provided_item(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - def __getitem__(self, item): - return getattr(self, item) - - class TestService: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - service = providers.Factory(TestService, resource=client.provided["resource"]) - - container = TestContainer() - - instance1, instance2 = self._run( - asyncio.gather( - container.service(), - container.service(), - ), - ) - - self.assertIs(instance1.resource, RESOURCE1) - self.assertIs(instance2.resource, RESOURCE1) - self.assertIs(instance1.resource, instance2.resource) - - def test_provided_item_error(self): - async def raise_exception(): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(raise_exception) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided["item"]()) - - def test_provided_item_undefined_item(self): - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(dict, resource=resource) - - container = TestContainer() - - with self.assertRaises(KeyError): - self._run(container.client.provided["item"]()) - - def test_provided_method_call(self): - class TestClient: - def __init__(self, resource): - self.resource = resource - - def get_resource(self): - return self.resource - - class TestService: - def __init__(self, resource): - self.resource = resource - - class TestContainer(containers.DeclarativeContainer): - resource = providers.Resource(init_resource, providers.Object(RESOURCE1)) - client = providers.Factory(TestClient, resource=resource) - service = providers.Factory(TestService, resource=client.provided.get_resource.call()) - - container = TestContainer() - - instance1, instance2 = self._run( - asyncio.gather( - container.service(), - container.service(), - ), - ) - - self.assertIs(instance1.resource, RESOURCE1) - self.assertIs(instance2.resource, RESOURCE1) - self.assertIs(instance1.resource, instance2.resource) - - def test_provided_method_call_parent_error(self): - async def raise_exception(): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(raise_exception) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided.method.call()()) - - def test_provided_method_call_error(self): - class TestClient: - def method(self): - raise RuntimeError() - - class TestContainer(containers.DeclarativeContainer): - client = providers.Factory(TestClient) - - container = TestContainer() - - with self.assertRaises(RuntimeError): - self._run(container.client.provided.method.call()()) - - -class DependencyTests(AsyncTestCase): - - def test_provide_error(self): - async def get_async(): - raise Exception - - provider = providers.Dependency() - provider.override(providers.Callable(get_async)) - - with self.assertRaises(Exception): - self._run(provider()) - - def test_isinstance(self): - dependency = 1.0 - - async def get_async(): - return dependency - - provider = providers.Dependency(instance_of=float) - provider.override(providers.Callable(get_async)) - - self.assertTrue(provider.is_async_mode_undefined()) - - dependency1 = self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - dependency2 = self._run(provider()) - - self.assertEqual(dependency1, dependency) - self.assertEqual(dependency2, dependency) - - def test_isinstance_invalid(self): - async def get_async(): - return {} - - provider = providers.Dependency(instance_of=float) - provider.override(providers.Callable(get_async)) - - self.assertTrue(provider.is_async_mode_undefined()) - - with self.assertRaises(errors.Error): - self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - def test_async_mode(self): - dependency = 123 - - async def get_async(): - return dependency - - def get_sync(): - return dependency - - provider = providers.Dependency(instance_of=int) - provider.override(providers.Factory(get_async)) - - self.assertTrue(provider.is_async_mode_undefined()) - - dependency1 = self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - dependency2 = self._run(provider()) - self.assertEqual(dependency1, dependency) - self.assertEqual(dependency2, dependency) - - provider.override(providers.Factory(get_sync)) - - dependency3 = self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - dependency4 = self._run(provider()) - self.assertEqual(dependency3, dependency) - self.assertEqual(dependency4, dependency) - - -class ListTests(AsyncTestCase): - - def test_provide(self): - # See issue: https://github.com/ets-labs/python-dependency-injector/issues/450 - async def create_resource(param: str): - return param - - class Container(containers.DeclarativeContainer): - - resources = providers.List( - providers.Resource(create_resource, "foo"), - providers.Resource(create_resource, "bar") - ) - - container = Container() - resources = self._run(container.resources()) - - self.assertEqual(resources[0], "foo") - self.assertEqual(resources[1], "bar") - - -class DictTests(AsyncTestCase): - - def test_provide(self): - async def create_resource(param: str): - return param - - class Container(containers.DeclarativeContainer): - - resources = providers.Dict( - foo=providers.Resource(create_resource, "foo"), - bar=providers.Resource(create_resource, "bar") - ) - - container = Container() - resources = self._run(container.resources()) - - self.assertEqual(resources["foo"], "foo") - self.assertEqual(resources["bar"], "bar") - - -class OverrideTests(AsyncTestCase): - - def test_provider(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - def _get_dependency_sync(): - return dependency - - provider = providers.Provider() - - provider.override(providers.Callable(_get_dependency_async)) - dependency1 = self._run(provider()) - - provider.override(providers.Callable(_get_dependency_sync)) - dependency2 = self._run(provider()) - - self.assertIs(dependency1, dependency) - self.assertIs(dependency2, dependency) - - def test_callable(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - def _get_dependency_sync(): - return dependency - - provider = providers.Callable(_get_dependency_async) - dependency1 = self._run(provider()) - - provider.override(providers.Callable(_get_dependency_sync)) - dependency2 = self._run(provider()) - - self.assertIs(dependency1, dependency) - self.assertIs(dependency2, dependency) - - def test_factory(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - def _get_dependency_sync(): - return dependency - - provider = providers.Factory(_get_dependency_async) - dependency1 = self._run(provider()) - - provider.override(providers.Callable(_get_dependency_sync)) - dependency2 = self._run(provider()) - - self.assertIs(dependency1, dependency) - self.assertIs(dependency2, dependency) - - def test_async_mode_enabling(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - provider = providers.Callable(_get_dependency_async) - self.assertTrue(provider.is_async_mode_undefined()) - - self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - def test_async_mode_disabling(self): - dependency = object() - - def _get_dependency(): - return dependency - - provider = providers.Callable(_get_dependency) - self.assertTrue(provider.is_async_mode_undefined()) - - provider() - - self.assertTrue(provider.is_async_mode_disabled()) - - def test_async_mode_enabling_on_overriding(self): - dependency = object() - - async def _get_dependency_async(): - return dependency - - provider = providers.Provider() - provider.override(providers.Callable(_get_dependency_async)) - self.assertTrue(provider.is_async_mode_undefined()) - - self._run(provider()) - - self.assertTrue(provider.is_async_mode_enabled()) - - def test_async_mode_disabling_on_overriding(self): - dependency = object() - - def _get_dependency(): - return dependency - - provider = providers.Provider() - provider.override(providers.Callable(_get_dependency)) - self.assertTrue(provider.is_async_mode_undefined()) - - provider() - - self.assertTrue(provider.is_async_mode_disabled()) - - -class TestAsyncModeApi(unittest.TestCase): - - def setUp(self): - self.provider = providers.Provider() - - def test_default_mode(self): - self.assertFalse(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertTrue(self.provider.is_async_mode_undefined()) - - def test_enable(self): - self.provider.enable_async_mode() - - self.assertTrue(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertFalse(self.provider.is_async_mode_undefined()) - - def test_disable(self): - self.provider.disable_async_mode() - - self.assertFalse(self.provider.is_async_mode_enabled()) - self.assertTrue(self.provider.is_async_mode_disabled()) - self.assertFalse(self.provider.is_async_mode_undefined()) - - def test_reset(self): - self.provider.enable_async_mode() - - self.assertTrue(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertFalse(self.provider.is_async_mode_undefined()) - - self.provider.reset_async_mode() - - self.assertFalse(self.provider.is_async_mode_enabled()) - self.assertFalse(self.provider.is_async_mode_disabled()) - self.assertTrue(self.provider.is_async_mode_undefined()) - - -class AsyncTypingStubTests(AsyncTestCase): - - def test_async_(self): - container = Container() - - client1 = self._run(container.client.async_()) - client2 = self._run(container.client.async_()) - - self.assertIsInstance(client1, Client) - self.assertIs(client1.resource1, RESOURCE1) - self.assertIs(client1.resource2, RESOURCE2) - - self.assertIsInstance(client2, Client) - self.assertIs(client2.resource1, RESOURCE1) - self.assertIs(client2.resource2, RESOURCE2) - - service1 = self._run(container.service.async_()) - service2 = self._run(container.service.async_()) - - self.assertIsInstance(service1, Service) - self.assertIsInstance(service1.client, Client) - self.assertIs(service1.client.resource1, RESOURCE1) - self.assertIs(service1.client.resource2, RESOURCE2) - - self.assertIsInstance(service2, Service) - self.assertIsInstance(service2.client, Client) - self.assertIs(service2.client.resource1, RESOURCE1) - self.assertIs(service2.client.resource2, RESOURCE2) - - self.assertIsNot(service1.client, service2.client) - - -class AsyncProvidersWithAsyncDependenciesTests(AsyncTestCase): - - def test_injections(self): - # See: https://github.com/ets-labs/python-dependency-injector/issues/368 - async def async_db_provider(): - return {"db": "ok"} - - async def async_service(db=None): - return {"service": "ok", "db": db} - - class Container(containers.DeclarativeContainer): - - db = providers.Factory(async_db_provider) - service = providers.Singleton(async_service, db=db) - - container = Container() - service = self._run(container.service()) - - self.assertEqual(service, {"service": "ok", "db": {"db": "ok"}}) - - -class AsyncProviderWithAwaitableObjectTests(AsyncTestCase): - - def test(self): - class SomeResource: - def __await__(self): - raise RuntimeError("Should never happen") - - async def init_resource(): - pool = SomeResource() - yield pool - - class Service: - def __init__(self, resource) -> None: - self.resource = resource - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource) - service = providers.Singleton(Service, resource=resource) - - container = Container() - - self._run(container.init_resources()) - self.assertIsInstance(container.service(), asyncio.Future) - self.assertIsInstance(container.resource(), asyncio.Future) - - resource = self._run(container.resource()) - service = self._run(container.service()) - - self.assertIsInstance(resource, SomeResource) - self.assertIsInstance(service.resource, SomeResource) - self.assertIs(service.resource, resource) - - def test_without_init_resources(self): - class SomeResource: - def __await__(self): - raise RuntimeError("Should never happen") - - async def init_resource(): - pool = SomeResource() - yield pool - - class Service: - def __init__(self, resource) -> None: - self.resource = resource - - class Container(containers.DeclarativeContainer): - resource = providers.Resource(init_resource) - service = providers.Singleton(Service, resource=resource) - - container = Container() - - self.assertIsInstance(container.service(), asyncio.Future) - self.assertIsInstance(container.resource(), asyncio.Future) - - resource = self._run(container.resource()) - service = self._run(container.service()) - - self.assertIsInstance(resource, SomeResource) - self.assertIsInstance(service.resource, SomeResource) - self.assertIs(service.resource, resource)