diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 14e3b652..e22993d1 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -7,6 +7,11 @@ that were made in every particular version. From version 0.7.6 *Dependency Injector* framework strictly follows `Semantic versioning`_ +4.3.2 +----- +- Fix a bug in ``wiring`` with mistakenly initialized and shutdown resource with ``Closing`` + marker on context argument providing. + 4.3.1 ----- - Fix README. diff --git a/src/dependency_injector/__init__.py b/src/dependency_injector/__init__.py index dabe6454..af3d33b5 100644 --- a/src/dependency_injector/__init__.py +++ b/src/dependency_injector/__init__.py @@ -1,6 +1,6 @@ """Top-level package.""" -__version__ = '4.3.1' +__version__ = '4.3.2' """Version number. :type: str diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 1b7c8f2d..f169a7e8 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -5,7 +5,7 @@ import inspect import pkgutil import sys from types import ModuleType -from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast +from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast if sys.version_info < (3, 7): from typing import GenericMeta @@ -226,11 +226,11 @@ def _unpatch_fn( def _resolve_injections( fn: Callable[..., Any], providers_map: ProvidersMap, -) -> Tuple[Dict[str, Any], List[Any]]: +) -> Tuple[Dict[str, Any], Dict[str, Any]]: signature = inspect.signature(fn) injections = {} - closing = [] + closing = {} for parameter_name, parameter in signature.parameters.items(): if not isinstance(parameter.default, _Marker): continue @@ -246,7 +246,7 @@ def _resolve_injections( continue if closing_modifier: - closing.append(provider) + closing[parameter_name] = provider if isinstance(marker, Provide): injections[parameter_name] = provider @@ -273,46 +273,60 @@ def _is_method(member): def _patch_with_injections(fn, injections, closing): if inspect.iscoroutinefunction(fn): - @functools.wraps(fn) - async def _patched(*args, **kwargs): - to_inject = {} - for injection, provider in injections.items(): - to_inject[injection] = provider() - - to_inject.update(kwargs) - - result = await fn(*args, **to_inject) - - for provider in closing: - if isinstance(provider, providers.Resource): - provider.shutdown() - - return result + _patched = _get_async_patched(fn, injections, closing) else: - @functools.wraps(fn) - def _patched(*args, **kwargs): - to_inject = {} - for injection, provider in injections.items(): - to_inject[injection] = provider() - - to_inject.update(kwargs) - - result = fn(*args, **to_inject) - - for provider in closing: - if isinstance(provider, providers.Resource): - provider.shutdown() - - return result + _patched = _get_patched(fn, injections, closing) _patched.__wired__ = True _patched.__original__ = fn _patched.__injections__ = injections - _patched.__closing__ = [] + _patched.__closing__ = closing return _patched +def _get_patched(fn, injections, closing): + @functools.wraps(fn) + def _patched(*args, **kwargs): + to_inject = kwargs.copy() + for injection, provider in injections.items(): + if injection not in kwargs: + to_inject[injection] = provider() + + result = fn(*args, **to_inject) + + for injection, provider in closing.items(): + if injection in kwargs: + continue + if not isinstance(provider, providers.Resource): + continue + provider.shutdown() + + return result + return _patched + + +def _get_async_patched(fn, injections, closing): + @functools.wraps(fn) + async def _patched(*args, **kwargs): + to_inject = kwargs.copy() + for injection, provider in injections.items(): + if injection not in kwargs: + to_inject[injection] = provider() + + result = await fn(*args, **to_inject) + + for injection, provider in closing.items(): + if injection in kwargs: + continue + if not isinstance(provider, providers.Resource): + continue + provider.shutdown() + + return result + return _patched + + def _is_patched(fn): return getattr(fn, '__wired__', False) is True diff --git a/tests/unit/samples/wiringsamples/resourceclosing.py b/tests/unit/samples/wiringsamples/resourceclosing.py index 33a160ca..f7f35bd1 100644 --- a/tests/unit/samples/wiringsamples/resourceclosing.py +++ b/tests/unit/samples/wiringsamples/resourceclosing.py @@ -6,6 +6,11 @@ class Service: init_counter: int = 0 shutdown_counter: int = 0 + @classmethod + def reset_counter(cls): + cls.init_counter = 0 + cls.shutdown_counter = 0 + @classmethod def init(cls): cls.init_counter += 1 diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index f5b1a509..d8e78c1d 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -178,6 +178,8 @@ class WiringTest(unittest.TestCase): def test_closing_resource(self): from wiringsamples import resourceclosing + resourceclosing.Service.reset_counter() + container = resourceclosing.Container() container.wire(modules=[resourceclosing]) self.addCleanup(container.unwire) @@ -193,3 +195,23 @@ class WiringTest(unittest.TestCase): self.assertEqual(result_2.shutdown_counter, 2) self.assertIsNot(result_1, result_2) + + def test_closing_resource_context(self): + from wiringsamples import resourceclosing + + resourceclosing.Service.reset_counter() + service = resourceclosing.Service() + + container = resourceclosing.Container() + container.wire(modules=[resourceclosing]) + self.addCleanup(container.unwire) + + result_1 = resourceclosing.test_function(service=service) + self.assertIs(result_1, service) + self.assertEqual(result_1.init_counter, 0) + self.assertEqual(result_1.shutdown_counter, 0) + + result_2 = resourceclosing.test_function(service=service) + self.assertIs(result_2, service) + self.assertEqual(result_2.init_counter, 0) + self.assertEqual(result_2.shutdown_counter, 0)