From ee73f03cb269aa30f2e1abceba343f7401bf99b5 Mon Sep 17 00:00:00 2001 From: Roman Mogilatov Date: Wed, 7 Oct 2015 13:36:28 +0300 Subject: [PATCH] Implement catalog providers, cls_providers and inherited providers class attributes --- dependency_injector/__init__.py | 2 + dependency_injector/catalog.py | 49 ++++++++++++---- dependency_injector/utils.py | 6 ++ tests/test_catalog.py | 101 +++++++++++++++++++++----------- tests/test_utils.py | 29 +++++++++ 5 files changed, 141 insertions(+), 46 deletions(-) diff --git a/dependency_injector/__init__.py b/dependency_injector/__init__.py index b87d0906..8420bd47 100644 --- a/dependency_injector/__init__.py +++ b/dependency_injector/__init__.py @@ -29,6 +29,7 @@ from .utils import ensure_is_injection from .utils import is_kwarg_injection from .utils import is_attribute_injection from .utils import is_method_injection +from .utils import is_catalog from .errors import Error @@ -67,6 +68,7 @@ __all__ = ( 'is_kwarg_injection', 'is_attribute_injection', 'is_method_injection', + 'is_catalog', # Errors 'Error', diff --git a/dependency_injector/catalog.py b/dependency_injector/catalog.py index 6300dd77..8f302320 100644 --- a/dependency_injector/catalog.py +++ b/dependency_injector/catalog.py @@ -3,7 +3,9 @@ import six from .errors import Error + from .utils import is_provider +from .utils import is_catalog class CatalogMetaClass(type): @@ -12,26 +14,47 @@ class CatalogMetaClass(type): def __new__(mcs, class_name, bases, attributes): """Meta class factory.""" - providers = dict() - new_attributes = dict() - for name, value in six.iteritems(attributes): - if is_provider(value): - providers[name] = value - new_attributes[name] = value + cls_providers = dict((name, provider) + for name, provider in six.iteritems(attributes) + if is_provider(provider)) - cls = type.__new__(mcs, class_name, bases, new_attributes) - cls.providers = cls.providers.copy() - cls.providers.update(providers) - return cls + inherited_providers = dict((name, provider) + for base in bases if is_catalog(base) + for name, provider in six.iteritems( + base.providers)) + + providers = dict() + providers.update(cls_providers) + providers.update(inherited_providers) + + attributes['cls_providers'] = cls_providers + attributes['inherited_providers'] = inherited_providers + attributes['providers'] = providers + return type.__new__(mcs, class_name, bases, attributes) @six.add_metaclass(CatalogMetaClass) class AbstractCatalog(object): - """Abstract providers catalog.""" + """Abstract providers catalog. + + :type providers: dict[str, dependency_injector.Provider] + :param providers: Dict of all catalog providers, including inherited from + parent catalogs + + :type cls_providers: dict[str, dependency_injector.Provider] + :param cls_providers: Dict of current catalog providers + + :type inherited_providers: dict[str, dependency_injector.Provider] + :param inherited_providers: Dict of providers, that are inherited from + parent catalogs + """ providers = dict() + cls_providers = dict() + inherited_providers = dict() + __IS_CATALOG__ = True __slots__ = ('used_providers',) def __init__(self, *used_providers): @@ -41,7 +64,7 @@ class AbstractCatalog(object): def __getattribute__(self, item): """Return providers.""" attribute = super(AbstractCatalog, self).__getattribute__(item) - if item in ('providers', 'used_providers',): + if item in ('providers', 'used_providers', '__class__'): return attribute if attribute not in self.used_providers: @@ -62,7 +85,7 @@ class AbstractCatalog(object): :type overriding: AbstractCatalog """ - for name, provider in six.iteritems(overriding.providers): + for name, provider in six.iteritems(overriding.cls_providers): cls.providers[name].override(provider) diff --git a/dependency_injector/utils.py b/dependency_injector/utils.py index f737be81..42a18202 100644 --- a/dependency_injector/utils.py +++ b/dependency_injector/utils.py @@ -56,6 +56,12 @@ def is_method_injection(instance): getattr(instance, '__IS_METHOD_INJECTION__', False) is True) +def is_catalog(instance): + """Check if instance is catalog instance.""" + return (isinstance(instance, six.class_types) and + getattr(instance, '__IS_CATALOG__', False) is True) + + def get_injectable_kwargs(kwargs, injections): """Return dictionary of kwargs, patched with injections.""" init_kwargs = dict(((injection.name, injection.value) diff --git a/tests/test_catalog.py b/tests/test_catalog.py index e74a511e..f1829222 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -4,6 +4,74 @@ import unittest2 as unittest import dependency_injector as di +class CatalogsInheritanceTests(unittest.TestCase): + + """Catalogs inheritance tests.""" + + class CatalogA(di.AbstractCatalog): + + """Test catalog A.""" + + p11 = di.Provider() + p12 = di.Provider() + + class CatalogB(CatalogA): + + """Test catalog B.""" + + p21 = di.Provider() + p22 = di.Provider() + + class CatalogC(CatalogB): + + """Test catalog C.""" + + p31 = di.Provider() + p32 = di.Provider() + + def test_cls_providers(self): + """Test `di.AbstractCatalog.cls_providers` contents.""" + self.assertDictEqual(self.CatalogA.cls_providers, + dict(p11=self.CatalogA.p11, + p12=self.CatalogA.p12)) + self.assertDictEqual(self.CatalogB.cls_providers, + dict(p21=self.CatalogB.p21, + p22=self.CatalogB.p22)) + self.assertDictEqual(self.CatalogC.cls_providers, + dict(p31=self.CatalogC.p31, + p32=self.CatalogC.p32)) + + def test_inherited_providers(self): + """Test `di.AbstractCatalog.inherited_providers` contents.""" + self.assertDictEqual(self.CatalogA.inherited_providers, dict()) + self.assertDictEqual(self.CatalogB.inherited_providers, + dict(p11=self.CatalogA.p11, + p12=self.CatalogA.p12)) + self.assertDictEqual(self.CatalogC.inherited_providers, + dict(p11=self.CatalogA.p11, + p12=self.CatalogA.p12, + p21=self.CatalogB.p21, + p22=self.CatalogB.p22)) + + def test_providers(self): + """Test `di.AbstractCatalog.inherited_providers` contents.""" + self.assertDictEqual(self.CatalogA.providers, + dict(p11=self.CatalogA.p11, + p12=self.CatalogA.p12)) + self.assertDictEqual(self.CatalogB.providers, + dict(p11=self.CatalogA.p11, + p12=self.CatalogA.p12, + p21=self.CatalogB.p21, + p22=self.CatalogB.p22)) + self.assertDictEqual(self.CatalogC.providers, + dict(p11=self.CatalogA.p11, + p12=self.CatalogA.p12, + p21=self.CatalogB.p21, + p22=self.CatalogB.p22, + p31=self.CatalogC.p31, + p32=self.CatalogC.p32)) + + class CatalogTests(unittest.TestCase): """Catalog test cases.""" @@ -25,44 +93,11 @@ class CatalogTests(unittest.TestCase): catalog = self.Catalog() self.assertRaises(di.Error, getattr, catalog, 'obj') - def test_all_providers(self): - """Test getting of all catalog providers.""" - self.assertTrue(len(self.Catalog.providers) == 2) - - self.assertIn('obj', self.Catalog.providers) - self.assertIn(self.Catalog.obj, self.Catalog.providers.values()) - - self.assertIn('another_obj', self.Catalog.providers) - self.assertIn(self.Catalog.another_obj, - self.Catalog.providers.values()) - def test_all_providers_by_type(self): """Test getting of all catalog providers of specific type.""" self.assertTrue(len(self.Catalog.filter(di.Object)) == 2) self.assertTrue(len(self.Catalog.filter(di.Value)) == 0) - def test_metaclass_with_several_catalogs(self): - """Test that metaclass work well with several catalogs.""" - class Catalog1(di.AbstractCatalog): - - """Catalog1.""" - - provider = di.Object(object()) - - class Catalog2(di.AbstractCatalog): - - """Catalog2.""" - - provider = di.Object(object()) - - self.assertTrue(len(Catalog1.providers) == 1) - self.assertIs(Catalog1.provider, Catalog1.providers['provider']) - - self.assertTrue(len(Catalog2.providers) == 1) - self.assertIs(Catalog2.provider, Catalog2.providers['provider']) - - self.assertIsNot(Catalog1.provider, Catalog2.provider) - class OverrideTests(unittest.TestCase): diff --git a/tests/test_utils.py b/tests/test_utils.py index de14695b..2943aa3d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -190,3 +190,32 @@ class IsMethodInjectionTests(unittest.TestCase): def test_with_object(self): """Test with object.""" self.assertFalse(di.is_method_injection(object())) + + +class IsCatalogTests(unittest.TestCase): + + """`is_catalog()` test cases.""" + + def test_with_cls(self): + """Test with class.""" + self.assertTrue(di.is_catalog(di.AbstractCatalog)) + + def test_with_instance(self): + """Test with class.""" + self.assertFalse(di.is_catalog(di.AbstractCatalog())) + + def test_with_child_class(self): + """Test with parent class.""" + class Catalog(di.AbstractCatalog): + + """Example catalog child class.""" + + self.assertTrue(di.is_catalog(Catalog)) + + def test_with_string(self): + """Test with string.""" + self.assertFalse(di.is_catalog('some_string')) + + def test_with_object(self): + """Test with object.""" + self.assertFalse(di.is_catalog(object()))