mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-22 09:36:48 +03:00
Add declarative catalog @copy decorator
This commit is contained in:
parent
43258e5fd9
commit
7cdeede38a
|
@ -123,7 +123,7 @@ class DeclarativeCatalogMetaClass(type):
|
||||||
:rtype: None
|
:rtype: None
|
||||||
"""
|
"""
|
||||||
if is_provider(value):
|
if is_provider(value):
|
||||||
setattr(cls._catalog, name, value)
|
cls.bind_provider(name, value, _set_as_attribute=False)
|
||||||
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
|
return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value)
|
||||||
|
|
||||||
def __delattr__(cls, name):
|
def __delattr__(cls, name):
|
||||||
|
@ -348,7 +348,8 @@ class DeclarativeCatalog(object):
|
||||||
get = get_provider # Backward compatibility for versions < 0.11.*
|
get = get_provider # Backward compatibility for versions < 0.11.*
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def bind_provider(cls, name, provider):
|
def bind_provider(cls, name, provider, force=False,
|
||||||
|
_set_as_attribute=True):
|
||||||
"""Bind provider to catalog with specified name.
|
"""Bind provider to catalog with specified name.
|
||||||
|
|
||||||
:param name: Name of the provider.
|
:param name: Name of the provider.
|
||||||
|
@ -357,14 +358,26 @@ class DeclarativeCatalog(object):
|
||||||
:param provider: Provider instance.
|
:param provider: Provider instance.
|
||||||
:type provider: :py:class:`dependency_injector.providers.Provider`
|
:type provider: :py:class:`dependency_injector.providers.Provider`
|
||||||
|
|
||||||
|
:param force: Force binding of provider.
|
||||||
|
:type force: bool
|
||||||
|
|
||||||
:raise: :py:exc:`dependency_injector.errors.Error`
|
:raise: :py:exc:`dependency_injector.errors.Error`
|
||||||
|
|
||||||
:rtype: None
|
:rtype: None
|
||||||
"""
|
"""
|
||||||
setattr(cls, name, provider)
|
if cls._catalog.is_provider_bound(provider):
|
||||||
|
bindind_name = cls._catalog.get_provider_bind_name(provider)
|
||||||
|
if bindind_name == name and not force:
|
||||||
|
return
|
||||||
|
|
||||||
|
cls._catalog.bind_provider(name, provider, force)
|
||||||
|
cls.cls_providers[name] = provider
|
||||||
|
|
||||||
|
if _set_as_attribute:
|
||||||
|
setattr(cls, name, provider)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def bind_providers(cls, providers):
|
def bind_providers(cls, providers, force=False):
|
||||||
"""Bind providers dictionary to catalog.
|
"""Bind providers dictionary to catalog.
|
||||||
|
|
||||||
:param providers: Dictionary of providers, where key is a name
|
:param providers: Dictionary of providers, where key is a name
|
||||||
|
@ -372,12 +385,15 @@ class DeclarativeCatalog(object):
|
||||||
:type providers:
|
:type providers:
|
||||||
dict[str, :py:class:`dependency_injector.providers.Provider`]
|
dict[str, :py:class:`dependency_injector.providers.Provider`]
|
||||||
|
|
||||||
|
:param force: Force binding of providers.
|
||||||
|
:type force: bool
|
||||||
|
|
||||||
:raise: :py:exc:`dependency_injector.errors.Error`
|
:raise: :py:exc:`dependency_injector.errors.Error`
|
||||||
|
|
||||||
:rtype: None
|
:rtype: None
|
||||||
"""
|
"""
|
||||||
for name, provider in six.iteritems(providers):
|
for name, provider in six.iteritems(providers):
|
||||||
setattr(cls, name, provider)
|
cls.bind_provider(name, provider, force=force)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has_provider(cls, name):
|
def has_provider(cls, name):
|
||||||
|
@ -402,6 +418,7 @@ class DeclarativeCatalog(object):
|
||||||
:rtype: None
|
:rtype: None
|
||||||
"""
|
"""
|
||||||
delattr(cls, name)
|
delattr(cls, name)
|
||||||
|
del cls.cls_providers[name]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __getattr__(cls, name): # pragma: no cover
|
def __getattr__(cls, name): # pragma: no cover
|
||||||
|
|
|
@ -4,42 +4,46 @@ import six
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from dependency_injector.errors import UndefinedProviderError
|
||||||
|
|
||||||
|
|
||||||
def copy(catalog):
|
def copy(catalog):
|
||||||
""":py:class:`DeclarativeCatalog` copying decorator.
|
""":py:class:`DeclarativeCatalog` copying decorator.
|
||||||
|
|
||||||
|
This decorator copy all providers from provided catalog to decorated one.
|
||||||
|
If one of the decorated catalog providers matches to source catalog
|
||||||
|
providers by name, it would be replaced by reference.
|
||||||
|
|
||||||
:param catalog: Catalog that should be copied by decorated catalog.
|
:param catalog: Catalog that should be copied by decorated catalog.
|
||||||
:type catalog: :py:class:`dependency_injector.catalogs.DeclarativeCatalog`
|
:type catalog: :py:class:`dependency_injector.catalogs.DeclarativeCatalog`
|
||||||
| :py:class:`dependency_injector.catalogs.DynamicCatalog`
|
|
||||||
|
|
||||||
:return: Declarative catalog's copying decorator.
|
:return: Declarative catalog's copying decorator.
|
||||||
:rtype:
|
:rtype:
|
||||||
callable(:py:class:`dependency_injector.catalogs.DeclarativeCatalog`)
|
callable(:py:class:`DeclarativeCatalog`)
|
||||||
"""
|
"""
|
||||||
def decorator(overriding_catalog):
|
def decorator(copied_catalog):
|
||||||
"""Overriding decorator.
|
"""Copying decorator.
|
||||||
|
|
||||||
:param catalog: Decorated catalog.
|
:param copied_catalog: Decorated catalog.
|
||||||
:type catalog:
|
:type copied_catalog: :py:class:`DeclarativeCatalog`
|
||||||
:py:class:`dependency_injector.catalogs.DeclarativeCatalog`
|
|
||||||
|
|
||||||
:return: Decorated catalog.
|
:return: Decorated catalog.
|
||||||
:rtype:
|
:rtype:
|
||||||
:py:class:`dependency_injector.catalogs.DeclarativeCatalog`
|
:py:class:`DeclarativeCatalog`
|
||||||
"""
|
"""
|
||||||
memo = dict()
|
memo = dict()
|
||||||
|
for name, provider in six.iteritems(copied_catalog.cls_providers):
|
||||||
|
try:
|
||||||
|
source_provider = catalog.get_provider(name)
|
||||||
|
except UndefinedProviderError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
memo[id(source_provider)] = provider
|
||||||
|
|
||||||
for name, provider in six.iteritems(overriding_catalog.providers):
|
copied_catalog.bind_providers(deepcopy(catalog.providers, memo),
|
||||||
memo[id(catalog.get_provider(name))] = provider
|
force=True)
|
||||||
|
|
||||||
dynamic_catalog_copy = deepcopy(catalog._catalog, memo)
|
return copied_catalog
|
||||||
|
|
||||||
print dynamic_catalog_copy.providers
|
|
||||||
|
|
||||||
for name, provider in six.iteritems(dynamic_catalog_copy.providers):
|
|
||||||
overriding_catalog.bind_provider(name, provider)
|
|
||||||
|
|
||||||
return overriding_catalog
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +51,7 @@ def override(catalog):
|
||||||
""":py:class:`DeclarativeCatalog` overriding decorator.
|
""":py:class:`DeclarativeCatalog` overriding decorator.
|
||||||
|
|
||||||
:param catalog: Catalog that should be overridden by decorated catalog.
|
:param catalog: Catalog that should be overridden by decorated catalog.
|
||||||
:type catalog: :py:class:`DeclarativeCatalog` | :py:class:`DynamicCatalog`
|
:type catalog: :py:class:`DeclarativeCatalog`
|
||||||
|
|
||||||
:return: Declarative catalog's overriding decorator.
|
:return: Declarative catalog's overriding decorator.
|
||||||
:rtype: callable(:py:class:`DeclarativeCatalog`)
|
:rtype: callable(:py:class:`DeclarativeCatalog`)
|
||||||
|
|
|
@ -29,6 +29,17 @@ class DeclarativeCatalogTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_cls_providers(self):
|
def test_cls_providers(self):
|
||||||
"""Test `di.DeclarativeCatalog.cls_providers` contents."""
|
"""Test `di.DeclarativeCatalog.cls_providers` contents."""
|
||||||
|
class CatalogA(catalogs.DeclarativeCatalog):
|
||||||
|
"""Test catalog A."""
|
||||||
|
|
||||||
|
p11 = providers.Provider()
|
||||||
|
p12 = providers.Provider()
|
||||||
|
|
||||||
|
class CatalogB(CatalogA):
|
||||||
|
"""Test catalog B."""
|
||||||
|
|
||||||
|
p21 = providers.Provider()
|
||||||
|
p22 = providers.Provider()
|
||||||
self.assertDictEqual(CatalogA.cls_providers,
|
self.assertDictEqual(CatalogA.cls_providers,
|
||||||
dict(p11=CatalogA.p11,
|
dict(p11=CatalogA.p11,
|
||||||
p12=CatalogA.p12))
|
p12=CatalogA.p12))
|
||||||
|
@ -71,6 +82,14 @@ class DeclarativeCatalogTests(unittest.TestCase):
|
||||||
del CatalogA.px
|
del CatalogA.px
|
||||||
del CatalogA.py
|
del CatalogA.py
|
||||||
|
|
||||||
|
def test_bind_existing_provider(self):
|
||||||
|
"""Test setting of provider via bind_provider() to catalog."""
|
||||||
|
with self.assertRaises(errors.Error):
|
||||||
|
CatalogA.p11 = providers.Provider()
|
||||||
|
|
||||||
|
with self.assertRaises(errors.Error):
|
||||||
|
CatalogA.bind_provider('p11', providers.Provider())
|
||||||
|
|
||||||
def test_bind_provider_with_valid_provided_type(self):
|
def test_bind_provider_with_valid_provided_type(self):
|
||||||
"""Test setting of provider with provider type restriction."""
|
"""Test setting of provider with provider type restriction."""
|
||||||
class SomeProvider(providers.Provider):
|
class SomeProvider(providers.Provider):
|
||||||
|
@ -350,3 +369,62 @@ class TestCatalogWithProvidingCallbacks(unittest.TestCase):
|
||||||
auth_service = Services.auth()
|
auth_service = Services.auth()
|
||||||
|
|
||||||
self.assertIsInstance(auth_service, ExtendedAuthService)
|
self.assertIsInstance(auth_service, ExtendedAuthService)
|
||||||
|
|
||||||
|
|
||||||
|
class CopyingTests(unittest.TestCase):
|
||||||
|
"""Declarative catalogs copying tests."""
|
||||||
|
|
||||||
|
def test_copy(self):
|
||||||
|
"""Test catalog providers copying."""
|
||||||
|
@catalogs.copy(CatalogA)
|
||||||
|
class CatalogA1(CatalogA):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@catalogs.copy(CatalogA)
|
||||||
|
class CatalogA2(CatalogA):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertIsNot(CatalogA.p11, CatalogA1.p11)
|
||||||
|
self.assertIsNot(CatalogA.p12, CatalogA1.p12)
|
||||||
|
|
||||||
|
self.assertIsNot(CatalogA.p11, CatalogA2.p11)
|
||||||
|
self.assertIsNot(CatalogA.p12, CatalogA2.p12)
|
||||||
|
|
||||||
|
self.assertIsNot(CatalogA1.p11, CatalogA2.p11)
|
||||||
|
self.assertIsNot(CatalogA1.p12, CatalogA2.p12)
|
||||||
|
|
||||||
|
def test_copy_with_replacing(self):
|
||||||
|
"""Test catalog providers copying."""
|
||||||
|
class CatalogA(catalogs.DeclarativeCatalog):
|
||||||
|
p11 = providers.Value(0)
|
||||||
|
p12 = providers.Factory(dict, p11=p11)
|
||||||
|
|
||||||
|
@catalogs.copy(CatalogA)
|
||||||
|
class CatalogA1(CatalogA):
|
||||||
|
p11 = providers.Value(1)
|
||||||
|
p13 = providers.Value(11)
|
||||||
|
|
||||||
|
@catalogs.copy(CatalogA)
|
||||||
|
class CatalogA2(CatalogA):
|
||||||
|
p11 = providers.Value(2)
|
||||||
|
p13 = providers.Value(22)
|
||||||
|
|
||||||
|
self.assertIsNot(CatalogA.p11, CatalogA1.p11)
|
||||||
|
self.assertIsNot(CatalogA.p12, CatalogA1.p12)
|
||||||
|
|
||||||
|
self.assertIsNot(CatalogA.p11, CatalogA2.p11)
|
||||||
|
self.assertIsNot(CatalogA.p12, CatalogA2.p12)
|
||||||
|
|
||||||
|
self.assertIsNot(CatalogA1.p11, CatalogA2.p11)
|
||||||
|
self.assertIsNot(CatalogA1.p12, CatalogA2.p12)
|
||||||
|
|
||||||
|
self.assertIs(CatalogA.p12.injections[0].injectable, CatalogA.p11)
|
||||||
|
self.assertIs(CatalogA1.p12.injections[0].injectable, CatalogA1.p11)
|
||||||
|
self.assertIs(CatalogA2.p12.injections[0].injectable, CatalogA2.p11)
|
||||||
|
|
||||||
|
self.assertEqual(CatalogA.p12(), dict(p11=0))
|
||||||
|
self.assertEqual(CatalogA1.p12(), dict(p11=1))
|
||||||
|
self.assertEqual(CatalogA2.p12(), dict(p11=2))
|
||||||
|
|
||||||
|
self.assertEqual(CatalogA1.p13(), 11)
|
||||||
|
self.assertEqual(CatalogA2.p13(), 22)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user