+
Github Navigator
+
+
+
+
Results found: {{ repositories|length }}
+
+
+
+
+{% endblock %}
diff --git a/examples/miniapps/flask-blueprints/githubnavigator/tests.py b/examples/miniapps/flask-blueprints/githubnavigator/tests.py
new file mode 100644
index 00000000..ddf4c256
--- /dev/null
+++ b/examples/miniapps/flask-blueprints/githubnavigator/tests.py
@@ -0,0 +1,71 @@
+"""Tests module."""
+
+from unittest import mock
+
+import pytest
+from github import Github
+from flask import url_for
+
+from .application import create_app
+
+
+@pytest.fixture
+def app():
+ app = create_app()
+ yield app
+ app.container.unwire()
+
+
+def test_index(client, app):
+ github_client_mock = mock.Mock(spec=Github)
+ github_client_mock.search_repositories.return_value = [
+ mock.Mock(
+ html_url='repo1-url',
+ name='repo1-name',
+ owner=mock.Mock(
+ login='owner1-login',
+ html_url='owner1-url',
+ avatar_url='owner1-avatar-url',
+ ),
+ get_commits=mock.Mock(return_value=[mock.Mock()]),
+ ),
+ mock.Mock(
+ html_url='repo2-url',
+ name='repo2-name',
+ owner=mock.Mock(
+ login='owner2-login',
+ html_url='owner2-url',
+ avatar_url='owner2-avatar-url',
+ ),
+ get_commits=mock.Mock(return_value=[mock.Mock()]),
+ ),
+ ]
+
+ with app.container.github_client.override(github_client_mock):
+ response = client.get(url_for('example.index'))
+
+ assert response.status_code == 200
+ assert b'Results found: 2' in response.data
+
+ assert b'repo1-url' in response.data
+ assert b'repo1-name' in response.data
+ assert b'owner1-login' in response.data
+ assert b'owner1-url' in response.data
+ assert b'owner1-avatar-url' in response.data
+
+ assert b'repo2-url' in response.data
+ assert b'repo2-name' in response.data
+ assert b'owner2-login' in response.data
+ assert b'owner2-url' in response.data
+ assert b'owner2-avatar-url' in response.data
+
+
+def test_index_no_results(client, app):
+ github_client_mock = mock.Mock(spec=Github)
+ github_client_mock.search_repositories.return_value = []
+
+ with app.container.github_client.override(github_client_mock):
+ response = client.get(url_for('example.index'))
+
+ assert response.status_code == 200
+ assert b'Results found: 0' in response.data
diff --git a/examples/miniapps/flask-blueprints/requirements.txt b/examples/miniapps/flask-blueprints/requirements.txt
new file mode 100644
index 00000000..78a650f6
--- /dev/null
+++ b/examples/miniapps/flask-blueprints/requirements.txt
@@ -0,0 +1,7 @@
+dependency-injector
+flask
+bootstrap-flask
+pygithub
+pyyaml
+pytest-flask
+pytest-cov
diff --git a/examples/miniapps/flask-blueprints/screenshot.png b/examples/miniapps/flask-blueprints/screenshot.png
new file mode 100644
index 00000000..350aaa67
Binary files /dev/null and b/examples/miniapps/flask-blueprints/screenshot.png differ
diff --git a/examples/miniapps/flask/README.rst b/examples/miniapps/flask/README.rst
index 118c4e52..d825a81e 100644
--- a/examples/miniapps/flask/README.rst
+++ b/examples/miniapps/flask/README.rst
@@ -95,6 +95,6 @@ The output should be something like:
githubnavigator/containers.py 7 0 100%
githubnavigator/services.py 14 0 100%
githubnavigator/tests.py 34 0 100%
- githubnavigator/views.py 9 0 100%
+ githubnavigator/views.py 10 0 100%
----------------------------------------------------
- TOTAL 79 0 100%
+ TOTAL 80 0 100%
diff --git a/examples/miniapps/flask/githubnavigator/views.py b/examples/miniapps/flask/githubnavigator/views.py
index 22a3d35d..5cccdad2 100644
--- a/examples/miniapps/flask/githubnavigator/views.py
+++ b/examples/miniapps/flask/githubnavigator/views.py
@@ -1,12 +1,13 @@
"""Views module."""
from flask import request, render_template
-from dependency_injector.wiring import Provide
+from dependency_injector.wiring import inject, Provide
from .services import SearchService
from .containers import Container
+@inject
def index(
search_service: SearchService = Provide[Container.search_service],
default_query: str = Provide[Container.config.default.query],
diff --git a/examples/miniapps/movie-lister/README.rst b/examples/miniapps/movie-lister/README.rst
index dcc85066..cdcbc20c 100644
--- a/examples/miniapps/movie-lister/README.rst
+++ b/examples/miniapps/movie-lister/README.rst
@@ -68,11 +68,11 @@ The output should be something like:
Name Stmts Miss Cover
------------------------------------------
movies/__init__.py 0 0 100%
- movies/__main__.py 17 17 0%
+ movies/__main__.py 18 18 0%
movies/containers.py 9 0 100%
movies/entities.py 7 1 86%
movies/finders.py 26 13 50%
movies/listers.py 8 0 100%
movies/tests.py 24 0 100%
------------------------------------------
- TOTAL 91 31 66%
+ TOTAL 92 32 65%
diff --git a/examples/miniapps/movie-lister/movies/__main__.py b/examples/miniapps/movie-lister/movies/__main__.py
index df39fdc5..975618f3 100644
--- a/examples/miniapps/movie-lister/movies/__main__.py
+++ b/examples/miniapps/movie-lister/movies/__main__.py
@@ -2,12 +2,13 @@
import sys
-from dependency_injector.wiring import Provide
+from dependency_injector.wiring import inject, Provide
from .listers import MovieLister
from .containers import Container
+@inject
def main(lister: MovieLister = Provide[Container.lister]) -> None:
print('Francis Lawrence movies:')
for movie in lister.movies_directed_by('Francis Lawrence'):
diff --git a/examples/miniapps/sanic/README.rst b/examples/miniapps/sanic/README.rst
index 9550fccc..adce2810 100644
--- a/examples/miniapps/sanic/README.rst
+++ b/examples/miniapps/sanic/README.rst
@@ -112,8 +112,8 @@ The output should be something like:
giphynavigator/application.py 12 0 100%
giphynavigator/containers.py 6 0 100%
giphynavigator/giphy.py 14 9 36%
- giphynavigator/handlers.py 10 0 100%
+ giphynavigator/handlers.py 11 0 100%
giphynavigator/services.py 9 1 89%
giphynavigator/tests.py 34 0 100%
---------------------------------------------------
- TOTAL 89 14 84%
+ TOTAL 90 14 84%
diff --git a/examples/miniapps/sanic/giphynavigator/handlers.py b/examples/miniapps/sanic/giphynavigator/handlers.py
index 6b319a25..3537827d 100644
--- a/examples/miniapps/sanic/giphynavigator/handlers.py
+++ b/examples/miniapps/sanic/giphynavigator/handlers.py
@@ -2,12 +2,13 @@
from sanic.request import Request
from sanic.response import HTTPResponse, json
-from dependency_injector.wiring import Provide
+from dependency_injector.wiring import inject, Provide
from .services import SearchService
from .containers import Container
+@inject
async def index(
request: Request,
search_service: SearchService = Provide[Container.search_service],
diff --git a/examples/wiring/example.py b/examples/wiring/example.py
index fe8f27ee..7a62e830 100644
--- a/examples/wiring/example.py
+++ b/examples/wiring/example.py
@@ -3,7 +3,7 @@
import sys
from dependency_injector import containers, providers
-from dependency_injector.wiring import Provide
+from dependency_injector.wiring import inject, Provide
class Service:
@@ -15,6 +15,7 @@ class Container(containers.DeclarativeContainer):
service = providers.Factory(Service)
+@inject
def main(service: Service = Provide[Container.service]) -> None:
...
diff --git a/examples/wiring/flask_example.py b/examples/wiring/flask_example.py
index 3ec0963f..033d0ee9 100644
--- a/examples/wiring/flask_example.py
+++ b/examples/wiring/flask_example.py
@@ -3,7 +3,7 @@
import sys
from dependency_injector import containers, providers
-from dependency_injector.wiring import Provide
+from dependency_injector.wiring import inject, Provide
from flask import Flask, json
@@ -16,6 +16,7 @@ class Container(containers.DeclarativeContainer):
service = providers.Factory(Service)
+@inject
def index_view(service: Service = Provide[Container.service]) -> str:
return json.dumps({'service_id': id(service)})
diff --git a/examples/wiring/flask_resource_closing.py b/examples/wiring/flask_resource_closing.py
index dec40140..05b62c37 100644
--- a/examples/wiring/flask_resource_closing.py
+++ b/examples/wiring/flask_resource_closing.py
@@ -3,7 +3,7 @@
import sys
from dependency_injector import containers, providers
-from dependency_injector.wiring import Provide, Closing
+from dependency_injector.wiring import inject, Provide, Closing
from flask import Flask, current_app
@@ -22,6 +22,7 @@ class Container(containers.DeclarativeContainer):
service = providers.Resource(init_service)
+@inject
def index_view(service: Service = Closing[Provide[Container.service]]):
assert service is current_app.container.service()
return 'Hello World!'
diff --git a/src/dependency_injector/__init__.py b/src/dependency_injector/__init__.py
index 06ffecd0..78668b04 100644
--- a/src/dependency_injector/__init__.py
+++ b/src/dependency_injector/__init__.py
@@ -1,6 +1,6 @@
"""Top-level package."""
-__version__ = '4.3.9'
+__version__ = '4.4.0'
"""Version number.
:type: str
diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py
index 1c671bb1..1ba73639 100644
--- a/src/dependency_injector/wiring.py
+++ b/src/dependency_injector/wiring.py
@@ -6,7 +6,19 @@ import importlib
import pkgutil
import sys
from types import ModuleType
-from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
+from typing import (
+ Optional,
+ Iterable,
+ Iterator,
+ Callable,
+ Any,
+ Tuple,
+ Dict,
+ Generic,
+ TypeVar,
+ Type,
+ cast,
+)
if sys.version_info < (3, 7):
from typing import GenericMeta
@@ -21,15 +33,35 @@ from . import providers
__all__ = (
'wire',
'unwire',
+ 'inject',
'Provide',
'Provider',
'Closing',
)
T = TypeVar('T')
+F = TypeVar('F', bound=Callable[..., Any])
Container = Any
+class Registry:
+
+ def __init__(self):
+ self._storage = set()
+
+ def add(self, patched: Callable[..., Any]) -> None:
+ self._storage.add(patched)
+
+ def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
+ for patched in self._storage:
+ if patched.__module__ != module.__name__:
+ continue
+ yield patched
+
+
+_patched_registry = Registry()
+
+
class ProvidersMap:
def __init__(self, container):
@@ -152,7 +184,7 @@ class ProvidersMap:
return providers_map
-def wire(
+def wire( # noqa: C901
container: Container,
*,
modules: Optional[Iterable[ModuleType]] = None,
@@ -179,6 +211,9 @@ def wire(
for method_name, method in inspect.getmembers(member, _is_method):
_patch_method(member, method_name, method, providers_map)
+ for patched in _patched_registry.get_from_module(module):
+ _bind_injections(patched, providers_map)
+
def unwire(
*,
@@ -201,6 +236,17 @@ def unwire(
for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch(member, method_name, method)
+ for patched in _patched_registry.get_from_module(module):
+ _unbind_injections(patched)
+
+
+def inject(fn: F) -> F:
+ """Decorate callable with injecting decorator."""
+ reference_injections, reference_closing = _fetch_reference_injections(fn)
+ patched = _get_patched(fn, reference_injections, reference_closing)
+ _patched_registry.add(patched)
+ return cast(F, patched)
+
def _patch_fn(
module: ModuleType,
@@ -208,11 +254,16 @@ def _patch_fn(
fn: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
- injections, closing = _resolve_injections(fn, providers_map)
- if not injections:
- return
- patched = _patch_with_injections(fn, injections, closing)
- setattr(module, name, _wrap_patched(patched, fn, injections, closing))
+ if not _is_patched(fn):
+ reference_injections, reference_closing = _fetch_reference_injections(fn)
+ if not reference_injections:
+ return
+ fn = _get_patched(fn, reference_injections, reference_closing)
+ _patched_registry.add(fn)
+
+ _bind_injections(fn, providers_map)
+
+ setattr(module, name, fn)
def _patch_method(
@@ -221,28 +272,27 @@ def _patch_method(
method: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
- injections, closing = _resolve_injections(method, providers_map)
- if not injections:
- return
-
if hasattr(cls, '__dict__') \
and name in cls.__dict__ \
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
method = cls.__dict__[name]
- patched = _patch_with_injections(method.__func__, injections, closing)
- patched = type(method)(patched)
+ fn = method.__func__
else:
- patched = _patch_with_injections(method, injections, closing)
+ fn = method
- setattr(cls, name, _wrap_patched(patched, method, injections, closing))
+ if not _is_patched(fn):
+ reference_injections, reference_closing = _fetch_reference_injections(fn)
+ if not reference_injections:
+ return
+ fn = _get_patched(fn, reference_injections, reference_closing)
+ _patched_registry.add(fn)
+ _bind_injections(fn, providers_map)
-def _wrap_patched(patched: Callable[..., Any], original, injections, closing):
- patched.__wired__ = True
- patched.__original__ = original
- patched.__injections__ = injections
- patched.__closing__ = closing
- return patched
+ if isinstance(method, (classmethod, staticmethod)):
+ fn = type(method)(fn)
+
+ setattr(cls, name, fn)
def _unpatch(
@@ -250,14 +300,20 @@ def _unpatch(
name: str,
fn: Callable[..., Any],
) -> None:
+ if hasattr(module, '__dict__') \
+ and name in module.__dict__ \
+ and isinstance(module.__dict__[name], (classmethod, staticmethod)):
+ method = module.__dict__[name]
+ fn = method.__func__
+
if not _is_patched(fn):
return
- setattr(module, name, _get_original_from_patched(fn))
+
+ _unbind_injections(fn)
-def _resolve_injections(
+def _fetch_reference_injections(
fn: Callable[..., Any],
- providers_map: ProvidersMap,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
signature = inspect.signature(fn)
@@ -268,24 +324,33 @@ def _resolve_injections(
continue
marker = parameter.default
- closing_modifier = False
if isinstance(marker, Closing):
- closing_modifier = True
marker = marker.provider
+ closing[parameter_name] = marker
+ injections[parameter_name] = marker
+ return injections, closing
+
+
+def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
+ for injection, marker in fn.__reference_injections__.items():
provider = providers_map.resolve_provider(marker.provider)
+
if provider is None:
continue
- if closing_modifier:
- closing[parameter_name] = provider
-
if isinstance(marker, Provide):
- injections[parameter_name] = provider
+ fn.__injections__[injection] = provider
elif isinstance(marker, Provider):
- injections[parameter_name] = provider.provider
+ fn.__injections__[injection] = provider.provider
- return injections, closing
+ if injection in fn.__reference_closing__:
+ fn.__closing__[injection] = provider
+
+
+def _unbind_injections(fn: Callable[..., Any]) -> None:
+ fn.__injections__ = {}
+ fn.__closing__ = {}
def _fetch_modules(package):
@@ -303,26 +368,34 @@ def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member)
-def _patch_with_injections(fn, injections, closing):
+def _get_patched(fn, reference_injections, reference_closing):
if inspect.iscoroutinefunction(fn):
- _patched = _get_async_patched(fn, injections, closing)
+ patched = _get_async_patched(fn)
else:
- _patched = _get_patched(fn, injections, closing)
- return _patched
+ patched = _get_sync_patched(fn)
+
+ patched.__wired__ = True
+ patched.__original__ = fn
+ patched.__injections__ = {}
+ patched.__reference_injections__ = reference_injections
+ patched.__closing__ = {}
+ patched.__reference_closing__ = reference_closing
+
+ return patched
-def _get_patched(fn, injections, closing):
+def _get_sync_patched(fn):
@functools.wraps(fn)
def _patched(*args, **kwargs):
to_inject = kwargs.copy()
- for injection, provider in injections.items():
+ for injection, provider in _patched.__injections__.items():
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider()
result = fn(*args, **to_inject)
- for injection, provider in closing.items():
+ for injection, provider in _patched.__closing__.items():
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
continue
@@ -334,18 +407,18 @@ def _get_patched(fn, injections, closing):
return _patched
-def _get_async_patched(fn, injections, closing):
+def _get_async_patched(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
to_inject = kwargs.copy()
- for injection, provider in injections.items():
+ for injection, provider in _patched.__injections__.items():
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider()
result = await fn(*args, **to_inject)
- for injection, provider in closing.items():
+ for injection, provider in _patched.__closing__.items():
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
continue
@@ -366,10 +439,6 @@ def _is_patched(fn):
return getattr(fn, '__wired__', False) is True
-def _get_original_from_patched(fn):
- return getattr(fn, '__original__')
-
-
def _is_declarative_container_instance(instance: Any) -> bool:
return (not isinstance(instance, type)
and getattr(instance, '__IS_CONTAINER__', False) is True
diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py
index c39d1414..7a3ae557 100644
--- a/tests/unit/samples/wiringsamples/module.py
+++ b/tests/unit/samples/wiringsamples/module.py
@@ -3,7 +3,7 @@
from decimal import Decimal
from typing import Callable
-from dependency_injector.wiring import Provide, Provider
+from dependency_injector.wiring import inject, Provide, Provider
from .container import Container, SubContainer
from .service import Service
@@ -11,30 +11,37 @@ from .service import Service
class TestClass:
+ @inject
def __init__(self, service: Service = Provide[Container.service]):
self.service = service
+ @inject
def method(self, service: Service = Provide[Container.service]):
return service
@classmethod
+ @inject
def class_method(cls, service: Service = Provide[Container.service]):
return service
@staticmethod
+ @inject
def static_method(service: Service = Provide[Container.service]):
return service
+@inject
def test_function(service: Service = Provide[Container.service]):
return service
+@inject
def test_function_provider(service_provider: Callable[..., Service] = Provider[Container.service]):
service = service_provider()
return service
+@inject
def test_config_value(
some_value_int: int = Provide[Container.config.a.b.c.as_int()],
some_value_str: str = Provide[Container.config.a.b.c.as_(str)],
@@ -43,25 +50,44 @@ def test_config_value(
return some_value_int, some_value_str, some_value_decimal
+@inject
def test_provide_provider(service_provider: Callable[..., Service] = Provider[Container.service.provider]):
service = service_provider()
return service
+@inject
def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]):
return some_value
+@inject
def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_object]):
return some_value
+@inject
def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]):
return some_value
+@inject
def test_provide_from_different_containers(
service: Service = Provide[Container.service],
some_value: int = Provide[SubContainer.int_object],
):
return service, some_value
+
+
+class ClassDecorator:
+ def __init__(self, fn):
+ self._fn = fn
+
+ def __call__(self, *args, **kwargs):
+ return self._fn(*args, **kwargs)
+
+
+@ClassDecorator
+@inject
+def test_class_decorator(service: Service = Provide[Container.service]):
+ return service
diff --git a/tests/unit/samples/wiringsamples/package/subpackage/submodule.py b/tests/unit/samples/wiringsamples/package/subpackage/submodule.py
index 8f8e0d7f..e99c131e 100644
--- a/tests/unit/samples/wiringsamples/package/subpackage/submodule.py
+++ b/tests/unit/samples/wiringsamples/package/subpackage/submodule.py
@@ -1,8 +1,9 @@
-from dependency_injector.wiring import Provide
+from dependency_injector.wiring import inject, Provide
from ...container import Container
from ...service import Service
+@inject
def test_function(service: Service = Provide[Container.service]):
return service
diff --git a/tests/unit/samples/wiringsamples/resourceclosing.py b/tests/unit/samples/wiringsamples/resourceclosing.py
index f7f35bd1..8dfee241 100644
--- a/tests/unit/samples/wiringsamples/resourceclosing.py
+++ b/tests/unit/samples/wiringsamples/resourceclosing.py
@@ -1,5 +1,5 @@
from dependency_injector import containers, providers
-from dependency_injector.wiring import Provide, Closing
+from dependency_injector.wiring import inject, Provide, Closing
class Service:
@@ -32,5 +32,6 @@ class Container(containers.DeclarativeContainer):
service = providers.Resource(init_service)
+@inject
def test_function(service: Service = Closing[Provide[Container.service]]):
return service
diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py
index b1886ea7..7c061ad9 100644
--- a/tests/unit/wiring/test_wiring_py36.py
+++ b/tests/unit/wiring/test_wiring_py36.py
@@ -226,6 +226,10 @@ class WiringTest(unittest.TestCase):
self.assertEqual(result_2.init_counter, 0)
self.assertEqual(result_2.shutdown_counter, 0)
+ def test_class_decorator(self):
+ service = module.test_class_decorator()
+ self.assertIsInstance(service, Service)
+
class WiringAndFastAPITest(unittest.TestCase):