Add optimization for wiring of async coroutines

This commit is contained in:
Roman Mogylatov 2022-03-27 21:59:19 -04:00
parent 09cf3459c5
commit 10c84165eb
4 changed files with 4851 additions and 1332 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
"""Wiring optimizations module."""

View File

@ -1,25 +1,15 @@
"""Wiring optimizations module."""
import copy
import asyncio
import collections.abc
import functools
import sys
import inspect
import types
from . import providers
from .wiring import _Marker
if sys.version_info[0] == 3: # pragma: no cover
CLASS_TYPES = (type,)
else: # pragma: no cover
CLASS_TYPES = (type, types.ClassType)
copy._deepcopy_dispatch[types.MethodType] = \
lambda obj, memo: type(obj)(obj.im_func,
copy.deepcopy(obj.im_self, memo),
obj.im_class)
def _get_sync_patched(fn):
@functools.wraps(fn)
def _patched(*args, **kwargs):
@ -28,16 +18,14 @@ def _get_sync_patched(fn):
to_inject = kwargs.copy()
for injection, provider in _patched.__injections__.items():
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
if injection not in kwargs or isinstance(kwargs[injection], _Marker):
to_inject[injection] = provider()
result = fn(*args, **to_inject)
if _patched.__closing__:
for injection, provider in _patched.__closing__.items():
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
if injection in kwargs and not isinstance(kwargs[injection], _Marker):
continue
if not isinstance(provider, providers.Resource):
continue
@ -47,6 +35,54 @@ def _get_sync_patched(fn):
return _patched
cdef bint _is_fastapi_default_arg_injection(object injection, dict kwargs):
"""Check if injection is FastAPI injection of the default argument."""
return injection in kwargs and isinstance(kwargs[injection], _Marker)
def _get_async_patched(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
cdef object result
cdef dict to_inject
cdef list to_inject_await = []
cdef list to_close_await = []
to_inject = kwargs.copy()
for injection, provider in _patched.__injections__.items():
if injection not in kwargs or isinstance(kwargs[injection], _Marker):
provide = provider()
if _isawaitable(provide):
to_inject_await.append((injection, provide))
else:
to_inject[injection] = provide
if to_inject_await:
async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
to_inject[injection] = provide
result = await fn(*args, **to_inject)
if _patched.__closing__:
for injection, provider in _patched.__closing__.items():
if injection in kwargs \
and isinstance(kwargs[injection], _Marker):
continue
if not isinstance(provider, providers.Resource):
continue
shutdown = provider.shutdown()
if _isawaitable(shutdown):
to_close_await.append(shutdown)
await asyncio.gather(*to_close_await)
return result
# Hotfix for iscoroutinefunction() for Cython < 3.0.0; can be removed after migration to Cython 3.0.0+
_patched._is_coroutine = asyncio.coroutines._is_coroutine
return _patched
cdef bint _isawaitable(object instance):
"""Return true if object can be passed to an ``await`` expression."""
return (isinstance(instance, types.CoroutineType) or
isinstance(instance, types.GeneratorType) and
bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or
isinstance(instance, collections.abc.Awaitable))

View File

@ -599,48 +599,6 @@ def _get_patched(fn, reference_injections, reference_closing):
return patched
def _get_async_patched(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
to_inject = kwargs.copy()
to_inject_await = []
to_close_await = []
for injection, provider in _patched.__injections__.items():
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
provide = provider()
if inspect.isawaitable(provide):
to_inject_await.append((injection, provide))
else:
to_inject[injection] = provide
async_to_inject = await asyncio.gather(*[provide for _, provide in to_inject_await])
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
to_inject[injection] = provide
result = await fn(*args, **to_inject)
for injection, provider in _patched.__closing__.items():
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
continue
if not isinstance(provider, providers.Resource):
continue
shutdown = provider.shutdown()
if inspect.isawaitable(shutdown):
to_close_await.append(shutdown)
await asyncio.gather(*to_close_await)
return result
return _patched
def _is_fastapi_default_arg_injection(injection, kwargs):
"""Check if injection is FastAPI injection of the default argument."""
return injection in kwargs and isinstance(kwargs[injection], _Marker)
def _is_fastapi_depends(param: Any) -> bool:
return fastapi and isinstance(param, fastapi.params.Depends)
@ -939,3 +897,4 @@ _loader = AutoLoader()
# Optimizations
from ._cwiring import _get_sync_patched # noqa
from ._cwiring import _get_async_patched # noqa