Providers traversal (#385)

* Implement providers traversal in first precision

* Implement traversal for all providers

* Update traverse interface + add some tests

* Refactor tests

* Add tests for callable provider

* Add configuration tests

* Add Factory tests

* Add FactoryAggrefate tests

* Add .provides attribute to singleton providers

* Add singleton provider tests

* Add list and dict provider tests

* Add resource tests

* Add Container provider tests

* Add Selector provider tests

* Add ProvidedInstance provider tests

* Add AttributeGetter provider tests

* Add ItemGetter provider tests

* Add MethodCaller provider tests

* Refactor container interface

* Update resource provider string representation

* Add .initializer attribute to Resource provider

* Add docs and examples

* Remove not needed EOL in the tests

* Make cosmetic refactoring

* Ignore flake8 line width error in traverse example
This commit is contained in:
Roman Mogylatov 2021-02-01 09:42:21 -05:00 committed by GitHub
parent 0c1a08174f
commit 3ca6dd9af1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 21688 additions and 17638 deletions

View File

@ -23,3 +23,4 @@ Containers module API docs - :py:mod:`dependency_injector.containers`.
dynamic dynamic
specialization specialization
overriding overriding
traversal

View File

@ -0,0 +1,33 @@
Container providers traversal
-----------------------------
To traverse container providers use method ``.traverse()``.
.. literalinclude:: ../../examples/containers/traverse.py
:language: python
:lines: 3-
:emphasize-lines: 38
Method ``.traverse()`` returns a generator. Traversal generator visits all container providers.
This includes nested providers even if they are not present on the root level of the container.
Traversal generator guarantees that each container provider will be visited only once.
It can traverse cyclic provider graphs.
Traversal generator does not guarantee traversal order.
You can use ``types=[...]`` argument to filter providers. Traversal generator will only return
providers matching specified types.
.. code-block:: python
:emphasize-lines: 3
container = Container()
for provider in container.traverse(types=[providers.Resource]):
print(provider)
# <dependency_injector.providers.Resource(<function init_database at 0x10bd2cb80>) at 0x10d346b40>
# <dependency_injector.providers.Resource(<function init_cache at 0x10be373a0>) at 0x10d346bc0>
.. disqus::

View File

@ -7,6 +7,15 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_ follows `Semantic versioning`_
Development version
-------------------
- Add container providers traversal.
- Add ``.provides`` attribute to ``Singleton`` and its subclasses.
It's a consistency change to make ``Singleton`` match ``Callable``
and ``Factory`` interfaces.
- Add ``.initializer`` attribute to ``Resource`` provider.
- Update string representation of ``Resource`` provider.
4.13.2 4.13.2
------ ------
- Fix PyCharm typing warning "Expected type 'Optional[Iterable[ModuleType]]', - Fix PyCharm typing warning "Expected type 'Optional[Iterable[ModuleType]]',

View File

@ -0,0 +1,48 @@
"""Container traversal example."""
from dependency_injector import containers, providers
def init_database():
return ...
def init_cache():
return ...
class Service:
def __init__(self, database, cache):
self.database = database
self.cache = cache
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
service = providers.Factory(
Service,
database=providers.Resource(
init_database,
url=config.database_url,
),
cache=providers.Resource(
init_cache,
hosts=config.cache_hosts,
),
)
if __name__ == '__main__':
container = Container()
for provider in container.traverse():
print(provider)
# <dependency_injector.providers.Configuration('config') at 0x10d37d200>
# <dependency_injector.providers.Factory(<class '__main__.Service'>) at 0x10d3a2820>
# <dependency_injector.providers.Resource(<function init_database at 0x10bd2cb80>) at 0x10d346b40>
# <dependency_injector.providers.ConfigurationOption('config.cache_hosts') at 0x10d37d350>
# <dependency_injector.providers.Resource(<function init_cache at 0x10be373a0>) at 0x10d346bc0>
# <dependency_injector.providers.ConfigurationOption('config.database_url') at 0x10d37d2e0>

View File

@ -4,6 +4,7 @@ max_complexity = 10
exclude = types.py exclude = types.py
per-file-ignores = per-file-ignores =
examples/demo/*: F841 examples/demo/*: F841
examples/containers/traverse.py: E501
examples/providers/async.py: F841 examples/providers/async.py: F841
examples/providers/async_overriding.py: F841 examples/providers/async_overriding.py: F841
examples/wiring/*: F841 examples/wiring/*: F841

File diff suppressed because it is too large Load Diff

View File

@ -7,9 +7,12 @@ from typing import (
Union, Union,
ClassVar, ClassVar,
Callable as _Callable, Callable as _Callable,
Sequence,
Iterable, Iterable,
Iterator,
TypeVar, TypeVar,
Awaitable, Awaitable,
overload,
) )
from .providers import Provider from .providers import Provider
@ -40,6 +43,11 @@ class Container:
def unwire(self) -> None: ... def unwire(self) -> None: ...
def init_resources(self) -> Optional[Awaitable]: ... def init_resources(self) -> Optional[Awaitable]: ...
def shutdown_resources(self) -> Optional[Awaitable]: ... def shutdown_resources(self) -> Optional[Awaitable]: ...
@overload
def traverse(self, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
@classmethod
@overload
def traverse(cls, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
class DynamicContainer(Container): ... class DynamicContainer(Container): ...

View File

@ -10,16 +10,7 @@ except ImportError:
import six import six
from .errors import Error from . import providers, errors
from .providers cimport (
Provider,
Object,
Resource,
Dependency,
DependenciesContainer,
Container as ContainerProvider,
deepcopy,
)
if sys.version_info[:2] >= (3, 6): if sys.version_info[:2] >= (3, 6):
@ -68,13 +59,13 @@ class DynamicContainer(object):
:rtype: None :rtype: None
""" """
self.provider_type = Provider self.provider_type = providers.Provider
self.providers = {} self.providers = {}
self.overridden = tuple() self.overridden = tuple()
self.declarative_parent = None self.declarative_parent = None
self.wired_to_modules = [] self.wired_to_modules = []
self.wired_to_packages = [] self.wired_to_packages = []
self.__self__ = Object(self) self.__self__ = providers.Object(self)
super(DynamicContainer, self).__init__() super(DynamicContainer, self).__init__()
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -84,11 +75,11 @@ class DynamicContainer(object):
return copied return copied
copied = self.__class__() copied = self.__class__()
copied.provider_type = Provider copied.provider_type = providers.Provider
copied.overridden = deepcopy(self.overridden, memo) copied.overridden = providers.deepcopy(self.overridden, memo)
copied.declarative_parent = self.declarative_parent copied.declarative_parent = self.declarative_parent
for name, provider in deepcopy(self.providers, memo).items(): for name, provider in providers.deepcopy(self.providers, memo).items():
setattr(copied, name, provider) setattr(copied, name, provider)
return copied return copied
@ -107,7 +98,7 @@ class DynamicContainer(object):
:rtype: None :rtype: None
""" """
if isinstance(value, Provider) and name != '__self__': if isinstance(value, providers.Provider) and name != '__self__':
_check_provider_type(self, value) _check_provider_type(self, value)
self.providers[name] = value self.providers[name] = value
super(DynamicContainer, self).__setattr__(name, value) super(DynamicContainer, self).__setattr__(name, value)
@ -140,9 +131,13 @@ class DynamicContainer(object):
return { return {
name: provider name: provider
for name, provider in self.providers.items() for name, provider in self.providers.items()
if isinstance(provider, (Dependency, DependenciesContainer)) if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
} }
def traverse(self, types=None):
"""Return providers traversal generator."""
yield from providers.traverse(*self.providers.values(), types=types)
def set_providers(self, **providers): def set_providers(self, **providers):
"""Set container providers. """Set container providers.
@ -167,7 +162,7 @@ class DynamicContainer(object):
:rtype: None :rtype: None
""" """
if overriding is self: if overriding is self:
raise Error('Container {0} could not be overridden ' raise errors.Error('Container {0} could not be overridden '
'with itself'.format(self)) 'with itself'.format(self))
self.overridden += (overriding,) self.overridden += (overriding,)
@ -197,7 +192,7 @@ class DynamicContainer(object):
:rtype: None :rtype: None
""" """
if not self.overridden: if not self.overridden:
raise Error('Container {0} is not overridden'.format(self)) raise errors.Error('Container {0} is not overridden'.format(self))
self.overridden = self.overridden[:-1] self.overridden = self.overridden[:-1]
@ -245,7 +240,7 @@ class DynamicContainer(object):
"""Initialize all container resources.""" """Initialize all container resources."""
futures = [] futures = []
for provider in self.providers.values(): for provider in self.providers.values():
if not isinstance(provider, Resource): if not isinstance(provider, providers.Resource):
continue continue
resource = provider.init() resource = provider.init()
@ -260,7 +255,7 @@ class DynamicContainer(object):
"""Shutdown all container resources.""" """Shutdown all container resources."""
futures = [] futures = []
for provider in self.providers.values(): for provider in self.providers.values():
if not isinstance(provider, Resource): if not isinstance(provider, providers.Resource):
continue continue
shutdown = provider.shutdown() shutdown = provider.shutdown()
@ -274,7 +269,7 @@ class DynamicContainer(object):
def apply_container_providers_overridings(self): def apply_container_providers_overridings(self):
"""Apply container providers' overridings.""" """Apply container providers' overridings."""
for provider in self.providers.values(): for provider in self.providers.values():
if not isinstance(provider, ContainerProvider): if not isinstance(provider, providers.Container):
continue continue
provider.apply_overridings() provider.apply_overridings()
@ -293,7 +288,7 @@ class DeclarativeContainerMetaClass(type):
cls_providers = { cls_providers = {
name: provider name: provider
for name, provider in six.iteritems(attributes) for name, provider in six.iteritems(attributes)
if isinstance(provider, Provider) if isinstance(provider, providers.Provider)
} }
inherited_providers = { inherited_providers = {
@ -303,18 +298,18 @@ class DeclarativeContainerMetaClass(type):
for name, provider in six.iteritems(base.providers) for name, provider in six.iteritems(base.providers)
} }
providers = {} all_providers = {}
providers.update(inherited_providers) all_providers.update(inherited_providers)
providers.update(cls_providers) all_providers.update(cls_providers)
attributes['containers'] = containers attributes['containers'] = containers
attributes['inherited_providers'] = inherited_providers attributes['inherited_providers'] = inherited_providers
attributes['cls_providers'] = cls_providers attributes['cls_providers'] = cls_providers
attributes['providers'] = providers attributes['providers'] = all_providers
cls = <type>type.__new__(mcs, class_name, bases, attributes) cls = <type>type.__new__(mcs, class_name, bases, attributes)
cls.__self__ = Object(cls) cls.__self__ = providers.Object(cls)
for provider in six.itervalues(cls.providers): for provider in six.itervalues(cls.providers):
_check_provider_type(cls, provider) _check_provider_type(cls, provider)
@ -335,7 +330,7 @@ class DeclarativeContainerMetaClass(type):
:rtype: None :rtype: None
""" """
if isinstance(value, Provider) and name != '__self__': if isinstance(value, providers.Provider) and name != '__self__':
_check_provider_type(cls, value) _check_provider_type(cls, value)
cls.providers[name] = value cls.providers[name] = value
cls.cls_providers[name] = value cls.cls_providers[name] = value
@ -370,9 +365,13 @@ class DeclarativeContainerMetaClass(type):
return { return {
name: provider name: provider
for name, provider in cls.providers.items() for name, provider in cls.providers.items()
if isinstance(provider, (Dependency, DependenciesContainer)) if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
} }
def traverse(cls, types=None):
"""Return providers traversal generator."""
yield from providers.traverse(*cls.providers.values(), types=types)
@six.add_metaclass(DeclarativeContainerMetaClass) @six.add_metaclass(DeclarativeContainerMetaClass)
class DeclarativeContainer(object): class DeclarativeContainer(object):
@ -388,7 +387,7 @@ class DeclarativeContainer(object):
__IS_CONTAINER__ = True __IS_CONTAINER__ = True
provider_type = Provider provider_type = providers.Provider
"""Type of providers that could be placed in container. """Type of providers that could be placed in container.
:type: type :type: type
@ -446,7 +445,7 @@ class DeclarativeContainer(object):
container = cls.instance_type() container = cls.instance_type()
container.provider_type = cls.provider_type container.provider_type = cls.provider_type
container.declarative_parent = cls container.declarative_parent = cls
container.set_providers(**deepcopy(cls.providers)) container.set_providers(**providers.deepcopy(cls.providers))
container.override_providers(**overriding_providers) container.override_providers(**overriding_providers)
container.apply_container_providers_overridings() container.apply_container_providers_overridings()
return container return container
@ -464,7 +463,7 @@ class DeclarativeContainer(object):
:rtype: None :rtype: None
""" """
if issubclass(cls, overriding): if issubclass(cls, overriding):
raise Error('Container {0} could not be overridden ' raise errors.Error('Container {0} could not be overridden '
'with itself or its subclasses'.format(cls)) 'with itself or its subclasses'.format(cls))
cls.overridden += (overriding,) cls.overridden += (overriding,)
@ -482,7 +481,7 @@ class DeclarativeContainer(object):
:rtype: None :rtype: None
""" """
if not cls.overridden: if not cls.overridden:
raise Error('Container {0} is not overridden'.format(cls)) raise errors.Error('Container {0} is not overridden'.format(cls))
cls.overridden = cls.overridden[:-1] cls.overridden = cls.overridden[:-1]
@ -559,7 +558,7 @@ def copy(object container):
def _decorator(copied_container): def _decorator(copied_container):
memo = _get_providers_memo(copied_container.cls_providers, container.providers) memo = _get_providers_memo(copied_container.cls_providers, container.providers)
providers_copy = deepcopy(container.providers, memo) providers_copy = providers.deepcopy(container.providers, memo)
for name, provider in six.iteritems(providers_copy): for name, provider in six.iteritems(providers_copy):
setattr(copied_container, name, provider) setattr(copied_container, name, provider)
@ -580,7 +579,7 @@ cpdef bint is_container(object instance):
cpdef object _check_provider_type(object container, object provider): cpdef object _check_provider_type(object container, object provider):
if not isinstance(provider, container.provider_type): if not isinstance(provider, container.provider_type):
raise Error('{0} can contain only {1} ' raise errors.Error('{0} can contain only {1} '
'instances'.format(container, container.provider_type)) 'instances'.format(container, container.provider_type))

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ from typing import (
Optional, Optional,
Union, Union,
Coroutine as _Coroutine, Coroutine as _Coroutine,
Iterable as _Iterable,
Iterator as _Iterator, Iterator as _Iterator,
AsyncIterator as _AsyncIterator, AsyncIterator as _AsyncIterator,
Generator as _Generator, Generator as _Generator,
@ -68,6 +69,9 @@ class Provider(Generic[T]):
def is_async_mode_enabled(self) -> bool: ... def is_async_mode_enabled(self) -> bool: ...
def is_async_mode_disabled(self) -> bool: ... def is_async_mode_disabled(self) -> bool: ...
def is_async_mode_undefined(self) -> bool: ... def is_async_mode_undefined(self) -> bool: ...
@property
def related(self) -> _Iterator[Provider]: ...
def traverse(self, types: Optional[_Iterable[Type]] = None) -> _Iterator[Provider]: ...
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ... def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
@ -238,6 +242,8 @@ class BaseSingleton(Provider[T]):
@property @property
def cls(self) -> T: ... def cls(self) -> T: ...
@property @property
def provides(self) -> T: ...
@property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> BaseSingleton[T]: ... def add_args(self, *args: Injection) -> BaseSingleton[T]: ...
def set_args(self, *args: Injection) -> BaseSingleton[T]: ... def set_args(self, *args: Injection) -> BaseSingleton[T]: ...
@ -313,6 +319,8 @@ class Resource(Provider[T]):
@overload @overload
def __init__(self, initializer: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, initializer: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
@property @property
def initializer(self) -> _Callable[..., Any]: ...
@property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Resource[T]: ... def add_args(self, *args: Injection) -> Resource[T]: ...
def set_args(self, *args: Injection) -> Resource[T]: ... def set_args(self, *args: Injection) -> Resource[T]: ...
@ -382,6 +390,9 @@ def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ... def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
def traverse(*providers: Provider, types: Optional[_Iterable[Type]]=None) -> _Iterator[Provider]: ...
if yaml: if yaml:
class YamlLoader(yaml.SafeLoader): ... class YamlLoader(yaml.SafeLoader): ...
else: else:

View File

@ -363,6 +363,15 @@ cdef class Provider(object):
"""Check if async mode is undefined.""" """Check if async mode is undefined."""
return self.__async_mode == ASYNC_MODE_UNDEFINED return self.__async_mode == ASYNC_MODE_UNDEFINED
@property
def related(self):
"""Return related providers generator."""
yield from self.overridden
def traverse(self, types=None):
"""Return providers traversal generator."""
return traverse(*self.related, types=types)
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Providing strategy implementation. """Providing strategy implementation.
@ -423,6 +432,13 @@ cdef class Object(Provider):
""" """
return self.__str__() return self.__str__()
@property
def related(self):
"""Return related providers generator."""
if isinstance(self.__provides, Provider):
yield self.__provides
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance. """Return provided instance.
@ -487,6 +503,12 @@ cdef class Delegate(Provider):
"""Return provider.""" """Return provider."""
return self.__provides return self.__provides
@property
def related(self):
"""Return related providers generator."""
yield self.__provides
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance. """Return provided instance.
@ -619,6 +641,13 @@ cdef class Dependency(Provider):
"""Return default provider.""" """Return default provider."""
return self.__default return self.__default
@property
def related(self):
"""Return related providers generator."""
if self.__default is not UNDEFINED:
yield self.__default
yield from super().related
def provided_by(self, provider): def provided_by(self, provider):
"""Set external dependency provider. """Set external dependency provider.
@ -790,6 +819,12 @@ cdef class DependenciesContainer(Object):
child.reset_override() child.reset_override()
super(DependenciesContainer, self).reset_override() super(DependenciesContainer, self).reset_override()
@property
def related(self):
"""Return related providers generator."""
yield from self.providers.values()
yield from super().related
cpdef object _override_providers(self, object container): cpdef object _override_providers(self, object container):
"""Override providers with providers from provided container.""" """Override providers with providers from provided container."""
for name, dependency_provider in container.providers.items(): for name, dependency_provider in container.providers.items():
@ -998,6 +1033,14 @@ cdef class Callable(Provider):
self.__kwargs_len = len(self.__kwargs) self.__kwargs_len = len(self.__kwargs)
return self return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call.""" """Return result of provided callable's call."""
return __callable_call(self, args, kwargs) return __callable_call(self, args, kwargs)
@ -1442,6 +1485,13 @@ cdef class ConfigurationOption(Provider):
self.override(value) self.override(value)
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.__name)
yield from self.__children.values()
yield from super().related
def _is_strict_mode_enabled(self): def _is_strict_mode_enabled(self):
return self.__root.__strict return self.__root.__strict
@ -1763,6 +1813,12 @@ cdef class Configuration(Object):
self.override(value) self.override(value)
@property
def related(self):
"""Return related providers generator."""
yield from self.__children.values()
yield from super().related
def _is_strict_mode_enabled(self): def _is_strict_mode_enabled(self):
return self.__strict return self.__strict
@ -1979,6 +2035,15 @@ cdef class Factory(Provider):
self.__attributes_len = len(self.__attributes) self.__attributes_len = len(self.__attributes)
return self return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from filter(is_provider, self.attributes.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return new instance.""" """Return new instance."""
return __factory_call(self, args, kwargs) return __factory_call(self, args, kwargs)
@ -2141,6 +2206,12 @@ cdef class FactoryAggregate(Provider):
raise Error( raise Error(
'{0} providers could not be overridden'.format(self.__class__)) '{0} providers could not be overridden'.format(self.__class__))
@property
def related(self):
"""Return related providers generator."""
yield from self.__factories.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
try: try:
factory_name = args[0] factory_name = args[0]
@ -2212,7 +2283,12 @@ cdef class BaseSingleton(Provider):
@property @property
def cls(self): def cls(self):
"""Return provided type.""" """Return provided type."""
return self.__instantiator.cls return self.provides
@property
def provides(self):
"""Return provided type."""
return self.__instantiator.provides
@property @property
def args(self): def args(self):
@ -2314,6 +2390,15 @@ cdef class BaseSingleton(Provider):
""" """
raise NotImplementedError() raise NotImplementedError()
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__instantiator.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from filter(is_provider, self.attributes.values())
yield from super().related
def _async_init_instance(self, future_result, result): def _async_init_instance(self, future_result, result):
try: try:
instance = result.result() instance = result.result()
@ -2741,6 +2826,12 @@ cdef class List(Provider):
self.__args_len = len(self.__args) self.__args_len = len(self.__args)
return self return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.args)
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call.""" """Return result of provided callable's call."""
return list(__provide_positional_args(args, self.__args, self.__args_len)) return list(__provide_positional_args(args, self.__args, self.__args_len))
@ -2853,6 +2944,12 @@ cdef class Dict(Provider):
self.__kwargs_len = len(self.__kwargs) self.__kwargs_len = len(self.__kwargs)
return self return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call.""" """Return result of provided callable's call."""
return __provide_keyword_args(kwargs, self.__kwargs, self.__kwargs_len) return __provide_keyword_args(kwargs, self.__kwargs, self.__kwargs_len)
@ -2895,11 +2992,17 @@ cdef class Resource(Provider):
return copied return copied
def __repr__(self): def __str__(self):
return ( """Return string representation of provider.
f'{self.__class__.__name__}({self.__initializer}, '
f'initialized={self.__initialized})' :rtype: str
) """
return represent_provider(provider=self, provides=self.__initializer)
@property
def initializer(self):
"""Return initializer."""
return self.__initializer
@property @property
def args(self): def args(self):
@ -3021,6 +3124,14 @@ cdef class Resource(Provider):
result.set_result(None) result.set_result(None)
return result return result
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__initializer])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
if self.__initialized: if self.__initialized:
return self.__resource return self.__resource
@ -3246,6 +3357,12 @@ cdef class Container(Provider):
declarative container initialization.""" declarative container initialization."""
self.__container.override_providers(**self.__overriding_providers) self.__container.override_providers(**self.__overriding_providers)
@property
def related(self):
"""Return related providers generator."""
yield from self.providers.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance.""" """Return single instance."""
return self.__container return self.__container
@ -3335,6 +3452,13 @@ cdef class Selector(Provider):
"""Return providers.""" """Return providers."""
return dict(self.__providers) return dict(self.__providers)
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__selector])
yield from self.providers.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance.""" """Return single instance."""
selector_value = self.__selector() selector_value = self.__selector()
@ -3411,6 +3535,12 @@ cdef class ProvidedInstance(Provider):
def call(self, *args, **kwargs): def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs) return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
return self.__provider(*args, **kwargs) return self.__provider(*args, **kwargs)
@ -3460,6 +3590,12 @@ cdef class AttributeGetter(Provider):
def call(self, *args, **kwargs): def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs) return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs) provided = self.__provider(*args, **kwargs)
if __isawaitable(provided): if __isawaitable(provided):
@ -3520,6 +3656,12 @@ cdef class ItemGetter(Provider):
def call(self, *args, **kwargs): def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs) return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs) provided = self.__provider(*args, **kwargs)
if __isawaitable(provided): if __isawaitable(provided):
@ -3612,6 +3754,14 @@ cdef class MethodCaller(Provider):
def call(self, *args, **kwargs): def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs) return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
call = self.__provider() call = self.__provider()
if __isawaitable(call): if __isawaitable(call):
@ -3846,6 +3996,29 @@ def merge_dicts(dict1, dict2):
return result return result
def traverse(*providers, types=None):
"""Return providers traversal generator."""
visited = set()
to_visit = set(providers)
if types:
types = tuple(types)
while len(to_visit) > 0:
visiting = to_visit.pop()
visited.add(visiting)
for child in visiting.related:
if child in visited:
continue
to_visit.add(child)
if types and not isinstance(visiting, types):
continue
yield visiting
def isawaitable(obj): def isawaitable(obj):
"""Check if object is a coroutine function. """Check if object is a coroutine function.

View File

@ -0,0 +1,93 @@
import unittest
from dependency_injector import containers, providers
class TraverseProviderTests(unittest.TestCase):
def test_nested_providers(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
container = Container()
all_providers = list(container.traverse())
self.assertIn(container.obj_factory, all_providers)
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 3)
def test_nested_providers_with_filtering(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
container = Container()
all_providers = list(container.traverse(types=[providers.Resource]))
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 2)
class TraverseProviderDeclarativeTests(unittest.TestCase):
def test_nested_providers(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
all_providers = list(Container.traverse())
self.assertIn(Container.obj_factory, all_providers)
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 3)
def test_nested_providers_with_filtering(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
all_providers = list(Container.traverse(types=[providers.Resource]))
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 2)

View File

@ -337,9 +337,9 @@ class ResourceTests(unittest.TestCase):
self.assertEqual( self.assertEqual(
repr(provider), repr(provider),
'Resource({0}, initialized={1})'.format( '<dependency_injector.providers.Resource({0}) at {1}>'.format(
init_fn, repr(init_fn),
provider.initialized, hex(id(provider)),
) )
) )

View File

@ -0,0 +1,875 @@
import unittest
from dependency_injector import containers, providers
class TraverseTests(unittest.TestCase):
def test_traverse_cycled_graph(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
all_providers = list(providers.traverse(provider1))
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(providers.traverse(provider, types=[providers.Resource]))
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class ProviderTests(unittest.TestCase):
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traversal_overriding_nested(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Provider()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse(types=[providers.Resource]))
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class ObjectTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Object('string')
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Object(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
another_provider_1 = providers.Provider()
another_provider_2 = providers.Provider()
another_provider_3 = providers.Provider()
provider = providers.Object(another_provider_1)
provider.override(another_provider_2)
provider.override(another_provider_3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(another_provider_1, all_providers)
self.assertIn(another_provider_2, all_providers)
self.assertIn(another_provider_3, all_providers)
class DelegateTests(unittest.TestCase):
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Delegate(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Delegate(provider1)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class DependencyTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Dependency()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.Dependency(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider = providers.Dependency()
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class DependenciesContainerTests(unittest.TestCase):
def test_traversal(self):
provider = providers.DependenciesContainer()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.DependenciesContainer(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_fluent_interface(self):
provider = providers.DependenciesContainer()
provider1 = provider.provider1
provider2 = provider.provider2
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.DependenciesContainer(
provider1=provider1,
provider2=provider2,
)
provider = providers.DependenciesContainer()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
self.assertIn(provider.provider1, all_providers)
self.assertIn(provider.provider2, all_providers)
class CallableTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Callable(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Callable(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ConfigurationTests(unittest.TestCase):
def test_traverse(self):
config = providers.Configuration(default={'option1': {'option2': 'option2'}})
option1 = config.option1
option2 = config.option1.option2
option3 = config.option1[config.option1.option2]
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(option1, all_providers)
self.assertIn(option2, all_providers)
self.assertIn(option3, all_providers)
def test_traverse_typed(self):
config = providers.Configuration()
option = config.option
typed_option = config.option.as_int()
all_providers = list(typed_option.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(option, all_providers)
def test_traverse_overridden(self):
options = {'option1': {'option2': 'option2'}}
config = providers.Configuration()
config.from_dict(options)
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 1)
overridden, = all_providers
self.assertEqual(overridden(), options)
self.assertIs(overridden, config.last_overriding)
def test_traverse_overridden_option_1(self):
options = {'option2': 'option2'}
config = providers.Configuration()
config.option1.from_dict(options)
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(config.option1, all_providers)
self.assertIn(config.last_overriding, all_providers)
def test_traverse_overridden_option_2(self):
options = {'option2': 'option2'}
config = providers.Configuration()
config.option1.from_dict(options)
all_providers = list(config.option1.traverse())
self.assertEqual(len(all_providers), 0)
class FactoryTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Factory(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_attributes(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict)
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Factory(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class FactoryAggregateTests(unittest.TestCase):
def test_traverse(self):
factory1 = providers.Factory(dict)
factory2 = providers.Factory(list)
provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(factory1, all_providers)
self.assertIn(factory2, all_providers)
class BaseSingletonTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Singleton(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_attributes(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict)
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Singleton(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ListTests(unittest.TestCase):
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.List('foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider3 = providers.List(provider1, provider2)
provider = providers.List('foo')
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class DictTests(unittest.TestCase):
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Dict(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider3 = providers.Dict(bar=provider1, baz=provider2)
provider = providers.Dict(foo='foo')
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ResourceTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Resource(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Resource(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Resource(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Resource(list)
provider2 = providers.Resource(tuple)
provider = providers.Resource(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider = providers.Resource(provider1)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(provider1, all_providers)
class ContainerTests(unittest.TestCase):
def test_traverse(self):
class Container(containers.DeclarativeContainer):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Container(Container)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertEqual(
{provider.provides for provider in all_providers},
{list, dict},
)
def test_traverse_overridden(self):
class Container1(containers.DeclarativeContainer):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
class Container2(containers.DeclarativeContainer):
provider1 = providers.Callable(tuple)
provider2 = providers.Callable(str)
container2 = Container2()
provider = providers.Container(Container1)
provider.override(container2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5)
self.assertEqual(
{
provider.provides
for provider in all_providers
if isinstance(provider, providers.Callable)
},
{list, dict, tuple, str},
)
self.assertIn(provider.last_overriding, all_providers)
self.assertIs(provider.last_overriding(), container2)
class SelectorTests(unittest.TestCase):
def test_traverse(self):
switch = lambda: 'provider1'
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_switch(self):
switch = providers.Callable(lambda: 'provider1')
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(switch, all_providers)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
selector1 = providers.Selector(lambda: 'provider1', provider1=provider1)
provider = providers.Selector(
lambda: 'provider2',
provider2=provider2,
)
provider.override(selector1)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(selector1, all_providers)
class ProvidedInstanceTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provider = provider1.provided
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(provider1, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider = provider1.provided
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class AttributeGetterTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
provider = provided.attr
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
provider2 = providers.Provider()
provider = provided.attr
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
class ItemGetterTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
provider = provided['item']
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
provider2 = providers.Provider()
provider = provided['item']
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
class MethodCallerTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider = method.call()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_args(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call('foo', provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call(foo='foo', bar=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call()
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)