Fix Singleton and ThreadLocalSingleton to handle initialization errors

This commit is contained in:
Roman Mogylatov 2021-01-03 18:07:37 -05:00
parent ea5af60669
commit dcea50b3a3
4 changed files with 2659 additions and 2330 deletions

File diff suppressed because it is too large Load Diff

View File

@ -2182,7 +2182,12 @@ cdef class BaseSingleton(Provider):
raise NotImplementedError()
def _async_init_instance(self, future_result, result):
try:
instance = result.result()
except Exception as exception:
self.__storage = None
future_result.set_exception(exception)
else:
self.__storage = instance
future_result.set_result(instance)
@ -2407,7 +2412,12 @@ cdef class ThreadLocalSingleton(BaseSingleton):
return instance
def _async_init_instance(self, future_result, result):
try:
instance = result.result()
except Exception as exception:
del self.__storage.instance
future_result.set_exception(exception)
else:
self.__storage.instance = instance
future_result.set_result(instance)

View File

@ -329,6 +329,37 @@ class SingletonTests(AsyncTestCase):
self.assertIs(instance1, instance2)
self.assertIs(instance, instance)
def test_async_init_with_error(self):
# Disable default exception handling to prevent output
asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
async def create_instance():
create_instance.counter += 1
raise RuntimeError()
create_instance.counter = 0
provider = providers.Singleton(create_instance)
future = provider()
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(future)
self.assertEqual(create_instance.counter, 1)
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(provider())
self.assertEqual(create_instance.counter, 2)
self.assertTrue(provider.is_async_mode_enabled())
# Restore default exception handling
asyncio.get_event_loop().set_exception_handler(None)
class DelegatedSingletonTests(AsyncTestCase):
@ -398,6 +429,36 @@ class ThreadLocalSingletonTests(AsyncTestCase):
self.assertIs(instance, instance)
def test_async_init_with_error(self):
# Disable default exception handling to prevent output
asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
async def create_instance():
create_instance.counter += 1
raise RuntimeError()
create_instance.counter = 0
provider = providers.ThreadLocalSingleton(create_instance)
future = provider()
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(future)
self.assertEqual(create_instance.counter, 1)
self.assertTrue(provider.is_async_mode_enabled())
with self.assertRaises(RuntimeError):
self._run(provider())
self.assertEqual(create_instance.counter, 2)
self.assertTrue(provider.is_async_mode_enabled())
# Restore default exception handling
asyncio.get_event_loop().set_exception_handler(None)
class DelegatedThreadLocalSingletonTests(AsyncTestCase):
def test_async_mode(self):

View File

@ -447,6 +447,7 @@ class AsyncResourceTest(AsyncTestCase):
future = provider()
self.assertTrue(provider.initialized)
self.assertTrue(provider.is_async_mode_enabled())
# Disable default exception handling to prevent output
asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
@ -458,6 +459,7 @@ class AsyncResourceTest(AsyncTestCase):
asyncio.get_event_loop().set_exception_handler(None)
self.assertFalse(provider.initialized)
self.assertTrue(provider.is_async_mode_enabled())
def test_init_and_shutdown_methods(self):
async def _init():