Refactor catalog bundle

This commit is contained in:
Roman Mogilatov 2015-10-16 23:54:51 +03:00
parent 730d46f9a0
commit 40dc54b64b
3 changed files with 90 additions and 81 deletions

View File

@ -8,32 +8,21 @@ from .utils import is_provider
from .utils import is_catalog from .utils import is_catalog
class CatalogPackFactory(object): class CatalogBundle(object):
"""Factory of catalog packs.""" """Bundle of catalog providers."""
def __init__(self, catalog): catalog = None
""":type: AbstractCatalog"""
__slots__ = ('providers', '__dict__')
def __init__(self, *providers):
"""Initializer.""" """Initializer."""
self.catalog = catalog self.providers = dict((provider.bind.name, provider)
for provider in providers
def __call__(self, *providers): if self._ensure_provider_is_bound(provider))
"""Create catalog pack with specified providers.""" self.__dict__.update(self.providers)
return CatalogPack(self.catalog, *providers) super(CatalogBundle, self).__init__()
class CatalogPack(object):
"""Pack of catalog providers."""
__slots__ = ('catalog', 'providers', '__dict__')
def __init__(self, catalog, *providers):
"""Initializer."""
self.catalog = catalog
self.providers = dict()
for provider in providers:
self._ensure_provider_is_bound(provider)
self.__dict__[provider.bind.name] = provider
self.providers[provider.bind.name] = provider
super(CatalogPack, self).__init__()
def get(self, name): def get(self, name):
"""Return provider with specified name or raises error.""" """Return provider with specified name or raises error."""
@ -47,16 +36,17 @@ class CatalogPack(object):
return name in self.providers return name in self.providers
def _ensure_provider_is_bound(self, provider): def _ensure_provider_is_bound(self, provider):
"""Check that provider is bound.""" """Check that provider is bound to the bundle's catalog."""
if not provider.bind: if not provider.is_bound:
raise Error('Provider {0} is not bound to ' raise Error('Provider {0} is not bound to '
'any catalog'.format(provider)) 'any catalog'.format(provider))
if provider is not self.catalog.get(provider.bind.name): if provider is not self.catalog.get(provider.bind.name):
raise Error('{0} can contain providers from ' raise Error('{0} can contain providers from '
'catalog {0}' .format(self, self.catalog)) 'catalog {0}'.format(self.__class__, self.catalog))
return True
def _raise_undefined_provider_error(self, name): def _raise_undefined_provider_error(self, name):
"""Raise error for cases when there is no such provider in pack.""" """Raise error for cases when there is no such provider in bundle."""
raise Error('Provider "{0}" is not a part of {1}'.format(name, self)) raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
def __getattr__(self, item): def __getattr__(self, item):
@ -64,16 +54,16 @@ class CatalogPack(object):
self._raise_undefined_provider_error(item) self._raise_undefined_provider_error(item)
def __repr__(self): def __repr__(self):
"""Return string representation of pack.""" """Return string representation of bundle."""
return '<Catalog pack ({0}), {1}>'.format( return '<Bundle of {0} providers ({1})>'.format(
', '.join(six.iterkeys(self.providers)), self.catalog) self.catalog, ', '.join(six.iterkeys(self.providers)))
class CatalogMetaClass(type): class CatalogMetaClass(type):
"""Providers catalog meta class.""" """Catalog meta class."""
def __new__(mcs, class_name, bases, attributes): def __new__(mcs, class_name, bases, attributes):
"""Meta class factory.""" """Catalog class factory."""
cls_providers = dict((name, provider) cls_providers = dict((name, provider)
for name, provider in six.iteritems(attributes) for name, provider in six.iteritems(attributes)
if is_provider(provider)) if is_provider(provider))
@ -93,13 +83,18 @@ class CatalogMetaClass(type):
cls.inherited_providers = inherited_providers cls.inherited_providers = inherited_providers
cls.providers = providers cls.providers = providers
cls.Pack = CatalogPackFactory(cls) cls.Bundle = mcs.bundle_cls_factory(cls)
for name, provider in six.iteritems(cls_providers): for name, provider in six.iteritems(cls_providers):
provider.bind = ProviderBinding(cls, name) provider.bind = ProviderBinding(cls, name)
return cls return cls
@classmethod
def bundle_cls_factory(mcs, cls):
"""Create bundle class for catalog."""
return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls))
def __repr__(cls): def __repr__(cls):
"""Return string representation of the catalog class.""" """Return string representation of the catalog class."""
return '<Catalog "' + '.'.join((cls.__module__, cls.__name__)) + '">' return '<Catalog "' + '.'.join((cls.__module__, cls.__name__)) + '">'
@ -120,11 +115,11 @@ class AbstractCatalog(object):
:param inherited_providers: Dict of providers, that are inherited from :param inherited_providers: Dict of providers, that are inherited from
parent catalogs parent catalogs
:type Pack: CatalogPack :type Bundle: CatalogBundle
:param Pack: Catalog's pack class :param Bundle: Catalog's bundle class
""" """
Pack = CatalogPackFactory Bundle = CatalogBundle
cls_providers = dict() cls_providers = dict()
inherited_providers = dict() inherited_providers = dict()

View File

@ -74,6 +74,11 @@ class Provider(object):
"""Reset all overriding providers.""" """Reset all overriding providers."""
self.overridden_by = None self.overridden_by = None
@property
def is_bound(self):
"""Check if provider is bound to any catalog."""
return bool(self.bind)
class Delegate(Provider): class Delegate(Provider):
"""Provider's delegate.""" """Provider's delegate."""

View File

@ -25,13 +25,6 @@ class CatalogC(CatalogB):
p32 = di.Provider() p32 = di.Provider()
class CatalogD(di.AbstractCatalog):
"""Test catalog D."""
p11 = di.Provider()
p12 = di.Provider()
class CatalogsInheritanceTests(unittest.TestCase): class CatalogsInheritanceTests(unittest.TestCase):
"""Catalogs inheritance tests.""" """Catalogs inheritance tests."""
@ -78,55 +71,71 @@ class CatalogsInheritanceTests(unittest.TestCase):
p32=CatalogC.p32)) p32=CatalogC.p32))
class CatalogPackTests(unittest.TestCase): class CatalogBundleTests(unittest.TestCase):
"""Catalog pack test cases.""" """Catalog bundle test cases."""
def setUp(self): def setUp(self):
"""Set test environment up.""" """Set test environment up."""
self.pack = CatalogC.Pack(CatalogC.p11, self.bundle = CatalogC.Bundle(CatalogC.p11,
CatalogC.p12) CatalogC.p12)
def test_get_attr_from_pack(self): def test_get_attr_from_bundle(self):
"""Test get providers (attribute) from pack.""" """Test get providers (attribute) from catalog bundle."""
self.assertIs(self.pack.p11, CatalogC.p11) self.assertIs(self.bundle.p11, CatalogC.p11)
self.assertIs(self.pack.p12, CatalogC.p12) self.assertIs(self.bundle.p12, CatalogC.p12)
def test_get_attr_not_from_pack(self): def test_get_attr_not_from_bundle(self):
"""Test get providers (attribute) that are not in pack.""" """Test get providers (attribute) that are not in bundle."""
self.assertRaises(di.Error, getattr, self.pack, 'p21') self.assertRaises(di.Error, getattr, self.bundle, 'p21')
self.assertRaises(di.Error, getattr, self.pack, 'p22') self.assertRaises(di.Error, getattr, self.bundle, 'p22')
self.assertRaises(di.Error, getattr, self.pack, 'p31') self.assertRaises(di.Error, getattr, self.bundle, 'p31')
self.assertRaises(di.Error, getattr, self.pack, 'p32') self.assertRaises(di.Error, getattr, self.bundle, 'p32')
def test_get_method_from_pack(self): def test_get_method_from_bundle(self):
"""Test get providers (get() method) from pack.""" """Test get providers (get() method) from bundle."""
self.assertIs(self.pack.get('p11'), CatalogC.p11) self.assertIs(self.bundle.get('p11'), CatalogC.p11)
self.assertIs(self.pack.get('p12'), CatalogC.p12) self.assertIs(self.bundle.get('p12'), CatalogC.p12)
def test_get_method_not_from_pack(self): def test_get_method_not_from_bundle(self):
"""Test get providers (get() method) that are not in pack.""" """Test get providers (get() method) that are not in bundle."""
self.assertRaises(di.Error, self.pack.get, 'p21') self.assertRaises(di.Error, self.bundle.get, 'p21')
self.assertRaises(di.Error, self.pack.get, 'p22') self.assertRaises(di.Error, self.bundle.get, 'p22')
self.assertRaises(di.Error, self.pack.get, 'p31') self.assertRaises(di.Error, self.bundle.get, 'p31')
self.assertRaises(di.Error, self.pack.get, 'p32') self.assertRaises(di.Error, self.bundle.get, 'p32')
def test_has(self): def test_has(self):
"""Test checks of providers availability in pack.""" """Test checks of providers availability in bundle."""
self.assertTrue(self.pack.has('p11')) self.assertTrue(self.bundle.has('p11'))
self.assertTrue(self.pack.has('p12')) self.assertTrue(self.bundle.has('p12'))
self.assertFalse(self.pack.has('p21')) self.assertFalse(self.bundle.has('p21'))
self.assertFalse(self.pack.has('p22')) self.assertFalse(self.bundle.has('p22'))
self.assertFalse(self.pack.has('p31')) self.assertFalse(self.bundle.has('p31'))
self.assertFalse(self.pack.has('p32')) self.assertFalse(self.bundle.has('p32'))
def test_create_pack_with_another_catalog_provider(self): def test_create_bundle_with_unbound_provider(self):
"""Test that pack is not created with provider from another catalog.""" """Test that bundle is not created with unbound provider."""
self.assertRaises(di.Error, CatalogC.Pack, CatalogC.p31, CatalogD.p11) self.assertRaises(di.Error, CatalogC.Bundle, di.Provider())
def test_create_pack_with_unbound_provider(self): def test_create_bundle_with_another_catalog_provider(self):
"""Test that pack is not created with unbound provider.""" """Test that bundle can not contain another catalog's provider."""
self.assertRaises(di.Error, CatalogC.Pack, di.Provider()) class TestCatalog(di.AbstractCatalog):
"""Test catalog."""
provider = di.Provider()
self.assertRaises(di.Error,
CatalogC.Bundle, CatalogC.p31, TestCatalog.provider)
def test_create_bundle_with_another_catalog_provider_with_same_name(self):
"""Test that bundle can not contain another catalog's provider."""
class TestCatalog(di.AbstractCatalog):
"""Test catalog."""
p31 = di.Provider()
self.assertRaises(di.Error,
CatalogC.Bundle, CatalogC.p31, TestCatalog.p31)
class CatalogTests(unittest.TestCase): class CatalogTests(unittest.TestCase):