Add declarative catalog @copy decorator

This commit is contained in:
Roman Mogilatov 2016-04-10 16:52:37 +03:00
parent 43258e5fd9
commit 7cdeede38a
3 changed files with 123 additions and 24 deletions

View File

@ -123,7 +123,7 @@ class DeclarativeCatalogMetaClass(type):
:rtype: None
"""
if is_provider(value):
setattr(cls._catalog, name, value)
cls.bind_provider(name, value, _set_as_attribute=False)
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
def __delattr__(cls, name):
@ -348,7 +348,8 @@ class DeclarativeCatalog(object):
get = get_provider # Backward compatibility for versions < 0.11.*
@classmethod
def bind_provider(cls, name, provider):
def bind_provider(cls, name, provider, force=False,
_set_as_attribute=True):
"""Bind provider to catalog with specified name.
:param name: Name of the provider.
@ -357,14 +358,26 @@ class DeclarativeCatalog(object):
:param provider: Provider instance.
:type provider: :py:class:`dependency_injector.providers.Provider`
:param force: Force binding of provider.
:type force: bool
:raise: :py:exc:`dependency_injector.errors.Error`
:rtype: None
"""
if cls._catalog.is_provider_bound(provider):
bindind_name = cls._catalog.get_provider_bind_name(provider)
if bindind_name == name and not force:
return
cls._catalog.bind_provider(name, provider, force)
cls.cls_providers[name] = provider
if _set_as_attribute:
setattr(cls, name, provider)
@classmethod
def bind_providers(cls, providers):
def bind_providers(cls, providers, force=False):
"""Bind providers dictionary to catalog.
:param providers: Dictionary of providers, where key is a name
@ -372,12 +385,15 @@ class DeclarativeCatalog(object):
:type providers:
dict[str, :py:class:`dependency_injector.providers.Provider`]
:param force: Force binding of providers.
:type force: bool
:raise: :py:exc:`dependency_injector.errors.Error`
:rtype: None
"""
for name, provider in six.iteritems(providers):
setattr(cls, name, provider)
cls.bind_provider(name, provider, force=force)
@classmethod
def has_provider(cls, name):
@ -402,6 +418,7 @@ class DeclarativeCatalog(object):
:rtype: None
"""
delattr(cls, name)
del cls.cls_providers[name]
@classmethod
def __getattr__(cls, name): # pragma: no cover

View File

@ -4,42 +4,46 @@ import six
from copy import deepcopy
from dependency_injector.errors import UndefinedProviderError
def copy(catalog):
""":py:class:`DeclarativeCatalog` copying decorator.
This decorator copy all providers from provided catalog to decorated one.
If one of the decorated catalog providers matches to source catalog
providers by name, it would be replaced by reference.
:param catalog: Catalog that should be copied by decorated catalog.
:type catalog: :py:class:`dependency_injector.catalogs.DeclarativeCatalog`
| :py:class:`dependency_injector.catalogs.DynamicCatalog`
:return: Declarative catalog's copying decorator.
:rtype:
callable(:py:class:`dependency_injector.catalogs.DeclarativeCatalog`)
callable(:py:class:`DeclarativeCatalog`)
"""
def decorator(overriding_catalog):
"""Overriding decorator.
def decorator(copied_catalog):
"""Copying decorator.
:param catalog: Decorated catalog.
:type catalog:
:py:class:`dependency_injector.catalogs.DeclarativeCatalog`
:param copied_catalog: Decorated catalog.
:type copied_catalog: :py:class:`DeclarativeCatalog`
:return: Decorated catalog.
:rtype:
:py:class:`dependency_injector.catalogs.DeclarativeCatalog`
:py:class:`DeclarativeCatalog`
"""
memo = dict()
for name, provider in six.iteritems(copied_catalog.cls_providers):
try:
source_provider = catalog.get_provider(name)
except UndefinedProviderError:
pass
else:
memo[id(source_provider)] = provider
for name, provider in six.iteritems(overriding_catalog.providers):
memo[id(catalog.get_provider(name))] = provider
copied_catalog.bind_providers(deepcopy(catalog.providers, memo),
force=True)
dynamic_catalog_copy = deepcopy(catalog._catalog, memo)
print dynamic_catalog_copy.providers
for name, provider in six.iteritems(dynamic_catalog_copy.providers):
overriding_catalog.bind_provider(name, provider)
return overriding_catalog
return copied_catalog
return decorator
@ -47,7 +51,7 @@ def override(catalog):
""":py:class:`DeclarativeCatalog` overriding decorator.
:param catalog: Catalog that should be overridden by decorated catalog.
:type catalog: :py:class:`DeclarativeCatalog` | :py:class:`DynamicCatalog`
:type catalog: :py:class:`DeclarativeCatalog`
:return: Declarative catalog's overriding decorator.
:rtype: callable(:py:class:`DeclarativeCatalog`)

View File

@ -29,6 +29,17 @@ class DeclarativeCatalogTests(unittest.TestCase):
def test_cls_providers(self):
"""Test `di.DeclarativeCatalog.cls_providers` contents."""
class CatalogA(catalogs.DeclarativeCatalog):
"""Test catalog A."""
p11 = providers.Provider()
p12 = providers.Provider()
class CatalogB(CatalogA):
"""Test catalog B."""
p21 = providers.Provider()
p22 = providers.Provider()
self.assertDictEqual(CatalogA.cls_providers,
dict(p11=CatalogA.p11,
p12=CatalogA.p12))
@ -71,6 +82,14 @@ class DeclarativeCatalogTests(unittest.TestCase):
del CatalogA.px
del CatalogA.py
def test_bind_existing_provider(self):
"""Test setting of provider via bind_provider() to catalog."""
with self.assertRaises(errors.Error):
CatalogA.p11 = providers.Provider()
with self.assertRaises(errors.Error):
CatalogA.bind_provider('p11', providers.Provider())
def test_bind_provider_with_valid_provided_type(self):
"""Test setting of provider with provider type restriction."""
class SomeProvider(providers.Provider):
@ -350,3 +369,62 @@ class TestCatalogWithProvidingCallbacks(unittest.TestCase):
auth_service = Services.auth()
self.assertIsInstance(auth_service, ExtendedAuthService)
class CopyingTests(unittest.TestCase):
"""Declarative catalogs copying tests."""
def test_copy(self):
"""Test catalog providers copying."""
@catalogs.copy(CatalogA)
class CatalogA1(CatalogA):
pass
@catalogs.copy(CatalogA)
class CatalogA2(CatalogA):
pass
self.assertIsNot(CatalogA.p11, CatalogA1.p11)
self.assertIsNot(CatalogA.p12, CatalogA1.p12)
self.assertIsNot(CatalogA.p11, CatalogA2.p11)
self.assertIsNot(CatalogA.p12, CatalogA2.p12)
self.assertIsNot(CatalogA1.p11, CatalogA2.p11)
self.assertIsNot(CatalogA1.p12, CatalogA2.p12)
def test_copy_with_replacing(self):
"""Test catalog providers copying."""
class CatalogA(catalogs.DeclarativeCatalog):
p11 = providers.Value(0)
p12 = providers.Factory(dict, p11=p11)
@catalogs.copy(CatalogA)
class CatalogA1(CatalogA):
p11 = providers.Value(1)
p13 = providers.Value(11)
@catalogs.copy(CatalogA)
class CatalogA2(CatalogA):
p11 = providers.Value(2)
p13 = providers.Value(22)
self.assertIsNot(CatalogA.p11, CatalogA1.p11)
self.assertIsNot(CatalogA.p12, CatalogA1.p12)
self.assertIsNot(CatalogA.p11, CatalogA2.p11)
self.assertIsNot(CatalogA.p12, CatalogA2.p12)
self.assertIsNot(CatalogA1.p11, CatalogA2.p11)
self.assertIsNot(CatalogA1.p12, CatalogA2.p12)
self.assertIs(CatalogA.p12.injections[0].injectable, CatalogA.p11)
self.assertIs(CatalogA1.p12.injections[0].injectable, CatalogA1.p11)
self.assertIs(CatalogA2.p12.injections[0].injectable, CatalogA2.p11)
self.assertEqual(CatalogA.p12(), dict(p11=0))
self.assertEqual(CatalogA1.p12(), dict(p11=1))
self.assertEqual(CatalogA2.p12(), dict(p11=2))
self.assertEqual(CatalogA1.p13(), 11)
self.assertEqual(CatalogA2.p13(), 22)