From a35db5889df52bb6b78fad9dfdf38d9ebb04683a Mon Sep 17 00:00:00 2001 From: Roman Mogilatov Date: Mon, 30 May 2016 23:34:14 +0300 Subject: [PATCH] Add some functionality and tests for declarative containers + Add checks for valid provider type + Add some wider functionality for overriding --- dependency_injector/containers.py | 51 ++++++++--- dependency_injector/providers/base.py | 10 +-- dependency_injector/utils.py | 12 +++ tests/test_containers.py | 121 +++++++++++++++++++++++++- 4 files changed, 173 insertions(+), 21 deletions(-) diff --git a/dependency_injector/containers.py b/dependency_injector/containers.py index c745d57f..2cf435fc 100644 --- a/dependency_injector/containers.py +++ b/dependency_injector/containers.py @@ -3,6 +3,7 @@ import six from dependency_injector import ( + providers, utils, errors, ) @@ -18,7 +19,8 @@ class DeclarativeContainerMetaClass(type): if utils.is_provider(provider)) inherited_providers = tuple((name, provider) - for base in bases if utils.is_catalog(base) + for base in bases if utils.is_container( + base) for name, provider in six.iteritems( base.cls_providers)) @@ -28,14 +30,8 @@ class DeclarativeContainerMetaClass(type): cls = type.__new__(mcs, class_name, bases, attributes) - if cls.provider_type: - for provider in six.itervalues(cls.providers): - try: - assert isinstance(provider, cls.provider_type) - except AssertionError: - raise errors.Error('{0} can contain only {1} ' - 'instances'.format(cls, - cls.provider_type)) + for provider in six.itervalues(cls.providers): + cls._check_provider_type(provider) return cls @@ -46,6 +42,7 @@ class DeclarativeContainerMetaClass(type): dictionary. """ if utils.is_provider(value): + cls._check_provider_type(value) cls.providers[name] = value cls.cls_providers[name] = value super(DeclarativeContainerMetaClass, cls).__setattr__(name, value) @@ -61,14 +58,19 @@ class DeclarativeContainerMetaClass(type): del cls.cls_providers[name] super(DeclarativeContainerMetaClass, cls).__delattr__(name) + def _check_provider_type(cls, provider): + if not isinstance(provider, cls.provider_type): + raise errors.Error('{0} can contain only {1} ' + 'instances'.format(cls, cls.provider_type)) + @six.add_metaclass(DeclarativeContainerMetaClass) class DeclarativeContainer(object): """Declarative inversion of control container.""" - __IS_CATALOG__ = True + __IS_CONTAINER__ = True - provider_type = None + provider_type = providers.Provider providers = dict() cls_providers = dict() @@ -89,7 +91,7 @@ class DeclarativeContainer(object): :rtype: None """ if issubclass(cls, overriding): - raise errors.Error('Catalog {0} could not be overridden ' + raise errors.Error('Container {0} could not be overridden ' 'with itself or its subclasses'.format(cls)) cls.overridden_by += (overriding,) @@ -100,6 +102,31 @@ class DeclarativeContainer(object): except AttributeError: pass + @classmethod + def reset_last_overriding(cls): + """Reset last overriding provider for each container providers. + + :rtype: None + """ + if not cls.overridden_by: + raise errors.Error('Container {0} is not overridden'.format(cls)) + + cls.overridden_by = cls.overridden_by[:-1] + + for provider in six.itervalues(cls.providers): + provider.reset_last_overriding() + + @classmethod + def reset_override(cls): + """Reset all overridings for each container providers. + + :rtype: None + """ + cls.overridden_by = tuple() + + for provider in six.itervalues(cls.providers): + provider.reset_override() + def override(container): """:py:class:`DeclarativeContainer` overriding decorator. diff --git a/dependency_injector/providers/base.py b/dependency_injector/providers/base.py index 36b9a6e9..4b740497 100644 --- a/dependency_injector/providers/base.py +++ b/dependency_injector/providers/base.py @@ -64,7 +64,7 @@ class Provider(object): def __init__(self): """Initializer.""" - self.overridden_by = None + self.overridden_by = tuple() super(Provider, self).__init__() # Enable __call__() / _provide() optimization if self.__class__.__OPTIMIZED_CALLS__: @@ -124,10 +124,7 @@ class Provider(object): if not is_provider(provider): provider = Object(provider) - if not self.is_overridden: - self.overridden_by = (ensure_is_provider(provider),) - else: - self.overridden_by += (ensure_is_provider(provider),) + self.overridden_by += (ensure_is_provider(provider),) # Disable __call__() / _provide() optimization if self.__class__.__OPTIMIZED_CALLS__: @@ -145,6 +142,7 @@ class Provider(object): """ if not self.overridden_by: raise Error('Provider {0} is not overridden'.format(str(self))) + self.overridden_by = self.overridden_by[:-1] if not self.is_overridden: @@ -157,7 +155,7 @@ class Provider(object): :rtype: None """ - self.overridden_by = None + self.overridden_by = tuple() # Enable __call__() / _provide() optimization if self.__class__.__OPTIMIZED_CALLS__: diff --git a/dependency_injector/utils.py b/dependency_injector/utils.py index f65dfcea..f1354a14 100644 --- a/dependency_injector/utils.py +++ b/dependency_injector/utils.py @@ -59,6 +59,18 @@ def ensure_is_provider(instance): return instance +def is_container(instance): + """Check if instance is container instance. + + :param instance: Instance to be checked. + :type instance: object + + :rtype: bool + """ + return (hasattr(instance, '__IS_CONTAINER__') and + getattr(instance, '__IS_CONTAINER__', False) is True) + + def is_catalog(instance): """Check if instance is catalog instance. diff --git a/tests/test_containers.py b/tests/test_containers.py index 38e6e667..39bf7d1e 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -5,6 +5,7 @@ import unittest2 as unittest from dependency_injector import ( containers, providers, + errors, ) @@ -28,7 +29,7 @@ class ContainerB(ContainerA): class DeclarativeContainerTests(unittest.TestCase): """Declarative container tests.""" - def test_providers_attribute_with(self): + def test_providers_attribute(self): """Test providers attribute.""" self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, p12=ContainerA.p12)) @@ -37,7 +38,7 @@ class DeclarativeContainerTests(unittest.TestCase): p21=ContainerB.p21, p22=ContainerB.p22)) - def test_cls_providers_attribute_with(self): + def test_cls_providers_attribute(self): """Test cls_providers attribute.""" self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, p12=ContainerA.p12)) @@ -51,7 +52,7 @@ class DeclarativeContainerTests(unittest.TestCase): dict(p11=ContainerA.p11, p12=ContainerA.p12)) - def test_set_get_del_provider_attribute(self): + def test_set_get_del_providers(self): """Test set/get/del provider attributes.""" a_p13 = providers.Provider() b_p23 = providers.Provider() @@ -90,6 +91,120 @@ class DeclarativeContainerTests(unittest.TestCase): self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21, p22=ContainerB.p22)) + def test_declare_with_valid_provider_type(self): + """Test declaration of container with valid provider type.""" + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + px = providers.Object(object()) + + self.assertIsInstance(_Container.px, providers.Object) + + def test_declare_with_invalid_provider_type(self): + """Test declaration of container with invalid provider type.""" + with self.assertRaises(errors.Error): + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + px = providers.Provider() + + def test_seth_valid_provider_type(self): + """Test setting of valid provider.""" + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + + _Container.px = providers.Object(object()) + + self.assertIsInstance(_Container.px, providers.Object) + + def test_set_invalid_provider_type(self): + """Test setting of invalid provider.""" + class _Container(containers.DeclarativeContainer): + provider_type = providers.Object + + with self.assertRaises(errors.Error): + _Container.px = providers.Provider() + + def test_override(self): + """Test override.""" + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + _Container.override(_OverridingContainer1) + _Container.override(_OverridingContainer2) + + self.assertEqual(_Container.overridden_by, + (_OverridingContainer1, + _OverridingContainer2)) + self.assertEqual(_Container.p11.overridden_by, + (_OverridingContainer1.p11, + _OverridingContainer2.p11)) + + def test_override_decorator(self): + """Test override decorator.""" + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + @containers.override(_Container) + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + @containers.override(_Container) + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + self.assertEqual(_Container.overridden_by, + (_OverridingContainer1, + _OverridingContainer2)) + self.assertEqual(_Container.p11.overridden_by, + (_OverridingContainer1.p11, + _OverridingContainer2.p11)) + + def test_reset_last_overridding(self): + """Test reset of last overriding.""" + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + _Container.override(_OverridingContainer1) + _Container.override(_OverridingContainer2) + _Container.reset_last_overriding() + + self.assertEqual(_Container.overridden_by, + (_OverridingContainer1,)) + self.assertEqual(_Container.p11.overridden_by, + (_OverridingContainer1.p11,)) + + def test_reset_override(self): + """Test reset all overridings.""" + class _Container(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer1(containers.DeclarativeContainer): + p11 = providers.Provider() + + class _OverridingContainer2(containers.DeclarativeContainer): + p11 = providers.Provider() + p12 = providers.Provider() + + _Container.override(_OverridingContainer1) + _Container.override(_OverridingContainer2) + _Container.reset_override() + + self.assertEqual(_Container.overridden_by, tuple()) + self.assertEqual(_Container.p11.overridden_by, tuple()) if __name__ == '__main__': unittest.main()