Merge remote-tracking branch 'origin/enhancement_of_catalog_inheritance'

This commit is contained in:
Roman Mogilatov 2015-10-07 19:57:40 +03:00
commit b7bab4938f
8 changed files with 180 additions and 71 deletions

View File

@ -29,6 +29,7 @@ from .utils import ensure_is_injection
from .utils import is_kwarg_injection from .utils import is_kwarg_injection
from .utils import is_attribute_injection from .utils import is_attribute_injection
from .utils import is_method_injection from .utils import is_method_injection
from .utils import is_catalog
from .errors import Error from .errors import Error
@ -67,6 +68,7 @@ __all__ = (
'is_kwarg_injection', 'is_kwarg_injection',
'is_attribute_injection', 'is_attribute_injection',
'is_method_injection', 'is_method_injection',
'is_catalog',
# Errors # Errors
'Error', 'Error',

View File

@ -3,7 +3,9 @@
import six import six
from .errors import Error from .errors import Error
from .utils import is_provider from .utils import is_provider
from .utils import is_catalog
class CatalogMetaClass(type): class CatalogMetaClass(type):
@ -12,26 +14,47 @@ class CatalogMetaClass(type):
def __new__(mcs, class_name, bases, attributes): def __new__(mcs, class_name, bases, attributes):
"""Meta class factory.""" """Meta class factory."""
providers = dict() cls_providers = dict((name, provider)
new_attributes = dict() for name, provider in six.iteritems(attributes)
for name, value in six.iteritems(attributes): if is_provider(provider))
if is_provider(value):
providers[name] = value
new_attributes[name] = value
cls = type.__new__(mcs, class_name, bases, new_attributes) inherited_providers = dict((name, provider)
cls.providers = cls.providers.copy() for base in bases if is_catalog(base)
cls.providers.update(providers) for name, provider in six.iteritems(
return cls base.providers))
providers = dict()
providers.update(cls_providers)
providers.update(inherited_providers)
attributes['cls_providers'] = cls_providers
attributes['inherited_providers'] = inherited_providers
attributes['providers'] = providers
return type.__new__(mcs, class_name, bases, attributes)
@six.add_metaclass(CatalogMetaClass) @six.add_metaclass(CatalogMetaClass)
class AbstractCatalog(object): class AbstractCatalog(object):
"""Abstract providers catalog.""" """Abstract providers catalog.
:type providers: dict[str, dependency_injector.Provider]
:param providers: Dict of all catalog providers, including inherited from
parent catalogs
:type cls_providers: dict[str, dependency_injector.Provider]
:param cls_providers: Dict of current catalog providers
:type inherited_providers: dict[str, dependency_injector.Provider]
:param inherited_providers: Dict of providers, that are inherited from
parent catalogs
"""
providers = dict() providers = dict()
cls_providers = dict()
inherited_providers = dict()
__IS_CATALOG__ = True
__slots__ = ('used_providers',) __slots__ = ('used_providers',)
def __init__(self, *used_providers): def __init__(self, *used_providers):
@ -41,7 +64,7 @@ class AbstractCatalog(object):
def __getattribute__(self, item): def __getattribute__(self, item):
"""Return providers.""" """Return providers."""
attribute = super(AbstractCatalog, self).__getattribute__(item) attribute = super(AbstractCatalog, self).__getattribute__(item)
if item in ('providers', 'used_providers',): if item in ('providers', 'used_providers', '__class__'):
return attribute return attribute
if attribute not in self.used_providers: if attribute not in self.used_providers:
@ -62,7 +85,7 @@ class AbstractCatalog(object):
:type overriding: AbstractCatalog :type overriding: AbstractCatalog
""" """
for name, provider in six.iteritems(overriding.providers): for name, provider in six.iteritems(overriding.cls_providers):
cls.providers[name].override(provider) cls.providers[name].override(provider)

View File

@ -56,6 +56,12 @@ def is_method_injection(instance):
getattr(instance, '__IS_METHOD_INJECTION__', False) is True) getattr(instance, '__IS_METHOD_INJECTION__', False) is True)
def is_catalog(instance):
"""Check if instance is catalog instance."""
return (isinstance(instance, six.class_types) and
getattr(instance, '__IS_CATALOG__', False) is True)
def get_injectable_kwargs(kwargs, injections): def get_injectable_kwargs(kwargs, injections):
"""Return dictionary of kwargs, patched with injections.""" """Return dictionary of kwargs, patched with injections."""
init_kwargs = dict(((injection.name, injection.value) init_kwargs = dict(((injection.name, injection.value)

View File

@ -45,16 +45,23 @@ Example:
Operating with catalog providers Operating with catalog providers
-------------------------------- --------------------------------
There are several things that could be useful for operating with catalog ``di.AbstractCatalog`` has several features that could be useful for some kind
providers: of operations on catalog's providers:
- First of all, ``di.AbstractCatalog.providers`` attribute contains ``dict`` - ``di.AbstractCatalog.providers`` is read-only attribute that contains
with all catalog providers. This dictionary could be used for any kind of ``dict`` of all catalog providers, including providers that are inherited
operations that could be done with providers. The only note, is that from parent catalogs, where key is the name of provider and value is
``di.AbstractCatalog.providers`` attribute is read-only. provider itself.
- Second one, ``di.AbstractCatalog.filter(provider_type=di.Provider)`` method - ``di.AbstractCatalog.cls_providers`` is read-only attribute contains ``dict``
could be used for filtering catalog providers by provider types (for example, of current catalog providers, where key is the name of provider and value is
for getting all ``di.Factory`` providers). provider itself.
- ``di.AbstractCatalog.inherited_providers`` is read-only attribute contains
``dict`` of all providers that are inherited from parent catalogs, where key
is the name of provider and value is provider itself.
- ``di.AbstractCatalog.filter(provider_type=di.Provider)`` is a class method
that could be used for filtering catalog providers by provider types
(for example, for getting all ``di.Factory`` providers).
``di.AbstractCatalog.filter()`` method use ``di.AbstractCatalog.providers``.
Example: Example:

View File

@ -12,6 +12,7 @@ Development version
------------------- -------------------
- Add functionality for decorating classes with ``@di.inject``. - Add functionality for decorating classes with ``@di.inject``.
- Add enhancement for ``di.AbstractCatalog`` inheritance.
0.9.5 0.9.5
----- -----

View File

@ -3,29 +3,35 @@
import dependency_injector as di import dependency_injector as di
class Catalog(di.AbstractCatalog): class CatalogA(di.AbstractCatalog):
"""Providers catalog.""" """Example catalog A."""
provider1 = di.Factory(object) provider1 = di.Factory(object)
""":type: (di.Provider) -> object""" """:type: (di.Provider) -> object"""
provider2 = di.Factory(object)
""":type: (di.Provider) -> object"""
provider3 = di.Singleton(object) class CatalogB(CatalogA):
""":type: (di.Provider) -> object"""
provider4 = di.Singleton(object) """Example catalog B."""
provider2 = di.Singleton(object)
""":type: (di.Provider) -> object""" """:type: (di.Provider) -> object"""
# Making some asserts: # Making some asserts for `providers` attribute:
assert Catalog.providers == dict(provider1=Catalog.provider1, assert CatalogA.providers == dict(provider1=CatalogA.provider1)
provider2=Catalog.provider2, assert CatalogB.providers == dict(provider1=CatalogA.provider1,
provider3=Catalog.provider3, provider2=CatalogB.provider2)
provider4=Catalog.provider4)
assert Catalog.filter(di.Factory) == dict(provider1=Catalog.provider1, # Making some asserts for `cls_providers` attribute:
provider2=Catalog.provider2) assert CatalogA.cls_providers == dict(provider1=CatalogA.provider1)
assert Catalog.filter(di.Singleton) == dict(provider3=Catalog.provider3, assert CatalogB.cls_providers == dict(provider2=CatalogB.provider2)
provider4=Catalog.provider4)
# Making some asserts for `inherited_providers` attribute:
assert CatalogA.inherited_providers == dict()
assert CatalogB.inherited_providers == dict(provider1=CatalogA.provider1)
# Making some asserts for `filter()` method:
assert CatalogB.filter(di.Factory) == dict(provider1=CatalogA.provider1)
assert CatalogB.filter(di.Singleton) == dict(provider2=CatalogB.provider2)

View File

@ -4,6 +4,74 @@ import unittest2 as unittest
import dependency_injector as di import dependency_injector as di
class CatalogsInheritanceTests(unittest.TestCase):
"""Catalogs inheritance tests."""
class CatalogA(di.AbstractCatalog):
"""Test catalog A."""
p11 = di.Provider()
p12 = di.Provider()
class CatalogB(CatalogA):
"""Test catalog B."""
p21 = di.Provider()
p22 = di.Provider()
class CatalogC(CatalogB):
"""Test catalog C."""
p31 = di.Provider()
p32 = di.Provider()
def test_cls_providers(self):
"""Test `di.AbstractCatalog.cls_providers` contents."""
self.assertDictEqual(self.CatalogA.cls_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogB.cls_providers,
dict(p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
self.assertDictEqual(self.CatalogC.cls_providers,
dict(p31=self.CatalogC.p31,
p32=self.CatalogC.p32))
def test_inherited_providers(self):
"""Test `di.AbstractCatalog.inherited_providers` contents."""
self.assertDictEqual(self.CatalogA.inherited_providers, dict())
self.assertDictEqual(self.CatalogB.inherited_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogC.inherited_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
def test_providers(self):
"""Test `di.AbstractCatalog.inherited_providers` contents."""
self.assertDictEqual(self.CatalogA.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogB.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
self.assertDictEqual(self.CatalogC.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22,
p31=self.CatalogC.p31,
p32=self.CatalogC.p32))
class CatalogTests(unittest.TestCase): class CatalogTests(unittest.TestCase):
"""Catalog test cases.""" """Catalog test cases."""
@ -25,44 +93,11 @@ class CatalogTests(unittest.TestCase):
catalog = self.Catalog() catalog = self.Catalog()
self.assertRaises(di.Error, getattr, catalog, 'obj') self.assertRaises(di.Error, getattr, catalog, 'obj')
def test_all_providers(self):
"""Test getting of all catalog providers."""
self.assertTrue(len(self.Catalog.providers) == 2)
self.assertIn('obj', self.Catalog.providers)
self.assertIn(self.Catalog.obj, self.Catalog.providers.values())
self.assertIn('another_obj', self.Catalog.providers)
self.assertIn(self.Catalog.another_obj,
self.Catalog.providers.values())
def test_all_providers_by_type(self): def test_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type.""" """Test getting of all catalog providers of specific type."""
self.assertTrue(len(self.Catalog.filter(di.Object)) == 2) self.assertTrue(len(self.Catalog.filter(di.Object)) == 2)
self.assertTrue(len(self.Catalog.filter(di.Value)) == 0) self.assertTrue(len(self.Catalog.filter(di.Value)) == 0)
def test_metaclass_with_several_catalogs(self):
"""Test that metaclass work well with several catalogs."""
class Catalog1(di.AbstractCatalog):
"""Catalog1."""
provider = di.Object(object())
class Catalog2(di.AbstractCatalog):
"""Catalog2."""
provider = di.Object(object())
self.assertTrue(len(Catalog1.providers) == 1)
self.assertIs(Catalog1.provider, Catalog1.providers['provider'])
self.assertTrue(len(Catalog2.providers) == 1)
self.assertIs(Catalog2.provider, Catalog2.providers['provider'])
self.assertIsNot(Catalog1.provider, Catalog2.provider)
class OverrideTests(unittest.TestCase): class OverrideTests(unittest.TestCase):

View File

@ -190,3 +190,32 @@ class IsMethodInjectionTests(unittest.TestCase):
def test_with_object(self): def test_with_object(self):
"""Test with object.""" """Test with object."""
self.assertFalse(di.is_method_injection(object())) self.assertFalse(di.is_method_injection(object()))
class IsCatalogTests(unittest.TestCase):
"""`is_catalog()` test cases."""
def test_with_cls(self):
"""Test with class."""
self.assertTrue(di.is_catalog(di.AbstractCatalog))
def test_with_instance(self):
"""Test with class."""
self.assertFalse(di.is_catalog(di.AbstractCatalog()))
def test_with_child_class(self):
"""Test with parent class."""
class Catalog(di.AbstractCatalog):
"""Example catalog child class."""
self.assertTrue(di.is_catalog(Catalog))
def test_with_string(self):
"""Test with string."""
self.assertFalse(di.is_catalog('some_string'))
def test_with_object(self):
"""Test with object."""
self.assertFalse(di.is_catalog(object()))