Wiring refactoring (#296)

* Refactor wiring

* Add todos to wiring

* Implement wiring of config invariant

* Implement sub containers wiring + add tests

* Add test for wiring config invariant
This commit is contained in:
Roman Mogylatov 2020-09-27 23:10:11 -04:00 committed by GitHub
parent 7f854548d6
commit 6182b8448a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 3847 additions and 4409 deletions

File diff suppressed because it is too large Load Diff

View File

@ -186,20 +186,6 @@ class DynamicContainer(object):
for provider in six.itervalues(self.providers): for provider in six.itervalues(self.providers):
provider.reset_override() provider.reset_override()
def resolve_provider_name(self, provider_to_resolve):
"""Try to resolve provider name by its instance."""
if self.declarative_parent:
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
if provider_name:
return provider_name
for provider_name, container_provider in self.providers.items():
if container_provider is provider_to_resolve:
return provider_name
else:
return None
def wire(self, modules=None, packages=None): def wire(self, modules=None, packages=None):
"""Wire container providers with provided packages and modules. """Wire container providers with provided packages and modules.

File diff suppressed because it is too large Load Diff

View File

@ -185,9 +185,9 @@ cdef class List(Provider):
cdef class Container(Provider): cdef class Container(Provider):
cdef object container_cls cdef object __container_cls
cdef dict overriding_providers cdef dict __overriding_providers
cdef object container cdef object __container
cpdef object _provide(self, tuple args, dict kwargs) cpdef object _provide(self, tuple args, dict kwargs)

View File

@ -54,6 +54,8 @@ class Object(Provider, Generic[T]):
class Delegate(Provider): class Delegate(Provider):
def __init__(self, provides: Provider) -> None: ... def __init__(self, provides: Provider) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ... def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
@property
def provides(self) -> Provider: ...
class Dependency(Provider, Generic[T]): class Dependency(Provider, Generic[T]):
@ -125,9 +127,11 @@ class ConfigurationOption(Provider):
def __init__(self, name: Tuple[str], root: Configuration) -> None: ... def __init__(self, name: Tuple[str], root: Configuration) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ... def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
def __getattr__(self, item: str) -> ConfigurationOption: ... def __getattr__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: str) -> ConfigurationOption: ... def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
@property
def root(self) -> Configuration: ...
def get_name(self) -> str: ... def get_name(self) -> str: ...
def get_option_provider(self, selector: str) -> ConfigurationOption: ... def get_name_segments(self) -> Tuple[Union[str, Provider]]: ...
def as_int(self) -> TypedConfigurationOption[int]: ... def as_int(self) -> TypedConfigurationOption[int]: ...
def as_float(self) -> TypedConfigurationOption[float]: ... def as_float(self) -> TypedConfigurationOption[float]: ...
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ... def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
@ -147,10 +151,9 @@ class Configuration(Object):
DEFAULT_NAME: str = 'config' DEFAULT_NAME: str = 'config'
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ... def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
def __getattr__(self, item: str) -> ConfigurationOption: ... def __getattr__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: str) -> ConfigurationOption: ... def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
def get_name(self) -> str: ... def get_name(self) -> str: ...
def get(self, selector: str) -> Any: ... def get(self, selector: str) -> Any: ...
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
def set(self, selector: str, value: Any) -> OverridingContext: ... def set(self, selector: str, value: Any) -> OverridingContext: ...
def reset_cache(self) -> None: ... def reset_cache(self) -> None: ...
def update(self, value: Any) -> None: ... def update(self, value: Any) -> None: ...
@ -275,6 +278,8 @@ class Container(Provider):
def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Provider) -> None: ... def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Provider) -> None: ...
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ... def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
def __getattr__(self, name: str) -> Provider: ... def __getattr__(self, name: str) -> Provider: ...
@property
def container(self) -> T: ...
class Selector(Provider): class Selector(Provider):

View File

@ -1144,26 +1144,16 @@ cdef class ConfigurationOption(Provider):
segment() if is_provider(segment) else segment for segment in self.__name segment() if is_provider(segment) else segment for segment in self.__name
) )
@property
def root(self):
return self.__root_ref()
def get_name(self): def get_name(self):
root = self.__root_ref() root = self.__root_ref()
return '.'.join((root.get_name(), self._get_self_name())) return '.'.join((root.get_name(), self._get_self_name()))
def get_option_provider(self, selector): def get_name_segments(self):
"""Return configuration option provider. return self.__name
:param selector: Selector string, e.g. "option1.option2"
:type selector: str
:return: Option provider.
:rtype: :py:class:`ConfigurationOption`
"""
key, *other_keys = selector.split('.')
child = getattr(self, key)
if other_keys:
child = child.get_option_provider('.'.join(other_keys))
return child
def as_int(self): def as_int(self):
return TypedConfigurationOption(int, self) return TypedConfigurationOption(int, self)
@ -1386,23 +1376,6 @@ cdef class Configuration(Object):
return value return value
def get_option_provider(self, selector):
"""Return configuration option provider.
:param selector: Selector string, e.g. "option1.option2"
:type selector: str
:return: Option provider.
:rtype: :py:class:`ConfigurationOption`
"""
key, *other_keys = selector.split('.')
child = getattr(self, key)
if other_keys:
child = child.get_option_provider('.'.join(other_keys))
return child
def set(self, selector, value): def set(self, selector, value):
"""Override configuration option. """Override configuration option.
@ -2504,13 +2477,13 @@ cdef class Container(Provider):
def __init__(self, container_cls, container=None, **overriding_providers): def __init__(self, container_cls, container=None, **overriding_providers):
"""Initialize provider.""" """Initialize provider."""
self.container_cls = container_cls self.__container_cls = container_cls
self.overriding_providers = overriding_providers self.__overriding_providers = overriding_providers
if container is None: if container is None:
container = container_cls() container = container_cls()
container.override_providers(**overriding_providers) container.override_providers(**overriding_providers)
self.container = container self.__container = container
super(Container, self).__init__() super(Container, self).__init__()
@ -2521,9 +2494,9 @@ cdef class Container(Provider):
return copied return copied
copied = self.__class__( copied = self.__class__(
self.container_cls, self.__container_cls,
deepcopy(self.container, memo), deepcopy(self.__container, memo),
**deepcopy(self.overriding_providers, memo), **deepcopy(self.__overriding_providers, memo),
) )
return copied return copied
@ -2535,7 +2508,11 @@ cdef class Container(Provider):
'\'{cls}\' object has no attribute ' '\'{cls}\' object has no attribute '
'\'{attribute_name}\''.format(cls=self.__class__.__name__, '\'{attribute_name}\''.format(cls=self.__class__.__name__,
attribute_name=name)) attribute_name=name))
return getattr(self.container, name) return getattr(self.__container, name)
@property
def container(self):
return self.__container
def override(self, provider): def override(self, provider):
"""Override provider with another provider.""" """Override provider with another provider."""
@ -2543,7 +2520,7 @@ cdef class Container(Provider):
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance.""" """Return single instance."""
return self.container return self.__container
cdef class Selector(Provider): cdef class Selector(Provider):

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil import pkgutil
import sys import sys
from types import ModuleType from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar, cast
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
from typing import GenericMeta from typing import GenericMeta
@ -17,8 +17,114 @@ else:
from . import providers from . import providers
AnyContainer = Any
T = TypeVar('T') T = TypeVar('T')
AnyContainer = Any
class ProvidersMap:
def __init__(self, container):
self._container = container
self._map = self._create_providers_map(
current_providers=container.providers,
original_providers=container.declarative_parent.providers,
)
def resolve_provider(self, provider: providers.Provider) -> providers.Provider:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
elif isinstance(provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
return self._resolve_provided_instance(provider)
elif isinstance(provider, providers.ConfigurationOption):
return self._resolve_config_option(provider)
elif isinstance(provider, providers.TypedConfigurationOption):
return self._resolve_config_option(provider.option, as_=provider.provides)
else:
return self._resolve_provider(provider)
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider:
return self._resolve_provider(original.provides)
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider:
modifiers = []
while isinstance(original, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
modifiers.insert(0, original)
original = original.provides
new = self._resolve_provider(original)
for modifier in modifiers:
if isinstance(modifier, providers.ProvidedInstance):
new = new.provided
elif isinstance(modifier, providers.AttributeGetter):
new = getattr(new, modifier.name)
elif isinstance(modifier, providers.ItemGetter):
new = new[modifier.name]
elif isinstance(modifier, providers.MethodCaller):
new = new.call(
*modifier.args,
**modifier.kwargs,
)
return new
def _resolve_config_option(
self,
original: providers.ConfigurationOption,
as_: Any = None,
) -> providers.Provider:
original_root = original.root
new = self._resolve_provider(original_root)
new = cast(providers.Configuration, new)
for segment in original.get_name_segments():
if providers.is_provider(segment):
segment = self.resolve_provider(segment)
new = new[segment]
else:
new = getattr(new, segment)
if as_:
new = new.as_(as_)
return new
def _resolve_provider(self, original: providers.Provider) -> providers.Provider:
try:
return self._map[original]
except KeyError:
raise Exception('Unable to resolve original provider')
@classmethod
def _create_providers_map(
cls,
current_providers: Dict[str, providers.Provider],
original_providers: Dict[str, providers.Provider],
) -> Dict[providers.Provider, providers.Provider]:
providers_map = {}
for provider_name, current_provider in current_providers.items():
original_provider = original_providers[provider_name]
providers_map[original_provider] = current_provider
if isinstance(current_provider, providers.Container) \
and isinstance(original_provider, providers.Container):
subcontainer_map = cls._create_providers_map(
current_providers=current_provider.container.providers,
original_providers=original_provider.container.providers,
)
providers_map.update(subcontainer_map)
return providers_map
def wire( def wire(
@ -28,6 +134,7 @@ def wire(
packages: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None,
) -> None: ) -> None:
"""Wire container providers with provided packages and modules.""" """Wire container providers with provided packages and modules."""
# TODO: Add protection to only wire declarative container instances
if not modules: if not modules:
modules = [] modules = []
@ -35,12 +142,14 @@ def wire(
for package in packages: for package in packages:
modules.extend(_fetch_modules(package)) modules.extend(_fetch_modules(package))
providers_map = ProvidersMap(container)
for module in modules: for module in modules:
for name, member in inspect.getmembers(module): for name, member in inspect.getmembers(module):
if inspect.isfunction(member): if inspect.isfunction(member):
_patch_fn(module, name, member, container) _patch_fn(module, name, member, providers_map)
elif inspect.isclass(member): elif inspect.isclass(member):
_patch_cls(member, container) _patch_cls(member, providers_map)
def unwire( def unwire(
@ -66,12 +175,12 @@ def unwire(
def _patch_cls( def _patch_cls(
cls: Type[Any], cls: Type[Any],
container: AnyContainer, providers_map: ProvidersMap,
) -> None: ) -> None:
if not hasattr(cls, '__init__'): if not hasattr(cls, '__init__'):
return return
init_method = getattr(cls, '__init__') init_method = getattr(cls, '__init__')
injections = _resolve_injections(init_method, container) injections = _resolve_injections(init_method, providers_map)
if not injections: if not injections:
return return
setattr(cls, '__init__', _patch_with_injections(init_method, injections)) setattr(cls, '__init__', _patch_with_injections(init_method, injections))
@ -90,9 +199,9 @@ def _patch_fn(
module: ModuleType, module: ModuleType,
name: str, name: str,
fn: Callable[..., Any], fn: Callable[..., Any],
container: AnyContainer, providers_map: ProvidersMap,
) -> None: ) -> None:
injections = _resolve_injections(fn, container) injections = _resolve_injections(fn, providers_map)
if not injections: if not injections:
return return
setattr(module, name, _patch_with_injections(fn, injections)) setattr(module, name, _patch_with_injections(fn, injections))
@ -108,9 +217,7 @@ def _unpatch_fn(
setattr(module, name, _get_original_from_patched(fn)) setattr(module, name, _get_original_from_patched(fn))
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]: # noqa def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa
config = _resolve_container_config(container)
signature = inspect.signature(fn) signature = inspect.signature(fn)
injections = {} injections = {}
@ -119,35 +226,7 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
continue continue
marker = parameter.default marker = parameter.default
if config and isinstance(marker.provider, providers.ConfigurationOption): provider = providers_map.resolve_provider(marker.provider)
provider = _prepare_config_injection(marker.provider.get_name(), config)
elif config and isinstance(marker.provider, providers.TypedConfigurationOption):
provider = _prepare_config_injection(
marker.provider.option.get_name(),
config,
marker.provider.provides,
)
elif isinstance(marker.provider, providers.Delegate):
provider_name = container.resolve_provider_name(marker.provider.provides)
provider = container.providers[provider_name]
elif isinstance(marker.provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provider = _prepare_provided_instance_injection(
marker.provider,
container,
)
elif isinstance(marker.provider, providers.Provider):
provider_name = container.resolve_provider_name(marker.provider)
if not provider_name:
continue
provider = container.providers[provider_name]
else:
continue
if isinstance(marker, Provide): if isinstance(marker, Provide):
injections[parameter_name] = provider injections[parameter_name] = provider
elif isinstance(marker, Provider): elif isinstance(marker, Provider):
@ -156,61 +235,6 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
return injections return injections
def _prepare_config_injection(
option_name: str,
config: providers.Configuration,
as_: Any = None,
) -> providers.Provider:
_, *parts = option_name.split('.')
relative_option_name = '.'.join(parts)
provider = config.get_option_provider(relative_option_name)
if as_:
return provider.as_(as_)
return provider
def _prepare_provided_instance_injection(
current_provider: providers.Provider,
container: AnyContainer,
) -> providers.Provider:
provided_instance_markers = []
instance_provider_marker = current_provider
while isinstance(instance_provider_marker, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provided_instance_markers.insert(0, instance_provider_marker)
instance_provider_marker = instance_provider_marker.provides
provider_name = container.resolve_provider_name(instance_provider_marker)
provider = container.providers[provider_name]
for provided_instance in provided_instance_markers:
if isinstance(provided_instance, providers.ProvidedInstance):
provider = provider.provided
elif isinstance(provided_instance, providers.AttributeGetter):
provider = getattr(provider, provided_instance.name)
elif isinstance(provided_instance, providers.ItemGetter):
provider = provider[provided_instance.name]
elif isinstance(provided_instance, providers.MethodCaller):
provider = provider.call(
*provided_instance.args,
**provided_instance.kwargs,
)
return provider
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
for provider in container.providers.values():
if isinstance(provider, providers.Configuration):
return provider
else:
return None
def _fetch_modules(package): def _fetch_modules(package):
modules = [] modules = []
for loader, module_name, is_pkg in pkgutil.walk_packages( for loader, module_name, is_pkg in pkgutil.walk_packages(

View File

@ -3,8 +3,15 @@ from dependency_injector import containers, providers
from .service import Service from .service import Service
class SubContainer(containers.DeclarativeContainer):
int_object = providers.Object(1)
class Container(containers.DeclarativeContainer): class Container(containers.DeclarativeContainer):
config = providers.Configuration() config = providers.Configuration()
service = providers.Factory(Service) service = providers.Factory(Service)
sub = providers.Container(SubContainer)

View File

@ -39,3 +39,11 @@ def test_provide_provider(service_provider: Callable[..., Service] = Provider[Co
def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]): def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]):
return some_value return some_value
def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_object]):
return some_value
def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]):
return some_value

View File

@ -72,3 +72,25 @@ class WiringTest(unittest.TestCase):
with self.container.service.override(TestService()): with self.container.service.override(TestService()):
some_value = module.test_provided_instance() some_value = module.test_provided_instance()
self.assertEqual(some_value, 10) self.assertEqual(some_value, 10)
def test_subcontainer(self):
some_value = module.test_subcontainer_provider()
self.assertEqual(some_value, 1)
def test_config_invariant(self):
config = {
'option': {
'a': 1,
'b': 2,
},
'switch': 'a',
}
self.container.config.from_dict(config)
with self.container.config.switch.override('a'):
value_a = module.test_config_invariant()
self.assertEqual(value_a, 1)
with self.container.config.switch.override('b'):
value_b = module.test_config_invariant()
self.assertEqual(value_b, 2)