mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-12-01 05:53:59 +03:00
Add optimization for wiring of async coroutines
This commit is contained in:
parent
09cf3459c5
commit
10c84165eb
File diff suppressed because it is too large
Load Diff
|
@ -1 +0,0 @@
|
||||||
"""Wiring optimizations module."""
|
|
|
@ -1,25 +1,15 @@
|
||||||
"""Wiring optimizations module."""
|
"""Wiring optimizations module."""
|
||||||
|
|
||||||
import copy
|
import asyncio
|
||||||
|
import collections.abc
|
||||||
import functools
|
import functools
|
||||||
import sys
|
import inspect
|
||||||
import types
|
import types
|
||||||
|
|
||||||
from . import providers
|
from . import providers
|
||||||
from .wiring import _Marker
|
from .wiring import _Marker
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info[0] == 3: # pragma: no cover
|
|
||||||
CLASS_TYPES = (type,)
|
|
||||||
else: # pragma: no cover
|
|
||||||
CLASS_TYPES = (type, types.ClassType)
|
|
||||||
|
|
||||||
copy._deepcopy_dispatch[types.MethodType] = \
|
|
||||||
lambda obj, memo: type(obj)(obj.im_func,
|
|
||||||
copy.deepcopy(obj.im_self, memo),
|
|
||||||
obj.im_class)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_sync_patched(fn):
|
def _get_sync_patched(fn):
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
def _patched(*args, **kwargs):
|
def _patched(*args, **kwargs):
|
||||||
|
@ -28,16 +18,14 @@ def _get_sync_patched(fn):
|
||||||
|
|
||||||
to_inject = kwargs.copy()
|
to_inject = kwargs.copy()
|
||||||
for injection, provider in _patched.__injections__.items():
|
for injection, provider in _patched.__injections__.items():
|
||||||
if injection not in kwargs \
|
if injection not in kwargs or isinstance(kwargs[injection], _Marker):
|
||||||
or _is_fastapi_default_arg_injection(injection, kwargs):
|
|
||||||
to_inject[injection] = provider()
|
to_inject[injection] = provider()
|
||||||
|
|
||||||
result = fn(*args, **to_inject)
|
result = fn(*args, **to_inject)
|
||||||
|
|
||||||
if _patched.__closing__:
|
if _patched.__closing__:
|
||||||
for injection, provider in _patched.__closing__.items():
|
for injection, provider in _patched.__closing__.items():
|
||||||
if injection in kwargs \
|
if injection in kwargs and not isinstance(kwargs[injection], _Marker):
|
||||||
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
|
||||||
continue
|
continue
|
||||||
if not isinstance(provider, providers.Resource):
|
if not isinstance(provider, providers.Resource):
|
||||||
continue
|
continue
|
||||||
|
@ -47,6 +35,54 @@ def _get_sync_patched(fn):
|
||||||
return _patched
|
return _patched
|
||||||
|
|
||||||
|
|
||||||
cdef bint _is_fastapi_default_arg_injection(object injection, dict kwargs):
|
def _get_async_patched(fn):
|
||||||
"""Check if injection is FastAPI injection of the default argument."""
|
@functools.wraps(fn)
|
||||||
return injection in kwargs and isinstance(kwargs[injection], _Marker)
|
async def _patched(*args, **kwargs):
|
||||||
|
cdef object result
|
||||||
|
cdef dict to_inject
|
||||||
|
cdef list to_inject_await = []
|
||||||
|
cdef list to_close_await = []
|
||||||
|
|
||||||
|
to_inject = kwargs.copy()
|
||||||
|
for injection, provider in _patched.__injections__.items():
|
||||||
|
if injection not in kwargs or isinstance(kwargs[injection], _Marker):
|
||||||
|
provide = provider()
|
||||||
|
if _isawaitable(provide):
|
||||||
|
to_inject_await.append((injection, provide))
|
||||||
|
else:
|
||||||
|
to_inject[injection] = provide
|
||||||
|
|
||||||
|
if to_inject_await:
|
||||||
|
async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))
|
||||||
|
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
|
||||||
|
to_inject[injection] = provide
|
||||||
|
|
||||||
|
result = await fn(*args, **to_inject)
|
||||||
|
|
||||||
|
if _patched.__closing__:
|
||||||
|
for injection, provider in _patched.__closing__.items():
|
||||||
|
if injection in kwargs \
|
||||||
|
and isinstance(kwargs[injection], _Marker):
|
||||||
|
continue
|
||||||
|
if not isinstance(provider, providers.Resource):
|
||||||
|
continue
|
||||||
|
shutdown = provider.shutdown()
|
||||||
|
if _isawaitable(shutdown):
|
||||||
|
to_close_await.append(shutdown)
|
||||||
|
|
||||||
|
await asyncio.gather(*to_close_await)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Hotfix for iscoroutinefunction() for Cython < 3.0.0; can be removed after migration to Cython 3.0.0+
|
||||||
|
_patched._is_coroutine = asyncio.coroutines._is_coroutine
|
||||||
|
|
||||||
|
return _patched
|
||||||
|
|
||||||
|
|
||||||
|
cdef bint _isawaitable(object instance):
|
||||||
|
"""Return true if object can be passed to an ``await`` expression."""
|
||||||
|
return (isinstance(instance, types.CoroutineType) or
|
||||||
|
isinstance(instance, types.GeneratorType) and
|
||||||
|
bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or
|
||||||
|
isinstance(instance, collections.abc.Awaitable))
|
||||||
|
|
|
@ -599,48 +599,6 @@ def _get_patched(fn, reference_injections, reference_closing):
|
||||||
return patched
|
return patched
|
||||||
|
|
||||||
|
|
||||||
def _get_async_patched(fn):
|
|
||||||
@functools.wraps(fn)
|
|
||||||
async def _patched(*args, **kwargs):
|
|
||||||
to_inject = kwargs.copy()
|
|
||||||
to_inject_await = []
|
|
||||||
to_close_await = []
|
|
||||||
for injection, provider in _patched.__injections__.items():
|
|
||||||
if injection not in kwargs \
|
|
||||||
or _is_fastapi_default_arg_injection(injection, kwargs):
|
|
||||||
provide = provider()
|
|
||||||
if inspect.isawaitable(provide):
|
|
||||||
to_inject_await.append((injection, provide))
|
|
||||||
else:
|
|
||||||
to_inject[injection] = provide
|
|
||||||
|
|
||||||
async_to_inject = await asyncio.gather(*[provide for _, provide in to_inject_await])
|
|
||||||
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
|
|
||||||
to_inject[injection] = provide
|
|
||||||
|
|
||||||
result = await fn(*args, **to_inject)
|
|
||||||
|
|
||||||
for injection, provider in _patched.__closing__.items():
|
|
||||||
if injection in kwargs \
|
|
||||||
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
|
||||||
continue
|
|
||||||
if not isinstance(provider, providers.Resource):
|
|
||||||
continue
|
|
||||||
shutdown = provider.shutdown()
|
|
||||||
if inspect.isawaitable(shutdown):
|
|
||||||
to_close_await.append(shutdown)
|
|
||||||
|
|
||||||
await asyncio.gather(*to_close_await)
|
|
||||||
|
|
||||||
return result
|
|
||||||
return _patched
|
|
||||||
|
|
||||||
|
|
||||||
def _is_fastapi_default_arg_injection(injection, kwargs):
|
|
||||||
"""Check if injection is FastAPI injection of the default argument."""
|
|
||||||
return injection in kwargs and isinstance(kwargs[injection], _Marker)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_fastapi_depends(param: Any) -> bool:
|
def _is_fastapi_depends(param: Any) -> bool:
|
||||||
return fastapi and isinstance(param, fastapi.params.Depends)
|
return fastapi and isinstance(param, fastapi.params.Depends)
|
||||||
|
|
||||||
|
@ -939,3 +897,4 @@ _loader = AutoLoader()
|
||||||
|
|
||||||
# Optimizations
|
# Optimizations
|
||||||
from ._cwiring import _get_sync_patched # noqa
|
from ._cwiring import _get_sync_patched # noqa
|
||||||
|
from ._cwiring import _get_async_patched # noqa
|
||||||
|
|
Loading…
Reference in New Issue
Block a user