diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index e58a97cd..d9ebb9d2 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -5,7 +5,7 @@ import inspect import pkgutil import sys from types import ModuleType -from typing import Optional, Iterable, Callable, Any, Dict, Generic, TypeVar, cast +from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast if sys.version_info < (3, 7): from typing import GenericMeta @@ -22,6 +22,7 @@ __all__ = ( 'unwire', 'Provide', 'Provider', + 'Closing', ) T = TypeVar('T') @@ -206,10 +207,10 @@ def _patch_fn( fn: Callable[..., Any], providers_map: ProvidersMap, ) -> None: - injections = _resolve_injections(fn, providers_map) + injections, closing = _resolve_injections(fn, providers_map) if not injections: return - setattr(module, name, _patch_with_injections(fn, injections)) + setattr(module, name, _patch_with_injections(fn, injections, closing)) def _unpatch_fn( @@ -222,25 +223,34 @@ def _unpatch_fn( setattr(module, name, _get_original_from_patched(fn)) -def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: +def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Tuple[Dict[str, Any], List[Any]]: signature = inspect.signature(fn) injections = {} + closing = [] for parameter_name, parameter in signature.parameters.items(): if not isinstance(parameter.default, _Marker): continue marker = parameter.default + closing_modifier = False + if isinstance(marker, Closing): + closing_modifier = True + marker = marker.provider + provider = providers_map.resolve_provider(marker.provider) if provider is None: continue + if closing_modifier: + closing.append(provider) + if isinstance(marker, Provide): injections[parameter_name] = provider elif isinstance(marker, Provider): injections[parameter_name] = provider.provider - return injections + return injections, closing def _fetch_modules(package): @@ -258,7 +268,7 @@ def _is_method(member): return inspect.ismethod(member) or inspect.isfunction(member) -def _patch_with_injections(fn, injections): +def _patch_with_injections(fn, injections, closing): if inspect.iscoroutinefunction(fn): @functools.wraps(fn) async def _patched(*args, **kwargs): @@ -268,7 +278,13 @@ def _patch_with_injections(fn, injections): to_inject.update(kwargs) - return await fn(*args, **to_inject) + result = await fn(*args, **to_inject) + + for provider in closing: + if isinstance(provider, providers.Resource): + provider.shutdown() + + return result else: @functools.wraps(fn) def _patched(*args, **kwargs): @@ -278,11 +294,18 @@ def _patch_with_injections(fn, injections): to_inject.update(kwargs) - return fn(*args, **to_inject) + result = fn(*args, **to_inject) + + for provider in closing: + if isinstance(provider, providers.Resource): + provider.shutdown() + + return result _patched.__wired__ = True _patched.__original__ = fn _patched.__injections__ = injections + _patched.__closing__ = [] return _patched @@ -322,3 +345,7 @@ class Provide(_Marker): class Provider(_Marker): ... + + +class Closing(_Marker): + ...