diff --git a/dependency_injector/catalog.py b/dependency_injector/catalog.py index f0abb854..06718458 100644 --- a/dependency_injector/catalog.py +++ b/dependency_injector/catalog.py @@ -87,6 +87,8 @@ class CatalogMetaClass(type): cls.inherited_providers = inherited_providers cls.providers = providers + cls.overridden_by = tuple() + cls.Bundle = mcs.bundle_cls_factory(cls) for name, provider in six.iteritems(cls_providers): @@ -104,6 +106,19 @@ class CatalogMetaClass(type): """Create bundle class for catalog.""" return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls)) + @property + def is_overridden(cls): + """Check if catalog is overridden by another catalog.""" + return bool(cls.overridden_by) + + @property + def last_overriding(cls): + """Return last overriding catalog.""" + try: + return cls.overridden_by[-1] + except (TypeError, IndexError): + raise Error('Catalog {0} is not overridden'.format(str(cls))) + def __repr__(cls): """Return string representation of the catalog class.""" return '' @@ -113,6 +128,9 @@ class CatalogMetaClass(type): class AbstractCatalog(object): """Abstract providers catalog. + :type Bundle: CatalogBundle + :param Bundle: Catalog's bundle class + :type providers: dict[str, dependency_injector.Provider] :param providers: Dict of all catalog providers, including inherited from parent catalogs @@ -124,8 +142,15 @@ class AbstractCatalog(object): :param inherited_providers: Dict of providers, that are inherited from parent catalogs - :type Bundle: CatalogBundle - :param Bundle: Catalog's bundle class + :type overridden_by: tuple[AbstractCatalog] + :param overridden_by: Tuple of overriding catalogs + + :type is_overridden: bool + :param is_overridden: Read-only, evaluated in runtime, property that is + set to True if catalog is overridden + + :type last_overriding: AbstractCatalog | None + :param last_overriding: Reference to the last overriding catalog, if any """ Bundle = CatalogBundle @@ -134,6 +159,10 @@ class AbstractCatalog(object): inherited_providers = dict() providers = dict() + overridden_by = tuple() + is_overridden = bool + last_overriding = None + __IS_CATALOG__ = True @classmethod @@ -154,9 +183,26 @@ class AbstractCatalog(object): :type overriding: AbstractCatalog """ + cls.overridden_by += (overriding,) for name, provider in six.iteritems(overriding.cls_providers): cls.providers[name].override(provider) + @classmethod + def reset_last_overriding(cls): + """Reset last overriding catalog.""" + if not cls.is_overridden: + raise Error('Catalog {0} is not overridden'.format(str(cls))) + cls.overridden_by = cls.overridden_by[:-1] + for provider in six.itervalues(cls.providers): + provider.reset_last_overriding() + + @classmethod + def reset_override(cls): + """Reset all overridings for all catalog providers.""" + cls.overridden_by = tuple() + for provider in six.itervalues(cls.providers): + provider.reset_override() + @classmethod def get(cls, name): """Return provider with specified name or raises error.""" diff --git a/docs/catalogs/overriding.rst b/docs/catalogs/overriding.rst index 4cd3a949..03ba19e6 100644 --- a/docs/catalogs/overriding.rst +++ b/docs/catalogs/overriding.rst @@ -20,3 +20,15 @@ Example of overriding catalog using ``@di.override()`` decorator: .. literalinclude:: ../../examples/catalogs/override_decorator.py :language: python +Also there are several useful methods and properties that help to work with +catalog overridings: + +- ``di.AbstractCatalog.is_overridden`` - read-only, evaluated in runtime, + property that is set to True if catalog is overridden. +- ``di.AbstractCatalog.last_overriding`` - reference to the last overriding + catalog, if any. +- ``di.AbstractCatalog.overridden_by`` - tuple of all overriding catalogs. +- ``di.AbstractCatalog.reset_last_overriding()`` - reset last overriding + catalog. +- ``di.AbstractCatalog.reset_override()`` - reset all overridings for all + catalog providers. diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 129df1d1..3597af20 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -12,7 +12,8 @@ Development version ------------------- - Add functionality for creating ``di.AbstractCatalog`` provider bundles. -- Enhance ``di.AbstractCatalog`` inheritance. +- Improve ``di.AbstractCatalog`` inheritance. +- Improve ``di.AbstractCatalog`` overriding. - Add images for catalog "Writing catalogs" and "Operating with catalogs" examples. - Add functionality for using positional argument injections with diff --git a/tests/test_catalog.py b/tests/test_catalog.py index f55cff9e..779a343d 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -198,7 +198,7 @@ class CatalogTests(unittest.TestCase): class OverrideTests(unittest.TestCase): - """Override decorator test cases.""" + """Catalog overriding and override decorator test cases.""" class Catalog(di.AbstractCatalog): """Test catalog.""" @@ -206,6 +206,10 @@ class OverrideTests(unittest.TestCase): obj = di.Object(object()) another_obj = di.Object(object()) + def tearDown(self): + """Tear test environment down.""" + self.Catalog.reset_override() + def test_overriding(self): """Test catalog overriding with another catalog.""" @di.override(self.Catalog) @@ -217,3 +221,77 @@ class OverrideTests(unittest.TestCase): self.assertEqual(self.Catalog.obj(), 1) self.assertEqual(self.Catalog.another_obj(), 2) + + def test_is_overridden(self): + """Test catalog is_overridden property.""" + self.assertFalse(self.Catalog.is_overridden) + + @di.override(self.Catalog) + class OverridingCatalog(self.Catalog): + """Overriding catalog.""" + + self.assertTrue(self.Catalog.is_overridden) + + def test_last_overriding(self): + """Test catalog last_overriding property.""" + @di.override(self.Catalog) + class OverridingCatalog1(self.Catalog): + """Overriding catalog.""" + + @di.override(self.Catalog) + class OverridingCatalog2(self.Catalog): + """Overriding catalog.""" + + self.assertIs(self.Catalog.last_overriding, OverridingCatalog2) + + def test_last_overriding_on_not_overridden(self): + """Test catalog last_overriding property on not overridden catalog.""" + with self.assertRaises(di.Error): + self.Catalog.last_overriding + + def test_reset_last_overriding(self): + """Test resetting last overriding catalog.""" + @di.override(self.Catalog) + class OverridingCatalog1(self.Catalog): + """Overriding catalog.""" + + obj = di.Value(1) + another_obj = di.Value(2) + + @di.override(self.Catalog) + class OverridingCatalog2(self.Catalog): + """Overriding catalog.""" + + obj = di.Value(3) + another_obj = di.Value(4) + + self.Catalog.reset_last_overriding() + + self.assertEqual(self.Catalog.obj(), 1) + self.assertEqual(self.Catalog.another_obj(), 2) + + def test_reset_last_overriding_when_not_overridden(self): + """Test resetting last overriding catalog when it is not overridden.""" + with self.assertRaises(di.Error): + self.Catalog.reset_last_overriding() + + def test_reset_override(self): + """Test resetting all catalog overrides.""" + @di.override(self.Catalog) + class OverridingCatalog1(self.Catalog): + """Overriding catalog.""" + + obj = di.Value(1) + another_obj = di.Value(2) + + @di.override(self.Catalog) + class OverridingCatalog2(self.Catalog): + """Overriding catalog.""" + + obj = di.Value(3) + another_obj = di.Value(4) + + self.Catalog.reset_override() + + self.assertIsInstance(self.Catalog.obj(), object) + self.assertIsInstance(self.Catalog.another_obj(), object)