Implement DynamicCatalog

This commit is contained in:
Roman Mogilatov 2015-11-10 17:58:04 +02:00
parent 8b175f0b71
commit b456d770b6
3 changed files with 185 additions and 71 deletions

View File

@ -2,6 +2,7 @@
from .catalog import DeclarativeCatalog from .catalog import DeclarativeCatalog
from .catalog import AbstractCatalog from .catalog import AbstractCatalog
from .catalog import DynamicCatalog
from .catalog import CatalogBundle from .catalog import CatalogBundle
from .catalog import override from .catalog import override
@ -47,6 +48,7 @@ __all__ = (
# Catalogs # Catalogs
'DeclarativeCatalog', 'DeclarativeCatalog',
'AbstractCatalog', 'AbstractCatalog',
'DynamicCatalog',
'CatalogBundle', 'CatalogBundle',
'override', 'override',

View File

@ -6,9 +6,117 @@ from .errors import Error
from .utils import is_provider from .utils import is_provider
from .utils import is_catalog from .utils import is_catalog
from .utils import ensure_is_provider
from .utils import ensure_is_catalog_bundle from .utils import ensure_is_catalog_bundle
@six.python_2_unicode_compatible
class DynamicCatalog(object):
"""Catalog of providers."""
__IS_CATALOG__ = True
__slots__ = ('name', 'Bundle', 'providers', 'provider_names',
'overridden_by')
def __init__(self, name, **providers):
"""Initializer.
:param name: Catalog's name
:type name: str
:param kwargs: Dict of providers with their catalog names
:type kwargs: dict[str, dependency_injector.providers.Provider]
"""
self.name = name
self.Bundle = CatalogBundle.sub_cls_factory(self)
self.providers = dict()
self.provider_names = dict()
for name, provider in six.iteritems(providers):
provider = ensure_is_provider(provider)
if provider in self.provider_names:
raise Error('Provider {0} could not be bound to the same '
'catalog (or catalogs hierarchy) more '
'than once'.format(provider))
self.provider_names[provider] = name
self.providers[name] = provider
self.overridden_by = tuple()
def is_bundle_owner(self, bundle):
"""Check if catalog is bundle owner."""
return ensure_is_catalog_bundle(bundle) and bundle.catalog is self
def get_provider_bind_name(self, provider):
"""Return provider's name in catalog."""
if not self.is_provider_bound(provider):
raise Error('Can not find bind name for {0} in catalog {1}'.format(
provider, self))
return self.provider_names[provider]
def is_provider_bound(self, provider):
"""Check if provider is bound to the catalog."""
return provider in self.provider_names
def filter(self, provider_type):
"""Return dict of providers, that are instance of provided type."""
return dict((name, provider)
for name, provider in six.iteritems(self.providers)
if isinstance(provider, provider_type))
@property
def is_overridden(self):
"""Check if catalog is overridden by another catalog."""
return bool(self.overridden_by)
@property
def last_overriding(self):
"""Return last overriding catalog."""
try:
return self.overridden_by[-1]
except (TypeError, IndexError):
raise Error('Catalog {0} is not overridden'.format(self))
def override(self, overriding):
"""Override current catalog providers by overriding catalog providers.
:type overriding: DynamicCatalog
"""
self.overridden_by += (overriding,)
for name, provider in six.iteritems(overriding.providers):
self.get(name).override(provider)
def reset_last_overriding(self):
"""Reset last overriding catalog."""
if not self.is_overridden:
raise Error('Catalog {0} is not overridden'.format(self))
self.overridden_by = self.overridden_by[:-1]
for provider in six.itervalues(self.providers):
provider.reset_last_overriding()
def reset_override(self):
"""Reset all overridings for all catalog providers."""
self.overridden_by = tuple()
for provider in six.itervalues(self.providers):
provider.reset_override()
def get(self, name):
"""Return provider with specified name or raise an error."""
try:
return self.providers[name]
except KeyError:
raise Error('{0} has no provider with such name - {1}'.format(
self, name))
def has(self, name):
"""Check if there is provider with certain name."""
return name in self.providers
def __repr__(self):
"""Return Python representation of catalog."""
return '<DynamicCatalog {0}>'.format(self.name)
__str__ = __repr__
@six.python_2_unicode_compatible @six.python_2_unicode_compatible
class CatalogBundle(object): class CatalogBundle(object):
"""Bundle of catalog providers.""" """Bundle of catalog providers."""
@ -27,32 +135,34 @@ class CatalogBundle(object):
self.__dict__.update(self.providers) self.__dict__.update(self.providers)
super(CatalogBundle, self).__init__() super(CatalogBundle, self).__init__()
@classmethod
def sub_cls_factory(cls, catalog):
"""Create bundle class for catalog."""
return type('{0}Bundle'.format(catalog.name), (cls,),
dict(catalog=catalog))
def get(self, name): def get(self, name):
"""Return provider with specified name or raise an error.""" """Return provider with specified name or raise an error."""
try: try:
return self.providers[name] return self.providers[name]
except KeyError: except KeyError:
self._raise_undefined_provider_error(name) raise Error('Provider "{0}" is not a part of {1}'.format(name,
self))
def has(self, name): def has(self, name):
"""Check if there is provider with certain name.""" """Check if there is provider with certain name."""
return name in self.providers return name in self.providers
def _raise_undefined_provider_error(self, name):
"""Raise error for cases when there is no such provider in bundle."""
raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
def __getattr__(self, item): def __getattr__(self, item):
"""Raise an error on every attempt to get undefined provider.""" """Raise an error on every attempt to get undefined provider."""
if item.startswith('__') and item.endswith('__'): if item.startswith('__') and item.endswith('__'):
return super(CatalogBundle, self).__getattr__(item) return super(CatalogBundle, self).__getattr__(item)
self._raise_undefined_provider_error(item) raise Error('Provider "{0}" is not a part of {1}'.format(item, self))
def __repr__(self): def __repr__(self):
"""Return string representation of catalog bundle.""" """Return string representation of catalog bundle."""
return '<{0}.{1}.Bundle({2})>'.format( return '<{0}.Bundle({1})>'.format(
self.catalog.__module__, self.catalog.__name__, self.catalog.name, ', '.join(six.iterkeys(self.providers)))
', '.join(six.iterkeys(self.providers)))
__str__ = __repr__ __str__ = __repr__
@ -65,9 +175,6 @@ class DeclarativeCatalogMetaClass(type):
"""Declarative catalog class factory.""" """Declarative catalog class factory."""
cls = type.__new__(mcs, class_name, bases, attributes) cls = type.__new__(mcs, class_name, bases, attributes)
cls.Bundle = mcs.bundle_cls_factory(cls)
cls.overridden_by = tuple()
cls_providers = tuple((name, provider) cls_providers = tuple((name, provider)
for name, provider in six.iteritems(attributes) for name, provider in six.iteritems(attributes)
if is_provider(provider)) if is_provider(provider))
@ -79,40 +186,38 @@ class DeclarativeCatalogMetaClass(type):
providers = cls_providers + inherited_providers providers = cls_providers + inherited_providers
cls.name = '.'.join((cls.__module__, cls.__name__))
cls.catalog = DynamicCatalog(cls.name, **dict(providers))
cls.Bundle = cls.catalog.Bundle
cls.cls_providers = dict(cls_providers) cls.cls_providers = dict(cls_providers)
cls.inherited_providers = dict(inherited_providers) cls.inherited_providers = dict(inherited_providers)
cls.providers = dict(providers)
cls.provider_names = dict()
for name, provider in providers:
if provider in cls.provider_names:
raise Error('Provider {0} could not be bound to the same '
'catalog (or catalogs hierarchy) more '
'than once'.format(provider))
cls.provider_names[provider] = name
return cls return cls
@classmethod @property
def bundle_cls_factory(mcs, cls): def providers(cls):
"""Create bundle class for catalog.""" """Return dict of catalog's providers."""
return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls)) return cls.catalog.providers
@property
def overridden_by(cls):
"""Return tuple of overriding catalogs."""
return cls.catalog.overridden_by
@property @property
def is_overridden(cls): def is_overridden(cls):
"""Check if catalog is overridden by another catalog.""" """Check if catalog is overridden by another catalog."""
return bool(cls.overridden_by) return cls.catalog.is_overridden
@property @property
def last_overriding(cls): def last_overriding(cls):
"""Return last overriding catalog.""" """Return last overriding catalog."""
try: return cls.catalog.last_overriding
return cls.overridden_by[-1]
except (TypeError, IndexError):
raise Error('Catalog {0} is not overridden'.format(str(cls)))
def __repr__(cls): def __repr__(cls):
"""Return string representation of the catalog class.""" """Return string representation of the catalog."""
return '<{0}.{1}>'.format(cls.__module__, cls.__name__) return '<DeclarativeCatalog {0}>'.format(cls.name)
__str__ = __repr__ __str__ = __repr__
@ -121,6 +226,12 @@ class DeclarativeCatalogMetaClass(type):
class DeclarativeCatalog(object): class DeclarativeCatalog(object):
"""Declarative catalog catalog of providers. """Declarative catalog catalog of providers.
:type name: str
:param name: Catalog's name
:type catalog: DynamicCatalog
:param catalog: Instance of dynamic catalog
:type Bundle: CatalogBundle :type Bundle: CatalogBundle
:param Bundle: Catalog's bundle class :param Bundle: Catalog's bundle class
@ -128,10 +239,6 @@ class DeclarativeCatalog(object):
:param providers: Dict of all catalog providers, including inherited from :param providers: Dict of all catalog providers, including inherited from
parent catalogs parent catalogs
:type provider_names: dict[dependency_injector.Provider, str]
:param provider_names: Dict of all catalog providers, including inherited
from parent catalogs
:type cls_providers: dict[str, dependency_injector.Provider] :type cls_providers: dict[str, dependency_injector.Provider]
:param cls_providers: Dict of current catalog providers :param cls_providers: Dict of current catalog providers
@ -150,12 +257,13 @@ class DeclarativeCatalog(object):
:param last_overriding: Reference to the last overriding catalog, if any :param last_overriding: Reference to the last overriding catalog, if any
""" """
name = str()
catalog = DynamicCatalog
Bundle = CatalogBundle Bundle = CatalogBundle
cls_providers = dict() cls_providers = dict()
inherited_providers = dict() inherited_providers = dict()
providers = dict() providers = dict()
provider_names = dict()
overridden_by = tuple() overridden_by = tuple()
is_overridden = bool is_overridden = bool
@ -166,67 +274,50 @@ class DeclarativeCatalog(object):
@classmethod @classmethod
def is_bundle_owner(cls, bundle): def is_bundle_owner(cls, bundle):
"""Check if catalog is bundle owner.""" """Check if catalog is bundle owner."""
return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls return cls.catalog.is_bundle_owner(bundle)
@classmethod @classmethod
def get_provider_bind_name(cls, provider): def get_provider_bind_name(cls, provider):
"""Return provider's name in catalog.""" """Return provider's name in catalog."""
if not cls.is_provider_bound(provider): return cls.catalog.get_provider_bind_name(provider)
raise Error('Can not find bind name for {0} in catalog {1}'.format(
provider, cls))
return cls.provider_names[provider]
@classmethod @classmethod
def is_provider_bound(cls, provider): def is_provider_bound(cls, provider):
"""Check if provider is bound to the catalog.""" """Check if provider is bound to the catalog."""
return provider in cls.provider_names return cls.catalog.is_provider_bound(provider)
@classmethod @classmethod
def filter(cls, provider_type): def filter(cls, provider_type):
"""Return dict of providers, that are instance of provided type.""" """Return dict of providers, that are instance of provided type."""
return dict((name, provider) return cls.catalog.filter(provider_type)
for name, provider in six.iteritems(cls.providers)
if isinstance(provider, provider_type))
@classmethod @classmethod
def override(cls, overriding): def override(cls, overriding):
"""Override current catalog providers by overriding catalog providers. """Override current catalog providers by overriding catalog providers.
:type overriding: DeclarativeCatalog :type overriding: DeclarativeCatalog | DynamicCatalog
""" """
cls.overridden_by += (overriding,) cls.catalog.override(overriding)
for name, provider in six.iteritems(overriding.cls_providers):
cls.providers[name].override(provider)
@classmethod @classmethod
def reset_last_overriding(cls): def reset_last_overriding(cls):
"""Reset last overriding catalog.""" """Reset last overriding catalog."""
if not cls.is_overridden: cls.catalog.reset_last_overriding()
raise Error('Catalog {0} is not overridden'.format(str(cls)))
cls.overridden_by = cls.overridden_by[:-1]
for provider in six.itervalues(cls.providers):
provider.reset_last_overriding()
@classmethod @classmethod
def reset_override(cls): def reset_override(cls):
"""Reset all overridings for all catalog providers.""" """Reset all overridings for all catalog providers."""
cls.overridden_by = tuple() cls.catalog.reset_override()
for provider in six.itervalues(cls.providers):
provider.reset_override()
@classmethod @classmethod
def get(cls, name): def get(cls, name):
"""Return provider with specified name or raises error.""" """Return provider with specified name or raises error."""
try: return cls.catalog.get(name)
return cls.providers[name]
except KeyError:
raise Error('{0} has no provider with such name - {1}'.format(
cls, name))
@classmethod @classmethod
def has(cls, name): def has(cls, name):
"""Check if there is provider with certain name.""" """Check if there is provider with certain name."""
return name in cls.providers return cls.catalog.has(name)
# Backward compatibility for versions < 0.11.* # Backward compatibility for versions < 0.11.*

View File

@ -261,7 +261,7 @@ class OverrideTests(unittest.TestCase):
def test_overriding(self): def test_overriding(self):
"""Test catalog overriding with another catalog.""" """Test catalog overriding with another catalog."""
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog(self.Catalog): class OverridingCatalog(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
obj = di.Value(1) obj = di.Value(1)
@ -269,13 +269,23 @@ class OverrideTests(unittest.TestCase):
self.assertEqual(self.Catalog.obj(), 1) self.assertEqual(self.Catalog.obj(), 1)
self.assertEqual(self.Catalog.another_obj(), 2) self.assertEqual(self.Catalog.another_obj(), 2)
self.assertEqual(len(self.Catalog.overridden_by), 1)
def test_overriding_with_dynamic_catalog(self):
"""Test catalog overriding with another dynamic catalog."""
self.Catalog.override(di.DynamicCatalog('OverridingCatalog',
obj=di.Value(1),
another_obj=di.Value(2)))
self.assertEqual(self.Catalog.obj(), 1)
self.assertEqual(self.Catalog.another_obj(), 2)
self.assertEqual(len(self.Catalog.overridden_by), 1)
def test_is_overridden(self): def test_is_overridden(self):
"""Test catalog is_overridden property.""" """Test catalog is_overridden property."""
self.assertFalse(self.Catalog.is_overridden) self.assertFalse(self.Catalog.is_overridden)
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog(self.Catalog): class OverridingCatalog(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
self.assertTrue(self.Catalog.is_overridden) self.assertTrue(self.Catalog.is_overridden)
@ -283,11 +293,11 @@ class OverrideTests(unittest.TestCase):
def test_last_overriding(self): def test_last_overriding(self):
"""Test catalog last_overriding property.""" """Test catalog last_overriding property."""
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog1(self.Catalog): class OverridingCatalog1(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog2(self.Catalog): class OverridingCatalog2(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
self.assertIs(self.Catalog.last_overriding, OverridingCatalog2) self.assertIs(self.Catalog.last_overriding, OverridingCatalog2)
@ -300,14 +310,14 @@ class OverrideTests(unittest.TestCase):
def test_reset_last_overriding(self): def test_reset_last_overriding(self):
"""Test resetting last overriding catalog.""" """Test resetting last overriding catalog."""
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog1(self.Catalog): class OverridingCatalog1(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
obj = di.Value(1) obj = di.Value(1)
another_obj = di.Value(2) another_obj = di.Value(2)
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog2(self.Catalog): class OverridingCatalog2(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
obj = di.Value(3) obj = di.Value(3)
@ -326,14 +336,14 @@ class OverrideTests(unittest.TestCase):
def test_reset_override(self): def test_reset_override(self):
"""Test resetting all catalog overrides.""" """Test resetting all catalog overrides."""
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog1(self.Catalog): class OverridingCatalog1(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
obj = di.Value(1) obj = di.Value(1)
another_obj = di.Value(2) another_obj = di.Value(2)
@di.override(self.Catalog) @di.override(self.Catalog)
class OverridingCatalog2(self.Catalog): class OverridingCatalog2(di.DeclarativeCatalog):
"""Overriding catalog.""" """Overriding catalog."""
obj = di.Value(3) obj = di.Value(3)
@ -345,6 +355,17 @@ class OverrideTests(unittest.TestCase):
self.assertIsInstance(self.Catalog.another_obj(), object) self.assertIsInstance(self.Catalog.another_obj(), object)
class DeclarativeCatalogReprTest(unittest.TestCase):
"""Tests for declarative catalog representation."""
def test_repr(self):
"""Test declarative catalog representation."""
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
self.assertIn('TestCatalog', repr(TestCatalog))
class AbstractCatalogCompatibilityTest(unittest.TestCase): class AbstractCatalogCompatibilityTest(unittest.TestCase):
"""Test backward compatibility with di.AbstractCatalog.""" """Test backward compatibility with di.AbstractCatalog."""