Improve functionality of catalog overriding

This commit is contained in:
Roman Mogilatov 2015-10-23 16:23:12 +03:00
parent f00b3612f8
commit b269023f7a
4 changed files with 141 additions and 4 deletions

View File

@ -87,6 +87,8 @@ class CatalogMetaClass(type):
cls.inherited_providers = inherited_providers cls.inherited_providers = inherited_providers
cls.providers = providers cls.providers = providers
cls.overridden_by = tuple()
cls.Bundle = mcs.bundle_cls_factory(cls) cls.Bundle = mcs.bundle_cls_factory(cls)
for name, provider in six.iteritems(cls_providers): for name, provider in six.iteritems(cls_providers):
@ -104,6 +106,19 @@ class CatalogMetaClass(type):
"""Create bundle class for catalog.""" """Create bundle class for catalog."""
return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls)) return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls))
@property
def is_overridden(cls):
"""Check if catalog is overridden by another catalog."""
return bool(cls.overridden_by)
@property
def last_overriding(cls):
"""Return last overriding catalog."""
try:
return cls.overridden_by[-1]
except (TypeError, IndexError):
raise Error('Catalog {0} is not overridden'.format(str(cls)))
def __repr__(cls): def __repr__(cls):
"""Return string representation of the catalog class.""" """Return string representation of the catalog class."""
return '<Catalog "' + '.'.join((cls.__module__, cls.__name__)) + '">' return '<Catalog "' + '.'.join((cls.__module__, cls.__name__)) + '">'
@ -113,6 +128,9 @@ class CatalogMetaClass(type):
class AbstractCatalog(object): class AbstractCatalog(object):
"""Abstract providers catalog. """Abstract providers catalog.
:type Bundle: CatalogBundle
:param Bundle: Catalog's bundle class
:type providers: dict[str, dependency_injector.Provider] :type providers: dict[str, dependency_injector.Provider]
:param providers: Dict of all catalog providers, including inherited from :param providers: Dict of all catalog providers, including inherited from
parent catalogs parent catalogs
@ -124,8 +142,15 @@ class AbstractCatalog(object):
:param inherited_providers: Dict of providers, that are inherited from :param inherited_providers: Dict of providers, that are inherited from
parent catalogs parent catalogs
:type Bundle: CatalogBundle :type overridden_by: tuple[AbstractCatalog]
:param Bundle: Catalog's bundle class :param overridden_by: Tuple of overriding catalogs
:type is_overridden: bool
:param is_overridden: Read-only, evaluated in runtime, property that is
set to True if catalog is overridden
:type last_overriding: AbstractCatalog | None
:param last_overriding: Reference to the last overriding catalog, if any
""" """
Bundle = CatalogBundle Bundle = CatalogBundle
@ -134,6 +159,10 @@ class AbstractCatalog(object):
inherited_providers = dict() inherited_providers = dict()
providers = dict() providers = dict()
overridden_by = tuple()
is_overridden = bool
last_overriding = None
__IS_CATALOG__ = True __IS_CATALOG__ = True
@classmethod @classmethod
@ -154,9 +183,26 @@ class AbstractCatalog(object):
:type overriding: AbstractCatalog :type overriding: AbstractCatalog
""" """
cls.overridden_by += (overriding,)
for name, provider in six.iteritems(overriding.cls_providers): for name, provider in six.iteritems(overriding.cls_providers):
cls.providers[name].override(provider) cls.providers[name].override(provider)
@classmethod
def reset_last_overriding(cls):
"""Reset last overriding catalog."""
if not cls.is_overridden:
raise Error('Catalog {0} is not overridden'.format(str(cls)))
cls.overridden_by = cls.overridden_by[:-1]
for provider in six.itervalues(cls.providers):
provider.reset_last_overriding()
@classmethod
def reset_override(cls):
"""Reset all overridings for all catalog providers."""
cls.overridden_by = tuple()
for provider in six.itervalues(cls.providers):
provider.reset_override()
@classmethod @classmethod
def get(cls, name): def get(cls, name):
"""Return provider with specified name or raises error.""" """Return provider with specified name or raises error."""

View File

@ -20,3 +20,15 @@ Example of overriding catalog using ``@di.override()`` decorator:
.. literalinclude:: ../../examples/catalogs/override_decorator.py .. literalinclude:: ../../examples/catalogs/override_decorator.py
:language: python :language: python
Also there are several useful methods and properties that help to work with
catalog overridings:
- ``di.AbstractCatalog.is_overridden`` - read-only, evaluated in runtime,
property that is set to True if catalog is overridden.
- ``di.AbstractCatalog.last_overriding`` - reference to the last overriding
catalog, if any.
- ``di.AbstractCatalog.overridden_by`` - tuple of all overriding catalogs.
- ``di.AbstractCatalog.reset_last_overriding()`` - reset last overriding
catalog.
- ``di.AbstractCatalog.reset_override()`` - reset all overridings for all
catalog providers.

View File

@ -12,7 +12,8 @@ Development version
------------------- -------------------
- Add functionality for creating ``di.AbstractCatalog`` provider bundles. - Add functionality for creating ``di.AbstractCatalog`` provider bundles.
- Enhance ``di.AbstractCatalog`` inheritance. - Improve ``di.AbstractCatalog`` inheritance.
- Improve ``di.AbstractCatalog`` overriding.
- Add images for catalog "Writing catalogs" and "Operating with catalogs" - Add images for catalog "Writing catalogs" and "Operating with catalogs"
examples. examples.
- Add functionality for using positional argument injections with - Add functionality for using positional argument injections with

View File

@ -198,7 +198,7 @@ class CatalogTests(unittest.TestCase):
class OverrideTests(unittest.TestCase): class OverrideTests(unittest.TestCase):
"""Override decorator test cases.""" """Catalog overriding and override decorator test cases."""
class Catalog(di.AbstractCatalog): class Catalog(di.AbstractCatalog):
"""Test catalog.""" """Test catalog."""
@ -206,6 +206,10 @@ class OverrideTests(unittest.TestCase):
obj = di.Object(object()) obj = di.Object(object())
another_obj = di.Object(object()) another_obj = di.Object(object())
def tearDown(self):
"""Tear test environment down."""
self.Catalog.reset_override()
def test_overriding(self): def test_overriding(self):
"""Test catalog overriding with another catalog.""" """Test catalog overriding with another catalog."""
@di.override(self.Catalog) @di.override(self.Catalog)
@ -217,3 +221,77 @@ class OverrideTests(unittest.TestCase):
self.assertEqual(self.Catalog.obj(), 1) self.assertEqual(self.Catalog.obj(), 1)
self.assertEqual(self.Catalog.another_obj(), 2) self.assertEqual(self.Catalog.another_obj(), 2)
def test_is_overridden(self):
"""Test catalog is_overridden property."""
self.assertFalse(self.Catalog.is_overridden)
@di.override(self.Catalog)
class OverridingCatalog(self.Catalog):
"""Overriding catalog."""
self.assertTrue(self.Catalog.is_overridden)
def test_last_overriding(self):
"""Test catalog last_overriding property."""
@di.override(self.Catalog)
class OverridingCatalog1(self.Catalog):
"""Overriding catalog."""
@di.override(self.Catalog)
class OverridingCatalog2(self.Catalog):
"""Overriding catalog."""
self.assertIs(self.Catalog.last_overriding, OverridingCatalog2)
def test_last_overriding_on_not_overridden(self):
"""Test catalog last_overriding property on not overridden catalog."""
with self.assertRaises(di.Error):
self.Catalog.last_overriding
def test_reset_last_overriding(self):
"""Test resetting last overriding catalog."""
@di.override(self.Catalog)
class OverridingCatalog1(self.Catalog):
"""Overriding catalog."""
obj = di.Value(1)
another_obj = di.Value(2)
@di.override(self.Catalog)
class OverridingCatalog2(self.Catalog):
"""Overriding catalog."""
obj = di.Value(3)
another_obj = di.Value(4)
self.Catalog.reset_last_overriding()
self.assertEqual(self.Catalog.obj(), 1)
self.assertEqual(self.Catalog.another_obj(), 2)
def test_reset_last_overriding_when_not_overridden(self):
"""Test resetting last overriding catalog when it is not overridden."""
with self.assertRaises(di.Error):
self.Catalog.reset_last_overriding()
def test_reset_override(self):
"""Test resetting all catalog overrides."""
@di.override(self.Catalog)
class OverridingCatalog1(self.Catalog):
"""Overriding catalog."""
obj = di.Value(1)
another_obj = di.Value(2)
@di.override(self.Catalog)
class OverridingCatalog2(self.Catalog):
"""Overriding catalog."""
obj = di.Value(3)
another_obj = di.Value(4)
self.Catalog.reset_override()
self.assertIsInstance(self.Catalog.obj(), object)
self.assertIsInstance(self.Catalog.another_obj(), object)