Add support for inspect.iscoroutinefunction() in Coroutine provider

This commit is contained in:
ZipFile 2024-11-05 14:02:59 +00:00
parent cab75cb9c7
commit d9bce7f531
2 changed files with 25 additions and 4 deletions

View File

@ -26,17 +26,22 @@ except ImportError:
# Python 2.7 # Python 2.7
import __builtin__ as builtins import __builtin__ as builtins
try:
from inspect import _is_coroutine_marker
except ImportError:
_is_coroutine_marker = True
try: try:
import asyncio import asyncio
except ImportError: except ImportError:
asyncio = None asyncio = None
_is_coroutine_marker = None _is_coroutine = None
else: else:
if sys.version_info >= (3, 5, 3): if sys.version_info >= (3, 5, 3):
import asyncio.coroutines import asyncio.coroutines
_is_coroutine_marker = asyncio.coroutines._is_coroutine _is_coroutine = asyncio.coroutines._is_coroutine
else: else:
_is_coroutine_marker = True _is_coroutine = True
try: try:
import ConfigParser as iniconfigparser import ConfigParser as iniconfigparser
@ -1434,7 +1439,8 @@ cdef class Coroutine(Callable):
some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4) some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4)
""" """
_is_coroutine = _is_coroutine_marker _is_coroutine_marker = _is_coroutine_marker # Python >=3.12
_is_coroutine = _is_coroutine # Python <3.16
def set_provides(self, provides): def set_provides(self, provides):
"""Set provider provides.""" """Set provider provides."""

View File

@ -1,4 +1,5 @@
"""Coroutine provider tests.""" """Coroutine provider tests."""
import sys
from dependency_injector import providers, errors from dependency_injector import providers, errors
from pytest import mark, raises from pytest import mark, raises
@ -208,3 +209,17 @@ def test_repr():
"<dependency_injector.providers." "<dependency_injector.providers."
"Coroutine({0}) at {1}>".format(repr(example), hex(id(provider))) "Coroutine({0}) at {1}>".format(repr(example), hex(id(provider)))
) )
@mark.skipif(sys.version_info > (3, 15), reason="requires Python<3.16")
def test_asyncio_iscoroutinefunction() -> None:
from asyncio.coroutines import iscoroutinefunction
assert iscoroutinefunction(providers.Coroutine(example))
@mark.skipif(sys.version_info < (3, 12), reason="requires Python>=3.12")
def test_inspect_iscoroutinefunction() -> None:
from inspect import iscoroutinefunction
assert iscoroutinefunction(providers.Coroutine(example))