python-dependency-injector/src/dependency_injector/wiring.py
2020-11-15 15:19:45 -05:00

432 lines
13 KiB
Python

"""Wiring module."""
import functools
import inspect
import importlib
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
if sys.version_info < (3, 7):
from typing import GenericMeta
else:
class GenericMeta(type):
...
from . import providers
__all__ = (
'wire',
'unwire',
'inject',
'Provide',
'Provider',
'Closing',
)
T = TypeVar('T')
Container = Any
class ProvidersMap:
def __init__(self, container):
self._container = container
self._map = self._create_providers_map(
current_providers=container.providers,
original_providers=container.declarative_parent.providers,
)
def resolve_provider(
self,
provider: providers.Provider,
) -> Optional[providers.Provider]:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
elif isinstance(provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
return self._resolve_provided_instance(provider)
elif isinstance(provider, providers.ConfigurationOption):
return self._resolve_config_option(provider)
elif isinstance(provider, providers.TypedConfigurationOption):
return self._resolve_config_option(provider.option, as_=provider.provides)
else:
return self._resolve_provider(provider)
def _resolve_delegate(
self,
original: providers.Delegate,
) -> Optional[providers.Provider]:
return self._resolve_provider(original.provides)
def _resolve_provided_instance(
self,
original: providers.Provider,
) -> Optional[providers.Provider]:
modifiers = []
while isinstance(original, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
modifiers.insert(0, original)
original = original.provides
new = self._resolve_provider(original)
if new is None:
return None
for modifier in modifiers:
if isinstance(modifier, providers.ProvidedInstance):
new = new.provided
elif isinstance(modifier, providers.AttributeGetter):
new = getattr(new, modifier.name)
elif isinstance(modifier, providers.ItemGetter):
new = new[modifier.name]
elif isinstance(modifier, providers.MethodCaller):
new = new.call(
*modifier.args,
**modifier.kwargs,
)
return new
def _resolve_config_option(
self,
original: providers.ConfigurationOption,
as_: Any = None,
) -> Optional[providers.Provider]:
original_root = original.root
new = self._resolve_provider(original_root)
if new is None:
return None
new = cast(providers.Configuration, new)
for segment in original.get_name_segments():
if providers.is_provider(segment):
segment = self.resolve_provider(segment)
new = new[segment]
else:
new = getattr(new, segment)
if as_:
new = new.as_(as_)
return new
def _resolve_provider(
self,
original: providers.Provider,
) -> Optional[providers.Provider]:
try:
return self._map[original]
except KeyError:
pass
@classmethod
def _create_providers_map(
cls,
current_providers: Dict[str, providers.Provider],
original_providers: Dict[str, providers.Provider],
) -> Dict[providers.Provider, providers.Provider]:
providers_map = {}
for provider_name, current_provider in current_providers.items():
original_provider = original_providers[provider_name]
providers_map[original_provider] = current_provider
if isinstance(current_provider, providers.Container) \
and isinstance(original_provider, providers.Container):
subcontainer_map = cls._create_providers_map(
current_providers=current_provider.container.providers,
original_providers=original_provider.container.providers,
)
providers_map.update(subcontainer_map)
return providers_map
def wire(
container: Container,
*,
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
) -> None:
"""Wire container providers with provided packages and modules."""
if not _is_declarative_container_instance(container):
raise Exception('Can wire only an instance of the declarative container')
if not modules:
modules = []
if packages:
for package in packages:
modules.extend(_fetch_modules(package))
providers_map = ProvidersMap(container)
for module in modules:
for name, member in inspect.getmembers(module):
if inspect.isfunction(member):
_patch_fn(module, name, member, providers_map)
elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, _is_method):
_patch_method(member, method_name, method, providers_map)
def unwire(
*,
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
) -> None:
"""Wire provided packages and modules with previous wired providers."""
if not modules:
modules = []
if packages:
for package in packages:
modules.extend(_fetch_modules(package))
for module in modules:
for name, member in inspect.getmembers(module):
if inspect.isfunction(member):
_unpatch(module, name, member)
elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch(member, method_name, method)
def inject(fn: Callable[..., Any]) -> Callable[..., Any]:
"""Decorate callable with injecting decorator."""
reference_injections, reference_closing = _fetch_reference_injections(fn)
return _get_patched(fn, reference_injections, reference_closing)
def _patch_fn(
module: ModuleType,
name: str,
fn: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
if not _is_patched(fn):
reference_injections, reference_closing = _fetch_reference_injections(fn)
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map)
setattr(module, name, fn)
def _patch_method(
cls: Type,
name: str,
method: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
if hasattr(cls, '__dict__') \
and name in cls.__dict__ \
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
method = cls.__dict__[name]
fn = method.__func__
else:
fn = method
if not _is_patched(fn):
reference_injections, reference_closing = _fetch_reference_injections(fn)
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map)
if isinstance(method, (classmethod, staticmethod)):
fn = type(method)(fn)
setattr(cls, name, fn)
def _unpatch(
module: ModuleType,
name: str,
fn: Callable[..., Any],
) -> None:
if hasattr(module, '__dict__') \
and name in module.__dict__ \
and isinstance(module.__dict__[name], (classmethod, staticmethod)):
method = module.__dict__[name]
fn = method.__func__
if not _is_patched(fn):
return
_unbind_injections(fn)
def _fetch_reference_injections(
fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
signature = inspect.signature(fn)
injections = {}
closing = {}
for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker):
continue
marker = parameter.default
if isinstance(marker, Closing):
marker = marker.provider
closing[parameter_name] = marker
injections[parameter_name] = marker
return injections, closing
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
for injection, marker in fn.__reference_injections__.items():
provider = providers_map.resolve_provider(marker.provider)
if provider is None:
continue
if isinstance(marker, Provide):
fn.__injections__[injection] = provider
elif isinstance(marker, Provider):
fn.__injections__[injection] = provider.provider
if injection in fn.__reference_closing__:
fn.__closing__[injection] = provider
def _unbind_injections(fn: Callable[..., Any]) -> None:
fn.__injections__ = {}
fn.__closing__ = {}
def _fetch_modules(package):
modules = [package]
for module_info in pkgutil.walk_packages(
path=package.__path__,
prefix=package.__name__ + '.',
):
module = importlib.import_module(module_info.name)
modules.append(module)
return modules
def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member)
def _get_patched(fn, reference_injections, reference_closing):
if inspect.iscoroutinefunction(fn):
patched = _get_async_patched(fn)
else:
patched = _get_sync_patched(fn)
patched.__wired__ = True
patched.__original__ = fn
patched.__injections__ = {}
patched.__reference_injections__ = reference_injections
patched.__closing__ = {}
patched.__reference_closing__ = reference_closing
return patched
def _get_sync_patched(fn):
@functools.wraps(fn)
def _patched(*args, **kwargs):
to_inject = kwargs.copy()
for injection, provider in _patched.__injections__.items():
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider()
result = fn(*args, **to_inject)
for injection, provider in _patched.__closing__.items():
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()
return result
return _patched
def _get_async_patched(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
to_inject = kwargs.copy()
for injection, provider in _patched.__injections__.items():
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider()
result = await fn(*args, **to_inject)
for injection, provider in _patched.__closing__.items():
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()
return result
return _patched
def _is_fastapi_default_arg_injection(injection, kwargs):
"""Check if injection is FastAPI injection of the default argument."""
return injection in kwargs and isinstance(kwargs[injection], _Marker)
def _is_patched(fn):
return getattr(fn, '__wired__', False) is True
def _is_declarative_container_instance(instance: Any) -> bool:
return (not isinstance(instance, type)
and getattr(instance, '__IS_CONTAINER__', False) is True
and getattr(instance, 'declarative_parent', None) is not None)
class ClassGetItemMeta(GenericMeta):
def __getitem__(cls, item):
# Spike for Python 3.6
return cls(item)
class _Marker(Generic[T], metaclass=ClassGetItemMeta):
def __init__(self, provider: providers.Provider) -> None:
self.provider = provider
def __class_getitem__(cls, item) -> T:
return cls(item)
class Provide(_Marker):
...
class Provider(_Marker):
...
class Closing(_Marker):
...