Merge branch 'develop' into fix/nested-resource-resolution

This commit is contained in:
ZipFile 2025-02-23 16:24:59 +00:00
commit e050ad44cc
10 changed files with 113 additions and 62 deletions

View File

@ -64,7 +64,7 @@ FastAPI example:
@app.api_route("/") @app.api_route("/")
@inject @inject
async def index(service: Service = Depends(Provide[Container.service])): async def index(service: Annotated[Service, Depends(Provide[Container.service])]):
value = await service.process() value = await service.process()
return {"result": value} return {"result": value}

View File

@ -1,18 +1,22 @@
"""Application module.""" """Application module."""
from dependency_injector.wiring import inject, Provide from typing import Annotated
from fastapi import FastAPI, Depends
from fastapi import Depends, FastAPI
from dependency_injector.wiring import Provide, inject
from .containers import Container from .containers import Container
from .services import Service from .services import Service
app = FastAPI() app = FastAPI()
@app.api_route("/") @app.api_route("/")
@inject @inject
async def index(service: Service = Depends(Provide[Container.service])): async def index(
service: Annotated[Service, Depends(Provide[Container.service])]
) -> dict[str, str]:
value = await service.process() value = await service.process()
return {"result": value} return {"result": value}

View File

@ -1,4 +1,7 @@
from fastapi import FastAPI, Depends from typing import Annotated
from fastapi import Depends, FastAPI
from dependency_injector import containers, providers from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject from dependency_injector.wiring import Provide, inject
@ -18,7 +21,9 @@ app = FastAPI()
@app.api_route("/") @app.api_route("/")
@inject @inject
async def index(service: Service = Depends(Provide[Container.service])): async def index(
service: Annotated[Service, Depends(Provide[Container.service])]
) -> dict[str, str]:
result = await service.process() result = await service.process()
return {"result": result} return {"result": result}

View File

@ -1,11 +1,14 @@
"""Endpoints module.""" """Endpoints module."""
from typing import Annotated
from fastapi import APIRouter, Depends, Response, status from fastapi import APIRouter, Depends, Response, status
from dependency_injector.wiring import inject, Provide
from dependency_injector.wiring import Provide, inject
from .containers import Container from .containers import Container
from .services import UserService
from .repositories import NotFoundError from .repositories import NotFoundError
from .services import UserService
router = APIRouter() router = APIRouter()
@ -13,7 +16,7 @@ router = APIRouter()
@router.get("/users") @router.get("/users")
@inject @inject
def get_list( def get_list(
user_service: UserService = Depends(Provide[Container.user_service]), user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
): ):
return user_service.get_users() return user_service.get_users()
@ -21,8 +24,8 @@ def get_list(
@router.get("/users/{user_id}") @router.get("/users/{user_id}")
@inject @inject
def get_by_id( def get_by_id(
user_id: int, user_id: int,
user_service: UserService = Depends(Provide[Container.user_service]), user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
): ):
try: try:
return user_service.get_user_by_id(user_id) return user_service.get_user_by_id(user_id)
@ -33,7 +36,7 @@ def get_by_id(
@router.post("/users", status_code=status.HTTP_201_CREATED) @router.post("/users", status_code=status.HTTP_201_CREATED)
@inject @inject
def add( def add(
user_service: UserService = Depends(Provide[Container.user_service]), user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
): ):
return user_service.create_user() return user_service.create_user()
@ -41,9 +44,9 @@ def add(
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject @inject
def remove( def remove(
user_id: int, user_id: int,
user_service: UserService = Depends(Provide[Container.user_service]), user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
): ) -> Response:
try: try:
user_service.delete_user_by_id(user_id) user_service.delete_user_by_id(user_id)
except NotFoundError: except NotFoundError:

View File

@ -1,13 +1,14 @@
"""Endpoints module.""" """Endpoints module."""
from typing import Optional, List from typing import Annotated, List
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pydantic import BaseModel from pydantic import BaseModel
from dependency_injector.wiring import inject, Provide
from .services import SearchService from dependency_injector.wiring import Provide, inject
from .containers import Container from .containers import Container
from .services import SearchService
class Gif(BaseModel): class Gif(BaseModel):
@ -26,11 +27,15 @@ router = APIRouter()
@router.get("/", response_model=Response) @router.get("/", response_model=Response)
@inject @inject
async def index( async def index(
query: Optional[str] = None, default_query: Annotated[str, Depends(Provide[Container.config.default.query])],
limit: Optional[str] = None, default_limit: Annotated[
default_query: str = Depends(Provide[Container.config.default.query]), int, Depends(Provide[Container.config.default.limit.as_int()])
default_limit: int = Depends(Provide[Container.config.default.limit.as_int()]), ],
search_service: SearchService = Depends(Provide[Container.search_service]), search_service: Annotated[
SearchService, Depends(Provide[Container.search_service])
],
query: str | None = None,
limit: int | None = None,
): ):
query = query or default_query query = query or default_query
limit = limit or default_limit limit = limit or default_limit

View File

@ -2,44 +2,39 @@
import asyncio import asyncio
import collections.abc import collections.abc
import functools
import inspect import inspect
import types import types
from . import providers from .wiring import _Marker
from .wiring import _Marker, PatchedCallable
from .providers cimport Provider from .providers cimport Provider, Resource
def _get_sync_patched(fn, patched: PatchedCallable): def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
@functools.wraps(fn) cdef object result
def _patched(*args, **kwargs): cdef dict to_inject
cdef object result cdef object arg_key
cdef dict to_inject cdef Provider provider
cdef object arg_key
cdef Provider provider
to_inject = kwargs.copy() to_inject = kwargs.copy()
for arg_key, provider in patched.injections.items(): for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider() to_inject[arg_key] = provider()
result = fn(*args, **to_inject) result = fn(*args, **to_inject)
if patched.closing: if closings:
for arg_key, provider in patched.closing.items(): for arg_key, provider in closings.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue continue
if not isinstance(provider, providers.Resource): if not isinstance(provider, Resource):
continue continue
provider.shutdown() provider.shutdown()
return result return result
return _patched
async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings): async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result cdef object result
cdef dict to_inject cdef dict to_inject
cdef list to_inject_await = [] cdef list to_inject_await = []
@ -69,7 +64,7 @@ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dic
for arg_key, provider in closings.items(): for arg_key, provider in closings.items():
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
continue continue
if not isinstance(provider, providers.Resource): if not isinstance(provider, Resource):
continue continue
shutdown = provider.shutdown() shutdown = provider.shutdown()
if _isawaitable(shutdown): if _isawaitable(shutdown):

View File

@ -2,6 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import asyncio
import copy import copy
import errno import errno
import functools import functools
@ -27,17 +28,19 @@ except ImportError:
import __builtin__ as builtins import __builtin__ as builtins
try: try:
import asyncio from inspect import _is_coroutine_mark as _is_coroutine_marker
except ImportError: except ImportError:
asyncio = None try:
_is_coroutine_marker = None # Python >=3.12.0,<3.12.5
else: from inspect import _is_coroutine_marker
if sys.version_info >= (3, 5, 3): except ImportError:
import asyncio.coroutines
_is_coroutine_marker = asyncio.coroutines._is_coroutine
else:
_is_coroutine_marker = True _is_coroutine_marker = True
try:
from asyncio.coroutines import _is_coroutine
except ImportError:
_is_coroutine = True
try: try:
import ConfigParser as iniconfigparser import ConfigParser as iniconfigparser
except ImportError: except ImportError:
@ -1475,7 +1478,8 @@ cdef class Coroutine(Callable):
some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4) some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4)
""" """
_is_coroutine = _is_coroutine_marker _is_coroutine_marker = _is_coroutine_marker # Python >=3.12
_is_coroutine = _is_coroutine # Python <3.16
def set_provides(self, provides): def set_provides(self, provides):
"""Set provider provides.""" """Set provider provides."""

View File

@ -1028,8 +1028,8 @@ _inspect_filter = InspectFilter()
_loader = AutoLoader() _loader = AutoLoader()
# Optimizations # Optimizations
from ._cwiring import _sync_inject # noqa
from ._cwiring import _async_inject # noqa from ._cwiring import _async_inject # noqa
from ._cwiring import _get_sync_patched # noqa
# Wiring uses the following Python wrapper because there is # Wiring uses the following Python wrapper because there is
@ -1045,4 +1045,17 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
patched.closing, patched.closing,
) )
return _patched return cast(F, _patched)
def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
def _patched(*args, **kwargs):
return _sync_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
return cast(F, _patched)

View File

@ -1,4 +1,5 @@
"""Coroutine provider tests.""" """Coroutine provider tests."""
import sys
from dependency_injector import providers, errors from dependency_injector import providers, errors
from pytest import mark, raises from pytest import mark, raises
@ -208,3 +209,17 @@ def test_repr():
"<dependency_injector.providers." "<dependency_injector.providers."
"Coroutine({0}) at {1}>".format(repr(example), hex(id(provider))) "Coroutine({0}) at {1}>".format(repr(example), hex(id(provider)))
) )
@mark.skipif(sys.version_info > (3, 15), reason="requires Python<3.16")
def test_asyncio_iscoroutinefunction() -> None:
from asyncio.coroutines import iscoroutinefunction
assert iscoroutinefunction(providers.Coroutine(example))
@mark.skipif(sys.version_info < (3, 12), reason="requires Python>=3.12")
def test_inspect_iscoroutinefunction() -> None:
from inspect import iscoroutinefunction
assert iscoroutinefunction(providers.Coroutine(example))

View File

@ -6,6 +6,13 @@ import inspect
from dependency_injector.wiring import inject from dependency_injector.wiring import inject
def test_isfunction():
@inject
def foo(): ...
assert inspect.isfunction(foo)
def test_asyncio_iscoroutinefunction(): def test_asyncio_iscoroutinefunction():
@inject @inject
async def foo(): async def foo():