Async resources and injections (#352)

* Add support of async injections into wiring

* Add support of async functions and async generators for resources

* Update resource provider typing stub for stutdown

* Add resource base class for async resources

* Fix tests

* Add tests for async injections in wiring @inject

* Refactor provider tests

* Add tests for async resources

* Rework async resources callbacks to .add_done_callback() style (fixes pypy3 issue)

* Add awaits into async resource class test

* Refactor FastAPI tests

* Implement async resources initialization in container

* Move container async resource tests to a separate module for Python 3.6+

* Fix init async resources in container on Python 2

* Add first dirty async injections implementation

* Fix isawaitable error

* Turm asyncio import to conditional for safer Py2 usage

* Refactor kwargs injections

* Implement positional injections, add tests and make refactoring

* Implement attribute injections and add tests

* Add singleton implementation + tests for all singleton types

* Implement injections in thread-local and thread-safe singleton providers

* Update .provided + fix resource concurent initialization issue

* Implement async mode for Dependency provider

* Add async mode for the provider

* Add overload for Factory typing

* Add typing stubs for async resource

* Refactor abstract* providers __call__()

* Add async mode API + tests

* Add typing stubs & tests for async mode API

* Add tests for async mode auto configuration

* Refactor Provider.__call__() to use async mode api

* Refactor Dependency provider to use async mode api

* Add tests for Dependency provider async mode

* Add support of async mode for FactoryAggregate provider + tests

* Refactor Singleton provider to use async mode api

* Refactor ThreadSafeSingleton provider to use async mode api

* Refactor ThreadLocalSingleton provider to use async mode api

* Finish Singleton refactoring to use async mode api

* Refactor Resource provider to use async mode api

* Add Provider.async_() method + tests

* Add typing stubs for async_() method + tests

* Refactor Singleton typing stubs to return singleton from argument methods

* Refactor provider typing stubs

* Improve resource typing stub

* Add tests for async context kwargs injections

* Fix typo in resource provider tests

* Cover shutdown of not initialized resource

* Add test to cover resource initialization with an error

* Fix Singleton and ThreadLocalSingleton to handle initialization errors

* Add FastAPI + Redis example

* Make cosmetic fixes to FastAPI + Redis example

* Add missing development requirements

* Update module docblock in fastapi + redis example

* Add FastAPI + Redis example docs

* Add references to FastAPI + Redis example

* Refactor resource docs

* Add asynchronous resources docs

* Refactor wiring docs

* Add async injections docs for wiring

* Add async injections page and update docs index, readme, and key features pages

* Add providers async injections example

* Add docs on provider async mode enabling

* Reword async provider docs

* Add provider async mode docs

* Add cross links to async docs

* Mute flake8 errors in async provider examples

* Update changelog

* Make cosmetic fix to containers.pyx
This commit is contained in:
Roman Mogylatov 2021-01-10 19:26:15 -05:00 committed by GitHub
parent 9f6d2bb522
commit feed916f46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 33044 additions and 15188 deletions

View File

@ -70,6 +70,8 @@ Key features of the ``Dependency Injector``:
- **Wiring**. Injects dependencies into functions and methods. Helps integrating with
other frameworks: Django, Flask, Aiohttp, Sanic, FastAPI, etc.
See `Wiring <https://python-dependency-injector.ets-labs.org/wiring.html>`_.
- **Asynchronous**. Supports asynchronous injections.
See `Asynchronous injections <https://python-dependency-injector.ets-labs.org/providers/async.html>`_.
- **Typing**. Provides typing stubs, ``mypy``-friendly.
See `Typing and mypy <https://python-dependency-injector.ets-labs.org/providers/typing_mypy.html>`_.
- **Performance**. Fast. Written in ``Cython``.
@ -157,6 +159,8 @@ Choose one of the following:
- `Flask example <https://python-dependency-injector.ets-labs.org/examples/flask.html>`_
- `Aiohttp example <https://python-dependency-injector.ets-labs.org/examples/aiohttp.html>`_
- `Sanic example <https://python-dependency-injector.ets-labs.org/examples/sanic.html>`_
- `FastAPI example <https://python-dependency-injector.ets-labs.org/examples/fastapi.html>`_
- `FastAPI + Redis example <https://python-dependency-injector.ets-labs.org/examples/fastapi-redis.html>`_
Tutorials
---------
@ -223,4 +227,3 @@ Want to contribute?
.. |tell| unicode:: U+1F4AC .. tell sign
.. |fork| unicode:: U+1F500 .. fork sign
.. |pull| unicode:: U+2B05 U+FE0F .. pull sign

View File

@ -0,0 +1,98 @@
.. _fastapi-redis-example:
FastAPI + Redis example
=======================
.. meta::
:keywords: Python,Dependency Injection,FastAPI,Redis,Example
:description: This example demonstrates a usage of the FastAPI, Redis, and Dependency Injector.
This example shows how to use ``Dependency Injector`` with `FastAPI <https://fastapi.tiangolo.com/>`_ and
`Redis <https://redis.io/>`_.
The source code is available on the `Github <https://github.com/ets-labs/python-dependency-injector/tree/master/examples/miniapps/fastapi-redis>`_.
See also:
- Provider :ref:`async-injections`
- Resource provider :ref:`resource-async-initializers`
- Wiring :ref:`async-injections-wiring`
Application structure
---------------------
Application has next structure:
.. code-block:: bash
./
├── fastapiredis/
│ ├── __init__.py
│ ├── application.py
│ ├── containers.py
│ ├── redis.py
│ ├── services.py
│ └── tests.py
├── docker-compose.yml
├── Dockerfile
└── requirements.txt
Redis
-----
Module ``redis`` defines Redis connection pool initialization and shutdown. See ``fastapiredis/redis.py``:
.. literalinclude:: ../../examples/miniapps/fastapi-redis/fastapiredis/redis.py
:language: python
Service
-------
Module ``services`` contains example service. Service has a dependency on Redis connection pool.
It uses it for getting and setting a key asynchronously. Real life service will do something more meaningful.
See ``fastapiredis/services.py``:
.. literalinclude:: ../../examples/miniapps/fastapi-redis/fastapiredis/services.py
:language: python
Container
---------
Declarative container wires example service with Redis connection pool. See ``fastapiredis/containers.py``:
.. literalinclude:: ../../examples/miniapps/fastapi-redis/fastapiredis/containers.py
:language: python
Application
-----------
Module ``application`` creates ``FastAPI`` app, setup endpoint, and init container.
Endpoint ``index`` has a dependency on example service. The dependency is injected using :ref:`wiring` feature.
Listing of ``fastapiredis/application.py``:
.. literalinclude:: ../../examples/miniapps/fastapi-redis/fastapiredis/application.py
:language: python
Tests
-----
Tests use :ref:`provider-overriding` feature to replace example service with a mock. See ``fastapiredis/tests.py``:
.. literalinclude:: ../../examples/miniapps/fastapi-redis/fastapiredis/tests.py
:language: python
:emphasize-lines: 24
Sources
-------
The source code is available on the `Github <https://github.com/ets-labs/python-dependency-injector/tree/master/examples/miniapps/fastapi-redis>`_.
See also:
- Provider :ref:`async-injections`
- Resource provider :ref:`resource-async-initializers`
- Wiring :ref:`async-injections-wiring`
.. disqus::

View File

@ -19,5 +19,6 @@ Explore the examples to see the ``Dependency Injector`` in action.
aiohttp
sanic
fastapi
fastapi-redis
.. disqus::

View File

@ -78,6 +78,7 @@ Key features of the ``Dependency Injector``:
- **Containers**. Provides declarative and dynamic containers. See :ref:`containers`.
- **Wiring**. Injects dependencies into functions and methods. Helps integrating with
other frameworks: Django, Flask, Aiohttp, Sanic, FastAPI, etc. See :ref:`wiring`.
- **Asynchronous**. Supports asynchronous injections. See :ref:`async-injections`.
- **Typing**. Provides typing stubs, ``mypy``-friendly. See :ref:`provider-typing`.
- **Performance**. Fast. Written in ``Cython``.
- **Maturity**. Mature and production-ready. Well-tested, documented and supported.

View File

@ -287,6 +287,7 @@ Choose one of the following as a next step:
- :ref:`aiohttp-example`
- :ref:`sanic-example`
- :ref:`fastapi-example`
- :ref:`fastapi-redis-example`
- Pass the tutorials:
- :ref:`flask-tutorial`
- :ref:`aiohttp-tutorial`

View File

@ -24,6 +24,7 @@ Key features of the ``Dependency Injector``:
- **Containers**. Provides declarative and dynamic containers. See :ref:`containers`.
- **Wiring**. Injects dependencies into functions and methods. Helps integrating with
other frameworks: Django, Flask, Aiohttp, Sanic, FastAPI, etc. See :ref:`wiring`.
- **Asynchronous**. Supports asynchronous injections. See :ref:`async-injections`.
- **Typing**. Provides typing stubs, ``mypy``-friendly. See :ref:`provider-typing`.
- **Performance**. Fast. Written in ``Cython``.
- **Maturity**. Mature and production-ready. Well-tested, documented and supported.

View File

@ -9,6 +9,10 @@ follows `Semantic versioning`_
Development version
-------------------
- Add support of async injections for providers.
- Add support of async injections for wiring.
- Add support of async initializers for ``Resource`` provider.
- Add ``FastAPI`` + ``Redis`` example.
- Add ARM wheel builds.
See issue `#342 <https://github.com/ets-labs/python-dependency-injector/issues/342>`_ for details.
- Fix a typo in `ext.flask` deprecation warning.

108
docs/providers/async.rst Normal file
View File

@ -0,0 +1,108 @@
.. _async-injections:
Asynchronous injections
=======================
.. meta::
:keywords: Python,DI,Dependency injection,IoC,Inversion of Control,Providers,Async,Injections,Asynchronous,Await,
Asyncio
:description: Dependency Injector providers support asynchronous injections. This page
demonstrates how make asynchronous dependency injections in Python.
Providers support asynchronous injections.
.. literalinclude:: ../../examples/providers/async.py
:language: python
:emphasize-lines: 26-29
:lines: 3-
If provider has any awaitable injections it switches into async mode. In async mode provider always returns awaitable.
This causes a cascade effect:
.. code-block:: bash
provider1() <── Async mode enabled <──┐
│ │
├──> provider2() │
│ │
├──> provider3() <── Async mode enabled <──┤
│ │ │
│ └──> provider4() <── Async provider ───────┘
└──> provider5()
└──> provider6()
In async mode provider prepares injections asynchronously.
If provider has multiple awaitable dependencies, it will run them concurrently. Provider will wait until all
dependencies are ready and inject them afterwards.
.. code-block:: bash
provider1()
├──> provider2() <── Async mode enabled
├──> provider3() <── Async mode enabled
└──> provider4() <── Async mode enabled
Here is what provider will do for the previous example:
.. code-block:: python
injections = await asyncio.gather(
provider2(),
provider3(),
provider4(),
)
await provider1(*injections)
Overriding behaviour
--------------------
In async mode provider always returns awaitable. It applies to the overriding too. If provider in async mode is
overridden by a provider that doesn't return awaitable result, the result will be wrapped into awaitable.
.. literalinclude:: ../../examples/providers/async_overriding.py
:language: python
:emphasize-lines: 19-24
:lines: 3-
Async mode mechanics and API
----------------------------
By default provider's async mode is undefined.
When provider async mode is undefined, provider will automatically select the mode during the next call.
If the result is awaitable, provider will enable async mode, if not - disable it.
If provider async mode is enabled, provider always returns awaitable. If the result is not awaitable,
provider wraps it into awaitable explicitly. You can safely ``await`` provider in async mode.
If provider async mode is disabled, provider behaves the regular way. It doesn't do async injections
preparation or non-awaitables to awaitables conversion.
Once provider async mode is enabled or disabled, provider will stay in this state. No automatic switching
will be done.
.. image:: images/async_mode.png
You can also use following methods to change provider's async mode manually:
- ``Provider.enable_async_mode()``
- ``Provider.disable_async_mode()``
- ``Provider.reset_async_mode()``
To check the state of provider's async mode use:
- ``Provider.is_async_mode_enabled()``
- ``Provider.is_async_mode_disabled()``
- ``Provider.is_async_mode_undefined()``
See also:
- Wiring :ref:`async-injections-wiring`
- Resource provider :ref:`resource-async-initializers`
- :ref:`fastapi-redis-example`

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -51,4 +51,5 @@ Providers module API docs - :py:mod:`dependency_injector.providers`
overriding
provided_instance
custom
async
typing_mypy

View File

@ -21,7 +21,7 @@ Resource provider
Resource providers help to initialize and configure logging, event loop, thread or process pool, etc.
Resource provider is similar to ``Singleton``. Resource initialization happens only once.
You can do injections and use provided instance the same way like you do with any other provider.
You can make injections and use provided instance the same way like you do with any other provider.
.. code-block:: python
:emphasize-lines: 12
@ -40,7 +40,7 @@ You can do injections and use provided instance the same way like you do with an
executor=thread_pool,
)
Container has an interface to initialize and shutdown all resources:
Container has an interface to initialize and shutdown all resources at once:
.. code-block:: python
@ -48,7 +48,7 @@ Container has an interface to initialize and shutdown all resources:
container.init_resources()
container.shutdown_resources()
You also can initialize and shutdown resources one-by-one using ``init()`` and
You can also initialize and shutdown resources one-by-one using ``init()`` and
``shutdown()`` methods of the provider:
.. code-block:: python
@ -57,6 +57,10 @@ You also can initialize and shutdown resources one-by-one using ``init()`` and
container.thread_pool.init()
container.thread_pool.shutdown()
When you call ``.shutdown()`` method on a resource provider, it will remove the reference to the initialized resource,
if any, and switch to uninitialized state. Some of resource initializer types support specifying custom
resource shutdown.
Resource provider supports 3 types of initializers:
- Function
@ -97,7 +101,7 @@ you configure global resource:
fname='logging.ini',
)
Function initializer does not support shutdown.
Function initializer does not provide a way to specify custom resource shutdown.
Generator initializer
---------------------
@ -235,4 +239,124 @@ The example above produces next output:
Shutdown service
127.0.0.1 - - [29/Oct/2020 22:39:41] "GET / HTTP/1.1" 200 -
.. _resource-async-initializers:
Asynchronous initializers
-------------------------
When you write an asynchronous application, you might need to initialize resources asynchronously. Resource
provider supports asynchronous initialization and shutdown.
Asynchronous function initializer:
.. code-block:: python
async def init_async_resource(argument1=..., argument2=...):
return await connect()
class Container(containers.DeclarativeContainer):
resource = providers.Resource(
init_resource,
argument1=...,
argument2=...,
)
Asynchronous generator initializer:
.. code-block:: python
async def init_async_resource(argument1=..., argument2=...):
connection = await connect()
yield connection
await connection.close()
class Container(containers.DeclarativeContainer):
resource = providers.Resource(
init_async_resource,
argument1=...,
argument2=...,
)
Asynchronous subclass initializer:
.. code-block:: python
from dependency_injector import resources
class AsyncConnection(resources.AsyncResource):
async def init(self, argument1=..., argument2=...):
yield await connect()
async def shutdown(self, connection):
await connection.close()
class Container(containers.DeclarativeContainer):
resource = providers.Resource(
AsyncConnection,
argument1=...,
argument2=...,
)
When you use resource provider with asynchronous initializer you need to call its ``__call__()``,
``init()``, and ``shutdown()`` methods asynchronously:
.. code-block:: python
import asyncio
class Container(containers.DeclarativeContainer):
connection = providers.Resource(init_async_connection)
async def main():
container = Container()
connection = await container.connection()
connection = await container.connection.init()
connection = await container.connection.shutdown()
if __name__ == '__main__':
asyncio.run(main())
Container ``init_resources()`` and ``shutdown_resources()`` methods should be used asynchronously if there is
at least one asynchronous resource provider:
.. code-block:: python
import asyncio
class Container(containers.DeclarativeContainer):
connection1 = providers.Resource(init_async_connection)
connection2 = providers.Resource(init_sync_connection)
async def main():
container = Container()
await container.init_resources()
await container.shutdown_resources()
if __name__ == '__main__':
asyncio.run(main())
See also:
- Provider :ref:`async-injections`
- Wiring :ref:`async-injections-wiring`
- :ref:`fastapi-redis-example`
.. disqus::

View File

@ -167,21 +167,105 @@ You can use that in testing to re-create and re-wire a container before each tes
avoid re-wiring between tests.
.. note::
Python has a limitation on patching already imported individual members. To protect from errors
prefer an import of modules instead of individual members or make sure that imports happen
Python has a limitation on patching individually imported functions. To protect from errors
prefer importing modules to importing individual functions or make sure imports happen
after the wiring:
.. code-block:: python
# Potential error:
from .module import fn
fn()
Instead use next:
.. code-block:: python
# Always works:
from . import module
module.fn()
# instead of
.. _async-injections-wiring:
from .module import fn
Asynchronous injections
-----------------------
fn()
Wiring feature supports asynchronous injections:
.. code-block:: python
class Container(containers.DeclarativeContainer):
db = providers.Resource(init_async_db_client)
cache = providers.Resource(init_async_cache_client)
@inject
async def main(
db: Database = Provide[Container.db],
cache: Cache = Provide[Container.cache],
):
...
When you call asynchronous function wiring prepares injections asynchronously.
Here is what it does for previous example:
.. code-block:: python
db, cache = await asyncio.gather(
container.db(),
container.cache(),
)
await main(db=db, cache=cache)
You can also use ``Closing`` marker with the asynchronous ``Resource`` providers:
.. code-block:: python
@inject
async def main(
db: Database = Closing[Provide[Container.db]],
cache: Cache = Closing[Provide[Container.cache]],
):
...
Wiring does closing asynchronously:
.. code-block:: python
db, cache = await asyncio.gather(
container.db(),
container.cache(),
)
await main(db=db, cache=cache)
await asyncio.gather(
container.db.shutdown(),
container.cache.shutdown(),
)
See :ref:`Resources, wiring and per-function execution scope <resource-provider-wiring-closing>` for
details on ``Closing`` marker.
.. note::
Wiring does not not convert asynchronous injections to synchronous.
It handles asynchronous injections only for ``async def`` functions. Asynchronous injections into
synchronous ``def`` function still work, but you need to take care of awaitables by your own.
See also:
- Provider :ref:`async-injections`
- Resource provider :ref:`resource-async-initializers`
- :ref:`fastapi-redis-example`
Integration with other frameworks
---------------------------------
@ -211,5 +295,6 @@ Take a look at other application examples:
- :ref:`aiohttp-example`
- :ref:`sanic-example`
- :ref:`fastapi-example`
- :ref:`fastapi-redis-example`
.. disqus::

View File

@ -0,0 +1,10 @@
FROM python:3.8-buster
ENV PYTHONUNBUFFERED=1
WORKDIR /code
COPY . /code/
RUN pip install -r requirements.txt
CMD ["uvicorn", "fastapiredis.application:app", "--host", "0.0.0.0"]

View File

@ -0,0 +1,89 @@
FastAPI + Redis + Dependency Injector Example
=============================================
This is a `FastAPI <https://docs.python.org/3/library/asyncio.html>`_
+ `Redis <https://redis.io/>`_
+ `Dependency Injector <https://python-dependency-injector.ets-labs.org/>`_ example application.
Run
---
Build the Docker image:
.. code-block:: bash
docker-compose build
Run the docker-compose environment:
.. code-block:: bash
docker-compose up
The output should be something like:
.. code-block::
redis_1 | 1:C 04 Jan 2021 02:42:14.115 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo
redis_1 | 1:C 04 Jan 2021 02:42:14.115 # Redis version=6.0.9, bits=64, commit=00000000, modified=0, pid=1, just started
redis_1 | 1:C 04 Jan 2021 02:42:14.115 # Configuration loaded
redis_1 | 1:M 04 Jan 2021 02:42:14.116 * Running mode=standalone, port=6379.
redis_1 | 1:M 04 Jan 2021 02:42:14.116 # WARNING: The TCP backlog setting of 511 cannot be enforced because /proc/sys/net/core/somaxconn is set to the lower value of 128.
redis_1 | 1:M 04 Jan 2021 02:42:14.116 # Server initialized
redis_1 | 1:M 04 Jan 2021 02:42:14.117 * Loading RDB produced by version 6.0.9
redis_1 | 1:M 04 Jan 2021 02:42:14.117 * RDB age 1 seconds
redis_1 | 1:M 04 Jan 2021 02:42:14.117 * RDB memory usage when created 0.77 Mb
redis_1 | 1:M 04 Jan 2021 02:42:14.117 * DB loaded from disk: 0.000 seconds
redis_1 | 1:M 04 Jan 2021 02:42:14.117 * Ready to accept connections
redis_1 | 1:signal-handler (1609728137) Received SIGTERM scheduling shutdown...
redis_1 | 1:M 04 Jan 2021 02:42:17.984 # User requested shutdown...
redis_1 | 1:M 04 Jan 2021 02:42:17.984 # Redis is now ready to exit, bye bye...
redis_1 | 1:C 04 Jan 2021 02:42:22.035 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo
redis_1 | 1:C 04 Jan 2021 02:42:22.035 # Redis version=6.0.9, bits=64, commit=00000000, modified=0, pid=1, just started
redis_1 | 1:C 04 Jan 2021 02:42:22.035 # Configuration loaded
redis_1 | 1:M 04 Jan 2021 02:42:22.037 * Running mode=standalone, port=6379.
redis_1 | 1:M 04 Jan 2021 02:42:22.037 # WARNING: The TCP backlog setting of 511 cannot be enforced because /proc/sys/net/core/somaxconn is set to the lower value of 128.
redis_1 | 1:M 04 Jan 2021 02:42:22.037 # Server initialized
redis_1 | 1:M 04 Jan 2021 02:42:22.037 * Loading RDB produced by version 6.0.9
redis_1 | 1:M 04 Jan 2021 02:42:22.037 * RDB age 9 seconds
redis_1 | 1:M 04 Jan 2021 02:42:22.037 * RDB memory usage when created 0.77 Mb
redis_1 | 1:M 04 Jan 2021 02:42:22.037 * DB loaded from disk: 0.000 seconds
redis_1 | 1:M 04 Jan 2021 02:42:22.037 * Ready to accept connections
example_1 | INFO: Started server process [1]
example_1 | INFO: Waiting for application startup.
example_1 | INFO: Application startup complete.
example_1 | INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
Test
----
This application comes with the unit tests.
To run the tests do:
.. code-block:: bash
docker-compose run --rm example py.test fastapiredis/tests.py --cov=fastapiredis
The output should be something like:
.. code-block::
platform linux -- Python 3.8.6, pytest-6.2.1, py-1.10.0, pluggy-0.13.1
rootdir: /code
plugins: cov-2.10.1, asyncio-0.14.0
collected 1 item
fastapiredis/tests.py . [100%]
----------- coverage: platform linux, python 3.8.6-final-0 -----------
Name Stmts Miss Cover
-------------------------------------------------
fastapiredis/__init__.py 0 0 100%
fastapiredis/application.py 15 0 100%
fastapiredis/containers.py 6 0 100%
fastapiredis/redis.py 7 4 43%
fastapiredis/services.py 7 3 57%
fastapiredis/tests.py 18 0 100%
-------------------------------------------------
TOTAL 53 7 87%

View File

@ -0,0 +1,21 @@
version: "3.7"
services:
example:
build: .
environment:
REDIS_HOST: "redis"
REDIS_PASSWORD: "password"
ports:
- "8000:8000"
volumes:
- "./:/code"
depends_on:
- "redis"
redis:
image: redis
command: ["redis-server", "--requirepass", "password"]
ports:
- "6379:6379"

View File

@ -0,0 +1 @@
"""Top-level package."""

View File

@ -0,0 +1,25 @@
"""Application module."""
import sys
from fastapi import FastAPI, Depends
from dependency_injector.wiring import inject, Provide
from .containers import Container
from .services import Service
app = FastAPI()
@app.api_route('/')
@inject
async def index(service: Service = Depends(Provide[Container.service])):
value = await service.process()
return {'result': value}
container = Container()
container.config.redis_host.from_env('REDIS_HOST', 'localhost')
container.config.redis_password.from_env('REDIS_PASSWORD', 'password')
container.wire(modules=[sys.modules[__name__]])

View File

@ -0,0 +1,21 @@
"""Containers module."""
from dependency_injector import containers, providers
from . import redis, services
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
redis_pool = providers.Resource(
redis.init_redis_pool,
host=config.redis_host,
password=config.redis_password,
)
service = providers.Factory(
services.Service,
redis=redis_pool,
)

View File

@ -0,0 +1,12 @@
"""Redis client module."""
from typing import AsyncIterator
from aioredis import create_redis_pool, Redis
async def init_redis_pool(host: str, password: str) -> AsyncIterator[Redis]:
pool = await create_redis_pool(f'redis://{host}', password=password)
yield pool
pool.close()
await pool.wait_closed()

View File

@ -0,0 +1,12 @@
"""Services module."""
from aioredis import Redis
class Service:
def __init__(self, redis: Redis) -> None:
self._redis = redis
async def process(self) -> str:
await self._redis.set('my-key', 'value')
return await self._redis.get('my-key', encoding='utf-8')

View File

@ -0,0 +1,28 @@
"""Tests module."""
from unittest import mock
import pytest
from httpx import AsyncClient
from .application import app, container
from .services import Service
@pytest.fixture
def client(event_loop):
client = AsyncClient(app=app, base_url='http://test')
yield client
event_loop.run_until_complete(client.aclose())
@pytest.mark.asyncio
async def test_index(client):
service_mock = mock.AsyncMock(spec=Service)
service_mock.process.return_value = 'Foo'
with container.service.override(service_mock):
response = await client.get('/')
assert response.status_code == 200
assert response.json() == {'result': 'Foo'}

View File

@ -0,0 +1,10 @@
dependency-injector
fastapi
uvicorn
aioredis
# For testing:
pytest
pytest-asyncio
pytest-cov
httpx

View File

@ -0,0 +1,37 @@
"""Asynchronous injections example."""
import asyncio
from dependency_injector import containers, providers
async def init_async_resource():
await asyncio.sleep(0.1)
yield 'Initialized'
class Service:
def __init__(self, resource):
self.resource = resource
class Container(containers.DeclarativeContainer):
resource = providers.Resource(init_async_resource)
service = providers.Factory(
Service,
resource=resource,
)
async def main(container: Container):
resource = await container.resource()
service = await container.service()
...
if __name__ == '__main__':
container = Container()
asyncio.run(main(container))

View File

@ -0,0 +1,32 @@
"""Provider overriding in async mode example."""
import asyncio
from dependency_injector import containers, providers
async def init_async_resource():
return ...
def init_resource_mock():
return ...
class Container(containers.DeclarativeContainer):
resource = providers.Resource(init_async_resource)
async def main(container: Container):
resource1 = await container.resource()
container.resource.override(providers.Callable(init_resource_mock))
resource2 = await container.resource()
...
if __name__ == '__main__':
container = Container()
asyncio.run(main(container))

View File

@ -7,5 +7,8 @@ pydocstyle
sphinx_autobuild
pip
mypy
pyyaml
httpx
fastapi
-r requirements-ext.txt

View File

@ -4,6 +4,8 @@ max_complexity = 10
exclude = types.py
per-file-ignores =
examples/demo/*: F841
examples/providers/async.py: F841
examples/providers/async_overriding.py: F841
examples/wiring/*: F841
[pydocstyle]

File diff suppressed because it is too large Load Diff

View File

@ -11,3 +11,6 @@ cpdef bint is_container(object instance)
cpdef object _check_provider_type(object container, object provider)
cpdef bint _isawaitable(object instance)

View File

@ -1,5 +1,17 @@
from types import ModuleType
from typing import Type, Dict, Tuple, Optional, Any, Union, ClassVar, Callable as _Callable, Iterable, TypeVar
from typing import (
Type,
Dict,
Tuple,
Optional,
Any,
Union,
ClassVar,
Callable as _Callable,
Iterable,
TypeVar,
Awaitable,
)
from .providers import Provider
@ -25,8 +37,8 @@ class Container:
def resolve_provider_name(self, provider_to_resolve: Provider) -> Optional[str]: ...
def wire(self, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None) -> None: ...
def unwire(self) -> None: ...
def init_resources(self) -> None: ...
def shutdown_resources(self) -> None: ...
def init_resources(self) -> Optional[Awaitable]: ...
def shutdown_resources(self) -> Optional[Awaitable]: ...
class DynamicContainer(Container): ...

View File

@ -1,7 +1,13 @@
"""Containers module."""
import inspect
import sys
try:
import asyncio
except ImportError:
asyncio = None
import six
from .errors import Error
@ -216,17 +222,33 @@ class DynamicContainer(object):
def init_resources(self):
"""Initialize all container resources."""
futures = []
for provider in self.providers.values():
if not isinstance(provider, Resource):
continue
provider.init()
resource = provider.init()
if _isawaitable(resource):
futures.append(resource)
if futures:
return asyncio.gather(*futures)
def shutdown_resources(self):
"""Shutdown all container resources."""
futures = []
for provider in self.providers.values():
if not isinstance(provider, Resource):
continue
provider.shutdown()
shutdown = provider.shutdown()
if _isawaitable(shutdown):
futures.append(shutdown)
if futures:
return asyncio.gather(*futures)
class DeclarativeContainerMetaClass(type):
@ -494,3 +516,10 @@ cpdef object _check_provider_type(object container, object provider):
if not isinstance(provider, container.provider_type):
raise Error('{0} can contain only {1} '
'instances'.format(container, container.provider_type))
cpdef bint _isawaitable(object instance):
try:
return <bint> inspect.isawaitable(instance)
except AttributeError:
return <bint> False

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,12 @@
"""Providers module."""
try:
import asyncio
except ImportError:
asyncio = None
import inspect
cimport cython
@ -7,6 +14,7 @@ cimport cython
cdef class Provider(object):
cdef tuple __overridden
cdef Provider __last_overriding
cdef int __async_mode
cpdef object _provide(self, tuple args, dict kwargs)
cpdef void _copy_overridings(self, Provider copied, dict memo)
@ -134,10 +142,10 @@ cdef class FactoryAggregate(Provider):
# Singleton providers
cdef class BaseSingleton(Provider):
cdef Factory __instantiator
cdef object __storage
cdef class Singleton(BaseSingleton):
cdef object __storage
cpdef object _provide(self, tuple args, dict kwargs)
@ -147,7 +155,6 @@ cdef class DelegatedSingleton(Singleton):
cdef class ThreadSafeSingleton(BaseSingleton):
cdef object __storage
cdef object __storage_lock
cpdef object _provide(self, tuple args, dict kwargs)
@ -158,7 +165,6 @@ cdef class DelegatedThreadSafeSingleton(ThreadSafeSingleton):
cdef class ThreadLocalSingleton(BaseSingleton):
cdef object __storage
cpdef object _provide(self, tuple args, dict kwargs)
@ -331,30 +337,38 @@ cdef inline tuple __separate_prefixed_kwargs(dict kwargs):
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline tuple __provide_positional_args(
cdef inline object __provide_positional_args(
tuple args,
tuple inj_args,
int inj_args_len,
):
cdef int index
cdef list positional_args
cdef list positional_args = []
cdef list awaitables = []
cdef PositionalInjection injection
if inj_args_len == 0:
return args
positional_args = list()
for index in range(inj_args_len):
injection = <PositionalInjection>inj_args[index]
positional_args.append(__get_value(injection))
value = __get_value(injection)
positional_args.append(value)
if __isawaitable(value):
awaitables.append((index, value))
positional_args.extend(args)
return tuple(positional_args)
if awaitables:
return __awaitable_args_kwargs_future(positional_args, awaitables)
return positional_args
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline dict __provide_keyword_args(
cdef inline object __provide_keyword_args(
dict kwargs,
tuple inj_kwargs,
int inj_kwargs_len,
@ -362,14 +376,18 @@ cdef inline dict __provide_keyword_args(
cdef int index
cdef object name
cdef object value
cdef dict prefixed
cdef dict prefixed = {}
cdef list awaitables = []
cdef NamedInjection kw_injection
if len(kwargs) == 0:
for index in range(inj_kwargs_len):
kw_injection = <NamedInjection>inj_kwargs[index]
name = __get_name(kw_injection)
kwargs[name] = __get_value(kw_injection)
value = __get_value(kw_injection)
kwargs[name] = value
if __isawaitable(value):
awaitables.append((name, value))
else:
kwargs, prefixed = __separate_prefixed_kwargs(kwargs)
@ -387,23 +405,77 @@ cdef inline dict __provide_keyword_args(
value = __get_value(kw_injection)
kwargs[name] = value
if __isawaitable(value):
awaitables.append((name, value))
if awaitables:
return __awaitable_args_kwargs_future(kwargs, awaitables)
return kwargs
cdef inline object __awaitable_args_kwargs_future(object args, list awaitables):
future_result = asyncio.Future()
args_future = asyncio.Future()
args_future.set_result((future_result, args, awaitables))
args_ready = asyncio.gather(args_future, *[value for _, value in awaitables])
args_ready.add_done_callback(__async_prepare_args_kwargs_callback)
asyncio.ensure_future(args_ready)
return future_result
cdef inline void __async_prepare_args_kwargs_callback(object future):
(future_result, args, awaitables), *awaited = future.result()
for value, (key, _) in zip(awaited, awaitables):
args[key] = value
future_result.set_result(args)
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline object __inject_attributes(
object instance,
tuple attributes,
int attributes_len,
):
cdef inline object __provide_attributes(tuple attributes, int attributes_len):
cdef NamedInjection attr_injection
cdef dict attribute_injections = {}
cdef list awaitables = []
for index in range(attributes_len):
attr_injection = <NamedInjection>attributes[index]
setattr(instance,
__get_name(attr_injection),
__get_value(attr_injection))
name = __get_name(attr_injection)
value = __get_value(attr_injection)
attribute_injections[name] = value
if __isawaitable(value):
awaitables.append((name, value))
if awaitables:
return __awaitable_args_kwargs_future(attribute_injections, awaitables)
return attribute_injections
cdef inline object __async_inject_attributes(future_instance, future_attributes):
future_result = asyncio.Future()
future = asyncio.Future()
future.set_result(future_result)
attributes_ready = asyncio.gather(future, future_instance, future_attributes)
attributes_ready.add_done_callback(__async_inject_attributes_callback)
asyncio.ensure_future(attributes_ready)
return future_result
cdef inline void __async_inject_attributes_callback(future):
future_result, instance, attributes = future.result()
__inject_attributes(instance, attributes)
future_result.set_result(instance)
cdef inline void __inject_attributes(object instance, dict attributes):
for name, value in attributes.items():
setattr(instance, name, value)
cdef inline object __call(
@ -411,25 +483,53 @@ cdef inline object __call(
tuple context_args,
tuple injection_args,
int injection_args_len,
dict kwargs,
dict context_kwargs,
tuple injection_kwargs,
int injection_kwargs_len,
):
cdef tuple positional_args
cdef dict keyword_args
positional_args = __provide_positional_args(
args = __provide_positional_args(
context_args,
injection_args,
injection_args_len,
)
keyword_args = __provide_keyword_args(
kwargs,
kwargs = __provide_keyword_args(
context_kwargs,
injection_kwargs,
injection_kwargs_len,
)
return call(*positional_args, **keyword_args)
args_awaitable = __isawaitable(args)
kwargs_awaitable = __isawaitable(kwargs)
if args_awaitable or kwargs_awaitable:
if not args_awaitable:
future = asyncio.Future()
future.set_result(args)
args = future
if not kwargs_awaitable:
future = asyncio.Future()
future.set_result(kwargs)
kwargs = future
future_result = asyncio.Future()
future = asyncio.Future()
future.set_result((future_result, call))
args_kwargs_ready = asyncio.gather(future, args, kwargs)
args_kwargs_ready.add_done_callback(__async_call_callback)
asyncio.ensure_future(args_kwargs_ready)
return future_result
return call(*args, **kwargs)
cdef inline void __async_call_callback(object future):
(future_result, call), args, kwargs = future.result()
result = call(*args, **kwargs)
future_result.set_result(result)
cdef inline object __callable_call(Callable self, tuple args, dict kwargs):
@ -450,8 +550,40 @@ cdef inline object __factory_call(Factory self, tuple args, dict kwargs):
instance = __callable_call(self.__instantiator, args, kwargs)
if self.__attributes_len > 0:
__inject_attributes(instance,
self.__attributes,
self.__attributes_len)
attributes = __provide_attributes(self.__attributes, self.__attributes_len)
instance_awaitable = __isawaitable(instance)
attributes_awaitable = __isawaitable(attributes)
if instance_awaitable or attributes_awaitable:
if not instance_awaitable:
future = asyncio.Future()
future.set_result(instance)
instance = future
if not attributes_awaitable:
future = asyncio.Future()
future.set_result(attributes)
attributes = future
return __async_inject_attributes(instance, attributes)
__inject_attributes(instance, attributes)
return instance
cdef bint __has_isawaitable = False
cdef inline bint __isawaitable(object instance):
global __has_isawaitable
if __has_isawaitable is True:
return inspect.isawaitable(instance)
if hasattr(inspect, 'isawaitable'):
__has_isawaitable = True
return inspect.isawaitable(instance)
return False

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from pathlib import Path
from typing import (
Awaitable,
TypeVar,
Generic,
Type,
@ -14,6 +15,7 @@ from typing import (
Union,
Coroutine as _Coroutine,
Iterator as _Iterator,
AsyncIterator as _AsyncIterator,
Generator as _Generator,
overload,
)
@ -33,7 +35,13 @@ class OverridingContext:
class Provider(Generic[T]):
def __init__(self) -> None: ...
@overload
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
@overload
def __call__(self, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
def async_(self, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
def __deepcopy__(self, memo: Optional[_Dict[Any, Any]]) -> Provider: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
@ -49,30 +57,33 @@ class Provider(Generic[T]):
def provider(self) -> Provider: ...
@property
def provided(self) -> ProvidedInstance: ...
def enable_async_mode(self) -> None: ...
def disable_async_mode(self) -> None: ...
def reset_async_mode(self) -> None: ...
def is_async_mode_enabled(self) -> bool: ...
def is_async_mode_disabled(self) -> bool: ...
def is_async_mode_undefined(self) -> bool: ...
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
class Object(Provider, Generic[T]):
class Object(Provider[T]):
def __init__(self, provides: T) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
class Delegate(Provider):
class Delegate(Provider[Provider]):
def __init__(self, provides: Provider) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
@property
def provides(self) -> Provider: ...
class Dependency(Provider, Generic[T]):
class Dependency(Provider[T]):
def __init__(self, instance_of: Type[T] = object) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
@property
def instance_of(self) -> Type[T]: ...
def provided_by(self, provider: Provider) -> OverridingContext: ...
class ExternalDependency(Dependency): ...
class ExternalDependency(Dependency[T]): ...
class DependenciesContainer(Object):
@ -82,9 +93,8 @@ class DependenciesContainer(Object):
def providers(self) -> _Dict[str, Provider]: ...
class Callable(Provider, Generic[T]):
class Callable(Provider[T]):
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
@property
def provides(self) -> T: ...
@property
@ -93,16 +103,16 @@ class Callable(Provider, Generic[T]):
def set_args(self, *args: Injection) -> Callable[T]: ...
def clear_args(self) -> Callable[T]: ...
@property
def kwargs(self) -> _Dict[str, Injection]: ...
def kwargs(self) -> _Dict[Any, Injection]: ...
def add_kwargs(self, **kwargs: Injection) -> Callable[T]: ...
def set_kwargs(self, **kwargs: Injection) -> Callable[T]: ...
def clear_kwargs(self) -> Callable[T]: ...
class DelegatedCallable(Callable): ...
class DelegatedCallable(Callable[T]): ...
class AbstractCallable(Callable):
class AbstractCallable(Callable[T]):
def override(self, provider: Callable) -> OverridingContext: ...
@ -110,13 +120,13 @@ class CallableDelegate(Delegate):
def __init__(self, callable: Callable) -> None: ...
class Coroutine(Callable): ...
class Coroutine(Callable[T]): ...
class DelegatedCoroutine(Coroutine): ...
class DelegatedCoroutine(Coroutine[T]): ...
class AbstractCoroutine(Coroutine):
class AbstractCoroutine(Coroutine[T]):
def override(self, provider: Coroutine) -> OverridingContext: ...
@ -124,10 +134,9 @@ class CoroutineDelegate(Delegate):
def __init__(self, coroutine: Coroutine) -> None: ...
class ConfigurationOption(Provider):
class ConfigurationOption(Provider[Any]):
UNDEFINED: object
def __init__(self, name: Tuple[str], root: Configuration) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
def __getattr__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
@property
@ -149,7 +158,7 @@ class TypedConfigurationOption(Callable[T]):
def option(self) -> ConfigurationOption: ...
class Configuration(Object):
class Configuration(Object[Any]):
DEFAULT_NAME: str = 'config'
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
def __getattr__(self, item: str) -> ConfigurationOption: ...
@ -165,10 +174,9 @@ class Configuration(Object):
def from_env(self, name: str, default: Optional[Any] = None) -> None: ...
class Factory(Provider, Generic[T]):
class Factory(Provider[T]):
provided_type: Optional[Type]
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
@property
def cls(self) -> T: ...
@property
@ -179,21 +187,21 @@ class Factory(Provider, Generic[T]):
def set_args(self, *args: Injection) -> Factory[T]: ...
def clear_args(self) -> Factory[T]: ...
@property
def kwargs(self) -> _Dict[str, Injection]: ...
def kwargs(self) -> _Dict[Any, Injection]: ...
def add_kwargs(self, **kwargs: Injection) -> Factory[T]: ...
def set_kwargs(self, **kwargs: Injection) -> Factory[T]: ...
def clear_kwargs(self) -> Factory[T]: ...
@property
def attributes(self) -> _Dict[str, Injection]: ...
def attributes(self) -> _Dict[Any, Injection]: ...
def add_attributes(self, **kwargs: Injection) -> Factory[T]: ...
def set_attributes(self, **kwargs: Injection) -> Factory[T]: ...
def clear_attributes(self) -> Factory[T]: ...
class DelegatedFactory(Factory): ...
class DelegatedFactory(Factory[T]): ...
class AbstractFactory(Factory):
class AbstractFactory(Factory[T]):
def override(self, provider: Factory) -> OverridingContext: ...
@ -203,55 +211,60 @@ class FactoryDelegate(Delegate):
class FactoryAggregate(Provider):
def __init__(self, **factories: Factory): ...
def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Any: ...
def __getattr__(self, factory_name: str) -> Factory: ...
@overload
def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Any: ...
@overload
def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ...
def async_(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ...
@property
def factories(self) -> _Dict[str, Factory]: ...
class BaseSingleton(Provider, Generic[T]):
class BaseSingleton(Provider[T]):
provided_type = Optional[Type]
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
@property
def cls(self) -> T: ...
@property
def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Factory[T]: ...
def set_args(self, *args: Injection) -> Factory[T]: ...
def clear_args(self) -> Factory[T]: ...
def add_args(self, *args: Injection) -> BaseSingleton[T]: ...
def set_args(self, *args: Injection) -> BaseSingleton[T]: ...
def clear_args(self) -> BaseSingleton[T]: ...
@property
def kwargs(self) -> _Dict[str, Injection]: ...
def add_kwargs(self, **kwargs: Injection) -> Factory[T]: ...
def set_kwargs(self, **kwargs: Injection) -> Factory[T]: ...
def clear_kwargs(self) -> Factory[T]: ...
def kwargs(self) -> _Dict[Any, Injection]: ...
def add_kwargs(self, **kwargs: Injection) -> BaseSingleton[T]: ...
def set_kwargs(self, **kwargs: Injection) -> BaseSingleton[T]: ...
def clear_kwargs(self) -> BaseSingleton[T]: ...
@property
def attributes(self) -> _Dict[str, Injection]: ...
def add_attributes(self, **kwargs: Injection) -> Factory[T]: ...
def set_attributes(self, **kwargs: Injection) -> Factory[T]: ...
def clear_attributes(self) -> Factory[T]: ...
def attributes(self) -> _Dict[Any, Injection]: ...
def add_attributes(self, **kwargs: Injection) -> BaseSingleton[T]: ...
def set_attributes(self, **kwargs: Injection) -> BaseSingleton[T]: ...
def clear_attributes(self) -> BaseSingleton[T]: ...
def reset(self) -> None: ...
class Singleton(BaseSingleton): ...
class Singleton(BaseSingleton[T]): ...
class DelegatedSingleton(Singleton): ...
class DelegatedSingleton(Singleton[T]): ...
class ThreadSafeSingleton(Singleton): ...
class ThreadSafeSingleton(Singleton[T]): ...
class DelegatedThreadSafeSingleton(ThreadSafeSingleton): ...
class DelegatedThreadSafeSingleton(ThreadSafeSingleton[T]): ...
class ThreadLocalSingleton(BaseSingleton): ...
class ThreadLocalSingleton(BaseSingleton[T]): ...
class DelegatedThreadLocalSingleton(ThreadLocalSingleton): ...
class DelegatedThreadLocalSingleton(ThreadLocalSingleton[T]): ...
class AbstractSingleton(BaseSingleton):
class AbstractSingleton(BaseSingleton[T]):
def override(self, provider: BaseSingleton) -> OverridingContext: ...
@ -259,19 +272,17 @@ class SingletonDelegate(Delegate):
def __init__(self, factory: BaseSingleton): ...
class List(Provider):
class List(Provider[_List]):
def __init__(self, *args: Injection): ...
def __call__(self, *args: Injection, **kwargs: Injection) -> _List[Any]: ...
@property
def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> List: ...
def set_args(self, *args: Injection) -> List: ...
def clear_args(self) -> List: ...
def add_args(self, *args: Injection) -> List[T]: ...
def set_args(self, *args: Injection) -> List[T]: ...
def clear_args(self) -> List[T]: ...
class Dict(Provider):
class Dict(Provider[_Dict]):
def __init__(self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection): ...
def __call__(self, *args: Injection, **kwargs: Injection) -> _Dict[Any, Any]: ...
@property
def kwargs(self) -> _Dict[Any, Injection]: ...
def add_kwargs(self, dict_: Optional[_Dict[Any, Injection]] = None, **kwargs: Injection) -> Dict: ...
@ -279,42 +290,44 @@ class Dict(Provider):
def clear_kwargs(self) -> Dict: ...
class Resource(Provider, Generic[T]):
class Resource(Provider[T]):
@overload
def __init__(self, initializer: _Callable[..., resources.Resource[T]], *args: Injection, **kwargs: Injection) -> None: ...
def __init__(self, initializer: Type[resources.Resource[T]], *args: Injection, **kwargs: Injection) -> None: ...
@overload
def __init__(self, initializer: Type[resources.AsyncResource[T]], *args: Injection, **kwargs: Injection) -> None: ...
@overload
def __init__(self, initializer: _Callable[..., _Iterator[T]], *args: Injection, **kwargs: Injection) -> None: ...
@overload
def __init__(self, initializer: _Callable[..., _AsyncIterator[T]], *args: Injection, **kwargs: Injection) -> None: ...
@overload
def __init__(self, initializer: _Callable[..., _Coroutine[Injection, Injection, T]], *args: Injection, **kwargs: Injection) -> None: ...
@overload
def __init__(self, initializer: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
@property
def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Resource: ...
def set_args(self, *args: Injection) -> Resource: ...
def clear_args(self) -> Resource: ...
def add_args(self, *args: Injection) -> Resource[T]: ...
def set_args(self, *args: Injection) -> Resource[T]: ...
def clear_args(self) -> Resource[T]: ...
@property
def kwargs(self) -> _Dict[Any, Injection]: ...
def add_kwargs(self, **kwargs: Injection) -> Resource: ...
def set_kwargs(self, **kwargs: Injection) -> Resource: ...
def clear_kwargs(self) -> Resource: ...
def add_kwargs(self, **kwargs: Injection) -> Resource[T]: ...
def set_kwargs(self, **kwargs: Injection) -> Resource[T]: ...
def clear_kwargs(self) -> Resource[T]: ...
@property
def initialized(self) -> bool: ...
def init(self) -> T: ...
def shutdown(self) -> None: ...
def init(self) -> Optional[Awaitable[T]]: ...
def shutdown(self) -> Optional[Awaitable]: ...
class Container(Provider):
class Container(Provider[T]):
def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Provider) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
def __getattr__(self, name: str) -> Provider: ...
@property
def container(self) -> T: ...
class Selector(Provider):
class Selector(Provider[Any]):
def __init__(self, selector: _Callable[..., Any], **providers: Provider): ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
def __getattr__(self, name: str) -> Provider: ...
@property
def providers(self) -> _Dict[str, Provider]: ...

View File

@ -3,6 +3,7 @@
from __future__ import absolute_import
import copy
import functools
import inspect
import os
import re
@ -88,6 +89,11 @@ else:
return parser
cdef int ASYNC_MODE_UNDEFINED = 0
cdef int ASYNC_MODE_ENABLED = 1
cdef int ASYNC_MODE_DISABLED = 2
cdef class Provider(object):
"""Base provider class.
@ -148,6 +154,7 @@ cdef class Provider(object):
"""Initializer."""
self.__overridden = tuple()
self.__last_overriding = None
self.__async_mode = ASYNC_MODE_UNDEFINED
super(Provider, self).__init__()
def __call__(self, *args, **kwargs):
@ -156,8 +163,24 @@ cdef class Provider(object):
Callable interface implementation.
"""
if self.__last_overriding is not None:
return self.__last_overriding(*args, **kwargs)
return self._provide(args, kwargs)
result = self.__last_overriding(*args, **kwargs)
else:
result = self._provide(args, kwargs)
if self.is_async_mode_disabled():
return result
elif self.is_async_mode_enabled():
if not __isawaitable(result):
future_result = asyncio.Future()
future_result.set_result(result)
return future_result
return result
elif self.is_async_mode_undefined():
if __isawaitable(result):
self.enable_async_mode()
else:
self.disable_async_mode()
return result
def __deepcopy__(self, memo):
"""Create and return full copy of provider."""
@ -254,6 +277,23 @@ cdef class Provider(object):
self.__overridden = tuple()
self.__last_overriding = None
def async_(self, *args, **kwargs):
"""Return provided object asynchronously.
This method is a synonym of __call__().
It provides typing stubs for correct type checking with
`await` expression:
.. code-block:: python
database_provider: Provider[DatabaseConnection] = Resource(init_db_async)
async def main():
db: DatabaseConnection = await database_provider.async_()
...
"""
return self.__call__(*args, **kwargs)
def delegate(self):
"""Return provider's delegate.
@ -279,6 +319,33 @@ cdef class Provider(object):
"""Return :py:class:`ProvidedInstance` provider."""
return ProvidedInstance(self)
def enable_async_mode(self):
"""Enable async mode."""
self.__async_mode = ASYNC_MODE_ENABLED
def disable_async_mode(self):
"""Disable async mode."""
self.__async_mode = ASYNC_MODE_DISABLED
def reset_async_mode(self):
"""Reset async mode.
Provider will automatically set the mode on the next call.
"""
self.__async_mode = ASYNC_MODE_UNDEFINED
def is_async_mode_enabled(self):
"""Check if async mode is enabled."""
return self.__async_mode == ASYNC_MODE_ENABLED
def is_async_mode_disabled(self):
"""Check if async mode is disabled."""
return self.__async_mode == ASYNC_MODE_DISABLED
def is_async_mode_undefined(self):
"""Check if async mode is undefined."""
return self.__async_mode == ASYNC_MODE_UNDEFINED
cpdef object _provide(self, tuple args, dict kwargs):
"""Providing strategy implementation.
@ -472,18 +539,38 @@ cdef class Dependency(Provider):
:rtype: object
"""
cdef object instance
if self.__last_overriding is None:
raise Error('Dependency is not defined')
instance = self.__last_overriding(*args, **kwargs)
result = self.__last_overriding(*args, **kwargs)
if not isinstance(instance, self.instance_of):
raise Error('{0} is not an '.format(instance) +
'instance of {0}'.format(self.instance_of))
return instance
if self.is_async_mode_disabled():
self._check_instance_type(result)
return result
elif self.is_async_mode_enabled():
if __isawaitable(result):
future_result = asyncio.Future()
result = asyncio.ensure_future(result)
result.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result
else:
self._check_instance_type(result)
future_result = asyncio.Future()
future_result.set_result(result)
return future_result
elif self.is_async_mode_undefined():
if __isawaitable(result):
self.enable_async_mode()
future_result = asyncio.Future()
result = asyncio.ensure_future(result)
result.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result
else:
self.disable_async_mode()
self._check_instance_type(result)
return result
def __str__(self):
"""Return string representation of provider.
@ -514,6 +601,19 @@ cdef class Dependency(Provider):
"""
return self.override(provider)
def _async_provide(self, future_result, future):
instance = future.result()
try:
self._check_instance_type(instance)
except Error as exception:
future_result.set_exception(exception)
else:
future_result.set_result(instance)
def _check_instance_type(self, instance):
if not isinstance(instance, self.instance_of):
raise Error('{0} is not an instance of {1}'.format(instance, self.instance_of))
cdef class ExternalDependency(Dependency):
""":py:class:`ExternalDependency` provider describes dependency interface.
@ -904,7 +1004,7 @@ cdef class AbstractCallable(Callable):
"""
if self.__last_overriding is None:
raise Error('{0} must be overridden before calling'.format(self))
return self.__last_overriding(*args, **kwargs)
return super().__call__(*args, **kwargs)
def override(self, provider):
"""Override provider with another provider.
@ -1020,7 +1120,7 @@ cdef class AbstractCoroutine(Coroutine):
"""
if self.__last_overriding is None:
raise Error('{0} must be overridden before calling'.format(self))
return self.__last_overriding(*args, **kwargs)
return super().__call__(*args, **kwargs)
def override(self, provider):
"""Override provider with another provider.
@ -1790,7 +1890,7 @@ cdef class AbstractFactory(Factory):
"""
if self.__last_overriding is None:
raise Error('{0} must be overridden before calling'.format(self))
return self.__last_overriding(*args, **kwargs)
return super().__call__(*args, **kwargs)
def override(self, provider):
"""Override provider with another provider.
@ -1881,13 +1981,6 @@ cdef class FactoryAggregate(Provider):
return copied
def __call__(self, factory_name, *args, **kwargs):
"""Create new object using factory with provided name.
Callable interface implementation.
"""
return self.__get_factory(factory_name)(*args, **kwargs)
def __getattr__(self, factory_name):
"""Return aggregated factory."""
return self.__get_factory(factory_name)
@ -1915,6 +2008,19 @@ cdef class FactoryAggregate(Provider):
raise Error(
'{0} providers could not be overridden'.format(self.__class__))
cpdef object _provide(self, tuple args, dict kwargs):
try:
factory_name = args[0]
except IndexError:
try:
factory_name = kwargs.pop('factory_name')
except KeyError:
raise TypeError('Factory missing 1 required positional argument: \'factory_name\'')
else:
args = args[1:]
return self.__get_factory(factory_name)(*args, **kwargs)
cdef Factory __get_factory(self, str factory_name):
if factory_name not in self.__factories:
raise NoSuchProviderError(
@ -2075,6 +2181,16 @@ cdef class BaseSingleton(Provider):
"""
raise NotImplementedError()
def _async_init_instance(self, future_result, result):
try:
instance = result.result()
except Exception as exception:
self.__storage = None
future_result.set_exception(exception)
else:
self.__storage = instance
future_result.set_result(instance)
cdef class Singleton(BaseSingleton):
"""Singleton provider returns same instance on every call.
@ -2122,13 +2238,24 @@ cdef class Singleton(BaseSingleton):
:rtype: None
"""
if __isawaitable(self.__storage):
asyncio.ensure_future(self.__storage).cancel()
self.__storage = None
cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance."""
if self.__storage is None:
self.__storage = __factory_call(self.__instantiator,
args, kwargs)
instance = __factory_call(self.__instantiator, args, kwargs)
if __isawaitable(instance):
future_result = asyncio.Future()
instance = asyncio.ensure_future(instance)
instance.add_done_callback(functools.partial(self._async_init_instance, future_result))
self.__storage = future_result
return future_result
self.__storage = instance
return self.__storage
@ -2179,18 +2306,30 @@ cdef class ThreadSafeSingleton(BaseSingleton):
:rtype: None
"""
with self.__storage_lock:
if __isawaitable(self.__storage):
asyncio.ensure_future(self.__storage).cancel()
self.__storage = None
cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance."""
storage = self.__storage
if storage is None:
instance = self.__storage
if instance is None:
with self.__storage_lock:
if self.__storage is None:
self.__storage = __factory_call(self.__instantiator,
args, kwargs)
storage = self.__storage
return storage
instance = __factory_call(self.__instantiator, args, kwargs)
if __isawaitable(instance):
future_result = asyncio.Future()
instance = asyncio.ensure_future(instance)
instance.add_done_callback(functools.partial(self._async_init_instance, future_result))
self.__storage = future_result
return future_result
self.__storage = instance
return instance
cdef class DelegatedThreadSafeSingleton(ThreadSafeSingleton):
@ -2248,6 +2387,8 @@ cdef class ThreadLocalSingleton(BaseSingleton):
:rtype: None
"""
if __isawaitable(self.__storage.instance):
asyncio.ensure_future(self.__storage.instance).cancel()
del self.__storage.instance
cpdef object _provide(self, tuple args, dict kwargs):
@ -2258,10 +2399,28 @@ cdef class ThreadLocalSingleton(BaseSingleton):
instance = self.__storage.instance
except AttributeError:
instance = __factory_call(self.__instantiator, args, kwargs)
if __isawaitable(instance):
future_result = asyncio.Future()
instance = asyncio.ensure_future(instance)
instance.add_done_callback(functools.partial(self._async_init_instance, future_result))
self.__storage.instance = future_result
return future_result
self.__storage.instance = instance
finally:
return instance
def _async_init_instance(self, future_result, result):
try:
instance = result.result()
except Exception as exception:
del self.__storage.instance
future_result.set_exception(exception)
else:
self.__storage.instance = instance
future_result.set_result(instance)
cdef class DelegatedThreadLocalSingleton(ThreadLocalSingleton):
"""Delegated thread-local singleton is injected "as is".
@ -2302,7 +2461,7 @@ cdef class AbstractSingleton(BaseSingleton):
"""
if self.__last_overriding is None:
raise Error('{0} must be overridden before calling'.format(self))
return self.__last_overriding(*args, **kwargs)
return super().__call__(*args, **kwargs)
def override(self, provider):
"""Override provider with another provider.
@ -2705,18 +2864,30 @@ cdef class Resource(Provider):
def shutdown(self):
"""Shutdown resource."""
if not self.__initialized:
if self.is_async_mode_enabled():
result = asyncio.Future()
result.set_result(None)
return result
return
if self.__shutdowner:
try:
self.__shutdowner(self.__resource)
shutdown = self.__shutdowner(self.__resource)
except StopIteration:
pass
else:
if inspect.isawaitable(shutdown):
return self._create_shutdown_future(shutdown)
self.__resource = None
self.__initialized = False
self.__shutdowner = None
if self.is_async_mode_enabled():
result = asyncio.Future()
result.set_result(None)
return result
cpdef object _provide(self, tuple args, dict kwargs):
if self.__initialized:
return self.__resource
@ -2733,6 +2904,19 @@ cdef class Resource(Provider):
self.__kwargs_len,
)
self.__shutdowner = initializer.shutdown
elif self._is_async_resource_subclass(self.__initializer):
initializer = self.__initializer()
async_init = __call(
initializer.init,
args,
self.__args,
self.__args_len,
kwargs,
self.__kwargs,
self.__kwargs_len,
)
self.__initialized = True
return self._create_init_future(async_init, initializer.shutdown)
elif inspect.isgeneratorfunction(self.__initializer):
initializer = __call(
self.__initializer,
@ -2745,6 +2929,30 @@ cdef class Resource(Provider):
)
self.__resource = next(initializer)
self.__shutdowner = initializer.send
elif iscoroutinefunction(self.__initializer):
initializer = __call(
self.__initializer,
args,
self.__args,
self.__args_len,
kwargs,
self.__kwargs,
self.__kwargs_len,
)
self.__initialized = True
return self._create_init_future(initializer)
elif isasyncgenfunction(self.__initializer):
initializer = __call(
self.__initializer,
args,
self.__args,
self.__args_len,
kwargs,
self.__kwargs,
self.__kwargs_len,
)
self.__initialized = True
return self._create_init_future(initializer.__anext__(), initializer.asend)
elif callable(self.__initializer):
self.__resource = __call(
self.__initializer,
@ -2761,6 +2969,45 @@ cdef class Resource(Provider):
self.__initialized = True
return self.__resource
def _create_init_future(self, future, shutdowner=None):
callback = self._async_init_callback
if shutdowner:
callback = functools.partial(callback, shutdowner=shutdowner)
future = asyncio.ensure_future(future)
future.add_done_callback(callback)
self.__resource = future
return future
def _async_init_callback(self, initializer, shutdowner=None):
try:
resource = initializer.result()
except Exception:
self.__initialized = False
raise
else:
self.__resource = resource
self.__shutdowner = shutdowner
def _create_shutdown_future(self, shutdown_future):
future = asyncio.Future()
shutdown_future = asyncio.ensure_future(shutdown_future)
shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future))
return future
def _async_shutdown_callback(self, future_result, shutdowner):
try:
shutdowner.result()
except StopAsyncIteration:
pass
self.__resource = None
self.__initialized = False
self.__shutdowner = None
future_result.set_result(None)
@staticmethod
def _is_resource_subclass(instance):
if sys.version_info < (3, 5):
@ -2770,6 +3017,15 @@ cdef class Resource(Provider):
from . import resources
return issubclass(instance, resources.Resource)
@staticmethod
def _is_async_resource_subclass(instance):
if sys.version_info < (3, 5):
return False
if not isinstance(instance, CLASS_TYPES):
return
from . import resources
return issubclass(instance, resources.AsyncResource)
cdef class Container(Provider):
"""Container provider provides an instance of declarative container.
@ -3037,8 +3293,18 @@ cdef class AttributeGetter(Provider):
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
future_result = asyncio.Future()
provided = asyncio.ensure_future(provided)
provided.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result
return getattr(provided, self.__attribute)
def _async_provide(self, future_result, future):
provided = future.result()
result = getattr(provided, self.__attribute)
future_result.set_result(result)
cdef class ItemGetter(Provider):
"""Provider that returns the item of the injected instance.
@ -3087,8 +3353,18 @@ cdef class ItemGetter(Provider):
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
future_result = asyncio.Future()
provided = asyncio.ensure_future(provided)
provided.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result
return provided[self.__item]
def _async_provide(self, future_result, future):
provided = future.result()
result = provided[self.__item]
future_result.set_result(result)
cdef class MethodCaller(Provider):
"""Provider that calls the method of the injected instance.
@ -3169,6 +3445,11 @@ cdef class MethodCaller(Provider):
cpdef object _provide(self, tuple args, dict kwargs):
call = self.__provider()
if __isawaitable(call):
future_result = asyncio.Future()
call = asyncio.ensure_future(call)
call.add_done_callback(functools.partial(self._async_provide, future_result, args, kwargs))
return future_result
return __call(
call,
args,
@ -3179,6 +3460,19 @@ cdef class MethodCaller(Provider):
self.__kwargs_len,
)
def _async_provide(self, future_result, args, kwargs, future):
call = future.result()
result = __call(
call,
args,
self.__args,
self.__args_len,
kwargs,
self.__kwargs,
self.__kwargs_len,
)
future_result.set_result(result)
cdef class Injection(object):
"""Abstract injection class."""
@ -3381,3 +3675,36 @@ def merge_dicts(dict1, dict2):
result = dict1.copy()
result.update(dict2)
return result
def isawaitable(obj):
"""Check if object is a coroutine function.
Return False for any object in Python 3.4 or below.
"""
try:
return inspect.isawaitable(obj)
except AttributeError:
return False
def iscoroutinefunction(obj):
"""Check if object is a coroutine function.
Return False for any object in Python 3.4 or below.
"""
try:
return inspect.iscoroutinefunction(obj)
except AttributeError:
return False
def isasyncgenfunction(obj):
"""Check if object is an asynchronous generator function.
Return False for any object in Python 3.4 or below.
"""
try:
return inspect.isasyncgenfunction(obj)
except AttributeError:
return False

View File

@ -29,3 +29,14 @@ class Resource(Generic[T], metaclass=ResourceMeta):
@abc.abstractmethod
def shutdown(self, resource: T) -> None:
...
class AsyncResource(Generic[T], metaclass=ResourceMeta):
@abc.abstractmethod
async def init(self, *args, **kwargs) -> T:
...
@abc.abstractmethod
async def shutdown(self, resource: T) -> None:
...

View File

@ -1,5 +1,6 @@
"""Wiring module."""
import asyncio
import functools
import inspect
import importlib
@ -426,10 +427,20 @@ def _get_async_patched(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
to_inject = kwargs.copy()
to_inject_await = []
to_close_await = []
for injection, provider in _patched.__injections__.items():
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider()
provide = provider()
if inspect.isawaitable(provide):
to_inject_await.append((injection, provide))
else:
to_inject[injection] = provide
async_to_inject = await asyncio.gather(*[provide for _, provide in to_inject_await])
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
to_inject[injection] = provide
result = await fn(*args, **to_inject)
@ -439,7 +450,11 @@ def _get_async_patched(fn):
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()
shutdown = provider.shutdown()
if inspect.isawaitable(shutdown):
to_close_await.append(shutdown)
await asyncio.gather(*to_close_await)
return result
return _patched

View File

@ -50,3 +50,9 @@ animal7: Animal = provider7(1, 2, 3, b='1', c=2, e=0.0)
# Test 8: to check the CallableDelegate __init__
provider8 = providers.CallableDelegate(providers.Callable(lambda: None))
# Test 9: to check the return type with await
provider9 = providers.Callable(Cat)
async def _async9() -> None:
animal1: Animal = await provider9(1, 2, 3, b='1', c=2, e=0.0) # type: ignore
animal2: Animal = await provider9.async_(1, 2, 3, b='1', c=2, e=0.0)

View File

@ -4,3 +4,9 @@ from dependency_injector import providers
# Test 1: to check the return type
provider1 = providers.Delegate(providers.Provider())
var1: providers.Provider = provider1()
# Test 2: to check the return type with await
provider2 = providers.Delegate(providers.Provider())
async def _async2() -> None:
var1: providers.Provider = await provider2() # type: ignore
var2: providers.Provider = await provider2.async_()

View File

@ -20,3 +20,9 @@ var1: Animal = provider1()
# Test 2: to check the return type
provider2 = providers.Dependency(instance_of=Animal)
var2: Type[Animal] = provider2.instance_of
# Test 3: to check the return type with await
provider3 = providers.Dependency(instance_of=Animal)
async def _async3() -> None:
var1: Animal = await provider3() # type: ignore
var2: Animal = await provider3.async_()

View File

@ -35,3 +35,13 @@ provider5 = providers.Dict(
a2=providers.Factory(object),
)
provided5: providers.ProvidedInstance = provider5.provided
# Test 6: to check the return type with await
provider6 = providers.Dict(
a1=providers.Factory(object),
a2=providers.Factory(object),
)
async def _async3() -> None:
var1: Dict[Any, Any] = await provider6() # type: ignore
var2: Dict[Any, Any] = await provider6.async_()

View File

@ -66,3 +66,9 @@ val9: Any = provider9('a')
# Test 10: to check the explicit typing
factory10: providers.Provider[Animal] = providers.Factory(Cat)
animal10: Animal = factory10()
# Test 11: to check the return type with await
provider11 = providers.Factory(Cat)
async def _async11() -> None:
animal1: Animal = await provider11(1, 2, 3, b='1', c=2, e=0.0) # type: ignore
animal2: Animal = await provider11.async_(1, 2, 3, b='1', c=2, e=0.0)

View File

@ -27,3 +27,12 @@ provided3: providers.ProvidedInstance = provider3.provided
attr_getter3: providers.AttributeGetter = provider3.provided.attr
item_getter3: providers.ItemGetter = provider3.provided['item']
method_caller3: providers.MethodCaller = provider3.provided.method.call(123, arg=324)
# Test 4: to check the return type with await
provider4 = providers.List(
providers.Factory(object),
providers.Factory(object),
)
async def _async4() -> None:
var1: List[Any] = await provider4() # type: ignore
var2: List[Any] = await provider4.async_()

View File

@ -11,3 +11,9 @@ provided2: providers.ProvidedInstance = provider2.provided
attr_getter2: providers.AttributeGetter = provider2.provided.attr
item_getter2: providers.ItemGetter = provider2.provided['item']
method_caller2: providers.MethodCaller = provider2.provided.method.call(123, arg=324)
# Test 3: to check the return type with await
provider3 = providers.Object(int(3))
async def _async3() -> None:
var1: int = await provider3() # type: ignore
var2: int = await provider3.async_()

View File

@ -4,3 +4,12 @@ from dependency_injector import providers
# Test 1: to check .provided attribute
provider1: providers.Provider[int] = providers.Object(1)
provided: providers.ProvidedInstance = provider1.provided
# Test 2: to check async mode API
provider2: providers.Provider = providers.Provider()
provider2.enable_async_mode()
provider2.disable_async_mode()
provider2.reset_async_mode()
r1: bool = provider2.is_async_mode_enabled()
r2: bool = provider2.is_async_mode_disabled()
r3: bool = provider2.is_async_mode_undefined()

View File

@ -1,4 +1,4 @@
from typing import List, Iterator, Generator
from typing import List, Iterator, Generator, AsyncIterator, AsyncGenerator
from dependency_injector import providers, resources
@ -41,3 +41,59 @@ class MyResource4(resources.Resource[List[int]]):
provider4 = providers.Resource(MyResource4)
var4: List[int] = provider4()
# Test 5: to check the return type with async function
async def init5() -> List[int]:
...
provider5 = providers.Resource(init5)
async def _provide5() -> None:
var1: List[int] = await provider5() # type: ignore
var2: List[int] = await provider5.async_()
# Test 6: to check the return type with async iterator
async def init6() -> AsyncIterator[List[int]]:
yield []
provider6 = providers.Resource(init6)
async def _provide6() -> None:
var1: List[int] = await provider6() # type: ignore
var2: List[int] = await provider6.async_()
# Test 7: to check the return type with async generator
async def init7() -> AsyncGenerator[List[int], None]:
yield []
provider7 = providers.Resource(init7)
async def _provide7() -> None:
var1: List[int] = await provider7() # type: ignore
var2: List[int] = await provider7.async_()
# Test 8: to check the return type with async resource subclass
class MyResource8(resources.AsyncResource[List[int]]):
async def init(self, *args, **kwargs) -> List[int]:
return []
async def shutdown(self, resource: List[int]) -> None:
...
provider8 = providers.Resource(MyResource8)
async def _provide8() -> None:
var1: List[int] = await provider8() # type: ignore
var2: List[int] = await provider8.async_()

View File

@ -1,3 +1,5 @@
from typing import Any
from dependency_injector import providers
@ -7,7 +9,7 @@ provider1 = providers.Selector(
a=providers.Factory(object),
b=providers.Factory(object),
)
var1: int = provider1()
var1: Any = provider1()
# Test 2: to check the provided instance interface
provider2 = providers.Selector(
@ -27,3 +29,13 @@ provider3 = providers.Selector(
b=providers.Factory(object),
)
attr3: providers.Provider = provider3.a
# Test 4: to check the return type with await
provider4 = providers.Selector(
lambda: 'a',
a=providers.Factory(object),
b=providers.Factory(object),
)
async def _async4() -> None:
var1: Any = await provider4() # type: ignore
var2: Any = await provider4.async_()

View File

@ -69,3 +69,9 @@ animal11: Animal = provider11(1, 2, 3, b='1', c=2, e=0.0)
# Test 12: to check the SingletonDelegate __init__
provider12 = providers.SingletonDelegate(providers.Singleton(object))
# Test 13: to check the return type with await
provider13 = providers.Singleton(Cat)
async def _async13() -> None:
animal1: Animal = await provider13(1, 2, 3, b='1', c=2, e=0.0) # type: ignore
animal2: Animal = await provider13.async_(1, 2, 3, b='1', c=2, e=0.0)

57
tests/unit/asyncutils.py Normal file
View File

@ -0,0 +1,57 @@
"""Test utils."""
import asyncio
import contextlib
import sys
import gc
import unittest
def run(main):
loop = asyncio.get_event_loop()
return loop.run_until_complete(main)
def setup_test_loop(
loop_factory=asyncio.new_event_loop
) -> asyncio.AbstractEventLoop:
loop = loop_factory()
try:
module = loop.__class__.__module__
skip_watcher = 'uvloop' in module
except AttributeError: # pragma: no cover
# Just in case
skip_watcher = True
asyncio.set_event_loop(loop)
if sys.platform != 'win32' and not skip_watcher:
policy = asyncio.get_event_loop_policy()
watcher = asyncio.SafeChildWatcher() # type: ignore
watcher.attach_loop(loop)
with contextlib.suppress(NotImplementedError):
policy.set_child_watcher(watcher)
return loop
def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
if not fast:
gc.collect()
asyncio.set_event_loop(None)
class AsyncTestCase(unittest.TestCase):
def setUp(self):
self.loop = setup_test_loop()
def tearDown(self):
teardown_test_loop(self.loop)
def _run(self, f):
return self.loop.run_until_complete(f)

View File

@ -0,0 +1,71 @@
"""Dependency injector dynamic container unit tests for async resources."""
import unittest2 as unittest
# Runtime import to get asyncutils module
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
import sys
sys.path.append(_TOP_DIR)
from asyncutils import AsyncTestCase
from dependency_injector import (
containers,
providers,
)
class AsyncResourcesTest(AsyncTestCase):
@unittest.skipIf(sys.version_info[:2] <= (3, 5), 'Async test')
def test_async_init_resources(self):
async def _init1():
_init1.init_counter += 1
yield
_init1.shutdown_counter += 1
_init1.init_counter = 0
_init1.shutdown_counter = 0
async def _init2():
_init2.init_counter += 1
yield
_init2.shutdown_counter += 1
_init2.init_counter = 0
_init2.shutdown_counter = 0
class Container(containers.DeclarativeContainer):
resource1 = providers.Resource(_init1)
resource2 = providers.Resource(_init2)
container = Container()
self.assertEqual(_init1.init_counter, 0)
self.assertEqual(_init1.shutdown_counter, 0)
self.assertEqual(_init2.init_counter, 0)
self.assertEqual(_init2.shutdown_counter, 0)
self._run(container.init_resources())
self.assertEqual(_init1.init_counter, 1)
self.assertEqual(_init1.shutdown_counter, 0)
self.assertEqual(_init2.init_counter, 1)
self.assertEqual(_init2.shutdown_counter, 0)
self._run(container.shutdown_resources())
self.assertEqual(_init1.init_counter, 1)
self.assertEqual(_init1.shutdown_counter, 1)
self.assertEqual(_init2.init_counter, 1)
self.assertEqual(_init2.shutdown_counter, 1)
self._run(container.init_resources())
self._run(container.shutdown_resources())
self.assertEqual(_init1.init_counter, 2)
self.assertEqual(_init1.shutdown_counter, 2)
self.assertEqual(_init2.init_counter, 2)
self.assertEqual(_init2.shutdown_counter, 2)

View File

@ -231,7 +231,3 @@ class DeclarativeContainerInstanceTests(unittest.TestCase):
self.assertEqual(_init1.shutdown_counter, 2)
self.assertEqual(_init2.init_counter, 2)
self.assertEqual(_init2.shutdown_counter, 2)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,818 @@
import asyncio
import random
import unittest
from dependency_injector import containers, providers, errors
# Runtime import to get asyncutils module
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
import sys
sys.path.append(_TOP_DIR)
from asyncutils import AsyncTestCase
RESOURCE1 = object()
RESOURCE2 = object()
async def init_resource(resource):
await asyncio.sleep(random.randint(1, 10) / 1000)
yield resource
await asyncio.sleep(random.randint(1, 10) / 1000)
class Client:
def __init__(self, resource1: object, resource2: object) -> None:
self.resource1 = resource1
self.resource2 = resource2
class Service:
def __init__(self, client: Client) -> None:
self.client = client
class Container(containers.DeclarativeContainer):
resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1))
resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2))
client = providers.Factory(
Client,
resource1=resource1,
resource2=resource2,
)
service = providers.Factory(
Service,
client=client,
)
class FactoryTests(AsyncTestCase):
def test_args_injection(self):
class ContainerWithArgs(containers.DeclarativeContainer):
resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1))
resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2))
client = providers.Factory(
Client,
resource1,
resource2,
)
service = providers.Factory(
Service,
client,
)
container = ContainerWithArgs()
client1 = self._run(container.client())
client2 = self._run(container.client())
self.assertIsInstance(client1, Client)
self.assertIs(client1.resource1, RESOURCE1)
self.assertIs(client1.resource2, RESOURCE2)
self.assertIsInstance(client2, Client)
self.assertIs(client2.resource1, RESOURCE1)
self.assertIs(client2.resource2, RESOURCE2)
service1 = self._run(container.service())
service2 = self._run(container.service())
self.assertIsInstance(service1, Service)
self.assertIsInstance(service1.client, Client)
self.assertIs(service1.client.resource1, RESOURCE1)
self.assertIs(service1.client.resource2, RESOURCE2)
self.assertIsInstance(service2, Service)
self.assertIsInstance(service2.client, Client)
self.assertIs(service2.client.resource1, RESOURCE1)
self.assertIs(service2.client.resource2, RESOURCE2)
self.assertIsNot(service1.client, service2.client)
def test_kwargs_injection(self):
container = Container()
client1 = self._run(container.client())
client2 = self._run(container.client())
self.assertIsInstance(client1, Client)
self.assertIs(client1.resource1, RESOURCE1)
self.assertIs(client1.resource2, RESOURCE2)
self.assertIsInstance(client2, Client)
self.assertIs(client2.resource1, RESOURCE1)
self.assertIs(client2.resource2, RESOURCE2)
service1 = self._run(container.service())
service2 = self._run(container.service())
self.assertIsInstance(service1, Service)
self.assertIsInstance(service1.client, Client)
self.assertIs(service1.client.resource1, RESOURCE1)
self.assertIs(service1.client.resource2, RESOURCE2)
self.assertIsInstance(service2, Service)
self.assertIsInstance(service2.client, Client)
self.assertIs(service2.client.resource1, RESOURCE1)
self.assertIs(service2.client.resource2, RESOURCE2)
self.assertIsNot(service1.client, service2.client)
def test_context_kwargs_injection(self):
resource2_extra = object()
container = Container()
client1 = self._run(container.client(resource2=resource2_extra))
client2 = self._run(container.client(resource2=resource2_extra))
self.assertIsInstance(client1, Client)
self.assertIs(client1.resource1, RESOURCE1)
self.assertIs(client1.resource2, resource2_extra)
self.assertIsInstance(client2, Client)
self.assertIs(client2.resource1, RESOURCE1)
self.assertIs(client2.resource2, resource2_extra)
def test_args_kwargs_injection(self):
class ContainerWithArgsAndKwArgs(containers.DeclarativeContainer):
resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1))
resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2))
client = providers.Factory(
Client,
resource1,
resource2=resource2,
)
service = providers.Factory(
Service,
client=client,
)
container = ContainerWithArgsAndKwArgs()
client1 = self._run(container.client())
client2 = self._run(container.client())
self.assertIsInstance(client1, Client)
self.assertIs(client1.resource1, RESOURCE1)
self.assertIs(client1.resource2, RESOURCE2)
self.assertIsInstance(client2, Client)
self.assertIs(client2.resource1, RESOURCE1)
self.assertIs(client2.resource2, RESOURCE2)
service1 = self._run(container.service())
service2 = self._run(container.service())
self.assertIsInstance(service1, Service)
self.assertIsInstance(service1.client, Client)
self.assertIs(service1.client.resource1, RESOURCE1)
self.assertIs(service1.client.resource2, RESOURCE2)
self.assertIsInstance(service2, Service)
self.assertIsInstance(service2.client, Client)
self.assertIs(service2.client.resource1, RESOURCE1)
self.assertIs(service2.client.resource2, RESOURCE2)
self.assertIsNot(service1.client, service2.client)
def test_attributes_injection(self):
class ContainerWithAttributes(containers.DeclarativeContainer):
resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1))
resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2))
client = providers.Factory(
Client,
resource1,
resource2=None,
)
client.add_attributes(resource2=resource2)
service = providers.Factory(
Service,
client=None,
)
service.add_attributes(client=client)
container = ContainerWithAttributes()
client1 = self._run(container.client())
client2 = self._run(container.client())
self.assertIsInstance(client1, Client)
self.assertIs(client1.resource1, RESOURCE1)
self.assertIs(client1.resource2, RESOURCE2)
self.assertIsInstance(client2, Client)
self.assertIs(client2.resource1, RESOURCE1)
self.assertIs(client2.resource2, RESOURCE2)
service1 = self._run(container.service())
service2 = self._run(container.service())
self.assertIsInstance(service1, Service)
self.assertIsInstance(service1.client, Client)
self.assertIs(service1.client.resource1, RESOURCE1)
self.assertIs(service1.client.resource2, RESOURCE2)
self.assertIsInstance(service2, Service)
self.assertIsInstance(service2.client, Client)
self.assertIs(service2.client.resource1, RESOURCE1)
self.assertIs(service2.client.resource2, RESOURCE2)
self.assertIsNot(service1.client, service2.client)
class FactoryAggregateTests(AsyncTestCase):
def test_async_mode(self):
object1 = object()
object2 = object()
async def _get_object1():
return object1
def _get_object2():
return object2
provider = providers.FactoryAggregate(
object1=providers.Factory(_get_object1),
object2=providers.Factory(_get_object2),
)
self.assertTrue(provider.is_async_mode_undefined())
created_object1 = self._run(provider('object1'))
self.assertIs(created_object1, object1)
self.assertTrue(provider.is_async_mode_enabled())
created_object2 = self._run(provider('object2'))
self.assertIs(created_object2, object2)
class SingletonTests(AsyncTestCase):
def test_injections(self):
class ContainerWithSingletons(containers.DeclarativeContainer):
resource1 = providers.Resource(init_resource, providers.Object(RESOURCE1))
resource2 = providers.Resource(init_resource, providers.Object(RESOURCE2))
client = providers.Singleton(
Client,
resource1=resource1,
resource2=resource2,
)
service = providers.Singleton(
Service,
client=client,
)
container = ContainerWithSingletons()
client1 = self._run(container.client())
client2 = self._run(container.client())
self.assertIsInstance(client1, Client)
self.assertIs(client1.resource1, RESOURCE1)
self.assertIs(client1.resource2, RESOURCE2)
self.assertIsInstance(client2, Client)
self.assertIs(client2.resource1, RESOURCE1)
self.assertIs(client2.resource2, RESOURCE2)
service1 = self._run(container.service())
service2 = self._run(container.service())
self.assertIsInstance(service1, Service)
self.assertIsInstance(service1.client, Client)
self.assertIs(service1.client.resource1, RESOURCE1)
self.assertIs(service1.client.resource2, RESOURCE2)
self.assertIsInstance(service2, Service)
self.assertIsInstance(service2.client, Client)
self.assertIs(service2.client.resource1, RESOURCE1)
self.assertIs(service2.client.resource2, RESOURCE2)
self.assertIs(service1, service2)
self.assertIs(service1.client, service2.client)
self.assertIs(service1.client, client1)
self.assertIs(service2.client, client2)
self.assertIs(client1, client2)
def test_async_mode(self):
instance = object()
async def create_instance():
return instance
provider = providers.Singleton(create_instance)
instance1 = self._run(provider())
instance2 = self._run(provider())
self.assertIs(instance1, instance2)
self.assertIs(instance, instance)
def test_async_init_with_error(self):
# Disable default exception handling to prevent output
asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
async def create_instance():
create_instance.counter += 1
raise RuntimeError()
create_instance.counter = 0
provider = providers.Singleton(create_instance)
future = provider()
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(future)
self.assertEqual(create_instance.counter, 1)
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(provider())
self.assertEqual(create_instance.counter, 2)
self.assertTrue(provider.is_async_mode_enabled())
# Restore default exception handling
asyncio.get_event_loop().set_exception_handler(None)
class DelegatedSingletonTests(AsyncTestCase):
def test_async_mode(self):
instance = object()
async def create_instance():
return instance
provider = providers.DelegatedSingleton(create_instance)
instance1 = self._run(provider())
instance2 = self._run(provider())
self.assertIs(instance1, instance2)
self.assertIs(instance, instance)
class ThreadSafeSingletonTests(AsyncTestCase):
def test_async_mode(self):
instance = object()
async def create_instance():
return instance
provider = providers.ThreadSafeSingleton(create_instance)
instance1 = self._run(provider())
instance2 = self._run(provider())
self.assertIs(instance1, instance2)
self.assertIs(instance, instance)
class DelegatedThreadSafeSingletonTests(AsyncTestCase):
def test_async_mode(self):
instance = object()
async def create_instance():
return instance
provider = providers.DelegatedThreadSafeSingleton(create_instance)
instance1 = self._run(provider())
instance2 = self._run(provider())
self.assertIs(instance1, instance2)
self.assertIs(instance, instance)
class ThreadLocalSingletonTests(AsyncTestCase):
def test_async_mode(self):
instance = object()
async def create_instance():
return instance
provider = providers.ThreadLocalSingleton(create_instance)
instance1 = self._run(provider())
instance2 = self._run(provider())
self.assertIs(instance1, instance2)
self.assertIs(instance, instance)
def test_async_init_with_error(self):
# Disable default exception handling to prevent output
asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
async def create_instance():
create_instance.counter += 1
raise RuntimeError()
create_instance.counter = 0
provider = providers.ThreadLocalSingleton(create_instance)
future = provider()
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(future)
self.assertEqual(create_instance.counter, 1)
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(provider())
self.assertEqual(create_instance.counter, 2)
self.assertTrue(provider.is_async_mode_enabled())
# Restore default exception handling
asyncio.get_event_loop().set_exception_handler(None)
class DelegatedThreadLocalSingletonTests(AsyncTestCase):
def test_async_mode(self):
instance = object()
async def create_instance():
return instance
provider = providers.DelegatedThreadLocalSingleton(create_instance)
instance1 = self._run(provider())
instance2 = self._run(provider())
self.assertIs(instance1, instance2)
self.assertIs(instance, instance)
class ProvidedInstanceTests(AsyncTestCase):
def test_provided_attribute(self):
class TestClient:
def __init__(self, resource):
self.resource = resource
class TestService:
def __init__(self, resource):
self.resource = resource
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(TestClient, resource=resource)
service = providers.Factory(TestService, resource=client.provided.resource)
container = TestContainer()
instance1, instance2 = self._run(
asyncio.gather(
container.service(),
container.service(),
),
)
self.assertIs(instance1.resource, RESOURCE1)
self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource)
def test_provided_item(self):
class TestClient:
def __init__(self, resource):
self.resource = resource
def __getitem__(self, item):
return getattr(self, item)
class TestService:
def __init__(self, resource):
self.resource = resource
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(TestClient, resource=resource)
service = providers.Factory(TestService, resource=client.provided['resource'])
container = TestContainer()
instance1, instance2 = self._run(
asyncio.gather(
container.service(),
container.service(),
),
)
self.assertIs(instance1.resource, RESOURCE1)
self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource)
def test_provided_method_call(self):
class TestClient:
def __init__(self, resource):
self.resource = resource
def get_resource(self):
return self.resource
class TestService:
def __init__(self, resource):
self.resource = resource
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(TestClient, resource=resource)
service = providers.Factory(TestService, resource=client.provided.get_resource.call())
container = TestContainer()
instance1, instance2 = self._run(
asyncio.gather(
container.service(),
container.service(),
),
)
self.assertIs(instance1.resource, RESOURCE1)
self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource)
class DependencyTests(AsyncTestCase):
def test_isinstance(self):
dependency = 1.0
async def get_async():
return dependency
provider = providers.Dependency(instance_of=float)
provider.override(providers.Callable(get_async))
self.assertTrue(provider.is_async_mode_undefined())
dependency1 = self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
dependency2 = self._run(provider())
self.assertEqual(dependency1, dependency)
self.assertEqual(dependency2, dependency)
def test_isinstance_invalid(self):
async def get_async():
return {}
provider = providers.Dependency(instance_of=float)
provider.override(providers.Callable(get_async))
self.assertTrue(provider.is_async_mode_undefined())
with self.assertRaises(errors.Error):
self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
def test_async_mode(self):
dependency = 123
async def get_async():
return dependency
def get_sync():
return dependency
provider = providers.Dependency(instance_of=int)
provider.override(providers.Factory(get_async))
self.assertTrue(provider.is_async_mode_undefined())
dependency1 = self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
dependency2 = self._run(provider())
self.assertEqual(dependency1, dependency)
self.assertEqual(dependency2, dependency)
provider.override(providers.Factory(get_sync))
dependency3 = self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
dependency4 = self._run(provider())
self.assertEqual(dependency3, dependency)
self.assertEqual(dependency4, dependency)
class OverrideTests(AsyncTestCase):
def test_provider(self):
dependency = object()
async def _get_dependency_async():
return dependency
def _get_dependency_sync():
return dependency
provider = providers.Provider()
provider.override(providers.Callable(_get_dependency_async))
dependency1 = self._run(provider())
provider.override(providers.Callable(_get_dependency_sync))
dependency2 = self._run(provider())
self.assertIs(dependency1, dependency)
self.assertIs(dependency2, dependency)
def test_callable(self):
dependency = object()
async def _get_dependency_async():
return dependency
def _get_dependency_sync():
return dependency
provider = providers.Callable(_get_dependency_async)
dependency1 = self._run(provider())
provider.override(providers.Callable(_get_dependency_sync))
dependency2 = self._run(provider())
self.assertIs(dependency1, dependency)
self.assertIs(dependency2, dependency)
def test_factory(self):
dependency = object()
async def _get_dependency_async():
return dependency
def _get_dependency_sync():
return dependency
provider = providers.Factory(_get_dependency_async)
dependency1 = self._run(provider())
provider.override(providers.Callable(_get_dependency_sync))
dependency2 = self._run(provider())
self.assertIs(dependency1, dependency)
self.assertIs(dependency2, dependency)
def test_async_mode_enabling(self):
dependency = object()
async def _get_dependency_async():
return dependency
provider = providers.Callable(_get_dependency_async)
self.assertTrue(provider.is_async_mode_undefined())
self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
def test_async_mode_disabling(self):
dependency = object()
def _get_dependency():
return dependency
provider = providers.Callable(_get_dependency)
self.assertTrue(provider.is_async_mode_undefined())
provider()
self.assertTrue(provider.is_async_mode_disabled())
def test_async_mode_enabling_on_overriding(self):
dependency = object()
async def _get_dependency_async():
return dependency
provider = providers.Provider()
provider.override(providers.Callable(_get_dependency_async))
self.assertTrue(provider.is_async_mode_undefined())
self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
def test_async_mode_disabling_on_overriding(self):
dependency = object()
def _get_dependency():
return dependency
provider = providers.Provider()
provider.override(providers.Callable(_get_dependency))
self.assertTrue(provider.is_async_mode_undefined())
provider()
self.assertTrue(provider.is_async_mode_disabled())
class TestAsyncModeApi(unittest.TestCase):
def setUp(self):
self.provider = providers.Provider()
def test_default_mode(self):
self.assertFalse(self.provider.is_async_mode_enabled())
self.assertFalse(self.provider.is_async_mode_disabled())
self.assertTrue(self.provider.is_async_mode_undefined())
def test_enable(self):
self.provider.enable_async_mode()
self.assertTrue(self.provider.is_async_mode_enabled())
self.assertFalse(self.provider.is_async_mode_disabled())
self.assertFalse(self.provider.is_async_mode_undefined())
def test_disable(self):
self.provider.disable_async_mode()
self.assertFalse(self.provider.is_async_mode_enabled())
self.assertTrue(self.provider.is_async_mode_disabled())
self.assertFalse(self.provider.is_async_mode_undefined())
def test_reset(self):
self.provider.enable_async_mode()
self.assertTrue(self.provider.is_async_mode_enabled())
self.assertFalse(self.provider.is_async_mode_disabled())
self.assertFalse(self.provider.is_async_mode_undefined())
self.provider.reset_async_mode()
self.assertFalse(self.provider.is_async_mode_enabled())
self.assertFalse(self.provider.is_async_mode_disabled())
self.assertTrue(self.provider.is_async_mode_undefined())
class AsyncTypingStubTests(AsyncTestCase):
def test_async_(self):
container = Container()
client1 = self._run(container.client.async_())
client2 = self._run(container.client.async_())
self.assertIsInstance(client1, Client)
self.assertIs(client1.resource1, RESOURCE1)
self.assertIs(client1.resource2, RESOURCE2)
self.assertIsInstance(client2, Client)
self.assertIs(client2.resource1, RESOURCE1)
self.assertIs(client2.resource2, RESOURCE2)
service1 = self._run(container.service.async_())
service2 = self._run(container.service.async_())
self.assertIsInstance(service1, Service)
self.assertIsInstance(service1.client, Client)
self.assertIs(service1.client.resource1, RESOURCE1)
self.assertIs(service1.client.resource2, RESOURCE2)
self.assertIsInstance(service2, Service)
self.assertIsInstance(service2.client, Client)
self.assertIs(service2.client.resource1, RESOURCE1)
self.assertIs(service2.client.resource2, RESOURCE2)
self.assertIsNot(service1.client, service2.client)

View File

@ -1,9 +1,6 @@
"""Dependency injector coroutine providers unit tests."""
import asyncio
import contextlib
import sys
import gc
import unittest2 as unittest
@ -12,6 +9,19 @@ from dependency_injector import (
errors,
)
# Runtime import to get asyncutils module
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
import sys
sys.path.append(_TOP_DIR)
from asyncutils import AsyncTestCase
async def _example(arg1, arg2, arg3, arg4):
future = asyncio.Future()
@ -25,52 +35,6 @@ def run(main):
return loop.run_until_complete(main)
def setup_test_loop(
loop_factory=asyncio.new_event_loop
) -> asyncio.AbstractEventLoop:
loop = loop_factory()
try:
module = loop.__class__.__module__
skip_watcher = 'uvloop' in module
except AttributeError: # pragma: no cover
# Just in case
skip_watcher = True
asyncio.set_event_loop(loop)
if sys.platform != "win32" and not skip_watcher:
policy = asyncio.get_event_loop_policy()
watcher = asyncio.SafeChildWatcher() # type: ignore
watcher.attach_loop(loop)
with contextlib.suppress(NotImplementedError):
policy.set_child_watcher(watcher)
return loop
def teardown_test_loop(loop: asyncio.AbstractEventLoop,
fast: bool=False) -> None:
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
if not fast:
gc.collect()
asyncio.set_event_loop(None)
class AsyncTestCase(unittest.TestCase):
def setUp(self):
self.loop = setup_test_loop()
def tearDown(self):
teardown_test_loop(self.loop)
def _run(self, f):
return self.loop.run_until_complete(f)
class CoroutineTests(AsyncTestCase):
def test_init_with_coroutine(self):

View File

@ -520,6 +520,24 @@ class FactoryAggregateTests(unittest.TestCase):
self.assertEqual(object_b.init_arg3, 33)
self.assertEqual(object_b.init_arg4, 44)
def test_call_factory_name_as_kwarg(self):
object_a = self.factory_aggregate(
factory_name='example_a',
init_arg1=1,
init_arg2=2,
init_arg3=3,
init_arg4=4,
)
self.assertIsInstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1)
self.assertEqual(object_a.init_arg2, 2)
self.assertEqual(object_a.init_arg3, 3)
self.assertEqual(object_a.init_arg4, 4)
def test_call_no_factory_name(self):
with self.assertRaises(TypeError):
self.factory_aggregate()
def test_call_no_such_provider(self):
with self.assertRaises(errors.NoSuchProviderError):
self.factory_aggregate('unknown')

View File

@ -1,11 +1,24 @@
"""Dependency injector resource provider unit tests."""
import sys
import asyncio
import unittest2 as unittest
from dependency_injector import containers, providers, resources, errors
# Runtime import to get asyncutils module
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
import sys
sys.path.append(_TOP_DIR)
from asyncutils import AsyncTestCase
def init_fn(*args, **kwargs):
return args, kwargs
@ -156,6 +169,15 @@ class ResourceTests(unittest.TestCase):
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 2)
def test_shutdown_of_not_initialized(self):
def _init():
yield
provider = providers.Resource(_init)
result = provider.shutdown()
self.assertIsNone(result)
def test_initialized(self):
provider = providers.Resource(init_fn)
self.assertFalse(provider.initialized)
@ -320,3 +342,186 @@ class ResourceTests(unittest.TestCase):
provider.initialized,
)
)
class AsyncResourceTest(AsyncTestCase):
def test_init_async_function(self):
resource = object()
async def _init():
await asyncio.sleep(0.001)
_init.counter += 1
return resource
_init.counter = 0
provider = providers.Resource(_init)
result1 = self._run(provider())
self.assertIs(result1, resource)
self.assertEqual(_init.counter, 1)
result2 = self._run(provider())
self.assertIs(result2, resource)
self.assertEqual(_init.counter, 1)
self._run(provider.shutdown())
def test_init_async_generator(self):
resource = object()
async def _init():
await asyncio.sleep(0.001)
_init.init_counter += 1
yield resource
await asyncio.sleep(0.001)
_init.shutdown_counter += 1
_init.init_counter = 0
_init.shutdown_counter = 0
provider = providers.Resource(_init)
result1 = self._run(provider())
self.assertIs(result1, resource)
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 0)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 1)
result2 = self._run(provider())
self.assertIs(result2, resource)
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 1)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 2)
def test_init_async_class(self):
resource = object()
class TestResource(resources.AsyncResource):
init_counter = 0
shutdown_counter = 0
async def init(self):
await asyncio.sleep(0.001)
self.__class__.init_counter += 1
return resource
async def shutdown(self, resource_):
await asyncio.sleep(0.001)
self.__class__.shutdown_counter += 1
assert resource_ is resource
provider = providers.Resource(TestResource)
result1 = self._run(provider())
self.assertIs(result1, resource)
self.assertEqual(TestResource.init_counter, 1)
self.assertEqual(TestResource.shutdown_counter, 0)
self._run(provider.shutdown())
self.assertEqual(TestResource.init_counter, 1)
self.assertEqual(TestResource.shutdown_counter, 1)
result2 = self._run(provider())
self.assertIs(result2, resource)
self.assertEqual(TestResource.init_counter, 2)
self.assertEqual(TestResource.shutdown_counter, 1)
self._run(provider.shutdown())
self.assertEqual(TestResource.init_counter, 2)
self.assertEqual(TestResource.shutdown_counter, 2)
def test_init_with_error(self):
async def _init():
raise RuntimeError()
provider = providers.Resource(_init)
future = provider()
self.assertTrue(provider.initialized)
self.assertTrue(provider.is_async_mode_enabled())
# Disable default exception handling to prevent output
asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
with self.assertRaises(RuntimeError):
self._run(future)
# Restore default exception handling
asyncio.get_event_loop().set_exception_handler(None)
self.assertFalse(provider.initialized)
self.assertTrue(provider.is_async_mode_enabled())
def test_init_and_shutdown_methods(self):
async def _init():
await asyncio.sleep(0.001)
_init.init_counter += 1
yield
await asyncio.sleep(0.001)
_init.shutdown_counter += 1
_init.init_counter = 0
_init.shutdown_counter = 0
provider = providers.Resource(_init)
self._run(provider.init())
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 0)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 1)
self.assertEqual(_init.shutdown_counter, 1)
self._run(provider.init())
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 1)
self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 2)
self.assertEqual(_init.shutdown_counter, 2)
def test_shutdown_of_not_initialized(self):
async def _init():
yield
provider = providers.Resource(_init)
provider.enable_async_mode()
result = self._run(provider.shutdown())
self.assertIsNone(result)
def test_concurrent_init(self):
resource = object()
async def _init():
await asyncio.sleep(0.001)
_init.counter += 1
return resource
_init.counter = 0
provider = providers.Resource(_init)
result1, result2 = self._run(
asyncio.gather(
provider(),
provider()
),
)
self.assertIs(result1, resource)
self.assertEqual(_init.counter, 1)
self.assertIs(result2, resource)
self.assertEqual(_init.counter, 1)

View File

@ -0,0 +1,50 @@
import asyncio
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing
class TestResource:
def __init__(self):
self.init_counter = 0
self.shutdown_counter = 0
def reset_counters(self):
self.init_counter = 0
self.shutdown_counter = 0
resource1 = TestResource()
resource2 = TestResource()
async def async_resource(resource):
await asyncio.sleep(0.001)
resource.init_counter += 1
yield resource
await asyncio.sleep(0.001)
resource.shutdown_counter += 1
class Container(containers.DeclarativeContainer):
resource1 = providers.Resource(async_resource, providers.Object(resource1))
resource2 = providers.Resource(async_resource, providers.Object(resource2))
@inject
async def async_injection(
resource1: object = Provide[Container.resource1],
resource2: object = Provide[Container.resource2],
):
return resource1, resource2
@inject
async def async_injection_with_closing(
resource1: object = Closing[Provide[Container.resource1]],
resource2: object = Closing[Provide[Container.resource2]],
):
return resource1, resource2

View File

@ -5,6 +5,12 @@ from dependency_injector.wiring import wire, Provide, Closing
# Runtime import to avoid syntax errors in samples on Python < 3.5
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
@ -12,8 +18,11 @@ _SAMPLES_DIR = os.path.abspath(
)),
)
import sys
sys.path.append(_TOP_DIR)
sys.path.append(_SAMPLES_DIR)
from asyncutils import AsyncTestCase
from wiringsamples import module, package
from wiringsamples.service import Service
from wiringsamples.container import Container, SubContainer
@ -267,3 +276,56 @@ class WiringAndFastAPITest(unittest.TestCase):
self.assertEqual(result_2.shutdown_counter, 2)
self.assertIsNot(result_1, result_2)
class WiringAsyncInjectionsTest(AsyncTestCase):
def test_async_injections(self):
from wiringsamples import asyncinjections
container = asyncinjections.Container()
container.wire(modules=[asyncinjections])
self.addCleanup(container.unwire)
asyncinjections.resource1.reset_counters()
asyncinjections.resource2.reset_counters()
resource1, resource2 = self._run(asyncinjections.async_injection())
self.assertIs(resource1, asyncinjections.resource1)
self.assertEqual(asyncinjections.resource1.init_counter, 1)
self.assertEqual(asyncinjections.resource1.shutdown_counter, 0)
self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 1)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 0)
def test_async_injections_with_closing(self):
from wiringsamples import asyncinjections
container = asyncinjections.Container()
container.wire(modules=[asyncinjections])
self.addCleanup(container.unwire)
asyncinjections.resource1.reset_counters()
asyncinjections.resource2.reset_counters()
resource1, resource2 = self._run(asyncinjections.async_injection_with_closing())
self.assertIs(resource1, asyncinjections.resource1)
self.assertEqual(asyncinjections.resource1.init_counter, 1)
self.assertEqual(asyncinjections.resource1.shutdown_counter, 1)
self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 1)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 1)
resource1, resource2 = self._run(asyncinjections.async_injection_with_closing())
self.assertIs(resource1, asyncinjections.resource1)
self.assertEqual(asyncinjections.resource1.init_counter, 2)
self.assertEqual(asyncinjections.resource1.shutdown_counter, 2)
self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 2)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 2)

View File

@ -1,13 +1,13 @@
import asyncio
import contextlib
import gc
import unittest
from unittest import mock
from httpx import AsyncClient
# Runtime import to avoid syntax errors in samples on Python < 3.5
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
@ -15,58 +15,14 @@ _SAMPLES_DIR = os.path.abspath(
)),
)
import sys
sys.path.append(_TOP_DIR)
sys.path.append(_SAMPLES_DIR)
from asyncutils import AsyncTestCase
from wiringfastapi import web
# TODO: Refactor to use common async test case
def setup_test_loop(
loop_factory=asyncio.new_event_loop
) -> asyncio.AbstractEventLoop:
loop = loop_factory()
try:
module = loop.__class__.__module__
skip_watcher = 'uvloop' in module
except AttributeError: # pragma: no cover
# Just in case
skip_watcher = True
asyncio.set_event_loop(loop)
if sys.platform != "win32" and not skip_watcher:
policy = asyncio.get_event_loop_policy()
watcher = asyncio.SafeChildWatcher() # type: ignore
watcher.attach_loop(loop)
with contextlib.suppress(NotImplementedError):
policy.set_child_watcher(watcher)
return loop
def teardown_test_loop(loop: asyncio.AbstractEventLoop,
fast: bool=False) -> None:
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
if not fast:
gc.collect()
asyncio.set_event_loop(None)
class AsyncTestCase(unittest.TestCase):
def setUp(self):
self.loop = setup_test_loop()
def tearDown(self):
teardown_test_loop(self.loop)
def _run(self, f):
return self.loop.run_until_complete(f)
class WiringFastAPITest(AsyncTestCase):
client: AsyncClient