Add support for typing.Annotated (#721)

This commit is contained in:
Taein Min 2025-01-21 00:37:28 +09:00 committed by GitHub
parent 29ae3e1337
commit 2330122de6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 115 additions and 31 deletions

View File

@ -2,10 +2,10 @@
from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from typing import Annotated
class Service:
...
class Service: ...
class Container(containers.DeclarativeContainer):
@ -13,9 +13,16 @@ class Container(containers.DeclarativeContainer):
service = providers.Factory(Service)
# You can place marker on parameter default value
@inject
def main(service: Service = Provide[Container.service]) -> None:
...
def main(service: Service = Provide[Container.service]) -> None: ...
# Also, you can place marker with typing.Annotated
@inject
def main_with_annotated(
service: Annotated[Service, Provide[Container.service]]
) -> None: ...
if __name__ == "__main__":

View File

@ -18,5 +18,6 @@ numpy
scipy
boto3
mypy_boto3_s3
typing_extensions
-r requirements-ext.txt

View File

@ -37,6 +37,21 @@ if sys.version_info >= (3, 9):
else:
GenericAlias = None
if sys.version_info >= (3, 9):
from typing import Annotated, get_args, get_origin
else:
try:
from typing_extensions import Annotated, get_args, get_origin
except ImportError:
Annotated = object()
# For preventing NameError. Never executes
def get_args(hint):
return ()
def get_origin(tp):
return None
try:
import fastapi.params
@ -572,6 +587,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None:
setattr(patched.member, patched.name, patched.marker)
def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
if get_origin(parameter.annotation) is Annotated:
marker = get_args(parameter.annotation)[1]
else:
marker = parameter.default
if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker):
return None
if _is_fastapi_depends(marker):
marker = marker.dependency
if not isinstance(marker, _Marker):
return None
return marker
def _fetch_reference_injections( # noqa: C901
fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@ -596,19 +629,11 @@ def _fetch_reference_injections( # noqa: C901
injections = {}
closing = {}
for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker) and not _is_fastapi_depends(
parameter.default
):
marker = _extract_marker(parameter)
if marker is None:
continue
marker = parameter.default
if _is_fastapi_depends(marker):
marker = marker.dependency
if not isinstance(marker, _Marker):
continue
if isinstance(marker, Closing):
marker = marker.provider
closing[parameter_name] = marker

View File

@ -1,7 +1,11 @@
import sys
from typing_extensions import Annotated
from fastapi import FastAPI, Depends
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
from fastapi import (
Request,
) # See: https://github.com/ets-labs/python-dependency-injector/issues/398
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide
@ -28,11 +32,16 @@ async def index(service: Service = Depends(Provide[Container.service])):
return {"result": result}
@app.api_route("/annotated")
@inject
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
result = await service.process()
return {"result": result}
@app.get("/auth")
@inject
def read_current_user(
credentials: HTTPBasicCredentials = Depends(security)
):
def read_current_user(credentials: HTTPBasicCredentials = Depends(security)):
return {"username": credentials.username, "password": credentials.password}

View File

@ -1,3 +1,5 @@
from typing_extensions import Annotated
from flask import Flask, jsonify, request, current_app, session, g
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide
@ -26,5 +28,12 @@ def index(service: Service = Provide[Container.service]):
return jsonify({"result": result})
@app.route("/annotated")
@inject
def annotated(service: Annotated[Service, Provide[Container.service]]):
result = service.process()
return jsonify({"result": result})
container = Container()
container.wire(modules=[__name__])

View File

@ -4,13 +4,17 @@ from pytest_asyncio import fixture as aio_fixture
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
"../samples/",
)),
os.path.sep.join(
(
os.path.dirname(__file__),
"../samples/",
)
),
)
import sys
sys.path.append(_SAMPLES_DIR)
@ -37,6 +41,19 @@ async def test_depends_marker_injection(async_client: AsyncClient):
assert response.json() == {"result": "Foo"}
@mark.asyncio
async def test_depends_with_annotated(async_client: AsyncClient):
class ServiceMock:
async def process(self):
return "Foo"
with web.container.service.override(ServiceMock()):
response = await async_client.get("/")
assert response.status_code == 200
assert response.json() == {"result": "Foo"}
@mark.asyncio
async def test_depends_injection(async_client: AsyncClient):
response = await async_client.get("/auth", auth=("john_smith", "secret"))

View File

@ -2,19 +2,25 @@ import json
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
"../",
)),
os.path.sep.join(
(
os.path.dirname(__file__),
"../",
)
),
)
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
"../samples/",
)),
os.path.sep.join(
(
os.path.dirname(__file__),
"../samples/",
)
),
)
import sys
sys.path.append(_TOP_DIR)
sys.path.append(_SAMPLES_DIR)
@ -29,3 +35,13 @@ def test_wiring_with_flask():
assert response.status_code == 200
assert json.loads(response.data) == {"result": "OK"}
def test_wiring_with_annotated():
client = web.app.test_client()
with web.app.app_context():
response = client.get("/annotated")
assert response.status_code == 200
assert json.loads(response.data) == {"result": "OK"}