From d9bce7f53159476507e5e451356c22b8b8aa2a3f Mon Sep 17 00:00:00 2001 From: ZipFile Date: Tue, 5 Nov 2024 14:02:59 +0000 Subject: [PATCH] Add support for inspect.iscoroutinefunction() in Coroutine provider --- src/dependency_injector/providers.pyx | 14 ++++++++++---- .../providers/coroutines/test_coroutine_py35.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index 402b513a..3ba2c4ab 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -26,17 +26,22 @@ except ImportError: # Python 2.7 import __builtin__ as builtins +try: + from inspect import _is_coroutine_marker +except ImportError: + _is_coroutine_marker = True + try: import asyncio except ImportError: asyncio = None - _is_coroutine_marker = None + _is_coroutine = None else: if sys.version_info >= (3, 5, 3): import asyncio.coroutines - _is_coroutine_marker = asyncio.coroutines._is_coroutine + _is_coroutine = asyncio.coroutines._is_coroutine else: - _is_coroutine_marker = True + _is_coroutine = True try: import ConfigParser as iniconfigparser @@ -1434,7 +1439,8 @@ cdef class Coroutine(Callable): some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4) """ - _is_coroutine = _is_coroutine_marker + _is_coroutine_marker = _is_coroutine_marker # Python >=3.12 + _is_coroutine = _is_coroutine # Python <3.16 def set_provides(self, provides): """Set provider provides.""" diff --git a/tests/unit/providers/coroutines/test_coroutine_py35.py b/tests/unit/providers/coroutines/test_coroutine_py35.py index 22e794b1..de0e9c67 100644 --- a/tests/unit/providers/coroutines/test_coroutine_py35.py +++ b/tests/unit/providers/coroutines/test_coroutine_py35.py @@ -1,4 +1,5 @@ """Coroutine provider tests.""" +import sys from dependency_injector import providers, errors from pytest import mark, raises @@ -208,3 +209,17 @@ def test_repr(): "".format(repr(example), hex(id(provider))) ) + + +@mark.skipif(sys.version_info > (3, 15), reason="requires Python<3.16") +def test_asyncio_iscoroutinefunction() -> None: + from asyncio.coroutines import iscoroutinefunction + + assert iscoroutinefunction(providers.Coroutine(example)) + + +@mark.skipif(sys.version_info < (3, 12), reason="requires Python>=3.12") +def test_inspect_iscoroutinefunction() -> None: + from inspect import iscoroutinefunction + + assert iscoroutinefunction(providers.Coroutine(example))