diff --git a/objects/catalog.py b/objects/catalog.py index 75b34e1d..490d752a 100644 --- a/objects/catalog.py +++ b/objects/catalog.py @@ -1,14 +1,38 @@ """Catalog module.""" -from .providers import Provider +from six import iteritems + from .errors import Error +from .utils import is_provider + + +class CatalogMetaClass(type): + + """Providers catalog meta class.""" + + def __new__(mcs, class_name, bases, attributes): + """Meta class factory.""" + providers = dict() + new_attributes = dict() + for name, value in attributes.iteritems(): + if is_provider(value): + providers[name] = value + new_attributes[name] = value + + cls = type.__new__(mcs, class_name, bases, new_attributes) + cls.providers = cls.providers.copy() + cls.providers.update(providers) + return cls class AbstractCatalog(object): - """Abstract object provides catalog.""" + """Abstract providers catalog.""" - __slots__ = ('__used_providers__',) + providers = dict() + + __slots__ = ('providers', '__used_providers__',) + __metaclass__ = CatalogMetaClass def __init__(self, *used_providers): """Initializer.""" @@ -17,7 +41,7 @@ class AbstractCatalog(object): def __getattribute__(self, item): """Return providers.""" attribute = super(AbstractCatalog, self).__getattribute__(item) - if item in ('__used_providers__',): + if item in ('providers', '__used_providers__',): return attribute if attribute not in self.__used_providers__: @@ -26,15 +50,11 @@ class AbstractCatalog(object): return attribute @classmethod - def all_providers(cls, provider_type=Provider): - """Return set of all class providers.""" - providers = set() - for attr_name in set(dir(cls)) - set(dir(AbstractCatalog)): - provider = getattr(cls, attr_name) - if not isinstance(provider, provider_type): - continue - providers.add((attr_name, provider)) - return providers + def filter(cls, provider_type): + """Return dict of providers, that are instance of provided type.""" + return dict([(name, provider) + for name, provider in iteritems(cls.providers) + if isinstance(provider, provider_type)]) @classmethod def override(cls, overriding): @@ -42,7 +62,5 @@ class AbstractCatalog(object): :type overriding: AbstractCatalog """ - overridden = overriding.all_providers() - cls.all_providers() - for name, provider in overridden: - overridden_provider = getattr(cls, name) - overridden_provider.override(provider) + for name, provider in iteritems(overriding.providers): + cls.providers[name].override(provider) diff --git a/tests/test_catalog.py b/tests/test_catalog.py index cc0baf8b..16bedf1c 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -33,19 +33,38 @@ class CatalogTests(unittest.TestCase): def test_all_providers(self): """Test getting of all catalog providers.""" - all_providers = self.Catalog.all_providers() - all_providers_dict = dict(all_providers) + self.assertTrue(len(self.Catalog.providers) == 2) - self.assertIsInstance(all_providers, set) - self.assertTrue(len(all_providers) == 2) + self.assertIn('obj', self.Catalog.providers) + self.assertIn(self.Catalog.obj, self.Catalog.providers.values()) - self.assertIn('obj', all_providers_dict) - self.assertIn(self.Catalog.obj, all_providers_dict.values()) - - self.assertIn('another_obj', all_providers_dict) - self.assertIn(self.Catalog.another_obj, all_providers_dict.values()) + self.assertIn('another_obj', self.Catalog.providers) + self.assertIn(self.Catalog.another_obj, + self.Catalog.providers.values()) def test_all_providers_by_type(self): """Test getting of all catalog providers of specific type.""" - self.assertTrue(len(self.Catalog.all_providers(Object)) == 2) - self.assertTrue(len(self.Catalog.all_providers(Value)) == 0) + self.assertTrue(len(self.Catalog.filter(Object)) == 2) + self.assertTrue(len(self.Catalog.filter(Value)) == 0) + + def test_metaclass_with_several_catalogs(self): + """Test that metaclass work well with several catalogs.""" + class Catalog1(AbstractCatalog): + + """Catalog1.""" + + provider = Object(object()) + + class Catalog2(AbstractCatalog): + + """Catalog2.""" + + provider = Object(object()) + + self.assertTrue(len(Catalog1.providers) == 1) + self.assertIs(Catalog1.provider, Catalog1.providers['provider']) + + self.assertTrue(len(Catalog2.providers) == 1) + self.assertIs(Catalog2.provider, Catalog2.providers['provider']) + + self.assertIsNot(Catalog1.provider, Catalog2.provider)