Reengineer wiring

This commit is contained in:
Roman Mogylatov 2020-11-13 15:11:11 -05:00
parent b39c1c8046
commit b43c7017ee

View File

@ -208,11 +208,15 @@ def _patch_fn(
fn: Callable[..., Any], fn: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> None: ) -> None:
injections, closing = _resolve_injections(fn, providers_map) if not _is_patched(fn):
if not injections: reference_injections, reference_closing = _fetch_reference_injections(fn)
return if not reference_injections:
patched = _patch_with_injections(fn, injections, closing) return
setattr(module, name, _wrap_patched(patched, fn, injections, closing)) fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map)
setattr(module, name, fn)
def _patch_method( def _patch_method(
@ -221,28 +225,26 @@ def _patch_method(
method: Callable[..., Any], method: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> None: ) -> None:
injections, closing = _resolve_injections(method, providers_map)
if not injections:
return
if hasattr(cls, '__dict__') \ if hasattr(cls, '__dict__') \
and name in cls.__dict__ \ and name in cls.__dict__ \
and isinstance(cls.__dict__[name], (classmethod, staticmethod)): and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
method = cls.__dict__[name] method = cls.__dict__[name]
patched = _patch_with_injections(method.__func__, injections, closing) fn = method.__func__
patched = type(method)(patched)
else: else:
patched = _patch_with_injections(method, injections, closing) fn = method
setattr(cls, name, _wrap_patched(patched, method, injections, closing)) if not _is_patched(fn):
reference_injections, reference_closing = _fetch_reference_injections(fn)
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map)
def _wrap_patched(patched: Callable[..., Any], original, injections, closing): if isinstance(method, (classmethod, staticmethod)):
patched.__wired__ = True fn = type(method)(fn)
patched.__original__ = original
patched.__injections__ = injections setattr(cls, name, fn)
patched.__closing__ = closing
return patched
def _unpatch( def _unpatch(
@ -250,14 +252,20 @@ def _unpatch(
name: str, name: str,
fn: Callable[..., Any], fn: Callable[..., Any],
) -> None: ) -> None:
if hasattr(module, '__dict__') \
and name in module.__dict__ \
and isinstance(module.__dict__[name], (classmethod, staticmethod)):
method = module.__dict__[name]
fn = method.__func__
if not _is_patched(fn): if not _is_patched(fn):
return return
setattr(module, name, _get_original_from_patched(fn))
_unbind_injections(fn)
def _resolve_injections( def _fetch_reference_injections(
fn: Callable[..., Any], fn: Callable[..., Any],
providers_map: ProvidersMap,
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
signature = inspect.signature(fn) signature = inspect.signature(fn)
@ -268,24 +276,33 @@ def _resolve_injections(
continue continue
marker = parameter.default marker = parameter.default
closing_modifier = False
if isinstance(marker, Closing): if isinstance(marker, Closing):
closing_modifier = True
marker = marker.provider marker = marker.provider
closing[parameter_name] = marker
injections[parameter_name] = marker
return injections, closing
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
for injection, marker in fn.__reference_injections__.items():
provider = providers_map.resolve_provider(marker.provider) provider = providers_map.resolve_provider(marker.provider)
if provider is None: if provider is None:
continue continue
if closing_modifier:
closing[parameter_name] = provider
if isinstance(marker, Provide): if isinstance(marker, Provide):
injections[parameter_name] = provider fn.__injections__[injection] = provider
elif isinstance(marker, Provider): elif isinstance(marker, Provider):
injections[parameter_name] = provider.provider fn.__injections__[injection] = provider.provider
return injections, closing if injection in fn.__reference_closing__:
fn.__closing__[injection] = provider
def _unbind_injections(fn: Callable[..., Any]) -> None:
fn.__injections__ = {}
fn.__closing__ = {}
def _fetch_modules(package): def _fetch_modules(package):
@ -303,26 +320,34 @@ def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member) return inspect.ismethod(member) or inspect.isfunction(member)
def _patch_with_injections(fn, injections, closing): def _get_patched(fn, reference_injections, reference_closing):
if inspect.iscoroutinefunction(fn): if inspect.iscoroutinefunction(fn):
_patched = _get_async_patched(fn, injections, closing) patched = _get_async_patched(fn)
else: else:
_patched = _get_patched(fn, injections, closing) patched = _get_sync_patched(fn)
return _patched
patched.__wired__ = True
patched.__original__ = fn
patched.__injections__ = {}
patched.__reference_injections__ = reference_injections
patched.__closing__ = {}
patched.__reference_closing__ = reference_closing
return patched
def _get_patched(fn, injections, closing): def _get_sync_patched(fn):
@functools.wraps(fn) @functools.wraps(fn)
def _patched(*args, **kwargs): def _patched(*args, **kwargs):
to_inject = kwargs.copy() to_inject = kwargs.copy()
for injection, provider in injections.items(): for injection, provider in _patched.__injections__.items():
if injection not in kwargs \ if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs): or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider() to_inject[injection] = provider()
result = fn(*args, **to_inject) result = fn(*args, **to_inject)
for injection, provider in closing.items(): for injection, provider in _patched.__closing__.items():
if injection in kwargs \ if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs): and not _is_fastapi_default_arg_injection(injection, kwargs):
continue continue
@ -334,18 +359,18 @@ def _get_patched(fn, injections, closing):
return _patched return _patched
def _get_async_patched(fn, injections, closing): def _get_async_patched(fn):
@functools.wraps(fn) @functools.wraps(fn)
async def _patched(*args, **kwargs): async def _patched(*args, **kwargs):
to_inject = kwargs.copy() to_inject = kwargs.copy()
for injection, provider in injections.items(): for injection, provider in _patched.__injections__.items():
if injection not in kwargs \ if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs): or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider() to_inject[injection] = provider()
result = await fn(*args, **to_inject) result = await fn(*args, **to_inject)
for injection, provider in closing.items(): for injection, provider in _patched.__closing__.items():
if injection in kwargs \ if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs): and not _is_fastapi_default_arg_injection(injection, kwargs):
continue continue
@ -366,10 +391,6 @@ def _is_patched(fn):
return getattr(fn, '__wired__', False) is True return getattr(fn, '__wired__', False) is True
def _get_original_from_patched(fn):
return getattr(fn, '__original__')
def _is_declarative_container_instance(instance: Any) -> bool: def _is_declarative_container_instance(instance: Any) -> bool:
return (not isinstance(instance, type) return (not isinstance(instance, type)
and getattr(instance, '__IS_CONTAINER__', False) is True and getattr(instance, '__IS_CONTAINER__', False) is True