From 633f60ac19b04abaf3de33f9a299f62f8a8fe792 Mon Sep 17 00:00:00 2001 From: Taein Min Date: Sun, 9 Jul 2023 18:19:14 +0900 Subject: [PATCH] Add support for typing.Annotated --- examples/wiring/example.py | 9 ++++- requirements-dev.txt | 1 + src/dependency_injector/wiring.py | 45 +++++++++++++++++++------ tests/unit/samples/wiringfastapi/web.py | 7 ++++ tests/unit/samples/wiringflask/web.py | 9 +++++ tests/unit/wiring/test_fastapi_py36.py | 14 ++++++++ tests/unit/wiring/test_flask_py36.py | 10 ++++++ 7 files changed, 84 insertions(+), 11 deletions(-) diff --git a/examples/wiring/example.py b/examples/wiring/example.py index 4221ab13..5361fa95 100644 --- a/examples/wiring/example.py +++ b/examples/wiring/example.py @@ -2,6 +2,7 @@ from dependency_injector import containers, providers from dependency_injector.wiring import Provide, inject +from typing import Annotated class Service: @@ -12,12 +13,18 @@ class Container(containers.DeclarativeContainer): service = providers.Factory(Service) - +# You can place marker on parameter default value @inject def main(service: Service = Provide[Container.service]) -> None: ... +# Also, you can place marker with typing.Annotated +@inject +def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None: + ... + + if __name__ == "__main__": container = Container() container.wire(modules=[__name__]) diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c101e8c..9e3151d5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,5 +16,6 @@ numpy scipy boto3 mypy_boto3_s3 +typing_extensions -r requirements-ext.txt diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index b1f01622..d88bac9f 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -36,6 +36,20 @@ if sys.version_info >= (3, 9): else: GenericAlias = None +if sys.version_info >= (3, 9): + from typing import Annotated, get_args, get_origin +else: + try: + from typing_extensions import Annotated, get_args, get_origin + except ImportError: + Annotated = object() + + # For preventing NameError. Never executes + def get_args(hint): + return () + + def get_origin(tp): + return None try: import fastapi.params @@ -548,6 +562,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None: setattr(patched.member, patched.name, patched.marker) +def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]: + if get_origin(parameter.annotation) is Annotated: + marker = get_args(parameter.annotation)[1] + else: + marker = parameter.default + + if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker): + return None + + if _is_fastapi_depends(marker): + marker = marker.dependency + + if not isinstance(marker, _Marker): + return None + + return marker + + def _fetch_reference_injections( # noqa: C901 fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -573,18 +605,11 @@ def _fetch_reference_injections( # noqa: C901 injections = {} closing = {} for parameter_name, parameter in signature.parameters.items(): - if not isinstance(parameter.default, _Marker) \ - and not _is_fastapi_depends(parameter.default): + marker = _extract_marker(parameter) + + if marker is None: continue - marker = parameter.default - - if _is_fastapi_depends(marker): - marker = marker.dependency - - if not isinstance(marker, _Marker): - continue - if isinstance(marker, Closing): marker = marker.provider closing[parameter_name] = marker diff --git a/tests/unit/samples/wiringfastapi/web.py b/tests/unit/samples/wiringfastapi/web.py index 3cee5450..f5c22d58 100644 --- a/tests/unit/samples/wiringfastapi/web.py +++ b/tests/unit/samples/wiringfastapi/web.py @@ -1,5 +1,7 @@ import sys +from typing_extensions import Annotated + from fastapi import FastAPI, Depends from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398 from fastapi.security import HTTPBasic, HTTPBasicCredentials @@ -27,6 +29,11 @@ async def index(service: Service = Depends(Provide[Container.service])): result = await service.process() return {"result": result} +@app.api_route('/annotated') +@inject +async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]): + result = await service.process() + return {'result': result} @app.get("/auth") @inject diff --git a/tests/unit/samples/wiringflask/web.py b/tests/unit/samples/wiringflask/web.py index 37fbd5e0..0d21e3f9 100644 --- a/tests/unit/samples/wiringflask/web.py +++ b/tests/unit/samples/wiringflask/web.py @@ -1,3 +1,5 @@ +from typing_extensions import Annotated + from flask import Flask, jsonify, request, current_app, session, g from flask import _request_ctx_stack, _app_ctx_stack from dependency_injector import containers, providers @@ -28,5 +30,12 @@ def index(service: Service = Provide[Container.service]): return jsonify({"result": result}) +@app.route("/annotated") +@inject +def annotated(service: Annotated[Service, Provide[Container.service]]): + result = service.process() + return jsonify({'result': result}) + + container = Container() container.wire(modules=[__name__]) diff --git a/tests/unit/wiring/test_fastapi_py36.py b/tests/unit/wiring/test_fastapi_py36.py index d93cc9c2..74655775 100644 --- a/tests/unit/wiring/test_fastapi_py36.py +++ b/tests/unit/wiring/test_fastapi_py36.py @@ -36,6 +36,20 @@ async def test_depends_marker_injection(async_client: AsyncClient): assert response.json() == {"result": "Foo"} +@mark.asyncio +async def test_depends_with_annotated(async_client: AsyncClient): + class ServiceMock: + async def process(self): + return "Foo" + + with web.container.service.override(ServiceMock()): + response = await async_client.get("/") + + assert response.status_code == 200 + assert response.json() == {"result": "Foo"} + + + @mark.asyncio async def test_depends_injection(async_client: AsyncClient): response = await async_client.get("/auth", auth=("john_smith", "secret")) diff --git a/tests/unit/wiring/test_flask_py36.py b/tests/unit/wiring/test_flask_py36.py index 751f04d8..beb23fdc 100644 --- a/tests/unit/wiring/test_flask_py36.py +++ b/tests/unit/wiring/test_flask_py36.py @@ -29,3 +29,13 @@ def test_wiring_with_flask(): assert response.status_code == 200 assert json.loads(response.data) == {"result": "OK"} + + +def test_wiring_with_annotated(): + client = web.app.test_client() + + with web.app.app_context(): + response = client.get("/annotated") + + assert response.status_code == 200 + assert json.loads(response.data) == {"result": "OK"}