Change provider-to-catalog binding restrictions

This commit is contained in:
Roman Mogilatov 2015-11-10 15:01:06 +02:00
parent 9f526a3ceb
commit c5602cf88b
3 changed files with 86 additions and 57 deletions

View File

@ -21,14 +21,14 @@ class CatalogBundle(object):
def __init__(self, *providers): def __init__(self, *providers):
"""Initializer.""" """Initializer."""
self.providers = dict((provider.bind.name, provider) self.providers = dict((self.catalog.get_provider_bind_name(provider),
for provider in providers provider)
if self._ensure_provider_is_bound(provider)) for provider in providers)
self.__dict__.update(self.providers) self.__dict__.update(self.providers)
super(CatalogBundle, self).__init__() super(CatalogBundle, self).__init__()
def get(self, name): def get(self, name):
"""Return provider with specified name or raises error.""" """Return provider with specified name or raise an error."""
try: try:
return self.providers[name] return self.providers[name]
except KeyError: except KeyError:
@ -38,16 +38,6 @@ class CatalogBundle(object):
"""Check if there is provider with certain name.""" """Check if there is provider with certain name."""
return name in self.providers return name in self.providers
def _ensure_provider_is_bound(self, provider):
"""Check that provider is bound to the bundle's catalog."""
if not provider.is_bound:
raise Error('Provider {0} is not bound to '
'any catalog'.format(provider))
if provider is not self.catalog.get(provider.bind.name):
raise Error('{0} can contain providers from '
'catalog {0}'.format(self.__class__, self.catalog))
return True
def _raise_undefined_provider_error(self, name): def _raise_undefined_provider_error(self, name):
"""Raise error for cases when there is no such provider in bundle.""" """Raise error for cases when there is no such provider in bundle."""
raise Error('Provider "{0}" is not a part of {1}'.format(name, self)) raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
@ -67,15 +57,6 @@ class CatalogBundle(object):
__str__ = __repr__ __str__ = __repr__
class Catalog(object):
"""Catalog of providers."""
def __init__(self, name, **providers):
"""Initializer."""
self.name = name
self.providers = providers
@six.python_2_unicode_compatible @six.python_2_unicode_compatible
class DeclarativeCatalogMetaClass(type): class DeclarativeCatalogMetaClass(type):
"""Declarative catalog meta class.""" """Declarative catalog meta class."""
@ -105,13 +86,13 @@ class DeclarativeCatalogMetaClass(type):
cls.Bundle = mcs.bundle_cls_factory(cls) cls.Bundle = mcs.bundle_cls_factory(cls)
for name, provider in six.iteritems(cls_providers): cls.provider_names = dict()
if provider.is_bound: for name, provider in six.iteritems(providers):
raise Error('Provider {0} has been already bound to catalog' if provider in cls.provider_names:
'{1} as "{2}"'.format(provider, raise Error('Provider {0} could not be bound to the same '
provider.bind.catalog, 'catalog (or catalogs hierarchy) more '
provider.bind.name)) 'than once'.format(provider))
provider.bind = ProviderBinding(cls, name) cls.provider_names[provider] = name
return cls return cls
@ -151,6 +132,10 @@ class DeclarativeCatalog(object):
:param providers: Dict of all catalog providers, including inherited from :param providers: Dict of all catalog providers, including inherited from
parent catalogs parent catalogs
:type provider_names: dict[dependency_injector.Provider, str]
:param provider_names: Dict of all catalog providers, including inherited
from parent catalogs
:type cls_providers: dict[str, dependency_injector.Provider] :type cls_providers: dict[str, dependency_injector.Provider]
:param cls_providers: Dict of current catalog providers :param cls_providers: Dict of current catalog providers
@ -174,6 +159,7 @@ class DeclarativeCatalog(object):
cls_providers = dict() cls_providers = dict()
inherited_providers = dict() inherited_providers = dict()
providers = dict() providers = dict()
provider_names = dict()
overridden_by = tuple() overridden_by = tuple()
is_overridden = bool is_overridden = bool
@ -186,6 +172,19 @@ class DeclarativeCatalog(object):
"""Check if catalog is bundle owner.""" """Check if catalog is bundle owner."""
return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls
@classmethod
def get_provider_bind_name(cls, provider):
"""Return provider's name in catalog."""
if not cls.is_provider_bound(provider):
raise Error('Can not find bind name for {0} in catalog {1}'.format(
provider, cls))
return cls.provider_names[provider]
@classmethod
def is_provider_bound(cls, provider):
"""Check if provider is bound to the catalog."""
return provider in cls.provider_names
@classmethod @classmethod
def filter(cls, provider_type): def filter(cls, provider_type):
"""Return dict of providers, that are instance of provided type.""" """Return dict of providers, that are instance of provided type."""
@ -238,17 +237,6 @@ class DeclarativeCatalog(object):
AbstractCatalog = DeclarativeCatalog AbstractCatalog = DeclarativeCatalog
class ProviderBinding(object):
"""Catalog provider binding."""
__slots__ = ('catalog', 'name')
def __init__(self, catalog, name):
"""Initializer."""
self.catalog = catalog
self.name = name
def override(catalog): def override(catalog):
"""Catalog overriding decorator.""" """Catalog overriding decorator."""
def decorator(overriding_catalog): def decorator(overriding_catalog):

View File

@ -19,12 +19,11 @@ class Provider(object):
"""Base provider class.""" """Base provider class."""
__IS_PROVIDER__ = True __IS_PROVIDER__ = True
__slots__ = ('overridden_by', 'bind') __slots__ = ('overridden_by',)
def __init__(self): def __init__(self):
"""Initializer.""" """Initializer."""
self.overridden_by = None self.overridden_by = None
self.bind = None
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Return provided instance.""" """Return provided instance."""
@ -75,11 +74,6 @@ class Provider(object):
"""Reset all overriding providers.""" """Reset all overriding providers."""
self.overridden_by = None self.overridden_by = None
@property
def is_bound(self):
"""Check if provider is bound to any catalog."""
return bool(self.bind)
class Delegate(Provider): class Delegate(Provider):
"""Provider's delegate.""" """Provider's delegate."""

View File

@ -76,17 +76,64 @@ class CatalogProvidersBindingTests(unittest.TestCase):
def test_provider_is_bound(self): def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs.""" """Test that providers are bound to the catalogs."""
self.assertIs(CatalogA.p11.bind.catalog, CatalogA) self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
self.assertEquals(CatalogA.p11.bind.name, 'p11') self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p11), 'p11')
self.assertIs(CatalogA.p12.bind.catalog, CatalogA) self.assertTrue(CatalogA.is_provider_bound(CatalogA.p12))
self.assertEquals(CatalogA.p12.bind.name, 'p12') self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p12), 'p12')
def test_provider_rebinding(self): def test_provider_binding_to_different_catalogs(self):
"""Test that provider could not be bound twice.""" """Test that provider could be bound to different catalogs."""
self.assertRaises(di.Error, type, 'TestCatalog', p11 = CatalogA.p11
(di.DeclarativeCatalog,), p12 = CatalogA.p12
dict(some_name=CatalogA.p11))
class CatalogD(di.DeclarativeCatalog):
"""Test catalog."""
pd1 = p11
pd2 = p12
class CatalogE(di.DeclarativeCatalog):
"""Test catalog."""
pe1 = p11
pe2 = p12
self.assertTrue(CatalogA.is_provider_bound(p11))
self.assertTrue(CatalogD.is_provider_bound(p11))
self.assertTrue(CatalogE.is_provider_bound(p11))
self.assertEquals(CatalogA.get_provider_bind_name(p11), 'p11')
self.assertEquals(CatalogD.get_provider_bind_name(p11), 'pd1')
self.assertEquals(CatalogE.get_provider_bind_name(p11), 'pe1')
self.assertTrue(CatalogA.is_provider_bound(p12))
self.assertTrue(CatalogD.is_provider_bound(p12))
self.assertTrue(CatalogE.is_provider_bound(p12))
self.assertEquals(CatalogA.get_provider_bind_name(p12), 'p12')
self.assertEquals(CatalogD.get_provider_bind_name(p12), 'pd2')
self.assertEquals(CatalogE.get_provider_bind_name(p12), 'pe2')
def test_provider_rebinding_to_the_same_catalog(self):
"""Test provider rebinding to the same catalog."""
with self.assertRaises(di.Error):
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
p2 = p1
def test_provider_rebinding_to_the_same_catalogs_hierarchy(self):
"""Test provider rebinding to the same catalogs hierarchy."""
class TestCatalog1(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
with self.assertRaises(di.Error):
class TestCatalog2(TestCatalog1):
"""Test catalog."""
p2 = TestCatalog1.p1
class CatalogBundleTests(unittest.TestCase): class CatalogBundleTests(unittest.TestCase):