diff --git a/dependency_injector/__init__.py b/dependency_injector/__init__.py index f659392d..1332f8b5 100644 --- a/dependency_injector/__init__.py +++ b/dependency_injector/__init__.py @@ -41,6 +41,7 @@ from .utils import is_catalog_bundle from .utils import ensure_is_catalog_bundle from .errors import Error +from .errors import UndefinedProviderError # Backward compatibility for versions < 0.11.* from . import catalogs @@ -96,6 +97,7 @@ __all__ = ( # Errors 'Error', + 'UndefinedProviderError', # Version 'VERSION' diff --git a/dependency_injector/catalogs.py b/dependency_injector/catalogs.py index 42cf44c1..843da8b0 100644 --- a/dependency_injector/catalogs.py +++ b/dependency_injector/catalogs.py @@ -88,6 +88,7 @@ class DynamicCatalog(object): self.Bundle = CatalogBundle.sub_cls_factory(self) self.bind_providers(providers) + super(DynamicCatalog, self).__init__() def is_bundle_owner(self, bundle): """Check if catalog is bundle owner.""" @@ -181,6 +182,16 @@ class DynamicCatalog(object): """Return provider with specified name or raise en error.""" return self.get_provider(name) + def __setattr__(self, name, value): + """Handle setting of catalog attributes. + + Setting of attributes works as usual, but if value of attribute is + provider, this provider will be bound to catalog correctly. + """ + if is_provider(value): + return self.bind_provider(name, value) + return super(DynamicCatalog, self).__setattr__(name, value) + def __repr__(self): """Return Python representation of catalog.""" return '<{0}({1})>'.format(self.name, @@ -195,8 +206,6 @@ class DeclarativeCatalogMetaClass(type): def __new__(mcs, class_name, bases, attributes): """Declarative catalog class factory.""" - cls = type.__new__(mcs, class_name, bases, attributes) - cls_providers = tuple((name, provider) for name, provider in six.iteritems(attributes) if is_provider(provider)) @@ -208,10 +217,10 @@ class DeclarativeCatalogMetaClass(type): providers = cls_providers + inherited_providers - cls.name = '.'.join((cls.__module__, cls.__name__)) + cls = type.__new__(mcs, class_name, bases, attributes) cls.catalog = DynamicCatalog() - cls.catalog.name = cls.name + cls.catalog.name = '.'.join((cls.__module__, cls.__name__)) cls.catalog.bind_providers(dict(providers)) cls.cls_providers = dict(cls_providers) @@ -221,6 +230,11 @@ class DeclarativeCatalogMetaClass(type): return cls + @property + def name(cls): + """Return catalog's name.""" + return cls.catalog.name + @property def providers(cls): """Return dict of catalog's providers.""" @@ -241,6 +255,21 @@ class DeclarativeCatalogMetaClass(type): """Return last overriding catalog.""" return cls.catalog.last_overriding + def __getattr__(cls, name): + """Return provider with specified name or raise en error.""" + raise UndefinedProviderError('There is no provider "{0}" in ' + 'catalog {1}'.format(name, cls)) + + def __setattr__(cls, name, value): + """Handle setting of catalog attributes. + + Setting of attributes works as usual, but if value of attribute is + provider, this provider will be bound to catalog correctly. + """ + if is_provider(value): + setattr(cls.catalog, name, value) + return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value) + def __repr__(cls): """Return string representation of the catalog.""" return '<{0}({1})>'.format(cls.name, @@ -326,28 +355,41 @@ class DeclarativeCatalog(object): :type overriding: DeclarativeCatalog | DynamicCatalog """ - cls.catalog.override(overriding) + return cls.catalog.override(overriding) @classmethod def reset_last_overriding(cls): """Reset last overriding catalog.""" - cls.catalog.reset_last_overriding() + return cls.catalog.reset_last_overriding() @classmethod def reset_override(cls): """Reset all overridings for all catalog providers.""" - cls.catalog.reset_override() + return cls.catalog.reset_override() @classmethod - def get(cls, name): - """Return provider with specified name or raises error.""" + def get_provider(cls, name): + """Return provider with specified name or raise an error.""" return cls.catalog.get_provider(name) @classmethod - def has(cls, name): + def bind_provider(cls, name, provider): + """Bind provider to catalog with specified name.""" + return cls.catalog.bind_provider(name, provider) + + @classmethod + def bind_providers(cls, providers): + """Bind providers dictionary to catalog.""" + return cls.catalog.bind_providers(providers) + + @classmethod + def has_provider(cls, name): """Check if there is provider with certain name.""" return cls.catalog.has_provider(name) + get = get_provider # Backward compatibility for versions < 0.11.* + has = has_provider # Backward compatibility for versions < 0.11.* + # Backward compatibility for versions < 0.11.* AbstractCatalog = DeclarativeCatalog diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index d0fdb910..081ca80f 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -157,6 +157,22 @@ class DeclarativeCatalogTests(unittest.TestCase): p31=CatalogC.p31, p32=CatalogC.p32)) + def test_setattr(self): + """Test setting of provider attributes to catalog.""" + px = di.Provider() + py = di.Provider() + + CatalogC.px = px + CatalogC.py = py + + self.assertIs(CatalogC.px, px) + self.assertIs(CatalogC.get_provider('px'), px) + self.assertIs(CatalogC.catalog.px, px) + + self.assertIs(CatalogC.py, py) + self.assertIs(CatalogC.get_provider('py'), py) + self.assertIs(CatalogC.catalog.py, py) + def test_provider_is_bound(self): """Test that providers are bound to the catalogs.""" self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11)) @@ -227,9 +243,23 @@ class DeclarativeCatalogTests(unittest.TestCase): self.assertIs(CatalogC.get('p32'), CatalogC.p32) self.assertIs(CatalogC.get('p32'), CatalogC.p32) + self.assertIs(CatalogC.get_provider('p11'), CatalogC.p11) + self.assertIs(CatalogC.get_provider('p12'), CatalogC.p12) + self.assertIs(CatalogC.get_provider('p22'), CatalogC.p22) + self.assertIs(CatalogC.get_provider('p22'), CatalogC.p22) + self.assertIs(CatalogC.get_provider('p32'), CatalogC.p32) + self.assertIs(CatalogC.get_provider('p32'), CatalogC.p32) + def test_get_undefined(self): """Test getting of undefined providers using get() method.""" - self.assertRaises(di.Error, CatalogC.get, 'undefined') + with self.assertRaises(di.UndefinedProviderError): + CatalogC.get('undefined') + + with self.assertRaises(di.UndefinedProviderError): + CatalogC.get_provider('undefined') + + with self.assertRaises(di.UndefinedProviderError): + CatalogC.undefined def test_has(self): """Test checks of providers availability in catalog.""" @@ -241,6 +271,14 @@ class DeclarativeCatalogTests(unittest.TestCase): self.assertTrue(CatalogC.has('p32')) self.assertFalse(CatalogC.has('undefined')) + self.assertTrue(CatalogC.has_provider('p11')) + self.assertTrue(CatalogC.has_provider('p12')) + self.assertTrue(CatalogC.has_provider('p21')) + self.assertTrue(CatalogC.has_provider('p22')) + self.assertTrue(CatalogC.has_provider('p31')) + self.assertTrue(CatalogC.has_provider('p32')) + self.assertFalse(CatalogC.has_provider('undefined')) + def test_filter_all_providers_by_type(self): """Test getting of all catalog providers of specific type.""" self.assertTrue(len(CatalogC.filter(di.Provider)) == 6)