diff --git a/objects/providers.py b/objects/providers.py index 17cd1e4a..a72fad08 100644 --- a/objects/providers.py +++ b/objects/providers.py @@ -2,7 +2,7 @@ from collections import Iterable -from .utils import is_provider +from .utils import ensure_is_provider from .utils import is_injection from .utils import is_init_arg_injection from .utils import is_attribute_injection @@ -32,10 +32,7 @@ class Provider(object): def override(self, provider): """Override provider with another provider.""" - if not is_provider(provider): - raise Error('Expected provider as an overriding instance, ' - 'got {}'.format(str(provider))) - self.overridden.append(provider) + self.overridden.append(ensure_is_provider(provider)) @property def last_overriding(self): @@ -58,7 +55,7 @@ class ProviderDelegate(Provider): :type delegated: Provider """ - self.delegated = delegated + self.delegated = ensure_is_provider(delegated) super(ProviderDelegate, self).__init__() def __call__(self): diff --git a/objects/utils.py b/objects/utils.py index dfc9add0..c95d7455 100644 --- a/objects/utils.py +++ b/objects/utils.py @@ -2,6 +2,8 @@ from inspect import isclass +from .errors import Error + def is_provider(instance): """Check if instance is provider instance.""" @@ -9,6 +11,14 @@ def is_provider(instance): hasattr(instance, '__IS_OBJECTS_PROVIDER__')) +def ensure_is_provider(instance): + """Check if instance is provider instance, otherwise raise and error.""" + if not is_provider(instance): + raise Error('Expected provider instance, ' + 'got {}'.format(str(instance))) + return instance + + def is_injection(instance): """Check if instance is injection instance.""" return (not isclass(instance) and diff --git a/tests/test_providers.py b/tests/test_providers.py index 36afb49c..e77900b4 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -24,55 +24,85 @@ class ProviderTest(unittest.TestCase): """Provider test cases.""" - def test_init(self): - """Test creating and initialization.""" - self.assertIsInstance(Provider(), Provider) + def setUp(self): + """Set test cases environment up.""" + self.provider = Provider() def test_is_provider(self): """Test `is_provider` check.""" - self.assertTrue(is_provider(Provider())) + self.assertTrue(is_provider(self.provider)) def test_call(self): """Test call.""" - self.assertRaises(NotImplementedError, Provider().__call__) + self.assertRaises(NotImplementedError, self.provider.__call__) def test_delegate(self): """Test creating of provider delegation.""" - provider = Provider() - delegate = provider.delegate() + delegate1 = self.provider.delegate() - self.assertIsInstance(delegate, ProviderDelegate) - self.assertIs(delegate.delegated, provider) + self.assertIsInstance(delegate1, ProviderDelegate) + self.assertIs(delegate1.delegated, self.provider) + + delegate2 = self.provider.delegate() + + self.assertIsInstance(delegate2, ProviderDelegate) + self.assertIs(delegate2.delegated, self.provider) + + self.assertIsNot(delegate1, delegate2) def test_override(self): """Test provider overriding.""" - provider = Provider() overriding_provider = Provider() - provider.override(overriding_provider) - self.assertTrue(provider.overridden) + self.provider.override(overriding_provider) + self.assertTrue(self.provider.overridden) def test_override_with_not_provider(self): """Test provider overriding with not provider instance.""" - self.assertRaises(Error, Provider().override, object()) + self.assertRaises(Error, self.provider.override, object()) def test_last_overriding(self): """Test getting last overriding provider.""" - provider = Provider() overriding_provider1 = Provider() overriding_provider2 = Provider() - provider.override(overriding_provider1) - self.assertIs(provider.last_overriding, overriding_provider1) + self.provider.override(overriding_provider1) + self.assertIs(self.provider.last_overriding, overriding_provider1) - provider.override(overriding_provider2) - self.assertIs(provider.last_overriding, overriding_provider2) + self.provider.override(overriding_provider2) + self.assertIs(self.provider.last_overriding, overriding_provider2) def test_last_overriding_of_not_overridden_provider(self): """Test getting last overriding from not overridden provider.""" try: - Provider().last_overriding + self.provider.last_overriding except Error: pass else: self.fail('Got en error in {}'.format( str(self.test_last_overriding_of_not_overridden_provider))) + + +class ProviderDelegateTest(unittest.TestCase): + + """ProviderDelegate test cases.""" + + def setUp(self): + """Set test cases environment up.""" + self.delegated = Provider() + self.delegate = ProviderDelegate(delegated=self.delegated) + + def test_is_provider(self): + """Test `is_provider` check.""" + self.assertTrue(is_provider(self.delegate)) + + def test_init_with_not_provider(self): + """Test that delegate accepts only another provider as delegated.""" + self.assertRaises(Error, ProviderDelegate, delegated=object()) + + def test_call(self): + """ Test returning of delegated provider.""" + delegated1 = self.delegate() + delegated2 = self.delegate() + + self.assertIs(delegated1, self.delegated) + self.assertIs(delegated2, self.delegated)