diff --git a/dependency_injector/catalogs/__init__.py b/dependency_injector/catalogs/__init__.py index 7b2baae2..a7405e76 100644 --- a/dependency_injector/catalogs/__init__.py +++ b/dependency_injector/catalogs/__init__.py @@ -1,30 +1,16 @@ """Dependency injector catalogs package.""" from dependency_injector.catalogs.bundle import CatalogBundle - from dependency_injector.catalogs.dynamic import DynamicCatalog - from dependency_injector.catalogs.declarative import ( DeclarativeCatalogMetaClass, DeclarativeCatalog, AbstractCatalog, ) - - -def override(catalog): - """:py:class:`DeclarativeCatalog` overriding decorator. - - :param catalog: Catalog that should be overridden by decorated catalog. - :type catalog: :py:class:`DeclarativeCatalog` | :py:class:`DynamicCatalog` - - :return: Declarative catalog's overriding decorator. - :rtype: callable(:py:class:`DeclarativeCatalog`) - """ - def decorator(overriding_catalog): - """Overriding decorator.""" - catalog.override(overriding_catalog) - return overriding_catalog - return decorator +from dependency_injector.catalogs.utils import ( + copy, + override +) __all__ = ( @@ -33,5 +19,6 @@ __all__ = ( 'DeclarativeCatalogMetaClass', 'DeclarativeCatalog', 'AbstractCatalog', + 'copy', 'override', ) diff --git a/dependency_injector/catalogs/declarative.py b/dependency_injector/catalogs/declarative.py index c1ad3310..c6fa485d 100644 --- a/dependency_injector/catalogs/declarative.py +++ b/dependency_injector/catalogs/declarative.py @@ -4,13 +4,11 @@ import six from dependency_injector.catalogs.dynamic import DynamicCatalog from dependency_injector.catalogs.bundle import CatalogBundle - from dependency_injector.utils import ( is_provider, is_catalog, is_declarative_catalog, ) - from dependency_injector.errors import ( Error, UndefinedProviderError, @@ -125,7 +123,7 @@ class DeclarativeCatalogMetaClass(type): :rtype: None """ if is_provider(value): - setattr(cls._catalog, name, value) + cls.bind_provider(name, value, _set_as_attribute=False) return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value) def __delattr__(cls, name): @@ -350,7 +348,8 @@ class DeclarativeCatalog(object): get = get_provider # Backward compatibility for versions < 0.11.* @classmethod - def bind_provider(cls, name, provider): + def bind_provider(cls, name, provider, force=False, + _set_as_attribute=True): """Bind provider to catalog with specified name. :param name: Name of the provider. @@ -359,14 +358,26 @@ class DeclarativeCatalog(object): :param provider: Provider instance. :type provider: :py:class:`dependency_injector.providers.Provider` + :param force: Force binding of provider. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None """ - setattr(cls, name, provider) + if cls._catalog.is_provider_bound(provider): + bindind_name = cls._catalog.get_provider_bind_name(provider) + if bindind_name == name and not force: + return + + cls._catalog.bind_provider(name, provider, force) + cls.cls_providers[name] = provider + + if _set_as_attribute: + setattr(cls, name, provider) @classmethod - def bind_providers(cls, providers): + def bind_providers(cls, providers, force=False): """Bind providers dictionary to catalog. :param providers: Dictionary of providers, where key is a name @@ -374,12 +385,15 @@ class DeclarativeCatalog(object): :type providers: dict[str, :py:class:`dependency_injector.providers.Provider`] + :param force: Force binding of providers. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None """ for name, provider in six.iteritems(providers): - setattr(cls, name, provider) + cls.bind_provider(name, provider, force=force) @classmethod def has_provider(cls, name): @@ -404,6 +418,7 @@ class DeclarativeCatalog(object): :rtype: None """ delattr(cls, name) + del cls.cls_providers[name] @classmethod def __getattr__(cls, name): # pragma: no cover diff --git a/dependency_injector/catalogs/dynamic.py b/dependency_injector/catalogs/dynamic.py index 402a090b..607b8124 100644 --- a/dependency_injector/catalogs/dynamic.py +++ b/dependency_injector/catalogs/dynamic.py @@ -3,13 +3,11 @@ import six from dependency_injector.catalogs.bundle import CatalogBundle - from dependency_injector.utils import ( is_provider, ensure_is_provider, ensure_is_catalog_bundle, ) - from dependency_injector.errors import ( Error, UndefinedProviderError, @@ -213,7 +211,7 @@ class DynamicCatalog(object): raise UndefinedProviderError('{0} has no provider with such ' 'name - {1}'.format(self, name)) - def bind_provider(self, name, provider): + def bind_provider(self, name, provider, force=False): """Bind provider to catalog with specified name. :param name: Name of the provider. @@ -222,6 +220,9 @@ class DynamicCatalog(object): :param provider: Provider instance. :type provider: :py:class:`dependency_injector.providers.Provider` + :param force: Force binding of provider. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None @@ -233,17 +234,18 @@ class DynamicCatalog(object): raise Error('{0} can contain only {1} instances'.format( self, self.__class__.provider_type)) - if name in self.providers: - raise Error('Catalog {0} already has provider with ' - 'such name - {1}'.format(self, name)) - if provider in self.provider_names: - raise Error('Catalog {0} already has such provider ' - 'instance - {1}'.format(self, provider)) + if not force: + if name in self.providers: + raise Error('Catalog {0} already has provider with ' + 'such name - {1}'.format(self, name)) + if provider in self.provider_names: + raise Error('Catalog {0} already has such provider ' + 'instance - {1}'.format(self, provider)) self.providers[name] = provider self.provider_names[provider] = name - def bind_providers(self, providers): + def bind_providers(self, providers, force=False): """Bind providers dictionary to catalog. :param providers: Dictionary of providers, where key is a name @@ -251,12 +253,15 @@ class DynamicCatalog(object): :type providers: dict[str, :py:class:`dependency_injector.providers.Provider`] + :param force: Force binding of providers. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None """ for name, provider in six.iteritems(providers): - self.bind_provider(name, provider) + self.bind_provider(name, provider, force) def has_provider(self, name): """Check if there is provider with certain name. diff --git a/dependency_injector/catalogs/utils.py b/dependency_injector/catalogs/utils.py new file mode 100644 index 00000000..150c2ffc --- /dev/null +++ b/dependency_injector/catalogs/utils.py @@ -0,0 +1,62 @@ +"""Dependency injector catalog utils.""" + +import six + +from dependency_injector.utils import _copy_providers +from dependency_injector.errors import UndefinedProviderError + + +def copy(catalog): + """:py:class:`DeclarativeCatalog` copying decorator. + + This decorator copy all providers from provided catalog to decorated one. + If one of the decorated catalog providers matches to source catalog + providers by name, it would be replaced by reference. + + :param catalog: Catalog that should be copied by decorated catalog. + :type catalog: :py:class:`dependency_injector.catalogs.DeclarativeCatalog` + + :return: Declarative catalog's copying decorator. + :rtype: + callable(:py:class:`DeclarativeCatalog`) + """ + def decorator(copied_catalog): + """Copying decorator. + + :param copied_catalog: Decorated catalog. + :type copied_catalog: :py:class:`DeclarativeCatalog` + + :return: Decorated catalog. + :rtype: + :py:class:`DeclarativeCatalog` + """ + memo = dict() + for name, provider in six.iteritems(copied_catalog.cls_providers): + try: + source_provider = catalog.get_provider(name) + except UndefinedProviderError: + pass + else: + memo[id(source_provider)] = provider + + copied_catalog.bind_providers(_copy_providers(catalog.providers, memo), + force=True) + + return copied_catalog + return decorator + + +def override(catalog): + """:py:class:`DeclarativeCatalog` overriding decorator. + + :param catalog: Catalog that should be overridden by decorated catalog. + :type catalog: :py:class:`DeclarativeCatalog` + + :return: Declarative catalog's overriding decorator. + :rtype: callable(:py:class:`DeclarativeCatalog`) + """ + def decorator(overriding_catalog): + """Overriding decorator.""" + catalog.override(overriding_catalog) + return overriding_catalog + return decorator diff --git a/dependency_injector/providers/__init__.py b/dependency_injector/providers/__init__.py index 4d927de8..99324b3c 100644 --- a/dependency_injector/providers/__init__.py +++ b/dependency_injector/providers/__init__.py @@ -6,8 +6,6 @@ from dependency_injector.providers.base import ( Static, StaticProvider, ExternalDependency, - OverridingContext, - override, ) from dependency_injector.providers.callable import ( Callable, @@ -29,6 +27,10 @@ from dependency_injector.providers.config import ( Config, ChildConfig, ) +from dependency_injector.providers.utils import ( + OverridingContext, + override, +) __all__ = ( diff --git a/dependency_injector/providers/base.py b/dependency_injector/providers/base.py index 093e02fe..ba21300d 100644 --- a/dependency_injector/providers/base.py +++ b/dependency_injector/providers/base.py @@ -2,8 +2,8 @@ import six +from dependency_injector.providers.utils import OverridingContext from dependency_injector.errors import Error - from dependency_injector.utils import ( is_provider, ensure_is_provider, @@ -356,67 +356,3 @@ class ExternalDependency(Provider): return represent_provider(provider=self, provides=self.instance_of) __repr__ = __str__ - - -class OverridingContext(object): - """Provider overriding context. - - :py:class:`OverridingContext` is used by :py:meth:`Provider.override` for - implemeting ``with`` contexts. When :py:class:`OverridingContext` is - closed, overriding that was created in this context is dropped also. - - .. code-block:: python - - with provider.override(another_provider): - assert provider.is_overridden - assert not provider.is_overridden - """ - - def __init__(self, overridden, overriding): - """Initializer. - - :param overridden: Overridden provider. - :type overridden: :py:class:`Provider` - - :param overriding: Overriding provider. - :type overriding: :py:class:`Provider` - """ - self.overridden = overridden - self.overriding = overriding - - def __enter__(self): - """Do nothing.""" - return self.overriding - - def __exit__(self, *_): - """Exit overriding context.""" - self.overridden.reset_last_overriding() - - -def override(overridden): - """Decorator for overriding providers. - - This decorator overrides ``overridden`` provider by decorated one. - - .. code-block:: python - - @Factory - class SomeClass(object): - pass - - - @override(SomeClass) - @Factory - class ExtendedSomeClass(SomeClass.cls): - pass - - :param overridden: Provider that should be overridden. - :type overridden: :py:class:`Provider` - - :return: Overriding provider. - :rtype: :py:class:`Provider` - """ - def decorator(overriding): - overridden.override(overriding) - return overriding - return decorator diff --git a/dependency_injector/providers/callable.py b/dependency_injector/providers/callable.py index 65b9abf8..1a21f6ab 100644 --- a/dependency_injector/providers/callable.py +++ b/dependency_injector/providers/callable.py @@ -3,14 +3,11 @@ import six from dependency_injector.providers.base import Provider - from dependency_injector.injections import ( _parse_args_injections, _parse_kwargs_injections, ) - from dependency_injector.utils import represent_provider - from dependency_injector.errors import Error diff --git a/dependency_injector/providers/config.py b/dependency_injector/providers/config.py index c3b004fa..f629f4df 100644 --- a/dependency_injector/providers/config.py +++ b/dependency_injector/providers/config.py @@ -3,10 +3,8 @@ import six from dependency_injector.providers.base import Provider - -from dependency_injector.errors import Error - from dependency_injector.utils import represent_provider +from dependency_injector.errors import Error @six.python_2_unicode_compatible diff --git a/dependency_injector/providers/creational.py b/dependency_injector/providers/creational.py index 322cc519..06a2419c 100644 --- a/dependency_injector/providers/creational.py +++ b/dependency_injector/providers/creational.py @@ -1,13 +1,11 @@ """Dependency injector creational providers.""" from dependency_injector.providers.callable import Callable - from dependency_injector.utils import ( is_attribute_injection, is_method_injection, GLOBAL_LOCK, ) - from dependency_injector.errors import Error diff --git a/dependency_injector/providers/utils.py b/dependency_injector/providers/utils.py new file mode 100644 index 00000000..40b9278e --- /dev/null +++ b/dependency_injector/providers/utils.py @@ -0,0 +1,65 @@ +"""Dependency injector provider utils.""" + + +class OverridingContext(object): + """Provider overriding context. + + :py:class:`OverridingContext` is used by :py:meth:`Provider.override` for + implemeting ``with`` contexts. When :py:class:`OverridingContext` is + closed, overriding that was created in this context is dropped also. + + .. code-block:: python + + with provider.override(another_provider): + assert provider.is_overridden + assert not provider.is_overridden + """ + + def __init__(self, overridden, overriding): + """Initializer. + + :param overridden: Overridden provider. + :type overridden: :py:class:`Provider` + + :param overriding: Overriding provider. + :type overriding: :py:class:`Provider` + """ + self.overridden = overridden + self.overriding = overriding + + def __enter__(self): + """Do nothing.""" + return self.overriding + + def __exit__(self, *_): + """Exit overriding context.""" + self.overridden.reset_last_overriding() + + +def override(overridden): + """Decorator for overriding providers. + + This decorator overrides ``overridden`` provider by decorated one. + + .. code-block:: python + + @Factory + class SomeClass(object): + pass + + + @override(SomeClass) + @Factory + class ExtendedSomeClass(SomeClass.cls): + pass + + :param overridden: Provider that should be overridden. + :type overridden: :py:class:`Provider` + + :return: Overriding provider. + :rtype: :py:class:`Provider` + """ + def decorator(overriding): + overridden.override(overriding) + return overriding + return decorator diff --git a/dependency_injector/utils.py b/dependency_injector/utils.py index c677d162..15941129 100644 --- a/dependency_injector/utils.py +++ b/dependency_injector/utils.py @@ -1,6 +1,8 @@ """Utils module.""" import sys +import copy +import types import threading import six @@ -20,6 +22,12 @@ if _IS_PYPY or six.PY3: # pragma: no cover else: # pragma: no cover _OBJECT_INIT = None +if six.PY2: # pragma: no cover + copy._deepcopy_dispatch[types.MethodType] = \ + lambda obj, memo: type(obj)(obj.im_func, + copy.deepcopy(obj.im_self, memo), + obj.im_class) + def is_provider(instance): """Check if instance is provider instance. @@ -245,3 +253,8 @@ def fetch_cls_init(cls): return None else: return cls_init + + +def _copy_providers(providers, memo=None): + """Make full copy of providers dictionary.""" + return copy.deepcopy(providers, memo) diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 3f5b15ef..e30ce304 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -9,7 +9,7 @@ follows `Semantic versioning`_ Development version ------------------- -- No features. +- Add ``@copy`` decorator for copying declarative catalog providers. 1.15.2 ------ diff --git a/examples/stories/movie_lister/apps_db_csv.py b/examples/stories/movie_lister/apps_db_csv.py new file mode 100644 index 00000000..e6de7745 --- /dev/null +++ b/examples/stories/movie_lister/apps_db_csv.py @@ -0,0 +1,76 @@ +"""A naive example of dependency injection in Python. + +Example implementation of dependency injection in Python from Martin Fowler's +article about dependency injection and inversion of control: + +http://www.martinfowler.com/articles/injection.html + +This mini application uses ``movies`` library, that is configured to work with +csv file movies database. +""" + +import sqlite3 + +from dependency_injector import catalogs +from dependency_injector import providers +from dependency_injector import injections + +from movies import MoviesModule +from movies import finders + +from settings import MOVIES_CSV_PATH +from settings import MOVIES_DB_PATH + + +class ApplicationModule(catalogs.DeclarativeCatalog): + """Catalog of application component providers.""" + + database = providers.Singleton(sqlite3.connect, MOVIES_DB_PATH) + + +@catalogs.copy(MoviesModule) +class DbMoviesModule(MoviesModule): + """Customized catalog of movies module component providers.""" + + movie_finder = providers.Factory(finders.SqliteMovieFinder, + *MoviesModule.movie_finder.injections, + database=ApplicationModule.database) + + +@catalogs.copy(MoviesModule) +class CsvMoviesModule(MoviesModule): + """Customized catalog of movies module component providers.""" + + movie_finder = providers.Factory(finders.CsvMovieFinder, + *MoviesModule.movie_finder.injections, + csv_file=MOVIES_CSV_PATH, + delimeter=',') + + +@injections.inject(db_movie_lister=DbMoviesModule.movie_lister) +@injections.inject(csv_movie_lister=CsvMoviesModule.movie_lister) +def main(db_movie_lister, csv_movie_lister): + """Main function. + + This program prints info about all movies that were directed by different + persons and then prints all movies that were released in 2015. + + :param db_movie_lister: Database movie lister instance + :type db_movie_lister: movies.listers.MovieLister + + :param csv_movie_lister: Database movie lister instance + :type csv_movie_lister: movies.listers.MovieLister + """ + print db_movie_lister.movies_directed_by('Francis Lawrence') + print db_movie_lister.movies_directed_by('Patricia Riggen') + print db_movie_lister.movies_directed_by('JJ Abrams') + print db_movie_lister.movies_released_in(2015) + + print csv_movie_lister.movies_directed_by('Francis Lawrence') + print csv_movie_lister.movies_directed_by('Patricia Riggen') + print csv_movie_lister.movies_directed_by('JJ Abrams') + print csv_movie_lister.movies_released_in(2015) + + +if __name__ == '__main__': + main() diff --git a/tests/catalogs/test_declarative.py b/tests/catalogs/test_declarative.py index 3871bbab..dcd8243c 100644 --- a/tests/catalogs/test_declarative.py +++ b/tests/catalogs/test_declarative.py @@ -29,6 +29,17 @@ class DeclarativeCatalogTests(unittest.TestCase): def test_cls_providers(self): """Test `di.DeclarativeCatalog.cls_providers` contents.""" + class CatalogA(catalogs.DeclarativeCatalog): + """Test catalog A.""" + + p11 = providers.Provider() + p12 = providers.Provider() + + class CatalogB(CatalogA): + """Test catalog B.""" + + p21 = providers.Provider() + p22 = providers.Provider() self.assertDictEqual(CatalogA.cls_providers, dict(p11=CatalogA.p11, p12=CatalogA.p12)) @@ -71,6 +82,14 @@ class DeclarativeCatalogTests(unittest.TestCase): del CatalogA.px del CatalogA.py + def test_bind_existing_provider(self): + """Test setting of provider via bind_provider() to catalog.""" + with self.assertRaises(errors.Error): + CatalogA.p11 = providers.Provider() + + with self.assertRaises(errors.Error): + CatalogA.bind_provider('p11', providers.Provider()) + def test_bind_provider_with_valid_provided_type(self): """Test setting of provider with provider type restriction.""" class SomeProvider(providers.Provider): @@ -350,3 +369,62 @@ class TestCatalogWithProvidingCallbacks(unittest.TestCase): auth_service = Services.auth() self.assertIsInstance(auth_service, ExtendedAuthService) + + +class CopyingTests(unittest.TestCase): + """Declarative catalogs copying tests.""" + + def test_copy(self): + """Test catalog providers copying.""" + @catalogs.copy(CatalogA) + class CatalogA1(CatalogA): + pass + + @catalogs.copy(CatalogA) + class CatalogA2(CatalogA): + pass + + self.assertIsNot(CatalogA.p11, CatalogA1.p11) + self.assertIsNot(CatalogA.p12, CatalogA1.p12) + + self.assertIsNot(CatalogA.p11, CatalogA2.p11) + self.assertIsNot(CatalogA.p12, CatalogA2.p12) + + self.assertIsNot(CatalogA1.p11, CatalogA2.p11) + self.assertIsNot(CatalogA1.p12, CatalogA2.p12) + + def test_copy_with_replacing(self): + """Test catalog providers copying.""" + class CatalogA(catalogs.DeclarativeCatalog): + p11 = providers.Value(0) + p12 = providers.Factory(dict, p11=p11) + + @catalogs.copy(CatalogA) + class CatalogA1(CatalogA): + p11 = providers.Value(1) + p13 = providers.Value(11) + + @catalogs.copy(CatalogA) + class CatalogA2(CatalogA): + p11 = providers.Value(2) + p13 = providers.Value(22) + + self.assertIsNot(CatalogA.p11, CatalogA1.p11) + self.assertIsNot(CatalogA.p12, CatalogA1.p12) + + self.assertIsNot(CatalogA.p11, CatalogA2.p11) + self.assertIsNot(CatalogA.p12, CatalogA2.p12) + + self.assertIsNot(CatalogA1.p11, CatalogA2.p11) + self.assertIsNot(CatalogA1.p12, CatalogA2.p12) + + self.assertIs(CatalogA.p12.injections[0].injectable, CatalogA.p11) + self.assertIs(CatalogA1.p12.injections[0].injectable, CatalogA1.p11) + self.assertIs(CatalogA2.p12.injections[0].injectable, CatalogA2.p11) + + self.assertEqual(CatalogA.p12(), dict(p11=0)) + self.assertEqual(CatalogA1.p12(), dict(p11=1)) + self.assertEqual(CatalogA2.p12(), dict(p11=2)) + + self.assertEqual(CatalogA1.p13(), 11) + self.assertEqual(CatalogA2.p13(), 22) diff --git a/tests/catalogs/test_dynamic.py b/tests/catalogs/test_dynamic.py index 15712884..7d74e222 100644 --- a/tests/catalogs/test_dynamic.py +++ b/tests/catalogs/test_dynamic.py @@ -41,6 +41,17 @@ class DynamicCatalogTests(unittest.TestCase): self.assertIs(self.catalog.py, py) self.assertIs(self.catalog.get_provider('py'), py) + def test_bind_existing_provider(self): + """Test setting of provider via bind_provider() to catalog.""" + with self.assertRaises(errors.Error): + self.catalog.bind_provider('p1', providers.Factory(object)) + + def test_force_bind_existing_provider(self): + """Test setting of provider via bind_provider() to catalog.""" + p1 = providers.Factory(object) + self.catalog.bind_provider('p1', p1, force=True) + self.assertIs(self.catalog.p1, p1) + def test_bind_provider_with_valid_provided_type(self): """Test setting of provider with provider type restriction.""" class SomeProvider(providers.Provider): @@ -99,6 +110,17 @@ class DynamicCatalogTests(unittest.TestCase): self.assertIs(self.catalog.py, py) self.assertIs(self.catalog.get_provider('py'), py) + def test_bind_providers_with_existing(self): + """Test setting of provider via bind_providers() to catalog.""" + with self.assertRaises(errors.Error): + self.catalog.bind_providers(dict(p1=providers.Factory(object))) + + def test_bind_providers_force(self): + """Test setting of provider via bind_providers() to catalog.""" + p1 = providers.Factory(object) + self.catalog.bind_providers(dict(p1=p1), force=True) + self.assertIs(self.catalog.p1, p1) + def test_setattr(self): """Test setting of providers via attributes to catalog.""" px = providers.Provider()