diff --git a/dependency_injector/catalogs/declarative.py b/dependency_injector/catalogs/declarative.py index 0ef0da13..c6fa485d 100644 --- a/dependency_injector/catalogs/declarative.py +++ b/dependency_injector/catalogs/declarative.py @@ -123,7 +123,7 @@ class DeclarativeCatalogMetaClass(type): :rtype: None """ if is_provider(value): - setattr(cls._catalog, name, value) + cls.bind_provider(name, value, _set_as_attribute=False) return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value) def __delattr__(cls, name): @@ -348,7 +348,8 @@ class DeclarativeCatalog(object): get = get_provider # Backward compatibility for versions < 0.11.* @classmethod - def bind_provider(cls, name, provider): + def bind_provider(cls, name, provider, force=False, + _set_as_attribute=True): """Bind provider to catalog with specified name. :param name: Name of the provider. @@ -357,14 +358,26 @@ class DeclarativeCatalog(object): :param provider: Provider instance. :type provider: :py:class:`dependency_injector.providers.Provider` + :param force: Force binding of provider. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None """ - setattr(cls, name, provider) + if cls._catalog.is_provider_bound(provider): + bindind_name = cls._catalog.get_provider_bind_name(provider) + if bindind_name == name and not force: + return + + cls._catalog.bind_provider(name, provider, force) + cls.cls_providers[name] = provider + + if _set_as_attribute: + setattr(cls, name, provider) @classmethod - def bind_providers(cls, providers): + def bind_providers(cls, providers, force=False): """Bind providers dictionary to catalog. :param providers: Dictionary of providers, where key is a name @@ -372,12 +385,15 @@ class DeclarativeCatalog(object): :type providers: dict[str, :py:class:`dependency_injector.providers.Provider`] + :param force: Force binding of providers. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None """ for name, provider in six.iteritems(providers): - setattr(cls, name, provider) + cls.bind_provider(name, provider, force=force) @classmethod def has_provider(cls, name): @@ -402,6 +418,7 @@ class DeclarativeCatalog(object): :rtype: None """ delattr(cls, name) + del cls.cls_providers[name] @classmethod def __getattr__(cls, name): # pragma: no cover diff --git a/dependency_injector/catalogs/utils.py b/dependency_injector/catalogs/utils.py index fc7f30dd..367e1b5f 100644 --- a/dependency_injector/catalogs/utils.py +++ b/dependency_injector/catalogs/utils.py @@ -4,42 +4,46 @@ import six from copy import deepcopy +from dependency_injector.errors import UndefinedProviderError + def copy(catalog): """:py:class:`DeclarativeCatalog` copying decorator. + This decorator copy all providers from provided catalog to decorated one. + If one of the decorated catalog providers matches to source catalog + providers by name, it would be replaced by reference. + :param catalog: Catalog that should be copied by decorated catalog. :type catalog: :py:class:`dependency_injector.catalogs.DeclarativeCatalog` - | :py:class:`dependency_injector.catalogs.DynamicCatalog` :return: Declarative catalog's copying decorator. :rtype: - callable(:py:class:`dependency_injector.catalogs.DeclarativeCatalog`) + callable(:py:class:`DeclarativeCatalog`) """ - def decorator(overriding_catalog): - """Overriding decorator. + def decorator(copied_catalog): + """Copying decorator. - :param catalog: Decorated catalog. - :type catalog: - :py:class:`dependency_injector.catalogs.DeclarativeCatalog` + :param copied_catalog: Decorated catalog. + :type copied_catalog: :py:class:`DeclarativeCatalog` :return: Decorated catalog. :rtype: - :py:class:`dependency_injector.catalogs.DeclarativeCatalog` + :py:class:`DeclarativeCatalog` """ memo = dict() + for name, provider in six.iteritems(copied_catalog.cls_providers): + try: + source_provider = catalog.get_provider(name) + except UndefinedProviderError: + pass + else: + memo[id(source_provider)] = provider - for name, provider in six.iteritems(overriding_catalog.providers): - memo[id(catalog.get_provider(name))] = provider + copied_catalog.bind_providers(deepcopy(catalog.providers, memo), + force=True) - dynamic_catalog_copy = deepcopy(catalog._catalog, memo) - - print dynamic_catalog_copy.providers - - for name, provider in six.iteritems(dynamic_catalog_copy.providers): - overriding_catalog.bind_provider(name, provider) - - return overriding_catalog + return copied_catalog return decorator @@ -47,7 +51,7 @@ def override(catalog): """:py:class:`DeclarativeCatalog` overriding decorator. :param catalog: Catalog that should be overridden by decorated catalog. - :type catalog: :py:class:`DeclarativeCatalog` | :py:class:`DynamicCatalog` + :type catalog: :py:class:`DeclarativeCatalog` :return: Declarative catalog's overriding decorator. :rtype: callable(:py:class:`DeclarativeCatalog`) diff --git a/tests/catalogs/test_declarative.py b/tests/catalogs/test_declarative.py index 3871bbab..dcd8243c 100644 --- a/tests/catalogs/test_declarative.py +++ b/tests/catalogs/test_declarative.py @@ -29,6 +29,17 @@ class DeclarativeCatalogTests(unittest.TestCase): def test_cls_providers(self): """Test `di.DeclarativeCatalog.cls_providers` contents.""" + class CatalogA(catalogs.DeclarativeCatalog): + """Test catalog A.""" + + p11 = providers.Provider() + p12 = providers.Provider() + + class CatalogB(CatalogA): + """Test catalog B.""" + + p21 = providers.Provider() + p22 = providers.Provider() self.assertDictEqual(CatalogA.cls_providers, dict(p11=CatalogA.p11, p12=CatalogA.p12)) @@ -71,6 +82,14 @@ class DeclarativeCatalogTests(unittest.TestCase): del CatalogA.px del CatalogA.py + def test_bind_existing_provider(self): + """Test setting of provider via bind_provider() to catalog.""" + with self.assertRaises(errors.Error): + CatalogA.p11 = providers.Provider() + + with self.assertRaises(errors.Error): + CatalogA.bind_provider('p11', providers.Provider()) + def test_bind_provider_with_valid_provided_type(self): """Test setting of provider with provider type restriction.""" class SomeProvider(providers.Provider): @@ -350,3 +369,62 @@ class TestCatalogWithProvidingCallbacks(unittest.TestCase): auth_service = Services.auth() self.assertIsInstance(auth_service, ExtendedAuthService) + + +class CopyingTests(unittest.TestCase): + """Declarative catalogs copying tests.""" + + def test_copy(self): + """Test catalog providers copying.""" + @catalogs.copy(CatalogA) + class CatalogA1(CatalogA): + pass + + @catalogs.copy(CatalogA) + class CatalogA2(CatalogA): + pass + + self.assertIsNot(CatalogA.p11, CatalogA1.p11) + self.assertIsNot(CatalogA.p12, CatalogA1.p12) + + self.assertIsNot(CatalogA.p11, CatalogA2.p11) + self.assertIsNot(CatalogA.p12, CatalogA2.p12) + + self.assertIsNot(CatalogA1.p11, CatalogA2.p11) + self.assertIsNot(CatalogA1.p12, CatalogA2.p12) + + def test_copy_with_replacing(self): + """Test catalog providers copying.""" + class CatalogA(catalogs.DeclarativeCatalog): + p11 = providers.Value(0) + p12 = providers.Factory(dict, p11=p11) + + @catalogs.copy(CatalogA) + class CatalogA1(CatalogA): + p11 = providers.Value(1) + p13 = providers.Value(11) + + @catalogs.copy(CatalogA) + class CatalogA2(CatalogA): + p11 = providers.Value(2) + p13 = providers.Value(22) + + self.assertIsNot(CatalogA.p11, CatalogA1.p11) + self.assertIsNot(CatalogA.p12, CatalogA1.p12) + + self.assertIsNot(CatalogA.p11, CatalogA2.p11) + self.assertIsNot(CatalogA.p12, CatalogA2.p12) + + self.assertIsNot(CatalogA1.p11, CatalogA2.p11) + self.assertIsNot(CatalogA1.p12, CatalogA2.p12) + + self.assertIs(CatalogA.p12.injections[0].injectable, CatalogA.p11) + self.assertIs(CatalogA1.p12.injections[0].injectable, CatalogA1.p11) + self.assertIs(CatalogA2.p12.injections[0].injectable, CatalogA2.p11) + + self.assertEqual(CatalogA.p12(), dict(p11=0)) + self.assertEqual(CatalogA1.p12(), dict(p11=1)) + self.assertEqual(CatalogA2.p12(), dict(p11=2)) + + self.assertEqual(CatalogA1.p13(), 11) + self.assertEqual(CatalogA2.p13(), 22)