take common code from ContextLocalResource to Resource

This commit is contained in:
elina-israyelyan 2025-11-02 02:02:56 +04:00
parent 066d228ab9
commit 44a6a68647
2 changed files with 59 additions and 116 deletions

View File

@ -3837,32 +3837,28 @@ cdef class Resource(Provider):
async def _handle_async_cm(self, obj) -> None: async def _handle_async_cm(self, obj) -> None:
try: try:
self._resource = resource = await obj.__aenter__() resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__
return resource return resource
except: except:
self._initialized = False self._initialized = False
raise raise
async def _provide_async(self, future) -> None: async def _provide_async(self, future):
try: obj = await future
obj = await future
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._resource = await obj.__aenter__() resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__ shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__() resource = obj.__enter__()
self._shutdowner = obj.__exit__ shutdowner = obj.__exit__
else: else:
self._resource = obj resource = obj
self._shutdowner = None shutdowner = None
return self._resource return resource, shutdowner
except:
self._initialized = False
raise
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized: if self._initialized:
return self._resource return self._resource
@ -3880,14 +3876,18 @@ cdef class Resource(Provider):
if __is_future_or_coroutine(obj): if __is_future_or_coroutine(obj):
self._initialized = True self._initialized = True
self._resource = resource = ensure_future(self._provide_async(obj)) future_result = asyncio.Future()
return resource future = ensure_future(self._provide_async(obj))
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
self._resource = future_result
return self._resource
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__() self._resource = obj.__enter__()
self._shutdowner = obj.__exit__ self._shutdowner = obj.__exit__
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._initialized = True self._initialized = True
self._resource = resource = ensure_future(self._handle_async_cm(obj)) self._resource = resource = ensure_future(self._handle_async_cm(obj))
self._shutdowner = obj.__aexit__
return resource return resource
else: else:
self._resource = obj self._resource = obj
@ -3896,14 +3896,27 @@ cdef class Resource(Provider):
self._initialized = True self._initialized = True
return self._resource return self._resource
def _async_init_instance(self, future_result, result):
try:
resource, shutdowner = result.result()
except Exception as exception:
self._resource = None
self._shutdowner = None
self._initialized = False
future_result.set_exception(exception)
else:
self._resource = resource
self._shutdowner = shutdowner
future_result.set_result(resource)
cdef class ContextLocalResource(Resource): cdef class ContextLocalResource(Resource):
_none = object() _none = object()
def __init__(self, provides=None, *args, **kwargs): def __init__(self, provides=None, *args, **kwargs):
self._initialized_context_var = ContextVar("_initialized_context_var", default=False) self._initialized_context_var = ContextVar("_initialized_context_var", default=False)
self._resource_context_var = ContextVar("_resource_context_var", default=self._none) self._resource_context_var = ContextVar("_resource_context_var", default=None)
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none) self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=None)
super().__init__(provides, *args, **kwargs) super().__init__(provides, *args, **kwargs)
@property @property
@ -3945,7 +3958,7 @@ cdef class ContextLocalResource(Resource):
return NULL_AWAITABLE return NULL_AWAITABLE
return return
if self._shutdowner != self._none: if self._shutdowner != None:
future = self._shutdowner(None, None, None) future = self._shutdowner(None, None, None)
if __is_future_or_coroutine(future): if __is_future_or_coroutine(future):
self._reset_all_contex_vars() self._reset_all_contex_vars()
@ -3958,79 +3971,8 @@ cdef class ContextLocalResource(Resource):
def _reset_all_contex_vars(self): def _reset_all_contex_vars(self):
self._initialized=False self._initialized=False
self._resource = self._none self._resource = None
self._shutdowner = self._none self._shutdowner = None
async def _handle_async_cm(self, obj) -> None:
resource = await obj.__aenter__()
return resource
async def _provide_async(self, future):
obj = await future
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = await obj.__aenter__()
shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
shutdowner = obj.__exit__
else:
resource = obj
shutdowner = self._none
return resource, shutdowner
cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized:
return self._resource
obj = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
if __is_future_or_coroutine(obj):
future_result = asyncio.Future()
future = ensure_future(self._provide_async(obj))
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
return future_result
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
self._resource = resource
self._initialized = True
self._shutdowner = obj.__exit__
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = ensure_future(self._handle_async_cm(obj))
self._resource = resource
self._initialized = True
self._shutdowner = obj.__aexit__
return resource
else:
self._resource = obj
self._initialized = True
self._shutdowner = self._none
return self._resource
def _async_init_instance(self, future_result, result):
try:
resource, shutdowner = result.result()
except Exception as exception:
self._resource = self._none
self._shutdowner = self._none
self._initialized = False
future_result.set_exception(exception)
else:
self._resource = resource
self._initialized = True
self._shutdowner = shutdowner
future_result.set_result(resource)
cdef class Container(Provider): cdef class Container(Provider):

View File

@ -88,16 +88,27 @@ async def test_injection_in_different_context():
context_local_resource = providers.ContextLocalResource(_init) context_local_resource = providers.ContextLocalResource(_init)
async_context_local_resource = providers.ContextLocalResource(_async_init) async_context_local_resource = providers.ContextLocalResource(_async_init)
async def run_in_context():
obj = await container.async_context_local_resource()
return obj
container = Container() container = Container()
obj1 = await container.async_context_local_resource()
obj2 = await container.async_context_local_resource() obj1, obj2 = await asyncio.gather(run_in_context(), run_in_context())
assert obj1 != obj2 assert obj1 != obj2
obj3 = container.context_local_resource() obj3 = await container.async_context_local_resource()
obj4 = container.context_local_resource() obj4 = await container.async_context_local_resource()
assert obj3 == obj4 assert obj3 == obj4
obj5, obj6 = await asyncio.gather(run_in_context(), run_in_context())
assert obj5 == obj6 # as context is copied from the current one where async_context_local_resource was initialized
obj7 = container.context_local_resource()
obj8 = container.context_local_resource()
assert obj7 == obj8
def test_init_function(): def test_init_function():
def _init(): def _init():
@ -329,37 +340,27 @@ def test_call_with_context_args():
def test_fluent_interface(): def test_fluent_interface():
provider = providers.ContextLocalResource(init_fn) \ provider = providers.ContextLocalResource(init_fn).add_args(1, 2).add_kwargs(a3=3, a4=4)
.add_args(1, 2) \
.add_kwargs(a3=3, a4=4)
assert provider() == ((1, 2), {"a3": 3, "a4": 4}) assert provider() == ((1, 2), {"a3": 3, "a4": 4})
def test_set_args(): def test_set_args():
provider = providers.ContextLocalResource(init_fn) \ provider = providers.ContextLocalResource(init_fn).add_args(1, 2).set_args(3, 4)
.add_args(1, 2) \
.set_args(3, 4)
assert provider.args == (3, 4) assert provider.args == (3, 4)
def test_clear_args(): def test_clear_args():
provider = providers.ContextLocalResource(init_fn) \ provider = providers.ContextLocalResource(init_fn).add_args(1, 2).clear_args()
.add_args(1, 2) \
.clear_args()
assert provider.args == tuple() assert provider.args == tuple()
def test_set_kwargs(): def test_set_kwargs():
provider = providers.ContextLocalResource(init_fn) \ provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").set_kwargs(a3="i3", a4="i4")
.add_kwargs(a1="i1", a2="i2") \
.set_kwargs(a3="i3", a4="i4")
assert provider.kwargs == {"a3": "i3", "a4": "i4"} assert provider.kwargs == {"a3": "i3", "a4": "i4"}
def test_clear_kwargs(): def test_clear_kwargs():
provider = providers.ContextLocalResource(init_fn) \ provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").clear_kwargs()
.add_kwargs(a1="i1", a2="i2") \
.clear_kwargs()
assert provider.kwargs == {} assert provider.kwargs == {}