Fix issue with wiring and resource initialization

This commit is contained in:
Roman Mogylatov 2020-10-30 16:40:27 -04:00
parent 4c46d34b20
commit c5f799a1ec
3 changed files with 50 additions and 19 deletions

View File

@ -5,7 +5,7 @@ import inspect
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast
if sys.version_info < (3, 7):
from typing import GenericMeta
@ -226,11 +226,11 @@ def _unpatch_fn(
def _resolve_injections(
fn: Callable[..., Any],
providers_map: ProvidersMap,
) -> Tuple[Dict[str, Any], List[Any]]:
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
signature = inspect.signature(fn)
injections = {}
closing = []
closing = {}
for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker):
continue
@ -246,7 +246,7 @@ def _resolve_injections(
continue
if closing_modifier:
closing.append(provider)
closing[parameter_name] = provider
if isinstance(marker, Provide):
injections[parameter_name] = provider
@ -275,40 +275,44 @@ def _patch_with_injections(fn, injections, closing):
if inspect.iscoroutinefunction(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
to_inject = {}
to_inject = kwargs.copy()
for injection, provider in injections.items():
to_inject[injection] = provider()
to_inject.update(kwargs)
if injection not in kwargs:
to_inject[injection] = provider()
result = await fn(*args, **to_inject)
for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
for injection, provider in closing.items():
if injection in kwargs:
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()
return result
else:
@functools.wraps(fn)
def _patched(*args, **kwargs):
to_inject = {}
to_inject = kwargs.copy()
for injection, provider in injections.items():
to_inject[injection] = provider()
to_inject.update(kwargs)
if injection not in kwargs:
to_inject[injection] = provider()
result = fn(*args, **to_inject)
for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
for injection, provider in closing.items():
if injection in kwargs:
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()
return result
_patched.__wired__ = True
_patched.__original__ = fn
_patched.__injections__ = injections
_patched.__closing__ = []
_patched.__closing__ = closing
return _patched

View File

@ -6,6 +6,11 @@ class Service:
init_counter: int = 0
shutdown_counter: int = 0
@classmethod
def reset_counter(cls):
cls.init_counter = 0
cls.shutdown_counter = 0
@classmethod
def init(cls):
cls.init_counter += 1

View File

@ -178,6 +178,8 @@ class WiringTest(unittest.TestCase):
def test_closing_resource(self):
from wiringsamples import resourceclosing
resourceclosing.Service.reset_counter()
container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
self.addCleanup(container.unwire)
@ -193,3 +195,23 @@ class WiringTest(unittest.TestCase):
self.assertEqual(result_2.shutdown_counter, 2)
self.assertIsNot(result_1, result_2)
def test_closing_resource_context(self):
from wiringsamples import resourceclosing
resourceclosing.Service.reset_counter()
service = resourceclosing.Service()
container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
self.addCleanup(container.unwire)
result_1 = resourceclosing.test_function(service=service)
self.assertIs(result_1, service)
self.assertEqual(result_1.init_counter, 0)
self.assertEqual(result_1.shutdown_counter, 0)
result_2 = resourceclosing.test_function(service=service)
self.assertIs(result_2, service)
self.assertEqual(result_2.init_counter, 0)
self.assertEqual(result_2.shutdown_counter, 0)