diff --git a/docs/wiring.rst b/docs/wiring.rst index 2de708d0..74026879 100644 --- a/docs/wiring.rst +++ b/docs/wiring.rst @@ -64,7 +64,7 @@ FastAPI example: @app.api_route("/") @inject - async def index(service: Service = Depends(Provide[Container.service])): + async def index(service: Annotated[Service, Depends(Provide[Container.service])]): value = await service.process() return {"result": value} diff --git a/examples/miniapps/fastapi-redis/fastapiredis/application.py b/examples/miniapps/fastapi-redis/fastapiredis/application.py index f8e4a3bb..52f14366 100644 --- a/examples/miniapps/fastapi-redis/fastapiredis/application.py +++ b/examples/miniapps/fastapi-redis/fastapiredis/application.py @@ -1,18 +1,22 @@ """Application module.""" -from dependency_injector.wiring import inject, Provide -from fastapi import FastAPI, Depends +from typing import Annotated + +from fastapi import Depends, FastAPI + +from dependency_injector.wiring import Provide, inject from .containers import Container from .services import Service - app = FastAPI() @app.api_route("/") @inject -async def index(service: Service = Depends(Provide[Container.service])): +async def index( + service: Annotated[Service, Depends(Provide[Container.service])] +) -> dict[str, str]: value = await service.process() return {"result": value} diff --git a/examples/miniapps/fastapi-simple/fastapi_di_example.py b/examples/miniapps/fastapi-simple/fastapi_di_example.py index 9f3d3f83..6d50499c 100644 --- a/examples/miniapps/fastapi-simple/fastapi_di_example.py +++ b/examples/miniapps/fastapi-simple/fastapi_di_example.py @@ -1,4 +1,7 @@ -from fastapi import FastAPI, Depends +from typing import Annotated + +from fastapi import Depends, FastAPI + from dependency_injector import containers, providers from dependency_injector.wiring import Provide, inject @@ -18,7 +21,9 @@ app = FastAPI() @app.api_route("/") @inject -async def index(service: Service = Depends(Provide[Container.service])): +async def index( + service: Annotated[Service, Depends(Provide[Container.service])] +) -> dict[str, str]: result = await service.process() return {"result": result} diff --git a/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py b/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py index 4d27101e..e02c2740 100644 --- a/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py +++ b/examples/miniapps/fastapi-sqlalchemy/webapp/endpoints.py @@ -1,11 +1,14 @@ """Endpoints module.""" +from typing import Annotated + from fastapi import APIRouter, Depends, Response, status -from dependency_injector.wiring import inject, Provide + +from dependency_injector.wiring import Provide, inject from .containers import Container -from .services import UserService from .repositories import NotFoundError +from .services import UserService router = APIRouter() @@ -13,7 +16,7 @@ router = APIRouter() @router.get("/users") @inject def get_list( - user_service: UserService = Depends(Provide[Container.user_service]), + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.get_users() @@ -21,8 +24,8 @@ def get_list( @router.get("/users/{user_id}") @inject def get_by_id( - user_id: int, - user_service: UserService = Depends(Provide[Container.user_service]), + user_id: int, + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): try: return user_service.get_user_by_id(user_id) @@ -33,7 +36,7 @@ def get_by_id( @router.post("/users", status_code=status.HTTP_201_CREATED) @inject def add( - user_service: UserService = Depends(Provide[Container.user_service]), + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.create_user() @@ -41,9 +44,9 @@ def add( @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @inject def remove( - user_id: int, - user_service: UserService = Depends(Provide[Container.user_service]), -): + user_id: int, + user_service: Annotated[UserService, Depends(Provide[Container.user_service])], +) -> Response: try: user_service.delete_user_by_id(user_id) except NotFoundError: diff --git a/examples/miniapps/fastapi/giphynavigator/endpoints.py b/examples/miniapps/fastapi/giphynavigator/endpoints.py index 2761f203..904eb71d 100644 --- a/examples/miniapps/fastapi/giphynavigator/endpoints.py +++ b/examples/miniapps/fastapi/giphynavigator/endpoints.py @@ -1,13 +1,14 @@ """Endpoints module.""" -from typing import Optional, List +from typing import Annotated, List from fastapi import APIRouter, Depends from pydantic import BaseModel -from dependency_injector.wiring import inject, Provide -from .services import SearchService +from dependency_injector.wiring import Provide, inject + from .containers import Container +from .services import SearchService class Gif(BaseModel): @@ -26,11 +27,15 @@ router = APIRouter() @router.get("/", response_model=Response) @inject async def index( - query: Optional[str] = None, - limit: Optional[str] = None, - default_query: str = Depends(Provide[Container.config.default.query]), - default_limit: int = Depends(Provide[Container.config.default.limit.as_int()]), - search_service: SearchService = Depends(Provide[Container.search_service]), + default_query: Annotated[str, Depends(Provide[Container.config.default.query])], + default_limit: Annotated[ + int, Depends(Provide[Container.config.default.limit.as_int()]) + ], + search_service: Annotated[ + SearchService, Depends(Provide[Container.search_service]) + ], + query: str | None = None, + limit: int | None = None, ): query = query or default_query limit = limit or default_limit diff --git a/src/dependency_injector/_cwiring.pyx b/src/dependency_injector/_cwiring.pyx index 88b6bc5a..84a5485f 100644 --- a/src/dependency_injector/_cwiring.pyx +++ b/src/dependency_injector/_cwiring.pyx @@ -2,44 +2,39 @@ import asyncio import collections.abc -import functools import inspect import types -from . import providers -from .wiring import _Marker, PatchedCallable +from .wiring import _Marker -from .providers cimport Provider +from .providers cimport Provider, Resource -def _get_sync_patched(fn, patched: PatchedCallable): - @functools.wraps(fn) - def _patched(*args, **kwargs): - cdef object result - cdef dict to_inject - cdef object arg_key - cdef Provider provider +def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): + cdef object result + cdef dict to_inject + cdef object arg_key + cdef Provider provider - to_inject = kwargs.copy() - for arg_key, provider in patched.injections.items(): - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): - to_inject[arg_key] = provider() + to_inject = kwargs.copy() + for arg_key, provider in injections.items(): + if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): + to_inject[arg_key] = provider() - result = fn(*args, **to_inject) + result = fn(*args, **to_inject) - if patched.closing: - for arg_key, provider in patched.closing.items(): - if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): - continue - if not isinstance(provider, providers.Resource): - continue - provider.shutdown() + if closings: + for arg_key, provider in closings.items(): + if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): + continue + if not isinstance(provider, Resource): + continue + provider.shutdown() - return result - return _patched + return result -async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings): +async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): cdef object result cdef dict to_inject cdef list to_inject_await = [] @@ -69,7 +64,7 @@ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dic for arg_key, provider in closings.items(): if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): continue - if not isinstance(provider, providers.Resource): + if not isinstance(provider, Resource): continue shutdown = provider.shutdown() if _isawaitable(shutdown): diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index 39716ea0..71b048a1 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -2,6 +2,7 @@ from __future__ import absolute_import +import asyncio import copy import errno import functools @@ -27,17 +28,19 @@ except ImportError: import __builtin__ as builtins try: - import asyncio + from inspect import _is_coroutine_mark as _is_coroutine_marker except ImportError: - asyncio = None - _is_coroutine_marker = None -else: - if sys.version_info >= (3, 5, 3): - import asyncio.coroutines - _is_coroutine_marker = asyncio.coroutines._is_coroutine - else: + try: + # Python >=3.12.0,<3.12.5 + from inspect import _is_coroutine_marker + except ImportError: _is_coroutine_marker = True +try: + from asyncio.coroutines import _is_coroutine +except ImportError: + _is_coroutine = True + try: import ConfigParser as iniconfigparser except ImportError: @@ -1475,7 +1478,8 @@ cdef class Coroutine(Callable): some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4) """ - _is_coroutine = _is_coroutine_marker + _is_coroutine_marker = _is_coroutine_marker # Python >=3.12 + _is_coroutine = _is_coroutine # Python <3.16 def set_provides(self, provides): """Set provider provides.""" diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 69b62c05..5cded9f5 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -1028,8 +1028,8 @@ _inspect_filter = InspectFilter() _loader = AutoLoader() # Optimizations +from ._cwiring import _sync_inject # noqa from ._cwiring import _async_inject # noqa -from ._cwiring import _get_sync_patched # noqa # Wiring uses the following Python wrapper because there is @@ -1045,4 +1045,17 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F: patched.closing, ) - return _patched + return cast(F, _patched) + + +def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: + @functools.wraps(fn) + def _patched(*args, **kwargs): + return _sync_inject( + fn, + args, + kwargs, + patched.injections, + patched.closing, + ) + return cast(F, _patched) diff --git a/tests/unit/providers/coroutines/test_coroutine_py35.py b/tests/unit/providers/coroutines/test_coroutine_py35.py index 22e794b1..de0e9c67 100644 --- a/tests/unit/providers/coroutines/test_coroutine_py35.py +++ b/tests/unit/providers/coroutines/test_coroutine_py35.py @@ -1,4 +1,5 @@ """Coroutine provider tests.""" +import sys from dependency_injector import providers, errors from pytest import mark, raises @@ -208,3 +209,17 @@ def test_repr(): "".format(repr(example), hex(id(provider))) ) + + +@mark.skipif(sys.version_info > (3, 15), reason="requires Python<3.16") +def test_asyncio_iscoroutinefunction() -> None: + from asyncio.coroutines import iscoroutinefunction + + assert iscoroutinefunction(providers.Coroutine(example)) + + +@mark.skipif(sys.version_info < (3, 12), reason="requires Python>=3.12") +def test_inspect_iscoroutinefunction() -> None: + from inspect import iscoroutinefunction + + assert iscoroutinefunction(providers.Coroutine(example)) diff --git a/tests/unit/wiring/test_introspection_py36.py b/tests/unit/wiring/test_introspection_py36.py index c7149602..66b36a80 100644 --- a/tests/unit/wiring/test_introspection_py36.py +++ b/tests/unit/wiring/test_introspection_py36.py @@ -6,6 +6,13 @@ import inspect from dependency_injector.wiring import inject +def test_isfunction(): + @inject + def foo(): ... + + assert inspect.isfunction(foo) + + def test_asyncio_iscoroutinefunction(): @inject async def foo():