Clean up DeclarativeContainer and add tests

This commit is contained in:
Roman Mogylatov 2021-02-13 08:36:45 -05:00
parent 1d884b5101
commit 21c0c82144
4 changed files with 1490 additions and 714 deletions

File diff suppressed because it is too large Load Diff

View File

@ -45,7 +45,11 @@ class Container:
def shutdown_resources(self) -> Optional[Awaitable]: ... def shutdown_resources(self) -> Optional[Awaitable]: ...
def apply_container_providers_overridings(self) -> None: ... def apply_container_providers_overridings(self) -> None: ...
def reset_singletons(self) -> None: ... def reset_singletons(self) -> None: ...
@overload
def resolve_provider_name(self, provider: Provider) -> str: ... def resolve_provider_name(self, provider: Provider) -> str: ...
@classmethod
@overload
def resolve_provider_name(cls, provider: Provider) -> str: ...
@property @property
def parent(self) -> Optional[ProviderParent]: ... def parent(self) -> Optional[ProviderParent]: ...
@property @property

View File

@ -374,6 +374,10 @@ class DeclarativeContainerMetaClass(type):
for provider in six.itervalues(cls.providers): for provider in six.itervalues(cls.providers):
_check_provider_type(cls, provider) _check_provider_type(cls, provider)
for provider in six.itervalues(cls.cls_providers):
if isinstance(provider, providers.CHILD_PROVIDERS):
provider.assign_parent(cls)
return cls return cls
def __setattr__(cls, str name, object value): def __setattr__(cls, str name, object value):
@ -392,6 +396,10 @@ class DeclarativeContainerMetaClass(type):
""" """
if isinstance(value, providers.Provider) and name != '__self__': if isinstance(value, providers.Provider) and name != '__self__':
_check_provider_type(cls, value) _check_provider_type(cls, value)
if isinstance(value, providers.CHILD_PROVIDERS):
value.assign_parent(cls)
cls.providers[name] = value cls.providers[name] = value
cls.cls_providers[name] = value cls.cls_providers[name] = value
super(DeclarativeContainerMetaClass, cls).__setattr__(name, value) super(DeclarativeContainerMetaClass, cls).__setattr__(name, value)
@ -425,13 +433,26 @@ class DeclarativeContainerMetaClass(type):
return { return {
name: provider name: provider
for name, provider in cls.providers.items() for name, provider in cls.providers.items()
if isinstance(provider, providers.CHILD_PROVIDERS) if isinstance(provider, (providers.Dependency, providers.DependenciesContainer))
} }
def traverse(cls, types=None): def traverse(cls, types=None):
"""Return providers traversal generator.""" """Return providers traversal generator."""
yield from providers.traverse(*cls.providers.values(), types=types) yield from providers.traverse(*cls.providers.values(), types=types)
def resolve_provider_name(cls, provider):
"""Try to resolve provider name."""
for provider_name, container_provider in cls.providers.items():
if container_provider is provider:
return provider_name
else:
raise errors.Error(f'Can not resolve name for provider "{provider}"')
@property
def parent_name(cls):
"""Return parent name."""
return cls.__name__
@staticmethod @staticmethod
def __fetch_self(attributes): def __fetch_self(attributes):
self = None self = None

View File

@ -224,7 +224,7 @@ class DeclarativeContainerTests(unittest.TestCase):
(_OverridingContainer1.p11, (_OverridingContainer1.p11,
_OverridingContainer2.p11)) _OverridingContainer2.p11))
def test_reset_last_overridding(self): def test_reset_last_overriding(self):
class _Container(containers.DeclarativeContainer): class _Container(containers.DeclarativeContainer):
p11 = providers.Provider() p11 = providers.Provider()
@ -244,7 +244,7 @@ class DeclarativeContainerTests(unittest.TestCase):
self.assertEqual(_Container.p11.overridden, self.assertEqual(_Container.p11.overridden,
(_OverridingContainer1.p11,)) (_OverridingContainer1.p11,))
def test_reset_last_overridding_when_not_overridden(self): def test_reset_last_overriding_when_not_overridden(self):
with self.assertRaises(errors.Error): with self.assertRaises(errors.Error):
ContainerA.reset_last_overriding() ContainerA.reset_last_overriding()
@ -431,3 +431,68 @@ class DeclarativeContainerTests(unittest.TestCase):
self.assertIsInstance(container.p31, providers.Provider) self.assertIsInstance(container.p31, providers.Provider)
self.assertIsInstance(container.p32, providers.Provider) self.assertIsInstance(container.p32, providers.Provider)
self.assertIs(container.p11.last_overriding, provider) self.assertIs(container.p11.last_overriding, provider)
def test_parent_set_in__new__(self):
class Container(containers.DeclarativeContainer):
dependency = providers.Dependency()
dependencies_container = providers.DependenciesContainer()
container = providers.Container(ContainerA)
self.assertIs(Container.dependency.parent, Container)
self.assertIs(Container.dependencies_container.parent, Container)
self.assertIs(Container.container.parent, Container)
def test_parent_set_in__setattr__(self):
class Container(containers.DeclarativeContainer):
pass
Container.dependency = providers.Dependency()
Container.dependencies_container = providers.DependenciesContainer()
Container.container = providers.Container(ContainerA)
self.assertIs(Container.dependency.parent, Container)
self.assertIs(Container.dependencies_container.parent, Container)
self.assertIs(Container.container.parent, Container)
def test_resolve_provider_name(self):
self.assertEqual(ContainerA.resolve_provider_name(ContainerA.p11), 'p11')
def test_resolve_provider_name_no_provider(self):
with self.assertRaises(errors.Error):
ContainerA.resolve_provider_name(providers.Provider())
def test_child_dependency_parent_name(self):
class Container(containers.DeclarativeContainer):
dependency = providers.Dependency()
with self.assertRaises(errors.Error) as context:
Container.dependency()
self.assertEqual(
str(context.exception),
'Dependency "Container.dependency" is not defined',
)
def test_child_dependencies_container_parent_name(self):
class Container(containers.DeclarativeContainer):
dependencies_container = providers.DependenciesContainer()
with self.assertRaises(errors.Error) as context:
Container.dependencies_container.dependency()
self.assertEqual(
str(context.exception),
'Dependency "Container.dependencies_container.dependency" is not defined',
)
def test_child_container_parent_name(self):
class ChildContainer(containers.DeclarativeContainer):
dependency = providers.Dependency()
class Container(containers.DeclarativeContainer):
child_container = providers.Container(ChildContainer)
with self.assertRaises(errors.Error) as context:
Container.child_container.dependency()
self.assertEqual(
str(context.exception),
'Dependency "Container.child_container.dependency" is not defined',
)