Refactor DeclarativeCatalog

This commit is contained in:
Roman Mogilatov 2015-11-12 16:56:00 +02:00
parent be1ff0445d
commit 2236d77313
2 changed files with 88 additions and 6 deletions

View File

@ -178,6 +178,12 @@ class DynamicCatalog(object):
"""Check if there is provider with certain name.""" """Check if there is provider with certain name."""
return name in self.providers return name in self.providers
def unbind_provider(self, name):
"""Remove provider binding."""
provider = self.get_provider(name)
del self.providers[name]
del self.provider_names[provider]
def __getattr__(self, name): def __getattr__(self, name):
"""Return provider with specified name or raise en error.""" """Return provider with specified name or raise en error."""
return self.get_provider(name) return self.get_provider(name)
@ -192,6 +198,14 @@ class DynamicCatalog(object):
return self.bind_provider(name, value) return self.bind_provider(name, value)
return super(DynamicCatalog, self).__setattr__(name, value) return super(DynamicCatalog, self).__setattr__(name, value)
def __delattr__(self, name):
"""Handle deleting of catalog attibute.
Deleting of attributes works as usual, but if value of attribute is
provider, this provider will be unbound from catalog correctly.
"""
self.unbind_provider(name)
def __repr__(self): def __repr__(self):
"""Return Python representation of catalog.""" """Return Python representation of catalog."""
return '<{0}({1})>'.format(self.name, return '<{0}({1})>'.format(self.name,
@ -270,6 +284,16 @@ class DeclarativeCatalogMetaClass(type):
setattr(cls.catalog, name, value) setattr(cls.catalog, name, value)
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value) return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
def __delattr__(cls, name):
"""Handle deleting of catalog attibute.
Deleting of attributes works as usual, but if value of attribute is
provider, this provider will be unbound from catalog correctly.
"""
if is_provider(getattr(cls, name)):
delattr(cls.catalog, name)
return super(DeclarativeCatalogMetaClass, cls).__delattr__(name)
def __repr__(cls): def __repr__(cls):
"""Return string representation of the catalog.""" """Return string representation of the catalog."""
return '<{0}({1})>'.format(cls.name, return '<{0}({1})>'.format(cls.name,
@ -360,12 +384,12 @@ class DeclarativeCatalog(object):
@classmethod @classmethod
def reset_last_overriding(cls): def reset_last_overriding(cls):
"""Reset last overriding catalog.""" """Reset last overriding catalog."""
return cls.catalog.reset_last_overriding() cls.catalog.reset_last_overriding()
@classmethod @classmethod
def reset_override(cls): def reset_override(cls):
"""Reset all overridings for all catalog providers.""" """Reset all overridings for all catalog providers."""
return cls.catalog.reset_override() cls.catalog.reset_override()
@classmethod @classmethod
def get_provider(cls, name): def get_provider(cls, name):
@ -375,17 +399,23 @@ class DeclarativeCatalog(object):
@classmethod @classmethod
def bind_provider(cls, name, provider): def bind_provider(cls, name, provider):
"""Bind provider to catalog with specified name.""" """Bind provider to catalog with specified name."""
return cls.catalog.bind_provider(name, provider) setattr(cls, name, provider)
@classmethod @classmethod
def bind_providers(cls, providers): def bind_providers(cls, providers):
"""Bind providers dictionary to catalog.""" """Bind providers dictionary to catalog."""
return cls.catalog.bind_providers(providers) for name, provider in six.iteritems(providers):
setattr(cls, name, provider)
@classmethod @classmethod
def has_provider(cls, name): def has_provider(cls, name):
"""Check if there is provider with certain name.""" """Check if there is provider with certain name."""
return cls.catalog.has_provider(name) return hasattr(cls, name)
@classmethod
def unbind_provider(cls, name):
"""Remove provider binding."""
delattr(cls, name)
get = get_provider # Backward compatibility for versions < 0.11.* get = get_provider # Backward compatibility for versions < 0.11.*
has = has_provider # Backward compatibility for versions < 0.11.* has = has_provider # Backward compatibility for versions < 0.11.*

View File

@ -157,8 +157,45 @@ class DeclarativeCatalogTests(unittest.TestCase):
p31=CatalogC.p31, p31=CatalogC.p31,
p32=CatalogC.p32)) p32=CatalogC.p32))
def test_bind_provider(self):
"""Test setting of provider via bind_provider() to catalog."""
px = di.Provider()
py = di.Provider()
CatalogC.bind_provider('px', px)
CatalogC.bind_provider('py', py)
self.assertIs(CatalogC.px, px)
self.assertIs(CatalogC.get_provider('px'), px)
self.assertIs(CatalogC.catalog.px, px)
self.assertIs(CatalogC.py, py)
self.assertIs(CatalogC.get_provider('py'), py)
self.assertIs(CatalogC.catalog.py, py)
del CatalogC.px
del CatalogC.py
def test_bind_providers(self):
"""Test setting of provider via bind_providers() to catalog."""
px = di.Provider()
py = di.Provider()
CatalogC.bind_providers(dict(px=px, py=py))
self.assertIs(CatalogC.px, px)
self.assertIs(CatalogC.get_provider('px'), px)
self.assertIs(CatalogC.catalog.px, px)
self.assertIs(CatalogC.py, py)
self.assertIs(CatalogC.get_provider('py'), py)
self.assertIs(CatalogC.catalog.py, py)
del CatalogC.px
del CatalogC.py
def test_setattr(self): def test_setattr(self):
"""Test setting of provider attributes to catalog.""" """Test setting of providers via attributes to catalog."""
px = di.Provider() px = di.Provider()
py = di.Provider() py = di.Provider()
@ -173,6 +210,21 @@ class DeclarativeCatalogTests(unittest.TestCase):
self.assertIs(CatalogC.get_provider('py'), py) self.assertIs(CatalogC.get_provider('py'), py)
self.assertIs(CatalogC.catalog.py, py) self.assertIs(CatalogC.catalog.py, py)
del CatalogC.px
del CatalogC.py
def test_unbind_provider(self):
"""Test that catalog unbinds provider correct."""
CatalogC.px = di.Provider()
CatalogC.unbind_provider('px')
self.assertFalse(CatalogC.has_provider('px'))
def test_unbind_via_delattr(self):
"""Test that catalog unbinds provider correct."""
CatalogC.px = di.Provider()
del CatalogC.px
self.assertFalse(CatalogC.has_provider('px'))
def test_provider_is_bound(self): def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs.""" """Test that providers are bound to the catalogs."""
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11)) self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))