diff --git a/objects/providers.py b/objects/providers.py index 48e02c41..6acce9ec 100644 --- a/objects/providers.py +++ b/objects/providers.py @@ -34,6 +34,10 @@ class Provider(object): """Override provider with another provider.""" self.overridden.append(ensure_is_provider(provider)) + def reset_override(self): + """Reset all overriding providers.""" + self.overridden = list() + @property def last_overriding(self): """Return last overriding provider.""" diff --git a/tests/test_providers.py b/tests/test_providers.py index c2f14b12..f61e09df 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -65,6 +65,25 @@ class ProviderTests(unittest.TestCase): """Test provider overriding with not provider instance.""" self.assertRaises(Error, self.provider.override, object()) + def test_reset_override(self): + """Test reset of provider's override.""" + overriding_provider = Provider() + self.provider.override(overriding_provider) + + self.assertTrue(self.provider.overridden) + self.assertIs(self.provider.last_overriding, overriding_provider) + + self.provider.reset_override() + + self.assertFalse(self.provider.overridden) + try: + self.provider.last_overriding + except Error: + pass + else: + self.fail('Got en error in {}'.format( + str(self.test_last_overriding_of_not_overridden_provider))) + def test_last_overriding(self): """Test getting last overriding provider.""" overriding_provider1 = Provider()