mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-25 11:04:01 +03:00
Add some functionality and tests for declarative containers
+ Add checks for valid provider type + Add some wider functionality for overriding
This commit is contained in:
parent
68ae1b80df
commit
a35db5889d
|
@ -3,6 +3,7 @@
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from dependency_injector import (
|
from dependency_injector import (
|
||||||
|
providers,
|
||||||
utils,
|
utils,
|
||||||
errors,
|
errors,
|
||||||
)
|
)
|
||||||
|
@ -18,7 +19,8 @@ class DeclarativeContainerMetaClass(type):
|
||||||
if utils.is_provider(provider))
|
if utils.is_provider(provider))
|
||||||
|
|
||||||
inherited_providers = tuple((name, provider)
|
inherited_providers = tuple((name, provider)
|
||||||
for base in bases if utils.is_catalog(base)
|
for base in bases if utils.is_container(
|
||||||
|
base)
|
||||||
for name, provider in six.iteritems(
|
for name, provider in six.iteritems(
|
||||||
base.cls_providers))
|
base.cls_providers))
|
||||||
|
|
||||||
|
@ -28,14 +30,8 @@ class DeclarativeContainerMetaClass(type):
|
||||||
|
|
||||||
cls = type.__new__(mcs, class_name, bases, attributes)
|
cls = type.__new__(mcs, class_name, bases, attributes)
|
||||||
|
|
||||||
if cls.provider_type:
|
|
||||||
for provider in six.itervalues(cls.providers):
|
for provider in six.itervalues(cls.providers):
|
||||||
try:
|
cls._check_provider_type(provider)
|
||||||
assert isinstance(provider, cls.provider_type)
|
|
||||||
except AssertionError:
|
|
||||||
raise errors.Error('{0} can contain only {1} '
|
|
||||||
'instances'.format(cls,
|
|
||||||
cls.provider_type))
|
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@ -46,6 +42,7 @@ class DeclarativeContainerMetaClass(type):
|
||||||
dictionary.
|
dictionary.
|
||||||
"""
|
"""
|
||||||
if utils.is_provider(value):
|
if utils.is_provider(value):
|
||||||
|
cls._check_provider_type(value)
|
||||||
cls.providers[name] = value
|
cls.providers[name] = value
|
||||||
cls.cls_providers[name] = value
|
cls.cls_providers[name] = value
|
||||||
super(DeclarativeContainerMetaClass, cls).__setattr__(name, value)
|
super(DeclarativeContainerMetaClass, cls).__setattr__(name, value)
|
||||||
|
@ -61,14 +58,19 @@ class DeclarativeContainerMetaClass(type):
|
||||||
del cls.cls_providers[name]
|
del cls.cls_providers[name]
|
||||||
super(DeclarativeContainerMetaClass, cls).__delattr__(name)
|
super(DeclarativeContainerMetaClass, cls).__delattr__(name)
|
||||||
|
|
||||||
|
def _check_provider_type(cls, provider):
|
||||||
|
if not isinstance(provider, cls.provider_type):
|
||||||
|
raise errors.Error('{0} can contain only {1} '
|
||||||
|
'instances'.format(cls, cls.provider_type))
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(DeclarativeContainerMetaClass)
|
@six.add_metaclass(DeclarativeContainerMetaClass)
|
||||||
class DeclarativeContainer(object):
|
class DeclarativeContainer(object):
|
||||||
"""Declarative inversion of control container."""
|
"""Declarative inversion of control container."""
|
||||||
|
|
||||||
__IS_CATALOG__ = True
|
__IS_CONTAINER__ = True
|
||||||
|
|
||||||
provider_type = None
|
provider_type = providers.Provider
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
cls_providers = dict()
|
cls_providers = dict()
|
||||||
|
@ -89,7 +91,7 @@ class DeclarativeContainer(object):
|
||||||
:rtype: None
|
:rtype: None
|
||||||
"""
|
"""
|
||||||
if issubclass(cls, overriding):
|
if issubclass(cls, overriding):
|
||||||
raise errors.Error('Catalog {0} could not be overridden '
|
raise errors.Error('Container {0} could not be overridden '
|
||||||
'with itself or its subclasses'.format(cls))
|
'with itself or its subclasses'.format(cls))
|
||||||
|
|
||||||
cls.overridden_by += (overriding,)
|
cls.overridden_by += (overriding,)
|
||||||
|
@ -100,6 +102,31 @@ class DeclarativeContainer(object):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_last_overriding(cls):
|
||||||
|
"""Reset last overriding provider for each container providers.
|
||||||
|
|
||||||
|
:rtype: None
|
||||||
|
"""
|
||||||
|
if not cls.overridden_by:
|
||||||
|
raise errors.Error('Container {0} is not overridden'.format(cls))
|
||||||
|
|
||||||
|
cls.overridden_by = cls.overridden_by[:-1]
|
||||||
|
|
||||||
|
for provider in six.itervalues(cls.providers):
|
||||||
|
provider.reset_last_overriding()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_override(cls):
|
||||||
|
"""Reset all overridings for each container providers.
|
||||||
|
|
||||||
|
:rtype: None
|
||||||
|
"""
|
||||||
|
cls.overridden_by = tuple()
|
||||||
|
|
||||||
|
for provider in six.itervalues(cls.providers):
|
||||||
|
provider.reset_override()
|
||||||
|
|
||||||
|
|
||||||
def override(container):
|
def override(container):
|
||||||
""":py:class:`DeclarativeContainer` overriding decorator.
|
""":py:class:`DeclarativeContainer` overriding decorator.
|
||||||
|
|
|
@ -64,7 +64,7 @@ class Provider(object):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initializer."""
|
"""Initializer."""
|
||||||
self.overridden_by = None
|
self.overridden_by = tuple()
|
||||||
super(Provider, self).__init__()
|
super(Provider, self).__init__()
|
||||||
# Enable __call__() / _provide() optimization
|
# Enable __call__() / _provide() optimization
|
||||||
if self.__class__.__OPTIMIZED_CALLS__:
|
if self.__class__.__OPTIMIZED_CALLS__:
|
||||||
|
@ -124,9 +124,6 @@ class Provider(object):
|
||||||
if not is_provider(provider):
|
if not is_provider(provider):
|
||||||
provider = Object(provider)
|
provider = Object(provider)
|
||||||
|
|
||||||
if not self.is_overridden:
|
|
||||||
self.overridden_by = (ensure_is_provider(provider),)
|
|
||||||
else:
|
|
||||||
self.overridden_by += (ensure_is_provider(provider),)
|
self.overridden_by += (ensure_is_provider(provider),)
|
||||||
|
|
||||||
# Disable __call__() / _provide() optimization
|
# Disable __call__() / _provide() optimization
|
||||||
|
@ -145,6 +142,7 @@ class Provider(object):
|
||||||
"""
|
"""
|
||||||
if not self.overridden_by:
|
if not self.overridden_by:
|
||||||
raise Error('Provider {0} is not overridden'.format(str(self)))
|
raise Error('Provider {0} is not overridden'.format(str(self)))
|
||||||
|
|
||||||
self.overridden_by = self.overridden_by[:-1]
|
self.overridden_by = self.overridden_by[:-1]
|
||||||
|
|
||||||
if not self.is_overridden:
|
if not self.is_overridden:
|
||||||
|
@ -157,7 +155,7 @@ class Provider(object):
|
||||||
|
|
||||||
:rtype: None
|
:rtype: None
|
||||||
"""
|
"""
|
||||||
self.overridden_by = None
|
self.overridden_by = tuple()
|
||||||
|
|
||||||
# Enable __call__() / _provide() optimization
|
# Enable __call__() / _provide() optimization
|
||||||
if self.__class__.__OPTIMIZED_CALLS__:
|
if self.__class__.__OPTIMIZED_CALLS__:
|
||||||
|
|
|
@ -59,6 +59,18 @@ def ensure_is_provider(instance):
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def is_container(instance):
|
||||||
|
"""Check if instance is container instance.
|
||||||
|
|
||||||
|
:param instance: Instance to be checked.
|
||||||
|
:type instance: object
|
||||||
|
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return (hasattr(instance, '__IS_CONTAINER__') and
|
||||||
|
getattr(instance, '__IS_CONTAINER__', False) is True)
|
||||||
|
|
||||||
|
|
||||||
def is_catalog(instance):
|
def is_catalog(instance):
|
||||||
"""Check if instance is catalog instance.
|
"""Check if instance is catalog instance.
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import unittest2 as unittest
|
||||||
from dependency_injector import (
|
from dependency_injector import (
|
||||||
containers,
|
containers,
|
||||||
providers,
|
providers,
|
||||||
|
errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ class ContainerB(ContainerA):
|
||||||
class DeclarativeContainerTests(unittest.TestCase):
|
class DeclarativeContainerTests(unittest.TestCase):
|
||||||
"""Declarative container tests."""
|
"""Declarative container tests."""
|
||||||
|
|
||||||
def test_providers_attribute_with(self):
|
def test_providers_attribute(self):
|
||||||
"""Test providers attribute."""
|
"""Test providers attribute."""
|
||||||
self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11,
|
self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11,
|
||||||
p12=ContainerA.p12))
|
p12=ContainerA.p12))
|
||||||
|
@ -37,7 +38,7 @@ class DeclarativeContainerTests(unittest.TestCase):
|
||||||
p21=ContainerB.p21,
|
p21=ContainerB.p21,
|
||||||
p22=ContainerB.p22))
|
p22=ContainerB.p22))
|
||||||
|
|
||||||
def test_cls_providers_attribute_with(self):
|
def test_cls_providers_attribute(self):
|
||||||
"""Test cls_providers attribute."""
|
"""Test cls_providers attribute."""
|
||||||
self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11,
|
self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11,
|
||||||
p12=ContainerA.p12))
|
p12=ContainerA.p12))
|
||||||
|
@ -51,7 +52,7 @@ class DeclarativeContainerTests(unittest.TestCase):
|
||||||
dict(p11=ContainerA.p11,
|
dict(p11=ContainerA.p11,
|
||||||
p12=ContainerA.p12))
|
p12=ContainerA.p12))
|
||||||
|
|
||||||
def test_set_get_del_provider_attribute(self):
|
def test_set_get_del_providers(self):
|
||||||
"""Test set/get/del provider attributes."""
|
"""Test set/get/del provider attributes."""
|
||||||
a_p13 = providers.Provider()
|
a_p13 = providers.Provider()
|
||||||
b_p23 = providers.Provider()
|
b_p23 = providers.Provider()
|
||||||
|
@ -90,6 +91,120 @@ class DeclarativeContainerTests(unittest.TestCase):
|
||||||
self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21,
|
self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21,
|
||||||
p22=ContainerB.p22))
|
p22=ContainerB.p22))
|
||||||
|
|
||||||
|
def test_declare_with_valid_provider_type(self):
|
||||||
|
"""Test declaration of container with valid provider type."""
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
provider_type = providers.Object
|
||||||
|
px = providers.Object(object())
|
||||||
|
|
||||||
|
self.assertIsInstance(_Container.px, providers.Object)
|
||||||
|
|
||||||
|
def test_declare_with_invalid_provider_type(self):
|
||||||
|
"""Test declaration of container with invalid provider type."""
|
||||||
|
with self.assertRaises(errors.Error):
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
provider_type = providers.Object
|
||||||
|
px = providers.Provider()
|
||||||
|
|
||||||
|
def test_seth_valid_provider_type(self):
|
||||||
|
"""Test setting of valid provider."""
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
provider_type = providers.Object
|
||||||
|
|
||||||
|
_Container.px = providers.Object(object())
|
||||||
|
|
||||||
|
self.assertIsInstance(_Container.px, providers.Object)
|
||||||
|
|
||||||
|
def test_set_invalid_provider_type(self):
|
||||||
|
"""Test setting of invalid provider."""
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
provider_type = providers.Object
|
||||||
|
|
||||||
|
with self.assertRaises(errors.Error):
|
||||||
|
_Container.px = providers.Provider()
|
||||||
|
|
||||||
|
def test_override(self):
|
||||||
|
"""Test override."""
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
class _OverridingContainer1(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
class _OverridingContainer2(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
p12 = providers.Provider()
|
||||||
|
|
||||||
|
_Container.override(_OverridingContainer1)
|
||||||
|
_Container.override(_OverridingContainer2)
|
||||||
|
|
||||||
|
self.assertEqual(_Container.overridden_by,
|
||||||
|
(_OverridingContainer1,
|
||||||
|
_OverridingContainer2))
|
||||||
|
self.assertEqual(_Container.p11.overridden_by,
|
||||||
|
(_OverridingContainer1.p11,
|
||||||
|
_OverridingContainer2.p11))
|
||||||
|
|
||||||
|
def test_override_decorator(self):
|
||||||
|
"""Test override decorator."""
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
@containers.override(_Container)
|
||||||
|
class _OverridingContainer1(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
@containers.override(_Container)
|
||||||
|
class _OverridingContainer2(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
p12 = providers.Provider()
|
||||||
|
|
||||||
|
self.assertEqual(_Container.overridden_by,
|
||||||
|
(_OverridingContainer1,
|
||||||
|
_OverridingContainer2))
|
||||||
|
self.assertEqual(_Container.p11.overridden_by,
|
||||||
|
(_OverridingContainer1.p11,
|
||||||
|
_OverridingContainer2.p11))
|
||||||
|
|
||||||
|
def test_reset_last_overridding(self):
|
||||||
|
"""Test reset of last overriding."""
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
class _OverridingContainer1(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
class _OverridingContainer2(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
p12 = providers.Provider()
|
||||||
|
|
||||||
|
_Container.override(_OverridingContainer1)
|
||||||
|
_Container.override(_OverridingContainer2)
|
||||||
|
_Container.reset_last_overriding()
|
||||||
|
|
||||||
|
self.assertEqual(_Container.overridden_by,
|
||||||
|
(_OverridingContainer1,))
|
||||||
|
self.assertEqual(_Container.p11.overridden_by,
|
||||||
|
(_OverridingContainer1.p11,))
|
||||||
|
|
||||||
|
def test_reset_override(self):
|
||||||
|
"""Test reset all overridings."""
|
||||||
|
class _Container(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
class _OverridingContainer1(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
|
||||||
|
class _OverridingContainer2(containers.DeclarativeContainer):
|
||||||
|
p11 = providers.Provider()
|
||||||
|
p12 = providers.Provider()
|
||||||
|
|
||||||
|
_Container.override(_OverridingContainer1)
|
||||||
|
_Container.override(_OverridingContainer2)
|
||||||
|
_Container.reset_override()
|
||||||
|
|
||||||
|
self.assertEqual(_Container.overridden_by, tuple())
|
||||||
|
self.assertEqual(_Container.p11.overridden_by, tuple())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user