diff --git a/dependency_injector/catalog.py b/dependency_injector/catalog.py index f2ca1136..2725f7c3 100644 --- a/dependency_injector/catalog.py +++ b/dependency_injector/catalog.py @@ -8,32 +8,21 @@ from .utils import is_provider from .utils import is_catalog -class CatalogPackFactory(object): - """Factory of catalog packs.""" +class CatalogBundle(object): + """Bundle of catalog providers.""" - def __init__(self, catalog): + catalog = None + """:type: AbstractCatalog""" + + __slots__ = ('providers', '__dict__') + + def __init__(self, *providers): """Initializer.""" - self.catalog = catalog - - def __call__(self, *providers): - """Create catalog pack with specified providers.""" - return CatalogPack(self.catalog, *providers) - - -class CatalogPack(object): - """Pack of catalog providers.""" - - __slots__ = ('catalog', 'providers', '__dict__') - - def __init__(self, catalog, *providers): - """Initializer.""" - self.catalog = catalog - self.providers = dict() - for provider in providers: - self._ensure_provider_is_bound(provider) - self.__dict__[provider.bind.name] = provider - self.providers[provider.bind.name] = provider - super(CatalogPack, self).__init__() + self.providers = dict((provider.bind.name, provider) + for provider in providers + if self._ensure_provider_is_bound(provider)) + self.__dict__.update(self.providers) + super(CatalogBundle, self).__init__() def get(self, name): """Return provider with specified name or raises error.""" @@ -47,16 +36,17 @@ class CatalogPack(object): return name in self.providers def _ensure_provider_is_bound(self, provider): - """Check that provider is bound.""" - if not provider.bind: + """Check that provider is bound to the bundle's catalog.""" + if not provider.is_bound: raise Error('Provider {0} is not bound to ' 'any catalog'.format(provider)) if provider is not self.catalog.get(provider.bind.name): raise Error('{0} can contain providers from ' - 'catalog {0}' .format(self, self.catalog)) + 'catalog {0}'.format(self.__class__, self.catalog)) + return True def _raise_undefined_provider_error(self, name): - """Raise error for cases when there is no such provider in pack.""" + """Raise error for cases when there is no such provider in bundle.""" raise Error('Provider "{0}" is not a part of {1}'.format(name, self)) def __getattr__(self, item): @@ -64,16 +54,16 @@ class CatalogPack(object): self._raise_undefined_provider_error(item) def __repr__(self): - """Return string representation of pack.""" - return ''.format( - ', '.join(six.iterkeys(self.providers)), self.catalog) + """Return string representation of bundle.""" + return ''.format( + self.catalog, ', '.join(six.iterkeys(self.providers))) class CatalogMetaClass(type): - """Providers catalog meta class.""" + """Catalog meta class.""" def __new__(mcs, class_name, bases, attributes): - """Meta class factory.""" + """Catalog class factory.""" cls_providers = dict((name, provider) for name, provider in six.iteritems(attributes) if is_provider(provider)) @@ -93,13 +83,18 @@ class CatalogMetaClass(type): cls.inherited_providers = inherited_providers cls.providers = providers - cls.Pack = CatalogPackFactory(cls) + cls.Bundle = mcs.bundle_cls_factory(cls) for name, provider in six.iteritems(cls_providers): provider.bind = ProviderBinding(cls, name) return cls + @classmethod + def bundle_cls_factory(mcs, cls): + """Create bundle class for catalog.""" + return type('{0}Bundle', (CatalogBundle,), dict(catalog=cls)) + def __repr__(cls): """Return string representation of the catalog class.""" return '' @@ -120,11 +115,11 @@ class AbstractCatalog(object): :param inherited_providers: Dict of providers, that are inherited from parent catalogs - :type Pack: CatalogPack - :param Pack: Catalog's pack class + :type Bundle: CatalogBundle + :param Bundle: Catalog's bundle class """ - Pack = CatalogPackFactory + Bundle = CatalogBundle cls_providers = dict() inherited_providers = dict() diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index 501b90d8..b510bd6b 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -74,6 +74,11 @@ class Provider(object): """Reset all overriding providers.""" self.overridden_by = None + @property + def is_bound(self): + """Check if provider is bound to any catalog.""" + return bool(self.bind) + class Delegate(Provider): """Provider's delegate.""" diff --git a/tests/test_catalog.py b/tests/test_catalog.py index 8572f367..fc7e031a 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -25,13 +25,6 @@ class CatalogC(CatalogB): p32 = di.Provider() -class CatalogD(di.AbstractCatalog): - """Test catalog D.""" - - p11 = di.Provider() - p12 = di.Provider() - - class CatalogsInheritanceTests(unittest.TestCase): """Catalogs inheritance tests.""" @@ -78,55 +71,71 @@ class CatalogsInheritanceTests(unittest.TestCase): p32=CatalogC.p32)) -class CatalogPackTests(unittest.TestCase): - """Catalog pack test cases.""" +class CatalogBundleTests(unittest.TestCase): + """Catalog bundle test cases.""" def setUp(self): """Set test environment up.""" - self.pack = CatalogC.Pack(CatalogC.p11, - CatalogC.p12) + self.bundle = CatalogC.Bundle(CatalogC.p11, + CatalogC.p12) - def test_get_attr_from_pack(self): - """Test get providers (attribute) from pack.""" - self.assertIs(self.pack.p11, CatalogC.p11) - self.assertIs(self.pack.p12, CatalogC.p12) + def test_get_attr_from_bundle(self): + """Test get providers (attribute) from catalog bundle.""" + self.assertIs(self.bundle.p11, CatalogC.p11) + self.assertIs(self.bundle.p12, CatalogC.p12) - def test_get_attr_not_from_pack(self): - """Test get providers (attribute) that are not in pack.""" - self.assertRaises(di.Error, getattr, self.pack, 'p21') - self.assertRaises(di.Error, getattr, self.pack, 'p22') - self.assertRaises(di.Error, getattr, self.pack, 'p31') - self.assertRaises(di.Error, getattr, self.pack, 'p32') + def test_get_attr_not_from_bundle(self): + """Test get providers (attribute) that are not in bundle.""" + self.assertRaises(di.Error, getattr, self.bundle, 'p21') + self.assertRaises(di.Error, getattr, self.bundle, 'p22') + self.assertRaises(di.Error, getattr, self.bundle, 'p31') + self.assertRaises(di.Error, getattr, self.bundle, 'p32') - def test_get_method_from_pack(self): - """Test get providers (get() method) from pack.""" - self.assertIs(self.pack.get('p11'), CatalogC.p11) - self.assertIs(self.pack.get('p12'), CatalogC.p12) + def test_get_method_from_bundle(self): + """Test get providers (get() method) from bundle.""" + self.assertIs(self.bundle.get('p11'), CatalogC.p11) + self.assertIs(self.bundle.get('p12'), CatalogC.p12) - def test_get_method_not_from_pack(self): - """Test get providers (get() method) that are not in pack.""" - self.assertRaises(di.Error, self.pack.get, 'p21') - self.assertRaises(di.Error, self.pack.get, 'p22') - self.assertRaises(di.Error, self.pack.get, 'p31') - self.assertRaises(di.Error, self.pack.get, 'p32') + def test_get_method_not_from_bundle(self): + """Test get providers (get() method) that are not in bundle.""" + self.assertRaises(di.Error, self.bundle.get, 'p21') + self.assertRaises(di.Error, self.bundle.get, 'p22') + self.assertRaises(di.Error, self.bundle.get, 'p31') + self.assertRaises(di.Error, self.bundle.get, 'p32') def test_has(self): - """Test checks of providers availability in pack.""" - self.assertTrue(self.pack.has('p11')) - self.assertTrue(self.pack.has('p12')) + """Test checks of providers availability in bundle.""" + self.assertTrue(self.bundle.has('p11')) + self.assertTrue(self.bundle.has('p12')) - self.assertFalse(self.pack.has('p21')) - self.assertFalse(self.pack.has('p22')) - self.assertFalse(self.pack.has('p31')) - self.assertFalse(self.pack.has('p32')) + self.assertFalse(self.bundle.has('p21')) + self.assertFalse(self.bundle.has('p22')) + self.assertFalse(self.bundle.has('p31')) + self.assertFalse(self.bundle.has('p32')) - def test_create_pack_with_another_catalog_provider(self): - """Test that pack is not created with provider from another catalog.""" - self.assertRaises(di.Error, CatalogC.Pack, CatalogC.p31, CatalogD.p11) + def test_create_bundle_with_unbound_provider(self): + """Test that bundle is not created with unbound provider.""" + self.assertRaises(di.Error, CatalogC.Bundle, di.Provider()) - def test_create_pack_with_unbound_provider(self): - """Test that pack is not created with unbound provider.""" - self.assertRaises(di.Error, CatalogC.Pack, di.Provider()) + def test_create_bundle_with_another_catalog_provider(self): + """Test that bundle can not contain another catalog's provider.""" + class TestCatalog(di.AbstractCatalog): + """Test catalog.""" + + provider = di.Provider() + + self.assertRaises(di.Error, + CatalogC.Bundle, CatalogC.p31, TestCatalog.provider) + + def test_create_bundle_with_another_catalog_provider_with_same_name(self): + """Test that bundle can not contain another catalog's provider.""" + class TestCatalog(di.AbstractCatalog): + """Test catalog.""" + + p31 = di.Provider() + + self.assertRaises(di.Error, + CatalogC.Bundle, CatalogC.p31, TestCatalog.p31) class CatalogTests(unittest.TestCase):