Providers traversal (#385)

* Implement providers traversal in first precision

* Implement traversal for all providers

* Update traverse interface + add some tests

* Refactor tests

* Add tests for callable provider

* Add configuration tests

* Add Factory tests

* Add FactoryAggrefate tests

* Add .provides attribute to singleton providers

* Add singleton provider tests

* Add list and dict provider tests

* Add resource tests

* Add Container provider tests

* Add Selector provider tests

* Add ProvidedInstance provider tests

* Add AttributeGetter provider tests

* Add ItemGetter provider tests

* Add MethodCaller provider tests

* Refactor container interface

* Update resource provider string representation

* Add .initializer attribute to Resource provider

* Add docs and examples

* Remove not needed EOL in the tests

* Make cosmetic refactoring

* Ignore flake8 line width error in traverse example
This commit is contained in:
Roman Mogylatov 2021-02-01 09:42:21 -05:00 committed by GitHub
parent 0c1a08174f
commit 3ca6dd9af1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 21688 additions and 17638 deletions

View File

@ -23,3 +23,4 @@ Containers module API docs - :py:mod:`dependency_injector.containers`.
dynamic
specialization
overriding
traversal

View File

@ -0,0 +1,33 @@
Container providers traversal
-----------------------------
To traverse container providers use method ``.traverse()``.
.. literalinclude:: ../../examples/containers/traverse.py
:language: python
:lines: 3-
:emphasize-lines: 38
Method ``.traverse()`` returns a generator. Traversal generator visits all container providers.
This includes nested providers even if they are not present on the root level of the container.
Traversal generator guarantees that each container provider will be visited only once.
It can traverse cyclic provider graphs.
Traversal generator does not guarantee traversal order.
You can use ``types=[...]`` argument to filter providers. Traversal generator will only return
providers matching specified types.
.. code-block:: python
:emphasize-lines: 3
container = Container()
for provider in container.traverse(types=[providers.Resource]):
print(provider)
# <dependency_injector.providers.Resource(<function init_database at 0x10bd2cb80>) at 0x10d346b40>
# <dependency_injector.providers.Resource(<function init_cache at 0x10be373a0>) at 0x10d346bc0>
.. disqus::

View File

@ -7,6 +7,15 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_
Development version
-------------------
- Add container providers traversal.
- Add ``.provides`` attribute to ``Singleton`` and its subclasses.
It's a consistency change to make ``Singleton`` match ``Callable``
and ``Factory`` interfaces.
- Add ``.initializer`` attribute to ``Resource`` provider.
- Update string representation of ``Resource`` provider.
4.13.2
------
- Fix PyCharm typing warning "Expected type 'Optional[Iterable[ModuleType]]',

View File

@ -0,0 +1,48 @@
"""Container traversal example."""
from dependency_injector import containers, providers
def init_database():
return ...
def init_cache():
return ...
class Service:
def __init__(self, database, cache):
self.database = database
self.cache = cache
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
service = providers.Factory(
Service,
database=providers.Resource(
init_database,
url=config.database_url,
),
cache=providers.Resource(
init_cache,
hosts=config.cache_hosts,
),
)
if __name__ == '__main__':
container = Container()
for provider in container.traverse():
print(provider)
# <dependency_injector.providers.Configuration('config') at 0x10d37d200>
# <dependency_injector.providers.Factory(<class '__main__.Service'>) at 0x10d3a2820>
# <dependency_injector.providers.Resource(<function init_database at 0x10bd2cb80>) at 0x10d346b40>
# <dependency_injector.providers.ConfigurationOption('config.cache_hosts') at 0x10d37d350>
# <dependency_injector.providers.Resource(<function init_cache at 0x10be373a0>) at 0x10d346bc0>
# <dependency_injector.providers.ConfigurationOption('config.database_url') at 0x10d37d2e0>

View File

@ -4,6 +4,7 @@ max_complexity = 10
exclude = types.py
per-file-ignores =
examples/demo/*: F841
examples/containers/traverse.py: E501
examples/providers/async.py: F841
examples/providers/async_overriding.py: F841
examples/wiring/*: F841

File diff suppressed because it is too large Load Diff

View File

@ -7,9 +7,12 @@ from typing import (
Union,
ClassVar,
Callable as _Callable,
Sequence,
Iterable,
Iterator,
TypeVar,
Awaitable,
overload,
)
from .providers import Provider
@ -40,6 +43,11 @@ class Container:
def unwire(self) -> None: ...
def init_resources(self) -> Optional[Awaitable]: ...
def shutdown_resources(self) -> Optional[Awaitable]: ...
@overload
def traverse(self, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
@classmethod
@overload
def traverse(cls, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
class DynamicContainer(Container): ...

View File

@ -10,16 +10,7 @@ except ImportError:
import six
from .errors import Error
from .providers cimport (
Provider,
Object,
Resource,
Dependency,
DependenciesContainer,
Container as ContainerProvider,
deepcopy,
)
from . import providers, errors
if sys.version_info[:2] >= (3, 6):
@ -68,13 +59,13 @@ class DynamicContainer(object):
:rtype: None
"""
self.provider_type = Provider
self.provider_type = providers.Provider
self.providers = {}
self.overridden = tuple()
self.declarative_parent = None
self.wired_to_modules = []
self.wired_to_packages = []
self.__self__ = Object(self)
self.__self__ = providers.Object(self)
super(DynamicContainer, self).__init__()
def __deepcopy__(self, memo):
@ -84,11 +75,11 @@ class DynamicContainer(object):
return copied
copied = self.__class__()
copied.provider_type = Provider
copied.overridden = deepcopy(self.overridden, memo)
copied.provider_type = providers.Provider
copied.overridden = providers.deepcopy(self.overridden, memo)
copied.declarative_parent = self.declarative_parent
for name, provider in deepcopy(self.providers, memo).items():
for name, provider in providers.deepcopy(self.providers, memo).items():
setattr(copied, name, provider)
return copied
@ -107,7 +98,7 @@ class DynamicContainer(object):
:rtype: None
"""
if isinstance(value, Provider) and name != '__self__':
if isinstance(value, providers.Provider) and name != '__self__':
_check_provider_type(self, value)
self.providers[name] = value
super(DynamicContainer, self).__setattr__(name, value)
@ -140,9 +131,13 @@ class DynamicContainer(object):
return {
name: provider
for name, provider in self.providers.items()
if isinstance(provider, (Dependency, DependenciesContainer))
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
}
def traverse(self, types=None):
"""Return providers traversal generator."""
yield from providers.traverse(*self.providers.values(), types=types)
def set_providers(self, **providers):
"""Set container providers.
@ -167,7 +162,7 @@ class DynamicContainer(object):
:rtype: None
"""
if overriding is self:
raise Error('Container {0} could not be overridden '
raise errors.Error('Container {0} could not be overridden '
'with itself'.format(self))
self.overridden += (overriding,)
@ -197,7 +192,7 @@ class DynamicContainer(object):
:rtype: None
"""
if not self.overridden:
raise Error('Container {0} is not overridden'.format(self))
raise errors.Error('Container {0} is not overridden'.format(self))
self.overridden = self.overridden[:-1]
@ -245,7 +240,7 @@ class DynamicContainer(object):
"""Initialize all container resources."""
futures = []
for provider in self.providers.values():
if not isinstance(provider, Resource):
if not isinstance(provider, providers.Resource):
continue
resource = provider.init()
@ -260,7 +255,7 @@ class DynamicContainer(object):
"""Shutdown all container resources."""
futures = []
for provider in self.providers.values():
if not isinstance(provider, Resource):
if not isinstance(provider, providers.Resource):
continue
shutdown = provider.shutdown()
@ -274,7 +269,7 @@ class DynamicContainer(object):
def apply_container_providers_overridings(self):
"""Apply container providers' overridings."""
for provider in self.providers.values():
if not isinstance(provider, ContainerProvider):
if not isinstance(provider, providers.Container):
continue
provider.apply_overridings()
@ -293,7 +288,7 @@ class DeclarativeContainerMetaClass(type):
cls_providers = {
name: provider
for name, provider in six.iteritems(attributes)
if isinstance(provider, Provider)
if isinstance(provider, providers.Provider)
}
inherited_providers = {
@ -303,18 +298,18 @@ class DeclarativeContainerMetaClass(type):
for name, provider in six.iteritems(base.providers)
}
providers = {}
providers.update(inherited_providers)
providers.update(cls_providers)
all_providers = {}
all_providers.update(inherited_providers)
all_providers.update(cls_providers)
attributes['containers'] = containers
attributes['inherited_providers'] = inherited_providers
attributes['cls_providers'] = cls_providers
attributes['providers'] = providers
attributes['providers'] = all_providers
cls = <type>type.__new__(mcs, class_name, bases, attributes)
cls.__self__ = Object(cls)
cls.__self__ = providers.Object(cls)
for provider in six.itervalues(cls.providers):
_check_provider_type(cls, provider)
@ -335,7 +330,7 @@ class DeclarativeContainerMetaClass(type):
:rtype: None
"""
if isinstance(value, Provider) and name != '__self__':
if isinstance(value, providers.Provider) and name != '__self__':
_check_provider_type(cls, value)
cls.providers[name] = value
cls.cls_providers[name] = value
@ -370,9 +365,13 @@ class DeclarativeContainerMetaClass(type):
return {
name: provider
for name, provider in cls.providers.items()
if isinstance(provider, (Dependency, DependenciesContainer))
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
}
def traverse(cls, types=None):
"""Return providers traversal generator."""
yield from providers.traverse(*cls.providers.values(), types=types)
@six.add_metaclass(DeclarativeContainerMetaClass)
class DeclarativeContainer(object):
@ -388,7 +387,7 @@ class DeclarativeContainer(object):
__IS_CONTAINER__ = True
provider_type = Provider
provider_type = providers.Provider
"""Type of providers that could be placed in container.
:type: type
@ -446,7 +445,7 @@ class DeclarativeContainer(object):
container = cls.instance_type()
container.provider_type = cls.provider_type
container.declarative_parent = cls
container.set_providers(**deepcopy(cls.providers))
container.set_providers(**providers.deepcopy(cls.providers))
container.override_providers(**overriding_providers)
container.apply_container_providers_overridings()
return container
@ -464,7 +463,7 @@ class DeclarativeContainer(object):
:rtype: None
"""
if issubclass(cls, overriding):
raise Error('Container {0} could not be overridden '
raise errors.Error('Container {0} could not be overridden '
'with itself or its subclasses'.format(cls))
cls.overridden += (overriding,)
@ -482,7 +481,7 @@ class DeclarativeContainer(object):
:rtype: None
"""
if not cls.overridden:
raise Error('Container {0} is not overridden'.format(cls))
raise errors.Error('Container {0} is not overridden'.format(cls))
cls.overridden = cls.overridden[:-1]
@ -559,7 +558,7 @@ def copy(object container):
def _decorator(copied_container):
memo = _get_providers_memo(copied_container.cls_providers, container.providers)
providers_copy = deepcopy(container.providers, memo)
providers_copy = providers.deepcopy(container.providers, memo)
for name, provider in six.iteritems(providers_copy):
setattr(copied_container, name, provider)
@ -580,7 +579,7 @@ cpdef bint is_container(object instance):
cpdef object _check_provider_type(object container, object provider):
if not isinstance(provider, container.provider_type):
raise Error('{0} can contain only {1} '
raise errors.Error('{0} can contain only {1} '
'instances'.format(container, container.provider_type))

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ from typing import (
Optional,
Union,
Coroutine as _Coroutine,
Iterable as _Iterable,
Iterator as _Iterator,
AsyncIterator as _AsyncIterator,
Generator as _Generator,
@ -68,6 +69,9 @@ class Provider(Generic[T]):
def is_async_mode_enabled(self) -> bool: ...
def is_async_mode_disabled(self) -> bool: ...
def is_async_mode_undefined(self) -> bool: ...
@property
def related(self) -> _Iterator[Provider]: ...
def traverse(self, types: Optional[_Iterable[Type]] = None) -> _Iterator[Provider]: ...
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
@ -238,6 +242,8 @@ class BaseSingleton(Provider[T]):
@property
def cls(self) -> T: ...
@property
def provides(self) -> T: ...
@property
def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> BaseSingleton[T]: ...
def set_args(self, *args: Injection) -> BaseSingleton[T]: ...
@ -313,6 +319,8 @@ class Resource(Provider[T]):
@overload
def __init__(self, initializer: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ...
@property
def initializer(self) -> _Callable[..., Any]: ...
@property
def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Resource[T]: ...
def set_args(self, *args: Injection) -> Resource[T]: ...
@ -382,6 +390,9 @@ def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
def traverse(*providers: Provider, types: Optional[_Iterable[Type]]=None) -> _Iterator[Provider]: ...
if yaml:
class YamlLoader(yaml.SafeLoader): ...
else:

View File

@ -363,6 +363,15 @@ cdef class Provider(object):
"""Check if async mode is undefined."""
return self.__async_mode == ASYNC_MODE_UNDEFINED
@property
def related(self):
"""Return related providers generator."""
yield from self.overridden
def traverse(self, types=None):
"""Return providers traversal generator."""
return traverse(*self.related, types=types)
cpdef object _provide(self, tuple args, dict kwargs):
"""Providing strategy implementation.
@ -423,6 +432,13 @@ cdef class Object(Provider):
"""
return self.__str__()
@property
def related(self):
"""Return related providers generator."""
if isinstance(self.__provides, Provider):
yield self.__provides
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance.
@ -487,6 +503,12 @@ cdef class Delegate(Provider):
"""Return provider."""
return self.__provides
@property
def related(self):
"""Return related providers generator."""
yield self.__provides
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance.
@ -619,6 +641,13 @@ cdef class Dependency(Provider):
"""Return default provider."""
return self.__default
@property
def related(self):
"""Return related providers generator."""
if self.__default is not UNDEFINED:
yield self.__default
yield from super().related
def provided_by(self, provider):
"""Set external dependency provider.
@ -790,6 +819,12 @@ cdef class DependenciesContainer(Object):
child.reset_override()
super(DependenciesContainer, self).reset_override()
@property
def related(self):
"""Return related providers generator."""
yield from self.providers.values()
yield from super().related
cpdef object _override_providers(self, object container):
"""Override providers with providers from provided container."""
for name, dependency_provider in container.providers.items():
@ -998,6 +1033,14 @@ cdef class Callable(Provider):
self.__kwargs_len = len(self.__kwargs)
return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call."""
return __callable_call(self, args, kwargs)
@ -1442,6 +1485,13 @@ cdef class ConfigurationOption(Provider):
self.override(value)
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.__name)
yield from self.__children.values()
yield from super().related
def _is_strict_mode_enabled(self):
return self.__root.__strict
@ -1763,6 +1813,12 @@ cdef class Configuration(Object):
self.override(value)
@property
def related(self):
"""Return related providers generator."""
yield from self.__children.values()
yield from super().related
def _is_strict_mode_enabled(self):
return self.__strict
@ -1979,6 +2035,15 @@ cdef class Factory(Provider):
self.__attributes_len = len(self.__attributes)
return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from filter(is_provider, self.attributes.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return new instance."""
return __factory_call(self, args, kwargs)
@ -2141,6 +2206,12 @@ cdef class FactoryAggregate(Provider):
raise Error(
'{0} providers could not be overridden'.format(self.__class__))
@property
def related(self):
"""Return related providers generator."""
yield from self.__factories.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
try:
factory_name = args[0]
@ -2212,7 +2283,12 @@ cdef class BaseSingleton(Provider):
@property
def cls(self):
"""Return provided type."""
return self.__instantiator.cls
return self.provides
@property
def provides(self):
"""Return provided type."""
return self.__instantiator.provides
@property
def args(self):
@ -2314,6 +2390,15 @@ cdef class BaseSingleton(Provider):
"""
raise NotImplementedError()
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__instantiator.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from filter(is_provider, self.attributes.values())
yield from super().related
def _async_init_instance(self, future_result, result):
try:
instance = result.result()
@ -2741,6 +2826,12 @@ cdef class List(Provider):
self.__args_len = len(self.__args)
return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.args)
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call."""
return list(__provide_positional_args(args, self.__args, self.__args_len))
@ -2853,6 +2944,12 @@ cdef class Dict(Provider):
self.__kwargs_len = len(self.__kwargs)
return self
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call."""
return __provide_keyword_args(kwargs, self.__kwargs, self.__kwargs_len)
@ -2895,11 +2992,17 @@ cdef class Resource(Provider):
return copied
def __repr__(self):
return (
f'{self.__class__.__name__}({self.__initializer}, '
f'initialized={self.__initialized})'
)
def __str__(self):
"""Return string representation of provider.
:rtype: str
"""
return represent_provider(provider=self, provides=self.__initializer)
@property
def initializer(self):
"""Return initializer."""
return self.__initializer
@property
def args(self):
@ -3021,6 +3124,14 @@ cdef class Resource(Provider):
result.set_result(None)
return result
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__initializer])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
if self.__initialized:
return self.__resource
@ -3246,6 +3357,12 @@ cdef class Container(Provider):
declarative container initialization."""
self.__container.override_providers(**self.__overriding_providers)
@property
def related(self):
"""Return related providers generator."""
yield from self.providers.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance."""
return self.__container
@ -3335,6 +3452,13 @@ cdef class Selector(Provider):
"""Return providers."""
return dict(self.__providers)
@property
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__selector])
yield from self.providers.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance."""
selector_value = self.__selector()
@ -3411,6 +3535,12 @@ cdef class ProvidedInstance(Provider):
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
return self.__provider(*args, **kwargs)
@ -3460,6 +3590,12 @@ cdef class AttributeGetter(Provider):
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
@ -3520,6 +3656,12 @@ cdef class ItemGetter(Provider):
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
if __isawaitable(provided):
@ -3612,6 +3754,14 @@ cdef class MethodCaller(Provider):
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)
@property
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
call = self.__provider()
if __isawaitable(call):
@ -3846,6 +3996,29 @@ def merge_dicts(dict1, dict2):
return result
def traverse(*providers, types=None):
"""Return providers traversal generator."""
visited = set()
to_visit = set(providers)
if types:
types = tuple(types)
while len(to_visit) > 0:
visiting = to_visit.pop()
visited.add(visiting)
for child in visiting.related:
if child in visited:
continue
to_visit.add(child)
if types and not isinstance(visiting, types):
continue
yield visiting
def isawaitable(obj):
"""Check if object is a coroutine function.

View File

@ -0,0 +1,93 @@
import unittest
from dependency_injector import containers, providers
class TraverseProviderTests(unittest.TestCase):
def test_nested_providers(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
container = Container()
all_providers = list(container.traverse())
self.assertIn(container.obj_factory, all_providers)
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 3)
def test_nested_providers_with_filtering(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
container = Container()
all_providers = list(container.traverse(types=[providers.Resource]))
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 2)
class TraverseProviderDeclarativeTests(unittest.TestCase):
def test_nested_providers(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
all_providers = list(Container.traverse())
self.assertIn(Container.obj_factory, all_providers)
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 3)
def test_nested_providers_with_filtering(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
all_providers = list(Container.traverse(types=[providers.Resource]))
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 2)

View File

@ -337,9 +337,9 @@ class ResourceTests(unittest.TestCase):
self.assertEqual(
repr(provider),
'Resource({0}, initialized={1})'.format(
init_fn,
provider.initialized,
'<dependency_injector.providers.Resource({0}) at {1}>'.format(
repr(init_fn),
hex(id(provider)),
)
)

View File

@ -0,0 +1,875 @@
import unittest
from dependency_injector import containers, providers
class TraverseTests(unittest.TestCase):
def test_traverse_cycled_graph(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
all_providers = list(providers.traverse(provider1))
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(providers.traverse(provider, types=[providers.Resource]))
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class ProviderTests(unittest.TestCase):
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traversal_overriding_nested(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Provider()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict)
provider2 = providers.Resource(dict)
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse(types=[providers.Resource]))
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class ObjectTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Object('string')
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Object(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
another_provider_1 = providers.Provider()
another_provider_2 = providers.Provider()
another_provider_3 = providers.Provider()
provider = providers.Object(another_provider_1)
provider.override(another_provider_2)
provider.override(another_provider_3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(another_provider_1, all_providers)
self.assertIn(another_provider_2, all_providers)
self.assertIn(another_provider_3, all_providers)
class DelegateTests(unittest.TestCase):
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Delegate(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Delegate(provider1)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class DependencyTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Dependency()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.Dependency(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider = providers.Dependency()
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class DependenciesContainerTests(unittest.TestCase):
def test_traversal(self):
provider = providers.DependenciesContainer()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.DependenciesContainer(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_fluent_interface(self):
provider = providers.DependenciesContainer()
provider1 = provider.provider1
provider2 = provider.provider2
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.DependenciesContainer(
provider1=provider1,
provider2=provider2,
)
provider = providers.DependenciesContainer()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
self.assertIn(provider.provider1, all_providers)
self.assertIn(provider.provider2, all_providers)
class CallableTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Callable(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Callable(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Callable(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ConfigurationTests(unittest.TestCase):
def test_traverse(self):
config = providers.Configuration(default={'option1': {'option2': 'option2'}})
option1 = config.option1
option2 = config.option1.option2
option3 = config.option1[config.option1.option2]
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(option1, all_providers)
self.assertIn(option2, all_providers)
self.assertIn(option3, all_providers)
def test_traverse_typed(self):
config = providers.Configuration()
option = config.option
typed_option = config.option.as_int()
all_providers = list(typed_option.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(option, all_providers)
def test_traverse_overridden(self):
options = {'option1': {'option2': 'option2'}}
config = providers.Configuration()
config.from_dict(options)
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 1)
overridden, = all_providers
self.assertEqual(overridden(), options)
self.assertIs(overridden, config.last_overriding)
def test_traverse_overridden_option_1(self):
options = {'option2': 'option2'}
config = providers.Configuration()
config.option1.from_dict(options)
all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(config.option1, all_providers)
self.assertIn(config.last_overriding, all_providers)
def test_traverse_overridden_option_2(self):
options = {'option2': 'option2'}
config = providers.Configuration()
config.option1.from_dict(options)
all_providers = list(config.option1.traverse())
self.assertEqual(len(all_providers), 0)
class FactoryTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Factory(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_attributes(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict)
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Factory(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Factory(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class FactoryAggregateTests(unittest.TestCase):
def test_traverse(self):
factory1 = providers.Factory(dict)
factory2 = providers.Factory(list)
provider = providers.FactoryAggregate(factory1=factory1, factory2=factory2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(factory1, all_providers)
self.assertIn(factory2, all_providers)
class BaseSingletonTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Singleton(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_attributes(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict)
provider.add_attributes(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Singleton(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider2 = providers.Object('bar')
provider3 = providers.Object('baz')
provider = providers.Singleton(provider1, provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ListTests(unittest.TestCase):
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.List('foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider3 = providers.List(provider1, provider2)
provider = providers.List('foo')
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class DictTests(unittest.TestCase):
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Dict(foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider3 = providers.Dict(bar=provider1, baz=provider2)
provider = providers.Dict(foo='foo')
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ResourceTests(unittest.TestCase):
def test_traverse(self):
provider = providers.Resource(dict)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traverse_args(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Resource(list, 'foo', provider1, provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Object('bar')
provider2 = providers.Object('baz')
provider = providers.Resource(dict, foo='foo', bar=provider1, baz=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Resource(list)
provider2 = providers.Resource(tuple)
provider = providers.Resource(dict, 'foo')
provider.override(provider1)
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_provides(self):
provider1 = providers.Callable(list)
provider = providers.Resource(provider1)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(provider1, all_providers)
class ContainerTests(unittest.TestCase):
def test_traverse(self):
class Container(containers.DeclarativeContainer):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Container(Container)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertEqual(
{provider.provides for provider in all_providers},
{list, dict},
)
def test_traverse_overridden(self):
class Container1(containers.DeclarativeContainer):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
class Container2(containers.DeclarativeContainer):
provider1 = providers.Callable(tuple)
provider2 = providers.Callable(str)
container2 = Container2()
provider = providers.Container(Container1)
provider.override(container2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5)
self.assertEqual(
{
provider.provides
for provider in all_providers
if isinstance(provider, providers.Callable)
},
{list, dict, tuple, str},
)
self.assertIn(provider.last_overriding, all_providers)
self.assertIs(provider.last_overriding(), container2)
class SelectorTests(unittest.TestCase):
def test_traverse(self):
switch = lambda: 'provider1'
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_switch(self):
switch = providers.Callable(lambda: 'provider1')
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
provider = providers.Selector(
switch,
provider1=provider1,
provider2=provider2,
)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(switch, all_providers)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Callable(list)
provider2 = providers.Callable(dict)
selector1 = providers.Selector(lambda: 'provider1', provider1=provider1)
provider = providers.Selector(
lambda: 'provider2',
provider2=provider2,
)
provider.override(selector1)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(selector1, all_providers)
class ProvidedInstanceTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provider = provider1.provided
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(provider1, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider = provider1.provided
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class AttributeGetterTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
provider = provided.attr
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
provider2 = providers.Provider()
provider = provided.attr
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
class ItemGetterTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
provider = provided['item']
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
provider2 = providers.Provider()
provider = provided['item']
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
class MethodCallerTests(unittest.TestCase):
def test_traverse(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider = method.call()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_args(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call('foo', provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_kwargs(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call(foo='foo', bar=provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)
def test_traverse_overridden(self):
provider1 = providers.Provider()
provided = provider1.provided
method = provided.method
provider2 = providers.Provider()
provider = method.call()
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provided, all_providers)
self.assertIn(method, all_providers)