Change provider-to-catalog binding restrictions

This commit is contained in:
Roman Mogilatov 2015-11-10 15:01:06 +02:00
parent 9f526a3ceb
commit c5602cf88b
3 changed files with 86 additions and 57 deletions

View File

@ -21,14 +21,14 @@ class CatalogBundle(object):
def __init__(self, *providers):
"""Initializer."""
self.providers = dict((provider.bind.name, provider)
for provider in providers
if self._ensure_provider_is_bound(provider))
self.providers = dict((self.catalog.get_provider_bind_name(provider),
provider)
for provider in providers)
self.__dict__.update(self.providers)
super(CatalogBundle, self).__init__()
def get(self, name):
"""Return provider with specified name or raises error."""
"""Return provider with specified name or raise an error."""
try:
return self.providers[name]
except KeyError:
@ -38,16 +38,6 @@ class CatalogBundle(object):
"""Check if there is provider with certain name."""
return name in self.providers
def _ensure_provider_is_bound(self, provider):
"""Check that provider is bound to the bundle's catalog."""
if not provider.is_bound:
raise Error('Provider {0} is not bound to '
'any catalog'.format(provider))
if provider is not self.catalog.get(provider.bind.name):
raise Error('{0} can contain providers from '
'catalog {0}'.format(self.__class__, self.catalog))
return True
def _raise_undefined_provider_error(self, name):
"""Raise error for cases when there is no such provider in bundle."""
raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
@ -67,15 +57,6 @@ class CatalogBundle(object):
__str__ = __repr__
class Catalog(object):
"""Catalog of providers."""
def __init__(self, name, **providers):
"""Initializer."""
self.name = name
self.providers = providers
@six.python_2_unicode_compatible
class DeclarativeCatalogMetaClass(type):
"""Declarative catalog meta class."""
@ -105,13 +86,13 @@ class DeclarativeCatalogMetaClass(type):
cls.Bundle = mcs.bundle_cls_factory(cls)
for name, provider in six.iteritems(cls_providers):
if provider.is_bound:
raise Error('Provider {0} has been already bound to catalog'
'{1} as "{2}"'.format(provider,
provider.bind.catalog,
provider.bind.name))
provider.bind = ProviderBinding(cls, name)
cls.provider_names = dict()
for name, provider in six.iteritems(providers):
if provider in cls.provider_names:
raise Error('Provider {0} could not be bound to the same '
'catalog (or catalogs hierarchy) more '
'than once'.format(provider))
cls.provider_names[provider] = name
return cls
@ -151,6 +132,10 @@ class DeclarativeCatalog(object):
:param providers: Dict of all catalog providers, including inherited from
parent catalogs
:type provider_names: dict[dependency_injector.Provider, str]
:param provider_names: Dict of all catalog providers, including inherited
from parent catalogs
:type cls_providers: dict[str, dependency_injector.Provider]
:param cls_providers: Dict of current catalog providers
@ -174,6 +159,7 @@ class DeclarativeCatalog(object):
cls_providers = dict()
inherited_providers = dict()
providers = dict()
provider_names = dict()
overridden_by = tuple()
is_overridden = bool
@ -186,6 +172,19 @@ class DeclarativeCatalog(object):
"""Check if catalog is bundle owner."""
return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls
@classmethod
def get_provider_bind_name(cls, provider):
"""Return provider's name in catalog."""
if not cls.is_provider_bound(provider):
raise Error('Can not find bind name for {0} in catalog {1}'.format(
provider, cls))
return cls.provider_names[provider]
@classmethod
def is_provider_bound(cls, provider):
"""Check if provider is bound to the catalog."""
return provider in cls.provider_names
@classmethod
def filter(cls, provider_type):
"""Return dict of providers, that are instance of provided type."""
@ -238,17 +237,6 @@ class DeclarativeCatalog(object):
AbstractCatalog = DeclarativeCatalog
class ProviderBinding(object):
"""Catalog provider binding."""
__slots__ = ('catalog', 'name')
def __init__(self, catalog, name):
"""Initializer."""
self.catalog = catalog
self.name = name
def override(catalog):
"""Catalog overriding decorator."""
def decorator(overriding_catalog):

View File

@ -19,12 +19,11 @@ class Provider(object):
"""Base provider class."""
__IS_PROVIDER__ = True
__slots__ = ('overridden_by', 'bind')
__slots__ = ('overridden_by',)
def __init__(self):
"""Initializer."""
self.overridden_by = None
self.bind = None
def __call__(self, *args, **kwargs):
"""Return provided instance."""
@ -75,11 +74,6 @@ class Provider(object):
"""Reset all overriding providers."""
self.overridden_by = None
@property
def is_bound(self):
"""Check if provider is bound to any catalog."""
return bool(self.bind)
class Delegate(Provider):
"""Provider's delegate."""

View File

@ -76,17 +76,64 @@ class CatalogProvidersBindingTests(unittest.TestCase):
def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs."""
self.assertIs(CatalogA.p11.bind.catalog, CatalogA)
self.assertEquals(CatalogA.p11.bind.name, 'p11')
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p11), 'p11')
self.assertIs(CatalogA.p12.bind.catalog, CatalogA)
self.assertEquals(CatalogA.p12.bind.name, 'p12')
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p12))
self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p12), 'p12')
def test_provider_rebinding(self):
"""Test that provider could not be bound twice."""
self.assertRaises(di.Error, type, 'TestCatalog',
(di.DeclarativeCatalog,),
dict(some_name=CatalogA.p11))
def test_provider_binding_to_different_catalogs(self):
"""Test that provider could be bound to different catalogs."""
p11 = CatalogA.p11
p12 = CatalogA.p12
class CatalogD(di.DeclarativeCatalog):
"""Test catalog."""
pd1 = p11
pd2 = p12
class CatalogE(di.DeclarativeCatalog):
"""Test catalog."""
pe1 = p11
pe2 = p12
self.assertTrue(CatalogA.is_provider_bound(p11))
self.assertTrue(CatalogD.is_provider_bound(p11))
self.assertTrue(CatalogE.is_provider_bound(p11))
self.assertEquals(CatalogA.get_provider_bind_name(p11), 'p11')
self.assertEquals(CatalogD.get_provider_bind_name(p11), 'pd1')
self.assertEquals(CatalogE.get_provider_bind_name(p11), 'pe1')
self.assertTrue(CatalogA.is_provider_bound(p12))
self.assertTrue(CatalogD.is_provider_bound(p12))
self.assertTrue(CatalogE.is_provider_bound(p12))
self.assertEquals(CatalogA.get_provider_bind_name(p12), 'p12')
self.assertEquals(CatalogD.get_provider_bind_name(p12), 'pd2')
self.assertEquals(CatalogE.get_provider_bind_name(p12), 'pe2')
def test_provider_rebinding_to_the_same_catalog(self):
"""Test provider rebinding to the same catalog."""
with self.assertRaises(di.Error):
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
p2 = p1
def test_provider_rebinding_to_the_same_catalogs_hierarchy(self):
"""Test provider rebinding to the same catalogs hierarchy."""
class TestCatalog1(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
with self.assertRaises(di.Error):
class TestCatalog2(TestCatalog1):
"""Test catalog."""
p2 = TestCatalog1.p1
class CatalogBundleTests(unittest.TestCase):