diff --git a/src/dependency_injector/providers.pyi b/src/dependency_injector/providers.pyi index ed82035b..d312174e 100644 --- a/src/dependency_injector/providers.pyi +++ b/src/dependency_injector/providers.pyi @@ -15,6 +15,7 @@ from typing import ( Union, Coroutine as _Coroutine, Iterator as _Iterator, + AsyncIterator as _AsyncIterator, Generator as _Generator, overload, ) @@ -287,10 +288,20 @@ class Resource(Provider, Generic[T]): @overload def __init__(self, initializer: _Callable[..., resources.Resource[T]], *args: Injection, **kwargs: Injection) -> None: ... @overload + def __init__(self, initializer: _Callable[..., resources.AsyncResource[T]], *args: Injection, **kwargs: Injection) -> None: ... + @overload def __init__(self, initializer: _Callable[..., _Iterator[T]], *args: Injection, **kwargs: Injection) -> None: ... @overload + def __init__(self, initializer: _Callable[..., _AsyncIterator[T]], *args: Injection, **kwargs: Injection) -> None: ... + @overload + def __init__(self, initializer: _Callable[..., _Coroutine[Injection, Injection, T]], *args: Injection, **kwargs: Injection) -> None: ... + @overload def __init__(self, initializer: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ... + @overload def __call__(self, *args: Injection, **kwargs: Injection) -> T: ... + @overload + def __call__(self, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... + def async_(self, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... @property def args(self) -> Tuple[Injection]: ... def add_args(self, *args: Injection) -> Resource: ... diff --git a/tests/typing/resource.py b/tests/typing/resource.py index 71808b13..b5687451 100644 --- a/tests/typing/resource.py +++ b/tests/typing/resource.py @@ -1,4 +1,4 @@ -from typing import List, Iterator, Generator +from typing import List, Iterator, Generator, AsyncIterator, AsyncGenerator from dependency_injector import providers, resources @@ -41,3 +41,59 @@ class MyResource4(resources.Resource[List[int]]): provider4 = providers.Resource(MyResource4) var4: List[int] = provider4() + + +# Test 5: to check the return type with async function +async def init5() -> List[int]: + ... + + +provider5 = providers.Resource(init5) + + +async def _provide5() -> None: + var1: List[int] = await provider5() # type: ignore + var2: List[int] = await provider5.async_() + + +# Test 6: to check the return type with async iterator +async def init6() -> AsyncIterator[List[int]]: + yield [] + + +provider6 = providers.Resource(init6) + + +async def _provide6() -> None: + var1: List[int] = await provider6() # type: ignore + var2: List[int] = await provider6.async_() + + +# Test 7: to check the return type with async generator +async def init7() -> AsyncGenerator[List[int], None]: + yield [] + + +provider7 = providers.Resource(init7) + + +async def _provide7() -> None: + var1: List[int] = await provider7() # type: ignore + var2: List[int] = await provider7.async_() + + +# Test 8: to check the return type with async resource subclass +class MyResource8(resources.AsyncResource[List[int]]): + async def init(self, *args, **kwargs) -> List[int]: + return [] + + async def shutdown(self, resource: List[int]) -> None: + ... + + +provider8 = providers.Resource(MyResource8) + + +async def _provide8() -> None: + var1: List[int] = await provider8() # type: ignore + var2: List[int] = await provider8.async_()