diff --git a/tests/typing/resource.py b/tests/typing/resource.py index d01a5106..0d3bf254 100644 --- a/tests/typing/resource.py +++ b/tests/typing/resource.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager, asynccontextmanager from typing import ( Any, AsyncGenerator, @@ -7,6 +8,7 @@ from typing import ( Iterator, List, Optional, + Self, ) from dependency_injector import providers, resources @@ -109,3 +111,62 @@ async def _provide8() -> None: # Test 9: to check string imports provider9: providers.Resource[Dict[Any, Any]] = providers.Resource("builtins.dict") provider9.set_provides("builtins.dict") + + +# Test 10: to check the return type with classes implementing AbstractContextManager protocol +class MyResource10: + def __init__(self) -> None: + pass + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + return None + + +provider10 = providers.Resource(MyResource10) +var10: MyResource10 = provider10() + + +# Test 11: to check the return type with functions decorated with contextlib.contextmanager +@contextmanager +def init11() -> Iterator[int]: + yield 1 + + +provider11 = providers.Resource(init11) +var11: int = provider11() + + +# Test 12: to check the return type with classes implementing AbstractAsyncContextManager protocol +class MyResource12: + def __init__(self) -> None: + pass + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + return None + + +provider12 = providers.Resource(MyResource12) + + +async def _provide12() -> None: + var1: MyResource12 = await provider12() # type: ignore + var2: MyResource12 = await provider12.async_() + + +# Test 13: to check the return type with functions decorated with contextlib.asynccontextmanager +@asynccontextmanager +async def init13() -> AsyncIterator[int]: + yield 1 + + +provider13 = providers.Resource(init13) + +async def _provide13() -> None: + var1: int = await provider13() # type: ignore + var2: int = await provider13.async_()