Add closing marker

This commit is contained in:
Roman Mogylatov 2020-10-29 21:11:10 -04:00
parent b18385a867
commit 9cfadce81a

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil import pkgutil
import sys import sys
from types import ModuleType from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Dict, Generic, TypeVar, cast from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
from typing import GenericMeta from typing import GenericMeta
@ -22,6 +22,7 @@ __all__ = (
'unwire', 'unwire',
'Provide', 'Provide',
'Provider', 'Provider',
'Closing',
) )
T = TypeVar('T') T = TypeVar('T')
@ -206,10 +207,10 @@ def _patch_fn(
fn: Callable[..., Any], fn: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> None: ) -> None:
injections = _resolve_injections(fn, providers_map) injections, closing = _resolve_injections(fn, providers_map)
if not injections: if not injections:
return return
setattr(module, name, _patch_with_injections(fn, injections)) setattr(module, name, _patch_with_injections(fn, injections, closing))
def _unpatch_fn( def _unpatch_fn(
@ -222,25 +223,34 @@ def _unpatch_fn(
setattr(module, name, _get_original_from_patched(fn)) setattr(module, name, _get_original_from_patched(fn))
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]: def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Tuple[Dict[str, Any], List[Any]]:
signature = inspect.signature(fn) signature = inspect.signature(fn)
injections = {} injections = {}
closing = []
for parameter_name, parameter in signature.parameters.items(): for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker): if not isinstance(parameter.default, _Marker):
continue continue
marker = parameter.default marker = parameter.default
closing_modifier = False
if isinstance(marker, Closing):
closing_modifier = True
marker = marker.provider
provider = providers_map.resolve_provider(marker.provider) provider = providers_map.resolve_provider(marker.provider)
if provider is None: if provider is None:
continue continue
if closing_modifier:
closing.append(provider)
if isinstance(marker, Provide): if isinstance(marker, Provide):
injections[parameter_name] = provider injections[parameter_name] = provider
elif isinstance(marker, Provider): elif isinstance(marker, Provider):
injections[parameter_name] = provider.provider injections[parameter_name] = provider.provider
return injections return injections, closing
def _fetch_modules(package): def _fetch_modules(package):
@ -258,7 +268,7 @@ def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member) return inspect.ismethod(member) or inspect.isfunction(member)
def _patch_with_injections(fn, injections): def _patch_with_injections(fn, injections, closing):
if inspect.iscoroutinefunction(fn): if inspect.iscoroutinefunction(fn):
@functools.wraps(fn) @functools.wraps(fn)
async def _patched(*args, **kwargs): async def _patched(*args, **kwargs):
@ -268,7 +278,13 @@ def _patch_with_injections(fn, injections):
to_inject.update(kwargs) to_inject.update(kwargs)
return await fn(*args, **to_inject) result = await fn(*args, **to_inject)
for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
return result
else: else:
@functools.wraps(fn) @functools.wraps(fn)
def _patched(*args, **kwargs): def _patched(*args, **kwargs):
@ -278,11 +294,18 @@ def _patch_with_injections(fn, injections):
to_inject.update(kwargs) to_inject.update(kwargs)
return fn(*args, **to_inject) result = fn(*args, **to_inject)
for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
return result
_patched.__wired__ = True _patched.__wired__ = True
_patched.__original__ = fn _patched.__original__ = fn
_patched.__injections__ = injections _patched.__injections__ = injections
_patched.__closing__ = []
return _patched return _patched
@ -322,3 +345,7 @@ class Provide(_Marker):
class Provider(_Marker): class Provider(_Marker):
... ...
class Closing(_Marker):
...