Implement providers traversal in first precision

This commit is contained in:
Roman Mogylatov 2021-01-29 19:26:38 -05:00
parent 0c1a08174f
commit 90c6759975
7 changed files with 15718 additions and 17320 deletions

File diff suppressed because it is too large Load Diff

View File

@ -7,9 +7,12 @@ from typing import (
Union,
ClassVar,
Callable as _Callable,
Sequence,
Iterable,
Iterator,
TypeVar,
Awaitable,
overload,
)
from .providers import Provider
@ -41,6 +44,12 @@ class Container:
def init_resources(self) -> Optional[Awaitable]: ...
def shutdown_resources(self) -> Optional[Awaitable]: ...
@overload
def traverse_providers(self, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
@classmethod
@overload
def traverse_providers(cls, types: Optional[Sequence[Type]] = None) -> Iterator[Provider]: ...
class DynamicContainer(Container): ...

View File

@ -10,16 +10,7 @@ except ImportError:
import six
from .errors import Error
from .providers cimport (
Provider,
Object,
Resource,
Dependency,
DependenciesContainer,
Container as ContainerProvider,
deepcopy,
)
from . import providers, errors
if sys.version_info[:2] >= (3, 6):
@ -68,13 +59,13 @@ class DynamicContainer(object):
:rtype: None
"""
self.provider_type = Provider
self.provider_type = providers.Provider
self.providers = {}
self.overridden = tuple()
self.declarative_parent = None
self.wired_to_modules = []
self.wired_to_packages = []
self.__self__ = Object(self)
self.__self__ = providers.Object(self)
super(DynamicContainer, self).__init__()
def __deepcopy__(self, memo):
@ -84,11 +75,11 @@ class DynamicContainer(object):
return copied
copied = self.__class__()
copied.provider_type = Provider
copied.overridden = deepcopy(self.overridden, memo)
copied.provider_type = providers.Provider
copied.overridden = providers.deepcopy(self.overridden, memo)
copied.declarative_parent = self.declarative_parent
for name, provider in deepcopy(self.providers, memo).items():
for name, provider in providers.deepcopy(self.providers, memo).items():
setattr(copied, name, provider)
return copied
@ -107,7 +98,7 @@ class DynamicContainer(object):
:rtype: None
"""
if isinstance(value, Provider) and name != '__self__':
if isinstance(value, providers.Provider) and name != '__self__':
_check_provider_type(self, value)
self.providers[name] = value
super(DynamicContainer, self).__setattr__(name, value)
@ -140,9 +131,13 @@ class DynamicContainer(object):
return {
name: provider
for name, provider in self.providers.items()
if isinstance(provider, (Dependency, DependenciesContainer))
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
}
def traverse_providers(self, types=None):
"""Return providers traversal generator."""
yield from providers.traverse(*self.providers.values(), types=types)
def set_providers(self, **providers):
"""Set container providers.
@ -167,8 +162,8 @@ class DynamicContainer(object):
:rtype: None
"""
if overriding is self:
raise Error('Container {0} could not be overridden '
'with itself'.format(self))
raise errors.Error('Container {0} could not be overridden '
'with itself'.format(self))
self.overridden += (overriding,)
@ -197,7 +192,7 @@ class DynamicContainer(object):
:rtype: None
"""
if not self.overridden:
raise Error('Container {0} is not overridden'.format(self))
raise errors.Error('Container {0} is not overridden'.format(self))
self.overridden = self.overridden[:-1]
@ -245,7 +240,7 @@ class DynamicContainer(object):
"""Initialize all container resources."""
futures = []
for provider in self.providers.values():
if not isinstance(provider, Resource):
if not isinstance(provider, providers.Resource):
continue
resource = provider.init()
@ -260,7 +255,7 @@ class DynamicContainer(object):
"""Shutdown all container resources."""
futures = []
for provider in self.providers.values():
if not isinstance(provider, Resource):
if not isinstance(provider, providers.Resource):
continue
shutdown = provider.shutdown()
@ -274,7 +269,7 @@ class DynamicContainer(object):
def apply_container_providers_overridings(self):
"""Apply container providers' overridings."""
for provider in self.providers.values():
if not isinstance(provider, ContainerProvider):
if not isinstance(provider, providers.Container):
continue
provider.apply_overridings()
@ -293,7 +288,7 @@ class DeclarativeContainerMetaClass(type):
cls_providers = {
name: provider
for name, provider in six.iteritems(attributes)
if isinstance(provider, Provider)
if isinstance(provider, providers.Provider)
}
inherited_providers = {
@ -303,18 +298,18 @@ class DeclarativeContainerMetaClass(type):
for name, provider in six.iteritems(base.providers)
}
providers = {}
providers.update(inherited_providers)
providers.update(cls_providers)
all_providers = {}
all_providers.update(inherited_providers)
all_providers.update(cls_providers)
attributes['containers'] = containers
attributes['inherited_providers'] = inherited_providers
attributes['cls_providers'] = cls_providers
attributes['providers'] = providers
attributes['providers'] = all_providers
cls = <type>type.__new__(mcs, class_name, bases, attributes)
cls.__self__ = Object(cls)
cls.__self__ = providers.Object(cls)
for provider in six.itervalues(cls.providers):
_check_provider_type(cls, provider)
@ -335,7 +330,7 @@ class DeclarativeContainerMetaClass(type):
:rtype: None
"""
if isinstance(value, Provider) and name != '__self__':
if isinstance(value, providers.Provider) and name != '__self__':
_check_provider_type(cls, value)
cls.providers[name] = value
cls.cls_providers[name] = value
@ -370,9 +365,13 @@ class DeclarativeContainerMetaClass(type):
return {
name: provider
for name, provider in cls.providers.items()
if isinstance(provider, (Dependency, DependenciesContainer))
if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
}
def traverse_providers(cls, types=None):
"""Return providers traversal generator."""
yield from providers.traverse(*cls.providers.values(), types=types)
@six.add_metaclass(DeclarativeContainerMetaClass)
class DeclarativeContainer(object):
@ -388,7 +387,7 @@ class DeclarativeContainer(object):
__IS_CONTAINER__ = True
provider_type = Provider
provider_type = providers.Provider
"""Type of providers that could be placed in container.
:type: type
@ -446,7 +445,7 @@ class DeclarativeContainer(object):
container = cls.instance_type()
container.provider_type = cls.provider_type
container.declarative_parent = cls
container.set_providers(**deepcopy(cls.providers))
container.set_providers(**providers.deepcopy(cls.providers))
container.override_providers(**overriding_providers)
container.apply_container_providers_overridings()
return container
@ -464,8 +463,8 @@ class DeclarativeContainer(object):
:rtype: None
"""
if issubclass(cls, overriding):
raise Error('Container {0} could not be overridden '
'with itself or its subclasses'.format(cls))
raise errors.Error('Container {0} could not be overridden '
'with itself or its subclasses'.format(cls))
cls.overridden += (overriding,)
@ -482,7 +481,7 @@ class DeclarativeContainer(object):
:rtype: None
"""
if not cls.overridden:
raise Error('Container {0} is not overridden'.format(cls))
raise errors.Error('Container {0} is not overridden'.format(cls))
cls.overridden = cls.overridden[:-1]
@ -559,7 +558,7 @@ def copy(object container):
def _decorator(copied_container):
memo = _get_providers_memo(copied_container.cls_providers, container.providers)
providers_copy = deepcopy(container.providers, memo)
providers_copy = providers.deepcopy(container.providers, memo)
for name, provider in six.iteritems(providers_copy):
setattr(copied_container, name, provider)
@ -580,8 +579,8 @@ cpdef bint is_container(object instance):
cpdef object _check_provider_type(object container, object provider):
if not isinstance(provider, container.provider_type):
raise Error('{0} can contain only {1} '
'instances'.format(container, container.provider_type))
raise errors.Error('{0} can contain only {1} '
'instances'.format(container, container.provider_type))
cpdef bint _isawaitable(object instance):

File diff suppressed because it is too large Load Diff

View File

@ -13,6 +13,7 @@ from typing import (
Dict as _Dict,
Optional,
Union,
Sequence,
Coroutine as _Coroutine,
Iterator as _Iterator,
AsyncIterator as _AsyncIterator,
@ -68,6 +69,8 @@ class Provider(Generic[T]):
def is_async_mode_enabled(self) -> bool: ...
def is_async_mode_disabled(self) -> bool: ...
def is_async_mode_undefined(self) -> bool: ...
@property
def providers_traversal(self) -> _Iterator[Provider]: ...
def _copy_overridings(self, copied: Provider, memo: Optional[_Dict[Any, Any]]) -> None: ...
@ -382,6 +385,9 @@ def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
def traverse(*providers: Provider, types: Optional[Sequence[Type]]=None) -> _Iterator[Provider]: ...
if yaml:
class YamlLoader(yaml.SafeLoader): ...
else:

View File

@ -363,6 +363,11 @@ cdef class Provider(object):
"""Check if async mode is undefined."""
return self.__async_mode == ASYNC_MODE_UNDEFINED
@property
def providers_traversal(self):
"""Return providers traversal generator."""
yield from self.overridden
cpdef object _provide(self, tuple args, dict kwargs):
"""Providing strategy implementation.
@ -423,6 +428,13 @@ cdef class Object(Provider):
"""
return self.__str__()
@property
def providers_traversal(self):
"""Return providers traversal generator."""
if isinstance(self.__provides, Provider):
yield self.__provides
yield from super().providers_traversal
cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance.
@ -487,6 +499,12 @@ cdef class Delegate(Provider):
"""Return provider."""
return self.__provides
@property
def providers_traversal(self):
"""Return providers traversal generator."""
yield self.__provides
yield from super().providers_traversal
cpdef object _provide(self, tuple args, dict kwargs):
"""Return provided instance.
@ -619,6 +637,13 @@ cdef class Dependency(Provider):
"""Return default provider."""
return self.__default
@property
def providers_traversal(self):
"""Return providers traversal generator."""
if self.__default is not UNDEFINED:
yield self.__default
yield from super().providers_traversal
def provided_by(self, provider):
"""Set external dependency provider.
@ -790,6 +815,12 @@ cdef class DependenciesContainer(Object):
child.reset_override()
super(DependenciesContainer, self).reset_override()
@property
def providers_traversal(self):
"""Return providers traversal generator."""
yield from self.providers.values()
yield from super().providers_traversal
cpdef object _override_providers(self, object container):
"""Override providers with providers from provided container."""
for name, dependency_provider in container.providers.items():
@ -998,6 +1029,14 @@ cdef class Callable(Provider):
self.__kwargs_len = len(self.__kwargs)
return self
@property
def providers_traversal(self):
"""Return providers traversal generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from super().providers_traversal
cpdef object _provide(self, tuple args, dict kwargs):
"""Return result of provided callable's call."""
return __callable_call(self, args, kwargs)
@ -1442,6 +1481,12 @@ cdef class ConfigurationOption(Provider):
self.override(value)
@property
def providers_traversal(self):
"""Return providers traversal generator."""
# TODO: revise
yield from super().providers_traversal
def _is_strict_mode_enabled(self):
return self.__root.__strict
@ -1979,6 +2024,15 @@ cdef class Factory(Provider):
self.__attributes_len = len(self.__attributes)
return self
@property
def providers_traversal(self):
"""Return providers traversal generator."""
yield from filter(is_provider, [self.provides])
yield from filter(is_provider, self.args)
yield from filter(is_provider, self.kwargs.values())
yield from filter(is_provider, self.attributes.values())
yield from super().providers_traversal
cpdef object _provide(self, tuple args, dict kwargs):
"""Return new instance."""
return __factory_call(self, args, kwargs)
@ -3846,6 +3900,29 @@ def merge_dicts(dict1, dict2):
return result
def traverse(*providers, types=None):
"""Return providers traversal generator."""
visited = set()
to_visit = set(providers)
if types:
types = tuple(types)
while len(to_visit) > 0:
visiting = to_visit.pop()
visited.add(visiting)
for child in visiting.providers_traversal:
if child in visited:
continue
to_visit.add(child)
if types and not isinstance(visiting, types):
continue
yield visiting
def isawaitable(obj):
"""Check if object is a coroutine function.

View File

@ -431,3 +431,90 @@ class DeclarativeContainerTests(unittest.TestCase):
self.assertIsInstance(container.p31, providers.Provider)
self.assertIsInstance(container.p32, providers.Provider)
self.assertIs(container.p11.last_overriding, provider)
class ProvidersTraversalTests(unittest.TestCase):
def test_nested_providers(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
container = Container()
all_providers = list(container.traverse_providers())
self.assertIn(container.obj_factory, all_providers)
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 3)
def test_nested_providers_class(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
all_providers = list(Container.traverse_providers())
self.assertIn(Container.obj_factory, all_providers)
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 3)
def test_nested_providers_with_filtering(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
container = Container()
all_providers = list(container.traverse_providers(types=[providers.Resource]))
self.assertIn(container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 2)
def test_nested_providers_class_with_filtering(self):
class Container(containers.DeclarativeContainer):
obj_factory = providers.DelegatedFactory(
dict,
foo=providers.Resource(
dict,
foo='bar'
),
bar=providers.Resource(
dict,
foo='bar'
)
)
all_providers = list(Container.traverse_providers(types=[providers.Resource]))
self.assertIn(Container.obj_factory.kwargs['foo'], all_providers)
self.assertIn(Container.obj_factory.kwargs['bar'], all_providers)
self.assertEqual(len(all_providers), 2)