"""Wiring module.""" import functools import inspect import importlib import pkgutil import sys from types import ModuleType from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast if sys.version_info < (3, 7): from typing import GenericMeta else: class GenericMeta(type): ... from . import providers __all__ = ( 'wire', 'unwire', 'inject', 'Provide', 'Provider', 'Closing', ) T = TypeVar('T') Container = Any class ProvidersMap: def __init__(self, container): self._container = container self._map = self._create_providers_map( current_providers=container.providers, original_providers=container.declarative_parent.providers, ) def resolve_provider( self, provider: providers.Provider, ) -> Optional[providers.Provider]: if isinstance(provider, providers.Delegate): return self._resolve_delegate(provider) elif isinstance(provider, ( providers.ProvidedInstance, providers.AttributeGetter, providers.ItemGetter, providers.MethodCaller, )): return self._resolve_provided_instance(provider) elif isinstance(provider, providers.ConfigurationOption): return self._resolve_config_option(provider) elif isinstance(provider, providers.TypedConfigurationOption): return self._resolve_config_option(provider.option, as_=provider.provides) else: return self._resolve_provider(provider) def _resolve_delegate( self, original: providers.Delegate, ) -> Optional[providers.Provider]: return self._resolve_provider(original.provides) def _resolve_provided_instance( self, original: providers.Provider, ) -> Optional[providers.Provider]: modifiers = [] while isinstance(original, ( providers.ProvidedInstance, providers.AttributeGetter, providers.ItemGetter, providers.MethodCaller, )): modifiers.insert(0, original) original = original.provides new = self._resolve_provider(original) if new is None: return None for modifier in modifiers: if isinstance(modifier, providers.ProvidedInstance): new = new.provided elif isinstance(modifier, providers.AttributeGetter): new = getattr(new, modifier.name) elif isinstance(modifier, providers.ItemGetter): new = new[modifier.name] elif isinstance(modifier, providers.MethodCaller): new = new.call( *modifier.args, **modifier.kwargs, ) return new def _resolve_config_option( self, original: providers.ConfigurationOption, as_: Any = None, ) -> Optional[providers.Provider]: original_root = original.root new = self._resolve_provider(original_root) if new is None: return None new = cast(providers.Configuration, new) for segment in original.get_name_segments(): if providers.is_provider(segment): segment = self.resolve_provider(segment) new = new[segment] else: new = getattr(new, segment) if as_: new = new.as_(as_) return new def _resolve_provider( self, original: providers.Provider, ) -> Optional[providers.Provider]: try: return self._map[original] except KeyError: pass @classmethod def _create_providers_map( cls, current_providers: Dict[str, providers.Provider], original_providers: Dict[str, providers.Provider], ) -> Dict[providers.Provider, providers.Provider]: providers_map = {} for provider_name, current_provider in current_providers.items(): original_provider = original_providers[provider_name] providers_map[original_provider] = current_provider if isinstance(current_provider, providers.Container) \ and isinstance(original_provider, providers.Container): subcontainer_map = cls._create_providers_map( current_providers=current_provider.container.providers, original_providers=original_provider.container.providers, ) providers_map.update(subcontainer_map) return providers_map def wire( container: Container, *, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None, ) -> None: """Wire container providers with provided packages and modules.""" if not _is_declarative_container_instance(container): raise Exception('Can wire only an instance of the declarative container') if not modules: modules = [] if packages: for package in packages: modules.extend(_fetch_modules(package)) providers_map = ProvidersMap(container) for module in modules: for name, member in inspect.getmembers(module): if inspect.isfunction(member): _patch_fn(module, name, member, providers_map) elif inspect.isclass(member): for method_name, method in inspect.getmembers(member, _is_method): _patch_method(member, method_name, method, providers_map) def unwire( *, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None, ) -> None: """Wire provided packages and modules with previous wired providers.""" if not modules: modules = [] if packages: for package in packages: modules.extend(_fetch_modules(package)) for module in modules: for name, member in inspect.getmembers(module): if inspect.isfunction(member): _unpatch(module, name, member) elif inspect.isclass(member): for method_name, method in inspect.getmembers(member, inspect.isfunction): _unpatch(member, method_name, method) def inject(fn: Callable[..., Any]) -> Callable[..., Any]: """Decorate callable with injecting decorator.""" reference_injections, reference_closing = _fetch_reference_injections(fn) return _get_patched(fn, reference_injections, reference_closing) def _patch_fn( module: ModuleType, name: str, fn: Callable[..., Any], providers_map: ProvidersMap, ) -> None: if not _is_patched(fn): reference_injections, reference_closing = _fetch_reference_injections(fn) if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) _bind_injections(fn, providers_map) setattr(module, name, fn) def _patch_method( cls: Type, name: str, method: Callable[..., Any], providers_map: ProvidersMap, ) -> None: if hasattr(cls, '__dict__') \ and name in cls.__dict__ \ and isinstance(cls.__dict__[name], (classmethod, staticmethod)): method = cls.__dict__[name] fn = method.__func__ else: fn = method if not _is_patched(fn): reference_injections, reference_closing = _fetch_reference_injections(fn) if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) _bind_injections(fn, providers_map) if isinstance(method, (classmethod, staticmethod)): fn = type(method)(fn) setattr(cls, name, fn) def _unpatch( module: ModuleType, name: str, fn: Callable[..., Any], ) -> None: if hasattr(module, '__dict__') \ and name in module.__dict__ \ and isinstance(module.__dict__[name], (classmethod, staticmethod)): method = module.__dict__[name] fn = method.__func__ if not _is_patched(fn): return _unbind_injections(fn) def _fetch_reference_injections( fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: signature = inspect.signature(fn) injections = {} closing = {} for parameter_name, parameter in signature.parameters.items(): if not isinstance(parameter.default, _Marker): continue marker = parameter.default if isinstance(marker, Closing): marker = marker.provider closing[parameter_name] = marker injections[parameter_name] = marker return injections, closing def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: for injection, marker in fn.__reference_injections__.items(): provider = providers_map.resolve_provider(marker.provider) if provider is None: continue if isinstance(marker, Provide): fn.__injections__[injection] = provider elif isinstance(marker, Provider): fn.__injections__[injection] = provider.provider if injection in fn.__reference_closing__: fn.__closing__[injection] = provider def _unbind_injections(fn: Callable[..., Any]) -> None: fn.__injections__ = {} fn.__closing__ = {} def _fetch_modules(package): modules = [package] for module_info in pkgutil.walk_packages( path=package.__path__, prefix=package.__name__ + '.', ): module = importlib.import_module(module_info.name) modules.append(module) return modules def _is_method(member): return inspect.ismethod(member) or inspect.isfunction(member) def _get_patched(fn, reference_injections, reference_closing): if inspect.iscoroutinefunction(fn): patched = _get_async_patched(fn) else: patched = _get_sync_patched(fn) patched.__wired__ = True patched.__original__ = fn patched.__injections__ = {} patched.__reference_injections__ = reference_injections patched.__closing__ = {} patched.__reference_closing__ = reference_closing return patched def _get_sync_patched(fn): @functools.wraps(fn) def _patched(*args, **kwargs): to_inject = kwargs.copy() for injection, provider in _patched.__injections__.items(): if injection not in kwargs \ or _is_fastapi_default_arg_injection(injection, kwargs): to_inject[injection] = provider() result = fn(*args, **to_inject) for injection, provider in _patched.__closing__.items(): if injection in kwargs \ and not _is_fastapi_default_arg_injection(injection, kwargs): continue if not isinstance(provider, providers.Resource): continue provider.shutdown() return result return _patched def _get_async_patched(fn): @functools.wraps(fn) async def _patched(*args, **kwargs): to_inject = kwargs.copy() for injection, provider in _patched.__injections__.items(): if injection not in kwargs \ or _is_fastapi_default_arg_injection(injection, kwargs): to_inject[injection] = provider() result = await fn(*args, **to_inject) for injection, provider in _patched.__closing__.items(): if injection in kwargs \ and not _is_fastapi_default_arg_injection(injection, kwargs): continue if not isinstance(provider, providers.Resource): continue provider.shutdown() return result return _patched def _is_fastapi_default_arg_injection(injection, kwargs): """Check if injection is FastAPI injection of the default argument.""" return injection in kwargs and isinstance(kwargs[injection], _Marker) def _is_patched(fn): return getattr(fn, '__wired__', False) is True def _is_declarative_container_instance(instance: Any) -> bool: return (not isinstance(instance, type) and getattr(instance, '__IS_CONTAINER__', False) is True and getattr(instance, 'declarative_parent', None) is not None) class ClassGetItemMeta(GenericMeta): def __getitem__(cls, item): # Spike for Python 3.6 return cls(item) class _Marker(Generic[T], metaclass=ClassGetItemMeta): def __init__(self, provider: providers.Provider) -> None: self.provider = provider def __class_getitem__(cls, item) -> T: return cls(item) class Provide(_Marker): ... class Provider(_Marker): ... class Closing(_Marker): ...