mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-02-07 07:00:49 +03:00
Wiring refactoring (#296)
* Refactor wiring * Add todos to wiring * Implement wiring of config invariant * Implement sub containers wiring + add tests * Add test for wiring config invariant
This commit is contained in:
parent
7f854548d6
commit
6182b8448a
File diff suppressed because it is too large
Load Diff
|
@ -186,20 +186,6 @@ class DynamicContainer(object):
|
||||||
for provider in six.itervalues(self.providers):
|
for provider in six.itervalues(self.providers):
|
||||||
provider.reset_override()
|
provider.reset_override()
|
||||||
|
|
||||||
def resolve_provider_name(self, provider_to_resolve):
|
|
||||||
"""Try to resolve provider name by its instance."""
|
|
||||||
if self.declarative_parent:
|
|
||||||
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
|
|
||||||
if provider_name:
|
|
||||||
return provider_name
|
|
||||||
|
|
||||||
for provider_name, container_provider in self.providers.items():
|
|
||||||
if container_provider is provider_to_resolve:
|
|
||||||
return provider_name
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def wire(self, modules=None, packages=None):
|
def wire(self, modules=None, packages=None):
|
||||||
"""Wire container providers with provided packages and modules.
|
"""Wire container providers with provided packages and modules.
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -185,9 +185,9 @@ cdef class List(Provider):
|
||||||
|
|
||||||
|
|
||||||
cdef class Container(Provider):
|
cdef class Container(Provider):
|
||||||
cdef object container_cls
|
cdef object __container_cls
|
||||||
cdef dict overriding_providers
|
cdef dict __overriding_providers
|
||||||
cdef object container
|
cdef object __container
|
||||||
|
|
||||||
cpdef object _provide(self, tuple args, dict kwargs)
|
cpdef object _provide(self, tuple args, dict kwargs)
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,8 @@ class Object(Provider, Generic[T]):
|
||||||
class Delegate(Provider):
|
class Delegate(Provider):
|
||||||
def __init__(self, provides: Provider) -> None: ...
|
def __init__(self, provides: Provider) -> None: ...
|
||||||
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
|
def __call__(self, *args: Injection, **kwargs: Injection) -> Provider: ...
|
||||||
|
@property
|
||||||
|
def provides(self) -> Provider: ...
|
||||||
|
|
||||||
|
|
||||||
class Dependency(Provider, Generic[T]):
|
class Dependency(Provider, Generic[T]):
|
||||||
|
@ -125,9 +127,11 @@ class ConfigurationOption(Provider):
|
||||||
def __init__(self, name: Tuple[str], root: Configuration) -> None: ...
|
def __init__(self, name: Tuple[str], root: Configuration) -> None: ...
|
||||||
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
|
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
|
||||||
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
||||||
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
|
||||||
|
@property
|
||||||
|
def root(self) -> Configuration: ...
|
||||||
def get_name(self) -> str: ...
|
def get_name(self) -> str: ...
|
||||||
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
|
def get_name_segments(self) -> Tuple[Union[str, Provider]]: ...
|
||||||
def as_int(self) -> TypedConfigurationOption[int]: ...
|
def as_int(self) -> TypedConfigurationOption[int]: ...
|
||||||
def as_float(self) -> TypedConfigurationOption[float]: ...
|
def as_float(self) -> TypedConfigurationOption[float]: ...
|
||||||
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
|
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
|
||||||
|
@ -147,10 +151,9 @@ class Configuration(Object):
|
||||||
DEFAULT_NAME: str = 'config'
|
DEFAULT_NAME: str = 'config'
|
||||||
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
|
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
|
||||||
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
||||||
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
|
||||||
def get_name(self) -> str: ...
|
def get_name(self) -> str: ...
|
||||||
def get(self, selector: str) -> Any: ...
|
def get(self, selector: str) -> Any: ...
|
||||||
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
|
|
||||||
def set(self, selector: str, value: Any) -> OverridingContext: ...
|
def set(self, selector: str, value: Any) -> OverridingContext: ...
|
||||||
def reset_cache(self) -> None: ...
|
def reset_cache(self) -> None: ...
|
||||||
def update(self, value: Any) -> None: ...
|
def update(self, value: Any) -> None: ...
|
||||||
|
@ -275,6 +278,8 @@ class Container(Provider):
|
||||||
def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Provider) -> None: ...
|
def __init__(self, container_cls: Type[T], container: Optional[T] = None, **overriding_providers: Provider) -> None: ...
|
||||||
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
|
def __call__(self, *args: Injection, **kwargs: Injection) -> T: ...
|
||||||
def __getattr__(self, name: str) -> Provider: ...
|
def __getattr__(self, name: str) -> Provider: ...
|
||||||
|
@property
|
||||||
|
def container(self) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
class Selector(Provider):
|
class Selector(Provider):
|
||||||
|
|
|
@ -1144,26 +1144,16 @@ cdef class ConfigurationOption(Provider):
|
||||||
segment() if is_provider(segment) else segment for segment in self.__name
|
segment() if is_provider(segment) else segment for segment in self.__name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def root(self):
|
||||||
|
return self.__root_ref()
|
||||||
|
|
||||||
def get_name(self):
|
def get_name(self):
|
||||||
root = self.__root_ref()
|
root = self.__root_ref()
|
||||||
return '.'.join((root.get_name(), self._get_self_name()))
|
return '.'.join((root.get_name(), self._get_self_name()))
|
||||||
|
|
||||||
def get_option_provider(self, selector):
|
def get_name_segments(self):
|
||||||
"""Return configuration option provider.
|
return self.__name
|
||||||
|
|
||||||
:param selector: Selector string, e.g. "option1.option2"
|
|
||||||
:type selector: str
|
|
||||||
|
|
||||||
:return: Option provider.
|
|
||||||
:rtype: :py:class:`ConfigurationOption`
|
|
||||||
"""
|
|
||||||
key, *other_keys = selector.split('.')
|
|
||||||
child = getattr(self, key)
|
|
||||||
|
|
||||||
if other_keys:
|
|
||||||
child = child.get_option_provider('.'.join(other_keys))
|
|
||||||
|
|
||||||
return child
|
|
||||||
|
|
||||||
def as_int(self):
|
def as_int(self):
|
||||||
return TypedConfigurationOption(int, self)
|
return TypedConfigurationOption(int, self)
|
||||||
|
@ -1386,23 +1376,6 @@ cdef class Configuration(Object):
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def get_option_provider(self, selector):
|
|
||||||
"""Return configuration option provider.
|
|
||||||
|
|
||||||
:param selector: Selector string, e.g. "option1.option2"
|
|
||||||
:type selector: str
|
|
||||||
|
|
||||||
:return: Option provider.
|
|
||||||
:rtype: :py:class:`ConfigurationOption`
|
|
||||||
"""
|
|
||||||
key, *other_keys = selector.split('.')
|
|
||||||
child = getattr(self, key)
|
|
||||||
|
|
||||||
if other_keys:
|
|
||||||
child = child.get_option_provider('.'.join(other_keys))
|
|
||||||
|
|
||||||
return child
|
|
||||||
|
|
||||||
def set(self, selector, value):
|
def set(self, selector, value):
|
||||||
"""Override configuration option.
|
"""Override configuration option.
|
||||||
|
|
||||||
|
@ -2504,13 +2477,13 @@ cdef class Container(Provider):
|
||||||
|
|
||||||
def __init__(self, container_cls, container=None, **overriding_providers):
|
def __init__(self, container_cls, container=None, **overriding_providers):
|
||||||
"""Initialize provider."""
|
"""Initialize provider."""
|
||||||
self.container_cls = container_cls
|
self.__container_cls = container_cls
|
||||||
self.overriding_providers = overriding_providers
|
self.__overriding_providers = overriding_providers
|
||||||
|
|
||||||
if container is None:
|
if container is None:
|
||||||
container = container_cls()
|
container = container_cls()
|
||||||
container.override_providers(**overriding_providers)
|
container.override_providers(**overriding_providers)
|
||||||
self.container = container
|
self.__container = container
|
||||||
|
|
||||||
super(Container, self).__init__()
|
super(Container, self).__init__()
|
||||||
|
|
||||||
|
@ -2521,9 +2494,9 @@ cdef class Container(Provider):
|
||||||
return copied
|
return copied
|
||||||
|
|
||||||
copied = self.__class__(
|
copied = self.__class__(
|
||||||
self.container_cls,
|
self.__container_cls,
|
||||||
deepcopy(self.container, memo),
|
deepcopy(self.__container, memo),
|
||||||
**deepcopy(self.overriding_providers, memo),
|
**deepcopy(self.__overriding_providers, memo),
|
||||||
)
|
)
|
||||||
|
|
||||||
return copied
|
return copied
|
||||||
|
@ -2535,7 +2508,11 @@ cdef class Container(Provider):
|
||||||
'\'{cls}\' object has no attribute '
|
'\'{cls}\' object has no attribute '
|
||||||
'\'{attribute_name}\''.format(cls=self.__class__.__name__,
|
'\'{attribute_name}\''.format(cls=self.__class__.__name__,
|
||||||
attribute_name=name))
|
attribute_name=name))
|
||||||
return getattr(self.container, name)
|
return getattr(self.__container, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def container(self):
|
||||||
|
return self.__container
|
||||||
|
|
||||||
def override(self, provider):
|
def override(self, provider):
|
||||||
"""Override provider with another provider."""
|
"""Override provider with another provider."""
|
||||||
|
@ -2543,7 +2520,7 @@ cdef class Container(Provider):
|
||||||
|
|
||||||
cpdef object _provide(self, tuple args, dict kwargs):
|
cpdef object _provide(self, tuple args, dict kwargs):
|
||||||
"""Return single instance."""
|
"""Return single instance."""
|
||||||
return self.container
|
return self.__container
|
||||||
|
|
||||||
|
|
||||||
cdef class Selector(Provider):
|
cdef class Selector(Provider):
|
||||||
|
|
|
@ -5,7 +5,7 @@ import inspect
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
|
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar, cast
|
||||||
|
|
||||||
if sys.version_info < (3, 7):
|
if sys.version_info < (3, 7):
|
||||||
from typing import GenericMeta
|
from typing import GenericMeta
|
||||||
|
@ -17,8 +17,114 @@ else:
|
||||||
from . import providers
|
from . import providers
|
||||||
|
|
||||||
|
|
||||||
AnyContainer = Any
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
AnyContainer = Any
|
||||||
|
|
||||||
|
|
||||||
|
class ProvidersMap:
|
||||||
|
|
||||||
|
def __init__(self, container):
|
||||||
|
self._container = container
|
||||||
|
self._map = self._create_providers_map(
|
||||||
|
current_providers=container.providers,
|
||||||
|
original_providers=container.declarative_parent.providers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_provider(self, provider: providers.Provider) -> providers.Provider:
|
||||||
|
if isinstance(provider, providers.Delegate):
|
||||||
|
return self._resolve_delegate(provider)
|
||||||
|
elif isinstance(provider, (
|
||||||
|
providers.ProvidedInstance,
|
||||||
|
providers.AttributeGetter,
|
||||||
|
providers.ItemGetter,
|
||||||
|
providers.MethodCaller,
|
||||||
|
)):
|
||||||
|
return self._resolve_provided_instance(provider)
|
||||||
|
elif isinstance(provider, providers.ConfigurationOption):
|
||||||
|
return self._resolve_config_option(provider)
|
||||||
|
elif isinstance(provider, providers.TypedConfigurationOption):
|
||||||
|
return self._resolve_config_option(provider.option, as_=provider.provides)
|
||||||
|
else:
|
||||||
|
return self._resolve_provider(provider)
|
||||||
|
|
||||||
|
def _resolve_delegate(self, original: providers.Delegate) -> providers.Provider:
|
||||||
|
return self._resolve_provider(original.provides)
|
||||||
|
|
||||||
|
def _resolve_provided_instance(self, original: providers.Provider) -> providers.Provider:
|
||||||
|
modifiers = []
|
||||||
|
while isinstance(original, (
|
||||||
|
providers.ProvidedInstance,
|
||||||
|
providers.AttributeGetter,
|
||||||
|
providers.ItemGetter,
|
||||||
|
providers.MethodCaller,
|
||||||
|
)):
|
||||||
|
modifiers.insert(0, original)
|
||||||
|
original = original.provides
|
||||||
|
|
||||||
|
new = self._resolve_provider(original)
|
||||||
|
|
||||||
|
for modifier in modifiers:
|
||||||
|
if isinstance(modifier, providers.ProvidedInstance):
|
||||||
|
new = new.provided
|
||||||
|
elif isinstance(modifier, providers.AttributeGetter):
|
||||||
|
new = getattr(new, modifier.name)
|
||||||
|
elif isinstance(modifier, providers.ItemGetter):
|
||||||
|
new = new[modifier.name]
|
||||||
|
elif isinstance(modifier, providers.MethodCaller):
|
||||||
|
new = new.call(
|
||||||
|
*modifier.args,
|
||||||
|
**modifier.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
def _resolve_config_option(
|
||||||
|
self,
|
||||||
|
original: providers.ConfigurationOption,
|
||||||
|
as_: Any = None,
|
||||||
|
) -> providers.Provider:
|
||||||
|
original_root = original.root
|
||||||
|
new = self._resolve_provider(original_root)
|
||||||
|
new = cast(providers.Configuration, new)
|
||||||
|
|
||||||
|
for segment in original.get_name_segments():
|
||||||
|
if providers.is_provider(segment):
|
||||||
|
segment = self.resolve_provider(segment)
|
||||||
|
new = new[segment]
|
||||||
|
else:
|
||||||
|
new = getattr(new, segment)
|
||||||
|
|
||||||
|
if as_:
|
||||||
|
new = new.as_(as_)
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
def _resolve_provider(self, original: providers.Provider) -> providers.Provider:
|
||||||
|
try:
|
||||||
|
return self._map[original]
|
||||||
|
except KeyError:
|
||||||
|
raise Exception('Unable to resolve original provider')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_providers_map(
|
||||||
|
cls,
|
||||||
|
current_providers: Dict[str, providers.Provider],
|
||||||
|
original_providers: Dict[str, providers.Provider],
|
||||||
|
) -> Dict[providers.Provider, providers.Provider]:
|
||||||
|
providers_map = {}
|
||||||
|
for provider_name, current_provider in current_providers.items():
|
||||||
|
original_provider = original_providers[provider_name]
|
||||||
|
providers_map[original_provider] = current_provider
|
||||||
|
|
||||||
|
if isinstance(current_provider, providers.Container) \
|
||||||
|
and isinstance(original_provider, providers.Container):
|
||||||
|
subcontainer_map = cls._create_providers_map(
|
||||||
|
current_providers=current_provider.container.providers,
|
||||||
|
original_providers=original_provider.container.providers,
|
||||||
|
)
|
||||||
|
providers_map.update(subcontainer_map)
|
||||||
|
|
||||||
|
return providers_map
|
||||||
|
|
||||||
|
|
||||||
def wire(
|
def wire(
|
||||||
|
@ -28,6 +134,7 @@ def wire(
|
||||||
packages: Optional[Iterable[ModuleType]] = None,
|
packages: Optional[Iterable[ModuleType]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Wire container providers with provided packages and modules."""
|
"""Wire container providers with provided packages and modules."""
|
||||||
|
# TODO: Add protection to only wire declarative container instances
|
||||||
if not modules:
|
if not modules:
|
||||||
modules = []
|
modules = []
|
||||||
|
|
||||||
|
@ -35,12 +142,14 @@ def wire(
|
||||||
for package in packages:
|
for package in packages:
|
||||||
modules.extend(_fetch_modules(package))
|
modules.extend(_fetch_modules(package))
|
||||||
|
|
||||||
|
providers_map = ProvidersMap(container)
|
||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
for name, member in inspect.getmembers(module):
|
for name, member in inspect.getmembers(module):
|
||||||
if inspect.isfunction(member):
|
if inspect.isfunction(member):
|
||||||
_patch_fn(module, name, member, container)
|
_patch_fn(module, name, member, providers_map)
|
||||||
elif inspect.isclass(member):
|
elif inspect.isclass(member):
|
||||||
_patch_cls(member, container)
|
_patch_cls(member, providers_map)
|
||||||
|
|
||||||
|
|
||||||
def unwire(
|
def unwire(
|
||||||
|
@ -66,12 +175,12 @@ def unwire(
|
||||||
|
|
||||||
def _patch_cls(
|
def _patch_cls(
|
||||||
cls: Type[Any],
|
cls: Type[Any],
|
||||||
container: AnyContainer,
|
providers_map: ProvidersMap,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not hasattr(cls, '__init__'):
|
if not hasattr(cls, '__init__'):
|
||||||
return
|
return
|
||||||
init_method = getattr(cls, '__init__')
|
init_method = getattr(cls, '__init__')
|
||||||
injections = _resolve_injections(init_method, container)
|
injections = _resolve_injections(init_method, providers_map)
|
||||||
if not injections:
|
if not injections:
|
||||||
return
|
return
|
||||||
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
|
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
|
||||||
|
@ -90,9 +199,9 @@ def _patch_fn(
|
||||||
module: ModuleType,
|
module: ModuleType,
|
||||||
name: str,
|
name: str,
|
||||||
fn: Callable[..., Any],
|
fn: Callable[..., Any],
|
||||||
container: AnyContainer,
|
providers_map: ProvidersMap,
|
||||||
) -> None:
|
) -> None:
|
||||||
injections = _resolve_injections(fn, container)
|
injections = _resolve_injections(fn, providers_map)
|
||||||
if not injections:
|
if not injections:
|
||||||
return
|
return
|
||||||
setattr(module, name, _patch_with_injections(fn, injections))
|
setattr(module, name, _patch_with_injections(fn, injections))
|
||||||
|
@ -108,9 +217,7 @@ def _unpatch_fn(
|
||||||
setattr(module, name, _get_original_from_patched(fn))
|
setattr(module, name, _get_original_from_patched(fn))
|
||||||
|
|
||||||
|
|
||||||
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]: # noqa
|
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: # noqa
|
||||||
config = _resolve_container_config(container)
|
|
||||||
|
|
||||||
signature = inspect.signature(fn)
|
signature = inspect.signature(fn)
|
||||||
|
|
||||||
injections = {}
|
injections = {}
|
||||||
|
@ -119,35 +226,7 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
|
||||||
continue
|
continue
|
||||||
marker = parameter.default
|
marker = parameter.default
|
||||||
|
|
||||||
if config and isinstance(marker.provider, providers.ConfigurationOption):
|
provider = providers_map.resolve_provider(marker.provider)
|
||||||
provider = _prepare_config_injection(marker.provider.get_name(), config)
|
|
||||||
elif config and isinstance(marker.provider, providers.TypedConfigurationOption):
|
|
||||||
provider = _prepare_config_injection(
|
|
||||||
marker.provider.option.get_name(),
|
|
||||||
config,
|
|
||||||
marker.provider.provides,
|
|
||||||
)
|
|
||||||
elif isinstance(marker.provider, providers.Delegate):
|
|
||||||
provider_name = container.resolve_provider_name(marker.provider.provides)
|
|
||||||
provider = container.providers[provider_name]
|
|
||||||
elif isinstance(marker.provider, (
|
|
||||||
providers.ProvidedInstance,
|
|
||||||
providers.AttributeGetter,
|
|
||||||
providers.ItemGetter,
|
|
||||||
providers.MethodCaller,
|
|
||||||
)):
|
|
||||||
provider = _prepare_provided_instance_injection(
|
|
||||||
marker.provider,
|
|
||||||
container,
|
|
||||||
)
|
|
||||||
elif isinstance(marker.provider, providers.Provider):
|
|
||||||
provider_name = container.resolve_provider_name(marker.provider)
|
|
||||||
if not provider_name:
|
|
||||||
continue
|
|
||||||
provider = container.providers[provider_name]
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(marker, Provide):
|
if isinstance(marker, Provide):
|
||||||
injections[parameter_name] = provider
|
injections[parameter_name] = provider
|
||||||
elif isinstance(marker, Provider):
|
elif isinstance(marker, Provider):
|
||||||
|
@ -156,61 +235,6 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
|
||||||
return injections
|
return injections
|
||||||
|
|
||||||
|
|
||||||
def _prepare_config_injection(
|
|
||||||
option_name: str,
|
|
||||||
config: providers.Configuration,
|
|
||||||
as_: Any = None,
|
|
||||||
) -> providers.Provider:
|
|
||||||
_, *parts = option_name.split('.')
|
|
||||||
relative_option_name = '.'.join(parts)
|
|
||||||
provider = config.get_option_provider(relative_option_name)
|
|
||||||
if as_:
|
|
||||||
return provider.as_(as_)
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_provided_instance_injection(
|
|
||||||
current_provider: providers.Provider,
|
|
||||||
container: AnyContainer,
|
|
||||||
) -> providers.Provider:
|
|
||||||
provided_instance_markers = []
|
|
||||||
instance_provider_marker = current_provider
|
|
||||||
while isinstance(instance_provider_marker, (
|
|
||||||
providers.ProvidedInstance,
|
|
||||||
providers.AttributeGetter,
|
|
||||||
providers.ItemGetter,
|
|
||||||
providers.MethodCaller,
|
|
||||||
)):
|
|
||||||
provided_instance_markers.insert(0, instance_provider_marker)
|
|
||||||
instance_provider_marker = instance_provider_marker.provides
|
|
||||||
|
|
||||||
provider_name = container.resolve_provider_name(instance_provider_marker)
|
|
||||||
provider = container.providers[provider_name]
|
|
||||||
|
|
||||||
for provided_instance in provided_instance_markers:
|
|
||||||
if isinstance(provided_instance, providers.ProvidedInstance):
|
|
||||||
provider = provider.provided
|
|
||||||
elif isinstance(provided_instance, providers.AttributeGetter):
|
|
||||||
provider = getattr(provider, provided_instance.name)
|
|
||||||
elif isinstance(provided_instance, providers.ItemGetter):
|
|
||||||
provider = provider[provided_instance.name]
|
|
||||||
elif isinstance(provided_instance, providers.MethodCaller):
|
|
||||||
provider = provider.call(
|
|
||||||
*provided_instance.args,
|
|
||||||
**provided_instance.kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
|
|
||||||
for provider in container.providers.values():
|
|
||||||
if isinstance(provider, providers.Configuration):
|
|
||||||
return provider
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_modules(package):
|
def _fetch_modules(package):
|
||||||
modules = []
|
modules = []
|
||||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||||
|
|
|
@ -3,8 +3,15 @@ from dependency_injector import containers, providers
|
||||||
from .service import Service
|
from .service import Service
|
||||||
|
|
||||||
|
|
||||||
|
class SubContainer(containers.DeclarativeContainer):
|
||||||
|
|
||||||
|
int_object = providers.Object(1)
|
||||||
|
|
||||||
|
|
||||||
class Container(containers.DeclarativeContainer):
|
class Container(containers.DeclarativeContainer):
|
||||||
|
|
||||||
config = providers.Configuration()
|
config = providers.Configuration()
|
||||||
|
|
||||||
service = providers.Factory(Service)
|
service = providers.Factory(Service)
|
||||||
|
|
||||||
|
sub = providers.Container(SubContainer)
|
||||||
|
|
|
@ -39,3 +39,11 @@ def test_provide_provider(service_provider: Callable[..., Service] = Provider[Co
|
||||||
|
|
||||||
def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]):
|
def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]):
|
||||||
return some_value
|
return some_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_subcontainer_provider(some_value: int = Provide[Container.sub.int_object]):
|
||||||
|
return some_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_invariant(some_value: int = Provide[Container.config.option[Container.config.switch]]):
|
||||||
|
return some_value
|
||||||
|
|
|
@ -72,3 +72,25 @@ class WiringTest(unittest.TestCase):
|
||||||
with self.container.service.override(TestService()):
|
with self.container.service.override(TestService()):
|
||||||
some_value = module.test_provided_instance()
|
some_value = module.test_provided_instance()
|
||||||
self.assertEqual(some_value, 10)
|
self.assertEqual(some_value, 10)
|
||||||
|
|
||||||
|
def test_subcontainer(self):
|
||||||
|
some_value = module.test_subcontainer_provider()
|
||||||
|
self.assertEqual(some_value, 1)
|
||||||
|
|
||||||
|
def test_config_invariant(self):
|
||||||
|
config = {
|
||||||
|
'option': {
|
||||||
|
'a': 1,
|
||||||
|
'b': 2,
|
||||||
|
},
|
||||||
|
'switch': 'a',
|
||||||
|
}
|
||||||
|
self.container.config.from_dict(config)
|
||||||
|
|
||||||
|
with self.container.config.switch.override('a'):
|
||||||
|
value_a = module.test_config_invariant()
|
||||||
|
self.assertEqual(value_a, 1)
|
||||||
|
|
||||||
|
with self.container.config.switch.override('b'):
|
||||||
|
value_b = module.test_config_invariant()
|
||||||
|
self.assertEqual(value_b, 2)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user