mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-02-12 01:20:51 +03:00
Add wiring (#294)
* Add wiring module * Fix code style * Fix package test * Add version fix * Try spike for 3.6 * Try another fix with metaclass * Downsample required version to 3.6 * Introduce concept with annotations * Fix bugs * Add debug message * Add extra tests * Add extra debugging * Update config resolving * Remove 3.6 generic meta fix * Fix Flake8 * Add spike for 3.6 * Add Python 3.6 spike * Add unwire functionality * Add support of corouting functions
This commit is contained in:
parent
53b7ad0275
commit
af7364e062
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1,5 @@
|
||||||
from typing import Type, Dict, Tuple, Optional, Any, Union, ClassVar, Callable as _Callable
|
from types import ModuleType
|
||||||
|
from typing import Type, Dict, Tuple, Optional, Any, Union, ClassVar, Callable as _Callable, Iterable
|
||||||
|
|
||||||
from .providers import Provider
|
from .providers import Provider
|
||||||
|
|
||||||
|
@ -16,6 +17,8 @@ class Container:
|
||||||
def override_providers(self, **overriding_providers: Provider) -> None: ...
|
def override_providers(self, **overriding_providers: Provider) -> None: ...
|
||||||
def reset_last_overriding(self) -> None: ...
|
def reset_last_overriding(self) -> None: ...
|
||||||
def reset_override(self) -> None: ...
|
def reset_override(self) -> None: ...
|
||||||
|
def resolve_provider_name(self, provider_to_resolve: Provider) -> Optional[str]: ...
|
||||||
|
def wire(self, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class DynamicContainer(Container): ...
|
class DynamicContainer(Container): ...
|
||||||
|
|
|
@ -1,15 +1,26 @@
|
||||||
"""Containers module."""
|
"""Containers module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from .errors import Error
|
from .errors import Error
|
||||||
|
|
||||||
from .providers cimport (
|
from .providers cimport (
|
||||||
Provider,
|
Provider,
|
||||||
deepcopy,
|
deepcopy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info[:2] >= (3, 6):
|
||||||
|
from .wiring import wire, unwire
|
||||||
|
else:
|
||||||
|
def wire(*args, **kwargs):
|
||||||
|
raise NotADirectoryError('Wiring requires Python 3.6 or above')
|
||||||
|
|
||||||
|
def unwire(*args, **kwargs):
|
||||||
|
raise NotADirectoryError('Wiring requires Python 3.6 or above')
|
||||||
|
|
||||||
|
|
||||||
class DynamicContainer(object):
|
class DynamicContainer(object):
|
||||||
"""Dynamic inversion of control container.
|
"""Dynamic inversion of control container.
|
||||||
|
|
||||||
|
@ -47,8 +58,11 @@ class DynamicContainer(object):
|
||||||
:rtype: None
|
:rtype: None
|
||||||
"""
|
"""
|
||||||
self.provider_type = Provider
|
self.provider_type = Provider
|
||||||
self.providers = dict()
|
self.providers = {}
|
||||||
self.overridden = tuple()
|
self.overridden = tuple()
|
||||||
|
self.declarative_parent = None
|
||||||
|
self.wired_to_modules = []
|
||||||
|
self.wired_to_packages = []
|
||||||
super(DynamicContainer, self).__init__()
|
super(DynamicContainer, self).__init__()
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
|
@ -60,6 +74,7 @@ class DynamicContainer(object):
|
||||||
copied = self.__class__()
|
copied = self.__class__()
|
||||||
copied.provider_type = Provider
|
copied.provider_type = Provider
|
||||||
copied.overridden = deepcopy(self.overridden, memo)
|
copied.overridden = deepcopy(self.overridden, memo)
|
||||||
|
copied.declarative_parent = self.declarative_parent
|
||||||
|
|
||||||
for name, provider in deepcopy(self.providers, memo).items():
|
for name, provider in deepcopy(self.providers, memo).items():
|
||||||
setattr(copied, name, provider)
|
setattr(copied, name, provider)
|
||||||
|
@ -171,6 +186,48 @@ class DynamicContainer(object):
|
||||||
for provider in six.itervalues(self.providers):
|
for provider in six.itervalues(self.providers):
|
||||||
provider.reset_override()
|
provider.reset_override()
|
||||||
|
|
||||||
|
def resolve_provider_name(self, provider_to_resolve):
|
||||||
|
"""Try to resolve provider name by its instance."""
|
||||||
|
if self.declarative_parent:
|
||||||
|
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
|
||||||
|
if provider_name:
|
||||||
|
return provider_name
|
||||||
|
|
||||||
|
for provider_name, container_provider in self.providers.items():
|
||||||
|
if container_provider is provider_to_resolve:
|
||||||
|
return provider_name
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def wire(self, modules=None, packages=None):
|
||||||
|
"""Wire container providers with provided packages and modules.
|
||||||
|
|
||||||
|
:rtype: None
|
||||||
|
"""
|
||||||
|
wire(
|
||||||
|
container=self,
|
||||||
|
modules=modules,
|
||||||
|
packages=packages,
|
||||||
|
)
|
||||||
|
|
||||||
|
if modules:
|
||||||
|
self.wired_to_modules.extend(modules)
|
||||||
|
|
||||||
|
if packages:
|
||||||
|
self.wired_to_packages.extend(packages)
|
||||||
|
|
||||||
|
def unwire(self):
|
||||||
|
"""Unwire container providers from previously wired packages and modules."""
|
||||||
|
unwire(
|
||||||
|
modules=self.wired_to_modules,
|
||||||
|
packages=self.wired_to_packages,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.wired_to_modules.clear()
|
||||||
|
self.wired_to_packages.clear()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DeclarativeContainerMetaClass(type):
|
class DeclarativeContainerMetaClass(type):
|
||||||
"""Declarative inversion of control container meta class."""
|
"""Declarative inversion of control container meta class."""
|
||||||
|
@ -310,6 +367,7 @@ class DeclarativeContainer(object):
|
||||||
"""
|
"""
|
||||||
container = cls.instance_type()
|
container = cls.instance_type()
|
||||||
container.provider_type = cls.provider_type
|
container.provider_type = cls.provider_type
|
||||||
|
container.declarative_parent = cls
|
||||||
container.set_providers(**deepcopy(cls.providers))
|
container.set_providers(**deepcopy(cls.providers))
|
||||||
container.override_providers(**overriding_providers)
|
container.override_providers(**overriding_providers)
|
||||||
return container
|
return container
|
||||||
|
@ -363,6 +421,27 @@ class DeclarativeContainer(object):
|
||||||
for provider in six.itervalues(cls.providers):
|
for provider in six.itervalues(cls.providers):
|
||||||
provider.reset_override()
|
provider.reset_override()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve_provider_name(cls, provider_to_resolve):
|
||||||
|
"""Try to resolve provider name by its instance."""
|
||||||
|
for provider_name, container_provider in cls.providers.items():
|
||||||
|
if container_provider is provider_to_resolve:
|
||||||
|
return provider_name
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wire(cls, modules=None, packages=None):
|
||||||
|
"""Wire container providers with provided packages and modules by name.
|
||||||
|
|
||||||
|
:rtype: None
|
||||||
|
"""
|
||||||
|
wire(
|
||||||
|
container=cls,
|
||||||
|
modules=modules,
|
||||||
|
packages=packages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def override(object container):
|
def override(object container):
|
||||||
""":py:class:`DeclarativeContainer` overriding decorator.
|
""":py:class:`DeclarativeContainer` overriding decorator.
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -93,6 +93,10 @@ cdef class ConfigurationOption(Provider):
|
||||||
cdef object __cache
|
cdef object __cache
|
||||||
|
|
||||||
|
|
||||||
|
cdef class TypedConfigurationOption(Callable):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
cdef class Configuration(Object):
|
cdef class Configuration(Object):
|
||||||
cdef str __name
|
cdef str __name
|
||||||
cdef dict __children
|
cdef dict __children
|
||||||
|
|
|
@ -122,14 +122,15 @@ class CoroutineDelegate(Delegate):
|
||||||
|
|
||||||
class ConfigurationOption(Provider):
|
class ConfigurationOption(Provider):
|
||||||
UNDEFINED: object
|
UNDEFINED: object
|
||||||
def __init__(self, name: str, root: Configuration) -> None: ...
|
def __init__(self, name: Tuple[str], root: Configuration) -> None: ...
|
||||||
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
|
def __call__(self, *args: Injection, **kwargs: Injection) -> Any: ...
|
||||||
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
def __getattr__(self, item: str) -> ConfigurationOption: ...
|
||||||
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
||||||
def get_name(self) -> str: ...
|
def get_name(self) -> str: ...
|
||||||
def as_int(self) -> Callable[int]: ...
|
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
|
||||||
def as_float(self) -> Callable[float]: ...
|
def as_int(self) -> TypedConfigurationOption[int]: ...
|
||||||
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> Callable[T]: ...
|
def as_float(self) -> TypedConfigurationOption[float]: ...
|
||||||
|
def as_(self, callback: _Callable[..., T], *args: Injection, **kwargs: Injection) -> TypedConfigurationOption[T]: ...
|
||||||
def update(self, value: Any) -> None: ...
|
def update(self, value: Any) -> None: ...
|
||||||
def from_ini(self, filepath: str) -> None: ...
|
def from_ini(self, filepath: str) -> None: ...
|
||||||
def from_yaml(self, filepath: str) -> None: ...
|
def from_yaml(self, filepath: str) -> None: ...
|
||||||
|
@ -137,6 +138,11 @@ class ConfigurationOption(Provider):
|
||||||
def from_env(self, name: str, default: Optional[Any] = None) -> None: ...
|
def from_env(self, name: str, default: Optional[Any] = None) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class TypedConfigurationOption(Callable[T]):
|
||||||
|
@property
|
||||||
|
def option(self) -> ConfigurationOption: ...
|
||||||
|
|
||||||
|
|
||||||
class Configuration(Object):
|
class Configuration(Object):
|
||||||
DEFAULT_NAME: str = 'config'
|
DEFAULT_NAME: str = 'config'
|
||||||
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
|
def __init__(self, name: str = DEFAULT_NAME, default: Optional[Any] = None) -> None: ...
|
||||||
|
@ -144,6 +150,7 @@ class Configuration(Object):
|
||||||
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
def __getitem__(self, item: str) -> ConfigurationOption: ...
|
||||||
def get_name(self) -> str: ...
|
def get_name(self) -> str: ...
|
||||||
def get(self, selector: str) -> Any: ...
|
def get(self, selector: str) -> Any: ...
|
||||||
|
def get_option_provider(self, selector: str) -> ConfigurationOption: ...
|
||||||
def set(self, selector: str, value: Any) -> OverridingContext: ...
|
def set(self, selector: str, value: Any) -> OverridingContext: ...
|
||||||
def reset_cache(self) -> None: ...
|
def reset_cache(self) -> None: ...
|
||||||
def update(self, value: Any) -> None: ...
|
def update(self, value: Any) -> None: ...
|
||||||
|
|
|
@ -1143,14 +1143,31 @@ cdef class ConfigurationOption(Provider):
|
||||||
root = self.__root_ref()
|
root = self.__root_ref()
|
||||||
return '.'.join((root.get_name(), self._get_self_name()))
|
return '.'.join((root.get_name(), self._get_self_name()))
|
||||||
|
|
||||||
|
def get_option_provider(self, selector):
|
||||||
|
"""Return configuration option provider.
|
||||||
|
|
||||||
|
:param selector: Selector string, e.g. "option1.option2"
|
||||||
|
:type selector: str
|
||||||
|
|
||||||
|
:return: Option provider.
|
||||||
|
:rtype: :py:class:`ConfigurationOption`
|
||||||
|
"""
|
||||||
|
key, *other_keys = selector.split('.')
|
||||||
|
child = getattr(self, key)
|
||||||
|
|
||||||
|
if other_keys:
|
||||||
|
child = child.get_option_provider('.'.join(other_keys))
|
||||||
|
|
||||||
|
return child
|
||||||
|
|
||||||
def as_int(self):
|
def as_int(self):
|
||||||
return Callable(int, self)
|
return TypedConfigurationOption(int, self)
|
||||||
|
|
||||||
def as_float(self):
|
def as_float(self):
|
||||||
return Callable(float, self)
|
return TypedConfigurationOption(float, self)
|
||||||
|
|
||||||
def as_(self, callback, *args, **kwargs):
|
def as_(self, callback, *args, **kwargs):
|
||||||
return Callable(callback, self, *args, **kwargs)
|
return TypedConfigurationOption(callback, self, *args, **kwargs)
|
||||||
|
|
||||||
def override(self, value):
|
def override(self, value):
|
||||||
if isinstance(value, Provider):
|
if isinstance(value, Provider):
|
||||||
|
@ -1262,6 +1279,13 @@ cdef class ConfigurationOption(Provider):
|
||||||
self.override(value)
|
self.override(value)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class TypedConfigurationOption(Callable):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def option(self):
|
||||||
|
return self.args[0]
|
||||||
|
|
||||||
|
|
||||||
cdef class Configuration(Object):
|
cdef class Configuration(Object):
|
||||||
"""Configuration provider provides configuration options to the other providers.
|
"""Configuration provider provides configuration options to the other providers.
|
||||||
|
|
||||||
|
@ -1357,6 +1381,23 @@ cdef class Configuration(Object):
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def get_option_provider(self, selector):
|
||||||
|
"""Return configuration option provider.
|
||||||
|
|
||||||
|
:param selector: Selector string, e.g. "option1.option2"
|
||||||
|
:type selector: str
|
||||||
|
|
||||||
|
:return: Option provider.
|
||||||
|
:rtype: :py:class:`ConfigurationOption`
|
||||||
|
"""
|
||||||
|
key, *other_keys = selector.split('.')
|
||||||
|
child = getattr(self, key)
|
||||||
|
|
||||||
|
if other_keys:
|
||||||
|
child = child.get_option_provider('.'.join(other_keys))
|
||||||
|
|
||||||
|
return child
|
||||||
|
|
||||||
def set(self, selector, value):
|
def set(self, selector, value):
|
||||||
"""Override configuration option.
|
"""Override configuration option.
|
||||||
|
|
||||||
|
|
235
src/dependency_injector/wiring.py
Normal file
235
src/dependency_injector/wiring.py
Normal file
|
@ -0,0 +1,235 @@
|
||||||
|
"""Wiring module."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import pkgutil
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
|
||||||
|
|
||||||
|
if sys.version_info < (3, 7):
|
||||||
|
from typing import GenericMeta
|
||||||
|
else:
|
||||||
|
class GenericMeta(type):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
from . import providers
|
||||||
|
|
||||||
|
|
||||||
|
AnyContainer = Any
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
def wire(
|
||||||
|
container: AnyContainer,
|
||||||
|
*,
|
||||||
|
modules: Optional[Iterable[ModuleType]] = None,
|
||||||
|
packages: Optional[Iterable[ModuleType]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Wire container providers with provided packages and modules."""
|
||||||
|
if not modules:
|
||||||
|
modules = []
|
||||||
|
|
||||||
|
if packages:
|
||||||
|
for package in packages:
|
||||||
|
modules.extend(_fetch_modules(package))
|
||||||
|
|
||||||
|
for module in modules:
|
||||||
|
for name, member in inspect.getmembers(module):
|
||||||
|
if inspect.isfunction(member):
|
||||||
|
_patch_fn(module, name, member, container)
|
||||||
|
elif inspect.isclass(member):
|
||||||
|
_patch_cls(member, container)
|
||||||
|
|
||||||
|
|
||||||
|
def unwire(
|
||||||
|
*,
|
||||||
|
modules: Optional[Iterable[ModuleType]] = None,
|
||||||
|
packages: Optional[Iterable[ModuleType]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Wire provided packages and modules with previous wired providers."""
|
||||||
|
if not modules:
|
||||||
|
modules = []
|
||||||
|
|
||||||
|
if packages:
|
||||||
|
for package in packages:
|
||||||
|
modules.extend(_fetch_modules(package))
|
||||||
|
|
||||||
|
for module in modules:
|
||||||
|
for name, member in inspect.getmembers(module):
|
||||||
|
if inspect.isfunction(member):
|
||||||
|
_unpatch_fn(module, name, member)
|
||||||
|
elif inspect.isclass(member):
|
||||||
|
_unpatch_cls(member,)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_cls(
|
||||||
|
cls: Type[Any],
|
||||||
|
container: AnyContainer,
|
||||||
|
) -> None:
|
||||||
|
if not hasattr(cls, '__init__'):
|
||||||
|
return
|
||||||
|
init_method = getattr(cls, '__init__')
|
||||||
|
injections = _resolve_injections(init_method, container)
|
||||||
|
if not injections:
|
||||||
|
return
|
||||||
|
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
|
||||||
|
|
||||||
|
|
||||||
|
def _unpatch_cls(cls: Type[Any]) -> None:
|
||||||
|
if not hasattr(cls, '__init__'):
|
||||||
|
return
|
||||||
|
init_method = getattr(cls, '__init__')
|
||||||
|
if not _is_patched(init_method):
|
||||||
|
return
|
||||||
|
setattr(cls, '__init__', _get_original_from_patched(init_method))
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_fn(
|
||||||
|
module: ModuleType,
|
||||||
|
name: str,
|
||||||
|
fn: Callable[..., Any],
|
||||||
|
container: AnyContainer,
|
||||||
|
) -> None:
|
||||||
|
injections = _resolve_injections(fn, container)
|
||||||
|
if not injections:
|
||||||
|
return
|
||||||
|
setattr(module, name, _patch_with_injections(fn, injections))
|
||||||
|
|
||||||
|
|
||||||
|
def _unpatch_fn(
|
||||||
|
module: ModuleType,
|
||||||
|
name: str,
|
||||||
|
fn: Callable[..., Any],
|
||||||
|
) -> None:
|
||||||
|
if not _is_patched(fn):
|
||||||
|
return
|
||||||
|
setattr(module, name, _get_original_from_patched(fn))
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]:
|
||||||
|
config = _resolve_container_config(container)
|
||||||
|
|
||||||
|
signature = inspect.signature(fn)
|
||||||
|
|
||||||
|
injections = {}
|
||||||
|
for parameter_name, parameter in signature.parameters.items():
|
||||||
|
if not isinstance(parameter.default, _Marker):
|
||||||
|
continue
|
||||||
|
marker = parameter.default
|
||||||
|
|
||||||
|
if config and isinstance(marker.provider, providers.ConfigurationOption):
|
||||||
|
provider = _prepare_config_injection(marker.provider.get_name(), config)
|
||||||
|
elif config and isinstance(marker.provider, providers.TypedConfigurationOption):
|
||||||
|
provider = _prepare_config_injection(
|
||||||
|
marker.provider.option.get_name(),
|
||||||
|
config,
|
||||||
|
marker.provider.provides,
|
||||||
|
)
|
||||||
|
elif isinstance(marker.provider, providers.Provider):
|
||||||
|
provider_name = container.resolve_provider_name(marker.provider)
|
||||||
|
if not provider_name:
|
||||||
|
continue
|
||||||
|
provider = container.providers[provider_name]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(marker, Provide):
|
||||||
|
injections[parameter_name] = provider
|
||||||
|
elif isinstance(marker, Provider):
|
||||||
|
injections[parameter_name] = provider.provider
|
||||||
|
|
||||||
|
return injections
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_config_injection(
|
||||||
|
option_name: str,
|
||||||
|
config: providers.Configuration,
|
||||||
|
as_: Any = None,
|
||||||
|
) -> providers.Provider:
|
||||||
|
_, *parts = option_name.split('.')
|
||||||
|
relative_option_name = '.'.join(parts)
|
||||||
|
provider = config.get_option_provider(relative_option_name)
|
||||||
|
if as_:
|
||||||
|
return provider.as_(as_)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
|
||||||
|
for provider in container.providers.values():
|
||||||
|
if isinstance(provider, providers.Configuration):
|
||||||
|
return provider
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_modules(package):
|
||||||
|
modules = []
|
||||||
|
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||||
|
path=package.__path__,
|
||||||
|
prefix=package.__name__ + '.',
|
||||||
|
):
|
||||||
|
module = loader.find_module(module_name).load_module(module_name)
|
||||||
|
modules.append(module)
|
||||||
|
return modules
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_with_injections(fn, injections):
|
||||||
|
if inspect.iscoroutinefunction(fn):
|
||||||
|
@functools.wraps(fn)
|
||||||
|
async def _patched(*args, **kwargs):
|
||||||
|
to_inject = {}
|
||||||
|
for injection, provider in injections.items():
|
||||||
|
to_inject[injection] = provider()
|
||||||
|
|
||||||
|
to_inject.update(kwargs)
|
||||||
|
|
||||||
|
return await fn(*args, **to_inject)
|
||||||
|
else:
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def _patched(*args, **kwargs):
|
||||||
|
to_inject = {}
|
||||||
|
for injection, provider in injections.items():
|
||||||
|
to_inject[injection] = provider()
|
||||||
|
|
||||||
|
to_inject.update(kwargs)
|
||||||
|
|
||||||
|
return fn(*args, **to_inject)
|
||||||
|
|
||||||
|
_patched.__wired__ = True
|
||||||
|
_patched.__original__ = fn
|
||||||
|
_patched.__injections__ = injections
|
||||||
|
|
||||||
|
return _patched
|
||||||
|
|
||||||
|
|
||||||
|
def _is_patched(fn):
|
||||||
|
return getattr(fn, '__wired__', False) is True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_original_from_patched(fn):
|
||||||
|
return getattr(fn, '__original__')
|
||||||
|
|
||||||
|
|
||||||
|
class ClassGetItemMeta(GenericMeta):
|
||||||
|
def __getitem__(cls, item):
|
||||||
|
# Spike for Python 3.6
|
||||||
|
return cls(item)
|
||||||
|
|
||||||
|
|
||||||
|
class _Marker(Generic[T], metaclass=ClassGetItemMeta):
|
||||||
|
|
||||||
|
def __init__(self, provider: providers.Provider) -> None:
|
||||||
|
self.provider = provider
|
||||||
|
|
||||||
|
def __class_getitem__(cls, item) -> T:
|
||||||
|
return cls(item)
|
||||||
|
|
||||||
|
|
||||||
|
class Provide(_Marker):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(_Marker):
|
||||||
|
...
|
1
tests/unit/wiring/__init__.py
Normal file
1
tests/unit/wiring/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
"""Wiring tests."""
|
10
tests/unit/wiring/container.py
Normal file
10
tests/unit/wiring/container.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from dependency_injector import containers, providers
|
||||||
|
|
||||||
|
from .service import Service
|
||||||
|
|
||||||
|
|
||||||
|
class Container(containers.DeclarativeContainer):
|
||||||
|
|
||||||
|
config = providers.Configuration()
|
||||||
|
|
||||||
|
service = providers.Factory(Service)
|
32
tests/unit/wiring/module.py
Normal file
32
tests/unit/wiring/module.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
"""Test module for wiring."""
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from dependency_injector.wiring import Provide, Provider
|
||||||
|
|
||||||
|
from .container import Container
|
||||||
|
from .service import Service
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass:
|
||||||
|
|
||||||
|
def __init__(self, service: Service = Provide[Container.service]):
|
||||||
|
self.service = service
|
||||||
|
|
||||||
|
|
||||||
|
def test_function(service: Service = Provide[Container.service]):
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_provider(service_provider: Callable[..., Service] = Provider[Container.service]):
|
||||||
|
service = service_provider()
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_value(
|
||||||
|
some_value_int: int = Provide[Container.config.a.b.c.as_int()],
|
||||||
|
some_value_str: str = Provide[Container.config.a.b.c.as_(str)],
|
||||||
|
some_value_decimal: Decimal = Provide[Container.config.a.b.c.as_(Decimal)],
|
||||||
|
):
|
||||||
|
return some_value_int, some_value_str, some_value_decimal
|
0
tests/unit/wiring/package/__init__.py
Normal file
0
tests/unit/wiring/package/__init__.py
Normal file
0
tests/unit/wiring/package/subpackage/__init__.py
Normal file
0
tests/unit/wiring/package/subpackage/__init__.py
Normal file
8
tests/unit/wiring/package/subpackage/submodule.py
Normal file
8
tests/unit/wiring/package/subpackage/submodule.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from dependency_injector.wiring import Provide
|
||||||
|
|
||||||
|
from ...container import Container
|
||||||
|
from ...service import Service
|
||||||
|
|
||||||
|
|
||||||
|
def test_function(service: Service = Provide[Container.service]):
|
||||||
|
return service
|
2
tests/unit/wiring/service.py
Normal file
2
tests/unit/wiring/service.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
class Service:
|
||||||
|
service_attr: int
|
60
tests/unit/wiring/test_wiring_py36.py
Normal file
60
tests/unit/wiring/test_wiring_py36.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
from decimal import Decimal
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from . import module, package
|
||||||
|
from .service import Service
|
||||||
|
from .container import Container
|
||||||
|
|
||||||
|
|
||||||
|
class WiringTest(unittest.TestCase):
|
||||||
|
|
||||||
|
container: Container
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.container = Container(config={'a': {'b': {'c': 10}}})
|
||||||
|
self.container.wire(
|
||||||
|
modules=[module],
|
||||||
|
packages=[package],
|
||||||
|
)
|
||||||
|
self.addCleanup(self.container.unwire)
|
||||||
|
|
||||||
|
def test_package_lookup(self):
|
||||||
|
from .package.subpackage.submodule import test_function
|
||||||
|
service = test_function()
|
||||||
|
self.assertIsInstance(service, Service)
|
||||||
|
|
||||||
|
def test_class_wiring(self):
|
||||||
|
test_class_object = module.TestClass()
|
||||||
|
self.assertIsInstance(test_class_object.service, Service)
|
||||||
|
|
||||||
|
def test_class_wiring_context_arg(self):
|
||||||
|
test_service = self.container.service()
|
||||||
|
|
||||||
|
test_class_object = module.TestClass(service=test_service)
|
||||||
|
self.assertIs(test_class_object.service, test_service)
|
||||||
|
|
||||||
|
def test_function_wiring(self):
|
||||||
|
service = module.test_function()
|
||||||
|
self.assertIsInstance(service, Service)
|
||||||
|
|
||||||
|
def test_function_wiring_context_arg(self):
|
||||||
|
test_service = self.container.service()
|
||||||
|
|
||||||
|
service = module.test_function(service=test_service)
|
||||||
|
self.assertIs(service, test_service)
|
||||||
|
|
||||||
|
def test_function_wiring_provider(self):
|
||||||
|
service = module.test_function_provider()
|
||||||
|
self.assertIsInstance(service, Service)
|
||||||
|
|
||||||
|
def test_function_wiring_provider_context_arg(self):
|
||||||
|
test_service = self.container.service()
|
||||||
|
|
||||||
|
service = module.test_function_provider(service_provider=lambda: test_service)
|
||||||
|
self.assertIs(service, test_service)
|
||||||
|
|
||||||
|
def test_configuration_option(self):
|
||||||
|
int_value, str_value, decimal_value = module.test_config_value()
|
||||||
|
self.assertEqual(int_value, 10)
|
||||||
|
self.assertEqual(str_value, '10')
|
||||||
|
self.assertEqual(decimal_value, Decimal(10))
|
Loading…
Reference in New Issue
Block a user