mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-02-27 17:25:03 +03:00
Add support for typing.Annotated (#721)
This commit is contained in:
parent
29ae3e1337
commit
2330122de6
|
@ -2,10 +2,10 @@
|
||||||
|
|
||||||
from dependency_injector import containers, providers
|
from dependency_injector import containers, providers
|
||||||
from dependency_injector.wiring import Provide, inject
|
from dependency_injector.wiring import Provide, inject
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
|
||||||
class Service:
|
class Service: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class Container(containers.DeclarativeContainer):
|
class Container(containers.DeclarativeContainer):
|
||||||
|
@ -13,9 +13,16 @@ class Container(containers.DeclarativeContainer):
|
||||||
service = providers.Factory(Service)
|
service = providers.Factory(Service)
|
||||||
|
|
||||||
|
|
||||||
|
# You can place marker on parameter default value
|
||||||
@inject
|
@inject
|
||||||
def main(service: Service = Provide[Container.service]) -> None:
|
def main(service: Service = Provide[Container.service]) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
# Also, you can place marker with typing.Annotated
|
||||||
|
@inject
|
||||||
|
def main_with_annotated(
|
||||||
|
service: Annotated[Service, Provide[Container.service]]
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -18,5 +18,6 @@ numpy
|
||||||
scipy
|
scipy
|
||||||
boto3
|
boto3
|
||||||
mypy_boto3_s3
|
mypy_boto3_s3
|
||||||
|
typing_extensions
|
||||||
|
|
||||||
-r requirements-ext.txt
|
-r requirements-ext.txt
|
||||||
|
|
|
@ -37,6 +37,21 @@ if sys.version_info >= (3, 9):
|
||||||
else:
|
else:
|
||||||
GenericAlias = None
|
GenericAlias = None
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
|
from typing import Annotated, get_args, get_origin
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from typing_extensions import Annotated, get_args, get_origin
|
||||||
|
except ImportError:
|
||||||
|
Annotated = object()
|
||||||
|
|
||||||
|
# For preventing NameError. Never executes
|
||||||
|
def get_args(hint):
|
||||||
|
return ()
|
||||||
|
|
||||||
|
def get_origin(tp):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fastapi.params
|
import fastapi.params
|
||||||
|
@ -572,6 +587,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None:
|
||||||
setattr(patched.member, patched.name, patched.marker)
|
setattr(patched.member, patched.name, patched.marker)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
|
||||||
|
if get_origin(parameter.annotation) is Annotated:
|
||||||
|
marker = get_args(parameter.annotation)[1]
|
||||||
|
else:
|
||||||
|
marker = parameter.default
|
||||||
|
|
||||||
|
if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if _is_fastapi_depends(marker):
|
||||||
|
marker = marker.dependency
|
||||||
|
|
||||||
|
if not isinstance(marker, _Marker):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return marker
|
||||||
|
|
||||||
|
|
||||||
def _fetch_reference_injections( # noqa: C901
|
def _fetch_reference_injections( # noqa: C901
|
||||||
fn: Callable[..., Any],
|
fn: Callable[..., Any],
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
@ -596,19 +629,11 @@ def _fetch_reference_injections( # noqa: C901
|
||||||
injections = {}
|
injections = {}
|
||||||
closing = {}
|
closing = {}
|
||||||
for parameter_name, parameter in signature.parameters.items():
|
for parameter_name, parameter in signature.parameters.items():
|
||||||
if not isinstance(parameter.default, _Marker) and not _is_fastapi_depends(
|
marker = _extract_marker(parameter)
|
||||||
parameter.default
|
|
||||||
):
|
if marker is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
marker = parameter.default
|
|
||||||
|
|
||||||
if _is_fastapi_depends(marker):
|
|
||||||
marker = marker.dependency
|
|
||||||
|
|
||||||
if not isinstance(marker, _Marker):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(marker, Closing):
|
if isinstance(marker, Closing):
|
||||||
marker = marker.provider
|
marker = marker.provider
|
||||||
closing[parameter_name] = marker
|
closing[parameter_name] = marker
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from fastapi import FastAPI, Depends
|
from fastapi import FastAPI, Depends
|
||||||
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
|
from fastapi import (
|
||||||
|
Request,
|
||||||
|
) # See: https://github.com/ets-labs/python-dependency-injector/issues/398
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from dependency_injector import containers, providers
|
from dependency_injector import containers, providers
|
||||||
from dependency_injector.wiring import inject, Provide
|
from dependency_injector.wiring import inject, Provide
|
||||||
|
@ -28,11 +32,16 @@ async def index(service: Service = Depends(Provide[Container.service])):
|
||||||
return {"result": result}
|
return {"result": result}
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/annotated")
|
||||||
|
@inject
|
||||||
|
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
|
||||||
|
result = await service.process()
|
||||||
|
return {"result": result}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/auth")
|
@app.get("/auth")
|
||||||
@inject
|
@inject
|
||||||
def read_current_user(
|
def read_current_user(credentials: HTTPBasicCredentials = Depends(security)):
|
||||||
credentials: HTTPBasicCredentials = Depends(security)
|
|
||||||
):
|
|
||||||
return {"username": credentials.username, "password": credentials.password}
|
return {"username": credentials.username, "password": credentials.password}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from flask import Flask, jsonify, request, current_app, session, g
|
from flask import Flask, jsonify, request, current_app, session, g
|
||||||
from dependency_injector import containers, providers
|
from dependency_injector import containers, providers
|
||||||
from dependency_injector.wiring import inject, Provide
|
from dependency_injector.wiring import inject, Provide
|
||||||
|
@ -26,5 +28,12 @@ def index(service: Service = Provide[Container.service]):
|
||||||
return jsonify({"result": result})
|
return jsonify({"result": result})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/annotated")
|
||||||
|
@inject
|
||||||
|
def annotated(service: Annotated[Service, Provide[Container.service]]):
|
||||||
|
result = service.process()
|
||||||
|
return jsonify({"result": result})
|
||||||
|
|
||||||
|
|
||||||
container = Container()
|
container = Container()
|
||||||
container.wire(modules=[__name__])
|
container.wire(modules=[__name__])
|
||||||
|
|
|
@ -4,13 +4,17 @@ from pytest_asyncio import fixture as aio_fixture
|
||||||
|
|
||||||
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
||||||
import os
|
import os
|
||||||
|
|
||||||
_SAMPLES_DIR = os.path.abspath(
|
_SAMPLES_DIR = os.path.abspath(
|
||||||
os.path.sep.join((
|
os.path.sep.join(
|
||||||
os.path.dirname(__file__),
|
(
|
||||||
"../samples/",
|
os.path.dirname(__file__),
|
||||||
)),
|
"../samples/",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(_SAMPLES_DIR)
|
sys.path.append(_SAMPLES_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,6 +41,19 @@ async def test_depends_marker_injection(async_client: AsyncClient):
|
||||||
assert response.json() == {"result": "Foo"}
|
assert response.json() == {"result": "Foo"}
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_depends_with_annotated(async_client: AsyncClient):
|
||||||
|
class ServiceMock:
|
||||||
|
async def process(self):
|
||||||
|
return "Foo"
|
||||||
|
|
||||||
|
with web.container.service.override(ServiceMock()):
|
||||||
|
response = await async_client.get("/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"result": "Foo"}
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
async def test_depends_injection(async_client: AsyncClient):
|
async def test_depends_injection(async_client: AsyncClient):
|
||||||
response = await async_client.get("/auth", auth=("john_smith", "secret"))
|
response = await async_client.get("/auth", auth=("john_smith", "secret"))
|
||||||
|
|
|
@ -2,19 +2,25 @@ import json
|
||||||
|
|
||||||
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
||||||
import os
|
import os
|
||||||
|
|
||||||
_TOP_DIR = os.path.abspath(
|
_TOP_DIR = os.path.abspath(
|
||||||
os.path.sep.join((
|
os.path.sep.join(
|
||||||
os.path.dirname(__file__),
|
(
|
||||||
"../",
|
os.path.dirname(__file__),
|
||||||
)),
|
"../",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
_SAMPLES_DIR = os.path.abspath(
|
_SAMPLES_DIR = os.path.abspath(
|
||||||
os.path.sep.join((
|
os.path.sep.join(
|
||||||
os.path.dirname(__file__),
|
(
|
||||||
"../samples/",
|
os.path.dirname(__file__),
|
||||||
)),
|
"../samples/",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(_TOP_DIR)
|
sys.path.append(_TOP_DIR)
|
||||||
sys.path.append(_SAMPLES_DIR)
|
sys.path.append(_SAMPLES_DIR)
|
||||||
|
|
||||||
|
@ -29,3 +35,13 @@ def test_wiring_with_flask():
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert json.loads(response.data) == {"result": "OK"}
|
assert json.loads(response.data) == {"result": "OK"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_wiring_with_annotated():
|
||||||
|
client = web.app.test_client()
|
||||||
|
|
||||||
|
with web.app.app_context():
|
||||||
|
response = client.get("/annotated")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert json.loads(response.data) == {"result": "OK"}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user