diff --git a/dependency_injector/catalogs/dynamic.py b/dependency_injector/catalogs/dynamic.py index d8e1618f..97340320 100644 --- a/dependency_injector/catalogs/dynamic.py +++ b/dependency_injector/catalogs/dynamic.py @@ -1,5 +1,7 @@ """Dependency injector dynamic catalog module.""" +import copy + import six from dependency_injector.catalogs.bundle import CatalogBundle @@ -211,7 +213,7 @@ class DynamicCatalog(object): raise UndefinedProviderError('{0} has no provider with such ' 'name - {1}'.format(self, name)) - def bind_provider(self, name, provider): + def bind_provider(self, name, provider, force=False): """Bind provider to catalog with specified name. :param name: Name of the provider. @@ -220,6 +222,9 @@ class DynamicCatalog(object): :param provider: Provider instance. :type provider: :py:class:`dependency_injector.providers.Provider` + :param force: Force binding of provider. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None @@ -231,17 +236,18 @@ class DynamicCatalog(object): raise Error('{0} can contain only {1} instances'.format( self, self.__class__.provider_type)) - if name in self.providers: - raise Error('Catalog {0} already has provider with ' - 'such name - {1}'.format(self, name)) - if provider in self.provider_names: - raise Error('Catalog {0} already has such provider ' - 'instance - {1}'.format(self, provider)) + if not force: + if name in self.providers: + raise Error('Catalog {0} already has provider with ' + 'such name - {1}'.format(self, name)) + if provider in self.provider_names: + raise Error('Catalog {0} already has such provider ' + 'instance - {1}'.format(self, provider)) self.providers[name] = provider self.provider_names[provider] = name - def bind_providers(self, providers): + def bind_providers(self, providers, force=False): """Bind providers dictionary to catalog. :param providers: Dictionary of providers, where key is a name @@ -249,12 +255,15 @@ class DynamicCatalog(object): :type providers: dict[str, :py:class:`dependency_injector.providers.Provider`] + :param force: Force binding of providers. + :type force: bool + :raise: :py:exc:`dependency_injector.errors.Error` :rtype: None """ for name, provider in six.iteritems(providers): - self.bind_provider(name, provider) + self.bind_provider(name, provider, force) def has_provider(self, name): """Check if there is provider with certain name. @@ -278,6 +287,25 @@ class DynamicCatalog(object): del self.providers[name] del self.provider_names[provider] + def copy(self): + """Copy catalog instance and return it. + + :rtype: py:class:`DynamicCatalog` + :return: Copied catalog. + """ + return copy.copy(self) + + def deepcopy(self, memo=None): + """Copy catalog instance and it's providers and return it. + + :param memo: Memorized instances + :type memo: dict[int, object] + + :rtype: py:class:`DynamicCatalog` + :return: Copied catalog. + """ + return copy.deepcopy(self, memo) + def __getattr__(self, name): """Return provider with specified name or raise en error. diff --git a/tests/catalogs/test_dynamic.py b/tests/catalogs/test_dynamic.py index 15712884..de8ed1f2 100644 --- a/tests/catalogs/test_dynamic.py +++ b/tests/catalogs/test_dynamic.py @@ -41,6 +41,17 @@ class DynamicCatalogTests(unittest.TestCase): self.assertIs(self.catalog.py, py) self.assertIs(self.catalog.get_provider('py'), py) + def test_bind_existing_provider(self): + """Test setting of provider via bind_provider() to catalog.""" + with self.assertRaises(errors.Error): + self.catalog.bind_provider('p1', providers.Factory(object)) + + def test_force_bind_existing_provider(self): + """Test setting of provider via bind_provider() to catalog.""" + p1 = providers.Factory(object) + self.catalog.bind_provider('p1', p1, force=True) + self.assertIs(self.catalog.p1, p1) + def test_bind_provider_with_valid_provided_type(self): """Test setting of provider with provider type restriction.""" class SomeProvider(providers.Provider): @@ -99,6 +110,17 @@ class DynamicCatalogTests(unittest.TestCase): self.assertIs(self.catalog.py, py) self.assertIs(self.catalog.get_provider('py'), py) + def test_bind_providers_with_existing(self): + """Test setting of provider via bind_providers() to catalog.""" + with self.assertRaises(errors.Error): + self.catalog.bind_providers(dict(p1=providers.Factory(object))) + + def test_bind_providers_force(self): + """Test setting of provider via bind_providers() to catalog.""" + p1 = providers.Factory(object) + self.catalog.bind_providers(dict(p1=p1), force=True) + self.assertIs(self.catalog.p1, p1) + def test_setattr(self): """Test setting of providers via attributes to catalog.""" px = providers.Provider() @@ -190,6 +212,22 @@ class DynamicCatalogTests(unittest.TestCase): self.assertTrue(len(self.catalog.filter(providers.Provider)) == 2) self.assertTrue(len(self.catalog.filter(providers.Value)) == 0) + def test_copy(self): + """Test copying of catalog.""" + catalog_copy = self.catalog.copy() + + self.assertIsNot(self.catalog, catalog_copy) + self.assertIs(self.catalog.p1, catalog_copy.p1) + self.assertIs(self.catalog.p2, catalog_copy.p2) + + def test_deepcopy(self): + """Test copying of catalog.""" + catalog_copy = self.catalog.deepcopy() + + self.assertIsNot(self.catalog, catalog_copy) + self.assertIsNot(self.catalog.p1, catalog_copy.p1) + self.assertIsNot(self.catalog.p2, catalog_copy.p2) + def test_repr(self): """Test catalog representation.""" self.assertIn('TestCatalog', repr(self.catalog))