2020-10-09 22:16:27 +03:00
|
|
|
"""Wiring module."""
|
|
|
|
|
2021-01-11 03:26:15 +03:00
|
|
|
import asyncio
|
2020-10-09 22:16:27 +03:00
|
|
|
import functools
|
|
|
|
import inspect
|
2020-11-05 20:20:09 +03:00
|
|
|
import importlib
|
2021-01-29 03:49:24 +03:00
|
|
|
import importlib.machinery
|
2020-10-09 22:16:27 +03:00
|
|
|
import pkgutil
|
|
|
|
import sys
|
|
|
|
from types import ModuleType
|
2020-11-16 00:06:42 +03:00
|
|
|
from typing import (
|
|
|
|
Optional,
|
|
|
|
Iterable,
|
|
|
|
Iterator,
|
|
|
|
Callable,
|
|
|
|
Any,
|
|
|
|
Tuple,
|
|
|
|
Dict,
|
|
|
|
Generic,
|
|
|
|
TypeVar,
|
|
|
|
Type,
|
2021-01-11 16:18:02 +03:00
|
|
|
Union,
|
2020-11-16 00:06:42 +03:00
|
|
|
cast,
|
|
|
|
)
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
if sys.version_info < (3, 7):
|
|
|
|
from typing import GenericMeta
|
|
|
|
else:
|
|
|
|
class GenericMeta(type):
|
|
|
|
...
|
|
|
|
|
2021-01-19 04:49:56 +03:00
|
|
|
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
|
|
|
|
if sys.version_info >= (3, 9):
|
|
|
|
from types import GenericAlias
|
|
|
|
else:
|
|
|
|
GenericAlias = None
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
|
2020-11-18 07:44:32 +03:00
|
|
|
try:
|
2021-02-27 17:45:49 +03:00
|
|
|
import fastapi.params
|
2020-11-18 07:44:32 +03:00
|
|
|
except ImportError:
|
2021-02-27 17:45:49 +03:00
|
|
|
fastapi = None
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
import starlette.requests
|
|
|
|
except ImportError:
|
|
|
|
starlette = None
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
import werkzeug.local
|
|
|
|
except ImportError:
|
|
|
|
werkzeug = None
|
2020-11-18 07:44:32 +03:00
|
|
|
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
from . import providers
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = (
|
|
|
|
'wire',
|
|
|
|
'unwire',
|
2020-11-16 00:06:42 +03:00
|
|
|
'inject',
|
2021-02-21 18:34:28 +03:00
|
|
|
'as_int',
|
|
|
|
'as_float',
|
|
|
|
'as_',
|
|
|
|
'required',
|
|
|
|
'invariant',
|
|
|
|
'provided',
|
2020-10-09 22:16:27 +03:00
|
|
|
'Provide',
|
|
|
|
'Provider',
|
2020-10-30 05:55:09 +03:00
|
|
|
'Closing',
|
2021-01-29 03:49:24 +03:00
|
|
|
'register_loader_containers',
|
|
|
|
'unregister_loader_containers',
|
|
|
|
'install_loader',
|
|
|
|
'uninstall_loader',
|
|
|
|
'is_loader_installed',
|
2020-10-09 22:16:27 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
T = TypeVar('T')
|
2020-11-16 00:06:42 +03:00
|
|
|
F = TypeVar('F', bound=Callable[..., Any])
|
2020-10-09 22:16:27 +03:00
|
|
|
Container = Any
|
|
|
|
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
class Registry:
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self._storage = set()
|
|
|
|
|
|
|
|
def add(self, patched: Callable[..., Any]) -> None:
|
|
|
|
self._storage.add(patched)
|
|
|
|
|
|
|
|
def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
|
|
|
|
for patched in self._storage:
|
|
|
|
if patched.__module__ != module.__name__:
|
|
|
|
continue
|
|
|
|
yield patched
|
|
|
|
|
|
|
|
|
|
|
|
_patched_registry = Registry()
|
|
|
|
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
class ProvidersMap:
|
|
|
|
|
2021-02-21 18:34:28 +03:00
|
|
|
CONTAINER_STRING_ID = '<container>'
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
def __init__(self, container):
|
|
|
|
self._container = container
|
|
|
|
self._map = self._create_providers_map(
|
2021-01-11 16:18:02 +03:00
|
|
|
current_container=container,
|
2021-02-21 18:34:28 +03:00
|
|
|
original_container=(
|
|
|
|
container.declarative_parent
|
|
|
|
if container.declarative_parent
|
|
|
|
else container
|
|
|
|
),
|
2020-10-09 22:16:27 +03:00
|
|
|
)
|
|
|
|
|
2020-10-28 20:50:51 +03:00
|
|
|
def resolve_provider(
|
|
|
|
self,
|
2021-02-21 18:34:28 +03:00
|
|
|
provider: Union[providers.Provider, str],
|
|
|
|
modifier: Optional['Modifier'] = None,
|
2020-10-28 20:50:51 +03:00
|
|
|
) -> Optional[providers.Provider]:
|
2020-10-09 22:16:27 +03:00
|
|
|
if isinstance(provider, providers.Delegate):
|
|
|
|
return self._resolve_delegate(provider)
|
2021-02-27 17:45:49 +03:00
|
|
|
elif isinstance(provider, (
|
2020-10-09 22:16:27 +03:00
|
|
|
providers.ProvidedInstance,
|
|
|
|
providers.AttributeGetter,
|
|
|
|
providers.ItemGetter,
|
|
|
|
providers.MethodCaller,
|
|
|
|
)):
|
|
|
|
return self._resolve_provided_instance(provider)
|
2021-02-27 17:45:49 +03:00
|
|
|
elif isinstance(provider, providers.ConfigurationOption):
|
2020-10-09 22:16:27 +03:00
|
|
|
return self._resolve_config_option(provider)
|
2021-02-27 17:45:49 +03:00
|
|
|
elif isinstance(provider, providers.TypedConfigurationOption):
|
2020-10-09 22:16:27 +03:00
|
|
|
return self._resolve_config_option(provider.option, as_=provider.provides)
|
2021-02-27 17:45:49 +03:00
|
|
|
elif isinstance(provider, str):
|
2021-02-21 18:34:28 +03:00
|
|
|
return self._resolve_string_id(provider, modifier)
|
2021-02-27 17:45:49 +03:00
|
|
|
else:
|
|
|
|
return self._resolve_provider(provider)
|
2020-10-09 22:16:27 +03:00
|
|
|
|
2021-02-21 18:34:28 +03:00
|
|
|
def _resolve_string_id(
|
2020-10-28 20:50:51 +03:00
|
|
|
self,
|
2021-02-21 18:34:28 +03:00
|
|
|
id: str,
|
|
|
|
modifier: Optional['Modifier'] = None,
|
2020-10-28 20:50:51 +03:00
|
|
|
) -> Optional[providers.Provider]:
|
2021-02-21 18:34:28 +03:00
|
|
|
if id == self.CONTAINER_STRING_ID:
|
|
|
|
return self._container.__self__
|
|
|
|
|
|
|
|
provider = self._container
|
|
|
|
for segment in id.split('.'):
|
|
|
|
try:
|
|
|
|
provider = getattr(provider, segment)
|
|
|
|
except AttributeError:
|
|
|
|
return None
|
|
|
|
|
|
|
|
if modifier:
|
|
|
|
provider = modifier.modify(provider, providers_map=self)
|
|
|
|
return provider
|
2020-10-09 22:16:27 +03:00
|
|
|
|
2020-10-28 20:50:51 +03:00
|
|
|
def _resolve_provided_instance(
|
|
|
|
self,
|
|
|
|
original: providers.Provider,
|
|
|
|
) -> Optional[providers.Provider]:
|
2020-10-09 22:16:27 +03:00
|
|
|
modifiers = []
|
|
|
|
while isinstance(original, (
|
|
|
|
providers.ProvidedInstance,
|
|
|
|
providers.AttributeGetter,
|
|
|
|
providers.ItemGetter,
|
|
|
|
providers.MethodCaller,
|
|
|
|
)):
|
|
|
|
modifiers.insert(0, original)
|
|
|
|
original = original.provides
|
|
|
|
|
|
|
|
new = self._resolve_provider(original)
|
2020-10-28 20:44:11 +03:00
|
|
|
if new is None:
|
|
|
|
return None
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
for modifier in modifiers:
|
|
|
|
if isinstance(modifier, providers.ProvidedInstance):
|
|
|
|
new = new.provided
|
|
|
|
elif isinstance(modifier, providers.AttributeGetter):
|
|
|
|
new = getattr(new, modifier.name)
|
|
|
|
elif isinstance(modifier, providers.ItemGetter):
|
|
|
|
new = new[modifier.name]
|
|
|
|
elif isinstance(modifier, providers.MethodCaller):
|
|
|
|
new = new.call(
|
|
|
|
*modifier.args,
|
|
|
|
**modifier.kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
return new
|
|
|
|
|
2021-02-21 18:34:28 +03:00
|
|
|
def _resolve_delegate(
|
|
|
|
self,
|
|
|
|
original: providers.Delegate,
|
|
|
|
) -> Optional[providers.Provider]:
|
|
|
|
return self._resolve_provider(original.provides)
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
def _resolve_config_option(
|
|
|
|
self,
|
|
|
|
original: providers.ConfigurationOption,
|
|
|
|
as_: Any = None,
|
2020-10-28 20:44:11 +03:00
|
|
|
) -> Optional[providers.Provider]:
|
2020-10-09 22:16:27 +03:00
|
|
|
original_root = original.root
|
|
|
|
new = self._resolve_provider(original_root)
|
2020-10-28 20:44:11 +03:00
|
|
|
if new is None:
|
|
|
|
return None
|
2020-10-09 22:16:27 +03:00
|
|
|
new = cast(providers.Configuration, new)
|
|
|
|
|
|
|
|
for segment in original.get_name_segments():
|
|
|
|
if providers.is_provider(segment):
|
|
|
|
segment = self.resolve_provider(segment)
|
|
|
|
new = new[segment]
|
|
|
|
else:
|
|
|
|
new = getattr(new, segment)
|
|
|
|
|
2021-01-16 16:53:40 +03:00
|
|
|
if original.is_required():
|
|
|
|
new = new.required()
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
if as_:
|
|
|
|
new = new.as_(as_)
|
|
|
|
|
|
|
|
return new
|
|
|
|
|
2020-10-28 20:50:51 +03:00
|
|
|
def _resolve_provider(
|
|
|
|
self,
|
|
|
|
original: providers.Provider,
|
|
|
|
) -> Optional[providers.Provider]:
|
2020-10-09 22:16:27 +03:00
|
|
|
try:
|
|
|
|
return self._map[original]
|
|
|
|
except KeyError:
|
2021-02-21 18:34:28 +03:00
|
|
|
return None
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _create_providers_map(
|
|
|
|
cls,
|
2021-01-11 16:18:02 +03:00
|
|
|
current_container: Container,
|
|
|
|
original_container: Container,
|
2020-10-09 22:16:27 +03:00
|
|
|
) -> Dict[providers.Provider, providers.Provider]:
|
2021-01-11 16:18:02 +03:00
|
|
|
current_providers = current_container.providers
|
|
|
|
current_providers['__self__'] = current_container.__self__
|
|
|
|
|
|
|
|
original_providers = original_container.providers
|
|
|
|
original_providers['__self__'] = original_container.__self__
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
providers_map = {}
|
|
|
|
for provider_name, current_provider in current_providers.items():
|
|
|
|
original_provider = original_providers[provider_name]
|
|
|
|
providers_map[original_provider] = current_provider
|
|
|
|
|
|
|
|
if isinstance(current_provider, providers.Container) \
|
|
|
|
and isinstance(original_provider, providers.Container):
|
|
|
|
subcontainer_map = cls._create_providers_map(
|
2021-01-11 16:18:02 +03:00
|
|
|
current_container=current_provider.container,
|
|
|
|
original_container=original_provider.container,
|
2020-10-09 22:16:27 +03:00
|
|
|
)
|
|
|
|
providers_map.update(subcontainer_map)
|
|
|
|
|
|
|
|
return providers_map
|
|
|
|
|
|
|
|
|
2021-02-27 17:45:49 +03:00
|
|
|
class InspectFilter:
|
|
|
|
|
|
|
|
def is_excluded(self, instance: object) -> bool:
|
|
|
|
if self._is_werkzeug_local_proxy(instance):
|
|
|
|
return True
|
|
|
|
elif self._is_starlette_request_cls(instance):
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def _is_werkzeug_local_proxy(self, instance: object) -> bool:
|
|
|
|
return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)
|
|
|
|
|
|
|
|
def _is_starlette_request_cls(self, instance: object) -> bool:
|
|
|
|
return starlette \
|
|
|
|
and isinstance(instance, type) \
|
|
|
|
and issubclass(instance, starlette.requests.Request)
|
|
|
|
|
|
|
|
|
|
|
|
inspect_filter = InspectFilter()
|
|
|
|
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
def wire( # noqa: C901
|
2020-10-09 22:16:27 +03:00
|
|
|
container: Container,
|
|
|
|
*,
|
|
|
|
modules: Optional[Iterable[ModuleType]] = None,
|
|
|
|
packages: Optional[Iterable[ModuleType]] = None,
|
|
|
|
) -> None:
|
|
|
|
"""Wire container providers with provided packages and modules."""
|
|
|
|
if not _is_declarative_container_instance(container):
|
|
|
|
raise Exception('Can wire only an instance of the declarative container')
|
|
|
|
|
|
|
|
if not modules:
|
|
|
|
modules = []
|
|
|
|
|
|
|
|
if packages:
|
|
|
|
for package in packages:
|
|
|
|
modules.extend(_fetch_modules(package))
|
|
|
|
|
|
|
|
providers_map = ProvidersMap(container)
|
|
|
|
|
|
|
|
for module in modules:
|
|
|
|
for name, member in inspect.getmembers(module):
|
2021-02-27 17:45:49 +03:00
|
|
|
if inspect_filter.is_excluded(member):
|
|
|
|
continue
|
2020-10-09 22:16:27 +03:00
|
|
|
if inspect.isfunction(member):
|
|
|
|
_patch_fn(module, name, member, providers_map)
|
|
|
|
elif inspect.isclass(member):
|
2020-10-28 20:11:07 +03:00
|
|
|
for method_name, method in inspect.getmembers(member, _is_method):
|
2020-11-03 23:59:02 +03:00
|
|
|
_patch_method(member, method_name, method, providers_map)
|
2020-10-09 22:16:27 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
for patched in _patched_registry.get_from_module(module):
|
|
|
|
_bind_injections(patched, providers_map)
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
def unwire(
|
|
|
|
*,
|
|
|
|
modules: Optional[Iterable[ModuleType]] = None,
|
|
|
|
packages: Optional[Iterable[ModuleType]] = None,
|
|
|
|
) -> None:
|
|
|
|
"""Wire provided packages and modules with previous wired providers."""
|
|
|
|
if not modules:
|
|
|
|
modules = []
|
|
|
|
|
|
|
|
if packages:
|
|
|
|
for package in packages:
|
|
|
|
modules.extend(_fetch_modules(package))
|
|
|
|
|
|
|
|
for module in modules:
|
|
|
|
for name, member in inspect.getmembers(module):
|
|
|
|
if inspect.isfunction(member):
|
2020-11-03 23:59:02 +03:00
|
|
|
_unpatch(module, name, member)
|
2020-10-09 22:16:27 +03:00
|
|
|
elif inspect.isclass(member):
|
|
|
|
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
2020-11-03 23:59:02 +03:00
|
|
|
_unpatch(member, method_name, method)
|
2020-10-09 22:16:27 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
for patched in _patched_registry.get_from_module(module):
|
|
|
|
_unbind_injections(patched)
|
|
|
|
|
|
|
|
|
|
|
|
def inject(fn: F) -> F:
|
|
|
|
"""Decorate callable with injecting decorator."""
|
|
|
|
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
|
|
|
patched = _get_patched(fn, reference_injections, reference_closing)
|
|
|
|
_patched_registry.add(patched)
|
|
|
|
return cast(F, patched)
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
def _patch_fn(
|
|
|
|
module: ModuleType,
|
|
|
|
name: str,
|
|
|
|
fn: Callable[..., Any],
|
|
|
|
providers_map: ProvidersMap,
|
|
|
|
) -> None:
|
2020-11-16 00:06:42 +03:00
|
|
|
if not _is_patched(fn):
|
|
|
|
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
|
|
|
if not reference_injections:
|
|
|
|
return
|
|
|
|
fn = _get_patched(fn, reference_injections, reference_closing)
|
|
|
|
_patched_registry.add(fn)
|
|
|
|
|
|
|
|
_bind_injections(fn, providers_map)
|
|
|
|
|
|
|
|
setattr(module, name, fn)
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
|
2020-11-03 23:59:02 +03:00
|
|
|
def _patch_method(
|
|
|
|
cls: Type,
|
|
|
|
name: str,
|
|
|
|
method: Callable[..., Any],
|
|
|
|
providers_map: ProvidersMap,
|
|
|
|
) -> None:
|
|
|
|
if hasattr(cls, '__dict__') \
|
|
|
|
and name in cls.__dict__ \
|
|
|
|
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
|
|
|
|
method = cls.__dict__[name]
|
2020-11-16 00:06:42 +03:00
|
|
|
fn = method.__func__
|
2020-11-03 23:59:02 +03:00
|
|
|
else:
|
2020-11-16 00:06:42 +03:00
|
|
|
fn = method
|
2020-11-03 23:59:02 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
if not _is_patched(fn):
|
|
|
|
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
|
|
|
if not reference_injections:
|
|
|
|
return
|
|
|
|
fn = _get_patched(fn, reference_injections, reference_closing)
|
|
|
|
_patched_registry.add(fn)
|
2020-11-03 23:59:02 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
_bind_injections(fn, providers_map)
|
2020-11-03 23:59:02 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
if isinstance(method, (classmethod, staticmethod)):
|
|
|
|
fn = type(method)(fn)
|
|
|
|
|
|
|
|
setattr(cls, name, fn)
|
2020-11-03 23:59:02 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _unpatch(
|
2020-10-09 22:16:27 +03:00
|
|
|
module: ModuleType,
|
|
|
|
name: str,
|
|
|
|
fn: Callable[..., Any],
|
|
|
|
) -> None:
|
2020-11-16 00:06:42 +03:00
|
|
|
if hasattr(module, '__dict__') \
|
|
|
|
and name in module.__dict__ \
|
|
|
|
and isinstance(module.__dict__[name], (classmethod, staticmethod)):
|
|
|
|
method = module.__dict__[name]
|
|
|
|
fn = method.__func__
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
if not _is_patched(fn):
|
|
|
|
return
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
_unbind_injections(fn)
|
2020-10-09 22:16:27 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
|
|
|
|
def _fetch_reference_injections(
|
2020-10-30 05:55:09 +03:00
|
|
|
fn: Callable[..., Any],
|
2020-10-30 23:40:27 +03:00
|
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
2021-02-16 01:47:03 +03:00
|
|
|
# Hotfix, see:
|
|
|
|
# - https://github.com/ets-labs/python-dependency-injector/issues/362
|
|
|
|
# - https://github.com/ets-labs/python-dependency-injector/issues/398
|
|
|
|
if GenericAlias and any((
|
|
|
|
fn is GenericAlias,
|
|
|
|
getattr(fn, '__func__', None) is GenericAlias
|
|
|
|
)):
|
2021-01-19 04:49:56 +03:00
|
|
|
fn = fn.__init__
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
signature = inspect.signature(fn)
|
|
|
|
|
|
|
|
injections = {}
|
2020-10-30 23:40:27 +03:00
|
|
|
closing = {}
|
2020-10-09 22:16:27 +03:00
|
|
|
for parameter_name, parameter in signature.parameters.items():
|
2020-11-18 07:44:32 +03:00
|
|
|
if not isinstance(parameter.default, _Marker) \
|
|
|
|
and not _is_fastapi_depends(parameter.default):
|
2020-10-09 22:16:27 +03:00
|
|
|
continue
|
2020-11-18 07:44:32 +03:00
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
marker = parameter.default
|
|
|
|
|
2020-11-18 07:44:32 +03:00
|
|
|
if _is_fastapi_depends(marker):
|
|
|
|
marker = marker.dependency
|
|
|
|
|
2020-12-06 05:40:51 +03:00
|
|
|
if not isinstance(marker, _Marker):
|
2020-12-05 04:11:21 +03:00
|
|
|
continue
|
|
|
|
|
2020-10-30 05:55:09 +03:00
|
|
|
if isinstance(marker, Closing):
|
|
|
|
marker = marker.provider
|
2020-11-16 00:06:42 +03:00
|
|
|
closing[parameter_name] = marker
|
2020-10-30 05:55:09 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
injections[parameter_name] = marker
|
|
|
|
return injections, closing
|
|
|
|
|
|
|
|
|
|
|
|
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
|
|
|
|
for injection, marker in fn.__reference_injections__.items():
|
2021-02-21 18:34:28 +03:00
|
|
|
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
|
2020-11-16 00:06:42 +03:00
|
|
|
|
2020-10-28 20:44:11 +03:00
|
|
|
if provider is None:
|
|
|
|
continue
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
if isinstance(marker, Provide):
|
2020-11-16 00:06:42 +03:00
|
|
|
fn.__injections__[injection] = provider
|
2020-10-09 22:16:27 +03:00
|
|
|
elif isinstance(marker, Provider):
|
2020-11-16 00:06:42 +03:00
|
|
|
fn.__injections__[injection] = provider.provider
|
2020-10-09 22:16:27 +03:00
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
if injection in fn.__reference_closing__:
|
|
|
|
fn.__closing__[injection] = provider
|
|
|
|
|
|
|
|
|
|
|
|
def _unbind_injections(fn: Callable[..., Any]) -> None:
|
|
|
|
fn.__injections__ = {}
|
|
|
|
fn.__closing__ = {}
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _fetch_modules(package):
|
2020-10-21 00:48:54 +03:00
|
|
|
modules = [package]
|
2020-11-05 20:20:09 +03:00
|
|
|
for module_info in pkgutil.walk_packages(
|
2020-10-09 22:16:27 +03:00
|
|
|
path=package.__path__,
|
|
|
|
prefix=package.__name__ + '.',
|
|
|
|
):
|
2020-11-05 20:20:09 +03:00
|
|
|
module = importlib.import_module(module_info.name)
|
2020-10-09 22:16:27 +03:00
|
|
|
modules.append(module)
|
|
|
|
return modules
|
|
|
|
|
|
|
|
|
2020-10-28 20:11:07 +03:00
|
|
|
def _is_method(member):
|
|
|
|
return inspect.ismethod(member) or inspect.isfunction(member)
|
|
|
|
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
def _get_patched(fn, reference_injections, reference_closing):
|
2020-10-09 22:16:27 +03:00
|
|
|
if inspect.iscoroutinefunction(fn):
|
2020-11-16 00:06:42 +03:00
|
|
|
patched = _get_async_patched(fn)
|
2020-10-09 22:16:27 +03:00
|
|
|
else:
|
2020-11-16 00:06:42 +03:00
|
|
|
patched = _get_sync_patched(fn)
|
|
|
|
|
|
|
|
patched.__wired__ = True
|
|
|
|
patched.__original__ = fn
|
|
|
|
patched.__injections__ = {}
|
|
|
|
patched.__reference_injections__ = reference_injections
|
|
|
|
patched.__closing__ = {}
|
|
|
|
patched.__reference_closing__ = reference_closing
|
|
|
|
|
|
|
|
return patched
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
def _get_sync_patched(fn):
|
2020-10-30 23:55:37 +03:00
|
|
|
@functools.wraps(fn)
|
2020-10-30 23:47:26 +03:00
|
|
|
def _patched(*args, **kwargs):
|
|
|
|
to_inject = kwargs.copy()
|
2020-11-16 00:06:42 +03:00
|
|
|
for injection, provider in _patched.__injections__.items():
|
2020-11-12 23:54:49 +03:00
|
|
|
if injection not in kwargs \
|
|
|
|
or _is_fastapi_default_arg_injection(injection, kwargs):
|
2020-10-30 23:47:26 +03:00
|
|
|
to_inject[injection] = provider()
|
|
|
|
|
|
|
|
result = fn(*args, **to_inject)
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
for injection, provider in _patched.__closing__.items():
|
2020-11-12 23:54:49 +03:00
|
|
|
if injection in kwargs \
|
|
|
|
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
2020-10-30 23:47:26 +03:00
|
|
|
continue
|
|
|
|
if not isinstance(provider, providers.Resource):
|
|
|
|
continue
|
|
|
|
provider.shutdown()
|
|
|
|
|
|
|
|
return result
|
|
|
|
return _patched
|
|
|
|
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
def _get_async_patched(fn):
|
2020-10-30 23:47:26 +03:00
|
|
|
@functools.wraps(fn)
|
|
|
|
async def _patched(*args, **kwargs):
|
|
|
|
to_inject = kwargs.copy()
|
2021-01-11 03:26:15 +03:00
|
|
|
to_inject_await = []
|
|
|
|
to_close_await = []
|
2020-11-16 00:06:42 +03:00
|
|
|
for injection, provider in _patched.__injections__.items():
|
2020-11-12 23:54:49 +03:00
|
|
|
if injection not in kwargs \
|
|
|
|
or _is_fastapi_default_arg_injection(injection, kwargs):
|
2021-01-11 03:26:15 +03:00
|
|
|
provide = provider()
|
|
|
|
if inspect.isawaitable(provide):
|
|
|
|
to_inject_await.append((injection, provide))
|
|
|
|
else:
|
|
|
|
to_inject[injection] = provide
|
|
|
|
|
|
|
|
async_to_inject = await asyncio.gather(*[provide for _, provide in to_inject_await])
|
|
|
|
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
|
|
|
|
to_inject[injection] = provide
|
2020-10-30 23:47:26 +03:00
|
|
|
|
|
|
|
result = await fn(*args, **to_inject)
|
|
|
|
|
2020-11-16 00:06:42 +03:00
|
|
|
for injection, provider in _patched.__closing__.items():
|
2020-11-12 23:54:49 +03:00
|
|
|
if injection in kwargs \
|
|
|
|
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
2020-10-30 23:47:26 +03:00
|
|
|
continue
|
|
|
|
if not isinstance(provider, providers.Resource):
|
|
|
|
continue
|
2021-01-11 03:26:15 +03:00
|
|
|
shutdown = provider.shutdown()
|
|
|
|
if inspect.isawaitable(shutdown):
|
|
|
|
to_close_await.append(shutdown)
|
|
|
|
|
|
|
|
await asyncio.gather(*to_close_await)
|
2020-10-30 23:47:26 +03:00
|
|
|
|
|
|
|
return result
|
|
|
|
return _patched
|
|
|
|
|
|
|
|
|
2020-11-12 23:54:49 +03:00
|
|
|
def _is_fastapi_default_arg_injection(injection, kwargs):
|
|
|
|
"""Check if injection is FastAPI injection of the default argument."""
|
|
|
|
return injection in kwargs and isinstance(kwargs[injection], _Marker)
|
|
|
|
|
|
|
|
|
2020-11-18 07:44:32 +03:00
|
|
|
def _is_fastapi_depends(param: Any) -> bool:
|
2021-02-27 17:45:49 +03:00
|
|
|
return fastapi and isinstance(param, fastapi.params.Depends)
|
2020-11-18 07:44:32 +03:00
|
|
|
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
def _is_patched(fn):
|
|
|
|
return getattr(fn, '__wired__', False) is True
|
|
|
|
|
|
|
|
|
|
|
|
def _is_declarative_container_instance(instance: Any) -> bool:
|
|
|
|
return (not isinstance(instance, type)
|
|
|
|
and getattr(instance, '__IS_CONTAINER__', False) is True
|
|
|
|
and getattr(instance, 'declarative_parent', None) is not None)
|
|
|
|
|
|
|
|
|
2021-01-11 16:18:02 +03:00
|
|
|
def _is_declarative_container(instance: Any) -> bool:
|
|
|
|
return (isinstance(instance, type)
|
|
|
|
and getattr(instance, '__IS_CONTAINER__', False) is True
|
|
|
|
and getattr(instance, 'declarative_parent', None) is None)
|
|
|
|
|
|
|
|
|
2021-02-21 18:34:28 +03:00
|
|
|
class Modifier:
|
|
|
|
|
|
|
|
def modify(
|
|
|
|
self,
|
|
|
|
provider: providers.ConfigurationOption,
|
|
|
|
providers_map: ProvidersMap,
|
|
|
|
) -> providers.Provider:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
class TypeModifier(Modifier):
|
|
|
|
|
|
|
|
def __init__(self, type_: Type):
|
|
|
|
self.type_ = type_
|
|
|
|
|
|
|
|
def modify(
|
|
|
|
self,
|
|
|
|
provider: providers.ConfigurationOption,
|
|
|
|
providers_map: ProvidersMap,
|
|
|
|
) -> providers.Provider:
|
|
|
|
return provider.as_(self.type_)
|
|
|
|
|
|
|
|
|
|
|
|
def as_int() -> TypeModifier:
|
|
|
|
"""Return int type modifier."""
|
|
|
|
return TypeModifier(int)
|
|
|
|
|
|
|
|
|
|
|
|
def as_float() -> TypeModifier:
|
|
|
|
"""Return float type modifier."""
|
|
|
|
return TypeModifier(float)
|
|
|
|
|
|
|
|
|
|
|
|
def as_(type_: Type) -> TypeModifier:
|
|
|
|
"""Return custom type modifier."""
|
|
|
|
return TypeModifier(type_)
|
|
|
|
|
|
|
|
|
|
|
|
class RequiredModifier(Modifier):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.type_modifier = None
|
|
|
|
|
|
|
|
def as_int(self) -> 'RequiredModifier':
|
|
|
|
self.type_modifier = TypeModifier(int)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def as_float(self) -> 'RequiredModifier':
|
|
|
|
self.type_modifier = TypeModifier(float)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def as_(self, type_: Type) -> 'RequiredModifier':
|
|
|
|
self.type_modifier = TypeModifier(type_)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def modify(
|
|
|
|
self,
|
|
|
|
provider: providers.ConfigurationOption,
|
|
|
|
providers_map: ProvidersMap,
|
|
|
|
) -> providers.Provider:
|
|
|
|
provider = provider.required()
|
|
|
|
if self.type_modifier:
|
|
|
|
provider = provider.as_(self.type_modifier.type_)
|
|
|
|
return provider
|
|
|
|
|
|
|
|
|
|
|
|
def required() -> RequiredModifier:
|
|
|
|
"""Return required modifier."""
|
|
|
|
return RequiredModifier()
|
|
|
|
|
|
|
|
|
|
|
|
class InvariantModifier(Modifier):
|
|
|
|
|
|
|
|
def __init__(self, id: str) -> None:
|
|
|
|
self.id = id
|
|
|
|
|
|
|
|
def modify(
|
|
|
|
self,
|
|
|
|
provider: providers.ConfigurationOption,
|
|
|
|
providers_map: ProvidersMap,
|
|
|
|
) -> providers.Provider:
|
|
|
|
invariant_segment = providers_map.resolve_provider(self.id)
|
|
|
|
return provider[invariant_segment]
|
|
|
|
|
|
|
|
|
|
|
|
def invariant(id: str) -> InvariantModifier:
|
|
|
|
"""Return invariant modifier."""
|
|
|
|
return InvariantModifier(id)
|
|
|
|
|
|
|
|
|
|
|
|
class ProvidedInstance(Modifier):
|
|
|
|
|
|
|
|
TYPE_ATTRIBUTE = 'attr'
|
|
|
|
TYPE_ITEM = 'item'
|
|
|
|
TYPE_CALL = 'call'
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.segments = []
|
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
self.segments.append((self.TYPE_ATTRIBUTE, item))
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __getitem__(self, item):
|
|
|
|
self.segments.append((self.TYPE_ITEM, item))
|
|
|
|
return self
|
|
|
|
|
|
|
|
def call(self):
|
|
|
|
self.segments.append((self.TYPE_CALL, None))
|
|
|
|
return self
|
|
|
|
|
|
|
|
def modify(
|
|
|
|
self,
|
|
|
|
provider: providers.ConfigurationOption,
|
|
|
|
providers_map: ProvidersMap,
|
|
|
|
) -> providers.Provider:
|
|
|
|
provider = provider.provided
|
|
|
|
for type_, value in self.segments:
|
|
|
|
if type_ == ProvidedInstance.TYPE_ATTRIBUTE:
|
|
|
|
provider = getattr(provider, value)
|
|
|
|
elif type_ == ProvidedInstance.TYPE_ITEM:
|
|
|
|
provider = provider[value]
|
|
|
|
elif type_ == ProvidedInstance.TYPE_CALL:
|
|
|
|
provider = provider.call()
|
|
|
|
return provider
|
|
|
|
|
|
|
|
|
|
|
|
def provided() -> ProvidedInstance:
|
|
|
|
"""Return provided instance modifier."""
|
|
|
|
return ProvidedInstance()
|
|
|
|
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
class ClassGetItemMeta(GenericMeta):
|
|
|
|
def __getitem__(cls, item):
|
|
|
|
# Spike for Python 3.6
|
2021-02-21 18:34:28 +03:00
|
|
|
if isinstance(item, tuple):
|
|
|
|
return cls(*item)
|
2020-10-09 22:16:27 +03:00
|
|
|
return cls(item)
|
|
|
|
|
|
|
|
|
|
|
|
class _Marker(Generic[T], metaclass=ClassGetItemMeta):
|
|
|
|
|
2021-02-21 18:34:28 +03:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
provider: Union[providers.Provider, Container, str],
|
|
|
|
modifier: Optional[Modifier] = None,
|
|
|
|
) -> None:
|
2021-01-11 16:18:02 +03:00
|
|
|
if _is_declarative_container(provider):
|
|
|
|
provider = provider.__self__
|
2021-02-21 18:34:28 +03:00
|
|
|
self.provider = provider
|
|
|
|
self.modifier = modifier
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
def __class_getitem__(cls, item) -> T:
|
2021-02-21 18:34:28 +03:00
|
|
|
if isinstance(item, tuple):
|
|
|
|
return cls(*item)
|
2020-10-09 22:16:27 +03:00
|
|
|
return cls(item)
|
|
|
|
|
2020-11-18 07:44:32 +03:00
|
|
|
def __call__(self) -> T:
|
|
|
|
return self
|
|
|
|
|
2020-10-09 22:16:27 +03:00
|
|
|
|
|
|
|
class Provide(_Marker):
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
class Provider(_Marker):
|
|
|
|
...
|
2020-10-30 05:55:09 +03:00
|
|
|
|
|
|
|
|
|
|
|
class Closing(_Marker):
|
|
|
|
...
|
2021-01-29 03:49:24 +03:00
|
|
|
|
|
|
|
|
|
|
|
class AutoLoader:
|
|
|
|
"""Auto-wiring module loader.
|
|
|
|
|
|
|
|
Automatically wire containers when modules are imported.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.containers = []
|
|
|
|
self._path_hook = None
|
|
|
|
|
|
|
|
def register_containers(self, *containers):
|
|
|
|
self.containers.extend(containers)
|
|
|
|
|
|
|
|
if not self.installed:
|
|
|
|
self.install()
|
|
|
|
|
|
|
|
def unregister_containers(self, *containers):
|
|
|
|
for container in containers:
|
|
|
|
self.containers.remove(container)
|
|
|
|
|
|
|
|
if not self.containers:
|
|
|
|
self.uninstall()
|
|
|
|
|
|
|
|
def wire_module(self, module):
|
|
|
|
for container in self.containers:
|
|
|
|
container.wire(modules=[module])
|
|
|
|
|
|
|
|
@property
|
|
|
|
def installed(self):
|
|
|
|
return self._path_hook is not None
|
|
|
|
|
|
|
|
def install(self):
|
|
|
|
if self.installed:
|
|
|
|
return
|
|
|
|
|
|
|
|
loader = self
|
|
|
|
|
|
|
|
class SourcelessFileLoader(importlib.machinery.SourcelessFileLoader):
|
|
|
|
def exec_module(self, module):
|
|
|
|
super().exec_module(module)
|
|
|
|
loader.wire_module(module)
|
|
|
|
|
|
|
|
class SourceFileLoader(importlib.machinery.SourceFileLoader):
|
|
|
|
def exec_module(self, module):
|
|
|
|
super().exec_module(module)
|
|
|
|
loader.wire_module(module)
|
|
|
|
|
|
|
|
loader_details = [
|
|
|
|
(SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
|
|
|
|
(SourceFileLoader, importlib.machinery.SOURCE_SUFFIXES),
|
|
|
|
]
|
|
|
|
|
|
|
|
self._path_hook = importlib.machinery.FileFinder.path_hook(*loader_details)
|
|
|
|
|
|
|
|
sys.path_hooks.insert(0, self._path_hook)
|
|
|
|
sys.path_importer_cache.clear()
|
|
|
|
importlib.invalidate_caches()
|
|
|
|
|
|
|
|
def uninstall(self):
|
|
|
|
if not self.installed:
|
|
|
|
return
|
|
|
|
|
|
|
|
sys.path_hooks.remove(self._path_hook)
|
|
|
|
sys.path_importer_cache.clear()
|
|
|
|
importlib.invalidate_caches()
|
|
|
|
|
|
|
|
|
|
|
|
_loader = AutoLoader()
|
|
|
|
|
|
|
|
|
|
|
|
def register_loader_containers(*containers: Container) -> None:
|
|
|
|
"""Register containers in auto-wiring module loader."""
|
|
|
|
_loader.register_containers(*containers)
|
|
|
|
|
|
|
|
|
|
|
|
def unregister_loader_containers(*containers: Container) -> None:
|
|
|
|
"""Unregister containers from auto-wiring module loader."""
|
|
|
|
_loader.unregister_containers(*containers)
|
|
|
|
|
|
|
|
|
|
|
|
def install_loader() -> None:
|
|
|
|
"""Install auto-wiring module loader hook."""
|
|
|
|
_loader.install()
|
|
|
|
|
|
|
|
|
|
|
|
def uninstall_loader() -> None:
|
|
|
|
"""Uninstall auto-wiring module loader hook."""
|
|
|
|
_loader.uninstall()
|
|
|
|
|
|
|
|
|
|
|
|
def is_loader_installed() -> bool:
|
|
|
|
"""Check if auto-wiring module loader hook is installed."""
|
|
|
|
return _loader.installed
|