Fix issue with wiring and resource initialization

This commit is contained in:
Roman Mogylatov 2020-10-30 16:40:27 -04:00
parent 4c46d34b20
commit c5f799a1ec
3 changed files with 50 additions and 19 deletions

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil import pkgutil
import sys import sys
from types import ModuleType from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
from typing import GenericMeta from typing import GenericMeta
@ -226,11 +226,11 @@ def _unpatch_fn(
def _resolve_injections( def _resolve_injections(
fn: Callable[..., Any], fn: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> Tuple[Dict[str, Any], List[Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
signature = inspect.signature(fn) signature = inspect.signature(fn)
injections = {} injections = {}
closing = [] closing = {}
for parameter_name, parameter in signature.parameters.items(): for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker): if not isinstance(parameter.default, _Marker):
continue continue
@ -246,7 +246,7 @@ def _resolve_injections(
continue continue
if closing_modifier: if closing_modifier:
closing.append(provider) closing[parameter_name] = provider
if isinstance(marker, Provide): if isinstance(marker, Provide):
injections[parameter_name] = provider injections[parameter_name] = provider
@ -275,32 +275,36 @@ def _patch_with_injections(fn, injections, closing):
if inspect.iscoroutinefunction(fn): if inspect.iscoroutinefunction(fn):
@functools.wraps(fn) @functools.wraps(fn)
async def _patched(*args, **kwargs): async def _patched(*args, **kwargs):
to_inject = {} to_inject = kwargs.copy()
for injection, provider in injections.items(): for injection, provider in injections.items():
if injection not in kwargs:
to_inject[injection] = provider() to_inject[injection] = provider()
to_inject.update(kwargs)
result = await fn(*args, **to_inject) result = await fn(*args, **to_inject)
for provider in closing: for injection, provider in closing.items():
if isinstance(provider, providers.Resource): if injection in kwargs:
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown() provider.shutdown()
return result return result
else: else:
@functools.wraps(fn) @functools.wraps(fn)
def _patched(*args, **kwargs): def _patched(*args, **kwargs):
to_inject = {} to_inject = kwargs.copy()
for injection, provider in injections.items(): for injection, provider in injections.items():
if injection not in kwargs:
to_inject[injection] = provider() to_inject[injection] = provider()
to_inject.update(kwargs)
result = fn(*args, **to_inject) result = fn(*args, **to_inject)
for provider in closing: for injection, provider in closing.items():
if isinstance(provider, providers.Resource): if injection in kwargs:
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown() provider.shutdown()
return result return result
@ -308,7 +312,7 @@ def _patch_with_injections(fn, injections, closing):
_patched.__wired__ = True _patched.__wired__ = True
_patched.__original__ = fn _patched.__original__ = fn
_patched.__injections__ = injections _patched.__injections__ = injections
_patched.__closing__ = [] _patched.__closing__ = closing
return _patched return _patched

View File

@ -6,6 +6,11 @@ class Service:
init_counter: int = 0 init_counter: int = 0
shutdown_counter: int = 0 shutdown_counter: int = 0
@classmethod
def reset_counter(cls):
cls.init_counter = 0
cls.shutdown_counter = 0
@classmethod @classmethod
def init(cls): def init(cls):
cls.init_counter += 1 cls.init_counter += 1

View File

@ -178,6 +178,8 @@ class WiringTest(unittest.TestCase):
def test_closing_resource(self): def test_closing_resource(self):
from wiringsamples import resourceclosing from wiringsamples import resourceclosing
resourceclosing.Service.reset_counter()
container = resourceclosing.Container() container = resourceclosing.Container()
container.wire(modules=[resourceclosing]) container.wire(modules=[resourceclosing])
self.addCleanup(container.unwire) self.addCleanup(container.unwire)
@ -193,3 +195,23 @@ class WiringTest(unittest.TestCase):
self.assertEqual(result_2.shutdown_counter, 2) self.assertEqual(result_2.shutdown_counter, 2)
self.assertIsNot(result_1, result_2) self.assertIsNot(result_1, result_2)
def test_closing_resource_context(self):
from wiringsamples import resourceclosing
resourceclosing.Service.reset_counter()
service = resourceclosing.Service()
container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
self.addCleanup(container.unwire)
result_1 = resourceclosing.test_function(service=service)
self.assertIs(result_1, service)
self.assertEqual(result_1.init_counter, 0)
self.assertEqual(result_1.shutdown_counter, 0)
result_2 = resourceclosing.test_function(service=service)
self.assertIs(result_2, service)
self.assertEqual(result_2.init_counter, 0)
self.assertEqual(result_2.shutdown_counter, 0)