mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-07-04 20:33:13 +03:00
Update traverse interface + add some tests
This commit is contained in:
parent
0d4b9574e6
commit
897f2d3110
File diff suppressed because it is too large
Load Diff
|
@ -13,8 +13,8 @@ from typing import (
|
|||
Dict as _Dict,
|
||||
Optional,
|
||||
Union,
|
||||
Sequence,
|
||||
Coroutine as _Coroutine,
|
||||
Iterable as _Iterable,
|
||||
Iterator as _Iterator,
|
||||
AsyncIterator as _AsyncIterator,
|
||||
Generator as _Generator,
|
||||
|
@ -70,7 +70,8 @@ class Provider(Generic[T]):
|
|||
def is_async_mode_disabled(self) -> bool: ...
|
||||
def is_async_mode_undefined(self) -> bool: ...
|
||||
@property
|
||||
def providers_traversal(self) -> _Iterator[Provider]: ...
|
||||
def related(self) -> _Iterator[Provider]: ...
|
||||
def traverse(self, types: Optional[_Iterable[Type]] = None) -> _Iterator[Provider]: ...
|
||||
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
|
||||
|
||||
|
||||
|
@ -385,7 +386,7 @@ def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
|
|||
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
|
||||
|
||||
|
||||
def traverse(*providers: Provider, types: Optional[Sequence[Type]]=None) -> _Iterator[Provider]: ...
|
||||
def traverse(*providers: Provider, types: Optional[_Iterable[Type]]=None) -> _Iterator[Provider]: ...
|
||||
|
||||
|
||||
if yaml:
|
||||
|
|
|
@ -364,10 +364,14 @@ cdef class Provider(object):
|
|||
return self.__async_mode == ASYNC_MODE_UNDEFINED
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.overridden
|
||||
|
||||
def traverse(self, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
return traverse(*self.related)
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Providing strategy implementation.
|
||||
|
||||
|
@ -429,11 +433,11 @@ cdef class Object(Provider):
|
|||
return self.__str__()
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
if isinstance(self.__provides, Provider):
|
||||
yield self.__provides
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return provided instance.
|
||||
|
@ -500,10 +504,10 @@ cdef class Delegate(Provider):
|
|||
return self.__provides
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provides
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return provided instance.
|
||||
|
@ -638,11 +642,11 @@ cdef class Dependency(Provider):
|
|||
return self.__default
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
if self.__default is not UNDEFINED:
|
||||
yield self.__default
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
def provided_by(self, provider):
|
||||
"""Set external dependency provider.
|
||||
|
@ -816,10 +820,10 @@ cdef class DependenciesContainer(Object):
|
|||
super(DependenciesContainer, self).reset_override()
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.providers.values()
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _override_providers(self, object container):
|
||||
"""Override providers with providers from provided container."""
|
||||
|
@ -1030,12 +1034,12 @@ cdef class Callable(Provider):
|
|||
return self
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.provides])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return result of provided callable's call."""
|
||||
|
@ -1482,12 +1486,12 @@ cdef class ConfigurationOption(Provider):
|
|||
self.override(value)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, self.__name)
|
||||
yield from [self.__root]
|
||||
yield from self.__children.values()
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
def _is_strict_mode_enabled(self):
|
||||
return self.__root.__strict
|
||||
|
@ -1811,10 +1815,10 @@ cdef class Configuration(Object):
|
|||
self.override(value)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.__children.values()
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
def _is_strict_mode_enabled(self):
|
||||
return self.__strict
|
||||
|
@ -2033,13 +2037,13 @@ cdef class Factory(Provider):
|
|||
return self
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.provides])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from filter(is_provider, self.attributes.values())
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return new instance."""
|
||||
|
@ -2204,10 +2208,10 @@ cdef class FactoryAggregate(Provider):
|
|||
'{0} providers could not be overridden'.format(self.__class__))
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.__factories.values()
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
try:
|
||||
|
@ -2383,10 +2387,13 @@ cdef class BaseSingleton(Provider):
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
yield self.__instantiator
|
||||
yield from super().providers_traversal
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.__instantiator.provs])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from filter(is_provider, self.attributes.values())
|
||||
yield from super().related
|
||||
|
||||
def _async_init_instance(self, future_result, result):
|
||||
try:
|
||||
|
@ -2816,10 +2823,10 @@ cdef class List(Provider):
|
|||
return self
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return result of provided callable's call."""
|
||||
|
@ -2934,10 +2941,10 @@ cdef class Dict(Provider):
|
|||
return self
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return result of provided callable's call."""
|
||||
|
@ -3108,12 +3115,12 @@ cdef class Resource(Provider):
|
|||
return result
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.__initializer])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
if self.__initialized:
|
||||
|
@ -3341,10 +3348,10 @@ cdef class Container(Provider):
|
|||
self.__container.override_providers(**self.__overriding_providers)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from self.providers.values()
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return single instance."""
|
||||
|
@ -3436,11 +3443,11 @@ cdef class Selector(Provider):
|
|||
return dict(self.__providers)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield from filter(is_provider, [self.__selector])
|
||||
yield from self.providers.values()
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return single instance."""
|
||||
|
@ -3519,10 +3526,10 @@ cdef class ProvidedInstance(Provider):
|
|||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
return self.__provider(*args, **kwargs)
|
||||
|
@ -3574,10 +3581,10 @@ cdef class AttributeGetter(Provider):
|
|||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
provided = self.__provider(*args, **kwargs)
|
||||
|
@ -3640,10 +3647,10 @@ cdef class ItemGetter(Provider):
|
|||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
provided = self.__provider(*args, **kwargs)
|
||||
|
@ -3738,12 +3745,12 @@ cdef class MethodCaller(Provider):
|
|||
return MethodCaller(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
def related(self):
|
||||
"""Return related providers generator."""
|
||||
yield self.__provider
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().providers_traversal
|
||||
yield from super().related
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
call = self.__provider()
|
||||
|
@ -3989,9 +3996,10 @@ def traverse(*providers, types=None):
|
|||
|
||||
while len(to_visit) > 0:
|
||||
visiting = to_visit.pop()
|
||||
|
||||
visited.add(visiting)
|
||||
|
||||
for child in visiting.providers_traversal:
|
||||
for child in visiting.related:
|
||||
if child in visited:
|
||||
continue
|
||||
to_visit.add(child)
|
||||
|
|
209
tests/unit/providers/test_traversal_py3.py
Normal file
209
tests/unit/providers/test_traversal_py3.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
import unittest
|
||||
|
||||
from dependency_injector import providers
|
||||
|
||||
|
||||
class ProviderTests(unittest.TestCase):
|
||||
|
||||
def test_traversal_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
provider2 = providers.Provider()
|
||||
provider3 = providers.Provider()
|
||||
|
||||
provider = providers.Provider()
|
||||
|
||||
provider.override(provider1)
|
||||
provider.override(provider2)
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
def test_traversal_overriding_nested(self):
|
||||
provider1 = providers.Provider()
|
||||
|
||||
provider2 = providers.Provider()
|
||||
provider2.override(provider1)
|
||||
|
||||
provider3 = providers.Provider()
|
||||
provider3.override(provider2)
|
||||
|
||||
provider = providers.Provider()
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
def test_traversal_overriding_cycled(self):
|
||||
provider1 = providers.Provider()
|
||||
|
||||
provider2 = providers.Provider()
|
||||
provider2.override(provider1)
|
||||
|
||||
provider3 = providers.Provider()
|
||||
provider3.override(provider2)
|
||||
|
||||
provider1.override(provider3) # Cycle: provider3 -> provider2 -> provider1 -> provider3
|
||||
|
||||
provider = providers.Provider()
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class ObjectTests(unittest.TestCase):
|
||||
|
||||
def test_traversal(self):
|
||||
provider = providers.Object('string')
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traversal_provider(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.Object(another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_provider_and_overriding(self):
|
||||
another_provider_1 = providers.Provider()
|
||||
another_provider_2 = providers.Provider()
|
||||
another_provider_3 = providers.Provider()
|
||||
|
||||
provider = providers.Object(another_provider_1)
|
||||
|
||||
provider.override(another_provider_2)
|
||||
provider.override(another_provider_3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(another_provider_1, all_providers)
|
||||
self.assertIn(another_provider_2, all_providers)
|
||||
self.assertIn(another_provider_3, all_providers)
|
||||
|
||||
|
||||
class DelegateTests(unittest.TestCase):
|
||||
|
||||
def test_traversal_provider(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.Delegate(another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_provider_and_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
provider2 = providers.Provider()
|
||||
|
||||
provider3 = providers.Provider()
|
||||
provider3.override(provider2)
|
||||
|
||||
provider = providers.Delegate(provider1)
|
||||
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
|
||||
|
||||
class DependencyTests(unittest.TestCase):
|
||||
|
||||
def test_traversal(self):
|
||||
provider = providers.Dependency()
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traversal_default(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.Dependency(default=another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
|
||||
provider2 = providers.Provider()
|
||||
provider2.override(provider1)
|
||||
|
||||
provider = providers.Dependency()
|
||||
provider.override(provider2)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
|
||||
class DependenciesContainerTests(unittest.TestCase):
|
||||
|
||||
def test_traversal(self):
|
||||
provider = providers.DependenciesContainer()
|
||||
all_providers = list(provider.traverse())
|
||||
self.assertEqual(len(all_providers), 0)
|
||||
|
||||
def test_traversal_default(self):
|
||||
another_provider = providers.Provider()
|
||||
provider = providers.DependenciesContainer(default=another_provider)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 1)
|
||||
self.assertIn(another_provider, all_providers)
|
||||
|
||||
def test_traversal_fluent_interface(self):
|
||||
provider = providers.DependenciesContainer()
|
||||
provider1 = provider.provider1
|
||||
provider2 = provider.provider2
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
|
||||
def test_traversal_overriding(self):
|
||||
provider1 = providers.Provider()
|
||||
provider2 = providers.Provider()
|
||||
provider3 = providers.DependenciesContainer(
|
||||
provider1=provider1,
|
||||
provider2=provider2,
|
||||
)
|
||||
|
||||
provider = providers.DependenciesContainer()
|
||||
provider.override(provider3)
|
||||
|
||||
all_providers = list(provider.traverse())
|
||||
|
||||
self.assertEqual(len(all_providers), 5)
|
||||
self.assertIn(provider1, all_providers)
|
||||
self.assertIn(provider2, all_providers)
|
||||
self.assertIn(provider3, all_providers)
|
||||
self.assertIn(provider.provider1, all_providers)
|
||||
self.assertIn(provider.provider2, all_providers)
|
||||
|
Loading…
Reference in New Issue
Block a user