mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-21 17:16:46 +03:00
Providers traversal (#385)
* Implement providers traversal in first precision * Implement traversal for all providers * Update traverse interface + add some tests * Refactor tests * Add tests for callable provider * Add configuration tests * Add Factory tests * Add FactoryAggrefate tests * Add .provides attribute to singleton providers * Add singleton provider tests * Add list and dict provider tests * Add resource tests * Add Container provider tests * Add Selector provider tests * Add ProvidedInstance provider tests * Add AttributeGetter provider tests * Add ItemGetter provider tests * Add MethodCaller provider tests * Refactor container interface * Update resource provider string representation * Add .initializer attribute to Resource provider * Add docs and examples * Remove not needed EOL in the tests * Make cosmetic refactoring * Ignore flake8 line width error in traverse example
This commit is contained in:
parent
0c1a08174f
commit
3ca6dd9af1
|
@ -23,3 +23,4 @@ Containers module API docs - :py:mod:`dependency_injector.containers`.
|
|||
dynamic
|
||||
specialization
|
||||
overriding
|
||||
traversal
|
||||
|
|
33
docs/containers/traversal.rst
Normal file
33
docs/containers/traversal.rst
Normal file
|
@ -0,0 +1,33 @@
|
|||
Container providers traversal
|
||||
-----------------------------
|
||||
|
||||
To traverse container providers use method ``.traverse()``.
|
||||
|
||||
.. literalinclude:: ../../examples/containers/traverse.py
|
||||
:language: python
|
||||
:lines: 3-
|
||||
:emphasize-lines: 38
|
||||
|
||||
Method ``.traverse()`` returns a generator. Traversal generator visits all container providers.
|
||||
This includes nested providers even if they are not present on the root level of the container.
|
||||
|
||||
Traversal generator guarantees that each container provider will be visited only once.
|
||||
It can traverse cyclic provider graphs.
|
||||
|
||||
Traversal generator does not guarantee traversal order.
|
||||
|
||||
You can use ``types=[...]`` argument to filter providers. Traversal generator will only return
|
||||
providers matching specified types.
|
||||
|
||||
.. code-block:: python
|
||||
:emphasize-lines: 3
|
||||
|
||||
container = Container()
|
||||
|
||||
for provider in container.traverse(types=[providers.Resource]):
|
||||
print(provider)
|
||||
|
||||
# <dependency_injector.providers.Resource(<function init_database at 0x10bd2cb80>) at 0x10d346b40>
|
||||
# <dependency_injector.providers.Resource(<function init_cache at 0x10be373a0>) at 0x10d346bc0>
|
||||
|
||||
.. disqus::
|
|
@ -7,6 +7,15 @@ that were made in every particular version.
|
|||
From version 0.7.6 *Dependency Injector* framework strictly
|
||||
follows `Semantic versioning`_
|
||||
|
||||
Development version
|
||||
-------------------
|
||||
- Add container providers traversal.
|
||||
- Add ``.provides`` attribute to ``Singleton`` and its subclasses.
|
||||
It's a consistency change to make ``Singleton`` match ``Callable``
|
||||
and ``Factory`` interfaces.
|
||||
- Add ``.initializer`` attribute to ``Resource`` provider.
|
||||
- Update string representation of ``Resource`` provider.
|
||||
|
||||
4.13.2
|
||||
------
|
||||
- Fix PyCharm typing warning "Expected type 'Optional[Iterable[ModuleType]]',
|
||||
|
|
48
examples/containers/traverse.py
Normal file
48
examples/containers/traverse.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""Container traversal example."""
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
|
||||
|
||||
def init_database():
|
||||
return ...
|
||||
|
||||
|
||||
def init_cache():
|
||||
return ...
|
||||
|
||||
|
||||
class Service:
|
||||
def __init__(self, database, cache):
|
||||
self.database = database
|
||||
self.cache = cache
|
||||
|
||||
|
||||
class Container(containers.DeclarativeContainer):
|
||||
|
||||
config = providers.Configuration()
|
||||
|
||||
service = providers.Factory(
|
||||
Service,
|
||||
database=providers.Resource(
|
||||
init_database,
|
||||
url=config.database_url,
|
||||
),
|
||||
cache=providers.Resource(
|
||||
init_cache,
|
||||
hosts=config.cache_hosts,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
container = Container()
|
||||
|
||||
for provider in container.traverse():
|
||||
print(provider)
|
||||
|
||||
# <dependency_injector.providers.Configuration('config') at 0x10d37d200>
|
||||
# <dependency_injector.providers.Factory(<class '__main__.Service'>) at 0x10d3a2820>
|
||||
# <dependency_injector.providers.Resource(<function init_database at 0x10bd2cb80>) at 0x10d346b40>
|
||||
# <dependency_injector.providers.ConfigurationOption('config.cache_hosts') at 0x10d37d350>
|
||||
# <dependency_injector.providers.Resource(<function init_cache at 0x10be373a0>) at 0x10d346bc0>
|
||||
# <dependency_injector.providers.ConfigurationOption('config.database_url') at 0x10d37d2e0>
|
|
@ -4,6 +4,7 @@ max_complexity = 10
|
|||
exclude = types.py
|
||||
per-file-ignores =
|
||||
examples/demo/*: F841
|
||||
examples/containers/traverse.py: E501
|
||||
examples/providers/async.py: F841
|
||||
examples/providers/async_overriding.py: F841
|
||||
examples/wiring/*: F841
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -7,9 +7,12 @@ from typing import (
|
|||
Union,
|
||||
ClassVar,
|
||||
Callable as _Callable,
|
||||
Sequence,
|
||||
Iterable,
|
||||
Iterator,
|
||||
TypeVar,
|
||||
Awaitable,
|
||||
overload,
|
||||
)
|
||||
|
||||
from .providers import Provider
|
||||
|
@ -40,6 +43,11 @@ class Container:
|
|||
def unwire(self) -> None: ...
|
||||
def init_resources(self) -> Optional[Awaitable]: ...
|
||||
def shutdown_resources(self) -> Optional[Awaitable]: ...
|
||||
@overload
|
||||
def traverse(self, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
|
||||
@classmethod
|
||||
@overload
|
||||
def traverse(cls, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
|
||||
|
||||
|
||||
class DynamicContainer(Container): ...
|
||||
|
|
|
@ -10,16 +10,7 @@ except ImportError:
|
|||
|
||||
import six
|
||||
|
||||
from .errors import Error
|
||||
from .providers cimport (
|
||||
Provider,
|
||||
Object,
|
||||
Resource,
|
||||
Dependency,
|
||||
DependenciesContainer,
|
||||
Container as ContainerProvider,
|
||||
deepcopy,
|
||||
)
|
||||
from . import providers, errors
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 6):
|
||||
|
@ -68,13 +59,13 @@ class DynamicContainer(object):
|
|||
|
||||
:rtype: None
|
||||
"""
|
||||
self.provider_type = Provider
|
||||
self.provider_type = providers.Provider
|
||||
self.providers = {}
|
||||
self.overridden = tuple()
|
||||
self.declarative_parent = None
|
||||
self.wired_to_modules = []
|
||||
self.wired_to_packages = []
|
||||
self.__self__ = Object(self)
|
||||
self.__self__ = providers.Object(self)
|
||||
super(DynamicContainer, self).__init__()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
@ -84,11 +75,11 @@ class DynamicContainer(object):
|
|||
return copied
|
||||
|
||||
copied = self.__class__()
|
||||
copied.provider_type = Provider
|
||||
copied.overridden = deepcopy(self.overridden, memo)
|
||||
copied.provider_type = providers.Provider
|
||||
copied.overridden = providers.deepcopy(self.overridden, memo)
|
||||
copied.declarative_parent = self.declarative_parent
|
||||
|
||||
for name, provider in deepcopy(self.providers, memo).items():
|
||||
for name, provider in providers.deepcopy(self.providers, memo).items():
|
||||
setattr(copied, name, provider)
|
||||
|
||||
return copied
|
||||
|
@ -107,7 +98,7 @@ class DynamicContainer(object):
|
|||
|
||||
:rtype: None
|
||||
"""
|
||||
if isinstance(value, Provider) and name != '__self__':
|
||||
if isinstance(value, providers.Provider) and name != '__self__':
|
||||
_check_provider_type(self, value)
|
||||
self.providers[name] = value
|
||||
super(DynamicContainer, self).__setattr__(name, value)
|
||||
|
@ -140,9 +131,13 @@ class DynamicContainer(object):
|
|||
return {
|
||||
name: provider
|
||||
for name, provider in self.providers.items()
|
||||
if isinstance(provider, (Dependency, DependenciesContainer))
|
||||
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
|
||||
}
|
||||
|
||||
def traverse(self, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
yield from providers.traverse(*self.providers.values(), types=types)
|
||||
|
||||
def set_providers(self, **providers):
|
||||
"""Set container providers.
|
||||
|
||||
|
@ -167,8 +162,8 @@ class DynamicContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if overriding is self:
|
||||
raise Error('Container {0} could not be overridden '
|
||||
'with itself'.format(self))
|
||||
raise errors.Error('Container {0} could not be overridden '
|
||||
'with itself'.format(self))
|
||||
|
||||
self.overridden += (overriding,)
|
||||
|
||||
|
@ -197,7 +192,7 @@ class DynamicContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if not self.overridden:
|
||||
raise Error('Container {0} is not overridden'.format(self))
|
||||
raise errors.Error('Container {0} is not overridden'.format(self))
|
||||
|
||||
self.overridden = self.overridden[:-1]
|
||||
|
||||
|
@ -245,7 +240,7 @@ class DynamicContainer(object):
|
|||
"""Initialize all container resources."""
|
||||
futures = []
|
||||
for provider in self.providers.values():
|
||||
if not isinstance(provider, Resource):
|
||||
if not isinstance(provider, providers.Resource):
|
||||
continue
|
||||
|
||||
resource = provider.init()
|
||||
|
@ -260,7 +255,7 @@ class DynamicContainer(object):
|
|||
"""Shutdown all container resources."""
|
||||
futures = []
|
||||
for provider in self.providers.values():
|
||||
if not isinstance(provider, Resource):
|
||||
if not isinstance(provider, providers.Resource):
|
||||
continue
|
||||
|
||||
shutdown = provider.shutdown()
|
||||
|
@ -274,7 +269,7 @@ class DynamicContainer(object):
|
|||
def apply_container_providers_overridings(self):
|
||||
"""Apply container providers' overridings."""
|
||||
for provider in self.providers.values():
|
||||
if not isinstance(provider, ContainerProvider):
|
||||
if not isinstance(provider, providers.Container):
|
||||
continue
|
||||
provider.apply_overridings()
|
||||
|
||||
|
@ -293,7 +288,7 @@ class DeclarativeContainerMetaClass(type):
|
|||
cls_providers = {
|
||||
name: provider
|
||||
for name, provider in six.iteritems(attributes)
|
||||
if isinstance(provider, Provider)
|
||||
if isinstance(provider, providers.Provider)
|
||||
}
|
||||
|
||||
inherited_providers = {
|
||||
|
@ -303,18 +298,18 @@ class DeclarativeContainerMetaClass(type):
|
|||
for name, provider in six.iteritems(base.providers)
|
||||
}
|
||||
|
||||
providers = {}
|
||||
providers.update(inherited_providers)
|
||||
providers.update(cls_providers)
|
||||
all_providers = {}
|
||||
all_providers.update(inherited_providers)
|
||||
all_providers.update(cls_providers)
|
||||
|
||||
attributes['containers'] = containers
|
||||
attributes['inherited_providers'] = inherited_providers
|
||||
attributes['cls_providers'] = cls_providers
|
||||
attributes['providers'] = providers
|
||||
attributes['providers'] = all_providers
|
||||
|
||||
cls = <type>type.__new__(mcs, class_name, bases, attributes)
|
||||
|
||||
cls.__self__ = Object(cls)
|
||||
cls.__self__ = providers.Object(cls)
|
||||
|
||||
for provider in six.itervalues(cls.providers):
|
||||
_check_provider_type(cls, provider)
|
||||
|
@ -335,7 +330,7 @@ class DeclarativeContainerMetaClass(type):
|
|||
|
||||
:rtype: None
|
||||
"""
|
||||
if isinstance(value, Provider) and name != '__self__':
|
||||
if isinstance(value, providers.Provider) and name != '__self__':
|
||||
_check_provider_type(cls, value)
|
||||
cls.providers[name] = value
|
||||
cls.cls_providers[name] = value
|
||||
|
@ -370,9 +365,13 @@ class DeclarativeContainerMetaClass(type):
|
|||
return {
|
||||
name: provider
|
||||
for name, provider in cls.providers.items()
|
||||
if isinstance(provider, (Dependency, DependenciesContainer))
|
||||
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
|
||||
}
|
||||
|
||||
def traverse(cls, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
yield from providers.traverse(*cls.providers.values(), types=types)
|
||||
|
||||
|
||||
@six.add_metaclass(DeclarativeContainerMetaClass)
|
||||
class DeclarativeContainer(object):
|
||||
|
@ -388,7 +387,7 @@ class DeclarativeContainer(object):
|
|||
|
||||
__IS_CONTAINER__ = True
|
||||
|
||||
provider_type = Provider
|
||||
provider_type = providers.Provider
|
||||
"""Type of providers that could be placed in container.
|
||||
|
||||
:type: type
|
||||
|
@ -446,7 +445,7 @@ class DeclarativeContainer(object):
|
|||
container = cls.instance_type()
|
||||
container.provider_type = cls.provider_type
|
||||
container.declarative_parent = cls
|
||||
container.set_providers(**deepcopy(cls.providers))
|
||||
container.set_providers(**providers.deepcopy(cls.providers))
|
||||
container.override_providers(**overriding_providers)
|
||||
container.apply_container_providers_overridings()
|
||||
return container
|
||||
|
@ -464,8 +463,8 @@ class DeclarativeContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if issubclass(cls, overriding):
|
||||
raise Error('Container {0} could not be overridden '
|
||||
'with itself or its subclasses'.format(cls))
|
||||
raise errors.Error('Container {0} could not be overridden '
|
||||
'with itself or its subclasses'.format(cls))
|
||||
|
||||
cls.overridden += (overriding,)
|
||||
|
||||
|
@ -482,7 +481,7 @@ class DeclarativeContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if not cls.overridden:
|
||||
raise Error('Container {0} is not overridden'.format(cls))
|
||||
raise errors.Error('Container {0} is not overridden'.format(cls))
|
||||
|
||||
cls.overridden = cls.overridden[:-1]
|
||||
|
||||
|
@ -559,7 +558,7 @@ def copy(object container):
|
|||
def _decorator(copied_container):
|
||||
memo = _get_providers_memo(copied_container.cls_providers, container.providers)
|
||||
|
||||
providers_copy = deepcopy(container.providers, memo)
|
||||
providers_copy = providers.deepcopy(container.providers, memo)
|
||||
for name, provider in six.iteritems(providers_copy):
|
||||
setattr(copied_container, name, provider)
|
||||
|
||||
|
@ -580,8 +579,8 @@ cpdef bint is_container(object instance):
|
|||
|
||||
cpdef object _check_provider_type(object container, object provider):
|
||||
if not isinstance(provider, container.provider_type):
|
||||
raise Error('{0} can contain only {1} '
|
||||
'instances'.format(container, container.provider_type))
|
||||
raise errors.Error('{0} can contain only {1} '
|
||||
'instances'.format(container, container.provider_type))
|
||||
|
||||
|
||||
cpdef bint _isawaitable(object instance):
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ from typing import (
|
|||
Optional,
|
||||
Union,
|
||||
Coroutine as _Coroutine,
|
||||
Iterable as _Iterable,
|
||||
Iterator as _Iterator,
|
||||
AsyncIterator as _AsyncIterator,
|
||||
Generator as _Generator,
|
||||
|
@ -68,6 +69,9 @@ class Provider(Generic[T]):
|
|||
def is_async_mode_enabled(self) -> bool: ...
|
||||
def is_async_mode_disabled(self) -> bool: ...
|
||||
def is_async_mode_undefined(self) -> bool: ...
|
||||
@property
|
||||
def related(self) -> _Iterator[Provider]: ...
|
||||
def traverse(self, types: Optional[_Iterable[Type]] = None) -> _Iterator[Provider]: ...
|
||||
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
|
||||
|
||||
|
||||
|
@ -238,6 +242,8 @@ class BaseSingleton(Provider[T]):
|
|||
@property
|
||||
def cls(self) -> T: ...
|
||||
@property
|
||||
def provides(self) -> T: ...
|
||||
@property
|
||||
def args(self) -> Tuple[Injection]: ...
|
||||
def add_args(self, *args: Injection) -> BaseSingleton[T]: ...
|
||||
def set_args(self, *args: Injection) -> BaseSingleton[T]: ...
|
||||
|
@ -313,6 +319,8 @@ class Resource(Provider[T]):
|
|||
@overload
|
||||
def __init__(self, initializer: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
|
||||
@property
|
||||
def initializer(self) -> _Callable[..., Any]: ...
|
||||
@property
|
||||
def args(self) -> Tuple[Injection]: ...
|
||||
def add_args(self, *args: Injection) -> Resource[T]: ...
|
||||
def set_args(self, *args: Injection) -> Resource[T]: ...
|
||||
|
@ -382,6 +390,9 @@ def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
|
|||
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
|
||||
|
||||
|
||||
def traverse(*providers: Provider, types: Optional[_Iterable[Type]]=None) -> _Iterator[Provider]: ...
|
||||
|
||||
|
||||
if yaml:
|
||||
class YamlLoader(yaml.SafeLoader): ...
|
||||
else:
|
||||
|
|
|
@ -363,6 +363,15 @@ cdef class Provider(object):
|
|||
"""Check if async mode is undefined."""
|
||||
return self.__async_mode == ASYNC_MODE_UNDEFINED
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.overridden
|
||||
|
||||
def traverse(self, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
return traverse(*self.related, types=types)
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Providing strategy implementation.
|
||||
|
||||
|
@ -423,6 +432,13 @@ cdef class Object(Provider):
|
|||
"""
|
||||
return self.__str__()
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
if isinstance(self.__provides, Provider):
|
||||
yield self.__provides
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return provided instance.
|
||||
|
||||
|
@ -487,6 +503,12 @@ cdef class Delegate(Provider):
|
|||
"""Return provider."""
|
||||
return self.__provides
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provides
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return provided instance.
|
||||
|
||||
|
@ -619,6 +641,13 @@ cdef class Dependency(Provider):
|
|||
"""Return default provider."""
|
||||
return self.__default
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
if self.__default is not UNDEFINED:
|
||||
yield self.__default
|
||||
yield from super().related
|
||||
|
||||
def provided_by(self, provider):
|
||||
"""Set external dependency provider.
|
||||
|
||||
|
@ -790,6 +819,12 @@ cdef class DependenciesContainer(Object):
|
|||
child.reset_override()
|
||||
super(DependenciesContainer, self).reset_override()
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.providers.values()
|
||||
yield from super().related
|
||||
|
||||
cpdef object _override_providers(self, object container):
|
||||
"""Override providers with providers from provided container."""
|
||||
for name, dependency_provider in container.providers.items():
|
||||
|
@ -998,6 +1033,14 @@ cdef class Callable(Provider):
|
|||
self.__kwargs_len = len(self.__kwargs)
|
||||
return self
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.provides])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return result of provided callable's call."""
|
||||
return __callable_call(self, args, kwargs)
|
||||
|
@ -1442,6 +1485,13 @@ cdef class ConfigurationOption(Provider):
|
|||
|
||||
self.override(value)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, self.__name)
|
||||
yield from self.__children.values()
|
||||
yield from super().related
|
||||
|
||||
def _is_strict_mode_enabled(self):
|
||||
return self.__root.__strict
|
||||
|
||||
|
@ -1763,6 +1813,12 @@ cdef class Configuration(Object):
|
|||
|
||||
self.override(value)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.__children.values()
|
||||
yield from super().related
|
||||
|
||||
def _is_strict_mode_enabled(self):
|
||||
return self.__strict
|
||||
|
||||
|
@ -1979,6 +2035,15 @@ cdef class Factory(Provider):
|
|||
self.__attributes_len = len(self.__attributes)
|
||||
return self
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.provides])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from filter(is_provider, self.attributes.values())
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return new instance."""
|
||||
return __factory_call(self, args, kwargs)
|
||||
|
@ -2141,6 +2206,12 @@ cdef class FactoryAggregate(Provider):
|
|||
raise Error(
|
||||
'{0} providers could not be overridden'.format(self.__class__))
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.__factories.values()
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
try:
|
||||
factory_name = args[0]
|
||||
|
@ -2212,7 +2283,12 @@ cdef class BaseSingleton(Provider):
|
|||
@property
|
||||
def cls(self):
|
||||
"""Return provided type."""
|
||||
return self.__instantiator.cls
|
||||
return self.provides
|
||||
|
||||
@property
|
||||
def provides(self):
|
||||
"""Return provided type."""
|
||||
return self.__instantiator.provides
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
|
@ -2314,6 +2390,15 @@ cdef class BaseSingleton(Provider):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.__instantiator.provides])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from filter(is_provider, self.attributes.values())
|
||||
yield from super().related
|
||||
|
||||
def _async_init_instance(self, future_result, result):
|
||||
try:
|
||||
instance = result.result()
|
||||
|
@ -2741,6 +2826,12 @@ cdef class List(Provider):
|
|||
self.__args_len = len(self.__args)
|
||||
return self
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return result of provided callable's call."""
|
||||
return list(__provide_positional_args(args, self.__args, self.__args_len))
|
||||
|
@ -2853,6 +2944,12 @@ cdef class Dict(Provider):
|
|||
self.__kwargs_len = len(self.__kwargs)
|
||||
return self
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return result of provided callable's call."""
|
||||
return __provide_keyword_args(kwargs, self.__kwargs, self.__kwargs_len)
|
||||
|
@ -2895,11 +2992,17 @@ cdef class Resource(Provider):
|
|||
|
||||
return copied
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.__initializer}, '
|
||||
f'initialized={self.__initialized})'
|
||||
)
|
||||
def __str__(self):
|
||||
"""Return string representation of provider.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
return represent_provider(provider=self, provides=self.__initializer)
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
"""Return initializer."""
|
||||
return self.__initializer
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
|
@ -3021,6 +3124,14 @@ cdef class Resource(Provider):
|
|||
result.set_result(None)
|
||||
return result
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.__initializer])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
if self.__initialized:
|
||||
return self.__resource
|
||||
|
@ -3246,6 +3357,12 @@ cdef class Container(Provider):
|
|||
declarative container initialization."""
|
||||
self.__container.override_providers(**self.__overriding_providers)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.providers.values()
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return single instance."""
|
||||
return self.__container
|
||||
|
@ -3335,6 +3452,13 @@ cdef class Selector(Provider):
|
|||
"""Return providers."""
|
||||
return dict(self.__providers)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.__selector])
|
||||
yield from self.providers.values()
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return single instance."""
|
||||
selector_value = self.__selector()
|
||||
|
@ -3411,6 +3535,12 @@ cdef class ProvidedInstance(Provider):
|
|||
def call(self, *args, **kwargs):
|
||||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
return self.__provider(*args, **kwargs)
|
||||
|
||||
|
@ -3460,6 +3590,12 @@ cdef class AttributeGetter(Provider):
|
|||
def call(self, *args, **kwargs):
|
||||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
provided = self.__provider(*args, **kwargs)
|
||||
if __isawaitable(provided):
|
||||
|
@ -3520,6 +3656,12 @@ cdef class ItemGetter(Provider):
|
|||
def call(self, *args, **kwargs):
|
||||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
provided = self.__provider(*args, **kwargs)
|
||||
if __isawaitable(provided):
|
||||
|
@ -3612,6 +3754,14 @@ cdef class MethodCaller(Provider):
|
|||
def call(self, *args, **kwargs):
|
||||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
call = self.__provider()
|
||||
if __isawaitable(call):
|
||||
|
@ -3846,6 +3996,29 @@ def merge_dicts(dict1, dict2):
|
|||
return result
|
||||
|
||||
|
||||
def traverse(*providers, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
visited = set()
|
||||
to_visit = set(providers)
|
||||
|
||||
if types:
|
||||
types = tuple(types)
|
||||
|
||||
while len(to_visit) > 0:
|
||||
visiting = to_visit.pop()
|
||||
visited.add(visiting)
|
||||
|
||||
for child in visiting.related:
|
||||
if child in visited:
|
||||
continue
|
||||
to_visit.add(child)
|
||||
|
||||
if types and not isinstance(visiting, types):
|
||||
continue
|
||||
|
||||
yield visiting
|
||||
|
||||
|
||||
def isawaitable(obj):
|
||||
"""Check if object is a coroutine function.
|
||||
|
||||
|
|
93
tests/unit/containers/test_traversal_py3.py
Normal file
93
tests/unit/containers/test_traversal_py3.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import unittest
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
|
||||
|
||||
class TraverseProviderTests(unittest.TestCase):
|
||||
|
||||
def test_nested_providers(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
container = Container()
|
||||
all_providers = list(container.traverse())
|
||||
|
||||
self.assertIn(container.obj_factory, all_providers)
|
||||
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
|
||||
def test_nested_providers_with_filtering(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
container = Container()
|
||||
all_providers = list(container.traverse(types=[providers.Resource]))
|
||||
|
||||
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
|
||||
|
||||
class TraverseProviderDeclarativeTests(unittest.TestCase):
|
||||
|
||||
def test_nested_providers(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
all_providers = list(Container.traverse())
|
||||
|
||||
self.assertIn(Container.obj_factory, all_providers)
|
||||
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
|
||||
def test_nested_providers_with_filtering(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
all_providers = list(Container.traverse(types=[providers.Resource]))
|
||||
|
||||
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 2)
|
|
@ -337,9 +337,9 @@ class ResourceTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(
|
||||
repr(provider),
|
||||
'Resource({0}, initialized={1})'.format(
|
||||
init_fn,
|
||||
provider.initialized,
|
||||
'<dependency_injector.providers.Resource({0}) at {1}>'.format(
|
||||
repr(init_fn),
|
||||
hex(id(provider)),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
875
tests/unit/providers/test_traversal_py3.py
Normal file
875
tests/unit/providers/test_traversal_py3.py
Normal file
|
@ -0,0 +1,875 @@
|
|||
import unittest
|
||||
|
||||
from dependency_injector import containers, providers
|
||||
|
||||
|
||||
class TraverseTests(unittest.TestCase):
|
||||
|
||||
def test_traverse_cycled_graph(self):
|
||||
provider1 = providers.Provider()
|
||||
|
||||
provider2 = providers.Provider()
|
||||
provider2.override(provider1)
|
||||
|
||||
provider3 = providers.Provider()
|
||||
provider3.override(provider2)
|
||||
|
||||
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
|
||||
|
||||
all_providers = list(providers.traverse(provider1))
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
def test_traverse_types_filtering(self):
|
||||
provider1 = providers.Resource(dict)
|
||||
provider2 = providers.Resource(dict)
|
||||
provider3 = providers.Provider()
|
||||
|
||||
provider = providers.Provider()
|
||||
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(providers.traverse(provider, types=[providers.Resource]))
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
|
||||
class ProviderTests(unittest.TestCase):
|
||||
|
||||
def test_traversal_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
provider2 = providers.Provider()
|
||||
provider3 = providers.Provider()
|
||||
|
||||
provider = providers.Provider()
|
||||
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
def test_traversal_overriding_nested(self):
|
||||
provider1 = providers.Provider()
|
||||
|
||||
provider2 = providers.Provider()
|
||||
provider2.override(provider1)
|
||||
|
||||
provider3 = providers.Provider()
|
||||
provider3.override(provider2)
|
||||
|
||||
provider = providers.Provider()
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
def test_traverse_types_filtering(self):
|
||||
provider1 = providers.Resource(dict)
|
||||
provider2 = providers.Resource(dict)
|
||||
provider3 = providers.Provider()
|
||||
|
||||
provider = providers.Provider()
|
||||
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse(types=[providers.Resource]))
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
|
||||
class ObjectTests(unittest.TestCase):
|
||||
|
||||
def test_traversal(self):
|
||||
provider = providers.Object('string')
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traversal_provider(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.Object(another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_provider_and_overriding(self):
|
||||
another_provider_1 = providers.Provider()
|
||||
another_provider_2 = providers.Provider()
|
||||
another_provider_3 = providers.Provider()
|
||||
|
||||
provider = providers.Object(another_provider_1)
|
||||
|
||||
provider.override(another_provider_2)
|
||||
provider.override(another_provider_3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(another_provider_1, all_providers)
|
||||
self.assertIn(another_provider_2, all_providers)
|
||||
self.assertIn(another_provider_3, all_providers)
|
||||
|
||||
|
||||
class DelegateTests(unittest.TestCase):
|
||||
|
||||
def test_traversal_provider(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.Delegate(another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_provider_and_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
provider2 = providers.Provider()
|
||||
|
||||
provider3 = providers.Provider()
|
||||
provider3.override(provider2)
|
||||
|
||||
provider = providers.Delegate(provider1)
|
||||
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class DependencyTests(unittest.TestCase):
|
||||
|
||||
def test_traversal(self):
|
||||
provider = providers.Dependency()
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traversal_default(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.Dependency(default=another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
|
||||
provider2 = providers.Provider()
|
||||
provider2.override(provider1)
|
||||
|
||||
provider = providers.Dependency()
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
|
||||
class DependenciesContainerTests(unittest.TestCase):
|
||||
|
||||
def test_traversal(self):
|
||||
provider = providers.DependenciesContainer()
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traversal_default(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.DependenciesContainer(default=another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_fluent_interface(self):
|
||||
provider = providers.DependenciesContainer()
|
||||
provider1 = provider.provider1
|
||||
provider2 = provider.provider2
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traversal_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
provider2 = providers.Provider()
|
||||
provider3 = providers.DependenciesContainer(
|
||||
provider1=provider1,
|
||||
provider2=provider2,
|
||||
)
|
||||
|
||||
provider = providers.DependenciesContainer()
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 5)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
self.assertIn(provider.provider1, all_providers)
|
||||
self.assertIn(provider.provider2, all_providers)
|
||||
|
||||
|
||||
class CallableTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider = providers.Callable(dict)
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traverse_args(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Callable(list, 'foo', provider1, provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_kwargs(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Callable(dict, foo='foo', bar=provider1, baz=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
|
||||
provider = providers.Callable(dict, 'foo')
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_provides(self):
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Object('bar')
|
||||
provider3 = providers.Object('baz')
|
||||
|
||||
provider = providers.Callable(provider1, provider2)
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class ConfigurationTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
config = providers.Configuration(default={'option1': {'option2': 'option2'}})
|
||||
option1 = config.option1
|
||||
option2 = config.option1.option2
|
||||
option3 = config.option1[config.option1.option2]
|
||||
|
||||
all_providers = list(config.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(option1, all_providers)
|
||||
self.assertIn(option2, all_providers)
|
||||
self.assertIn(option3, all_providers)
|
||||
|
||||
def test_traverse_typed(self):
|
||||
config = providers.Configuration()
|
||||
option = config.option
|
||||
typed_option = config.option.as_int()
|
||||
|
||||
all_providers = list(typed_option.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(option, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
options = {'option1': {'option2': 'option2'}}
|
||||
config = providers.Configuration()
|
||||
config.from_dict(options)
|
||||
|
||||
all_providers = list(config.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
overridden, = all_providers
|
||||
self.assertEqual(overridden(), options)
|
||||
self.assertIs(overridden, config.last_overriding)
|
||||
|
||||
def test_traverse_overridden_option_1(self):
|
||||
options = {'option2': 'option2'}
|
||||
config = providers.Configuration()
|
||||
config.option1.from_dict(options)
|
||||
|
||||
all_providers = list(config.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(config.option1, all_providers)
|
||||
self.assertIn(config.last_overriding, all_providers)
|
||||
|
||||
def test_traverse_overridden_option_2(self):
|
||||
options = {'option2': 'option2'}
|
||||
config = providers.Configuration()
|
||||
config.option1.from_dict(options)
|
||||
|
||||
all_providers = list(config.option1.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
|
||||
class FactoryTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider = providers.Factory(dict)
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traverse_args(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Factory(list, 'foo', provider1, provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_kwargs(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Factory(dict, foo='foo', bar=provider1, baz=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_attributes(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Factory(dict)
|
||||
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
|
||||
provider = providers.Factory(dict, 'foo')
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_provides(self):
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Object('bar')
|
||||
provider3 = providers.Object('baz')
|
||||
|
||||
provider = providers.Factory(provider1, provider2)
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class FactoryAggregateTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
factory1 = providers.Factory(dict)
|
||||
factory2 = providers.Factory(list)
|
||||
provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(factory1, all_providers)
|
||||
self.assertIn(factory2, all_providers)
|
||||
|
||||
|
||||
class BaseSingletonTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider = providers.Singleton(dict)
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traverse_args(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Singleton(list, 'foo', provider1, provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_kwargs(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Singleton(dict, foo='foo', bar=provider1, baz=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_attributes(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Singleton(dict)
|
||||
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
|
||||
provider = providers.Singleton(dict, 'foo')
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_provides(self):
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Object('bar')
|
||||
provider3 = providers.Object('baz')
|
||||
|
||||
provider = providers.Singleton(provider1, provider2)
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class ListTests(unittest.TestCase):
|
||||
|
||||
def test_traverse_args(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.List('foo', provider1, provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider3 = providers.List(provider1, provider2)
|
||||
|
||||
provider = providers.List('foo')
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class DictTests(unittest.TestCase):
|
||||
|
||||
def test_traverse_kwargs(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Dict(foo='foo', bar=provider1, baz=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider3 = providers.Dict(bar=provider1, baz=provider2)
|
||||
|
||||
provider = providers.Dict(foo='foo')
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class ResourceTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider = providers.Resource(dict)
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traverse_args(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Resource(list, 'foo', provider1, provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_kwargs(self):
|
||||
provider1 = providers.Object('bar')
|
||||
provider2 = providers.Object('baz')
|
||||
provider = providers.Resource(dict, foo='foo', bar=provider1, baz=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Resource(list)
|
||||
provider2 = providers.Resource(tuple)
|
||||
|
||||
provider = providers.Resource(dict, 'foo')
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_provides(self):
|
||||
provider1 = providers.Callable(list)
|
||||
|
||||
provider = providers.Resource(provider1)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(provider1, all_providers)
|
||||
|
||||
|
||||
class ContainerTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Callable(dict)
|
||||
|
||||
provider = providers.Container(Container)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertEqual(
|
||||
{provider.provides for provider in all_providers},
|
||||
{list, dict},
|
||||
)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
class Container1(containers.DeclarativeContainer):
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Callable(dict)
|
||||
|
||||
class Container2(containers.DeclarativeContainer):
|
||||
provider1 = providers.Callable(tuple)
|
||||
provider2 = providers.Callable(str)
|
||||
|
||||
container2 = Container2()
|
||||
|
||||
provider = providers.Container(Container1)
|
||||
provider.override(container2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 5)
|
||||
self.assertEqual(
|
||||
{
|
||||
provider.provides
|
||||
for provider in all_providers
|
||||
if isinstance(provider, providers.Callable)
|
||||
},
|
||||
{list, dict, tuple, str},
|
||||
)
|
||||
self.assertIn(provider.last_overriding, all_providers)
|
||||
self.assertIs(provider.last_overriding(), container2)
|
||||
|
||||
|
||||
class SelectorTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
switch = lambda: 'provider1'
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Callable(dict)
|
||||
|
||||
provider = providers.Selector(
|
||||
switch,
|
||||
provider1=provider1,
|
||||
provider2=provider2,
|
||||
)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_switch(self):
|
||||
switch = providers.Callable(lambda: 'provider1')
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Callable(dict)
|
||||
|
||||
provider = providers.Selector(
|
||||
switch,
|
||||
provider1=provider1,
|
||||
provider2=provider2,
|
||||
)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(switch, all_providers)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Callable(list)
|
||||
provider2 = providers.Callable(dict)
|
||||
selector1 = providers.Selector(lambda: 'provider1', provider1=provider1)
|
||||
|
||||
provider = providers.Selector(
|
||||
lambda: 'provider2',
|
||||
provider2=provider2,
|
||||
)
|
||||
provider.override(selector1)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(selector1, all_providers)
|
||||
|
||||
|
||||
class ProvidedInstanceTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider1 = providers.Provider()
|
||||
provider = provider1.provided
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(provider1, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Provider()
|
||||
provider2 = providers.Provider()
|
||||
|
||||
provider = provider1.provided
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
|
||||
class AttributeGetterTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
provider = provided.attr
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
provider2 = providers.Provider()
|
||||
|
||||
provider = provided.attr
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
|
||||
|
||||
class ItemGetterTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
provider = provided['item']
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
provider2 = providers.Provider()
|
||||
|
||||
provider = provided['item']
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
|
||||
|
||||
class MethodCallerTests(unittest.TestCase):
|
||||
|
||||
def test_traverse(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
method = provided.method
|
||||
provider = method.call()
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
self.assertIn(method, all_providers)
|
||||
|
||||
def test_traverse_args(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
method = provided.method
|
||||
provider2 = providers.Provider()
|
||||
provider = method.call('foo', provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 4)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
self.assertIn(method, all_providers)
|
||||
|
||||
def test_traverse_kwargs(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
method = provided.method
|
||||
provider2 = providers.Provider()
|
||||
provider = method.call(foo='foo', bar=provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 4)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
self.assertIn(method, all_providers)
|
||||
|
||||
def test_traverse_overridden(self):
|
||||
provider1 = providers.Provider()
|
||||
provided = provider1.provided
|
||||
method = provided.method
|
||||
provider2 = providers.Provider()
|
||||
|
||||
provider = method.call()
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 4)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provided, all_providers)
|
||||
self.assertIn(method, all_providers)
|
Loading…
Reference in New Issue
Block a user