Implement catalog providers, cls_providers and inherited providers class attributes

This commit is contained in:
Roman Mogilatov 2015-10-07 13:36:28 +03:00
parent 53e8f62c89
commit ee73f03cb2
5 changed files with 141 additions and 46 deletions

View File

@ -29,6 +29,7 @@ from .utils import ensure_is_injection
from .utils import is_kwarg_injection
from .utils import is_attribute_injection
from .utils import is_method_injection
from .utils import is_catalog
from .errors import Error
@ -67,6 +68,7 @@ __all__ = (
'is_kwarg_injection',
'is_attribute_injection',
'is_method_injection',
'is_catalog',
# Errors
'Error',

View File

@ -3,7 +3,9 @@
import six
from .errors import Error
from .utils import is_provider
from .utils import is_catalog
class CatalogMetaClass(type):
@ -12,26 +14,47 @@ class CatalogMetaClass(type):
def __new__(mcs, class_name, bases, attributes):
"""Meta class factory."""
providers = dict()
new_attributes = dict()
for name, value in six.iteritems(attributes):
if is_provider(value):
providers[name] = value
new_attributes[name] = value
cls_providers = dict((name, provider)
for name, provider in six.iteritems(attributes)
if is_provider(provider))
cls = type.__new__(mcs, class_name, bases, new_attributes)
cls.providers = cls.providers.copy()
cls.providers.update(providers)
return cls
inherited_providers = dict((name, provider)
for base in bases if is_catalog(base)
for name, provider in six.iteritems(
base.providers))
providers = dict()
providers.update(cls_providers)
providers.update(inherited_providers)
attributes['cls_providers'] = cls_providers
attributes['inherited_providers'] = inherited_providers
attributes['providers'] = providers
return type.__new__(mcs, class_name, bases, attributes)
@six.add_metaclass(CatalogMetaClass)
class AbstractCatalog(object):
"""Abstract providers catalog."""
"""Abstract providers catalog.
:type providers: dict[str, dependency_injector.Provider]
:param providers: Dict of all catalog providers, including inherited from
parent catalogs
:type cls_providers: dict[str, dependency_injector.Provider]
:param cls_providers: Dict of current catalog providers
:type inherited_providers: dict[str, dependency_injector.Provider]
:param inherited_providers: Dict of providers, that are inherited from
parent catalogs
"""
providers = dict()
cls_providers = dict()
inherited_providers = dict()
__IS_CATALOG__ = True
__slots__ = ('used_providers',)
def __init__(self, *used_providers):
@ -41,7 +64,7 @@ class AbstractCatalog(object):
def __getattribute__(self, item):
"""Return providers."""
attribute = super(AbstractCatalog, self).__getattribute__(item)
if item in ('providers', 'used_providers',):
if item in ('providers', 'used_providers', '__class__'):
return attribute
if attribute not in self.used_providers:
@ -62,7 +85,7 @@ class AbstractCatalog(object):
:type overriding: AbstractCatalog
"""
for name, provider in six.iteritems(overriding.providers):
for name, provider in six.iteritems(overriding.cls_providers):
cls.providers[name].override(provider)

View File

@ -56,6 +56,12 @@ def is_method_injection(instance):
getattr(instance, '__IS_METHOD_INJECTION__', False) is True)
def is_catalog(instance):
"""Check if instance is catalog instance."""
return (isinstance(instance, six.class_types) and
getattr(instance, '__IS_CATALOG__', False) is True)
def get_injectable_kwargs(kwargs, injections):
"""Return dictionary of kwargs, patched with injections."""
init_kwargs = dict(((injection.name, injection.value)

View File

@ -4,6 +4,74 @@ import unittest2 as unittest
import dependency_injector as di
class CatalogsInheritanceTests(unittest.TestCase):
"""Catalogs inheritance tests."""
class CatalogA(di.AbstractCatalog):
"""Test catalog A."""
p11 = di.Provider()
p12 = di.Provider()
class CatalogB(CatalogA):
"""Test catalog B."""
p21 = di.Provider()
p22 = di.Provider()
class CatalogC(CatalogB):
"""Test catalog C."""
p31 = di.Provider()
p32 = di.Provider()
def test_cls_providers(self):
"""Test `di.AbstractCatalog.cls_providers` contents."""
self.assertDictEqual(self.CatalogA.cls_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogB.cls_providers,
dict(p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
self.assertDictEqual(self.CatalogC.cls_providers,
dict(p31=self.CatalogC.p31,
p32=self.CatalogC.p32))
def test_inherited_providers(self):
"""Test `di.AbstractCatalog.inherited_providers` contents."""
self.assertDictEqual(self.CatalogA.inherited_providers, dict())
self.assertDictEqual(self.CatalogB.inherited_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogC.inherited_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
def test_providers(self):
"""Test `di.AbstractCatalog.inherited_providers` contents."""
self.assertDictEqual(self.CatalogA.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogB.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
self.assertDictEqual(self.CatalogC.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22,
p31=self.CatalogC.p31,
p32=self.CatalogC.p32))
class CatalogTests(unittest.TestCase):
"""Catalog test cases."""
@ -25,44 +93,11 @@ class CatalogTests(unittest.TestCase):
catalog = self.Catalog()
self.assertRaises(di.Error, getattr, catalog, 'obj')
def test_all_providers(self):
"""Test getting of all catalog providers."""
self.assertTrue(len(self.Catalog.providers) == 2)
self.assertIn('obj', self.Catalog.providers)
self.assertIn(self.Catalog.obj, self.Catalog.providers.values())
self.assertIn('another_obj', self.Catalog.providers)
self.assertIn(self.Catalog.another_obj,
self.Catalog.providers.values())
def test_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type."""
self.assertTrue(len(self.Catalog.filter(di.Object)) == 2)
self.assertTrue(len(self.Catalog.filter(di.Value)) == 0)
def test_metaclass_with_several_catalogs(self):
"""Test that metaclass work well with several catalogs."""
class Catalog1(di.AbstractCatalog):
"""Catalog1."""
provider = di.Object(object())
class Catalog2(di.AbstractCatalog):
"""Catalog2."""
provider = di.Object(object())
self.assertTrue(len(Catalog1.providers) == 1)
self.assertIs(Catalog1.provider, Catalog1.providers['provider'])
self.assertTrue(len(Catalog2.providers) == 1)
self.assertIs(Catalog2.provider, Catalog2.providers['provider'])
self.assertIsNot(Catalog1.provider, Catalog2.provider)
class OverrideTests(unittest.TestCase):

View File

@ -190,3 +190,32 @@ class IsMethodInjectionTests(unittest.TestCase):
def test_with_object(self):
"""Test with object."""
self.assertFalse(di.is_method_injection(object()))
class IsCatalogTests(unittest.TestCase):
"""`is_catalog()` test cases."""
def test_with_cls(self):
"""Test with class."""
self.assertTrue(di.is_catalog(di.AbstractCatalog))
def test_with_instance(self):
"""Test with class."""
self.assertFalse(di.is_catalog(di.AbstractCatalog()))
def test_with_child_class(self):
"""Test with parent class."""
class Catalog(di.AbstractCatalog):
"""Example catalog child class."""
self.assertTrue(di.is_catalog(Catalog))
def test_with_string(self):
"""Test with string."""
self.assertFalse(di.is_catalog('some_string'))
def test_with_object(self):
"""Test with object."""
self.assertFalse(di.is_catalog(object()))