mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-11-24 03:45:48 +03:00
take common code from ContextLocalResource to Resource
This commit is contained in:
parent
066d228ab9
commit
44a6a68647
|
|
@ -3837,31 +3837,27 @@ cdef class Resource(Provider):
|
|||
|
||||
async def _handle_async_cm(self, obj) -> None:
|
||||
try:
|
||||
self._resource = resource = await obj.__aenter__()
|
||||
self._shutdowner = obj.__aexit__
|
||||
resource = await obj.__aenter__()
|
||||
return resource
|
||||
except:
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
async def _provide_async(self, future) -> None:
|
||||
try:
|
||||
obj = await future
|
||||
async def _provide_async(self, future):
|
||||
obj = await future
|
||||
|
||||
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
|
||||
self._resource = await obj.__aenter__()
|
||||
self._shutdowner = obj.__aexit__
|
||||
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
|
||||
self._resource = obj.__enter__()
|
||||
self._shutdowner = obj.__exit__
|
||||
else:
|
||||
self._resource = obj
|
||||
self._shutdowner = None
|
||||
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
|
||||
resource = await obj.__aenter__()
|
||||
shutdowner = obj.__aexit__
|
||||
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
|
||||
resource = obj.__enter__()
|
||||
shutdowner = obj.__exit__
|
||||
else:
|
||||
resource = obj
|
||||
shutdowner = None
|
||||
|
||||
return resource, shutdowner
|
||||
|
||||
return self._resource
|
||||
except:
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
if self._initialized:
|
||||
|
|
@ -3880,14 +3876,18 @@ cdef class Resource(Provider):
|
|||
|
||||
if __is_future_or_coroutine(obj):
|
||||
self._initialized = True
|
||||
self._resource = resource = ensure_future(self._provide_async(obj))
|
||||
return resource
|
||||
future_result = asyncio.Future()
|
||||
future = ensure_future(self._provide_async(obj))
|
||||
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
|
||||
self._resource = future_result
|
||||
return self._resource
|
||||
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
|
||||
self._resource = obj.__enter__()
|
||||
self._shutdowner = obj.__exit__
|
||||
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
|
||||
self._initialized = True
|
||||
self._resource = resource = ensure_future(self._handle_async_cm(obj))
|
||||
self._shutdowner = obj.__aexit__
|
||||
return resource
|
||||
else:
|
||||
self._resource = obj
|
||||
|
|
@ -3896,14 +3896,27 @@ cdef class Resource(Provider):
|
|||
self._initialized = True
|
||||
return self._resource
|
||||
|
||||
def _async_init_instance(self, future_result, result):
|
||||
try:
|
||||
resource, shutdowner = result.result()
|
||||
except Exception as exception:
|
||||
self._resource = None
|
||||
self._shutdowner = None
|
||||
self._initialized = False
|
||||
future_result.set_exception(exception)
|
||||
else:
|
||||
self._resource = resource
|
||||
self._shutdowner = shutdowner
|
||||
future_result.set_result(resource)
|
||||
|
||||
|
||||
cdef class ContextLocalResource(Resource):
|
||||
_none = object()
|
||||
|
||||
def __init__(self, provides=None, *args, **kwargs):
|
||||
self._initialized_context_var = ContextVar("_initialized_context_var", default=False)
|
||||
self._resource_context_var = ContextVar("_resource_context_var", default=self._none)
|
||||
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none)
|
||||
self._resource_context_var = ContextVar("_resource_context_var", default=None)
|
||||
self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=None)
|
||||
super().__init__(provides, *args, **kwargs)
|
||||
|
||||
@property
|
||||
|
|
@ -3945,7 +3958,7 @@ cdef class ContextLocalResource(Resource):
|
|||
return NULL_AWAITABLE
|
||||
return
|
||||
|
||||
if self._shutdowner != self._none:
|
||||
if self._shutdowner != None:
|
||||
future = self._shutdowner(None, None, None)
|
||||
if __is_future_or_coroutine(future):
|
||||
self._reset_all_contex_vars()
|
||||
|
|
@ -3958,79 +3971,8 @@ cdef class ContextLocalResource(Resource):
|
|||
|
||||
def _reset_all_contex_vars(self):
|
||||
self._initialized=False
|
||||
self._resource = self._none
|
||||
self._shutdowner = self._none
|
||||
|
||||
async def _handle_async_cm(self, obj) -> None:
|
||||
resource = await obj.__aenter__()
|
||||
return resource
|
||||
|
||||
async def _provide_async(self, future):
|
||||
obj = await future
|
||||
|
||||
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
|
||||
resource = await obj.__aenter__()
|
||||
shutdowner = obj.__aexit__
|
||||
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
|
||||
resource = obj.__enter__()
|
||||
shutdowner = obj.__exit__
|
||||
else:
|
||||
resource = obj
|
||||
shutdowner = self._none
|
||||
|
||||
return resource, shutdowner
|
||||
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
if self._initialized:
|
||||
return self._resource
|
||||
obj = __call(
|
||||
self._provides,
|
||||
args,
|
||||
self._args,
|
||||
self._args_len,
|
||||
kwargs,
|
||||
self._kwargs,
|
||||
self._kwargs_len,
|
||||
self._async_mode,
|
||||
)
|
||||
|
||||
if __is_future_or_coroutine(obj):
|
||||
future_result = asyncio.Future()
|
||||
future = ensure_future(self._provide_async(obj))
|
||||
future.add_done_callback(functools.partial(self._async_init_instance, future_result))
|
||||
return future_result
|
||||
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
|
||||
resource = obj.__enter__()
|
||||
self._resource = resource
|
||||
self._initialized = True
|
||||
self._shutdowner = obj.__exit__
|
||||
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
|
||||
resource = ensure_future(self._handle_async_cm(obj))
|
||||
self._resource = resource
|
||||
self._initialized = True
|
||||
self._shutdowner = obj.__aexit__
|
||||
return resource
|
||||
else:
|
||||
self._resource = obj
|
||||
self._initialized = True
|
||||
self._shutdowner = self._none
|
||||
|
||||
return self._resource
|
||||
|
||||
def _async_init_instance(self, future_result, result):
|
||||
try:
|
||||
resource, shutdowner = result.result()
|
||||
except Exception as exception:
|
||||
self._resource = self._none
|
||||
self._shutdowner = self._none
|
||||
self._initialized = False
|
||||
future_result.set_exception(exception)
|
||||
else:
|
||||
self._resource = resource
|
||||
self._initialized = True
|
||||
self._shutdowner = shutdowner
|
||||
future_result.set_result(resource)
|
||||
self._resource = None
|
||||
self._shutdowner = None
|
||||
|
||||
|
||||
cdef class Container(Provider):
|
||||
|
|
|
|||
|
|
@ -88,16 +88,27 @@ async def test_injection_in_different_context():
|
|||
context_local_resource = providers.ContextLocalResource(_init)
|
||||
async_context_local_resource = providers.ContextLocalResource(_async_init)
|
||||
|
||||
async def run_in_context():
|
||||
obj = await container.async_context_local_resource()
|
||||
return obj
|
||||
|
||||
container = Container()
|
||||
obj1 = await container.async_context_local_resource()
|
||||
obj2 = await container.async_context_local_resource()
|
||||
|
||||
obj1, obj2 = await asyncio.gather(run_in_context(), run_in_context())
|
||||
assert obj1 != obj2
|
||||
|
||||
obj3 = container.context_local_resource()
|
||||
obj4 = container.context_local_resource()
|
||||
|
||||
obj3 = await container.async_context_local_resource()
|
||||
obj4 = await container.async_context_local_resource()
|
||||
assert obj3 == obj4
|
||||
|
||||
obj5, obj6 = await asyncio.gather(run_in_context(), run_in_context())
|
||||
assert obj5 == obj6 # as context is copied from the current one where async_context_local_resource was initialized
|
||||
|
||||
obj7 = container.context_local_resource()
|
||||
obj8 = container.context_local_resource()
|
||||
|
||||
assert obj7 == obj8
|
||||
|
||||
|
||||
def test_init_function():
|
||||
def _init():
|
||||
|
|
@ -329,37 +340,27 @@ def test_call_with_context_args():
|
|||
|
||||
|
||||
def test_fluent_interface():
|
||||
provider = providers.ContextLocalResource(init_fn) \
|
||||
.add_args(1, 2) \
|
||||
.add_kwargs(a3=3, a4=4)
|
||||
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).add_kwargs(a3=3, a4=4)
|
||||
assert provider() == ((1, 2), {"a3": 3, "a4": 4})
|
||||
|
||||
|
||||
def test_set_args():
|
||||
provider = providers.ContextLocalResource(init_fn) \
|
||||
.add_args(1, 2) \
|
||||
.set_args(3, 4)
|
||||
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).set_args(3, 4)
|
||||
assert provider.args == (3, 4)
|
||||
|
||||
|
||||
def test_clear_args():
|
||||
provider = providers.ContextLocalResource(init_fn) \
|
||||
.add_args(1, 2) \
|
||||
.clear_args()
|
||||
provider = providers.ContextLocalResource(init_fn).add_args(1, 2).clear_args()
|
||||
assert provider.args == tuple()
|
||||
|
||||
|
||||
def test_set_kwargs():
|
||||
provider = providers.ContextLocalResource(init_fn) \
|
||||
.add_kwargs(a1="i1", a2="i2") \
|
||||
.set_kwargs(a3="i3", a4="i4")
|
||||
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").set_kwargs(a3="i3", a4="i4")
|
||||
assert provider.kwargs == {"a3": "i3", "a4": "i4"}
|
||||
|
||||
|
||||
def test_clear_kwargs():
|
||||
provider = providers.ContextLocalResource(init_fn) \
|
||||
.add_kwargs(a1="i1", a2="i2") \
|
||||
.clear_kwargs()
|
||||
provider = providers.ContextLocalResource(init_fn).add_kwargs(a1="i1", a2="i2").clear_kwargs()
|
||||
assert provider.kwargs == {}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user