Add force binding and copying functionaloty for DynamicCatalog

This commit is contained in:
Roman Mogilatov 2016-04-10 16:41:54 +03:00
parent 3a7b9c1e98
commit 43258e5fd9
2 changed files with 75 additions and 9 deletions

View File

@ -1,5 +1,7 @@
"""Dependency injector dynamic catalog module.""" """Dependency injector dynamic catalog module."""
import copy
import six import six
from dependency_injector.catalogs.bundle import CatalogBundle from dependency_injector.catalogs.bundle import CatalogBundle
@ -211,7 +213,7 @@ class DynamicCatalog(object):
raise UndefinedProviderError('{0} has no provider with such ' raise UndefinedProviderError('{0} has no provider with such '
'name - {1}'.format(self, name)) 'name - {1}'.format(self, name))
def bind_provider(self, name, provider): def bind_provider(self, name, provider, force=False):
"""Bind provider to catalog with specified name. """Bind provider to catalog with specified name.
:param name: Name of the provider. :param name: Name of the provider.
@ -220,6 +222,9 @@ class DynamicCatalog(object):
:param provider: Provider instance. :param provider: Provider instance.
:type provider: :py:class:`dependency_injector.providers.Provider` :type provider: :py:class:`dependency_injector.providers.Provider`
:param force: Force binding of provider.
:type force: bool
:raise: :py:exc:`dependency_injector.errors.Error` :raise: :py:exc:`dependency_injector.errors.Error`
:rtype: None :rtype: None
@ -231,17 +236,18 @@ class DynamicCatalog(object):
raise Error('{0} can contain only {1} instances'.format( raise Error('{0} can contain only {1} instances'.format(
self, self.__class__.provider_type)) self, self.__class__.provider_type))
if name in self.providers: if not force:
raise Error('Catalog {0} already has provider with ' if name in self.providers:
'such name - {1}'.format(self, name)) raise Error('Catalog {0} already has provider with '
if provider in self.provider_names: 'such name - {1}'.format(self, name))
raise Error('Catalog {0} already has such provider ' if provider in self.provider_names:
'instance - {1}'.format(self, provider)) raise Error('Catalog {0} already has such provider '
'instance - {1}'.format(self, provider))
self.providers[name] = provider self.providers[name] = provider
self.provider_names[provider] = name self.provider_names[provider] = name
def bind_providers(self, providers): def bind_providers(self, providers, force=False):
"""Bind providers dictionary to catalog. """Bind providers dictionary to catalog.
:param providers: Dictionary of providers, where key is a name :param providers: Dictionary of providers, where key is a name
@ -249,12 +255,15 @@ class DynamicCatalog(object):
:type providers: :type providers:
dict[str, :py:class:`dependency_injector.providers.Provider`] dict[str, :py:class:`dependency_injector.providers.Provider`]
:param force: Force binding of providers.
:type force: bool
:raise: :py:exc:`dependency_injector.errors.Error` :raise: :py:exc:`dependency_injector.errors.Error`
:rtype: None :rtype: None
""" """
for name, provider in six.iteritems(providers): for name, provider in six.iteritems(providers):
self.bind_provider(name, provider) self.bind_provider(name, provider, force)
def has_provider(self, name): def has_provider(self, name):
"""Check if there is provider with certain name. """Check if there is provider with certain name.
@ -278,6 +287,25 @@ class DynamicCatalog(object):
del self.providers[name] del self.providers[name]
del self.provider_names[provider] del self.provider_names[provider]
def copy(self):
"""Copy catalog instance and return it.
:rtype: py:class:`DynamicCatalog`
:return: Copied catalog.
"""
return copy.copy(self)
def deepcopy(self, memo=None):
"""Copy catalog instance and it's providers and return it.
:param memo: Memorized instances
:type memo: dict[int, object]
:rtype: py:class:`DynamicCatalog`
:return: Copied catalog.
"""
return copy.deepcopy(self, memo)
def __getattr__(self, name): def __getattr__(self, name):
"""Return provider with specified name or raise en error. """Return provider with specified name or raise en error.

View File

@ -41,6 +41,17 @@ class DynamicCatalogTests(unittest.TestCase):
self.assertIs(self.catalog.py, py) self.assertIs(self.catalog.py, py)
self.assertIs(self.catalog.get_provider('py'), py) self.assertIs(self.catalog.get_provider('py'), py)
def test_bind_existing_provider(self):
"""Test setting of provider via bind_provider() to catalog."""
with self.assertRaises(errors.Error):
self.catalog.bind_provider('p1', providers.Factory(object))
def test_force_bind_existing_provider(self):
"""Test setting of provider via bind_provider() to catalog."""
p1 = providers.Factory(object)
self.catalog.bind_provider('p1', p1, force=True)
self.assertIs(self.catalog.p1, p1)
def test_bind_provider_with_valid_provided_type(self): def test_bind_provider_with_valid_provided_type(self):
"""Test setting of provider with provider type restriction.""" """Test setting of provider with provider type restriction."""
class SomeProvider(providers.Provider): class SomeProvider(providers.Provider):
@ -99,6 +110,17 @@ class DynamicCatalogTests(unittest.TestCase):
self.assertIs(self.catalog.py, py) self.assertIs(self.catalog.py, py)
self.assertIs(self.catalog.get_provider('py'), py) self.assertIs(self.catalog.get_provider('py'), py)
def test_bind_providers_with_existing(self):
"""Test setting of provider via bind_providers() to catalog."""
with self.assertRaises(errors.Error):
self.catalog.bind_providers(dict(p1=providers.Factory(object)))
def test_bind_providers_force(self):
"""Test setting of provider via bind_providers() to catalog."""
p1 = providers.Factory(object)
self.catalog.bind_providers(dict(p1=p1), force=True)
self.assertIs(self.catalog.p1, p1)
def test_setattr(self): def test_setattr(self):
"""Test setting of providers via attributes to catalog.""" """Test setting of providers via attributes to catalog."""
px = providers.Provider() px = providers.Provider()
@ -190,6 +212,22 @@ class DynamicCatalogTests(unittest.TestCase):
self.assertTrue(len(self.catalog.filter(providers.Provider)) == 2) self.assertTrue(len(self.catalog.filter(providers.Provider)) == 2)
self.assertTrue(len(self.catalog.filter(providers.Value)) == 0) self.assertTrue(len(self.catalog.filter(providers.Value)) == 0)
def test_copy(self):
"""Test copying of catalog."""
catalog_copy = self.catalog.copy()
self.assertIsNot(self.catalog, catalog_copy)
self.assertIs(self.catalog.p1, catalog_copy.p1)
self.assertIs(self.catalog.p2, catalog_copy.p2)
def test_deepcopy(self):
"""Test copying of catalog."""
catalog_copy = self.catalog.deepcopy()
self.assertIsNot(self.catalog, catalog_copy)
self.assertIsNot(self.catalog.p1, catalog_copy.p1)
self.assertIsNot(self.catalog.p2, catalog_copy.p2)
def test_repr(self): def test_repr(self):
"""Test catalog representation.""" """Test catalog representation."""
self.assertIn('TestCatalog', repr(self.catalog)) self.assertIn('TestCatalog', repr(self.catalog))