mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-11-01 16:37:56 +03:00 
			
		
		
		
	Fix Singleton and ThreadLocalSingleton to handle initialization errors
This commit is contained in:
		
							parent
							
								
									ea5af60669
								
							
						
					
					
						commit
						dcea50b3a3
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -2182,9 +2182,14 @@ cdef class BaseSingleton(Provider): | |||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def _async_init_instance(self, future_result, result): | ||||
|         instance = result.result() | ||||
|         self.__storage = instance | ||||
|         future_result.set_result(instance) | ||||
|         try: | ||||
|             instance = result.result() | ||||
|         except Exception as exception: | ||||
|             self.__storage = None | ||||
|             future_result.set_exception(exception) | ||||
|         else: | ||||
|             self.__storage = instance | ||||
|             future_result.set_result(instance) | ||||
| 
 | ||||
| 
 | ||||
| cdef class Singleton(BaseSingleton): | ||||
|  | @ -2407,9 +2412,14 @@ cdef class ThreadLocalSingleton(BaseSingleton): | |||
|             return instance | ||||
| 
 | ||||
|     def _async_init_instance(self, future_result, result): | ||||
|         instance = result.result() | ||||
|         self.__storage.instance = instance | ||||
|         future_result.set_result(instance) | ||||
|         try: | ||||
|             instance = result.result() | ||||
|         except Exception as exception: | ||||
|             del self.__storage.instance | ||||
|             future_result.set_exception(exception) | ||||
|         else: | ||||
|             self.__storage.instance = instance | ||||
|             future_result.set_result(instance) | ||||
| 
 | ||||
| 
 | ||||
| cdef class DelegatedThreadLocalSingleton(ThreadLocalSingleton): | ||||
|  |  | |||
|  | @ -329,6 +329,37 @@ class SingletonTests(AsyncTestCase): | |||
|         self.assertIs(instance1, instance2) | ||||
|         self.assertIs(instance, instance) | ||||
| 
 | ||||
|     def test_async_init_with_error(self): | ||||
|         # Disable default exception handling to prevent output | ||||
|         asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...) | ||||
| 
 | ||||
|         async def create_instance(): | ||||
|             create_instance.counter += 1 | ||||
|             raise RuntimeError() | ||||
| 
 | ||||
|         create_instance.counter = 0 | ||||
| 
 | ||||
|         provider = providers.Singleton(create_instance) | ||||
| 
 | ||||
| 
 | ||||
|         future = provider() | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             self._run(future) | ||||
| 
 | ||||
|         self.assertEqual(create_instance.counter, 1) | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             self._run(provider()) | ||||
| 
 | ||||
|         self.assertEqual(create_instance.counter, 2) | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|         # Restore default exception handling | ||||
|         asyncio.get_event_loop().set_exception_handler(None) | ||||
| 
 | ||||
| 
 | ||||
| class DelegatedSingletonTests(AsyncTestCase): | ||||
| 
 | ||||
|  | @ -398,6 +429,36 @@ class ThreadLocalSingletonTests(AsyncTestCase): | |||
|         self.assertIs(instance, instance) | ||||
| 
 | ||||
| 
 | ||||
|     def test_async_init_with_error(self): | ||||
|         # Disable default exception handling to prevent output | ||||
|         asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...) | ||||
| 
 | ||||
|         async def create_instance(): | ||||
|             create_instance.counter += 1 | ||||
|             raise RuntimeError() | ||||
|         create_instance.counter = 0 | ||||
| 
 | ||||
|         provider = providers.ThreadLocalSingleton(create_instance) | ||||
| 
 | ||||
|         future = provider() | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             self._run(future) | ||||
| 
 | ||||
|         self.assertEqual(create_instance.counter, 1) | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             self._run(provider()) | ||||
| 
 | ||||
|         self.assertEqual(create_instance.counter, 2) | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|         # Restore default exception handling | ||||
|         asyncio.get_event_loop().set_exception_handler(None) | ||||
| 
 | ||||
| 
 | ||||
| class DelegatedThreadLocalSingletonTests(AsyncTestCase): | ||||
| 
 | ||||
|     def test_async_mode(self): | ||||
|  |  | |||
|  | @ -447,6 +447,7 @@ class AsyncResourceTest(AsyncTestCase): | |||
| 
 | ||||
|         future = provider() | ||||
|         self.assertTrue(provider.initialized) | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|         # Disable default exception handling to prevent output | ||||
|         asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...) | ||||
|  | @ -458,6 +459,7 @@ class AsyncResourceTest(AsyncTestCase): | |||
|         asyncio.get_event_loop().set_exception_handler(None) | ||||
| 
 | ||||
|         self.assertFalse(provider.initialized) | ||||
|         self.assertTrue(provider.is_async_mode_enabled()) | ||||
| 
 | ||||
|     def test_init_and_shutdown_methods(self): | ||||
|         async def _init(): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user