Merge remote-tracking branch 'origin/catalog_bundles'

This commit is contained in:
Roman Mogilatov 2015-10-19 12:14:55 +03:00
commit 2a5ed703bf
18 changed files with 391 additions and 205 deletions

View File

@ -1,7 +1,7 @@
"""Dependency injector.""" """Dependency injector."""
from .catalog import AbstractCatalog from .catalog import AbstractCatalog
from .catalog import CatalogSubset from .catalog import CatalogBundle
from .catalog import override from .catalog import override
from .providers import Provider from .providers import Provider
@ -31,7 +31,8 @@ from .utils import is_kwarg_injection
from .utils import is_attribute_injection from .utils import is_attribute_injection
from .utils import is_method_injection from .utils import is_method_injection
from .utils import is_catalog from .utils import is_catalog
from .utils import is_catalog_subset from .utils import is_catalog_bundle
from .utils import ensure_is_catalog_bundle
from .errors import Error from .errors import Error
@ -39,7 +40,7 @@ from .errors import Error
__all__ = ( __all__ = (
# Catalogs # Catalogs
'AbstractCatalog', 'AbstractCatalog',
'CatalogSubset', 'CatalogBundle',
'override', 'override',
# Providers # Providers
@ -72,7 +73,8 @@ __all__ = (
'is_attribute_injection', 'is_attribute_injection',
'is_method_injection', 'is_method_injection',
'is_catalog', 'is_catalog',
'is_catalog_subset', 'is_catalog_bundle',
'ensure_is_catalog_bundle',
# Errors # Errors
'Error', 'Error',

View File

@ -6,13 +6,68 @@ from .errors import Error
from .utils import is_provider from .utils import is_provider
from .utils import is_catalog from .utils import is_catalog
from .utils import ensure_is_catalog_bundle
class CatalogBundle(object):
"""Bundle of catalog providers."""
catalog = None
""":type: AbstractCatalog"""
__IS_CATALOG_BUNDLE__ = True
__slots__ = ('providers', '__dict__')
def __init__(self, *providers):
"""Initializer."""
self.providers = dict((provider.bind.name, provider)
for provider in providers
if self._ensure_provider_is_bound(provider))
self.__dict__.update(self.providers)
super(CatalogBundle, self).__init__()
def get(self, name):
"""Return provider with specified name or raises error."""
try:
return self.providers[name]
except KeyError:
self._raise_undefined_provider_error(name)
def has(self, name):
"""Check if there is provider with certain name."""
return name in self.providers
def _ensure_provider_is_bound(self, provider):
"""Check that provider is bound to the bundle's catalog."""
if not provider.is_bound:
raise Error('Provider {0} is not bound to '
'any catalog'.format(provider))
if provider is not self.catalog.get(provider.bind.name):
raise Error('{0} can contain providers from '
'catalog {0}'.format(self.__class__, self.catalog))
return True
def _raise_undefined_provider_error(self, name):
"""Raise error for cases when there is no such provider in bundle."""
raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
def __getattr__(self, item):
"""Raise an error on every attempt to get undefined provider."""
if item.startswith('__') and item.endswith('__'):
return super(CatalogBundle, self).__getattr__(item)
self._raise_undefined_provider_error(item)
def __repr__(self):
"""Return string representation of bundle."""
return '<Bundle of {0} providers ({1})>'.format(
self.catalog, ', '.join(six.iterkeys(self.providers)))
class CatalogMetaClass(type): class CatalogMetaClass(type):
"""Providers catalog meta class.""" """Catalog meta class."""
def __new__(mcs, class_name, bases, attributes): def __new__(mcs, class_name, bases, attributes):
"""Meta class factory.""" """Catalog class factory."""
cls_providers = dict((name, provider) cls_providers = dict((name, provider)
for name, provider in six.iteritems(attributes) for name, provider in six.iteritems(attributes)
if is_provider(provider)) if is_provider(provider))
@ -26,10 +81,28 @@ class CatalogMetaClass(type):
providers.update(cls_providers) providers.update(cls_providers)
providers.update(inherited_providers) providers.update(inherited_providers)
attributes['cls_providers'] = cls_providers cls = type.__new__(mcs, class_name, bases, attributes)
attributes['inherited_providers'] = inherited_providers
attributes['providers'] = providers cls.cls_providers = cls_providers
return type.__new__(mcs, class_name, bases, attributes) cls.inherited_providers = inherited_providers
cls.providers = providers
cls.Bundle = mcs.bundle_cls_factory(cls)
for name, provider in six.iteritems(cls_providers):
if provider.is_bound:
raise Error('Provider {0} has been already bound to catalog'
'{1} as "{2}"'.format(provider,
provider.bind.catalog,
provider.bind.name))
provider.bind = ProviderBinding(cls, name)
return cls
@classmethod
def bundle_cls_factory(mcs, cls):
"""Create bundle class for catalog."""
return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls))
def __repr__(cls): def __repr__(cls):
"""Return string representation of the catalog class.""" """Return string representation of the catalog class."""
@ -50,27 +123,23 @@ class AbstractCatalog(object):
:type inherited_providers: dict[str, dependency_injector.Provider] :type inherited_providers: dict[str, dependency_injector.Provider]
:param inherited_providers: Dict of providers, that are inherited from :param inherited_providers: Dict of providers, that are inherited from
parent catalogs parent catalogs
:type Bundle: CatalogBundle
:param Bundle: Catalog's bundle class
""" """
providers = dict() Bundle = CatalogBundle
cls_providers = dict() cls_providers = dict()
inherited_providers = dict() inherited_providers = dict()
providers = dict()
__IS_CATALOG__ = True __IS_CATALOG__ = True
def __new__(cls, *providers):
"""Catalog constructor.
Catalogs are declaratives entities that could not be instantiated.
Catalog constructor is designed to produce subsets of catalog
providers.
"""
return CatalogSubset(catalog=cls, providers=providers)
@classmethod @classmethod
def is_subset_owner(cls, subset): def is_bundle_owner(cls, bundle):
"""Check if catalog is subset owner.""" """Check if catalog is bundle owner."""
return subset.catalog is cls return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls
@classmethod @classmethod
def filter(cls, provider_type): def filter(cls, provider_type):
@ -103,52 +172,15 @@ class AbstractCatalog(object):
return name in cls.providers return name in cls.providers
class CatalogSubset(object): class ProviderBinding(object):
"""Subset of catalog providers.""" """Catalog provider binding."""
__IS_SUBSET__ = True __slots__ = ('catalog', 'name')
__slots__ = ('catalog', 'available_providers', 'providers', '__dict__')
def __init__(self, catalog, providers): def __init__(self, catalog, name):
"""Initializer.""" """Initializer."""
self.catalog = catalog self.catalog = catalog
self.available_providers = set(providers) self.name = name
self.providers = dict()
for provider_name in self.available_providers:
try:
provider = self.catalog.providers[provider_name]
except KeyError:
raise Error('Subset could not add "{0}" provider in scope, '
'because {1} has no provider with '
'such name'.format(provider_name, self.catalog))
else:
self.providers[provider_name] = provider
self.__dict__.update(self.providers)
super(CatalogSubset, self).__init__()
def get(self, name):
"""Return provider with specified name or raises error."""
try:
return self.providers[name]
except KeyError:
self._raise_undefined_provider_error(name)
def has(self, name):
"""Check if there is provider with certain name."""
return name in self.providers
def __getattr__(self, item):
"""Raise an error on every attempt to get undefined provider."""
self._raise_undefined_provider_error(item)
def __repr__(self):
"""Return string representation of subset."""
return '<Subset ({0}), {1}>'.format(
', '.join(self.available_providers), self.catalog)
def _raise_undefined_provider_error(self, name):
"""Raise error for cases when there is no such provider in subset."""
raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
def override(catalog): def override(catalog):

View File

@ -21,17 +21,18 @@ class Injection(object):
"""Base injection class.""" """Base injection class."""
__IS_INJECTION__ = True __IS_INJECTION__ = True
__slots__ = ('name', 'injectable') __slots__ = ('name', 'injectable', 'is_provider')
def __init__(self, name, injectable): def __init__(self, name, injectable):
"""Initializer.""" """Initializer."""
self.name = name self.name = name
self.injectable = injectable self.injectable = injectable
self.is_provider = is_provider(injectable)
@property @property
def value(self): def value(self):
"""Return injectable value.""" """Return injectable value."""
if is_provider(self.injectable): if self.is_provider:
return self.injectable() return self.injectable()
return self.injectable return self.injectable

View File

@ -18,11 +18,12 @@ class Provider(object):
"""Base provider class.""" """Base provider class."""
__IS_PROVIDER__ = True __IS_PROVIDER__ = True
__slots__ = ('overridden_by',) __slots__ = ('overridden_by', 'bind')
def __init__(self): def __init__(self):
"""Initializer.""" """Initializer."""
self.overridden_by = None self.overridden_by = None
self.bind = None
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Return provided instance.""" """Return provided instance."""
@ -73,6 +74,11 @@ class Provider(object):
"""Reset all overriding providers.""" """Reset all overriding providers."""
self.overridden_by = None self.overridden_by = None
@property
def is_bound(self):
"""Check if provider is bound to any catalog."""
return bool(self.bind)
class Delegate(Provider): class Delegate(Provider):
"""Provider's delegate.""" """Provider's delegate."""

View File

@ -17,7 +17,10 @@ def is_provider(instance):
def ensure_is_provider(instance): def ensure_is_provider(instance):
"""Check if instance is provider instance, otherwise raise and error.""" """Check if instance is provider instance and return it.
:raise: Error if provided instance is not provider.
"""
if not is_provider(instance): if not is_provider(instance):
raise Error('Expected provider instance, ' raise Error('Expected provider instance, '
'got {0}'.format(str(instance))) 'got {0}'.format(str(instance)))
@ -62,10 +65,21 @@ def is_catalog(instance):
getattr(instance, '__IS_CATALOG__', False) is True) getattr(instance, '__IS_CATALOG__', False) is True)
def is_catalog_subset(instance): def is_catalog_bundle(instance):
"""Check if instance is catalog subset instance.""" """Check if instance is catalog bundle instance."""
return (not isinstance(instance, six.class_types) and return (not isinstance(instance, six.class_types) and
getattr(instance, '__IS_SUBSET__', False) is True) getattr(instance, '__IS_CATALOG_BUNDLE__', False) is True)
def ensure_is_catalog_bundle(instance):
"""Check if instance is catalog bundle instance and return it.
:raise: Error if provided instance is not catalog bundle.
"""
if not is_catalog_bundle(instance):
raise Error('Expected catalog bundle instance, '
'got {0}'.format(str(instance)))
return instance
def get_injectable_kwargs(kwargs, injections): def get_injectable_kwargs(kwargs, injections):

40
docs/catalogs/bundles.rst Normal file
View File

@ -0,0 +1,40 @@
Creating catalog provider bundles
---------------------------------
``di.AbstractCatalog.Bundle`` is a limited collection of catalog providers.
While catalog could be used as a centralized place for particular providers
group, such bundles of catalog providers can be used for creating several
limited scopes that could be passed to different subsystems.
``di.AbstractCatalog.Bundle`` has exactly the same API as
``di.AbstractCatalog`` except of the limitations on getting providers.
Each ``di.AbstractCatalog`` has a reference to its bundle class -
``di.AbstractCatalog.Bundle``. For example, if some concrete catalog has name
``SomeCatalog``, then its bundle class could be reached as
``SomeCatalog.Bundle``.
``di.AbstractCatalog.Bundle`` expects to get the list of its catalog providers
as positional arguments and will limit the scope of created bundle to this
list.
Example:
.. image:: /images/catalogs/bundles.png
:width: 100%
:align: center
Listing of `services.py`:
.. literalinclude:: ../../examples/catalogs/bundles/services.py
:language: python
Listing of `views.py`:
.. literalinclude:: ../../examples/catalogs/bundles/views.py
:language: python
Listing of `catalogs.py`:
.. literalinclude:: ../../examples/catalogs/bundles/catalogs.py
:language: python

View File

@ -21,5 +21,5 @@ of providers.
writing writing
operating operating
subsets bundles
overriding overriding

View File

@ -1,19 +0,0 @@
Creating catalog subsets
------------------------
``di.AbstractCatalog`` subset is a limited collection of catalog providers.
While catalog could be used as a centralized place for particular providers
group, such subsets of catalog providers can be used for creating several
limited scopes that could be passed to different subsystems.
``di.AbstractCatalog`` subsets could be created by instantiating of particular
catalog with passing provider names to the constructor.
Example:
.. image:: /images/catalogs/subsets.png
:width: 100%
:align: center
.. literalinclude:: ../../examples/catalogs/subsets.py
:language: python

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 96 KiB

View File

@ -12,12 +12,14 @@ Development version
------------------- -------------------
- Add functionality for decorating classes with ``@di.inject``. - Add functionality for decorating classes with ``@di.inject``.
- Add functionality for creating ``di.AbstractCatalog`` subsets. - Add functionality for creating ``di.AbstractCatalog`` provider bundles.
- Add enhancement for ``di.AbstractCatalog`` inheritance. - Add enhancement for ``di.AbstractCatalog`` inheritance.
- Add images for catalog "Writing catalogs" and "Operating with catalogs" - Add images for catalog "Writing catalogs" and "Operating with catalogs"
examples. examples.
- Add support of Python 3.5. - Add support of Python 3.5.
- Add support of six 1.10.0. - Add support of six 1.10.0.
- Add optimization for ``di.Injection.value`` property that will compute
type of injection once, instead of doing this on every call.
- Add minor refactorings and code style fixes. - Add minor refactorings and code style fixes.
0.9.5 0.9.5

View File

@ -0,0 +1,59 @@
"""Catalog bundles example."""
import dependency_injector as di
import services
import views
# Declaring services catalog:
class Services(di.AbstractCatalog):
"""Example catalog of service providers."""
users = di.Factory(services.UsersService)
""":type: (di.Provider) -> services.UsersService"""
auth = di.Factory(services.AuthService)
""":type: (di.Provider) -> services.AuthService"""
photos = di.Factory(services.PhotosService)
""":type: (di.Provider) -> services.PhotosService"""
# Declaring views catalog:
class Views(di.AbstractCatalog):
"""Example catalog of web views."""
auth = di.Factory(views.AuthView,
services=Services.Bundle(Services.users,
Services.auth))
""":type: (di.Provider) -> views.AuthView"""
photos = di.Factory(views.PhotosView,
services=Services.Bundle(Services.users,
Services.photos))
""":type: (di.Provider) -> views.PhotosView"""
# Creating example views:
auth_view = Views.auth()
photos_view = Views.photos()
# Making some asserts:
assert auth_view.services.users is Services.users
assert auth_view.services.auth is Services.auth
try:
auth_view.services.photos
except di.Error:
# `photos` service provider is not in scope of `auth_view` services bundle,
# so `di.Error` will be raised.
pass
assert photos_view.services.users is Services.users
assert photos_view.services.photos is Services.photos
try:
photos_view.services.auth
except di.Error as exception:
# `auth` service provider is not in scope of `photo_processing_view`
# services bundle, so `di.Error` will be raised.
pass

View File

@ -0,0 +1,17 @@
"""Example services."""
class BaseService(object):
"""Example base class of service."""
class UsersService(BaseService):
"""Example users service."""
class AuthService(BaseService):
"""Example auth service."""
class PhotosService(BaseService):
"""Example photo service."""

View File

@ -0,0 +1,21 @@
"""Example web views."""
class BaseWebView(object):
"""Example base class of web view."""
def __init__(self, services):
"""Initializer.
:type services: Services
:param services: Bundle of service providers
"""
self.services = services
class AuthView(BaseWebView):
"""Example auth web view."""
class PhotosView(BaseWebView):
"""Example photo processing web view."""

View File

@ -1,59 +0,0 @@
"""Catalog subsets example."""
import dependency_injector as di
# Declaring example services catalog:
class Services(di.AbstractCatalog):
"""Example catalog of service providers."""
users = di.Provider()
auth = di.Provider()
photos = di.Provider()
# Declaring example base class for some web views:
class BaseWebView(object):
"""Example base class of web view."""
def __init__(self, services):
"""Initializer.
:type services: Services
:param services: Subset of service providers
"""
self.services = services
# Declaring several example web views:
class AuthView(BaseWebView):
"""Example auth web view."""
class PhotosView(BaseWebView):
"""Example photo processing web view."""
# Creating example views with appropriate service provider subsets:
auth_view = AuthView(Services('users', 'auth'))
photos_view = PhotosView(Services('users', 'photos'))
# Making some asserts:
assert auth_view.services.users is Services.users
assert auth_view.services.auth is Services.auth
try:
auth_view.services.photos
except di.Error:
# `photos` service provider is not in scope of `auth_view` services subset,
# so `di.Error` will be raised.
pass
assert photos_view.services.users is Services.users
assert photos_view.services.photos is Services.photos
try:
photos_view.services.auth
except di.Error as exception:
# `auth` service provider is not in scope of `photo_processing_view`
# services subset, so `di.Error` will be raised.
pass

View File

@ -71,52 +71,98 @@ class CatalogsInheritanceTests(unittest.TestCase):
p32=CatalogC.p32)) p32=CatalogC.p32))
class CatalogSubsetTests(unittest.TestCase): class CatalogProvidersBindingTests(unittest.TestCase):
"""Catalog subset test cases.""" """Catalog providers binding test cases."""
catalog = None def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs."""
self.assertIs(CatalogA.p11.bind.catalog, CatalogA)
self.assertEquals(CatalogA.p11.bind.name, 'p11')
self.assertIs(CatalogA.p12.bind.catalog, CatalogA)
self.assertEquals(CatalogA.p12.bind.name, 'p12')
def test_provider_rebinding(self):
"""Test that provider could not be bound twice."""
self.assertRaises(di.Error, type, 'TestCatalog', (di.AbstractCatalog,),
dict(some_name=CatalogA.p11))
class CatalogBundleTests(unittest.TestCase):
"""Catalog bundle test cases."""
def setUp(self): def setUp(self):
"""Set test environment up.""" """Set test environment up."""
self.subset = CatalogC('p11', 'p12') self.bundle = CatalogC.Bundle(CatalogC.p11,
CatalogC.p12)
def test_get_attr_from_subset(self): def test_get_attr_from_bundle(self):
"""Test get providers (attribute) from subset.""" """Test get providers (attribute) from catalog bundle."""
self.assertIs(self.subset.p11, CatalogC.p11) self.assertIs(self.bundle.p11, CatalogC.p11)
self.assertIs(self.subset.p12, CatalogC.p12) self.assertIs(self.bundle.p12, CatalogC.p12)
def test_get_attr_not_from_subset(self): def test_get_attr_not_from_bundle(self):
"""Test get providers (attribute) that are not in subset.""" """Test get providers (attribute) that are not in bundle."""
self.assertRaises(di.Error, getattr, self.subset, 'p21') self.assertRaises(di.Error, getattr, self.bundle, 'p21')
self.assertRaises(di.Error, getattr, self.subset, 'p22') self.assertRaises(di.Error, getattr, self.bundle, 'p22')
self.assertRaises(di.Error, getattr, self.subset, 'p31') self.assertRaises(di.Error, getattr, self.bundle, 'p31')
self.assertRaises(di.Error, getattr, self.subset, 'p32') self.assertRaises(di.Error, getattr, self.bundle, 'p32')
def test_get_method_from_subset(self): def test_get_method_from_bundle(self):
"""Test get providers (get() method) from subset.""" """Test get providers (get() method) from bundle."""
self.assertIs(self.subset.get('p11'), CatalogC.p11) self.assertIs(self.bundle.get('p11'), CatalogC.p11)
self.assertIs(self.subset.get('p12'), CatalogC.p12) self.assertIs(self.bundle.get('p12'), CatalogC.p12)
def test_get_method_not_from_subset(self): def test_get_method_not_from_bundle(self):
"""Test get providers (get() method) that are not in subset.""" """Test get providers (get() method) that are not in bundle."""
self.assertRaises(di.Error, self.subset.get, 'p21') self.assertRaises(di.Error, self.bundle.get, 'p21')
self.assertRaises(di.Error, self.subset.get, 'p22') self.assertRaises(di.Error, self.bundle.get, 'p22')
self.assertRaises(di.Error, self.subset.get, 'p31') self.assertRaises(di.Error, self.bundle.get, 'p31')
self.assertRaises(di.Error, self.subset.get, 'p32') self.assertRaises(di.Error, self.bundle.get, 'p32')
def test_has(self): def test_has(self):
"""Test checks of providers availability in subsets.""" """Test checks of providers availability in bundle."""
self.assertTrue(self.subset.has('p11')) self.assertTrue(self.bundle.has('p11'))
self.assertTrue(self.subset.has('p12')) self.assertTrue(self.bundle.has('p12'))
self.assertFalse(self.subset.has('p21')) self.assertFalse(self.bundle.has('p21'))
self.assertFalse(self.subset.has('p22')) self.assertFalse(self.bundle.has('p22'))
self.assertFalse(self.subset.has('p31')) self.assertFalse(self.bundle.has('p31'))
self.assertFalse(self.subset.has('p32')) self.assertFalse(self.bundle.has('p32'))
def test_creating_with_undefined_provider(self): def test_create_bundle_with_unbound_provider(self):
"""Test subset creation with provider that is not in catalog.""" """Test that bundle is not created with unbound provider."""
self.assertRaises(di.Error, CatalogC, 'undefined_provider') self.assertRaises(di.Error, CatalogC.Bundle, di.Provider())
def test_create_bundle_with_another_catalog_provider(self):
"""Test that bundle can not contain another catalog's provider."""
class TestCatalog(di.AbstractCatalog):
"""Test catalog."""
provider = di.Provider()
self.assertRaises(di.Error,
CatalogC.Bundle, CatalogC.p31, TestCatalog.provider)
def test_create_bundle_with_another_catalog_provider_with_same_name(self):
"""Test that bundle can not contain another catalog's provider."""
class TestCatalog(di.AbstractCatalog):
"""Test catalog."""
p31 = di.Provider()
self.assertRaises(di.Error,
CatalogC.Bundle, CatalogC.p31, TestCatalog.p31)
def test_is_bundle_owner(self):
"""Test that catalog bundle is owned by catalog."""
self.assertTrue(CatalogC.is_bundle_owner(self.bundle))
self.assertFalse(CatalogB.is_bundle_owner(self.bundle))
self.assertFalse(CatalogA.is_bundle_owner(self.bundle))
def test_is_bundle_owner_with_not_bundle_instance(self):
"""Test that check of bundle ownership raises error with not bundle."""
self.assertRaises(di.Error, CatalogC.is_bundle_owner, object())
class CatalogTests(unittest.TestCase): class CatalogTests(unittest.TestCase):
@ -136,7 +182,7 @@ class CatalogTests(unittest.TestCase):
self.assertRaises(di.Error, CatalogC.get, 'undefined') self.assertRaises(di.Error, CatalogC.get, 'undefined')
def test_has(self): def test_has(self):
"""Test checks of providers availability in subsets.""" """Test checks of providers availability in catalog."""
self.assertTrue(CatalogC.has('p11')) self.assertTrue(CatalogC.has('p11'))
self.assertTrue(CatalogC.has('p12')) self.assertTrue(CatalogC.has('p12'))
self.assertTrue(CatalogC.has('p21')) self.assertTrue(CatalogC.has('p21'))
@ -145,14 +191,6 @@ class CatalogTests(unittest.TestCase):
self.assertTrue(CatalogC.has('p32')) self.assertTrue(CatalogC.has('p32'))
self.assertFalse(CatalogC.has('undefined')) self.assertFalse(CatalogC.has('undefined'))
def test_is_subset_owner(self):
"""Test that catalog is subset owner."""
subset = CatalogA()
self.assertTrue(CatalogA.is_subset_owner(subset))
self.assertFalse(CatalogB.is_subset_owner(subset))
self.assertFalse(CatalogC.is_subset_owner(subset))
def test_filter_all_providers_by_type(self): def test_filter_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type.""" """Test getting of all catalog providers of specific type."""
self.assertTrue(len(CatalogC.filter(di.Provider)) == 6) self.assertTrue(len(CatalogC.filter(di.Provider)) == 6)

View File

@ -23,6 +23,17 @@ class InjectionTests(unittest.TestCase):
injection = di.Injection('some_arg_name', di.Factory(object)) injection = di.Injection('some_arg_name', di.Factory(object))
self.assertIsInstance(injection.value, object) self.assertIsInstance(injection.value, object)
def test_value_with_catalog_bundle_injectable(self):
"""Test Injection value property with catalog bundle."""
class TestCatalog(di.AbstractCatalog):
"""Test catalog."""
provider = di.Provider()
injection = di.Injection('some_arg_name',
TestCatalog.Bundle(TestCatalog.provider))
self.assertIsInstance(injection.value, TestCatalog.Bundle)
class KwArgTests(unittest.TestCase): class KwArgTests(unittest.TestCase):
"""Keyword arg injection test cases.""" """Keyword arg injection test cases."""

View File

@ -210,22 +210,43 @@ class IsCatalogTests(unittest.TestCase):
self.assertFalse(di.is_catalog(object())) self.assertFalse(di.is_catalog(object()))
class IsCatalogSubsetTests(unittest.TestCase): class IsCatalogBundleTests(unittest.TestCase):
"""`is_catalog_subset()` test cases.""" """`is_catalog_bundle()` test cases."""
def test_with_instance(self):
"""Test with instance."""
self.assertTrue(di.is_catalog_bundle(di.CatalogBundle()))
def test_with_cls(self): def test_with_cls(self):
"""Test with class.""" """Test with class."""
self.assertFalse(di.is_catalog_subset(di.CatalogSubset)) self.assertFalse(di.is_catalog_bundle(di.CatalogBundle))
def test_with_instance(self):
"""Test with class."""
self.assertTrue(di.is_catalog_subset(
di.CatalogSubset(catalog=di.AbstractCatalog, providers=tuple())))
def test_with_string(self): def test_with_string(self):
"""Test with string.""" """Test with string."""
self.assertFalse(di.is_catalog_subset('some_string')) self.assertFalse(di.is_catalog_bundle('some_string'))
def test_with_object(self): def test_with_object(self):
"""Test with object.""" """Test with object."""
self.assertFalse(di.is_catalog_subset(object())) self.assertFalse(di.is_catalog_bundle(object()))
class EnsureIsCatalogBundleTests(unittest.TestCase):
"""`ensure_is_catalog_bundle` test cases."""
def test_with_instance(self):
"""Test with instance."""
bundle = di.CatalogBundle()
self.assertIs(di.ensure_is_catalog_bundle(bundle), bundle)
def test_with_class(self):
"""Test with class."""
self.assertRaises(di.Error, di.ensure_is_catalog_bundle,
di.CatalogBundle)
def test_with_string(self):
"""Test with string."""
self.assertRaises(di.Error, di.ensure_is_catalog_bundle, 'some_string')
def test_with_object(self):
"""Test with object."""
self.assertRaises(di.Error, di.ensure_is_catalog_bundle, object())