Refactor DeclarativeCatalog

This commit is contained in:
Roman Mogilatov 2015-11-12 13:43:54 +02:00
parent 4252fbfe4c
commit be1ff0445d
3 changed files with 93 additions and 11 deletions

View File

@ -41,6 +41,7 @@ from .utils import is_catalog_bundle
from .utils import ensure_is_catalog_bundle
from .errors import Error
from .errors import UndefinedProviderError
# Backward compatibility for versions < 0.11.*
from . import catalogs
@ -96,6 +97,7 @@ __all__ = (
# Errors
'Error',
'UndefinedProviderError',
# Version
'VERSION'

View File

@ -88,6 +88,7 @@ class DynamicCatalog(object):
self.Bundle = CatalogBundle.sub_cls_factory(self)
self.bind_providers(providers)
super(DynamicCatalog, self).__init__()
def is_bundle_owner(self, bundle):
"""Check if catalog is bundle owner."""
@ -181,6 +182,16 @@ class DynamicCatalog(object):
"""Return provider with specified name or raise en error."""
return self.get_provider(name)
def __setattr__(self, name, value):
"""Handle setting of catalog attributes.
Setting of attributes works as usual, but if value of attribute is
provider, this provider will be bound to catalog correctly.
"""
if is_provider(value):
return self.bind_provider(name, value)
return super(DynamicCatalog, self).__setattr__(name, value)
def __repr__(self):
"""Return Python representation of catalog."""
return '<{0}({1})>'.format(self.name,
@ -195,8 +206,6 @@ class DeclarativeCatalogMetaClass(type):
def __new__(mcs, class_name, bases, attributes):
"""Declarative catalog class factory."""
cls = type.__new__(mcs, class_name, bases, attributes)
cls_providers = tuple((name, provider)
for name, provider in six.iteritems(attributes)
if is_provider(provider))
@ -208,10 +217,10 @@ class DeclarativeCatalogMetaClass(type):
providers = cls_providers + inherited_providers
cls.name = '.'.join((cls.__module__, cls.__name__))
cls = type.__new__(mcs, class_name, bases, attributes)
cls.catalog = DynamicCatalog()
cls.catalog.name = cls.name
cls.catalog.name = '.'.join((cls.__module__, cls.__name__))
cls.catalog.bind_providers(dict(providers))
cls.cls_providers = dict(cls_providers)
@ -221,6 +230,11 @@ class DeclarativeCatalogMetaClass(type):
return cls
@property
def name(cls):
"""Return catalog's name."""
return cls.catalog.name
@property
def providers(cls):
"""Return dict of catalog's providers."""
@ -241,6 +255,21 @@ class DeclarativeCatalogMetaClass(type):
"""Return last overriding catalog."""
return cls.catalog.last_overriding
def __getattr__(cls, name):
"""Return provider with specified name or raise en error."""
raise UndefinedProviderError('There is no provider "{0}" in '
'catalog {1}'.format(name, cls))
def __setattr__(cls, name, value):
"""Handle setting of catalog attributes.
Setting of attributes works as usual, but if value of attribute is
provider, this provider will be bound to catalog correctly.
"""
if is_provider(value):
setattr(cls.catalog, name, value)
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
def __repr__(cls):
"""Return string representation of the catalog."""
return '<{0}({1})>'.format(cls.name,
@ -326,28 +355,41 @@ class DeclarativeCatalog(object):
:type overriding: DeclarativeCatalog | DynamicCatalog
"""
cls.catalog.override(overriding)
return cls.catalog.override(overriding)
@classmethod
def reset_last_overriding(cls):
"""Reset last overriding catalog."""
cls.catalog.reset_last_overriding()
return cls.catalog.reset_last_overriding()
@classmethod
def reset_override(cls):
"""Reset all overridings for all catalog providers."""
cls.catalog.reset_override()
return cls.catalog.reset_override()
@classmethod
def get(cls, name):
"""Return provider with specified name or raises error."""
def get_provider(cls, name):
"""Return provider with specified name or raise an error."""
return cls.catalog.get_provider(name)
@classmethod
def has(cls, name):
def bind_provider(cls, name, provider):
"""Bind provider to catalog with specified name."""
return cls.catalog.bind_provider(name, provider)
@classmethod
def bind_providers(cls, providers):
"""Bind providers dictionary to catalog."""
return cls.catalog.bind_providers(providers)
@classmethod
def has_provider(cls, name):
"""Check if there is provider with certain name."""
return cls.catalog.has_provider(name)
get = get_provider # Backward compatibility for versions < 0.11.*
has = has_provider # Backward compatibility for versions < 0.11.*
# Backward compatibility for versions < 0.11.*
AbstractCatalog = DeclarativeCatalog

View File

@ -157,6 +157,22 @@ class DeclarativeCatalogTests(unittest.TestCase):
p31=CatalogC.p31,
p32=CatalogC.p32))
def test_setattr(self):
"""Test setting of provider attributes to catalog."""
px = di.Provider()
py = di.Provider()
CatalogC.px = px
CatalogC.py = py
self.assertIs(CatalogC.px, px)
self.assertIs(CatalogC.get_provider('px'), px)
self.assertIs(CatalogC.catalog.px, px)
self.assertIs(CatalogC.py, py)
self.assertIs(CatalogC.get_provider('py'), py)
self.assertIs(CatalogC.catalog.py, py)
def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs."""
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
@ -227,9 +243,23 @@ class DeclarativeCatalogTests(unittest.TestCase):
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
self.assertIs(CatalogC.get_provider('p11'), CatalogC.p11)
self.assertIs(CatalogC.get_provider('p12'), CatalogC.p12)
self.assertIs(CatalogC.get_provider('p22'), CatalogC.p22)
self.assertIs(CatalogC.get_provider('p22'), CatalogC.p22)
self.assertIs(CatalogC.get_provider('p32'), CatalogC.p32)
self.assertIs(CatalogC.get_provider('p32'), CatalogC.p32)
def test_get_undefined(self):
"""Test getting of undefined providers using get() method."""
self.assertRaises(di.Error, CatalogC.get, 'undefined')
with self.assertRaises(di.UndefinedProviderError):
CatalogC.get('undefined')
with self.assertRaises(di.UndefinedProviderError):
CatalogC.get_provider('undefined')
with self.assertRaises(di.UndefinedProviderError):
CatalogC.undefined
def test_has(self):
"""Test checks of providers availability in catalog."""
@ -241,6 +271,14 @@ class DeclarativeCatalogTests(unittest.TestCase):
self.assertTrue(CatalogC.has('p32'))
self.assertFalse(CatalogC.has('undefined'))
self.assertTrue(CatalogC.has_provider('p11'))
self.assertTrue(CatalogC.has_provider('p12'))
self.assertTrue(CatalogC.has_provider('p21'))
self.assertTrue(CatalogC.has_provider('p22'))
self.assertTrue(CatalogC.has_provider('p31'))
self.assertTrue(CatalogC.has_provider('p32'))
self.assertFalse(CatalogC.has_provider('undefined'))
def test_filter_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type."""
self.assertTrue(len(CatalogC.filter(di.Provider)) == 6)