Add some functionality and tests for declarative containers

+ Add checks for valid provider type
+ Add some wider functionality for overriding
This commit is contained in:
Roman Mogilatov 2016-05-30 23:34:14 +03:00
parent 68ae1b80df
commit a35db5889d
4 changed files with 173 additions and 21 deletions

View File

@ -3,6 +3,7 @@
import six import six
from dependency_injector import ( from dependency_injector import (
providers,
utils, utils,
errors, errors,
) )
@ -18,7 +19,8 @@ class DeclarativeContainerMetaClass(type):
if utils.is_provider(provider)) if utils.is_provider(provider))
inherited_providers = tuple((name, provider) inherited_providers = tuple((name, provider)
for base in bases if utils.is_catalog(base) for base in bases if utils.is_container(
base)
for name, provider in six.iteritems( for name, provider in six.iteritems(
base.cls_providers)) base.cls_providers))
@ -28,14 +30,8 @@ class DeclarativeContainerMetaClass(type):
cls = type.__new__(mcs, class_name, bases, attributes) cls = type.__new__(mcs, class_name, bases, attributes)
if cls.provider_type: for provider in six.itervalues(cls.providers):
for provider in six.itervalues(cls.providers): cls._check_provider_type(provider)
try:
assert isinstance(provider, cls.provider_type)
except AssertionError:
raise errors.Error('{0} can contain only {1} '
'instances'.format(cls,
cls.provider_type))
return cls return cls
@ -46,6 +42,7 @@ class DeclarativeContainerMetaClass(type):
dictionary. dictionary.
""" """
if utils.is_provider(value): if utils.is_provider(value):
cls._check_provider_type(value)
cls.providers[name] = value cls.providers[name] = value
cls.cls_providers[name] = value cls.cls_providers[name] = value
super(DeclarativeContainerMetaClass, cls).__setattr__(name, value) super(DeclarativeContainerMetaClass, cls).__setattr__(name, value)
@ -61,14 +58,19 @@ class DeclarativeContainerMetaClass(type):
del cls.cls_providers[name] del cls.cls_providers[name]
super(DeclarativeContainerMetaClass, cls).__delattr__(name) super(DeclarativeContainerMetaClass, cls).__delattr__(name)
def _check_provider_type(cls, provider):
if not isinstance(provider, cls.provider_type):
raise errors.Error('{0} can contain only {1} '
'instances'.format(cls, cls.provider_type))
@six.add_metaclass(DeclarativeContainerMetaClass) @six.add_metaclass(DeclarativeContainerMetaClass)
class DeclarativeContainer(object): class DeclarativeContainer(object):
"""Declarative inversion of control container.""" """Declarative inversion of control container."""
__IS_CATALOG__ = True __IS_CONTAINER__ = True
provider_type = None provider_type = providers.Provider
providers = dict() providers = dict()
cls_providers = dict() cls_providers = dict()
@ -89,7 +91,7 @@ class DeclarativeContainer(object):
:rtype: None :rtype: None
""" """
if issubclass(cls, overriding): if issubclass(cls, overriding):
raise errors.Error('Catalog {0} could not be overridden ' raise errors.Error('Container {0} could not be overridden '
'with itself or its subclasses'.format(cls)) 'with itself or its subclasses'.format(cls))
cls.overridden_by += (overriding,) cls.overridden_by += (overriding,)
@ -100,6 +102,31 @@ class DeclarativeContainer(object):
except AttributeError: except AttributeError:
pass pass
@classmethod
def reset_last_overriding(cls):
"""Reset last overriding provider for each container providers.
:rtype: None
"""
if not cls.overridden_by:
raise errors.Error('Container {0} is not overridden'.format(cls))
cls.overridden_by = cls.overridden_by[:-1]
for provider in six.itervalues(cls.providers):
provider.reset_last_overriding()
@classmethod
def reset_override(cls):
"""Reset all overridings for each container providers.
:rtype: None
"""
cls.overridden_by = tuple()
for provider in six.itervalues(cls.providers):
provider.reset_override()
def override(container): def override(container):
""":py:class:`DeclarativeContainer` overriding decorator. """:py:class:`DeclarativeContainer` overriding decorator.

View File

@ -64,7 +64,7 @@ class Provider(object):
def __init__(self): def __init__(self):
"""Initializer.""" """Initializer."""
self.overridden_by = None self.overridden_by = tuple()
super(Provider, self).__init__() super(Provider, self).__init__()
# Enable __call__() / _provide() optimization # Enable __call__() / _provide() optimization
if self.__class__.__OPTIMIZED_CALLS__: if self.__class__.__OPTIMIZED_CALLS__:
@ -124,10 +124,7 @@ class Provider(object):
if not is_provider(provider): if not is_provider(provider):
provider = Object(provider) provider = Object(provider)
if not self.is_overridden: self.overridden_by += (ensure_is_provider(provider),)
self.overridden_by = (ensure_is_provider(provider),)
else:
self.overridden_by += (ensure_is_provider(provider),)
# Disable __call__() / _provide() optimization # Disable __call__() / _provide() optimization
if self.__class__.__OPTIMIZED_CALLS__: if self.__class__.__OPTIMIZED_CALLS__:
@ -145,6 +142,7 @@ class Provider(object):
""" """
if not self.overridden_by: if not self.overridden_by:
raise Error('Provider {0} is not overridden'.format(str(self))) raise Error('Provider {0} is not overridden'.format(str(self)))
self.overridden_by = self.overridden_by[:-1] self.overridden_by = self.overridden_by[:-1]
if not self.is_overridden: if not self.is_overridden:
@ -157,7 +155,7 @@ class Provider(object):
:rtype: None :rtype: None
""" """
self.overridden_by = None self.overridden_by = tuple()
# Enable __call__() / _provide() optimization # Enable __call__() / _provide() optimization
if self.__class__.__OPTIMIZED_CALLS__: if self.__class__.__OPTIMIZED_CALLS__:

View File

@ -59,6 +59,18 @@ def ensure_is_provider(instance):
return instance return instance
def is_container(instance):
"""Check if instance is container instance.
:param instance: Instance to be checked.
:type instance: object
:rtype: bool
"""
return (hasattr(instance, '__IS_CONTAINER__') and
getattr(instance, '__IS_CONTAINER__', False) is True)
def is_catalog(instance): def is_catalog(instance):
"""Check if instance is catalog instance. """Check if instance is catalog instance.

View File

@ -5,6 +5,7 @@ import unittest2 as unittest
from dependency_injector import ( from dependency_injector import (
containers, containers,
providers, providers,
errors,
) )
@ -28,7 +29,7 @@ class ContainerB(ContainerA):
class DeclarativeContainerTests(unittest.TestCase): class DeclarativeContainerTests(unittest.TestCase):
"""Declarative container tests.""" """Declarative container tests."""
def test_providers_attribute_with(self): def test_providers_attribute(self):
"""Test providers attribute.""" """Test providers attribute."""
self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11,
p12=ContainerA.p12)) p12=ContainerA.p12))
@ -37,7 +38,7 @@ class DeclarativeContainerTests(unittest.TestCase):
p21=ContainerB.p21, p21=ContainerB.p21,
p22=ContainerB.p22)) p22=ContainerB.p22))
def test_cls_providers_attribute_with(self): def test_cls_providers_attribute(self):
"""Test cls_providers attribute.""" """Test cls_providers attribute."""
self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11,
p12=ContainerA.p12)) p12=ContainerA.p12))
@ -51,7 +52,7 @@ class DeclarativeContainerTests(unittest.TestCase):
dict(p11=ContainerA.p11, dict(p11=ContainerA.p11,
p12=ContainerA.p12)) p12=ContainerA.p12))
def test_set_get_del_provider_attribute(self): def test_set_get_del_providers(self):
"""Test set/get/del provider attributes.""" """Test set/get/del provider attributes."""
a_p13 = providers.Provider() a_p13 = providers.Provider()
b_p23 = providers.Provider() b_p23 = providers.Provider()
@ -90,6 +91,120 @@ class DeclarativeContainerTests(unittest.TestCase):
self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21, self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21,
p22=ContainerB.p22)) p22=ContainerB.p22))
def test_declare_with_valid_provider_type(self):
"""Test declaration of container with valid provider type."""
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object
px = providers.Object(object())
self.assertIsInstance(_Container.px, providers.Object)
def test_declare_with_invalid_provider_type(self):
"""Test declaration of container with invalid provider type."""
with self.assertRaises(errors.Error):
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object
px = providers.Provider()
def test_seth_valid_provider_type(self):
"""Test setting of valid provider."""
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object
_Container.px = providers.Object(object())
self.assertIsInstance(_Container.px, providers.Object)
def test_set_invalid_provider_type(self):
"""Test setting of invalid provider."""
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object
with self.assertRaises(errors.Error):
_Container.px = providers.Provider()
def test_override(self):
"""Test override."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()
class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()
class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()
_Container.override(_OverridingContainer1)
_Container.override(_OverridingContainer2)
self.assertEqual(_Container.overridden_by,
(_OverridingContainer1,
_OverridingContainer2))
self.assertEqual(_Container.p11.overridden_by,
(_OverridingContainer1.p11,
_OverridingContainer2.p11))
def test_override_decorator(self):
"""Test override decorator."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()
@containers.override(_Container)
class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()
@containers.override(_Container)
class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()
self.assertEqual(_Container.overridden_by,
(_OverridingContainer1,
_OverridingContainer2))
self.assertEqual(_Container.p11.overridden_by,
(_OverridingContainer1.p11,
_OverridingContainer2.p11))
def test_reset_last_overridding(self):
"""Test reset of last overriding."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()
class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()
class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()
_Container.override(_OverridingContainer1)
_Container.override(_OverridingContainer2)
_Container.reset_last_overriding()
self.assertEqual(_Container.overridden_by,
(_OverridingContainer1,))
self.assertEqual(_Container.p11.overridden_by,
(_OverridingContainer1.p11,))
def test_reset_override(self):
"""Test reset all overridings."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()
class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()
class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()
_Container.override(_OverridingContainer1)
_Container.override(_OverridingContainer2)
_Container.reset_override()
self.assertEqual(_Container.overridden_by, tuple())
self.assertEqual(_Container.p11.overridden_by, tuple())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()