mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-10-31 16:07:51 +03:00 
			
		
		
		
	Add some functionality and tests for declarative containers
+ Add checks for valid provider type + Add some wider functionality for overriding
This commit is contained in:
		
							parent
							
								
									68ae1b80df
								
							
						
					
					
						commit
						a35db5889d
					
				|  | @ -3,6 +3,7 @@ | ||||||
| import six | import six | ||||||
| 
 | 
 | ||||||
| from dependency_injector import ( | from dependency_injector import ( | ||||||
|  |     providers, | ||||||
|     utils, |     utils, | ||||||
|     errors, |     errors, | ||||||
| ) | ) | ||||||
|  | @ -18,7 +19,8 @@ class DeclarativeContainerMetaClass(type): | ||||||
|                               if utils.is_provider(provider)) |                               if utils.is_provider(provider)) | ||||||
| 
 | 
 | ||||||
|         inherited_providers = tuple((name, provider) |         inherited_providers = tuple((name, provider) | ||||||
|                                     for base in bases if utils.is_catalog(base) |                                     for base in bases if utils.is_container( | ||||||
|  |                                         base) | ||||||
|                                     for name, provider in six.iteritems( |                                     for name, provider in six.iteritems( | ||||||
|                                         base.cls_providers)) |                                         base.cls_providers)) | ||||||
| 
 | 
 | ||||||
|  | @ -28,14 +30,8 @@ class DeclarativeContainerMetaClass(type): | ||||||
| 
 | 
 | ||||||
|         cls = type.__new__(mcs, class_name, bases, attributes) |         cls = type.__new__(mcs, class_name, bases, attributes) | ||||||
| 
 | 
 | ||||||
|         if cls.provider_type: |         for provider in six.itervalues(cls.providers): | ||||||
|             for provider in six.itervalues(cls.providers): |             cls._check_provider_type(provider) | ||||||
|                 try: |  | ||||||
|                     assert isinstance(provider, cls.provider_type) |  | ||||||
|                 except AssertionError: |  | ||||||
|                     raise errors.Error('{0} can contain only {1} ' |  | ||||||
|                                        'instances'.format(cls, |  | ||||||
|                                                           cls.provider_type)) |  | ||||||
| 
 | 
 | ||||||
|         return cls |         return cls | ||||||
| 
 | 
 | ||||||
|  | @ -46,6 +42,7 @@ class DeclarativeContainerMetaClass(type): | ||||||
|         dictionary. |         dictionary. | ||||||
|         """ |         """ | ||||||
|         if utils.is_provider(value): |         if utils.is_provider(value): | ||||||
|  |             cls._check_provider_type(value) | ||||||
|             cls.providers[name] = value |             cls.providers[name] = value | ||||||
|             cls.cls_providers[name] = value |             cls.cls_providers[name] = value | ||||||
|         super(DeclarativeContainerMetaClass, cls).__setattr__(name, value) |         super(DeclarativeContainerMetaClass, cls).__setattr__(name, value) | ||||||
|  | @ -61,14 +58,19 @@ class DeclarativeContainerMetaClass(type): | ||||||
|             del cls.cls_providers[name] |             del cls.cls_providers[name] | ||||||
|         super(DeclarativeContainerMetaClass, cls).__delattr__(name) |         super(DeclarativeContainerMetaClass, cls).__delattr__(name) | ||||||
| 
 | 
 | ||||||
|  |     def _check_provider_type(cls, provider): | ||||||
|  |         if not isinstance(provider, cls.provider_type): | ||||||
|  |             raise errors.Error('{0} can contain only {1} ' | ||||||
|  |                                'instances'.format(cls, cls.provider_type)) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @six.add_metaclass(DeclarativeContainerMetaClass) | @six.add_metaclass(DeclarativeContainerMetaClass) | ||||||
| class DeclarativeContainer(object): | class DeclarativeContainer(object): | ||||||
|     """Declarative inversion of control container.""" |     """Declarative inversion of control container.""" | ||||||
| 
 | 
 | ||||||
|     __IS_CATALOG__ = True |     __IS_CONTAINER__ = True | ||||||
| 
 | 
 | ||||||
|     provider_type = None |     provider_type = providers.Provider | ||||||
| 
 | 
 | ||||||
|     providers = dict() |     providers = dict() | ||||||
|     cls_providers = dict() |     cls_providers = dict() | ||||||
|  | @ -89,7 +91,7 @@ class DeclarativeContainer(object): | ||||||
|         :rtype: None |         :rtype: None | ||||||
|         """ |         """ | ||||||
|         if issubclass(cls, overriding): |         if issubclass(cls, overriding): | ||||||
|             raise errors.Error('Catalog {0} could not be overridden ' |             raise errors.Error('Container {0} could not be overridden ' | ||||||
|                                'with itself or its subclasses'.format(cls)) |                                'with itself or its subclasses'.format(cls)) | ||||||
| 
 | 
 | ||||||
|         cls.overridden_by += (overriding,) |         cls.overridden_by += (overriding,) | ||||||
|  | @ -100,6 +102,31 @@ class DeclarativeContainer(object): | ||||||
|             except AttributeError: |             except AttributeError: | ||||||
|                 pass |                 pass | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def reset_last_overriding(cls): | ||||||
|  |         """Reset last overriding provider for each container providers. | ||||||
|  | 
 | ||||||
|  |         :rtype: None | ||||||
|  |         """ | ||||||
|  |         if not cls.overridden_by: | ||||||
|  |             raise errors.Error('Container {0} is not overridden'.format(cls)) | ||||||
|  | 
 | ||||||
|  |         cls.overridden_by = cls.overridden_by[:-1] | ||||||
|  | 
 | ||||||
|  |         for provider in six.itervalues(cls.providers): | ||||||
|  |             provider.reset_last_overriding() | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def reset_override(cls): | ||||||
|  |         """Reset all overridings for each container providers. | ||||||
|  | 
 | ||||||
|  |         :rtype: None | ||||||
|  |         """ | ||||||
|  |         cls.overridden_by = tuple() | ||||||
|  | 
 | ||||||
|  |         for provider in six.itervalues(cls.providers): | ||||||
|  |             provider.reset_override() | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def override(container): | def override(container): | ||||||
|     """:py:class:`DeclarativeContainer` overriding decorator. |     """:py:class:`DeclarativeContainer` overriding decorator. | ||||||
|  |  | ||||||
|  | @ -64,7 +64,7 @@ class Provider(object): | ||||||
| 
 | 
 | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         """Initializer.""" |         """Initializer.""" | ||||||
|         self.overridden_by = None |         self.overridden_by = tuple() | ||||||
|         super(Provider, self).__init__() |         super(Provider, self).__init__() | ||||||
|         # Enable __call__() / _provide() optimization |         # Enable __call__() / _provide() optimization | ||||||
|         if self.__class__.__OPTIMIZED_CALLS__: |         if self.__class__.__OPTIMIZED_CALLS__: | ||||||
|  | @ -124,10 +124,7 @@ class Provider(object): | ||||||
|         if not is_provider(provider): |         if not is_provider(provider): | ||||||
|             provider = Object(provider) |             provider = Object(provider) | ||||||
| 
 | 
 | ||||||
|         if not self.is_overridden: |         self.overridden_by += (ensure_is_provider(provider),) | ||||||
|             self.overridden_by = (ensure_is_provider(provider),) |  | ||||||
|         else: |  | ||||||
|             self.overridden_by += (ensure_is_provider(provider),) |  | ||||||
| 
 | 
 | ||||||
|         # Disable __call__() / _provide() optimization |         # Disable __call__() / _provide() optimization | ||||||
|         if self.__class__.__OPTIMIZED_CALLS__: |         if self.__class__.__OPTIMIZED_CALLS__: | ||||||
|  | @ -145,6 +142,7 @@ class Provider(object): | ||||||
|         """ |         """ | ||||||
|         if not self.overridden_by: |         if not self.overridden_by: | ||||||
|             raise Error('Provider {0} is not overridden'.format(str(self))) |             raise Error('Provider {0} is not overridden'.format(str(self))) | ||||||
|  | 
 | ||||||
|         self.overridden_by = self.overridden_by[:-1] |         self.overridden_by = self.overridden_by[:-1] | ||||||
| 
 | 
 | ||||||
|         if not self.is_overridden: |         if not self.is_overridden: | ||||||
|  | @ -157,7 +155,7 @@ class Provider(object): | ||||||
| 
 | 
 | ||||||
|         :rtype: None |         :rtype: None | ||||||
|         """ |         """ | ||||||
|         self.overridden_by = None |         self.overridden_by = tuple() | ||||||
| 
 | 
 | ||||||
|         # Enable __call__() / _provide() optimization |         # Enable __call__() / _provide() optimization | ||||||
|         if self.__class__.__OPTIMIZED_CALLS__: |         if self.__class__.__OPTIMIZED_CALLS__: | ||||||
|  |  | ||||||
|  | @ -59,6 +59,18 @@ def ensure_is_provider(instance): | ||||||
|     return instance |     return instance | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def is_container(instance): | ||||||
|  |     """Check if instance is container instance. | ||||||
|  | 
 | ||||||
|  |     :param instance: Instance to be checked. | ||||||
|  |     :type instance: object | ||||||
|  | 
 | ||||||
|  |     :rtype: bool | ||||||
|  |     """ | ||||||
|  |     return (hasattr(instance, '__IS_CONTAINER__') and | ||||||
|  |             getattr(instance, '__IS_CONTAINER__', False) is True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def is_catalog(instance): | def is_catalog(instance): | ||||||
|     """Check if instance is catalog instance. |     """Check if instance is catalog instance. | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import unittest2 as unittest | ||||||
| from dependency_injector import ( | from dependency_injector import ( | ||||||
|     containers, |     containers, | ||||||
|     providers, |     providers, | ||||||
|  |     errors, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -28,7 +29,7 @@ class ContainerB(ContainerA): | ||||||
| class DeclarativeContainerTests(unittest.TestCase): | class DeclarativeContainerTests(unittest.TestCase): | ||||||
|     """Declarative container tests.""" |     """Declarative container tests.""" | ||||||
| 
 | 
 | ||||||
|     def test_providers_attribute_with(self): |     def test_providers_attribute(self): | ||||||
|         """Test providers attribute.""" |         """Test providers attribute.""" | ||||||
|         self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, |         self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11, | ||||||
|                                                     p12=ContainerA.p12)) |                                                     p12=ContainerA.p12)) | ||||||
|  | @ -37,7 +38,7 @@ class DeclarativeContainerTests(unittest.TestCase): | ||||||
|                                                     p21=ContainerB.p21, |                                                     p21=ContainerB.p21, | ||||||
|                                                     p22=ContainerB.p22)) |                                                     p22=ContainerB.p22)) | ||||||
| 
 | 
 | ||||||
|     def test_cls_providers_attribute_with(self): |     def test_cls_providers_attribute(self): | ||||||
|         """Test cls_providers attribute.""" |         """Test cls_providers attribute.""" | ||||||
|         self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, |         self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11, | ||||||
|                                                         p12=ContainerA.p12)) |                                                         p12=ContainerA.p12)) | ||||||
|  | @ -51,7 +52,7 @@ class DeclarativeContainerTests(unittest.TestCase): | ||||||
|                          dict(p11=ContainerA.p11, |                          dict(p11=ContainerA.p11, | ||||||
|                               p12=ContainerA.p12)) |                               p12=ContainerA.p12)) | ||||||
| 
 | 
 | ||||||
|     def test_set_get_del_provider_attribute(self): |     def test_set_get_del_providers(self): | ||||||
|         """Test set/get/del provider attributes.""" |         """Test set/get/del provider attributes.""" | ||||||
|         a_p13 = providers.Provider() |         a_p13 = providers.Provider() | ||||||
|         b_p23 = providers.Provider() |         b_p23 = providers.Provider() | ||||||
|  | @ -90,6 +91,120 @@ class DeclarativeContainerTests(unittest.TestCase): | ||||||
|         self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21, |         self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21, | ||||||
|                                                         p22=ContainerB.p22)) |                                                         p22=ContainerB.p22)) | ||||||
| 
 | 
 | ||||||
|  |     def test_declare_with_valid_provider_type(self): | ||||||
|  |         """Test declaration of container with valid provider type.""" | ||||||
|  |         class _Container(containers.DeclarativeContainer): | ||||||
|  |             provider_type = providers.Object | ||||||
|  |             px = providers.Object(object()) | ||||||
|  | 
 | ||||||
|  |         self.assertIsInstance(_Container.px, providers.Object) | ||||||
|  | 
 | ||||||
|  |     def test_declare_with_invalid_provider_type(self): | ||||||
|  |         """Test declaration of container with invalid provider type.""" | ||||||
|  |         with self.assertRaises(errors.Error): | ||||||
|  |             class _Container(containers.DeclarativeContainer): | ||||||
|  |                 provider_type = providers.Object | ||||||
|  |                 px = providers.Provider() | ||||||
|  | 
 | ||||||
|  |     def test_seth_valid_provider_type(self): | ||||||
|  |         """Test setting of valid provider.""" | ||||||
|  |         class _Container(containers.DeclarativeContainer): | ||||||
|  |             provider_type = providers.Object | ||||||
|  | 
 | ||||||
|  |         _Container.px = providers.Object(object()) | ||||||
|  | 
 | ||||||
|  |         self.assertIsInstance(_Container.px, providers.Object) | ||||||
|  | 
 | ||||||
|  |     def test_set_invalid_provider_type(self): | ||||||
|  |         """Test setting of invalid provider.""" | ||||||
|  |         class _Container(containers.DeclarativeContainer): | ||||||
|  |             provider_type = providers.Object | ||||||
|  | 
 | ||||||
|  |         with self.assertRaises(errors.Error): | ||||||
|  |             _Container.px = providers.Provider() | ||||||
|  | 
 | ||||||
|  |     def test_override(self): | ||||||
|  |         """Test override.""" | ||||||
|  |         class _Container(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         class _OverridingContainer1(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         class _OverridingContainer2(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  |             p12 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         _Container.override(_OverridingContainer1) | ||||||
|  |         _Container.override(_OverridingContainer2) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(_Container.overridden_by, | ||||||
|  |                          (_OverridingContainer1, | ||||||
|  |                           _OverridingContainer2)) | ||||||
|  |         self.assertEqual(_Container.p11.overridden_by, | ||||||
|  |                          (_OverridingContainer1.p11, | ||||||
|  |                           _OverridingContainer2.p11)) | ||||||
|  | 
 | ||||||
|  |     def test_override_decorator(self): | ||||||
|  |         """Test override decorator.""" | ||||||
|  |         class _Container(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         @containers.override(_Container) | ||||||
|  |         class _OverridingContainer1(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         @containers.override(_Container) | ||||||
|  |         class _OverridingContainer2(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  |             p12 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(_Container.overridden_by, | ||||||
|  |                          (_OverridingContainer1, | ||||||
|  |                           _OverridingContainer2)) | ||||||
|  |         self.assertEqual(_Container.p11.overridden_by, | ||||||
|  |                          (_OverridingContainer1.p11, | ||||||
|  |                           _OverridingContainer2.p11)) | ||||||
|  | 
 | ||||||
|  |     def test_reset_last_overridding(self): | ||||||
|  |         """Test reset of last overriding.""" | ||||||
|  |         class _Container(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         class _OverridingContainer1(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         class _OverridingContainer2(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  |             p12 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         _Container.override(_OverridingContainer1) | ||||||
|  |         _Container.override(_OverridingContainer2) | ||||||
|  |         _Container.reset_last_overriding() | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(_Container.overridden_by, | ||||||
|  |                          (_OverridingContainer1,)) | ||||||
|  |         self.assertEqual(_Container.p11.overridden_by, | ||||||
|  |                          (_OverridingContainer1.p11,)) | ||||||
|  | 
 | ||||||
|  |     def test_reset_override(self): | ||||||
|  |         """Test reset all overridings.""" | ||||||
|  |         class _Container(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         class _OverridingContainer1(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         class _OverridingContainer2(containers.DeclarativeContainer): | ||||||
|  |             p11 = providers.Provider() | ||||||
|  |             p12 = providers.Provider() | ||||||
|  | 
 | ||||||
|  |         _Container.override(_OverridingContainer1) | ||||||
|  |         _Container.override(_OverridingContainer2) | ||||||
|  |         _Container.reset_override() | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(_Container.overridden_by, tuple()) | ||||||
|  |         self.assertEqual(_Container.p11.overridden_by, tuple()) | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user