Update traverse interface + add some tests

This commit is contained in:
Roman Mogylatov 2021-01-30 16:24:37 -05:00
parent 0d4b9574e6
commit 897f2d3110
4 changed files with 6746 additions and 6202 deletions

File diff suppressed because it is too large Load Diff

View File

@ -13,8 +13,8 @@ from typing import (
Dict as _Dict,
Optional,
Union,
Sequence,
Coroutine as _Coroutine,
Iterable as _Iterable,
Iterator as _Iterator,
AsyncIterator as _AsyncIterator,
Generator as _Generator,
@ -70,7 +70,8 @@ class Provider(Generic[T]):
def is_async_mode_disabled(self) -> bool: ...
def is_async_mode_undefined(self) -> bool: ...
@property
def providers_traversal(self) -> _Iterator[Provider]: ...
def related(self) -> _Iterator[Provider]: ...
def traverse(self, types: Optional[_Iterable[Type]] = None) -> _Iterator[Provider]: ...
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
@ -385,7 +386,7 @@ def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
def traverse(*providers: Provider, types: Optional[Sequence[Type]]=None) -> _Iterator[Provider]: ...
def traverse(*providers: Provider, types: Optional[_Iterable[Type]]=None) -> _Iterator[Provider]: ...
if yaml:

View File

@ -364,10 +364,14 @@ cdef class Provider(object):
return self.__async_mode == ASYNC_MODE_UNDEFINED
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from self.overridden
def traverse(self, types=None):
"""Return providers traversal generator."""
return traverse(*self.related)
cpdef object _provide(self, tuple args, dict kwargs):
"""Providing strategy implementation.
@ -429,11 +433,11 @@ cdef class Object(Provider):
return self.__str__()
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
if isinstance(self.__provides, Provider):
yield self.__provides
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance.
@ -500,10 +504,10 @@ cdef class Delegate(Provider):
return self.__provides
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield self.__provides
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance.
@ -638,11 +642,11 @@ cdef class Dependency(Provider):
return self.__default
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
if self.__default is not UNDEFINED:
yield self.__default
yield from super().providers_traversal
yield from super().related
def provided_by(self, provider):
"""Set external dependency provider.
@ -816,10 +820,10 @@ cdef class DependenciesContainer(Object):
super(DependenciesContainer, self).reset_override()
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from self.providers.values()
yield from super().providers_traversal
yield from super().related
cpdef object _override_providers(self, object container):
"""Override providers with providers from provided container."""
@ -1030,12 +1034,12 @@ cdef class Callable(Provider):
return self
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call."""
@ -1482,12 +1486,12 @@ cdef class ConfigurationOption(Provider):
self.override(value)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.__name)
yield from [self.__root]
yield from self.__children.values()
yield from super().providers_traversal
yield from super().related
def _is_strict_mode_enabled(self):
return self.__root.__strict
@ -1811,10 +1815,10 @@ cdef class Configuration(Object):
self.override(value)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from self.__children.values()
yield from super().providers_traversal
yield from super().related
def _is_strict_mode_enabled(self):
return self.__strict
@ -2033,13 +2037,13 @@ cdef class Factory(Provider):
return self
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from filter(is_provider, self.attributes.values())
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return new instance."""
@ -2204,10 +2208,10 @@ cdef class FactoryAggregate(Provider):
'{0} providers could not be overridden'.format(self.__class__))
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from self.__factories.values()
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
try:
@ -2383,10 +2387,13 @@ cdef class BaseSingleton(Provider):
raise NotImplementedError()
@property
def providers_traversal(self):
"""Return providers traversal generator."""
yield self.__instantiator
yield from super().providers_traversal
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__instantiator.provs])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from filter(is_provider, self.attributes.values())
yield from super().related
def _async_init_instance(self, future_result, result):
try:
@ -2816,10 +2823,10 @@ cdef class List(Provider):
return self
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.args)
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call."""
@ -2934,10 +2941,10 @@ cdef class Dict(Provider):
return self
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, self.kwargs.values())
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call."""
@ -3108,12 +3115,12 @@ cdef class Resource(Provider):
return result
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__initializer])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
if self.__initialized:
@ -3341,10 +3348,10 @@ cdef class Container(Provider):
self.__container.override_providers(**self.__overriding_providers)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from self.providers.values()
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance."""
@ -3436,11 +3443,11 @@ cdef class Selector(Provider):
return dict(self.__providers)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield from filter(is_provider, [self.__selector])
yield from self.providers.values()
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
"""Return single instance."""
@ -3519,10 +3526,10 @@ cdef class ProvidedInstance(Provider):
return MethodCaller(self, *args, **kwargs)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
return self.__provider(*args, **kwargs)
@ -3574,10 +3581,10 @@ cdef class AttributeGetter(Provider):
return MethodCaller(self, *args, **kwargs)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
@ -3640,10 +3647,10 @@ cdef class ItemGetter(Provider):
return MethodCaller(self, *args, **kwargs)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs)
@ -3738,12 +3745,12 @@ cdef class MethodCaller(Provider):
return MethodCaller(self, *args, **kwargs)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
def related(self):
"""Return related providers generator."""
yield self.__provider
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().providers_traversal
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
call = self.__provider()
@ -3989,9 +3996,10 @@ def traverse(*providers, types=None):
while len(to_visit) > 0:
visiting = to_visit.pop()
visited.add(visiting)
for child in visiting.providers_traversal:
for child in visiting.related:
if child in visited:
continue
to_visit.add(child)

View File

@ -0,0 +1,209 @@
import unittest
from dependency_injector import providers
class ProviderTests(unittest.TestCase):
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider = providers.Provider()
provider.override(provider1)
provider.override(provider2)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traversal_overriding_nested(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Provider()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
def test_traversal_overriding_cycled(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider3 = providers.Provider()
provider3.override(provider2)
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
provider = providers.Provider()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class ObjectTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Object('string')
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Object(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
another_provider_1 = providers.Provider()
another_provider_2 = providers.Provider()
another_provider_3 = providers.Provider()
provider = providers.Object(another_provider_1)
provider.override(another_provider_2)
provider.override(another_provider_3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(another_provider_1, all_providers)
self.assertIn(another_provider_2, all_providers)
self.assertIn(another_provider_3, all_providers)
class DelegateTests(unittest.TestCase):
def test_traversal_provider(self):
another_provider = providers.Provider()
provider = providers.Delegate(another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_provider_and_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.Provider()
provider3.override(provider2)
provider = providers.Delegate(provider1)
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
class DependencyTests(unittest.TestCase):
def test_traversal(self):
provider = providers.Dependency()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.Dependency(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider2.override(provider1)
provider = providers.Dependency()
provider.override(provider2)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
class DependenciesContainerTests(unittest.TestCase):
def test_traversal(self):
provider = providers.DependenciesContainer()
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0)
def test_traversal_default(self):
another_provider = providers.Provider()
provider = providers.DependenciesContainer(default=another_provider)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1)
self.assertIn(another_provider, all_providers)
def test_traversal_fluent_interface(self):
provider = providers.DependenciesContainer()
provider1 = provider.provider1
provider2 = provider.provider2
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
def test_traversal_overriding(self):
provider1 = providers.Provider()
provider2 = providers.Provider()
provider3 = providers.DependenciesContainer(
provider1=provider1,
provider2=provider2,
)
provider = providers.DependenciesContainer()
provider.override(provider3)
all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5)
self.assertIn(provider1, all_providers)
self.assertIn(provider2, all_providers)
self.assertIn(provider3, all_providers)
self.assertIn(provider.provider1, all_providers)
self.assertIn(provider.provider2, all_providers)