Merge remote-tracking branch 'origin/dynamic_catalogs'

This commit is contained in:
Roman Mogilatov 2015-11-13 11:33:15 +02:00
commit 176ff06778
8 changed files with 1052 additions and 605 deletions

View File

@ -1,9 +1,10 @@
"""Dependency injector."""
from .catalog import DeclarativeCatalog
from .catalog import AbstractCatalog
from .catalog import CatalogBundle
from .catalog import override
from .catalogs import DeclarativeCatalog
from .catalogs import AbstractCatalog
from .catalogs import DynamicCatalog
from .catalogs import CatalogBundle
from .catalogs import override
from .providers import Provider
from .providers import Delegate
@ -34,11 +35,17 @@ from .utils import is_kwarg_injection
from .utils import is_attribute_injection
from .utils import is_method_injection
from .utils import is_catalog
from .utils import is_dynamic_catalog
from .utils import is_declarative_catalog
from .utils import is_catalog_bundle
from .utils import ensure_is_catalog_bundle
from .errors import Error
from .errors import UndefinedProviderError
# Backward compatibility for versions < 0.11.*
from . import catalogs
catalog = catalogs
VERSION = '0.10.5'
@ -47,6 +54,7 @@ __all__ = (
# Catalogs
'DeclarativeCatalog',
'AbstractCatalog',
'DynamicCatalog',
'CatalogBundle',
'override',
@ -82,11 +90,14 @@ __all__ = (
'is_attribute_injection',
'is_method_injection',
'is_catalog',
'is_dynamic_catalog',
'is_declarative_catalog',
'is_catalog_bundle',
'ensure_is_catalog_bundle',
# Errors
'Error',
'UndefinedProviderError',
# Version
'VERSION'

View File

@ -1,242 +0,0 @@
"""Catalog module."""
import six
from .errors import Error
from .utils import is_provider
from .utils import is_catalog
from .utils import ensure_is_catalog_bundle
@six.python_2_unicode_compatible
class CatalogBundle(object):
"""Bundle of catalog providers."""
catalog = None
""":type: DeclarativeCatalog"""
__IS_CATALOG_BUNDLE__ = True
__slots__ = ('providers', '__dict__')
def __init__(self, *providers):
"""Initializer."""
self.providers = dict((self.catalog.get_provider_bind_name(provider),
provider)
for provider in providers)
self.__dict__.update(self.providers)
super(CatalogBundle, self).__init__()
def get(self, name):
"""Return provider with specified name or raise an error."""
try:
return self.providers[name]
except KeyError:
self._raise_undefined_provider_error(name)
def has(self, name):
"""Check if there is provider with certain name."""
return name in self.providers
def _raise_undefined_provider_error(self, name):
"""Raise error for cases when there is no such provider in bundle."""
raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
def __getattr__(self, item):
"""Raise an error on every attempt to get undefined provider."""
if item.startswith('__') and item.endswith('__'):
return super(CatalogBundle, self).__getattr__(item)
self._raise_undefined_provider_error(item)
def __repr__(self):
"""Return string representation of catalog bundle."""
return '<{0}.{1}.Bundle({2})>'.format(
self.catalog.__module__, self.catalog.__name__,
', '.join(six.iterkeys(self.providers)))
__str__ = __repr__
@six.python_2_unicode_compatible
class DeclarativeCatalogMetaClass(type):
"""Declarative catalog meta class."""
def __new__(mcs, class_name, bases, attributes):
"""Declarative catalog class factory."""
cls = type.__new__(mcs, class_name, bases, attributes)
cls.Bundle = mcs.bundle_cls_factory(cls)
cls.overridden_by = tuple()
cls_providers = tuple((name, provider)
for name, provider in six.iteritems(attributes)
if is_provider(provider))
inherited_providers = tuple((name, provider)
for base in bases if is_catalog(base)
for name, provider in six.iteritems(
base.providers))
providers = cls_providers + inherited_providers
cls.cls_providers = dict(cls_providers)
cls.inherited_providers = dict(inherited_providers)
cls.providers = dict(providers)
cls.provider_names = dict()
for name, provider in providers:
if provider in cls.provider_names:
raise Error('Provider {0} could not be bound to the same '
'catalog (or catalogs hierarchy) more '
'than once'.format(provider))
cls.provider_names[provider] = name
return cls
@classmethod
def bundle_cls_factory(mcs, cls):
"""Create bundle class for catalog."""
return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls))
@property
def is_overridden(cls):
"""Check if catalog is overridden by another catalog."""
return bool(cls.overridden_by)
@property
def last_overriding(cls):
"""Return last overriding catalog."""
try:
return cls.overridden_by[-1]
except (TypeError, IndexError):
raise Error('Catalog {0} is not overridden'.format(str(cls)))
def __repr__(cls):
"""Return string representation of the catalog class."""
return '<{0}.{1}>'.format(cls.__module__, cls.__name__)
__str__ = __repr__
@six.add_metaclass(DeclarativeCatalogMetaClass)
class DeclarativeCatalog(object):
"""Declarative catalog catalog of providers.
:type Bundle: CatalogBundle
:param Bundle: Catalog's bundle class
:type providers: dict[str, dependency_injector.Provider]
:param providers: Dict of all catalog providers, including inherited from
parent catalogs
:type provider_names: dict[dependency_injector.Provider, str]
:param provider_names: Dict of all catalog providers, including inherited
from parent catalogs
:type cls_providers: dict[str, dependency_injector.Provider]
:param cls_providers: Dict of current catalog providers
:type inherited_providers: dict[str, dependency_injector.Provider]
:param inherited_providers: Dict of providers, that are inherited from
parent catalogs
:type overridden_by: tuple[DeclarativeCatalog]
:param overridden_by: Tuple of overriding catalogs
:type is_overridden: bool
:param is_overridden: Read-only, evaluated in runtime, property that is
set to True if catalog is overridden
:type last_overriding: DeclarativeCatalog | None
:param last_overriding: Reference to the last overriding catalog, if any
"""
Bundle = CatalogBundle
cls_providers = dict()
inherited_providers = dict()
providers = dict()
provider_names = dict()
overridden_by = tuple()
is_overridden = bool
last_overriding = None
__IS_CATALOG__ = True
@classmethod
def is_bundle_owner(cls, bundle):
"""Check if catalog is bundle owner."""
return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls
@classmethod
def get_provider_bind_name(cls, provider):
"""Return provider's name in catalog."""
if not cls.is_provider_bound(provider):
raise Error('Can not find bind name for {0} in catalog {1}'.format(
provider, cls))
return cls.provider_names[provider]
@classmethod
def is_provider_bound(cls, provider):
"""Check if provider is bound to the catalog."""
return provider in cls.provider_names
@classmethod
def filter(cls, provider_type):
"""Return dict of providers, that are instance of provided type."""
return dict((name, provider)
for name, provider in six.iteritems(cls.providers)
if isinstance(provider, provider_type))
@classmethod
def override(cls, overriding):
"""Override current catalog providers by overriding catalog providers.
:type overriding: DeclarativeCatalog
"""
cls.overridden_by += (overriding,)
for name, provider in six.iteritems(overriding.cls_providers):
cls.providers[name].override(provider)
@classmethod
def reset_last_overriding(cls):
"""Reset last overriding catalog."""
if not cls.is_overridden:
raise Error('Catalog {0} is not overridden'.format(str(cls)))
cls.overridden_by = cls.overridden_by[:-1]
for provider in six.itervalues(cls.providers):
provider.reset_last_overriding()
@classmethod
def reset_override(cls):
"""Reset all overridings for all catalog providers."""
cls.overridden_by = tuple()
for provider in six.itervalues(cls.providers):
provider.reset_override()
@classmethod
def get(cls, name):
"""Return provider with specified name or raises error."""
try:
return cls.providers[name]
except KeyError:
raise Error('{0} has no provider with such name - {1}'.format(
cls, name))
@classmethod
def has(cls, name):
"""Check if there is provider with certain name."""
return name in cls.providers
# Backward compatibility for versions < 0.11.*
AbstractCatalog = DeclarativeCatalog
def override(catalog):
"""Catalog overriding decorator."""
def decorator(overriding_catalog):
"""Overriding decorator."""
catalog.override(overriding_catalog)
return overriding_catalog
return decorator

View File

@ -0,0 +1,434 @@
"""Catalogs module."""
import six
from .errors import Error
from .errors import UndefinedProviderError
from .utils import is_provider
from .utils import is_catalog
from .utils import ensure_is_provider
from .utils import ensure_is_catalog_bundle
@six.python_2_unicode_compatible
class CatalogBundle(object):
"""Bundle of catalog providers."""
catalog = None
""":type: DeclarativeCatalog"""
__IS_CATALOG_BUNDLE__ = True
__slots__ = ('providers', '__dict__')
def __init__(self, *providers):
"""Initializer."""
self.providers = dict()
for provider in providers:
provider_name = self.catalog.get_provider_bind_name(provider)
self.providers[provider_name] = provider
self.__dict__.update(self.providers)
super(CatalogBundle, self).__init__()
@classmethod
def sub_cls_factory(cls, catalog):
"""Create bundle class for catalog.
:rtype: CatalogBundle
:return: Subclass of CatalogBundle
"""
return type('BundleSubclass', (cls,), dict(catalog=catalog))
def get_provider(self, name):
"""Return provider with specified name or raise an error."""
try:
return self.providers[name]
except KeyError:
raise Error('Provider "{0}" is not a part of {1}'.format(name,
self))
def has_provider(self, name):
"""Check if there is provider with certain name."""
return name in self.providers
def __getattr__(self, item):
"""Raise an error on every attempt to get undefined provider."""
if item.startswith('__') and item.endswith('__'):
return super(CatalogBundle, self).__getattr__(item)
raise UndefinedProviderError('Provider "{0}" is not a part '
'of {1}'.format(item, self))
def __repr__(self):
"""Return string representation of catalog bundle."""
return '<{0}.Bundle({1})>'.format(
self.catalog.name, ', '.join(six.iterkeys(self.providers)))
__str__ = __repr__
@six.python_2_unicode_compatible
class DynamicCatalog(object):
"""Catalog of providers."""
__IS_CATALOG__ = True
__slots__ = ('name', 'providers', 'provider_names', 'overridden_by',
'Bundle')
def __init__(self, **providers):
"""Initializer.
:type providers: dict[str, dependency_injector.providers.Provider]
"""
self.name = '.'.join((self.__class__.__module__,
self.__class__.__name__))
self.providers = dict()
self.provider_names = dict()
self.overridden_by = tuple()
self.Bundle = CatalogBundle.sub_cls_factory(self)
self.bind_providers(providers)
super(DynamicCatalog, self).__init__()
def is_bundle_owner(self, bundle):
"""Check if catalog is bundle owner."""
return ensure_is_catalog_bundle(bundle) and bundle.catalog is self
def get_provider_bind_name(self, provider):
"""Return provider's name in catalog."""
if not self.is_provider_bound(provider):
raise Error('Can not find bind name for {0} in catalog {1}'.format(
provider, self))
return self.provider_names[provider]
def is_provider_bound(self, provider):
"""Check if provider is bound to the catalog."""
return provider in self.provider_names
def filter(self, provider_type):
"""Return dict of providers, that are instance of provided type."""
return dict((name, provider)
for name, provider in six.iteritems(self.providers)
if isinstance(provider, provider_type))
@property
def is_overridden(self):
"""Check if catalog is overridden by another catalog."""
return bool(self.overridden_by)
@property
def last_overriding(self):
"""Return last overriding catalog."""
try:
return self.overridden_by[-1]
except (TypeError, IndexError):
raise Error('Catalog {0} is not overridden'.format(self))
def override(self, overriding):
"""Override current catalog providers by overriding catalog providers.
:type overriding: DynamicCatalog
"""
self.overridden_by += (overriding,)
for name, provider in six.iteritems(overriding.providers):
self.get_provider(name).override(provider)
def reset_last_overriding(self):
"""Reset last overriding catalog."""
if not self.is_overridden:
raise Error('Catalog {0} is not overridden'.format(self))
self.overridden_by = self.overridden_by[:-1]
for provider in six.itervalues(self.providers):
provider.reset_last_overriding()
def reset_override(self):
"""Reset all overridings for all catalog providers."""
self.overridden_by = tuple()
for provider in six.itervalues(self.providers):
provider.reset_override()
def get_provider(self, name):
"""Return provider with specified name or raise an error."""
try:
return self.providers[name]
except KeyError:
raise UndefinedProviderError('{0} has no provider with such '
'name - {1}'.format(self, name))
def bind_provider(self, name, provider):
"""Bind provider to catalog with specified name."""
provider = ensure_is_provider(provider)
if name in self.providers:
raise Error('Catalog {0} already has provider with '
'such name - {1}'.format(self, name))
if provider in self.provider_names:
raise Error('Catalog {0} already has such provider '
'instance - {1}'.format(self, provider))
self.providers[name] = provider
self.provider_names[provider] = name
def bind_providers(self, providers):
"""Bind providers dictionary to catalog."""
for name, provider in six.iteritems(providers):
self.bind_provider(name, provider)
def has_provider(self, name):
"""Check if there is provider with certain name."""
return name in self.providers
def unbind_provider(self, name):
"""Remove provider binding."""
provider = self.get_provider(name)
del self.providers[name]
del self.provider_names[provider]
def __getattr__(self, name):
"""Return provider with specified name or raise en error."""
return self.get_provider(name)
def __setattr__(self, name, value):
"""Handle setting of catalog attributes.
Setting of attributes works as usual, but if value of attribute is
provider, this provider will be bound to catalog correctly.
"""
if is_provider(value):
return self.bind_provider(name, value)
return super(DynamicCatalog, self).__setattr__(name, value)
def __delattr__(self, name):
"""Handle deleting of catalog attibute.
Deleting of attributes works as usual, but if value of attribute is
provider, this provider will be unbound from catalog correctly.
"""
self.unbind_provider(name)
def __repr__(self):
"""Return Python representation of catalog."""
return '<{0}({1})>'.format(self.name,
', '.join(six.iterkeys(self.providers)))
__str__ = __repr__
@six.python_2_unicode_compatible
class DeclarativeCatalogMetaClass(type):
"""Declarative catalog meta class."""
def __new__(mcs, class_name, bases, attributes):
"""Declarative catalog class factory."""
cls_providers = tuple((name, provider)
for name, provider in six.iteritems(attributes)
if is_provider(provider))
inherited_providers = tuple((name, provider)
for base in bases if is_catalog(base)
for name, provider in six.iteritems(
base.providers))
providers = cls_providers + inherited_providers
cls = type.__new__(mcs, class_name, bases, attributes)
cls.catalog = DynamicCatalog()
cls.catalog.name = '.'.join((cls.__module__, cls.__name__))
cls.catalog.bind_providers(dict(providers))
cls.cls_providers = dict(cls_providers)
cls.inherited_providers = dict(inherited_providers)
cls.Bundle = cls.catalog.Bundle
return cls
@property
def name(cls):
"""Return catalog's name."""
return cls.catalog.name
@property
def providers(cls):
"""Return dict of catalog's providers."""
return cls.catalog.providers
@property
def overridden_by(cls):
"""Return tuple of overriding catalogs."""
return cls.catalog.overridden_by
@property
def is_overridden(cls):
"""Check if catalog is overridden by another catalog."""
return cls.catalog.is_overridden
@property
def last_overriding(cls):
"""Return last overriding catalog."""
return cls.catalog.last_overriding
def __getattr__(cls, name):
"""Return provider with specified name or raise en error."""
raise UndefinedProviderError('There is no provider "{0}" in '
'catalog {1}'.format(name, cls))
def __setattr__(cls, name, value):
"""Handle setting of catalog attributes.
Setting of attributes works as usual, but if value of attribute is
provider, this provider will be bound to catalog correctly.
"""
if is_provider(value):
setattr(cls.catalog, name, value)
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
def __delattr__(cls, name):
"""Handle deleting of catalog attibute.
Deleting of attributes works as usual, but if value of attribute is
provider, this provider will be unbound from catalog correctly.
"""
if is_provider(getattr(cls, name)):
delattr(cls.catalog, name)
return super(DeclarativeCatalogMetaClass, cls).__delattr__(name)
def __repr__(cls):
"""Return string representation of the catalog."""
return '<{0}({1})>'.format(cls.name,
', '.join(six.iterkeys(cls.providers)))
__str__ = __repr__
@six.add_metaclass(DeclarativeCatalogMetaClass)
class DeclarativeCatalog(object):
"""Declarative catalog catalog of providers.
:type name: str
:param name: Catalog's name
:type catalog: DynamicCatalog
:param catalog: Instance of dynamic catalog
:type Bundle: CatalogBundle
:param Bundle: Catalog's bundle class
:type providers: dict[str, dependency_injector.Provider]
:param providers: Dict of all catalog providers, including inherited from
parent catalogs
:type cls_providers: dict[str, dependency_injector.Provider]
:param cls_providers: Dict of current catalog providers
:type inherited_providers: dict[str, dependency_injector.Provider]
:param inherited_providers: Dict of providers, that are inherited from
parent catalogs
:type overridden_by: tuple[DeclarativeCatalog]
:param overridden_by: Tuple of overriding catalogs
:type is_overridden: bool
:param is_overridden: Read-only, evaluated in runtime, property that is
set to True if catalog is overridden
:type last_overriding: DeclarativeCatalog | None
:param last_overriding: Reference to the last overriding catalog, if any
"""
Bundle = CatalogBundle
name = str()
cls_providers = dict()
inherited_providers = dict()
providers = dict()
overridden_by = tuple()
is_overridden = bool
last_overriding = None
catalog = DynamicCatalog
__IS_CATALOG__ = True
@classmethod
def is_bundle_owner(cls, bundle):
"""Check if catalog is bundle owner."""
return cls.catalog.is_bundle_owner(bundle)
@classmethod
def get_provider_bind_name(cls, provider):
"""Return provider's name in catalog."""
return cls.catalog.get_provider_bind_name(provider)
@classmethod
def is_provider_bound(cls, provider):
"""Check if provider is bound to the catalog."""
return cls.catalog.is_provider_bound(provider)
@classmethod
def filter(cls, provider_type):
"""Return dict of providers, that are instance of provided type."""
return cls.catalog.filter(provider_type)
@classmethod
def override(cls, overriding):
"""Override current catalog providers by overriding catalog providers.
:type overriding: DeclarativeCatalog | DynamicCatalog
"""
return cls.catalog.override(overriding)
@classmethod
def reset_last_overriding(cls):
"""Reset last overriding catalog."""
cls.catalog.reset_last_overriding()
@classmethod
def reset_override(cls):
"""Reset all overridings for all catalog providers."""
cls.catalog.reset_override()
@classmethod
def get_provider(cls, name):
"""Return provider with specified name or raise an error."""
return cls.catalog.get_provider(name)
@classmethod
def bind_provider(cls, name, provider):
"""Bind provider to catalog with specified name."""
setattr(cls, name, provider)
@classmethod
def bind_providers(cls, providers):
"""Bind providers dictionary to catalog."""
for name, provider in six.iteritems(providers):
setattr(cls, name, provider)
@classmethod
def has_provider(cls, name):
"""Check if there is provider with certain name."""
return hasattr(cls, name)
@classmethod
def unbind_provider(cls, name):
"""Remove provider binding."""
delattr(cls, name)
get = get_provider # Backward compatibility for versions < 0.11.*
has = has_provider # Backward compatibility for versions < 0.11.*
# Backward compatibility for versions < 0.11.*
AbstractCatalog = DeclarativeCatalog
def override(catalog):
"""Catalog overriding decorator."""
def decorator(overriding_catalog):
"""Overriding decorator."""
catalog.override(overriding_catalog)
return overriding_catalog
return decorator

View File

@ -3,3 +3,7 @@
class Error(Exception):
"""Base error."""
class UndefinedProviderError(Error, AttributeError):
"""Undefined provider error."""

View File

@ -73,11 +73,20 @@ def is_method_injection(instance):
def is_catalog(instance):
"""Check if instance is catalog instance."""
return (isinstance(instance, six.class_types) and
hasattr(instance, '__IS_CATALOG__') and
return (hasattr(instance, '__IS_CATALOG__') and
getattr(instance, '__IS_CATALOG__', False) is True)
def is_dynamic_catalog(instance):
"""Check if instance is dynamic catalog instance."""
return (not isinstance(instance, six.class_types) and is_catalog(instance))
def is_declarative_catalog(instance):
"""Check if instance is declarative catalog instance."""
return (isinstance(instance, six.class_types) and is_catalog(instance))
def is_catalog_bundle(instance):
"""Check if instance is catalog bundle instance."""
return (not isinstance(instance, six.class_types) and

View File

@ -1,353 +0,0 @@
"""Dependency injector catalog unittests."""
import unittest2 as unittest
import dependency_injector as di
class CatalogA(di.DeclarativeCatalog):
"""Test catalog A."""
p11 = di.Provider()
p12 = di.Provider()
class CatalogB(CatalogA):
"""Test catalog B."""
p21 = di.Provider()
p22 = di.Provider()
class CatalogC(CatalogB):
"""Test catalog C."""
p31 = di.Provider()
p32 = di.Provider()
class CatalogsInheritanceTests(unittest.TestCase):
"""Catalogs inheritance tests."""
def test_cls_providers(self):
"""Test `di.DeclarativeCatalog.cls_providers` contents."""
self.assertDictEqual(CatalogA.cls_providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12))
self.assertDictEqual(CatalogB.cls_providers,
dict(p21=CatalogB.p21,
p22=CatalogB.p22))
self.assertDictEqual(CatalogC.cls_providers,
dict(p31=CatalogC.p31,
p32=CatalogC.p32))
def test_inherited_providers(self):
"""Test `di.DeclarativeCatalog.inherited_providers` contents."""
self.assertDictEqual(CatalogA.inherited_providers, dict())
self.assertDictEqual(CatalogB.inherited_providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12))
self.assertDictEqual(CatalogC.inherited_providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12,
p21=CatalogB.p21,
p22=CatalogB.p22))
def test_providers(self):
"""Test `di.DeclarativeCatalog.inherited_providers` contents."""
self.assertDictEqual(CatalogA.providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12))
self.assertDictEqual(CatalogB.providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12,
p21=CatalogB.p21,
p22=CatalogB.p22))
self.assertDictEqual(CatalogC.providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12,
p21=CatalogB.p21,
p22=CatalogB.p22,
p31=CatalogC.p31,
p32=CatalogC.p32))
class CatalogProvidersBindingTests(unittest.TestCase):
"""Catalog providers binding test cases."""
def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs."""
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p11), 'p11')
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p12))
self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p12), 'p12')
def test_provider_binding_to_different_catalogs(self):
"""Test that provider could be bound to different catalogs."""
p11 = CatalogA.p11
p12 = CatalogA.p12
class CatalogD(di.DeclarativeCatalog):
"""Test catalog."""
pd1 = p11
pd2 = p12
class CatalogE(di.DeclarativeCatalog):
"""Test catalog."""
pe1 = p11
pe2 = p12
self.assertTrue(CatalogA.is_provider_bound(p11))
self.assertTrue(CatalogD.is_provider_bound(p11))
self.assertTrue(CatalogE.is_provider_bound(p11))
self.assertEquals(CatalogA.get_provider_bind_name(p11), 'p11')
self.assertEquals(CatalogD.get_provider_bind_name(p11), 'pd1')
self.assertEquals(CatalogE.get_provider_bind_name(p11), 'pe1')
self.assertTrue(CatalogA.is_provider_bound(p12))
self.assertTrue(CatalogD.is_provider_bound(p12))
self.assertTrue(CatalogE.is_provider_bound(p12))
self.assertEquals(CatalogA.get_provider_bind_name(p12), 'p12')
self.assertEquals(CatalogD.get_provider_bind_name(p12), 'pd2')
self.assertEquals(CatalogE.get_provider_bind_name(p12), 'pe2')
def test_provider_rebinding_to_the_same_catalog(self):
"""Test provider rebinding to the same catalog."""
with self.assertRaises(di.Error):
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
p2 = p1
def test_provider_rebinding_to_the_same_catalogs_hierarchy(self):
"""Test provider rebinding to the same catalogs hierarchy."""
class TestCatalog1(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
with self.assertRaises(di.Error):
class TestCatalog2(TestCatalog1):
"""Test catalog."""
p2 = TestCatalog1.p1
class CatalogBundleTests(unittest.TestCase):
"""Catalog bundle test cases."""
def setUp(self):
"""Set test environment up."""
self.bundle = CatalogC.Bundle(CatalogC.p11,
CatalogC.p12)
def test_get_attr_from_bundle(self):
"""Test get providers (attribute) from catalog bundle."""
self.assertIs(self.bundle.p11, CatalogC.p11)
self.assertIs(self.bundle.p12, CatalogC.p12)
def test_get_attr_not_from_bundle(self):
"""Test get providers (attribute) that are not in bundle."""
self.assertRaises(di.Error, getattr, self.bundle, 'p21')
self.assertRaises(di.Error, getattr, self.bundle, 'p22')
self.assertRaises(di.Error, getattr, self.bundle, 'p31')
self.assertRaises(di.Error, getattr, self.bundle, 'p32')
def test_get_method_from_bundle(self):
"""Test get providers (get() method) from bundle."""
self.assertIs(self.bundle.get('p11'), CatalogC.p11)
self.assertIs(self.bundle.get('p12'), CatalogC.p12)
def test_get_method_not_from_bundle(self):
"""Test get providers (get() method) that are not in bundle."""
self.assertRaises(di.Error, self.bundle.get, 'p21')
self.assertRaises(di.Error, self.bundle.get, 'p22')
self.assertRaises(di.Error, self.bundle.get, 'p31')
self.assertRaises(di.Error, self.bundle.get, 'p32')
def test_has(self):
"""Test checks of providers availability in bundle."""
self.assertTrue(self.bundle.has('p11'))
self.assertTrue(self.bundle.has('p12'))
self.assertFalse(self.bundle.has('p21'))
self.assertFalse(self.bundle.has('p22'))
self.assertFalse(self.bundle.has('p31'))
self.assertFalse(self.bundle.has('p32'))
def test_create_bundle_with_unbound_provider(self):
"""Test that bundle is not created with unbound provider."""
self.assertRaises(di.Error, CatalogC.Bundle, di.Provider())
def test_create_bundle_with_another_catalog_provider(self):
"""Test that bundle can not contain another catalog's provider."""
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
provider = di.Provider()
self.assertRaises(di.Error,
CatalogC.Bundle, CatalogC.p31, TestCatalog.provider)
def test_create_bundle_with_another_catalog_provider_with_same_name(self):
"""Test that bundle can not contain another catalog's provider."""
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
p31 = di.Provider()
self.assertRaises(di.Error,
CatalogC.Bundle, CatalogC.p31, TestCatalog.p31)
def test_is_bundle_owner(self):
"""Test that catalog bundle is owned by catalog."""
self.assertTrue(CatalogC.is_bundle_owner(self.bundle))
self.assertFalse(CatalogB.is_bundle_owner(self.bundle))
self.assertFalse(CatalogA.is_bundle_owner(self.bundle))
def test_is_bundle_owner_with_not_bundle_instance(self):
"""Test that check of bundle ownership raises error with not bundle."""
self.assertRaises(di.Error, CatalogC.is_bundle_owner, object())
class CatalogTests(unittest.TestCase):
"""Catalog test cases."""
def test_get(self):
"""Test getting of providers using get() method."""
self.assertIs(CatalogC.get('p11'), CatalogC.p11)
self.assertIs(CatalogC.get('p12'), CatalogC.p12)
self.assertIs(CatalogC.get('p22'), CatalogC.p22)
self.assertIs(CatalogC.get('p22'), CatalogC.p22)
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
def test_get_undefined(self):
"""Test getting of undefined providers using get() method."""
self.assertRaises(di.Error, CatalogC.get, 'undefined')
def test_has(self):
"""Test checks of providers availability in catalog."""
self.assertTrue(CatalogC.has('p11'))
self.assertTrue(CatalogC.has('p12'))
self.assertTrue(CatalogC.has('p21'))
self.assertTrue(CatalogC.has('p22'))
self.assertTrue(CatalogC.has('p31'))
self.assertTrue(CatalogC.has('p32'))
self.assertFalse(CatalogC.has('undefined'))
def test_filter_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type."""
self.assertTrue(len(CatalogC.filter(di.Provider)) == 6)
self.assertTrue(len(CatalogC.filter(di.Value)) == 0)
class OverrideTests(unittest.TestCase):
"""Catalog overriding and override decorator test cases."""
class Catalog(di.DeclarativeCatalog):
"""Test catalog."""
obj = di.Object(object())
another_obj = di.Object(object())
def tearDown(self):
"""Tear test environment down."""
self.Catalog.reset_override()
def test_overriding(self):
"""Test catalog overriding with another catalog."""
@di.override(self.Catalog)
class OverridingCatalog(self.Catalog):
"""Overriding catalog."""
obj = di.Value(1)
another_obj = di.Value(2)
self.assertEqual(self.Catalog.obj(), 1)
self.assertEqual(self.Catalog.another_obj(), 2)
def test_is_overridden(self):
"""Test catalog is_overridden property."""
self.assertFalse(self.Catalog.is_overridden)
@di.override(self.Catalog)
class OverridingCatalog(self.Catalog):
"""Overriding catalog."""
self.assertTrue(self.Catalog.is_overridden)
def test_last_overriding(self):
"""Test catalog last_overriding property."""
@di.override(self.Catalog)
class OverridingCatalog1(self.Catalog):
"""Overriding catalog."""
@di.override(self.Catalog)
class OverridingCatalog2(self.Catalog):
"""Overriding catalog."""
self.assertIs(self.Catalog.last_overriding, OverridingCatalog2)
def test_last_overriding_on_not_overridden(self):
"""Test catalog last_overriding property on not overridden catalog."""
with self.assertRaises(di.Error):
self.Catalog.last_overriding
def test_reset_last_overriding(self):
"""Test resetting last overriding catalog."""
@di.override(self.Catalog)
class OverridingCatalog1(self.Catalog):
"""Overriding catalog."""
obj = di.Value(1)
another_obj = di.Value(2)
@di.override(self.Catalog)
class OverridingCatalog2(self.Catalog):
"""Overriding catalog."""
obj = di.Value(3)
another_obj = di.Value(4)
self.Catalog.reset_last_overriding()
self.assertEqual(self.Catalog.obj(), 1)
self.assertEqual(self.Catalog.another_obj(), 2)
def test_reset_last_overriding_when_not_overridden(self):
"""Test resetting last overriding catalog when it is not overridden."""
with self.assertRaises(di.Error):
self.Catalog.reset_last_overriding()
def test_reset_override(self):
"""Test resetting all catalog overrides."""
@di.override(self.Catalog)
class OverridingCatalog1(self.Catalog):
"""Overriding catalog."""
obj = di.Value(1)
another_obj = di.Value(2)
@di.override(self.Catalog)
class OverridingCatalog2(self.Catalog):
"""Overriding catalog."""
obj = di.Value(3)
another_obj = di.Value(4)
self.Catalog.reset_override()
self.assertIsInstance(self.Catalog.obj(), object)
self.assertIsInstance(self.Catalog.another_obj(), object)
class AbstractCatalogCompatibilityTest(unittest.TestCase):
"""Test backward compatibility with di.AbstractCatalog."""
def test_compatibility(self):
"""Test that di.AbstractCatalog is available."""
self.assertIs(di.DeclarativeCatalog, di.AbstractCatalog)

560
tests/test_catalogs.py Normal file
View File

@ -0,0 +1,560 @@
"""Dependency injector catalogs unittests."""
import unittest2 as unittest
import dependency_injector as di
class CatalogA(di.DeclarativeCatalog):
"""Test catalog A."""
p11 = di.Provider()
p12 = di.Provider()
class CatalogB(CatalogA):
"""Test catalog B."""
p21 = di.Provider()
p22 = di.Provider()
class CatalogBundleTests(unittest.TestCase):
"""Catalog bundle test cases."""
def setUp(self):
"""Set test environment up."""
self.bundle = CatalogB.Bundle(CatalogB.p11,
CatalogB.p12)
def test_get_attr_from_bundle(self):
"""Test get providers (attribute) from catalog bundle."""
self.assertIs(self.bundle.p11, CatalogA.p11)
self.assertIs(self.bundle.p12, CatalogA.p12)
def test_get_attr_not_from_bundle(self):
"""Test get providers (attribute) that are not in bundle."""
self.assertRaises(di.Error, getattr, self.bundle, 'p21')
self.assertRaises(di.Error, getattr, self.bundle, 'p22')
def test_get_method_from_bundle(self):
"""Test get providers (get() method) from bundle."""
self.assertIs(self.bundle.get_provider('p11'), CatalogB.p11)
self.assertIs(self.bundle.get_provider('p12'), CatalogB.p12)
def test_get_method_not_from_bundle(self):
"""Test get providers (get() method) that are not in bundle."""
self.assertRaises(di.Error, self.bundle.get_provider, 'p21')
self.assertRaises(di.Error, self.bundle.get_provider, 'p22')
def test_has(self):
"""Test checks of providers availability in bundle."""
self.assertTrue(self.bundle.has_provider('p11'))
self.assertTrue(self.bundle.has_provider('p12'))
self.assertFalse(self.bundle.has_provider('p21'))
self.assertFalse(self.bundle.has_provider('p22'))
def test_hasattr(self):
"""Test checks of providers availability in bundle."""
self.assertTrue(hasattr(self.bundle, 'p11'))
self.assertTrue(hasattr(self.bundle, 'p12'))
self.assertFalse(hasattr(self.bundle, 'p21'))
self.assertFalse(hasattr(self.bundle, 'p22'))
def test_create_bundle_with_unbound_provider(self):
"""Test that bundle is not created with unbound provider."""
self.assertRaises(di.Error, CatalogB.Bundle, di.Provider())
def test_create_bundle_with_another_catalog_provider(self):
"""Test that bundle can not contain another catalog's provider."""
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
provider = di.Provider()
self.assertRaises(di.Error,
CatalogB.Bundle, CatalogB.p21, TestCatalog.provider)
def test_create_bundle_with_another_catalog_provider_with_same_name(self):
"""Test that bundle can not contain another catalog's provider."""
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
p21 = di.Provider()
self.assertRaises(di.Error,
CatalogB.Bundle, CatalogB.p21, TestCatalog.p21)
def test_is_bundle_owner(self):
"""Test that catalog bundle is owned by catalog."""
self.assertTrue(CatalogB.is_bundle_owner(self.bundle))
self.assertFalse(CatalogA.is_bundle_owner(self.bundle))
def test_is_bundle_owner_with_not_bundle_instance(self):
"""Test that check of bundle ownership raises error with not bundle."""
self.assertRaises(di.Error, CatalogB.is_bundle_owner, object())
class DynamicCatalogTests(unittest.TestCase):
"""Dynamic catalog tests."""
catalog = None
""":type: di.DynamicCatalog"""
def setUp(self):
"""Set test environment up."""
self.catalog = di.DynamicCatalog(p1=di.Provider(),
p2=di.Provider())
self.catalog.name = 'TestCatalog'
def test_providers(self):
"""Test `di.DeclarativeCatalog.inherited_providers` contents."""
self.assertDictEqual(self.catalog.providers,
dict(p1=self.catalog.p1,
p2=self.catalog.p2))
def test_bind_provider(self):
"""Test setting of provider via bind_provider() to catalog."""
px = di.Provider()
py = di.Provider()
self.catalog.bind_provider('px', px)
self.catalog.bind_provider('py', py)
self.assertIs(self.catalog.px, px)
self.assertIs(self.catalog.get_provider('px'), px)
self.assertIs(self.catalog.py, py)
self.assertIs(self.catalog.get_provider('py'), py)
def test_bind_providers(self):
"""Test setting of provider via bind_providers() to catalog."""
px = di.Provider()
py = di.Provider()
self.catalog.bind_providers(dict(px=px, py=py))
self.assertIs(self.catalog.px, px)
self.assertIs(self.catalog.get_provider('px'), px)
self.assertIs(self.catalog.py, py)
self.assertIs(self.catalog.get_provider('py'), py)
def test_setattr(self):
"""Test setting of providers via attributes to catalog."""
px = di.Provider()
py = di.Provider()
self.catalog.px = px
self.catalog.py = py
self.assertIs(self.catalog.px, px)
self.assertIs(self.catalog.get_provider('px'), px)
self.assertIs(self.catalog.py, py)
self.assertIs(self.catalog.get_provider('py'), py)
def test_unbind_provider(self):
"""Test that catalog unbinds provider correct."""
self.catalog.px = di.Provider()
self.catalog.unbind_provider('px')
self.assertFalse(self.catalog.has_provider('px'))
def test_unbind_via_delattr(self):
"""Test that catalog unbinds provider correct."""
self.catalog.px = di.Provider()
del self.catalog.px
self.assertFalse(self.catalog.has_provider('px'))
def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs."""
self.assertTrue(self.catalog.is_provider_bound(self.catalog.p1))
self.assertEquals(
self.catalog.get_provider_bind_name(self.catalog.p1), 'p1')
self.assertTrue(self.catalog.is_provider_bound(self.catalog.p2))
self.assertEquals(
self.catalog.get_provider_bind_name(self.catalog.p2), 'p2')
def test_provider_binding_to_different_catalogs(self):
"""Test that provider could be bound to different catalogs."""
p1 = self.catalog.p1
p2 = self.catalog.p2
catalog_a = di.DynamicCatalog(pa1=p1, pa2=p2)
catalog_b = di.DynamicCatalog(pb1=p1, pb2=p2)
self.assertTrue(self.catalog.is_provider_bound(p1))
self.assertTrue(catalog_a.is_provider_bound(p1))
self.assertTrue(catalog_b.is_provider_bound(p1))
self.assertEquals(self.catalog.get_provider_bind_name(p1), 'p1')
self.assertEquals(catalog_a.get_provider_bind_name(p1), 'pa1')
self.assertEquals(catalog_b.get_provider_bind_name(p1), 'pb1')
self.assertTrue(self.catalog.is_provider_bound(p2))
self.assertTrue(catalog_a.is_provider_bound(p2))
self.assertTrue(catalog_b.is_provider_bound(p2))
self.assertEquals(self.catalog.get_provider_bind_name(p2), 'p2')
self.assertEquals(catalog_a.get_provider_bind_name(p2), 'pa2')
self.assertEquals(catalog_b.get_provider_bind_name(p2), 'pb2')
def test_provider_rebinding_to_the_same_catalog(self):
"""Test provider rebinding to the same catalog."""
with self.assertRaises(di.Error):
self.catalog.p3 = self.catalog.p1
def test_provider_binding_with_the_same_name(self):
"""Test binding of provider with the same name."""
with self.assertRaises(di.Error):
self.catalog.bind_provider('p1', di.Provider())
def test_get(self):
"""Test getting of providers using get() method."""
self.assertIs(self.catalog.get_provider('p1'), self.catalog.p1)
self.assertIs(self.catalog.get_provider('p2'), self.catalog.p2)
def test_get_undefined(self):
"""Test getting of undefined providers using get() method."""
with self.assertRaises(di.UndefinedProviderError):
self.catalog.get_provider('undefined')
with self.assertRaises(di.UndefinedProviderError):
self.catalog.undefined
def test_has_provider(self):
"""Test checks of providers availability in catalog."""
self.assertTrue(self.catalog.has_provider('p1'))
self.assertTrue(self.catalog.has_provider('p2'))
self.assertFalse(self.catalog.has_provider('undefined'))
def test_filter_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type."""
self.assertTrue(len(self.catalog.filter(di.Provider)) == 2)
self.assertTrue(len(self.catalog.filter(di.Value)) == 0)
def test_repr(self):
"""Test catalog representation."""
self.assertIn('TestCatalog', repr(self.catalog))
self.assertIn('p1', repr(self.catalog))
self.assertIn('p2', repr(self.catalog))
class DeclarativeCatalogTests(unittest.TestCase):
"""Declarative catalog tests."""
def test_cls_providers(self):
"""Test `di.DeclarativeCatalog.cls_providers` contents."""
self.assertDictEqual(CatalogA.cls_providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12))
self.assertDictEqual(CatalogB.cls_providers,
dict(p21=CatalogB.p21,
p22=CatalogB.p22))
def test_inherited_providers(self):
"""Test `di.DeclarativeCatalog.inherited_providers` contents."""
self.assertDictEqual(CatalogA.inherited_providers, dict())
self.assertDictEqual(CatalogB.inherited_providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12))
def test_providers(self):
"""Test `di.DeclarativeCatalog.inherited_providers` contents."""
self.assertDictEqual(CatalogA.providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12))
self.assertDictEqual(CatalogB.providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12,
p21=CatalogB.p21,
p22=CatalogB.p22))
def test_bind_provider(self):
"""Test setting of provider via bind_provider() to catalog."""
px = di.Provider()
py = di.Provider()
CatalogA.bind_provider('px', px)
CatalogA.bind_provider('py', py)
self.assertIs(CatalogA.px, px)
self.assertIs(CatalogA.get_provider('px'), px)
self.assertIs(CatalogA.catalog.px, px)
self.assertIs(CatalogA.py, py)
self.assertIs(CatalogA.get_provider('py'), py)
self.assertIs(CatalogA.catalog.py, py)
del CatalogA.px
del CatalogA.py
def test_bind_providers(self):
"""Test setting of provider via bind_providers() to catalog."""
px = di.Provider()
py = di.Provider()
CatalogB.bind_providers(dict(px=px, py=py))
self.assertIs(CatalogB.px, px)
self.assertIs(CatalogB.get_provider('px'), px)
self.assertIs(CatalogB.catalog.px, px)
self.assertIs(CatalogB.py, py)
self.assertIs(CatalogB.get_provider('py'), py)
self.assertIs(CatalogB.catalog.py, py)
del CatalogB.px
del CatalogB.py
def test_setattr(self):
"""Test setting of providers via attributes to catalog."""
px = di.Provider()
py = di.Provider()
CatalogB.px = px
CatalogB.py = py
self.assertIs(CatalogB.px, px)
self.assertIs(CatalogB.get_provider('px'), px)
self.assertIs(CatalogB.catalog.px, px)
self.assertIs(CatalogB.py, py)
self.assertIs(CatalogB.get_provider('py'), py)
self.assertIs(CatalogB.catalog.py, py)
del CatalogB.px
del CatalogB.py
def test_unbind_provider(self):
"""Test that catalog unbinds provider correct."""
CatalogB.px = di.Provider()
CatalogB.unbind_provider('px')
self.assertFalse(CatalogB.has_provider('px'))
def test_unbind_via_delattr(self):
"""Test that catalog unbinds provider correct."""
CatalogB.px = di.Provider()
del CatalogB.px
self.assertFalse(CatalogB.has_provider('px'))
def test_provider_is_bound(self):
"""Test that providers are bound to the catalogs."""
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p11), 'p11')
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p12))
self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p12), 'p12')
def test_provider_binding_to_different_catalogs(self):
"""Test that provider could be bound to different catalogs."""
p11 = CatalogA.p11
p12 = CatalogA.p12
class CatalogD(di.DeclarativeCatalog):
"""Test catalog."""
pd1 = p11
pd2 = p12
class CatalogE(di.DeclarativeCatalog):
"""Test catalog."""
pe1 = p11
pe2 = p12
self.assertTrue(CatalogA.is_provider_bound(p11))
self.assertTrue(CatalogD.is_provider_bound(p11))
self.assertTrue(CatalogE.is_provider_bound(p11))
self.assertEquals(CatalogA.get_provider_bind_name(p11), 'p11')
self.assertEquals(CatalogD.get_provider_bind_name(p11), 'pd1')
self.assertEquals(CatalogE.get_provider_bind_name(p11), 'pe1')
self.assertTrue(CatalogA.is_provider_bound(p12))
self.assertTrue(CatalogD.is_provider_bound(p12))
self.assertTrue(CatalogE.is_provider_bound(p12))
self.assertEquals(CatalogA.get_provider_bind_name(p12), 'p12')
self.assertEquals(CatalogD.get_provider_bind_name(p12), 'pd2')
self.assertEquals(CatalogE.get_provider_bind_name(p12), 'pe2')
def test_provider_rebinding_to_the_same_catalog(self):
"""Test provider rebinding to the same catalog."""
with self.assertRaises(di.Error):
class TestCatalog(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
p2 = p1
def test_provider_rebinding_to_the_same_catalogs_hierarchy(self):
"""Test provider rebinding to the same catalogs hierarchy."""
class TestCatalog1(di.DeclarativeCatalog):
"""Test catalog."""
p1 = di.Provider()
with self.assertRaises(di.Error):
class TestCatalog2(TestCatalog1):
"""Test catalog."""
p2 = TestCatalog1.p1
def test_get(self):
"""Test getting of providers using get() method."""
self.assertIs(CatalogB.get('p11'), CatalogB.p11)
self.assertIs(CatalogB.get('p12'), CatalogB.p12)
self.assertIs(CatalogB.get('p22'), CatalogB.p22)
self.assertIs(CatalogB.get('p22'), CatalogB.p22)
self.assertIs(CatalogB.get_provider('p11'), CatalogB.p11)
self.assertIs(CatalogB.get_provider('p12'), CatalogB.p12)
self.assertIs(CatalogB.get_provider('p22'), CatalogB.p22)
self.assertIs(CatalogB.get_provider('p22'), CatalogB.p22)
def test_get_undefined(self):
"""Test getting of undefined providers using get() method."""
with self.assertRaises(di.UndefinedProviderError):
CatalogB.get('undefined')
with self.assertRaises(di.UndefinedProviderError):
CatalogB.get_provider('undefined')
with self.assertRaises(di.UndefinedProviderError):
CatalogB.undefined
def test_has(self):
"""Test checks of providers availability in catalog."""
self.assertTrue(CatalogB.has('p11'))
self.assertTrue(CatalogB.has('p12'))
self.assertTrue(CatalogB.has('p21'))
self.assertTrue(CatalogB.has('p22'))
self.assertFalse(CatalogB.has('undefined'))
self.assertTrue(CatalogB.has_provider('p11'))
self.assertTrue(CatalogB.has_provider('p12'))
self.assertTrue(CatalogB.has_provider('p21'))
self.assertTrue(CatalogB.has_provider('p22'))
self.assertFalse(CatalogB.has_provider('undefined'))
def test_filter_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type."""
self.assertTrue(len(CatalogB.filter(di.Provider)) == 4)
self.assertTrue(len(CatalogB.filter(di.Value)) == 0)
def test_repr(self):
"""Test catalog representation."""
self.assertIn('CatalogA', repr(CatalogA))
self.assertIn('p11', repr(CatalogA))
self.assertIn('p12', repr(CatalogA))
self.assertIn('CatalogB', repr(CatalogB))
self.assertIn('p11', repr(CatalogB))
self.assertIn('p12', repr(CatalogB))
self.assertIn('p21', repr(CatalogB))
self.assertIn('p22', repr(CatalogB))
def test_abstract_catalog_backward_compatibility(self):
"""Test that di.AbstractCatalog is available."""
self.assertIs(di.DeclarativeCatalog, di.AbstractCatalog)
class OverrideTests(unittest.TestCase):
"""Catalog overriding and override decorator test cases."""
def tearDown(self):
"""Tear test environment down."""
CatalogA.reset_override()
def test_overriding(self):
"""Test catalog overriding with another catalog."""
@di.override(CatalogA)
class OverridingCatalog(di.DeclarativeCatalog):
"""Overriding catalog."""
p11 = di.Value(1)
p12 = di.Value(2)
self.assertEqual(CatalogA.p11(), 1)
self.assertEqual(CatalogA.p12(), 2)
self.assertEqual(len(CatalogA.overridden_by), 1)
def test_overriding_with_dynamic_catalog(self):
"""Test catalog overriding with another dynamic catalog."""
CatalogA.override(di.DynamicCatalog(p11=di.Value(1),
p12=di.Value(2)))
self.assertEqual(CatalogA.p11(), 1)
self.assertEqual(CatalogA.p12(), 2)
self.assertEqual(len(CatalogA.overridden_by), 1)
def test_is_overridden(self):
"""Test catalog is_overridden property."""
self.assertFalse(CatalogA.is_overridden)
@di.override(CatalogA)
class OverridingCatalog(di.DeclarativeCatalog):
"""Overriding catalog."""
self.assertTrue(CatalogA.is_overridden)
def test_last_overriding(self):
"""Test catalog last_overriding property."""
@di.override(CatalogA)
class OverridingCatalog1(di.DeclarativeCatalog):
"""Overriding catalog."""
@di.override(CatalogA)
class OverridingCatalog2(di.DeclarativeCatalog):
"""Overriding catalog."""
self.assertIs(CatalogA.last_overriding, OverridingCatalog2)
def test_last_overriding_on_not_overridden(self):
"""Test catalog last_overriding property on not overridden catalog."""
with self.assertRaises(di.Error):
CatalogA.last_overriding
def test_reset_last_overriding(self):
"""Test resetting last overriding catalog."""
@di.override(CatalogA)
class OverridingCatalog1(di.DeclarativeCatalog):
"""Overriding catalog."""
p11 = di.Value(1)
p12 = di.Value(2)
@di.override(CatalogA)
class OverridingCatalog2(di.DeclarativeCatalog):
"""Overriding catalog."""
p11 = di.Value(3)
p12 = di.Value(4)
CatalogA.reset_last_overriding()
self.assertEqual(CatalogA.p11(), 1)
self.assertEqual(CatalogA.p12(), 2)
def test_reset_last_overriding_when_not_overridden(self):
"""Test resetting last overriding catalog when it is not overridden."""
with self.assertRaises(di.Error):
CatalogA.reset_last_overriding()
def test_reset_override(self):
"""Test resetting all catalog overrides."""
@di.override(CatalogA)
class OverridingCatalog1(di.DeclarativeCatalog):
"""Overriding catalog."""
p11 = di.Value(1)
p12 = di.Value(2)
@di.override(CatalogA)
class OverridingCatalog2(di.DeclarativeCatalog):
"""Overriding catalog."""
p11 = di.Value(3)
p12 = di.Value(4)
CatalogA.reset_override()
self.assertFalse(CatalogA.p11.is_overridden)
self.assertFalse(CatalogA.p12.is_overridden)

View File

@ -211,13 +211,13 @@ class IsMethodInjectionTests(unittest.TestCase):
class IsCatalogTests(unittest.TestCase):
"""`is_catalog()` test cases."""
def test_with_cls(self):
def test_with_declarative_catalog(self):
"""Test with class."""
self.assertTrue(di.is_catalog(di.AbstractCatalog))
self.assertTrue(di.is_catalog(di.DeclarativeCatalog))
def test_with_instance(self):
def test_with_dynamic_catalog(self):
"""Test with class."""
self.assertFalse(di.is_catalog(di.AbstractCatalog()))
self.assertTrue(di.is_catalog(di.DynamicCatalog()))
def test_with_child_class(self):
"""Test with parent class."""
@ -235,6 +235,30 @@ class IsCatalogTests(unittest.TestCase):
self.assertFalse(di.is_catalog(object()))
class IsDynamicCatalogTests(unittest.TestCase):
"""`is_dynamic_catalog()` test cases."""
def test_with_declarative_catalog(self):
"""Test with declarative catalog."""
self.assertFalse(di.is_dynamic_catalog(di.DeclarativeCatalog))
def test_with_dynamic_catalog(self):
"""Test with dynamic catalog."""
self.assertTrue(di.is_dynamic_catalog(di.DynamicCatalog()))
class IsDeclarativeCatalogTests(unittest.TestCase):
"""`is_declarative_catalog()` test cases."""
def test_with_declarative_catalog(self):
"""Test with declarative catalog."""
self.assertTrue(di.is_declarative_catalog(di.DeclarativeCatalog))
def test_with_dynamic_catalog(self):
"""Test with dynamic catalog."""
self.assertFalse(di.is_declarative_catalog(di.DynamicCatalog()))
class IsCatalogBundleTests(unittest.TestCase):
"""`is_catalog_bundle()` test cases."""