diff --git a/dependency_injector/__init__.py b/dependency_injector/__init__.py index 8420bd47..5e38585a 100644 --- a/dependency_injector/__init__.py +++ b/dependency_injector/__init__.py @@ -1,6 +1,7 @@ """Dependency injector.""" from .catalog import AbstractCatalog +from .catalog import CatalogBundle from .catalog import override from .providers import Provider @@ -30,6 +31,8 @@ from .utils import is_kwarg_injection from .utils import is_attribute_injection from .utils import is_method_injection from .utils import is_catalog +from .utils import is_catalog_bundle +from .utils import ensure_is_catalog_bundle from .errors import Error @@ -37,6 +40,7 @@ from .errors import Error __all__ = ( # Catalogs 'AbstractCatalog', + 'CatalogBundle', 'override', # Providers @@ -69,6 +73,8 @@ __all__ = ( 'is_attribute_injection', 'is_method_injection', 'is_catalog', + 'is_catalog_bundle', + 'ensure_is_catalog_bundle', # Errors 'Error', diff --git a/dependency_injector/catalog.py b/dependency_injector/catalog.py index 61f1663c..f0abb854 100644 --- a/dependency_injector/catalog.py +++ b/dependency_injector/catalog.py @@ -6,6 +6,7 @@ from .errors import Error from .utils import is_provider from .utils import is_catalog +from .utils import ensure_is_catalog_bundle class CatalogBundle(object): @@ -14,6 +15,7 @@ class CatalogBundle(object): catalog = None """:type: AbstractCatalog""" + __IS_CATALOG_BUNDLE__ = True __slots__ = ('providers', '__dict__') def __init__(self, *providers): @@ -134,6 +136,11 @@ class AbstractCatalog(object): __IS_CATALOG__ = True + @classmethod + def is_bundle_owner(cls, bundle): + """Check if catalog is bundle owner.""" + return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls + @classmethod def filter(cls, provider_type): """Return dict of providers, that are instance of provided type.""" diff --git a/dependency_injector/utils.py b/dependency_injector/utils.py index 42a18202..29547b9c 100644 --- a/dependency_injector/utils.py +++ b/dependency_injector/utils.py @@ -17,7 +17,10 @@ def is_provider(instance): def ensure_is_provider(instance): - """Check if instance is provider instance, otherwise raise and error.""" + """Check if instance is provider instance and return it. + + :raise: Error if provided instance is not provider. + """ if not is_provider(instance): raise Error('Expected provider instance, ' 'got {0}'.format(str(instance))) @@ -62,6 +65,23 @@ def is_catalog(instance): getattr(instance, '__IS_CATALOG__', False) is True) +def is_catalog_bundle(instance): + """Check if instance is catalog bundle instance.""" + return (not isinstance(instance, six.class_types) and + getattr(instance, '__IS_CATALOG_BUNDLE__', False) is True) + + +def ensure_is_catalog_bundle(instance): + """Check if instance is catalog bundle instance and return it. + + :raise: Error if provided instance is not catalog bundle. + """ + if not is_catalog_bundle(instance): + raise Error('Expected catalog bundle instance, ' + 'got {0}'.format(str(instance))) + return instance + + def get_injectable_kwargs(kwargs, injections): """Return dictionary of kwargs, patched with injections.""" init_kwargs = dict(((injection.name, injection.value) diff --git a/tests/test_catalog.py b/tests/test_catalog.py index c5e1e182..f55cff9e 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -154,6 +154,16 @@ class CatalogBundleTests(unittest.TestCase): self.assertRaises(di.Error, CatalogC.Bundle, CatalogC.p31, TestCatalog.p31) + def test_is_bundle_owner(self): + """Test that catalog bundle is owned by catalog.""" + self.assertTrue(CatalogC.is_bundle_owner(self.bundle)) + self.assertFalse(CatalogB.is_bundle_owner(self.bundle)) + self.assertFalse(CatalogA.is_bundle_owner(self.bundle)) + + def test_is_bundle_owner_with_not_bundle_instance(self): + """Test that check of bundle ownership raises error with not bundle.""" + self.assertRaises(di.Error, CatalogC.is_bundle_owner, object()) + class CatalogTests(unittest.TestCase): """Catalog test cases.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index de27bab1..61705ff1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -208,3 +208,45 @@ class IsCatalogTests(unittest.TestCase): def test_with_object(self): """Test with object.""" self.assertFalse(di.is_catalog(object())) + + +class IsCatalogBundleTests(unittest.TestCase): + """`is_catalog_bundle()` test cases.""" + + def test_with_instance(self): + """Test with instance.""" + self.assertTrue(di.is_catalog_bundle(di.CatalogBundle())) + + def test_with_cls(self): + """Test with class.""" + self.assertFalse(di.is_catalog_bundle(di.CatalogBundle)) + + def test_with_string(self): + """Test with string.""" + self.assertFalse(di.is_catalog_bundle('some_string')) + + def test_with_object(self): + """Test with object.""" + self.assertFalse(di.is_catalog_bundle(object())) + + +class EnsureIsCatalogBundleTests(unittest.TestCase): + """`ensure_is_catalog_bundle` test cases.""" + + def test_with_instance(self): + """Test with instance.""" + bundle = di.CatalogBundle() + self.assertIs(di.ensure_is_catalog_bundle(bundle), bundle) + + def test_with_class(self): + """Test with class.""" + self.assertRaises(di.Error, di.ensure_is_catalog_bundle, + di.CatalogBundle) + + def test_with_string(self): + """Test with string.""" + self.assertRaises(di.Error, di.ensure_is_catalog_bundle, 'some_string') + + def test_with_object(self): + """Test with object.""" + self.assertRaises(di.Error, di.ensure_is_catalog_bundle, object())