mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-11-04 01:47:36 +03:00 
			
		
		
		
	Fix Singleton and ThreadLocalSingleton to handle initialization errors
This commit is contained in:
		
							parent
							
								
									ea5af60669
								
							
						
					
					
						commit
						dcea50b3a3
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
					@ -2182,9 +2182,14 @@ cdef class BaseSingleton(Provider):
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _async_init_instance(self, future_result, result):
 | 
					    def _async_init_instance(self, future_result, result):
 | 
				
			||||||
        instance = result.result()
 | 
					        try:
 | 
				
			||||||
        self.__storage = instance
 | 
					            instance = result.result()
 | 
				
			||||||
        future_result.set_result(instance)
 | 
					        except Exception as exception:
 | 
				
			||||||
 | 
					            self.__storage = None
 | 
				
			||||||
 | 
					            future_result.set_exception(exception)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.__storage = instance
 | 
				
			||||||
 | 
					            future_result.set_result(instance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class Singleton(BaseSingleton):
 | 
					cdef class Singleton(BaseSingleton):
 | 
				
			||||||
| 
						 | 
					@ -2407,9 +2412,14 @@ cdef class ThreadLocalSingleton(BaseSingleton):
 | 
				
			||||||
            return instance
 | 
					            return instance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _async_init_instance(self, future_result, result):
 | 
					    def _async_init_instance(self, future_result, result):
 | 
				
			||||||
        instance = result.result()
 | 
					        try:
 | 
				
			||||||
        self.__storage.instance = instance
 | 
					            instance = result.result()
 | 
				
			||||||
        future_result.set_result(instance)
 | 
					        except Exception as exception:
 | 
				
			||||||
 | 
					            del self.__storage.instance
 | 
				
			||||||
 | 
					            future_result.set_exception(exception)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.__storage.instance = instance
 | 
				
			||||||
 | 
					            future_result.set_result(instance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class DelegatedThreadLocalSingleton(ThreadLocalSingleton):
 | 
					cdef class DelegatedThreadLocalSingleton(ThreadLocalSingleton):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -329,6 +329,37 @@ class SingletonTests(AsyncTestCase):
 | 
				
			||||||
        self.assertIs(instance1, instance2)
 | 
					        self.assertIs(instance1, instance2)
 | 
				
			||||||
        self.assertIs(instance, instance)
 | 
					        self.assertIs(instance, instance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_async_init_with_error(self):
 | 
				
			||||||
 | 
					        # Disable default exception handling to prevent output
 | 
				
			||||||
 | 
					        asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async def create_instance():
 | 
				
			||||||
 | 
					            create_instance.counter += 1
 | 
				
			||||||
 | 
					            raise RuntimeError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        create_instance.counter = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        provider = providers.Singleton(create_instance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        future = provider()
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(RuntimeError):
 | 
				
			||||||
 | 
					            self._run(future)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(create_instance.counter, 1)
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(RuntimeError):
 | 
				
			||||||
 | 
					            self._run(provider())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(create_instance.counter, 2)
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Restore default exception handling
 | 
				
			||||||
 | 
					        asyncio.get_event_loop().set_exception_handler(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DelegatedSingletonTests(AsyncTestCase):
 | 
					class DelegatedSingletonTests(AsyncTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -398,6 +429,36 @@ class ThreadLocalSingletonTests(AsyncTestCase):
 | 
				
			||||||
        self.assertIs(instance, instance)
 | 
					        self.assertIs(instance, instance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_async_init_with_error(self):
 | 
				
			||||||
 | 
					        # Disable default exception handling to prevent output
 | 
				
			||||||
 | 
					        asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async def create_instance():
 | 
				
			||||||
 | 
					            create_instance.counter += 1
 | 
				
			||||||
 | 
					            raise RuntimeError()
 | 
				
			||||||
 | 
					        create_instance.counter = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        provider = providers.ThreadLocalSingleton(create_instance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        future = provider()
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(RuntimeError):
 | 
				
			||||||
 | 
					            self._run(future)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(create_instance.counter, 1)
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(RuntimeError):
 | 
				
			||||||
 | 
					            self._run(provider())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(create_instance.counter, 2)
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Restore default exception handling
 | 
				
			||||||
 | 
					        asyncio.get_event_loop().set_exception_handler(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DelegatedThreadLocalSingletonTests(AsyncTestCase):
 | 
					class DelegatedThreadLocalSingletonTests(AsyncTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_async_mode(self):
 | 
					    def test_async_mode(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -447,6 +447,7 @@ class AsyncResourceTest(AsyncTestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        future = provider()
 | 
					        future = provider()
 | 
				
			||||||
        self.assertTrue(provider.initialized)
 | 
					        self.assertTrue(provider.initialized)
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Disable default exception handling to prevent output
 | 
					        # Disable default exception handling to prevent output
 | 
				
			||||||
        asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
 | 
					        asyncio.get_event_loop().set_exception_handler(lambda loop, context: ...)
 | 
				
			||||||
| 
						 | 
					@ -458,6 +459,7 @@ class AsyncResourceTest(AsyncTestCase):
 | 
				
			||||||
        asyncio.get_event_loop().set_exception_handler(None)
 | 
					        asyncio.get_event_loop().set_exception_handler(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertFalse(provider.initialized)
 | 
					        self.assertFalse(provider.initialized)
 | 
				
			||||||
 | 
					        self.assertTrue(provider.is_async_mode_enabled())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_init_and_shutdown_methods(self):
 | 
					    def test_init_and_shutdown_methods(self):
 | 
				
			||||||
        async def _init():
 | 
					        async def _init():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user