Refactor wiring

This commit is contained in:
Roman Mogylatov 2020-09-27 15:00:39 -04:00
parent 7f854548d6
commit c145d62263
6 changed files with 3641 additions and 3874 deletions

File diff suppressed because it is too large Load Diff

View File

@ -186,20 +186,6 @@ class DynamicContainer(object):
for provider in six.itervalues(self.providers):
provider.reset_override()
def resolve_provider_name(self, provider_to_resolve):
"""Try to resolve provider name by its instance."""
if self.declarative_parent:
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
if provider_name:
return provider_name
for provider_name, container_provider in self.providers.items():
if container_provider is provider_to_resolve:
return provider_name
else:
return None
def wire(self, modules=None, packages=None):
"""Wire container providers with provided packages and modules.

File diff suppressed because it is too large Load Diff

View File

@ -54,6 +54,8 @@ class Object(Provider, Generic[T]):
class Delegate(Provider):
def __init__(self, provides: Provider) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
@property
def provides(self) -> Provider: ...
class Dependency(Provider, Generic[T]):
@ -126,7 +128,10 @@ class ConfigurationOption(Provider):
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
def __getattr__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: str) -> ConfigurationOption: ...
@property
def root(self) -> Configuration: ...
def get_name(self) -> str: ...
def get_relative_name(self) -> str: ...
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
def as_int(self) -> TypedConfigurationOption[int]: ...
def as_float(self) -> TypedConfigurationOption[float]: ...

View File

@ -1144,10 +1144,17 @@ cdef class ConfigurationOption(Provider):
segment() if is_provider(segment) else segment for segment in self.__name
)
@property
def root(self):
return self.__root_ref()
def get_name(self):
root = self.__root_ref()
return '.'.join((root.get_name(), self._get_self_name()))
def get_relative_name(self):
return self._get_self_name()
def get_option_provider(self, selector):
"""Return configuration option provider.

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar, cast
if sys.version_info < (3, 7):
from typing import GenericMeta
@ -17,8 +17,100 @@ else:
from . import providers
AnyContainer = Any
T = TypeVar('T')
AnyContainer = Any
class ProvidersMap:
def __init__(self, container):
self._container = container
self._map = self._create_providers_map(container)
def resolve_provider(self, provider: providers.Provider) -> providers.Provider:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
elif isinstance(provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
return self._resolve_provided_instance(provider)
elif isinstance(provider, providers.ConfigurationOption):
return self._resolve_config_option(provider)
elif isinstance(provider, providers.TypedConfigurationOption):
return self._resolve_config_option(provider.option, as_=provider.provides)
else:
return self._resolve_provider(provider)
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider:
return self._resolve_provider(original.provides)
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider:
modifiers = []
while isinstance(original, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
modifiers.insert(0, original)
original = original.provides
new = self._resolve_provider(original)
for modifier in modifiers:
if isinstance(modifier, providers.ProvidedInstance):
new = new.provided
elif isinstance(modifier, providers.AttributeGetter):
new = getattr(new, modifier.name)
elif isinstance(modifier, providers.ItemGetter):
new = new[modifier.name]
elif isinstance(modifier, providers.MethodCaller):
new = new.call(
*modifier.args,
**modifier.kwargs,
)
return new
def _resolve_config_option(
self,
original: providers.ConfigurationOption,
as_: Any = None,
) -> providers.Provider:
original_root = original.root
new_root: providers.Configuration = cast(providers.Configuration, self._resolve_provider(original_root))
new_option = new_root.get_option_provider(original.get_relative_name())
if as_:
new_option = new_option.as_(as_)
return new_option
def _resolve_provider(self, original: providers.Provider) -> providers.Provider:
try:
return self._map[original]
except KeyError:
raise Exception('Unable to resolve original provider')
@classmethod
def _create_providers_map(
cls,
container: AnyContainer,
) -> Dict[providers.Provider, providers.Provider]:
current_providers = container.providers
original_providers = container.declarative_parent.providers
providers_map = {}
for provider_name, current_provider in current_providers.items():
original_provider = original_providers[provider_name]
providers_map[original_provider] = current_provider
if isinstance(current_provider, providers.Container):
subcontainer_map = cls._create_providers_map(current_provider.container)
providers_map.update(subcontainer_map)
return providers_map
def wire(
@ -35,12 +127,14 @@ def wire(
for package in packages:
modules.extend(_fetch_modules(package))
providers_map = ProvidersMap(container)
for module in modules:
for name, member in inspect.getmembers(module):
if inspect.isfunction(member):
_patch_fn(module, name, member, container)
_patch_fn(module, name, member, providers_map)
elif inspect.isclass(member):
_patch_cls(member, container)
_patch_cls(member, providers_map)
def unwire(
@ -66,12 +160,12 @@ def unwire(
def _patch_cls(
cls: Type[Any],
container: AnyContainer,
providers_map: ProvidersMap,
) -> None:
if not hasattr(cls, '__init__'):
return
init_method = getattr(cls, '__init__')
injections = _resolve_injections(init_method, container)
injections = _resolve_injections(init_method, providers_map)
if not injections:
return
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
@ -90,9 +184,9 @@ def _patch_fn(
module: ModuleType,
name: str,
fn: Callable[..., Any],
container: AnyContainer,
providers_map: ProvidersMap,
) -> None:
injections = _resolve_injections(fn, container)
injections = _resolve_injections(fn, providers_map)
if not injections:
return
setattr(module, name, _patch_with_injections(fn, injections))
@ -108,9 +202,7 @@ def _unpatch_fn(
setattr(module, name, _get_original_from_patched(fn))
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]: # noqa
config = _resolve_container_config(container)
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa
signature = inspect.signature(fn)
injections = {}
@ -119,35 +211,7 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
continue
marker = parameter.default
if config and isinstance(marker.provider, providers.ConfigurationOption):
provider = _prepare_config_injection(marker.provider.get_name(), config)
elif config and isinstance(marker.provider, providers.TypedConfigurationOption):
provider = _prepare_config_injection(
marker.provider.option.get_name(),
config,
marker.provider.provides,
)
elif isinstance(marker.provider, providers.Delegate):
provider_name = container.resolve_provider_name(marker.provider.provides)
provider = container.providers[provider_name]
elif isinstance(marker.provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provider = _prepare_provided_instance_injection(
marker.provider,
container,
)
elif isinstance(marker.provider, providers.Provider):
provider_name = container.resolve_provider_name(marker.provider)
if not provider_name:
continue
provider = container.providers[provider_name]
else:
continue
provider = providers_map.resolve_provider(marker.provider)
if isinstance(marker, Provide):
injections[parameter_name] = provider
elif isinstance(marker, Provider):
@ -156,61 +220,6 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
return injections
def _prepare_config_injection(
option_name: str,
config: providers.Configuration,
as_: Any = None,
) -> providers.Provider:
_, *parts = option_name.split('.')
relative_option_name = '.'.join(parts)
provider = config.get_option_provider(relative_option_name)
if as_:
return provider.as_(as_)
return provider
def _prepare_provided_instance_injection(
current_provider: providers.Provider,
container: AnyContainer,
) -> providers.Provider:
provided_instance_markers = []
instance_provider_marker = current_provider
while isinstance(instance_provider_marker, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provided_instance_markers.insert(0, instance_provider_marker)
instance_provider_marker = instance_provider_marker.provides
provider_name = container.resolve_provider_name(instance_provider_marker)
provider = container.providers[provider_name]
for provided_instance in provided_instance_markers:
if isinstance(provided_instance, providers.ProvidedInstance):
provider = provider.provided
elif isinstance(provided_instance, providers.AttributeGetter):
provider = getattr(provider, provided_instance.name)
elif isinstance(provided_instance, providers.ItemGetter):
provider = provider[provided_instance.name]
elif isinstance(provided_instance, providers.MethodCaller):
provider = provider.call(
*provided_instance.args,
**provided_instance.kwargs,
)
return provider
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
for provider in container.providers.values():
if isinstance(provider, providers.Configuration):
return provider
else:
return None
def _fetch_modules(package):
modules = []
for loader, module_name, is_pkg in pkgutil.walk_packages(