diff --git a/objects/providers.py b/objects/providers.py index b7b4229f..f08efb99 100644 --- a/objects/providers.py +++ b/objects/providers.py @@ -47,18 +47,23 @@ class Provider(object): else: self.overridden = self.overridden + (ensure_is_provider(provider),) - def reset_override(self): - """Reset all overriding providers.""" - self.overridden = None - @property def last_overriding(self): """Return last overriding provider.""" try: return self.overridden[-1] except (TypeError, IndexError): - raise Error('Provider {0} '.format(str(self)) + - 'is not overridden') + raise Error('Provider {0} is not overridden'.format(str(self))) + + def reset_last_overriding(self): + """Reset last overriding provider.""" + if not self.overridden: + raise Error('Provider {0} is not overridden'.format(str(self))) + self.overridden = self.overridden[:-1] + + def reset_override(self): + """Reset all overriding providers.""" + self.overridden = None class Delegate(Provider): diff --git a/tests/test_providers.py b/tests/test_providers.py index 26bddfa4..65c93676 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -63,25 +63,6 @@ class ProviderTests(unittest.TestCase): """Test provider overriding with not provider instance.""" self.assertRaises(Error, self.provider.override, object()) - def test_reset_override(self): - """Test reset of provider's override.""" - overriding_provider = Provider() - self.provider.override(overriding_provider) - - self.assertTrue(self.provider.overridden) - self.assertIs(self.provider.last_overriding, overriding_provider) - - self.provider.reset_override() - - self.assertFalse(self.provider.overridden) - try: - self.provider.last_overriding - except Error: - pass - else: - self.fail('Got en error in {}'.format( - str(self.test_last_overriding_of_not_overridden_provider))) - def test_last_overriding(self): """Test getting last overriding provider.""" overriding_provider1 = Provider() @@ -103,6 +84,45 @@ class ProviderTests(unittest.TestCase): self.fail('Got en error in {}'.format( str(self.test_last_overriding_of_not_overridden_provider))) + def test_reset_last_overriding(self): + """Test reseting of last overriding provider.""" + overriding_provider1 = Provider() + overriding_provider2 = Provider() + + self.provider.override(overriding_provider1) + self.provider.override(overriding_provider2) + + self.assertIs(self.provider.last_overriding, overriding_provider2) + + self.provider.reset_last_overriding() + self.assertIs(self.provider.last_overriding, overriding_provider1) + + self.provider.reset_last_overriding() + self.assertFalse(self.provider.overridden) + + def test_reset_last_overriding_of_not_overridden_provider(self): + """Test resetting of last overriding on not overridden provier.""" + self.assertRaises(Error, self.provider.reset_last_overriding) + + def test_reset_override(self): + """Test reset of provider's override.""" + overriding_provider = Provider() + self.provider.override(overriding_provider) + + self.assertTrue(self.provider.overridden) + self.assertIs(self.provider.last_overriding, overriding_provider) + + self.provider.reset_override() + + self.assertFalse(self.provider.overridden) + try: + self.provider.last_overriding + except Error: + pass + else: + self.fail('Got en error in {}'.format( + str(self.test_last_overriding_of_not_overridden_provider))) + class DelegateTests(unittest.TestCase):