take common code from ContextLocalResource to Resource

This commit is contained in:
elina-israyelyan 2025-11-02 02:02:56 +04:00
parent 066d228ab9
commit 44a6a68647
2 changed files with 59 additions and 116 deletions

View File

@ -3837,31 +3837,27 @@ cdef class Resource(Provider):
async def _handle_async_cm(self, obj) -> None:
try:
self._resource = resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__
resource = await obj.__aenter__()
return resource
except:
self._initialized = False
raise
async def _provide_async(self, future) -> None:
try:
obj = await future
async def _provide_async(self, future):
obj = await future
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._resource = await obj.__aenter__()
self._shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__()
self._shutdowner = obj.__exit__
else:
self._resource = obj
self._shutdowner = None
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = await obj.__aenter__()
shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
shutdowner = obj.__exit__
else:
resource = obj
shutdowner = None
return resource, shutdowner
return self._resource
except:
self._initialized = False
raise
cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized:
@ -3880,14 +3876,18 @@ cdef class Resource(Provider):
if __is_future_or_coroutine(obj):
self._initialized = True
self._resource = resource = ensure_future(self._provide_async(obj))
return resource
future_result = asyncio.Future()
future = ensure_future(self._provide_async(obj))
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
self._resource = future_result
return self._resource
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
self._resource = obj.__enter__()
self._shutdowner = obj.__exit__
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
self._initialized = True
self._resource = resource = ensure_future(self._handle_async_cm(obj))
self._shutdowner = obj.__aexit__
return resource
else:
self._resource = obj
@ -3896,14 +3896,27 @@ cdef class Resource(Provider):
self._initialized = True
return self._resource
def _async_init_instance(self, future_result, result):
try:
resource, shutdowner = result.result()
except Exception as exception:
self._resource = None
self._shutdowner = None
self._initialized = False
future_result.set_exception(exception)
else:
self._resource = resource
self._shutdowner = shutdowner
future_result.set_result(resource)
cdef class ContextLocalResource(Resource):
_none = object()
def __init__(self, provides=None, *args, **kwargs):
self._initialized_context_var = ContextVar("_initialized_context_var", default=False)
self._resource_context_var = ContextVar("_resource_context_var", default=self._none)
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none)
self._resource_context_var = ContextVar("_resource_context_var", default=None)
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=None)
super().__init__(provides, *args, **kwargs)
@property
@ -3945,7 +3958,7 @@ cdef class ContextLocalResource(Resource):
return NULL_AWAITABLE
return
if self._shutdowner != self._none:
if self._shutdowner != None:
future = self._shutdowner(None, None, None)
if __is_future_or_coroutine(future):
self._reset_all_contex_vars()
@ -3958,79 +3971,8 @@ cdef class ContextLocalResource(Resource):
def _reset_all_contex_vars(self):
self._initialized=False
self._resource = self._none
self._shutdowner = self._none
async def _handle_async_cm(self, obj) -> None:
resource = await obj.__aenter__()
return resource
async def _provide_async(self, future):
obj = await future
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = await obj.__aenter__()
shutdowner = obj.__aexit__
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
shutdowner = obj.__exit__
else:
resource = obj
shutdowner = self._none
return resource, shutdowner
cpdef object _provide(self, tuple args, dict kwargs):
if self._initialized:
return self._resource
obj = __call(
self._provides,
args,
self._args,
self._args_len,
kwargs,
self._kwargs,
self._kwargs_len,
self._async_mode,
)
if __is_future_or_coroutine(obj):
future_result = asyncio.Future()
future = ensure_future(self._provide_async(obj))
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
return future_result
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
resource = obj.__enter__()
self._resource = resource
self._initialized = True
self._shutdowner = obj.__exit__
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
resource = ensure_future(self._handle_async_cm(obj))
self._resource = resource
self._initialized = True
self._shutdowner = obj.__aexit__
return resource
else:
self._resource = obj
self._initialized = True
self._shutdowner = self._none
return self._resource
def _async_init_instance(self, future_result, result):
try:
resource, shutdowner = result.result()
except Exception as exception:
self._resource = self._none
self._shutdowner = self._none
self._initialized = False
future_result.set_exception(exception)
else:
self._resource = resource
self._initialized = True
self._shutdowner = shutdowner
future_result.set_result(resource)
self._resource = None
self._shutdowner = None
cdef class Container(Provider):

View File

@ -88,16 +88,27 @@ async def test_injection_in_different_context():
context_local_resource = providers.ContextLocalResource(_init)
async_context_local_resource = providers.ContextLocalResource(_async_init)
async def run_in_context():
obj = await container.async_context_local_resource()
return obj
container = Container()
obj1 = await container.async_context_local_resource()
obj2 = await container.async_context_local_resource()
obj1, obj2 = await asyncio.gather(run_in_context(), run_in_context())
assert obj1 != obj2
obj3 = container.context_local_resource()
obj4 = container.context_local_resource()
obj3 = await container.async_context_local_resource()
obj4 = await container.async_context_local_resource()
assert obj3 == obj4
obj5, obj6 = await asyncio.gather(run_in_context(), run_in_context())
assert obj5 == obj6 # as context is copied from the current one where async_context_local_resource was initialized
obj7 = container.context_local_resource()
obj8 = container.context_local_resource()
assert obj7 == obj8
def test_init_function():
def _init():
@ -329,37 +340,27 @@ def test_call_with_context_args():
def test_fluent_interface():
provider = providers.ContextLocalResource(init_fn) \
.add_args(1, 2) \
.add_kwargs(a3=3, a4=4)
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).add_kwargs(a3=3, a4=4)
assert provider() == ((1, 2), {"a3": 3, "a4": 4})
def test_set_args():
provider = providers.ContextLocalResource(init_fn) \
.add_args(1, 2) \
.set_args(3, 4)
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).set_args(3, 4)
assert provider.args == (3, 4)
def test_clear_args():
provider = providers.ContextLocalResource(init_fn) \
.add_args(1, 2) \
.clear_args()
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).clear_args()
assert provider.args == tuple()
def test_set_kwargs():
provider = providers.ContextLocalResource(init_fn) \
.add_kwargs(a1="i1", a2="i2") \
.set_kwargs(a3="i3", a4="i4")
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").set_kwargs(a3="i3", a4="i4")
assert provider.kwargs == {"a3": "i3", "a4": "i4"}
def test_clear_kwargs():
provider = providers.ContextLocalResource(init_fn) \
.add_kwargs(a1="i1", a2="i2") \
.clear_kwargs()
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").clear_kwargs()
assert provider.kwargs == {}