Add closing marker

This commit is contained in:
Roman Mogylatov 2020-10-29 21:11:10 -04:00
parent b18385a867
commit 9cfadce81a

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Dict, Generic, TypeVar, cast
from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast
if sys.version_info < (3, 7):
from typing import GenericMeta
@ -22,6 +22,7 @@ __all__ = (
'unwire',
'Provide',
'Provider',
'Closing',
)
T = TypeVar('T')
@ -206,10 +207,10 @@ def _patch_fn(
fn: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
injections = _resolve_injections(fn, providers_map)
injections, closing = _resolve_injections(fn, providers_map)
if not injections:
return
setattr(module, name, _patch_with_injections(fn, injections))
setattr(module, name, _patch_with_injections(fn, injections, closing))
def _unpatch_fn(
@ -222,25 +223,34 @@ def _unpatch_fn(
setattr(module, name, _get_original_from_patched(fn))
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Dict[str, Any]:
def _resolve_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Tuple[Dict[str, Any], List[Any]]:
signature = inspect.signature(fn)
injections = {}
closing = []
for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker):
continue
marker = parameter.default
closing_modifier = False
if isinstance(marker, Closing):
closing_modifier = True
marker = marker.provider
provider = providers_map.resolve_provider(marker.provider)
if provider is None:
continue
if closing_modifier:
closing.append(provider)
if isinstance(marker, Provide):
injections[parameter_name] = provider
elif isinstance(marker, Provider):
injections[parameter_name] = provider.provider
return injections
return injections, closing
def _fetch_modules(package):
@ -258,7 +268,7 @@ def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member)
def _patch_with_injections(fn, injections):
def _patch_with_injections(fn, injections, closing):
if inspect.iscoroutinefunction(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
@ -268,7 +278,13 @@ def _patch_with_injections(fn, injections):
to_inject.update(kwargs)
return await fn(*args, **to_inject)
result = await fn(*args, **to_inject)
for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
return result
else:
@functools.wraps(fn)
def _patched(*args, **kwargs):
@ -278,11 +294,18 @@ def _patch_with_injections(fn, injections):
to_inject.update(kwargs)
return fn(*args, **to_inject)
result = fn(*args, **to_inject)
for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
return result
_patched.__wired__ = True
_patched.__original__ = fn
_patched.__injections__ = injections
_patched.__closing__ = []
return _patched
@ -322,3 +345,7 @@ class Provide(_Marker):
class Provider(_Marker):
...
class Closing(_Marker):
...