mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-06-14 02:23:13 +03:00
Refactor DeclarativeCatalog
This commit is contained in:
parent
4252fbfe4c
commit
be1ff0445d
|
@ -41,6 +41,7 @@ from .utils import is_catalog_bundle
|
||||||
from .utils import ensure_is_catalog_bundle
|
from .utils import ensure_is_catalog_bundle
|
||||||
|
|
||||||
from .errors import Error
|
from .errors import Error
|
||||||
|
from .errors import UndefinedProviderError
|
||||||
|
|
||||||
# Backward compatibility for versions < 0.11.*
|
# Backward compatibility for versions < 0.11.*
|
||||||
from . import catalogs
|
from . import catalogs
|
||||||
|
@ -96,6 +97,7 @@ __all__ = (
|
||||||
|
|
||||||
# Errors
|
# Errors
|
||||||
'Error',
|
'Error',
|
||||||
|
'UndefinedProviderError',
|
||||||
|
|
||||||
# Version
|
# Version
|
||||||
'VERSION'
|
'VERSION'
|
||||||
|
|
|
@ -88,6 +88,7 @@ class DynamicCatalog(object):
|
||||||
self.Bundle = CatalogBundle.sub_cls_factory(self)
|
self.Bundle = CatalogBundle.sub_cls_factory(self)
|
||||||
|
|
||||||
self.bind_providers(providers)
|
self.bind_providers(providers)
|
||||||
|
super(DynamicCatalog, self).__init__()
|
||||||
|
|
||||||
def is_bundle_owner(self, bundle):
|
def is_bundle_owner(self, bundle):
|
||||||
"""Check if catalog is bundle owner."""
|
"""Check if catalog is bundle owner."""
|
||||||
|
@ -181,6 +182,16 @@ class DynamicCatalog(object):
|
||||||
"""Return provider with specified name or raise en error."""
|
"""Return provider with specified name or raise en error."""
|
||||||
return self.get_provider(name)
|
return self.get_provider(name)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
"""Handle setting of catalog attributes.
|
||||||
|
|
||||||
|
Setting of attributes works as usual, but if value of attribute is
|
||||||
|
provider, this provider will be bound to catalog correctly.
|
||||||
|
"""
|
||||||
|
if is_provider(value):
|
||||||
|
return self.bind_provider(name, value)
|
||||||
|
return super(DynamicCatalog, self).__setattr__(name, value)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Return Python representation of catalog."""
|
"""Return Python representation of catalog."""
|
||||||
return '<{0}({1})>'.format(self.name,
|
return '<{0}({1})>'.format(self.name,
|
||||||
|
@ -195,8 +206,6 @@ class DeclarativeCatalogMetaClass(type):
|
||||||
|
|
||||||
def __new__(mcs, class_name, bases, attributes):
|
def __new__(mcs, class_name, bases, attributes):
|
||||||
"""Declarative catalog class factory."""
|
"""Declarative catalog class factory."""
|
||||||
cls = type.__new__(mcs, class_name, bases, attributes)
|
|
||||||
|
|
||||||
cls_providers = tuple((name, provider)
|
cls_providers = tuple((name, provider)
|
||||||
for name, provider in six.iteritems(attributes)
|
for name, provider in six.iteritems(attributes)
|
||||||
if is_provider(provider))
|
if is_provider(provider))
|
||||||
|
@ -208,10 +217,10 @@ class DeclarativeCatalogMetaClass(type):
|
||||||
|
|
||||||
providers = cls_providers + inherited_providers
|
providers = cls_providers + inherited_providers
|
||||||
|
|
||||||
cls.name = '.'.join((cls.__module__, cls.__name__))
|
cls = type.__new__(mcs, class_name, bases, attributes)
|
||||||
|
|
||||||
cls.catalog = DynamicCatalog()
|
cls.catalog = DynamicCatalog()
|
||||||
cls.catalog.name = cls.name
|
cls.catalog.name = '.'.join((cls.__module__, cls.__name__))
|
||||||
cls.catalog.bind_providers(dict(providers))
|
cls.catalog.bind_providers(dict(providers))
|
||||||
|
|
||||||
cls.cls_providers = dict(cls_providers)
|
cls.cls_providers = dict(cls_providers)
|
||||||
|
@ -221,6 +230,11 @@ class DeclarativeCatalogMetaClass(type):
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(cls):
|
||||||
|
"""Return catalog's name."""
|
||||||
|
return cls.catalog.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def providers(cls):
|
def providers(cls):
|
||||||
"""Return dict of catalog's providers."""
|
"""Return dict of catalog's providers."""
|
||||||
|
@ -241,6 +255,21 @@ class DeclarativeCatalogMetaClass(type):
|
||||||
"""Return last overriding catalog."""
|
"""Return last overriding catalog."""
|
||||||
return cls.catalog.last_overriding
|
return cls.catalog.last_overriding
|
||||||
|
|
||||||
|
def __getattr__(cls, name):
|
||||||
|
"""Return provider with specified name or raise en error."""
|
||||||
|
raise UndefinedProviderError('There is no provider "{0}" in '
|
||||||
|
'catalog {1}'.format(name, cls))
|
||||||
|
|
||||||
|
def __setattr__(cls, name, value):
|
||||||
|
"""Handle setting of catalog attributes.
|
||||||
|
|
||||||
|
Setting of attributes works as usual, but if value of attribute is
|
||||||
|
provider, this provider will be bound to catalog correctly.
|
||||||
|
"""
|
||||||
|
if is_provider(value):
|
||||||
|
setattr(cls.catalog, name, value)
|
||||||
|
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
|
||||||
|
|
||||||
def __repr__(cls):
|
def __repr__(cls):
|
||||||
"""Return string representation of the catalog."""
|
"""Return string representation of the catalog."""
|
||||||
return '<{0}({1})>'.format(cls.name,
|
return '<{0}({1})>'.format(cls.name,
|
||||||
|
@ -326,28 +355,41 @@ class DeclarativeCatalog(object):
|
||||||
|
|
||||||
:type overriding: DeclarativeCatalog | DynamicCatalog
|
:type overriding: DeclarativeCatalog | DynamicCatalog
|
||||||
"""
|
"""
|
||||||
cls.catalog.override(overriding)
|
return cls.catalog.override(overriding)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reset_last_overriding(cls):
|
def reset_last_overriding(cls):
|
||||||
"""Reset last overriding catalog."""
|
"""Reset last overriding catalog."""
|
||||||
cls.catalog.reset_last_overriding()
|
return cls.catalog.reset_last_overriding()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reset_override(cls):
|
def reset_override(cls):
|
||||||
"""Reset all overridings for all catalog providers."""
|
"""Reset all overridings for all catalog providers."""
|
||||||
cls.catalog.reset_override()
|
return cls.catalog.reset_override()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, name):
|
def get_provider(cls, name):
|
||||||
"""Return provider with specified name or raises error."""
|
"""Return provider with specified name or raise an error."""
|
||||||
return cls.catalog.get_provider(name)
|
return cls.catalog.get_provider(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has(cls, name):
|
def bind_provider(cls, name, provider):
|
||||||
|
"""Bind provider to catalog with specified name."""
|
||||||
|
return cls.catalog.bind_provider(name, provider)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def bind_providers(cls, providers):
|
||||||
|
"""Bind providers dictionary to catalog."""
|
||||||
|
return cls.catalog.bind_providers(providers)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_provider(cls, name):
|
||||||
"""Check if there is provider with certain name."""
|
"""Check if there is provider with certain name."""
|
||||||
return cls.catalog.has_provider(name)
|
return cls.catalog.has_provider(name)
|
||||||
|
|
||||||
|
get = get_provider # Backward compatibility for versions < 0.11.*
|
||||||
|
has = has_provider # Backward compatibility for versions < 0.11.*
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility for versions < 0.11.*
|
# Backward compatibility for versions < 0.11.*
|
||||||
AbstractCatalog = DeclarativeCatalog
|
AbstractCatalog = DeclarativeCatalog
|
||||||
|
|
|
@ -157,6 +157,22 @@ class DeclarativeCatalogTests(unittest.TestCase):
|
||||||
p31=CatalogC.p31,
|
p31=CatalogC.p31,
|
||||||
p32=CatalogC.p32))
|
p32=CatalogC.p32))
|
||||||
|
|
||||||
|
def test_setattr(self):
|
||||||
|
"""Test setting of provider attributes to catalog."""
|
||||||
|
px = di.Provider()
|
||||||
|
py = di.Provider()
|
||||||
|
|
||||||
|
CatalogC.px = px
|
||||||
|
CatalogC.py = py
|
||||||
|
|
||||||
|
self.assertIs(CatalogC.px, px)
|
||||||
|
self.assertIs(CatalogC.get_provider('px'), px)
|
||||||
|
self.assertIs(CatalogC.catalog.px, px)
|
||||||
|
|
||||||
|
self.assertIs(CatalogC.py, py)
|
||||||
|
self.assertIs(CatalogC.get_provider('py'), py)
|
||||||
|
self.assertIs(CatalogC.catalog.py, py)
|
||||||
|
|
||||||
def test_provider_is_bound(self):
|
def test_provider_is_bound(self):
|
||||||
"""Test that providers are bound to the catalogs."""
|
"""Test that providers are bound to the catalogs."""
|
||||||
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
|
self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
|
||||||
|
@ -227,9 +243,23 @@ class DeclarativeCatalogTests(unittest.TestCase):
|
||||||
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
|
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
|
||||||
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
|
self.assertIs(CatalogC.get('p32'), CatalogC.p32)
|
||||||
|
|
||||||
|
self.assertIs(CatalogC.get_provider('p11'), CatalogC.p11)
|
||||||
|
self.assertIs(CatalogC.get_provider('p12'), CatalogC.p12)
|
||||||
|
self.assertIs(CatalogC.get_provider('p22'), CatalogC.p22)
|
||||||
|
self.assertIs(CatalogC.get_provider('p22'), CatalogC.p22)
|
||||||
|
self.assertIs(CatalogC.get_provider('p32'), CatalogC.p32)
|
||||||
|
self.assertIs(CatalogC.get_provider('p32'), CatalogC.p32)
|
||||||
|
|
||||||
def test_get_undefined(self):
|
def test_get_undefined(self):
|
||||||
"""Test getting of undefined providers using get() method."""
|
"""Test getting of undefined providers using get() method."""
|
||||||
self.assertRaises(di.Error, CatalogC.get, 'undefined')
|
with self.assertRaises(di.UndefinedProviderError):
|
||||||
|
CatalogC.get('undefined')
|
||||||
|
|
||||||
|
with self.assertRaises(di.UndefinedProviderError):
|
||||||
|
CatalogC.get_provider('undefined')
|
||||||
|
|
||||||
|
with self.assertRaises(di.UndefinedProviderError):
|
||||||
|
CatalogC.undefined
|
||||||
|
|
||||||
def test_has(self):
|
def test_has(self):
|
||||||
"""Test checks of providers availability in catalog."""
|
"""Test checks of providers availability in catalog."""
|
||||||
|
@ -241,6 +271,14 @@ class DeclarativeCatalogTests(unittest.TestCase):
|
||||||
self.assertTrue(CatalogC.has('p32'))
|
self.assertTrue(CatalogC.has('p32'))
|
||||||
self.assertFalse(CatalogC.has('undefined'))
|
self.assertFalse(CatalogC.has('undefined'))
|
||||||
|
|
||||||
|
self.assertTrue(CatalogC.has_provider('p11'))
|
||||||
|
self.assertTrue(CatalogC.has_provider('p12'))
|
||||||
|
self.assertTrue(CatalogC.has_provider('p21'))
|
||||||
|
self.assertTrue(CatalogC.has_provider('p22'))
|
||||||
|
self.assertTrue(CatalogC.has_provider('p31'))
|
||||||
|
self.assertTrue(CatalogC.has_provider('p32'))
|
||||||
|
self.assertFalse(CatalogC.has_provider('undefined'))
|
||||||
|
|
||||||
def test_filter_all_providers_by_type(self):
|
def test_filter_all_providers_by_type(self):
|
||||||
"""Test getting of all catalog providers of specific type."""
|
"""Test getting of all catalog providers of specific type."""
|
||||||
self.assertTrue(len(CatalogC.filter(di.Provider)) == 6)
|
self.assertTrue(len(CatalogC.filter(di.Provider)) == 6)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user