diff --git a/dependency_injector/__init__.py b/dependency_injector/__init__.py index d0cd6304..5e38585a 100644 --- a/dependency_injector/__init__.py +++ b/dependency_injector/__init__.py @@ -1,7 +1,7 @@ """Dependency injector.""" from .catalog import AbstractCatalog -from .catalog import CatalogSubset +from .catalog import CatalogBundle from .catalog import override from .providers import Provider @@ -31,7 +31,8 @@ from .utils import is_kwarg_injection from .utils import is_attribute_injection from .utils import is_method_injection from .utils import is_catalog -from .utils import is_catalog_subset +from .utils import is_catalog_bundle +from .utils import ensure_is_catalog_bundle from .errors import Error @@ -39,7 +40,7 @@ from .errors import Error __all__ = ( # Catalogs 'AbstractCatalog', - 'CatalogSubset', + 'CatalogBundle', 'override', # Providers @@ -72,7 +73,8 @@ __all__ = ( 'is_attribute_injection', 'is_method_injection', 'is_catalog', - 'is_catalog_subset', + 'is_catalog_bundle', + 'ensure_is_catalog_bundle', # Errors 'Error', diff --git a/dependency_injector/catalog.py b/dependency_injector/catalog.py index 544cb8c2..f0abb854 100644 --- a/dependency_injector/catalog.py +++ b/dependency_injector/catalog.py @@ -6,13 +6,68 @@ from .errors import Error from .utils import is_provider from .utils import is_catalog +from .utils import ensure_is_catalog_bundle + + +class CatalogBundle(object): + """Bundle of catalog providers.""" + + catalog = None + """:type: AbstractCatalog""" + + __IS_CATALOG_BUNDLE__ = True + __slots__ = ('providers', '__dict__') + + def __init__(self, *providers): + """Initializer.""" + self.providers = dict((provider.bind.name, provider) + for provider in providers + if self._ensure_provider_is_bound(provider)) + self.__dict__.update(self.providers) + super(CatalogBundle, self).__init__() + + def get(self, name): + """Return provider with specified name or raises error.""" + try: + return self.providers[name] + except KeyError: + self._raise_undefined_provider_error(name) + + def has(self, name): + """Check if there is provider with certain name.""" + return name in self.providers + + def _ensure_provider_is_bound(self, provider): + """Check that provider is bound to the bundle's catalog.""" + if not provider.is_bound: + raise Error('Provider {0} is not bound to ' + 'any catalog'.format(provider)) + if provider is not self.catalog.get(provider.bind.name): + raise Error('{0} can contain providers from ' + 'catalog {0}'.format(self.__class__, self.catalog)) + return True + + def _raise_undefined_provider_error(self, name): + """Raise error for cases when there is no such provider in bundle.""" + raise Error('Provider "{0}" is not a part of {1}'.format(name, self)) + + def __getattr__(self, item): + """Raise an error on every attempt to get undefined provider.""" + if item.startswith('__') and item.endswith('__'): + return super(CatalogBundle, self).__getattr__(item) + self._raise_undefined_provider_error(item) + + def __repr__(self): + """Return string representation of bundle.""" + return ''.format( + self.catalog, ', '.join(six.iterkeys(self.providers))) class CatalogMetaClass(type): - """Providers catalog meta class.""" + """Catalog meta class.""" def __new__(mcs, class_name, bases, attributes): - """Meta class factory.""" + """Catalog class factory.""" cls_providers = dict((name, provider) for name, provider in six.iteritems(attributes) if is_provider(provider)) @@ -26,10 +81,28 @@ class CatalogMetaClass(type): providers.update(cls_providers) providers.update(inherited_providers) - attributes['cls_providers'] = cls_providers - attributes['inherited_providers'] = inherited_providers - attributes['providers'] = providers - return type.__new__(mcs, class_name, bases, attributes) + cls = type.__new__(mcs, class_name, bases, attributes) + + cls.cls_providers = cls_providers + cls.inherited_providers = inherited_providers + cls.providers = providers + + cls.Bundle = mcs.bundle_cls_factory(cls) + + for name, provider in six.iteritems(cls_providers): + if provider.is_bound: + raise Error('Provider {0} has been already bound to catalog' + '{1} as "{2}"'.format(provider, + provider.bind.catalog, + provider.bind.name)) + provider.bind = ProviderBinding(cls, name) + + return cls + + @classmethod + def bundle_cls_factory(mcs, cls): + """Create bundle class for catalog.""" + return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls)) def __repr__(cls): """Return string representation of the catalog class.""" @@ -50,27 +123,23 @@ class AbstractCatalog(object): :type inherited_providers: dict[str, dependency_injector.Provider] :param inherited_providers: Dict of providers, that are inherited from parent catalogs + + :type Bundle: CatalogBundle + :param Bundle: Catalog's bundle class """ - providers = dict() + Bundle = CatalogBundle + cls_providers = dict() inherited_providers = dict() + providers = dict() __IS_CATALOG__ = True - def __new__(cls, *providers): - """Catalog constructor. - - Catalogs are declaratives entities that could not be instantiated. - Catalog constructor is designed to produce subsets of catalog - providers. - """ - return CatalogSubset(catalog=cls, providers=providers) - @classmethod - def is_subset_owner(cls, subset): - """Check if catalog is subset owner.""" - return subset.catalog is cls + def is_bundle_owner(cls, bundle): + """Check if catalog is bundle owner.""" + return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls @classmethod def filter(cls, provider_type): @@ -103,52 +172,15 @@ class AbstractCatalog(object): return name in cls.providers -class CatalogSubset(object): - """Subset of catalog providers.""" +class ProviderBinding(object): + """Catalog provider binding.""" - __IS_SUBSET__ = True - __slots__ = ('catalog', 'available_providers', 'providers', '__dict__') + __slots__ = ('catalog', 'name') - def __init__(self, catalog, providers): + def __init__(self, catalog, name): """Initializer.""" self.catalog = catalog - self.available_providers = set(providers) - self.providers = dict() - for provider_name in self.available_providers: - try: - provider = self.catalog.providers[provider_name] - except KeyError: - raise Error('Subset could not add "{0}" provider in scope, ' - 'because {1} has no provider with ' - 'such name'.format(provider_name, self.catalog)) - else: - self.providers[provider_name] = provider - self.__dict__.update(self.providers) - super(CatalogSubset, self).__init__() - - def get(self, name): - """Return provider with specified name or raises error.""" - try: - return self.providers[name] - except KeyError: - self._raise_undefined_provider_error(name) - - def has(self, name): - """Check if there is provider with certain name.""" - return name in self.providers - - def __getattr__(self, item): - """Raise an error on every attempt to get undefined provider.""" - self._raise_undefined_provider_error(item) - - def __repr__(self): - """Return string representation of subset.""" - return ''.format( - ', '.join(self.available_providers), self.catalog) - - def _raise_undefined_provider_error(self, name): - """Raise error for cases when there is no such provider in subset.""" - raise Error('Provider "{0}" is not a part of {1}'.format(name, self)) + self.name = name def override(catalog): diff --git a/dependency_injector/injections.py b/dependency_injector/injections.py index bc0cf797..a7a29084 100644 --- a/dependency_injector/injections.py +++ b/dependency_injector/injections.py @@ -21,17 +21,18 @@ class Injection(object): """Base injection class.""" __IS_INJECTION__ = True - __slots__ = ('name', 'injectable') + __slots__ = ('name', 'injectable', 'is_provider') def __init__(self, name, injectable): """Initializer.""" self.name = name self.injectable = injectable + self.is_provider = is_provider(injectable) @property def value(self): """Return injectable value.""" - if is_provider(self.injectable): + if self.is_provider: return self.injectable() return self.injectable diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index 69351cbb..b510bd6b 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -18,11 +18,12 @@ class Provider(object): """Base provider class.""" __IS_PROVIDER__ = True - __slots__ = ('overridden_by',) + __slots__ = ('overridden_by', 'bind') def __init__(self): """Initializer.""" self.overridden_by = None + self.bind = None def __call__(self, *args, **kwargs): """Return provided instance.""" @@ -73,6 +74,11 @@ class Provider(object): """Reset all overriding providers.""" self.overridden_by = None + @property + def is_bound(self): + """Check if provider is bound to any catalog.""" + return bool(self.bind) + class Delegate(Provider): """Provider's delegate.""" diff --git a/dependency_injector/utils.py b/dependency_injector/utils.py index 426584cd..29547b9c 100644 --- a/dependency_injector/utils.py +++ b/dependency_injector/utils.py @@ -17,7 +17,10 @@ def is_provider(instance): def ensure_is_provider(instance): - """Check if instance is provider instance, otherwise raise and error.""" + """Check if instance is provider instance and return it. + + :raise: Error if provided instance is not provider. + """ if not is_provider(instance): raise Error('Expected provider instance, ' 'got {0}'.format(str(instance))) @@ -62,10 +65,21 @@ def is_catalog(instance): getattr(instance, '__IS_CATALOG__', False) is True) -def is_catalog_subset(instance): - """Check if instance is catalog subset instance.""" +def is_catalog_bundle(instance): + """Check if instance is catalog bundle instance.""" return (not isinstance(instance, six.class_types) and - getattr(instance, '__IS_SUBSET__', False) is True) + getattr(instance, '__IS_CATALOG_BUNDLE__', False) is True) + + +def ensure_is_catalog_bundle(instance): + """Check if instance is catalog bundle instance and return it. + + :raise: Error if provided instance is not catalog bundle. + """ + if not is_catalog_bundle(instance): + raise Error('Expected catalog bundle instance, ' + 'got {0}'.format(str(instance))) + return instance def get_injectable_kwargs(kwargs, injections): diff --git a/docs/catalogs/bundles.rst b/docs/catalogs/bundles.rst new file mode 100644 index 00000000..49e916e2 --- /dev/null +++ b/docs/catalogs/bundles.rst @@ -0,0 +1,40 @@ +Creating catalog provider bundles +--------------------------------- + +``di.AbstractCatalog.Bundle`` is a limited collection of catalog providers. +While catalog could be used as a centralized place for particular providers +group, such bundles of catalog providers can be used for creating several +limited scopes that could be passed to different subsystems. + +``di.AbstractCatalog.Bundle`` has exactly the same API as +``di.AbstractCatalog`` except of the limitations on getting providers. + +Each ``di.AbstractCatalog`` has a reference to its bundle class - +``di.AbstractCatalog.Bundle``. For example, if some concrete catalog has name +``SomeCatalog``, then its bundle class could be reached as +``SomeCatalog.Bundle``. + +``di.AbstractCatalog.Bundle`` expects to get the list of its catalog providers +as positional arguments and will limit the scope of created bundle to this +list. + +Example: + +.. image:: /images/catalogs/bundles.png + :width: 100% + :align: center + +Listing of `services.py`: + +.. literalinclude:: ../../examples/catalogs/bundles/services.py + :language: python + +Listing of `views.py`: + +.. literalinclude:: ../../examples/catalogs/bundles/views.py + :language: python + +Listing of `catalogs.py`: + +.. literalinclude:: ../../examples/catalogs/bundles/catalogs.py + :language: python diff --git a/docs/catalogs/index.rst b/docs/catalogs/index.rst index 3106ae5b..cfc3d579 100644 --- a/docs/catalogs/index.rst +++ b/docs/catalogs/index.rst @@ -21,5 +21,5 @@ of providers. writing operating - subsets + bundles overriding diff --git a/docs/catalogs/subsets.rst b/docs/catalogs/subsets.rst deleted file mode 100644 index be27061a..00000000 --- a/docs/catalogs/subsets.rst +++ /dev/null @@ -1,19 +0,0 @@ -Creating catalog subsets ------------------------- - -``di.AbstractCatalog`` subset is a limited collection of catalog providers. -While catalog could be used as a centralized place for particular providers -group, such subsets of catalog providers can be used for creating several -limited scopes that could be passed to different subsystems. - -``di.AbstractCatalog`` subsets could be created by instantiating of particular -catalog with passing provider names to the constructor. - -Example: - -.. image:: /images/catalogs/subsets.png - :width: 100% - :align: center - -.. literalinclude:: ../../examples/catalogs/subsets.py - :language: python diff --git a/docs/images/catalogs/bundles.png b/docs/images/catalogs/bundles.png new file mode 100644 index 00000000..14a20d22 Binary files /dev/null and b/docs/images/catalogs/bundles.png differ diff --git a/docs/images/catalogs/subsets.png b/docs/images/catalogs/subsets.png deleted file mode 100644 index f09c0c0d..00000000 Binary files a/docs/images/catalogs/subsets.png and /dev/null differ diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index b2dfed0e..20c321eb 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -12,12 +12,14 @@ Development version ------------------- - Add functionality for decorating classes with ``@di.inject``. -- Add functionality for creating ``di.AbstractCatalog`` subsets. +- Add functionality for creating ``di.AbstractCatalog`` provider bundles. - Add enhancement for ``di.AbstractCatalog`` inheritance. - Add images for catalog "Writing catalogs" and "Operating with catalogs" examples. - Add support of Python 3.5. - Add support of six 1.10.0. +- Add optimization for ``di.Injection.value`` property that will compute + type of injection once, instead of doing this on every call. - Add minor refactorings and code style fixes. 0.9.5 diff --git a/examples/catalogs/bundles/catalogs.py b/examples/catalogs/bundles/catalogs.py new file mode 100644 index 00000000..3d876c0e --- /dev/null +++ b/examples/catalogs/bundles/catalogs.py @@ -0,0 +1,59 @@ +"""Catalog bundles example.""" + +import dependency_injector as di + +import services +import views + + +# Declaring services catalog: +class Services(di.AbstractCatalog): + """Example catalog of service providers.""" + + users = di.Factory(services.UsersService) + """:type: (di.Provider) -> services.UsersService""" + + auth = di.Factory(services.AuthService) + """:type: (di.Provider) -> services.AuthService""" + + photos = di.Factory(services.PhotosService) + """:type: (di.Provider) -> services.PhotosService""" + + +# Declaring views catalog: +class Views(di.AbstractCatalog): + """Example catalog of web views.""" + + auth = di.Factory(views.AuthView, + services=Services.Bundle(Services.users, + Services.auth)) + """:type: (di.Provider) -> views.AuthView""" + + photos = di.Factory(views.PhotosView, + services=Services.Bundle(Services.users, + Services.photos)) + """:type: (di.Provider) -> views.PhotosView""" + + +# Creating example views: +auth_view = Views.auth() +photos_view = Views.photos() + +# Making some asserts: +assert auth_view.services.users is Services.users +assert auth_view.services.auth is Services.auth +try: + auth_view.services.photos +except di.Error: + # `photos` service provider is not in scope of `auth_view` services bundle, + # so `di.Error` will be raised. + pass + +assert photos_view.services.users is Services.users +assert photos_view.services.photos is Services.photos +try: + photos_view.services.auth +except di.Error as exception: + # `auth` service provider is not in scope of `photo_processing_view` + # services bundle, so `di.Error` will be raised. + pass diff --git a/examples/catalogs/bundles/services.py b/examples/catalogs/bundles/services.py new file mode 100644 index 00000000..91b9cae5 --- /dev/null +++ b/examples/catalogs/bundles/services.py @@ -0,0 +1,17 @@ +"""Example services.""" + + +class BaseService(object): + """Example base class of service.""" + + +class UsersService(BaseService): + """Example users service.""" + + +class AuthService(BaseService): + """Example auth service.""" + + +class PhotosService(BaseService): + """Example photo service.""" diff --git a/examples/catalogs/bundles/views.py b/examples/catalogs/bundles/views.py new file mode 100644 index 00000000..658ff9f0 --- /dev/null +++ b/examples/catalogs/bundles/views.py @@ -0,0 +1,21 @@ +"""Example web views.""" + + +class BaseWebView(object): + """Example base class of web view.""" + + def __init__(self, services): + """Initializer. + + :type services: Services + :param services: Bundle of service providers + """ + self.services = services + + +class AuthView(BaseWebView): + """Example auth web view.""" + + +class PhotosView(BaseWebView): + """Example photo processing web view.""" diff --git a/examples/catalogs/subsets.py b/examples/catalogs/subsets.py deleted file mode 100644 index 814af380..00000000 --- a/examples/catalogs/subsets.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Catalog subsets example.""" - -import dependency_injector as di - - -# Declaring example services catalog: -class Services(di.AbstractCatalog): - """Example catalog of service providers.""" - - users = di.Provider() - - auth = di.Provider() - - photos = di.Provider() - - -# Declaring example base class for some web views: -class BaseWebView(object): - """Example base class of web view.""" - - def __init__(self, services): - """Initializer. - - :type services: Services - :param services: Subset of service providers - """ - self.services = services - - -# Declaring several example web views: -class AuthView(BaseWebView): - """Example auth web view.""" - - -class PhotosView(BaseWebView): - """Example photo processing web view.""" - -# Creating example views with appropriate service provider subsets: -auth_view = AuthView(Services('users', 'auth')) -photos_view = PhotosView(Services('users', 'photos')) - -# Making some asserts: -assert auth_view.services.users is Services.users -assert auth_view.services.auth is Services.auth -try: - auth_view.services.photos -except di.Error: - # `photos` service provider is not in scope of `auth_view` services subset, - # so `di.Error` will be raised. - pass - -assert photos_view.services.users is Services.users -assert photos_view.services.photos is Services.photos -try: - photos_view.services.auth -except di.Error as exception: - # `auth` service provider is not in scope of `photo_processing_view` - # services subset, so `di.Error` will be raised. - pass diff --git a/tests/test_catalog.py b/tests/test_catalog.py index 26f1ec9b..f55cff9e 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -71,52 +71,98 @@ class CatalogsInheritanceTests(unittest.TestCase): p32=CatalogC.p32)) -class CatalogSubsetTests(unittest.TestCase): - """Catalog subset test cases.""" +class CatalogProvidersBindingTests(unittest.TestCase): + """Catalog providers binding test cases.""" - catalog = None + def test_provider_is_bound(self): + """Test that providers are bound to the catalogs.""" + self.assertIs(CatalogA.p11.bind.catalog, CatalogA) + self.assertEquals(CatalogA.p11.bind.name, 'p11') + + self.assertIs(CatalogA.p12.bind.catalog, CatalogA) + self.assertEquals(CatalogA.p12.bind.name, 'p12') + + def test_provider_rebinding(self): + """Test that provider could not be bound twice.""" + self.assertRaises(di.Error, type, 'TestCatalog', (di.AbstractCatalog,), + dict(some_name=CatalogA.p11)) + + +class CatalogBundleTests(unittest.TestCase): + """Catalog bundle test cases.""" def setUp(self): """Set test environment up.""" - self.subset = CatalogC('p11', 'p12') + self.bundle = CatalogC.Bundle(CatalogC.p11, + CatalogC.p12) - def test_get_attr_from_subset(self): - """Test get providers (attribute) from subset.""" - self.assertIs(self.subset.p11, CatalogC.p11) - self.assertIs(self.subset.p12, CatalogC.p12) + def test_get_attr_from_bundle(self): + """Test get providers (attribute) from catalog bundle.""" + self.assertIs(self.bundle.p11, CatalogC.p11) + self.assertIs(self.bundle.p12, CatalogC.p12) - def test_get_attr_not_from_subset(self): - """Test get providers (attribute) that are not in subset.""" - self.assertRaises(di.Error, getattr, self.subset, 'p21') - self.assertRaises(di.Error, getattr, self.subset, 'p22') - self.assertRaises(di.Error, getattr, self.subset, 'p31') - self.assertRaises(di.Error, getattr, self.subset, 'p32') + def test_get_attr_not_from_bundle(self): + """Test get providers (attribute) that are not in bundle.""" + self.assertRaises(di.Error, getattr, self.bundle, 'p21') + self.assertRaises(di.Error, getattr, self.bundle, 'p22') + self.assertRaises(di.Error, getattr, self.bundle, 'p31') + self.assertRaises(di.Error, getattr, self.bundle, 'p32') - def test_get_method_from_subset(self): - """Test get providers (get() method) from subset.""" - self.assertIs(self.subset.get('p11'), CatalogC.p11) - self.assertIs(self.subset.get('p12'), CatalogC.p12) + def test_get_method_from_bundle(self): + """Test get providers (get() method) from bundle.""" + self.assertIs(self.bundle.get('p11'), CatalogC.p11) + self.assertIs(self.bundle.get('p12'), CatalogC.p12) - def test_get_method_not_from_subset(self): - """Test get providers (get() method) that are not in subset.""" - self.assertRaises(di.Error, self.subset.get, 'p21') - self.assertRaises(di.Error, self.subset.get, 'p22') - self.assertRaises(di.Error, self.subset.get, 'p31') - self.assertRaises(di.Error, self.subset.get, 'p32') + def test_get_method_not_from_bundle(self): + """Test get providers (get() method) that are not in bundle.""" + self.assertRaises(di.Error, self.bundle.get, 'p21') + self.assertRaises(di.Error, self.bundle.get, 'p22') + self.assertRaises(di.Error, self.bundle.get, 'p31') + self.assertRaises(di.Error, self.bundle.get, 'p32') def test_has(self): - """Test checks of providers availability in subsets.""" - self.assertTrue(self.subset.has('p11')) - self.assertTrue(self.subset.has('p12')) + """Test checks of providers availability in bundle.""" + self.assertTrue(self.bundle.has('p11')) + self.assertTrue(self.bundle.has('p12')) - self.assertFalse(self.subset.has('p21')) - self.assertFalse(self.subset.has('p22')) - self.assertFalse(self.subset.has('p31')) - self.assertFalse(self.subset.has('p32')) + self.assertFalse(self.bundle.has('p21')) + self.assertFalse(self.bundle.has('p22')) + self.assertFalse(self.bundle.has('p31')) + self.assertFalse(self.bundle.has('p32')) - def test_creating_with_undefined_provider(self): - """Test subset creation with provider that is not in catalog.""" - self.assertRaises(di.Error, CatalogC, 'undefined_provider') + def test_create_bundle_with_unbound_provider(self): + """Test that bundle is not created with unbound provider.""" + self.assertRaises(di.Error, CatalogC.Bundle, di.Provider()) + + def test_create_bundle_with_another_catalog_provider(self): + """Test that bundle can not contain another catalog's provider.""" + class TestCatalog(di.AbstractCatalog): + """Test catalog.""" + + provider = di.Provider() + + self.assertRaises(di.Error, + CatalogC.Bundle, CatalogC.p31, TestCatalog.provider) + + def test_create_bundle_with_another_catalog_provider_with_same_name(self): + """Test that bundle can not contain another catalog's provider.""" + class TestCatalog(di.AbstractCatalog): + """Test catalog.""" + + p31 = di.Provider() + + self.assertRaises(di.Error, + CatalogC.Bundle, CatalogC.p31, TestCatalog.p31) + + def test_is_bundle_owner(self): + """Test that catalog bundle is owned by catalog.""" + self.assertTrue(CatalogC.is_bundle_owner(self.bundle)) + self.assertFalse(CatalogB.is_bundle_owner(self.bundle)) + self.assertFalse(CatalogA.is_bundle_owner(self.bundle)) + + def test_is_bundle_owner_with_not_bundle_instance(self): + """Test that check of bundle ownership raises error with not bundle.""" + self.assertRaises(di.Error, CatalogC.is_bundle_owner, object()) class CatalogTests(unittest.TestCase): @@ -136,7 +182,7 @@ class CatalogTests(unittest.TestCase): self.assertRaises(di.Error, CatalogC.get, 'undefined') def test_has(self): - """Test checks of providers availability in subsets.""" + """Test checks of providers availability in catalog.""" self.assertTrue(CatalogC.has('p11')) self.assertTrue(CatalogC.has('p12')) self.assertTrue(CatalogC.has('p21')) @@ -145,14 +191,6 @@ class CatalogTests(unittest.TestCase): self.assertTrue(CatalogC.has('p32')) self.assertFalse(CatalogC.has('undefined')) - def test_is_subset_owner(self): - """Test that catalog is subset owner.""" - subset = CatalogA() - - self.assertTrue(CatalogA.is_subset_owner(subset)) - self.assertFalse(CatalogB.is_subset_owner(subset)) - self.assertFalse(CatalogC.is_subset_owner(subset)) - def test_filter_all_providers_by_type(self): """Test getting of all catalog providers of specific type.""" self.assertTrue(len(CatalogC.filter(di.Provider)) == 6) diff --git a/tests/test_injections.py b/tests/test_injections.py index f84ad280..c96567b7 100644 --- a/tests/test_injections.py +++ b/tests/test_injections.py @@ -23,6 +23,17 @@ class InjectionTests(unittest.TestCase): injection = di.Injection('some_arg_name', di.Factory(object)) self.assertIsInstance(injection.value, object) + def test_value_with_catalog_bundle_injectable(self): + """Test Injection value property with catalog bundle.""" + class TestCatalog(di.AbstractCatalog): + """Test catalog.""" + + provider = di.Provider() + injection = di.Injection('some_arg_name', + TestCatalog.Bundle(TestCatalog.provider)) + + self.assertIsInstance(injection.value, TestCatalog.Bundle) + class KwArgTests(unittest.TestCase): """Keyword arg injection test cases.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 3d86b04d..61705ff1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -210,22 +210,43 @@ class IsCatalogTests(unittest.TestCase): self.assertFalse(di.is_catalog(object())) -class IsCatalogSubsetTests(unittest.TestCase): - """`is_catalog_subset()` test cases.""" +class IsCatalogBundleTests(unittest.TestCase): + """`is_catalog_bundle()` test cases.""" + + def test_with_instance(self): + """Test with instance.""" + self.assertTrue(di.is_catalog_bundle(di.CatalogBundle())) def test_with_cls(self): """Test with class.""" - self.assertFalse(di.is_catalog_subset(di.CatalogSubset)) - - def test_with_instance(self): - """Test with class.""" - self.assertTrue(di.is_catalog_subset( - di.CatalogSubset(catalog=di.AbstractCatalog, providers=tuple()))) + self.assertFalse(di.is_catalog_bundle(di.CatalogBundle)) def test_with_string(self): """Test with string.""" - self.assertFalse(di.is_catalog_subset('some_string')) + self.assertFalse(di.is_catalog_bundle('some_string')) def test_with_object(self): """Test with object.""" - self.assertFalse(di.is_catalog_subset(object())) + self.assertFalse(di.is_catalog_bundle(object())) + + +class EnsureIsCatalogBundleTests(unittest.TestCase): + """`ensure_is_catalog_bundle` test cases.""" + + def test_with_instance(self): + """Test with instance.""" + bundle = di.CatalogBundle() + self.assertIs(di.ensure_is_catalog_bundle(bundle), bundle) + + def test_with_class(self): + """Test with class.""" + self.assertRaises(di.Error, di.ensure_is_catalog_bundle, + di.CatalogBundle) + + def test_with_string(self): + """Test with string.""" + self.assertRaises(di.Error, di.ensure_is_catalog_bundle, 'some_string') + + def test_with_object(self): + """Test with object.""" + self.assertRaises(di.Error, di.ensure_is_catalog_bundle, object())