mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-07-04 20:33:13 +03:00
Implement providers traversal in first precision
This commit is contained in:
parent
0c1a08174f
commit
90c6759975
File diff suppressed because it is too large
Load Diff
|
@ -7,9 +7,12 @@ from typing import (
|
|||
Union,
|
||||
ClassVar,
|
||||
Callable as _Callable,
|
||||
Sequence,
|
||||
Iterable,
|
||||
Iterator,
|
||||
TypeVar,
|
||||
Awaitable,
|
||||
overload,
|
||||
)
|
||||
|
||||
from .providers import Provider
|
||||
|
@ -41,6 +44,12 @@ class Container:
|
|||
def init_resources(self) -> Optional[Awaitable]: ...
|
||||
def shutdown_resources(self) -> Optional[Awaitable]: ...
|
||||
|
||||
@overload
|
||||
def traverse_providers(self, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
|
||||
@classmethod
|
||||
@overload
|
||||
def traverse_providers(cls, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
|
||||
|
||||
|
||||
class DynamicContainer(Container): ...
|
||||
|
||||
|
|
|
@ -10,16 +10,7 @@ except ImportError:
|
|||
|
||||
import six
|
||||
|
||||
from .errors import Error
|
||||
from .providers cimport (
|
||||
Provider,
|
||||
Object,
|
||||
Resource,
|
||||
Dependency,
|
||||
DependenciesContainer,
|
||||
Container as ContainerProvider,
|
||||
deepcopy,
|
||||
)
|
||||
from . import providers, errors
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 6):
|
||||
|
@ -68,13 +59,13 @@ class DynamicContainer(object):
|
|||
|
||||
:rtype: None
|
||||
"""
|
||||
self.provider_type = Provider
|
||||
self.provider_type = providers.Provider
|
||||
self.providers = {}
|
||||
self.overridden = tuple()
|
||||
self.declarative_parent = None
|
||||
self.wired_to_modules = []
|
||||
self.wired_to_packages = []
|
||||
self.__self__ = Object(self)
|
||||
self.__self__ = providers.Object(self)
|
||||
super(DynamicContainer, self).__init__()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
@ -84,11 +75,11 @@ class DynamicContainer(object):
|
|||
return copied
|
||||
|
||||
copied = self.__class__()
|
||||
copied.provider_type = Provider
|
||||
copied.overridden = deepcopy(self.overridden, memo)
|
||||
copied.provider_type = providers.Provider
|
||||
copied.overridden = providers.deepcopy(self.overridden, memo)
|
||||
copied.declarative_parent = self.declarative_parent
|
||||
|
||||
for name, provider in deepcopy(self.providers, memo).items():
|
||||
for name, provider in providers.deepcopy(self.providers, memo).items():
|
||||
setattr(copied, name, provider)
|
||||
|
||||
return copied
|
||||
|
@ -107,7 +98,7 @@ class DynamicContainer(object):
|
|||
|
||||
:rtype: None
|
||||
"""
|
||||
if isinstance(value, Provider) and name != '__self__':
|
||||
if isinstance(value, providers.Provider) and name != '__self__':
|
||||
_check_provider_type(self, value)
|
||||
self.providers[name] = value
|
||||
super(DynamicContainer, self).__setattr__(name, value)
|
||||
|
@ -140,9 +131,13 @@ class DynamicContainer(object):
|
|||
return {
|
||||
name: provider
|
||||
for name, provider in self.providers.items()
|
||||
if isinstance(provider, (Dependency, DependenciesContainer))
|
||||
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
|
||||
}
|
||||
|
||||
def traverse_providers(self, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
yield from providers.traverse(*self.providers.values(), types=types)
|
||||
|
||||
def set_providers(self, **providers):
|
||||
"""Set container providers.
|
||||
|
||||
|
@ -167,8 +162,8 @@ class DynamicContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if overriding is self:
|
||||
raise Error('Container {0} could not be overridden '
|
||||
'with itself'.format(self))
|
||||
raise errors.Error('Container {0} could not be overridden '
|
||||
'with itself'.format(self))
|
||||
|
||||
self.overridden += (overriding,)
|
||||
|
||||
|
@ -197,7 +192,7 @@ class DynamicContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if not self.overridden:
|
||||
raise Error('Container {0} is not overridden'.format(self))
|
||||
raise errors.Error('Container {0} is not overridden'.format(self))
|
||||
|
||||
self.overridden = self.overridden[:-1]
|
||||
|
||||
|
@ -245,7 +240,7 @@ class DynamicContainer(object):
|
|||
"""Initialize all container resources."""
|
||||
futures = []
|
||||
for provider in self.providers.values():
|
||||
if not isinstance(provider, Resource):
|
||||
if not isinstance(provider, providers.Resource):
|
||||
continue
|
||||
|
||||
resource = provider.init()
|
||||
|
@ -260,7 +255,7 @@ class DynamicContainer(object):
|
|||
"""Shutdown all container resources."""
|
||||
futures = []
|
||||
for provider in self.providers.values():
|
||||
if not isinstance(provider, Resource):
|
||||
if not isinstance(provider, providers.Resource):
|
||||
continue
|
||||
|
||||
shutdown = provider.shutdown()
|
||||
|
@ -274,7 +269,7 @@ class DynamicContainer(object):
|
|||
def apply_container_providers_overridings(self):
|
||||
"""Apply container providers' overridings."""
|
||||
for provider in self.providers.values():
|
||||
if not isinstance(provider, ContainerProvider):
|
||||
if not isinstance(provider, providers.Container):
|
||||
continue
|
||||
provider.apply_overridings()
|
||||
|
||||
|
@ -293,7 +288,7 @@ class DeclarativeContainerMetaClass(type):
|
|||
cls_providers = {
|
||||
name: provider
|
||||
for name, provider in six.iteritems(attributes)
|
||||
if isinstance(provider, Provider)
|
||||
if isinstance(provider, providers.Provider)
|
||||
}
|
||||
|
||||
inherited_providers = {
|
||||
|
@ -303,18 +298,18 @@ class DeclarativeContainerMetaClass(type):
|
|||
for name, provider in six.iteritems(base.providers)
|
||||
}
|
||||
|
||||
providers = {}
|
||||
providers.update(inherited_providers)
|
||||
providers.update(cls_providers)
|
||||
all_providers = {}
|
||||
all_providers.update(inherited_providers)
|
||||
all_providers.update(cls_providers)
|
||||
|
||||
attributes['containers'] = containers
|
||||
attributes['inherited_providers'] = inherited_providers
|
||||
attributes['cls_providers'] = cls_providers
|
||||
attributes['providers'] = providers
|
||||
attributes['providers'] = all_providers
|
||||
|
||||
cls = <type>type.__new__(mcs, class_name, bases, attributes)
|
||||
|
||||
cls.__self__ = Object(cls)
|
||||
cls.__self__ = providers.Object(cls)
|
||||
|
||||
for provider in six.itervalues(cls.providers):
|
||||
_check_provider_type(cls, provider)
|
||||
|
@ -335,7 +330,7 @@ class DeclarativeContainerMetaClass(type):
|
|||
|
||||
:rtype: None
|
||||
"""
|
||||
if isinstance(value, Provider) and name != '__self__':
|
||||
if isinstance(value, providers.Provider) and name != '__self__':
|
||||
_check_provider_type(cls, value)
|
||||
cls.providers[name] = value
|
||||
cls.cls_providers[name] = value
|
||||
|
@ -370,9 +365,13 @@ class DeclarativeContainerMetaClass(type):
|
|||
return {
|
||||
name: provider
|
||||
for name, provider in cls.providers.items()
|
||||
if isinstance(provider, (Dependency, DependenciesContainer))
|
||||
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
|
||||
}
|
||||
|
||||
def traverse_providers(cls, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
yield from providers.traverse(*cls.providers.values(), types=types)
|
||||
|
||||
|
||||
@six.add_metaclass(DeclarativeContainerMetaClass)
|
||||
class DeclarativeContainer(object):
|
||||
|
@ -388,7 +387,7 @@ class DeclarativeContainer(object):
|
|||
|
||||
__IS_CONTAINER__ = True
|
||||
|
||||
provider_type = Provider
|
||||
provider_type = providers.Provider
|
||||
"""Type of providers that could be placed in container.
|
||||
|
||||
:type: type
|
||||
|
@ -446,7 +445,7 @@ class DeclarativeContainer(object):
|
|||
container = cls.instance_type()
|
||||
container.provider_type = cls.provider_type
|
||||
container.declarative_parent = cls
|
||||
container.set_providers(**deepcopy(cls.providers))
|
||||
container.set_providers(**providers.deepcopy(cls.providers))
|
||||
container.override_providers(**overriding_providers)
|
||||
container.apply_container_providers_overridings()
|
||||
return container
|
||||
|
@ -464,8 +463,8 @@ class DeclarativeContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if issubclass(cls, overriding):
|
||||
raise Error('Container {0} could not be overridden '
|
||||
'with itself or its subclasses'.format(cls))
|
||||
raise errors.Error('Container {0} could not be overridden '
|
||||
'with itself or its subclasses'.format(cls))
|
||||
|
||||
cls.overridden += (overriding,)
|
||||
|
||||
|
@ -482,7 +481,7 @@ class DeclarativeContainer(object):
|
|||
:rtype: None
|
||||
"""
|
||||
if not cls.overridden:
|
||||
raise Error('Container {0} is not overridden'.format(cls))
|
||||
raise errors.Error('Container {0} is not overridden'.format(cls))
|
||||
|
||||
cls.overridden = cls.overridden[:-1]
|
||||
|
||||
|
@ -559,7 +558,7 @@ def copy(object container):
|
|||
def _decorator(copied_container):
|
||||
memo = _get_providers_memo(copied_container.cls_providers, container.providers)
|
||||
|
||||
providers_copy = deepcopy(container.providers, memo)
|
||||
providers_copy = providers.deepcopy(container.providers, memo)
|
||||
for name, provider in six.iteritems(providers_copy):
|
||||
setattr(copied_container, name, provider)
|
||||
|
||||
|
@ -580,8 +579,8 @@ cpdef bint is_container(object instance):
|
|||
|
||||
cpdef object _check_provider_type(object container, object provider):
|
||||
if not isinstance(provider, container.provider_type):
|
||||
raise Error('{0} can contain only {1} '
|
||||
'instances'.format(container, container.provider_type))
|
||||
raise errors.Error('{0} can contain only {1} '
|
||||
'instances'.format(container, container.provider_type))
|
||||
|
||||
|
||||
cpdef bint _isawaitable(object instance):
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -13,6 +13,7 @@ from typing import (
|
|||
Dict as _Dict,
|
||||
Optional,
|
||||
Union,
|
||||
Sequence,
|
||||
Coroutine as _Coroutine,
|
||||
Iterator as _Iterator,
|
||||
AsyncIterator as _AsyncIterator,
|
||||
|
@ -68,6 +69,8 @@ class Provider(Generic[T]):
|
|||
def is_async_mode_enabled(self) -> bool: ...
|
||||
def is_async_mode_disabled(self) -> bool: ...
|
||||
def is_async_mode_undefined(self) -> bool: ...
|
||||
@property
|
||||
def providers_traversal(self) -> _Iterator[Provider]: ...
|
||||
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
|
||||
|
||||
|
||||
|
@ -382,6 +385,9 @@ def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
|
|||
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
|
||||
|
||||
|
||||
def traverse(*providers: Provider, types: Optional[Sequence[Type]]=None) -> _Iterator[Provider]: ...
|
||||
|
||||
|
||||
if yaml:
|
||||
class YamlLoader(yaml.SafeLoader): ...
|
||||
else:
|
||||
|
|
|
@ -363,6 +363,11 @@ cdef class Provider(object):
|
|||
"""Check if async mode is undefined."""
|
||||
return self.__async_mode == ASYNC_MODE_UNDEFINED
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
yield from self.overridden
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Providing strategy implementation.
|
||||
|
||||
|
@ -423,6 +428,13 @@ cdef class Object(Provider):
|
|||
"""
|
||||
return self.__str__()
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
if isinstance(self.__provides, Provider):
|
||||
yield self.__provides
|
||||
yield from super().providers_traversal
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return provided instance.
|
||||
|
||||
|
@ -487,6 +499,12 @@ cdef class Delegate(Provider):
|
|||
"""Return provider."""
|
||||
return self.__provides
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
yield self.__provides
|
||||
yield from super().providers_traversal
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return provided instance.
|
||||
|
||||
|
@ -619,6 +637,13 @@ cdef class Dependency(Provider):
|
|||
"""Return default provider."""
|
||||
return self.__default
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
if self.__default is not UNDEFINED:
|
||||
yield self.__default
|
||||
yield from super().providers_traversal
|
||||
|
||||
def provided_by(self, provider):
|
||||
"""Set external dependency provider.
|
||||
|
||||
|
@ -790,6 +815,12 @@ cdef class DependenciesContainer(Object):
|
|||
child.reset_override()
|
||||
super(DependenciesContainer, self).reset_override()
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
yield from self.providers.values()
|
||||
yield from super().providers_traversal
|
||||
|
||||
cpdef object _override_providers(self, object container):
|
||||
"""Override providers with providers from provided container."""
|
||||
for name, dependency_provider in container.providers.items():
|
||||
|
@ -998,6 +1029,14 @@ cdef class Callable(Provider):
|
|||
self.__kwargs_len = len(self.__kwargs)
|
||||
return self
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
yield from filter(is_provider, [self.provides])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from super().providers_traversal
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return result of provided callable's call."""
|
||||
return __callable_call(self, args, kwargs)
|
||||
|
@ -1442,6 +1481,12 @@ cdef class ConfigurationOption(Provider):
|
|||
|
||||
self.override(value)
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
# TODO: revise
|
||||
yield from super().providers_traversal
|
||||
|
||||
def _is_strict_mode_enabled(self):
|
||||
return self.__root.__strict
|
||||
|
||||
|
@ -1979,6 +2024,15 @@ cdef class Factory(Provider):
|
|||
self.__attributes_len = len(self.__attributes)
|
||||
return self
|
||||
|
||||
@property
|
||||
def providers_traversal(self):
|
||||
"""Return providers traversal generator."""
|
||||
yield from filter(is_provider, [self.provides])
|
||||
yield from filter(is_provider, self.args)
|
||||
yield from filter(is_provider, self.kwargs.values())
|
||||
yield from filter(is_provider, self.attributes.values())
|
||||
yield from super().providers_traversal
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
"""Return new instance."""
|
||||
return __factory_call(self, args, kwargs)
|
||||
|
@ -3846,6 +3900,29 @@ def merge_dicts(dict1, dict2):
|
|||
return result
|
||||
|
||||
|
||||
def traverse(*providers, types=None):
|
||||
"""Return providers traversal generator."""
|
||||
visited = set()
|
||||
to_visit = set(providers)
|
||||
|
||||
if types:
|
||||
types = tuple(types)
|
||||
|
||||
while len(to_visit) > 0:
|
||||
visiting = to_visit.pop()
|
||||
visited.add(visiting)
|
||||
|
||||
for child in visiting.providers_traversal:
|
||||
if child in visited:
|
||||
continue
|
||||
to_visit.add(child)
|
||||
|
||||
if types and not isinstance(visiting, types):
|
||||
continue
|
||||
|
||||
yield visiting
|
||||
|
||||
|
||||
def isawaitable(obj):
|
||||
"""Check if object is a coroutine function.
|
||||
|
||||
|
|
|
@ -431,3 +431,90 @@ class DeclarativeContainerTests(unittest.TestCase):
|
|||
self.assertIsInstance(container.p31, providers.Provider)
|
||||
self.assertIsInstance(container.p32, providers.Provider)
|
||||
self.assertIs(container.p11.last_overriding, provider)
|
||||
|
||||
|
||||
class ProvidersTraversalTests(unittest.TestCase):
|
||||
|
||||
def test_nested_providers(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
container = Container()
|
||||
all_providers = list(container.traverse_providers())
|
||||
|
||||
self.assertIn(container.obj_factory, all_providers)
|
||||
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
|
||||
def test_nested_providers_class(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
all_providers = list(Container.traverse_providers())
|
||||
|
||||
self.assertIn(Container.obj_factory, all_providers)
|
||||
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 3)
|
||||
|
||||
def test_nested_providers_with_filtering(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
container = Container()
|
||||
all_providers = list(container.traverse_providers(types=[providers.Resource]))
|
||||
|
||||
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
|
||||
def test_nested_providers_class_with_filtering(self):
|
||||
class Container(containers.DeclarativeContainer):
|
||||
obj_factory = providers.DelegatedFactory(
|
||||
dict,
|
||||
foo=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
),
|
||||
bar=providers.Resource(
|
||||
dict,
|
||||
foo='bar'
|
||||
)
|
||||
)
|
||||
|
||||
all_providers = list(Container.traverse_providers(types=[providers.Resource]))
|
||||
|
||||
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
|
||||
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
|
||||
self.assertEqual(len(all_providers), 2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user