Wiring refactoring (#296)

* Refactor wiring

* Add todos to wiring

* Implement wiring of config invariant

* Implement sub containers wiring + add tests

* Add test for wiring config invariant
This commit is contained in:
Roman Mogylatov 2020-09-27 23:10:11 -04:00 committed by GitHub
parent 7f854548d6
commit 6182b8448a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 3847 additions and 4409 deletions

File diff suppressed because it is too large Load Diff

View File

@ -186,20 +186,6 @@ class DynamicContainer(object):
for provider in six.itervalues(self.providers):
provider.reset_override()
def resolve_provider_name(self, provider_to_resolve):
"""Try to resolve provider name by its instance."""
if self.declarative_parent:
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
if provider_name:
return provider_name
for provider_name, container_provider in self.providers.items():
if container_provider is provider_to_resolve:
return provider_name
else:
return None
def wire(self, modules=None, packages=None):
"""Wire container providers with provided packages and modules.

File diff suppressed because it is too large Load Diff

View File

@ -185,9 +185,9 @@ cdef class List(Provider):
cdef class Container(Provider):
cdef object container_cls
cdef dict overriding_providers
cdef object container
cdef object __container_cls
cdef dict __overriding_providers
cdef object __container
cpdef object _provide(self, tuple args, dict kwargs)

View File

@ -54,6 +54,8 @@ class Object(Provider, Generic[T]):
class Delegate(Provider):
def __init__(self, provides: Provider) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
@property
def provides(self) -> Provider: ...
class Dependency(Provider, Generic[T]):
@ -125,9 +127,11 @@ class ConfigurationOption(Provider):
def __init__(self, name: Tuple[str], root: Configuration) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
def __getattr__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
@property
def root(self) -> Configuration: ...
def get_name(self) -> str: ...
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
def get_name_segments(self) -> Tuple[Union[str, Provider]]: ...
def as_int(self) -> TypedConfigurationOption[int]: ...
def as_float(self) -> TypedConfigurationOption[float]: ...
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
@ -147,10 +151,9 @@ class Configuration(Object):
DEFAULT_NAME: str = 'config'
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
def __getattr__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
def get_name(self) -> str: ...
def get(self, selector: str) -> Any: ...
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
def set(self, selector: str, value: Any) -> OverridingContext: ...
def reset_cache(self) -> None: ...
def update(self, value: Any) -> None: ...
@ -275,6 +278,8 @@ class Container(Provider):
def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Provider) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
def __getattr__(self, name: str) -> Provider: ...
@property
def container(self) -> T: ...
class Selector(Provider):

View File

@ -1144,26 +1144,16 @@ cdef class ConfigurationOption(Provider):
segment() if is_provider(segment) else segment for segment in self.__name
)
@property
def root(self):
return self.__root_ref()
def get_name(self):
root = self.__root_ref()
return '.'.join((root.get_name(), self._get_self_name()))
def get_option_provider(self, selector):
"""Return configuration option provider.
:param selector: Selector string, e.g. "option1.option2"
:type selector: str
:return: Option provider.
:rtype: :py:class:`ConfigurationOption`
"""
key, *other_keys = selector.split('.')
child = getattr(self, key)
if other_keys:
child = child.get_option_provider('.'.join(other_keys))
return child
def get_name_segments(self):
return self.__name
def as_int(self):
return TypedConfigurationOption(int, self)
@ -1386,23 +1376,6 @@ cdef class Configuration(Object):
return value
def get_option_provider(self, selector):
"""Return configuration option provider.
:param selector: Selector string, e.g. "option1.option2"
:type selector: str
:return: Option provider.
:rtype: :py:class:`ConfigurationOption`
"""
key, *other_keys = selector.split('.')
child = getattr(self, key)
if other_keys:
child = child.get_option_provider('.'.join(other_keys))
return child
def set(self, selector, value):
"""Override configuration option.
@ -2504,13 +2477,13 @@ cdef class Container(Provider):
def __init__(self, container_cls, container=None, **overriding_providers):
"""Initialize provider."""
self.container_cls = container_cls
self.overriding_providers = overriding_providers
self.__container_cls = container_cls
self.__overriding_providers = overriding_providers
if container is None:
container = container_cls()
container.override_providers(**overriding_providers)
self.container = container
self.__container = container
super(Container, self).__init__()
@ -2521,9 +2494,9 @@ cdef class Container(Provider):
return copied
copied = self.__class__(
self.container_cls,
deepcopy(self.container, memo),
**deepcopy(self.overriding_providers, memo),
self.__container_cls,
deepcopy(self.__container, memo),
**deepcopy(self.__overriding_providers, memo),
)
return copied
@ -2535,7 +2508,11 @@ cdef class Container(Provider):
'\'{cls}\' object has no attribute '
'\'{attribute_name}\''.format(cls=self.__class__.__name__,
attribute_name=name))
return getattr(self.container, name)
return getattr(self.__container, name)
@property
def container(self):
return self.__container
def override(self, provider):
"""Override provider with another provider."""
@ -2543,7 +2520,7 @@ cdef class Container(Provider):
cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance."""
return self.container
return self.__container
cdef class Selector(Provider):

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar, cast
if sys.version_info < (3, 7):
from typing import GenericMeta
@ -17,8 +17,114 @@ else:
from . import providers
AnyContainer = Any
T = TypeVar('T')
AnyContainer = Any
class ProvidersMap:
def __init__(self, container):
self._container = container
self._map = self._create_providers_map(
current_providers=container.providers,
original_providers=container.declarative_parent.providers,
)
def resolve_provider(self, provider: providers.Provider) -> providers.Provider:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
elif isinstance(provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
return self._resolve_provided_instance(provider)
elif isinstance(provider, providers.ConfigurationOption):
return self._resolve_config_option(provider)
elif isinstance(provider, providers.TypedConfigurationOption):
return self._resolve_config_option(provider.option, as_=provider.provides)
else:
return self._resolve_provider(provider)
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider:
return self._resolve_provider(original.provides)
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider:
modifiers = []
while isinstance(original, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
modifiers.insert(0, original)
original = original.provides
new = self._resolve_provider(original)
for modifier in modifiers:
if isinstance(modifier, providers.ProvidedInstance):
new = new.provided
elif isinstance(modifier, providers.AttributeGetter):
new = getattr(new, modifier.name)
elif isinstance(modifier, providers.ItemGetter):
new = new[modifier.name]
elif isinstance(modifier, providers.MethodCaller):
new = new.call(
*modifier.args,
**modifier.kwargs,
)
return new
def _resolve_config_option(
self,
original: providers.ConfigurationOption,
as_: Any = None,
) -> providers.Provider:
original_root = original.root
new = self._resolve_provider(original_root)
new = cast(providers.Configuration, new)
for segment in original.get_name_segments():
if providers.is_provider(segment):
segment = self.resolve_provider(segment)
new = new[segment]
else:
new = getattr(new, segment)
if as_:
new = new.as_(as_)
return new
def _resolve_provider(self, original: providers.Provider) -> providers.Provider:
try:
return self._map[original]
except KeyError:
raise Exception('Unable to resolve original provider')
@classmethod
def _create_providers_map(
cls,
current_providers: Dict[str, providers.Provider],
original_providers: Dict[str, providers.Provider],
) -> Dict[providers.Provider, providers.Provider]:
providers_map = {}
for provider_name, current_provider in current_providers.items():
original_provider = original_providers[provider_name]
providers_map[original_provider] = current_provider
if isinstance(current_provider, providers.Container) \
and isinstance(original_provider, providers.Container):
subcontainer_map = cls._create_providers_map(
current_providers=current_provider.container.providers,
original_providers=original_provider.container.providers,
)
providers_map.update(subcontainer_map)
return providers_map
def wire(
@ -28,6 +134,7 @@ def wire(
packages: Optional[Iterable[ModuleType]] = None,
) -> None:
"""Wire container providers with provided packages and modules."""
# TODO: Add protection to only wire declarative container instances
if not modules:
modules = []
@ -35,12 +142,14 @@ def wire(
for package in packages:
modules.extend(_fetch_modules(package))
providers_map = ProvidersMap(container)
for module in modules:
for name, member in inspect.getmembers(module):
if inspect.isfunction(member):
_patch_fn(module, name, member, container)
_patch_fn(module, name, member, providers_map)
elif inspect.isclass(member):
_patch_cls(member, container)
_patch_cls(member, providers_map)
def unwire(
@ -66,12 +175,12 @@ def unwire(
def _patch_cls(
cls: Type[Any],
container: AnyContainer,
providers_map: ProvidersMap,
) -> None:
if not hasattr(cls, '__init__'):
return
init_method = getattr(cls, '__init__')
injections = _resolve_injections(init_method, container)
injections = _resolve_injections(init_method, providers_map)
if not injections:
return
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
@ -90,9 +199,9 @@ def _patch_fn(
module: ModuleType,
name: str,
fn: Callable[..., Any],
container: AnyContainer,
providers_map: ProvidersMap,
) -> None:
injections = _resolve_injections(fn, container)
injections = _resolve_injections(fn, providers_map)
if not injections:
return
setattr(module, name, _patch_with_injections(fn, injections))
@ -108,9 +217,7 @@ def _unpatch_fn(
setattr(module, name, _get_original_from_patched(fn))
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]: # noqa
config = _resolve_container_config(container)
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa
signature = inspect.signature(fn)
injections = {}
@ -119,35 +226,7 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
continue
marker = parameter.default
if config and isinstance(marker.provider, providers.ConfigurationOption):
provider = _prepare_config_injection(marker.provider.get_name(), config)
elif config and isinstance(marker.provider, providers.TypedConfigurationOption):
provider = _prepare_config_injection(
marker.provider.option.get_name(),
config,
marker.provider.provides,
)
elif isinstance(marker.provider, providers.Delegate):
provider_name = container.resolve_provider_name(marker.provider.provides)
provider = container.providers[provider_name]
elif isinstance(marker.provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provider = _prepare_provided_instance_injection(
marker.provider,
container,
)
elif isinstance(marker.provider, providers.Provider):
provider_name = container.resolve_provider_name(marker.provider)
if not provider_name:
continue
provider = container.providers[provider_name]
else:
continue
provider = providers_map.resolve_provider(marker.provider)
if isinstance(marker, Provide):
injections[parameter_name] = provider
elif isinstance(marker, Provider):
@ -156,61 +235,6 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
return injections
def _prepare_config_injection(
option_name: str,
config: providers.Configuration,
as_: Any = None,
) -> providers.Provider:
_, *parts = option_name.split('.')
relative_option_name = '.'.join(parts)
provider = config.get_option_provider(relative_option_name)
if as_:
return provider.as_(as_)
return provider
def _prepare_provided_instance_injection(
current_provider: providers.Provider,
container: AnyContainer,
) -> providers.Provider:
provided_instance_markers = []
instance_provider_marker = current_provider
while isinstance(instance_provider_marker, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provided_instance_markers.insert(0, instance_provider_marker)
instance_provider_marker = instance_provider_marker.provides
provider_name = container.resolve_provider_name(instance_provider_marker)
provider = container.providers[provider_name]
for provided_instance in provided_instance_markers:
if isinstance(provided_instance, providers.ProvidedInstance):
provider = provider.provided
elif isinstance(provided_instance, providers.AttributeGetter):
provider = getattr(provider, provided_instance.name)
elif isinstance(provided_instance, providers.ItemGetter):
provider = provider[provided_instance.name]
elif isinstance(provided_instance, providers.MethodCaller):
provider = provider.call(
*provided_instance.args,
**provided_instance.kwargs,
)
return provider
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
for provider in container.providers.values():
if isinstance(provider, providers.Configuration):
return provider
else:
return None
def _fetch_modules(package):
modules = []
for loader, module_name, is_pkg in pkgutil.walk_packages(

View File

@ -3,8 +3,15 @@ from dependency_injector import containers, providers
from .service import Service
class SubContainer(containers.DeclarativeContainer):
int_object = providers.Object(1)
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
service = providers.Factory(Service)
sub = providers.Container(SubContainer)

View File

@ -39,3 +39,11 @@ def test_provide_provider(service_provider: Callable[..., Service] = Provider[Co
def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]):
return some_value
def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_object]):
return some_value
def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]):
return some_value

View File

@ -72,3 +72,25 @@ class WiringTest(unittest.TestCase):
with self.container.service.override(TestService()):
some_value = module.test_provided_instance()
self.assertEqual(some_value, 10)
def test_subcontainer(self):
some_value = module.test_subcontainer_provider()
self.assertEqual(some_value, 1)
def test_config_invariant(self):
config = {
'option': {
'a': 1,
'b': 2,
},
'switch': 'a',
}
self.container.config.from_dict(config)
with self.container.config.switch.override('a'):
value_a = module.test_config_invariant()
self.assertEqual(value_a, 1)
with self.container.config.switch.override('b'):
value_b = module.test_config_invariant()
self.assertEqual(value_b, 2)