diff --git a/dependency_injector/injections.py b/dependency_injector/injections.py index d041a50c..964d648c 100644 --- a/dependency_injector/injections.py +++ b/dependency_injector/injections.py @@ -60,7 +60,7 @@ class Injection(object): :rtype: object """ if self.call_injectable: - return self.injectable() + return self.injectable.provide() return self.injectable def __str__(self): diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index b88f2b19..f2b1c14a 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -64,7 +64,7 @@ class Provider(object): __IS_PROVIDER__ = True __OPTIMIZED_CALLS__ = True - __slots__ = ('overridden_by', '__call__') + __slots__ = ('overridden_by', 'provide', '__call__') def __init__(self): """Initializer.""" @@ -72,7 +72,7 @@ class Provider(object): super(Provider, self).__init__() # Enable __call__() / _provide() optimization if self.__class__.__OPTIMIZED_CALLS__: - self.__call__ = self._provide + self.__call__ = self.provide = self._provide def _provide(self, *args, **kwargs): """Providing strategy implementation. @@ -124,9 +124,9 @@ class Provider(object): # Disable __call__() / _provide() optimization if self.__class__.__OPTIMIZED_CALLS__: - self.__call__ = self._call_last_overriding + self.__call__ = self.provide = self._call_last_overriding - return provider + return OverridingContext(self) def reset_last_overriding(self): """Reset last overriding provider. @@ -143,7 +143,7 @@ class Provider(object): if not self.is_overridden: # Enable __call__() / _provide() optimization if self.__class__.__OPTIMIZED_CALLS__: - self.__call__ = self._provide + self.__call__ = self.provide = self._provide def reset_override(self): """Reset all overriding providers. @@ -154,7 +154,7 @@ class Provider(object): # Enable __call__() / _provide() optimization if self.__class__.__OPTIMIZED_CALLS__: - self.__call__ = self._provide + self.__call__ = self.provide = self._provide def delegate(self): """Return provider's delegate. @@ -1042,6 +1042,36 @@ class ChildConfig(Provider): __repr__ = __str__ +class OverridingContext(object): + """Provider overriding context. + + :py:class:`OverridingContext` is used by :py:meth:`Provider.override` for + implemeting ``with`` contexts. When :py:class:`OverridingContext` is + closed, overriding that was created in this context is dropped also. + + .. code-block:: python + + with provider.override(another_provider): + assert provider.is_overridden + assert not provider.is_overridden + """ + + def __init__(self, overridden): + """Initializer. + + :param overridden: Overridden provider + :type overridden: :py:class:`Provider` + """ + self.overridden = overridden + + def __enter__(self): + """Do nothing.""" + + def __exit__(self, *_): + """Exit overriding context.""" + self.overridden.reset_last_overriding() + + def override(overridden): """Decorator for overriding providers. diff --git a/tests/test_providers.py b/tests/test_providers.py index b5aca482..47dd666d 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -69,6 +69,13 @@ class ProviderTests(unittest.TestCase): self.provider.override(overriding_provider) self.assertTrue(self.provider.is_overridden) + def test_overriding_context(self): + """Test provider overriding context.""" + overriding_provider = providers.Provider() + with self.provider.override(overriding_provider): + self.assertTrue(self.provider.is_overridden) + self.assertFalse(self.provider.is_overridden) + def test_override_with_itself(self): """Test provider overriding with itself.""" self.assertRaises(errors.Error, self.provider.override, self.provider)