mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-02-07 07:00:49 +03:00
Wiring refactoring (#296)
* Refactor wiring * Add todos to wiring * Implement wiring of config invariant * Implement sub containers wiring + add tests * Add test for wiring config invariant
This commit is contained in:
parent
7f854548d6
commit
6182b8448a
File diff suppressed because it is too large
Load Diff
|
@ -186,20 +186,6 @@ class DynamicContainer(object):
|
|||
for provider in six.itervalues(self.providers):
|
||||
provider.reset_override()
|
||||
|
||||
def resolve_provider_name(self, provider_to_resolve):
|
||||
"""Try to resolve provider name by its instance."""
|
||||
if self.declarative_parent:
|
||||
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
|
||||
if provider_name:
|
||||
return provider_name
|
||||
|
||||
for provider_name, container_provider in self.providers.items():
|
||||
if container_provider is provider_to_resolve:
|
||||
return provider_name
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def wire(self, modules=None, packages=None):
|
||||
"""Wire container providers with provided packages and modules.
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -185,9 +185,9 @@ cdef class List(Provider):
|
|||
|
||||
|
||||
cdef class Container(Provider):
|
||||
cdef object container_cls
|
||||
cdef dict overriding_providers
|
||||
cdef object container
|
||||
cdef object __container_cls
|
||||
cdef dict __overriding_providers
|
||||
cdef object __container
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs)
|
||||
|
||||
|
|
|
@ -54,6 +54,8 @@ class Object(Provider, Generic[T]):
|
|||
class Delegate(Provider):
|
||||
def __init__(self, provides: Provider) -> None: ...
|
||||
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
|
||||
@property
|
||||
def provides(self) -> Provider: ...
|
||||
|
||||
|
||||
class Dependency(Provider, Generic[T]):
|
||||
|
@ -125,9 +127,11 @@ class ConfigurationOption(Provider):
|
|||
def __init__(self, name: Tuple[str], root: Configuration) -> None: ...
|
||||
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
|
||||
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
||||
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
||||
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
|
||||
@property
|
||||
def root(self) -> Configuration: ...
|
||||
def get_name(self) -> str: ...
|
||||
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
|
||||
def get_name_segments(self) -> Tuple[Union[str, Provider]]: ...
|
||||
def as_int(self) -> TypedConfigurationOption[int]: ...
|
||||
def as_float(self) -> TypedConfigurationOption[float]: ...
|
||||
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
|
||||
|
@ -147,10 +151,9 @@ class Configuration(Object):
|
|||
DEFAULT_NAME: str = 'config'
|
||||
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
|
||||
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
||||
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
||||
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
|
||||
def get_name(self) -> str: ...
|
||||
def get(self, selector: str) -> Any: ...
|
||||
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
|
||||
def set(self, selector: str, value: Any) -> OverridingContext: ...
|
||||
def reset_cache(self) -> None: ...
|
||||
def update(self, value: Any) -> None: ...
|
||||
|
@ -275,6 +278,8 @@ class Container(Provider):
|
|||
def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Provider) -> None: ...
|
||||
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
|
||||
def __getattr__(self, name: str) -> Provider: ...
|
||||
@property
|
||||
def container(self) -> T: ...
|
||||
|
||||
|
||||
class Selector(Provider):
|
||||
|
|
|
@ -1144,26 +1144,16 @@ cdef class ConfigurationOption(Provider):
|
|||
segment() if is_provider(segment) else segment for segment in self.__name
|
||||
)
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self.__root_ref()
|
||||
|
||||
def get_name(self):
|
||||
root = self.__root_ref()
|
||||
return '.'.join((root.get_name(), self._get_self_name()))
|
||||
|
||||
def get_option_provider(self, selector):
|
||||
"""Return configuration option provider.
|
||||
|
||||
:param selector: Selector string, e.g. "option1.option2"
|
||||
:type selector: str
|
||||
|
||||
:return: Option provider.
|
||||
:rtype: :py:class:`ConfigurationOption`
|
||||
"""
|
||||
key, *other_keys = selector.split('.')
|
||||
child = getattr(self, key)
|
||||
|
||||
if other_keys:
|
||||
child = child.get_option_provider('.'.join(other_keys))
|
||||
|
||||
return child
|
||||
def get_name_segments(self):
|
||||
return self.__name
|
||||
|
||||
def as_int(self):
|
||||
return TypedConfigurationOption(int, self)
|
||||
|
@ -1386,23 +1376,6 @@ cdef class Configuration(Object):
|
|||
|
||||
return value
|
||||
|
||||
def get_option_provider(self, selector):
|
||||
"""Return configuration option provider.
|
||||
|
||||
:param selector: Selector string, e.g. "option1.option2"
|
||||
:type selector: str
|
||||
|
||||
:return: Option provider.
|
||||
:rtype: :py:class:`ConfigurationOption`
|
||||
"""
|
||||
key, *other_keys = selector.split('.')
|
||||
child = getattr(self, key)
|
||||
|
||||
if other_keys:
|
||||
child = child.get_option_provider('.'.join(other_keys))
|
||||
|
||||
return child
|
||||
|
||||
def set(self, selector, value):
|
||||
"""Override configuration option.
|
||||
|
||||
|
@ -2504,13 +2477,13 @@ cdef class Container(Provider):
|
|||
|
||||
def __init__(self, container_cls, container=None, **overriding_providers):
|
||||
"""Initialize provider."""
|
||||
self.container_cls = container_cls
|
||||
self.overriding_providers = overriding_providers
|
||||
self.__container_cls = container_cls
|
||||
self.__overriding_providers = overriding_providers
|
||||
|
||||
if container is None:
|
||||
container = container_cls()
|
||||
container.override_providers(**overriding_providers)
|
||||
self.container = container
|
||||
self.__container = container
|
||||
|
||||
super(Container, self).__init__()
|
||||
|
||||
|
@ -2521,9 +2494,9 @@ cdef class Container(Provider):
|
|||
return copied
|
||||
|
||||
copied = self.__class__(
|
||||
self.container_cls,
|
||||
deepcopy(self.container, memo),
|
||||
**deepcopy(self.overriding_providers, memo),
|
||||
self.__container_cls,
|
||||
deepcopy(self.__container, memo),
|
||||
**deepcopy(self.__overriding_providers, memo),
|
||||
)
|
||||
|
||||
return copied
|
||||
|
@ -2535,7 +2508,11 @@ cdef class Container(Provider):
|
|||
'\'{cls}\' object has no attribute '
|
||||
'\'{attribute_name}\''.format(cls=self.__class__.__name__,
|
||||
attribute_name=name))
|
||||
return getattr(self.container, name)
|
||||
return getattr(self.__container, name)
|
||||
|
||||
@property
|
||||
def container(self):
|
||||
return self.__container
|
||||
|
||||
def override(self, provider):
|
||||
"""Override provider with another provider."""
|
||||
|
@ -2543,7 +2520,7 @@ cdef class Container(Provider):
|
|||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return single instance."""
|
||||
return self.container
|
||||
return self.__container
|
||||
|
||||
|
||||
cdef class Selector(Provider):
|
||||
|
|
|
@ -5,7 +5,7 @@ import inspect
|
|||
import pkgutil
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
|
||||
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar, cast
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
from typing import GenericMeta
|
||||
|
@ -17,8 +17,114 @@ else:
|
|||
from . import providers
|
||||
|
||||
|
||||
AnyContainer = Any
|
||||
T = TypeVar('T')
|
||||
AnyContainer = Any
|
||||
|
||||
|
||||
class ProvidersMap:
|
||||
|
||||
def __init__(self, container):
|
||||
self._container = container
|
||||
self._map = self._create_providers_map(
|
||||
current_providers=container.providers,
|
||||
original_providers=container.declarative_parent.providers,
|
||||
)
|
||||
|
||||
def resolve_provider(self, provider: providers.Provider) -> providers.Provider:
|
||||
if isinstance(provider, providers.Delegate):
|
||||
return self._resolve_delegate(provider)
|
||||
elif isinstance(provider, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
return self._resolve_provided_instance(provider)
|
||||
elif isinstance(provider, providers.ConfigurationOption):
|
||||
return self._resolve_config_option(provider)
|
||||
elif isinstance(provider, providers.TypedConfigurationOption):
|
||||
return self._resolve_config_option(provider.option, as_=provider.provides)
|
||||
else:
|
||||
return self._resolve_provider(provider)
|
||||
|
||||
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider:
|
||||
return self._resolve_provider(original.provides)
|
||||
|
||||
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider:
|
||||
modifiers = []
|
||||
while isinstance(original, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
modifiers.insert(0, original)
|
||||
original = original.provides
|
||||
|
||||
new = self._resolve_provider(original)
|
||||
|
||||
for modifier in modifiers:
|
||||
if isinstance(modifier, providers.ProvidedInstance):
|
||||
new = new.provided
|
||||
elif isinstance(modifier, providers.AttributeGetter):
|
||||
new = getattr(new, modifier.name)
|
||||
elif isinstance(modifier, providers.ItemGetter):
|
||||
new = new[modifier.name]
|
||||
elif isinstance(modifier, providers.MethodCaller):
|
||||
new = new.call(
|
||||
*modifier.args,
|
||||
**modifier.kwargs,
|
||||
)
|
||||
|
||||
return new
|
||||
|
||||
def _resolve_config_option(
|
||||
self,
|
||||
original: providers.ConfigurationOption,
|
||||
as_: Any = None,
|
||||
) -> providers.Provider:
|
||||
original_root = original.root
|
||||
new = self._resolve_provider(original_root)
|
||||
new = cast(providers.Configuration, new)
|
||||
|
||||
for segment in original.get_name_segments():
|
||||
if providers.is_provider(segment):
|
||||
segment = self.resolve_provider(segment)
|
||||
new = new[segment]
|
||||
else:
|
||||
new = getattr(new, segment)
|
||||
|
||||
if as_:
|
||||
new = new.as_(as_)
|
||||
|
||||
return new
|
||||
|
||||
def _resolve_provider(self, original: providers.Provider) -> providers.Provider:
|
||||
try:
|
||||
return self._map[original]
|
||||
except KeyError:
|
||||
raise Exception('Unable to resolve original provider')
|
||||
|
||||
@classmethod
|
||||
def _create_providers_map(
|
||||
cls,
|
||||
current_providers: Dict[str, providers.Provider],
|
||||
original_providers: Dict[str, providers.Provider],
|
||||
) -> Dict[providers.Provider, providers.Provider]:
|
||||
providers_map = {}
|
||||
for provider_name, current_provider in current_providers.items():
|
||||
original_provider = original_providers[provider_name]
|
||||
providers_map[original_provider] = current_provider
|
||||
|
||||
if isinstance(current_provider, providers.Container) \
|
||||
and isinstance(original_provider, providers.Container):
|
||||
subcontainer_map = cls._create_providers_map(
|
||||
current_providers=current_provider.container.providers,
|
||||
original_providers=original_provider.container.providers,
|
||||
)
|
||||
providers_map.update(subcontainer_map)
|
||||
|
||||
return providers_map
|
||||
|
||||
|
||||
def wire(
|
||||
|
@ -28,6 +134,7 @@ def wire(
|
|||
packages: Optional[Iterable[ModuleType]] = None,
|
||||
) -> None:
|
||||
"""Wire container providers with provided packages and modules."""
|
||||
# TODO: Add protection to only wire declarative container instances
|
||||
if not modules:
|
||||
modules = []
|
||||
|
||||
|
@ -35,12 +142,14 @@ def wire(
|
|||
for package in packages:
|
||||
modules.extend(_fetch_modules(package))
|
||||
|
||||
providers_map = ProvidersMap(container)
|
||||
|
||||
for module in modules:
|
||||
for name, member in inspect.getmembers(module):
|
||||
if inspect.isfunction(member):
|
||||
_patch_fn(module, name, member, container)
|
||||
_patch_fn(module, name, member, providers_map)
|
||||
elif inspect.isclass(member):
|
||||
_patch_cls(member, container)
|
||||
_patch_cls(member, providers_map)
|
||||
|
||||
|
||||
def unwire(
|
||||
|
@ -66,12 +175,12 @@ def unwire(
|
|||
|
||||
def _patch_cls(
|
||||
cls: Type[Any],
|
||||
container: AnyContainer,
|
||||
providers_map: ProvidersMap,
|
||||
) -> None:
|
||||
if not hasattr(cls, '__init__'):
|
||||
return
|
||||
init_method = getattr(cls, '__init__')
|
||||
injections = _resolve_injections(init_method, container)
|
||||
injections = _resolve_injections(init_method, providers_map)
|
||||
if not injections:
|
||||
return
|
||||
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
|
||||
|
@ -90,9 +199,9 @@ def _patch_fn(
|
|||
module: ModuleType,
|
||||
name: str,
|
||||
fn: Callable[..., Any],
|
||||
container: AnyContainer,
|
||||
providers_map: ProvidersMap,
|
||||
) -> None:
|
||||
injections = _resolve_injections(fn, container)
|
||||
injections = _resolve_injections(fn, providers_map)
|
||||
if not injections:
|
||||
return
|
||||
setattr(module, name, _patch_with_injections(fn, injections))
|
||||
|
@ -108,9 +217,7 @@ def _unpatch_fn(
|
|||
setattr(module, name, _get_original_from_patched(fn))
|
||||
|
||||
|
||||
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]: # noqa
|
||||
config = _resolve_container_config(container)
|
||||
|
||||
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa
|
||||
signature = inspect.signature(fn)
|
||||
|
||||
injections = {}
|
||||
|
@ -119,35 +226,7 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
|
|||
continue
|
||||
marker = parameter.default
|
||||
|
||||
if config and isinstance(marker.provider, providers.ConfigurationOption):
|
||||
provider = _prepare_config_injection(marker.provider.get_name(), config)
|
||||
elif config and isinstance(marker.provider, providers.TypedConfigurationOption):
|
||||
provider = _prepare_config_injection(
|
||||
marker.provider.option.get_name(),
|
||||
config,
|
||||
marker.provider.provides,
|
||||
)
|
||||
elif isinstance(marker.provider, providers.Delegate):
|
||||
provider_name = container.resolve_provider_name(marker.provider.provides)
|
||||
provider = container.providers[provider_name]
|
||||
elif isinstance(marker.provider, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
provider = _prepare_provided_instance_injection(
|
||||
marker.provider,
|
||||
container,
|
||||
)
|
||||
elif isinstance(marker.provider, providers.Provider):
|
||||
provider_name = container.resolve_provider_name(marker.provider)
|
||||
if not provider_name:
|
||||
continue
|
||||
provider = container.providers[provider_name]
|
||||
else:
|
||||
continue
|
||||
|
||||
provider = providers_map.resolve_provider(marker.provider)
|
||||
if isinstance(marker, Provide):
|
||||
injections[parameter_name] = provider
|
||||
elif isinstance(marker, Provider):
|
||||
|
@ -156,61 +235,6 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
|
|||
return injections
|
||||
|
||||
|
||||
def _prepare_config_injection(
|
||||
option_name: str,
|
||||
config: providers.Configuration,
|
||||
as_: Any = None,
|
||||
) -> providers.Provider:
|
||||
_, *parts = option_name.split('.')
|
||||
relative_option_name = '.'.join(parts)
|
||||
provider = config.get_option_provider(relative_option_name)
|
||||
if as_:
|
||||
return provider.as_(as_)
|
||||
return provider
|
||||
|
||||
|
||||
def _prepare_provided_instance_injection(
|
||||
current_provider: providers.Provider,
|
||||
container: AnyContainer,
|
||||
) -> providers.Provider:
|
||||
provided_instance_markers = []
|
||||
instance_provider_marker = current_provider
|
||||
while isinstance(instance_provider_marker, (
|
||||
providers.ProvidedInstance,
|
||||
providers.AttributeGetter,
|
||||
providers.ItemGetter,
|
||||
providers.MethodCaller,
|
||||
)):
|
||||
provided_instance_markers.insert(0, instance_provider_marker)
|
||||
instance_provider_marker = instance_provider_marker.provides
|
||||
|
||||
provider_name = container.resolve_provider_name(instance_provider_marker)
|
||||
provider = container.providers[provider_name]
|
||||
|
||||
for provided_instance in provided_instance_markers:
|
||||
if isinstance(provided_instance, providers.ProvidedInstance):
|
||||
provider = provider.provided
|
||||
elif isinstance(provided_instance, providers.AttributeGetter):
|
||||
provider = getattr(provider, provided_instance.name)
|
||||
elif isinstance(provided_instance, providers.ItemGetter):
|
||||
provider = provider[provided_instance.name]
|
||||
elif isinstance(provided_instance, providers.MethodCaller):
|
||||
provider = provider.call(
|
||||
*provided_instance.args,
|
||||
**provided_instance.kwargs,
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
|
||||
for provider in container.providers.values():
|
||||
if isinstance(provider, providers.Configuration):
|
||||
return provider
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_modules(package):
|
||||
modules = []
|
||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||
|
|
|
@ -3,8 +3,15 @@ from dependency_injector import containers, providers
|
|||
from .service import Service
|
||||
|
||||
|
||||
class SubContainer(containers.DeclarativeContainer):
|
||||
|
||||
int_object = providers.Object(1)
|
||||
|
||||
|
||||
class Container(containers.DeclarativeContainer):
|
||||
|
||||
config = providers.Configuration()
|
||||
|
||||
service = providers.Factory(Service)
|
||||
|
||||
sub = providers.Container(SubContainer)
|
||||
|
|
|
@ -39,3 +39,11 @@ def test_provide_provider(service_provider: Callable[..., Service] = Provider[Co
|
|||
|
||||
def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]):
|
||||
return some_value
|
||||
|
||||
|
||||
def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_object]):
|
||||
return some_value
|
||||
|
||||
|
||||
def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]):
|
||||
return some_value
|
||||
|
|
|
@ -72,3 +72,25 @@ class WiringTest(unittest.TestCase):
|
|||
with self.container.service.override(TestService()):
|
||||
some_value = module.test_provided_instance()
|
||||
self.assertEqual(some_value, 10)
|
||||
|
||||
def test_subcontainer(self):
|
||||
some_value = module.test_subcontainer_provider()
|
||||
self.assertEqual(some_value, 1)
|
||||
|
||||
def test_config_invariant(self):
|
||||
config = {
|
||||
'option': {
|
||||
'a': 1,
|
||||
'b': 2,
|
||||
},
|
||||
'switch': 'a',
|
||||
}
|
||||
self.container.config.from_dict(config)
|
||||
|
||||
with self.container.config.switch.override('a'):
|
||||
value_a = module.test_config_invariant()
|
||||
self.assertEqual(value_a, 1)
|
||||
|
||||
with self.container.config.switch.override('b'):
|
||||
value_b = module.test_config_invariant()
|
||||
self.assertEqual(value_b, 2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user