Add declarative catalog @copy decorator

This commit is contained in:
Roman Mogilatov 2016-04-10 16:52:37 +03:00
parent 43258e5fd9
commit 7cdeede38a
3 changed files with 123 additions and 24 deletions

View File

@ -123,7 +123,7 @@ class DeclarativeCatalogMetaClass(type):
:rtype: None :rtype: None
""" """
if is_provider(value): if is_provider(value):
setattr(cls._catalog, name, value) cls.bind_provider(name, value, _set_as_attribute=False)
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value) return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
def __delattr__(cls, name): def __delattr__(cls, name):
@ -348,7 +348,8 @@ class DeclarativeCatalog(object):
get = get_provider # Backward compatibility for versions < 0.11.* get = get_provider # Backward compatibility for versions < 0.11.*
@classmethod @classmethod
def bind_provider(cls, name, provider): def bind_provider(cls, name, provider, force=False,
_set_as_attribute=True):
"""Bind provider to catalog with specified name. """Bind provider to catalog with specified name.
:param name: Name of the provider. :param name: Name of the provider.
@ -357,14 +358,26 @@ class DeclarativeCatalog(object):
:param provider: Provider instance. :param provider: Provider instance.
:type provider: :py:class:`dependency_injector.providers.Provider` :type provider: :py:class:`dependency_injector.providers.Provider`
:param force: Force binding of provider.
:type force: bool
:raise: :py:exc:`dependency_injector.errors.Error` :raise: :py:exc:`dependency_injector.errors.Error`
:rtype: None :rtype: None
""" """
setattr(cls, name, provider) if cls._catalog.is_provider_bound(provider):
bindind_name = cls._catalog.get_provider_bind_name(provider)
if bindind_name == name and not force:
return
cls._catalog.bind_provider(name, provider, force)
cls.cls_providers[name] = provider
if _set_as_attribute:
setattr(cls, name, provider)
@classmethod @classmethod
def bind_providers(cls, providers): def bind_providers(cls, providers, force=False):
"""Bind providers dictionary to catalog. """Bind providers dictionary to catalog.
:param providers: Dictionary of providers, where key is a name :param providers: Dictionary of providers, where key is a name
@ -372,12 +385,15 @@ class DeclarativeCatalog(object):
:type providers: :type providers:
dict[str, :py:class:`dependency_injector.providers.Provider`] dict[str, :py:class:`dependency_injector.providers.Provider`]
:param force: Force binding of providers.
:type force: bool
:raise: :py:exc:`dependency_injector.errors.Error` :raise: :py:exc:`dependency_injector.errors.Error`
:rtype: None :rtype: None
""" """
for name, provider in six.iteritems(providers): for name, provider in six.iteritems(providers):
setattr(cls, name, provider) cls.bind_provider(name, provider, force=force)
@classmethod @classmethod
def has_provider(cls, name): def has_provider(cls, name):
@ -402,6 +418,7 @@ class DeclarativeCatalog(object):
:rtype: None :rtype: None
""" """
delattr(cls, name) delattr(cls, name)
del cls.cls_providers[name]
@classmethod @classmethod
def __getattr__(cls, name): # pragma: no cover def __getattr__(cls, name): # pragma: no cover

View File

@ -4,42 +4,46 @@ import six
from copy import deepcopy from copy import deepcopy
from dependency_injector.errors import UndefinedProviderError
def copy(catalog): def copy(catalog):
""":py:class:`DeclarativeCatalog` copying decorator. """:py:class:`DeclarativeCatalog` copying decorator.
This decorator copy all providers from provided catalog to decorated one.
If one of the decorated catalog providers matches to source catalog
providers by name, it would be replaced by reference.
:param catalog: Catalog that should be copied by decorated catalog. :param catalog: Catalog that should be copied by decorated catalog.
:type catalog: :py:class:`dependency_injector.catalogs.DeclarativeCatalog` :type catalog: :py:class:`dependency_injector.catalogs.DeclarativeCatalog`
| :py:class:`dependency_injector.catalogs.DynamicCatalog`
:return: Declarative catalog's copying decorator. :return: Declarative catalog's copying decorator.
:rtype: :rtype:
callable(:py:class:`dependency_injector.catalogs.DeclarativeCatalog`) callable(:py:class:`DeclarativeCatalog`)
""" """
def decorator(overriding_catalog): def decorator(copied_catalog):
"""Overriding decorator. """Copying decorator.
:param catalog: Decorated catalog. :param copied_catalog: Decorated catalog.
:type catalog: :type copied_catalog: :py:class:`DeclarativeCatalog`
:py:class:`dependency_injector.catalogs.DeclarativeCatalog`
:return: Decorated catalog. :return: Decorated catalog.
:rtype: :rtype:
:py:class:`dependency_injector.catalogs.DeclarativeCatalog` :py:class:`DeclarativeCatalog`
""" """
memo = dict() memo = dict()
for name, provider in six.iteritems(copied_catalog.cls_providers):
try:
source_provider = catalog.get_provider(name)
except UndefinedProviderError:
pass
else:
memo[id(source_provider)] = provider
for name, provider in six.iteritems(overriding_catalog.providers): copied_catalog.bind_providers(deepcopy(catalog.providers, memo),
memo[id(catalog.get_provider(name))] = provider force=True)
dynamic_catalog_copy = deepcopy(catalog._catalog, memo) return copied_catalog
print dynamic_catalog_copy.providers
for name, provider in six.iteritems(dynamic_catalog_copy.providers):
overriding_catalog.bind_provider(name, provider)
return overriding_catalog
return decorator return decorator
@ -47,7 +51,7 @@ def override(catalog):
""":py:class:`DeclarativeCatalog` overriding decorator. """:py:class:`DeclarativeCatalog` overriding decorator.
:param catalog: Catalog that should be overridden by decorated catalog. :param catalog: Catalog that should be overridden by decorated catalog.
:type catalog: :py:class:`DeclarativeCatalog` | :py:class:`DynamicCatalog` :type catalog: :py:class:`DeclarativeCatalog`
:return: Declarative catalog's overriding decorator. :return: Declarative catalog's overriding decorator.
:rtype: callable(:py:class:`DeclarativeCatalog`) :rtype: callable(:py:class:`DeclarativeCatalog`)

View File

@ -29,6 +29,17 @@ class DeclarativeCatalogTests(unittest.TestCase):
def test_cls_providers(self): def test_cls_providers(self):
"""Test `di.DeclarativeCatalog.cls_providers` contents.""" """Test `di.DeclarativeCatalog.cls_providers` contents."""
class CatalogA(catalogs.DeclarativeCatalog):
"""Test catalog A."""
p11 = providers.Provider()
p12 = providers.Provider()
class CatalogB(CatalogA):
"""Test catalog B."""
p21 = providers.Provider()
p22 = providers.Provider()
self.assertDictEqual(CatalogA.cls_providers, self.assertDictEqual(CatalogA.cls_providers,
dict(p11=CatalogA.p11, dict(p11=CatalogA.p11,
p12=CatalogA.p12)) p12=CatalogA.p12))
@ -71,6 +82,14 @@ class DeclarativeCatalogTests(unittest.TestCase):
del CatalogA.px del CatalogA.px
del CatalogA.py del CatalogA.py
def test_bind_existing_provider(self):
"""Test setting of provider via bind_provider() to catalog."""
with self.assertRaises(errors.Error):
CatalogA.p11 = providers.Provider()
with self.assertRaises(errors.Error):
CatalogA.bind_provider('p11', providers.Provider())
def test_bind_provider_with_valid_provided_type(self): def test_bind_provider_with_valid_provided_type(self):
"""Test setting of provider with provider type restriction.""" """Test setting of provider with provider type restriction."""
class SomeProvider(providers.Provider): class SomeProvider(providers.Provider):
@ -350,3 +369,62 @@ class TestCatalogWithProvidingCallbacks(unittest.TestCase):
auth_service = Services.auth() auth_service = Services.auth()
self.assertIsInstance(auth_service, ExtendedAuthService) self.assertIsInstance(auth_service, ExtendedAuthService)
class CopyingTests(unittest.TestCase):
"""Declarative catalogs copying tests."""
def test_copy(self):
"""Test catalog providers copying."""
@catalogs.copy(CatalogA)
class CatalogA1(CatalogA):
pass
@catalogs.copy(CatalogA)
class CatalogA2(CatalogA):
pass
self.assertIsNot(CatalogA.p11, CatalogA1.p11)
self.assertIsNot(CatalogA.p12, CatalogA1.p12)
self.assertIsNot(CatalogA.p11, CatalogA2.p11)
self.assertIsNot(CatalogA.p12, CatalogA2.p12)
self.assertIsNot(CatalogA1.p11, CatalogA2.p11)
self.assertIsNot(CatalogA1.p12, CatalogA2.p12)
def test_copy_with_replacing(self):
"""Test catalog providers copying."""
class CatalogA(catalogs.DeclarativeCatalog):
p11 = providers.Value(0)
p12 = providers.Factory(dict, p11=p11)
@catalogs.copy(CatalogA)
class CatalogA1(CatalogA):
p11 = providers.Value(1)
p13 = providers.Value(11)
@catalogs.copy(CatalogA)
class CatalogA2(CatalogA):
p11 = providers.Value(2)
p13 = providers.Value(22)
self.assertIsNot(CatalogA.p11, CatalogA1.p11)
self.assertIsNot(CatalogA.p12, CatalogA1.p12)
self.assertIsNot(CatalogA.p11, CatalogA2.p11)
self.assertIsNot(CatalogA.p12, CatalogA2.p12)
self.assertIsNot(CatalogA1.p11, CatalogA2.p11)
self.assertIsNot(CatalogA1.p12, CatalogA2.p12)
self.assertIs(CatalogA.p12.injections[0].injectable, CatalogA.p11)
self.assertIs(CatalogA1.p12.injections[0].injectable, CatalogA1.p11)
self.assertIs(CatalogA2.p12.injections[0].injectable, CatalogA2.p11)
self.assertEqual(CatalogA.p12(), dict(p11=0))
self.assertEqual(CatalogA1.p12(), dict(p11=1))
self.assertEqual(CatalogA2.p12(), dict(p11=2))
self.assertEqual(CatalogA1.p13(), 11)
self.assertEqual(CatalogA2.p13(), 22)