mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-04-17 15:42:09 +03:00
Refactor wiring
This commit is contained in:
parent
7f854548d6
commit
c145d62263
File diff suppressed because it is too large
Load Diff
|
@ -186,20 +186,6 @@ class DynamicContainer(object):
|
|||
for provider in six.itervalues(self.providers):
|
||||
provider.reset_override()
|
||||
|
||||
def resolve_provider_name(self, provider_to_resolve):
|
||||
"""Try to resolve provider name by its instance."""
|
||||
if self.declarative_parent:
|
||||
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
|
||||
if provider_name:
|
||||
return provider_name
|
||||
|
||||
for provider_name, container_provider in self.providers.items():
|
||||
if container_provider is provider_to_resolve:
|
||||
return provider_name
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def wire(self, modules=None, packages=None):
|
||||
"""Wire container providers with provided packages and modules.
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -54,6 +54,8 @@ class Object(Provider, Generic[T]):
|
|||
class Delegate(Provider):
|
||||
def __init__(self, provides: Provider) -> None: ...
|
||||
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
|
||||
@property
|
||||
def provides(self) -> Provider: ...
|
||||
|
||||
|
||||
class Dependency(Provider, Generic[T]):
|
||||
|
@ -126,7 +128,10 @@ class ConfigurationOption(Provider):
|
|||
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
|
||||
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
||||
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
||||
@property
|
||||
def root(self) -> Configuration: ...
|
||||
def get_name(self) -> str: ...
|
||||
def get_relative_name(self) -> str: ...
|
||||
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
|
||||
def as_int(self) -> TypedConfigurationOption[int]: ...
|
||||
def as_float(self) -> TypedConfigurationOption[float]: ...
|
||||
|
|
|
@ -1144,10 +1144,17 @@ cdef class ConfigurationOption(Provider):
|
|||
segment() if is_provider(segment) else segment for segment in self.__name
|
||||
)
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self.__root_ref()
|
||||
|
||||
def get_name(self):
|
||||
root = self.__root_ref()
|
||||
return '.'.join((root.get_name(), self._get_self_name()))
|
||||
|
||||
def get_relative_name(self):
|
||||
return self._get_self_name()
|
||||
|
||||
def get_option_provider(self, selector):
|
||||
"""Return configuration option provider.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import inspect
|
|||
import pkgutil
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
|
||||
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar, cast
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
from typing import GenericMeta
|
||||
|
@ -17,8 +17,100 @@ else:
|
|||
from . import providers
|
||||
|
||||
|
||||
AnyContainer = Any
|
||||
T = TypeVar('T')
|
||||
AnyContainer = Any
|
||||
|
||||
|
||||
class ProvidersMap:
|
||||
|
||||
def __init__(self, container):
|
||||
self._container = container
|
||||
self._map = self._create_providers_map(container)
|
||||
|
||||
def resolve_provider(self, provider: providers.Provider) -> providers.Provider:
|
||||
if isinstance(provider, providers.Delegate):
|
||||
return self._resolve_delegate(provider)
|
||||
elif isinstance(provider, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
return self._resolve_provided_instance(provider)
|
||||
elif isinstance(provider, providers.ConfigurationOption):
|
||||
return self._resolve_config_option(provider)
|
||||
elif isinstance(provider, providers.TypedConfigurationOption):
|
||||
return self._resolve_config_option(provider.option, as_=provider.provides)
|
||||
else:
|
||||
return self._resolve_provider(provider)
|
||||
|
||||
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider:
|
||||
return self._resolve_provider(original.provides)
|
||||
|
||||
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider:
|
||||
modifiers = []
|
||||
while isinstance(original, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
modifiers.insert(0, original)
|
||||
original = original.provides
|
||||
|
||||
new = self._resolve_provider(original)
|
||||
|
||||
for modifier in modifiers:
|
||||
if isinstance(modifier, providers.ProvidedInstance):
|
||||
new = new.provided
|
||||
elif isinstance(modifier, providers.AttributeGetter):
|
||||
new = getattr(new, modifier.name)
|
||||
elif isinstance(modifier, providers.ItemGetter):
|
||||
new = new[modifier.name]
|
||||
elif isinstance(modifier, providers.MethodCaller):
|
||||
new = new.call(
|
||||
*modifier.args,
|
||||
**modifier.kwargs,
|
||||
)
|
||||
|
||||
return new
|
||||
|
||||
def _resolve_config_option(
|
||||
self,
|
||||
original: providers.ConfigurationOption,
|
||||
as_: Any = None,
|
||||
) -> providers.Provider:
|
||||
original_root = original.root
|
||||
new_root: providers.Configuration = cast(providers.Configuration, self._resolve_provider(original_root))
|
||||
new_option = new_root.get_option_provider(original.get_relative_name())
|
||||
if as_:
|
||||
new_option = new_option.as_(as_)
|
||||
return new_option
|
||||
|
||||
def _resolve_provider(self, original: providers.Provider) -> providers.Provider:
|
||||
try:
|
||||
return self._map[original]
|
||||
except KeyError:
|
||||
raise Exception('Unable to resolve original provider')
|
||||
|
||||
@classmethod
|
||||
def _create_providers_map(
|
||||
cls,
|
||||
container: AnyContainer,
|
||||
) -> Dict[providers.Provider, providers.Provider]:
|
||||
current_providers = container.providers
|
||||
original_providers = container.declarative_parent.providers
|
||||
|
||||
providers_map = {}
|
||||
for provider_name, current_provider in current_providers.items():
|
||||
original_provider = original_providers[provider_name]
|
||||
providers_map[original_provider] = current_provider
|
||||
|
||||
if isinstance(current_provider, providers.Container):
|
||||
subcontainer_map = cls._create_providers_map(current_provider.container)
|
||||
providers_map.update(subcontainer_map)
|
||||
|
||||
return providers_map
|
||||
|
||||
|
||||
def wire(
|
||||
|
@ -35,12 +127,14 @@ def wire(
|
|||
for package in packages:
|
||||
modules.extend(_fetch_modules(package))
|
||||
|
||||
providers_map = ProvidersMap(container)
|
||||
|
||||
for module in modules:
|
||||
for name, member in inspect.getmembers(module):
|
||||
if inspect.isfunction(member):
|
||||
_patch_fn(module, name, member, container)
|
||||
_patch_fn(module, name, member, providers_map)
|
||||
elif inspect.isclass(member):
|
||||
_patch_cls(member, container)
|
||||
_patch_cls(member, providers_map)
|
||||
|
||||
|
||||
def unwire(
|
||||
|
@ -66,12 +160,12 @@ def unwire(
|
|||
|
||||
def _patch_cls(
|
||||
cls: Type[Any],
|
||||
container: AnyContainer,
|
||||
providers_map: ProvidersMap,
|
||||
) -> None:
|
||||
if not hasattr(cls, '__init__'):
|
||||
return
|
||||
init_method = getattr(cls, '__init__')
|
||||
injections = _resolve_injections(init_method, container)
|
||||
injections = _resolve_injections(init_method, providers_map)
|
||||
if not injections:
|
||||
return
|
||||
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
|
||||
|
@ -90,9 +184,9 @@ def _patch_fn(
|
|||
module: ModuleType,
|
||||
name: str,
|
||||
fn: Callable[..., Any],
|
||||
container: AnyContainer,
|
||||
providers_map: ProvidersMap,
|
||||
) -> None:
|
||||
injections = _resolve_injections(fn, container)
|
||||
injections = _resolve_injections(fn, providers_map)
|
||||
if not injections:
|
||||
return
|
||||
setattr(module, name, _patch_with_injections(fn, injections))
|
||||
|
@ -108,9 +202,7 @@ def _unpatch_fn(
|
|||
setattr(module, name, _get_original_from_patched(fn))
|
||||
|
||||
|
||||
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]: # noqa
|
||||
config = _resolve_container_config(container)
|
||||
|
||||
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa
|
||||
signature = inspect.signature(fn)
|
||||
|
||||
injections = {}
|
||||
|
@ -119,35 +211,7 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
|
|||
continue
|
||||
marker = parameter.default
|
||||
|
||||
if config and isinstance(marker.provider, providers.ConfigurationOption):
|
||||
provider = _prepare_config_injection(marker.provider.get_name(), config)
|
||||
elif config and isinstance(marker.provider, providers.TypedConfigurationOption):
|
||||
provider = _prepare_config_injection(
|
||||
marker.provider.option.get_name(),
|
||||
config,
|
||||
marker.provider.provides,
|
||||
)
|
||||
elif isinstance(marker.provider, providers.Delegate):
|
||||
provider_name = container.resolve_provider_name(marker.provider.provides)
|
||||
provider = container.providers[provider_name]
|
||||
elif isinstance(marker.provider, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
provider = _prepare_provided_instance_injection(
|
||||
marker.provider,
|
||||
container,
|
||||
)
|
||||
elif isinstance(marker.provider, providers.Provider):
|
||||
provider_name = container.resolve_provider_name(marker.provider)
|
||||
if not provider_name:
|
||||
continue
|
||||
provider = container.providers[provider_name]
|
||||
else:
|
||||
continue
|
||||
|
||||
provider = providers_map.resolve_provider(marker.provider)
|
||||
if isinstance(marker, Provide):
|
||||
injections[parameter_name] = provider
|
||||
elif isinstance(marker, Provider):
|
||||
|
@ -156,61 +220,6 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
|
|||
return injections
|
||||
|
||||
|
||||
def _prepare_config_injection(
|
||||
option_name: str,
|
||||
config: providers.Configuration,
|
||||
as_: Any = None,
|
||||
) -> providers.Provider:
|
||||
_, *parts = option_name.split('.')
|
||||
relative_option_name = '.'.join(parts)
|
||||
provider = config.get_option_provider(relative_option_name)
|
||||
if as_:
|
||||
return provider.as_(as_)
|
||||
return provider
|
||||
|
||||
|
||||
def _prepare_provided_instance_injection(
|
||||
current_provider: providers.Provider,
|
||||
container: AnyContainer,
|
||||
) -> providers.Provider:
|
||||
provided_instance_markers = []
|
||||
instance_provider_marker = current_provider
|
||||
while isinstance(instance_provider_marker, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
provided_instance_markers.insert(0, instance_provider_marker)
|
||||
instance_provider_marker = instance_provider_marker.provides
|
||||
|
||||
provider_name = container.resolve_provider_name(instance_provider_marker)
|
||||
provider = container.providers[provider_name]
|
||||
|
||||
for provided_instance in provided_instance_markers:
|
||||
if isinstance(provided_instance, providers.ProvidedInstance):
|
||||
provider = provider.provided
|
||||
elif isinstance(provided_instance, providers.AttributeGetter):
|
||||
provider = getattr(provider, provided_instance.name)
|
||||
elif isinstance(provided_instance, providers.ItemGetter):
|
||||
provider = provider[provided_instance.name]
|
||||
elif isinstance(provided_instance, providers.MethodCaller):
|
||||
provider = provider.call(
|
||||
*provided_instance.args,
|
||||
**provided_instance.kwargs,
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
|
||||
for provider in container.providers.values():
|
||||
if isinstance(provider, providers.Configuration):
|
||||
return provider
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_modules(package):
|
||||
modules = []
|
||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||
|
|
Loading…
Reference in New Issue
Block a user