mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-11-04 01:47:36 +03:00 
			
		
		
		
	Add support of async mode for FactoryAggregate provider + tests
This commit is contained in:
		
							parent
							
								
									d5a8da8907
								
							
						
					
					
						commit
						e4fd36555f
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
				
			
			@ -1964,13 +1964,6 @@ cdef class FactoryAggregate(Provider):
 | 
			
		|||
 | 
			
		||||
        return copied
 | 
			
		||||
 | 
			
		||||
    def __call__(self, factory_name, *args, **kwargs):
 | 
			
		||||
        """Create new object using factory with provided name.
 | 
			
		||||
 | 
			
		||||
        Callable interface implementation.
 | 
			
		||||
        """
 | 
			
		||||
        return self.__get_factory(factory_name)(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, factory_name):
 | 
			
		||||
        """Return aggregated factory."""
 | 
			
		||||
        return self.__get_factory(factory_name)
 | 
			
		||||
| 
						 | 
				
			
			@ -1998,6 +1991,19 @@ cdef class FactoryAggregate(Provider):
 | 
			
		|||
        raise Error(
 | 
			
		||||
            '{0} providers could not be overridden'.format(self.__class__))
 | 
			
		||||
 | 
			
		||||
    cpdef object _provide(self, tuple args, dict kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            factory_name = args[0]
 | 
			
		||||
        except IndexError:
 | 
			
		||||
            try:
 | 
			
		||||
                factory_name = kwargs.pop('factory_name')
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                raise TypeError('Factory missing 1 required positional argument: \'factory_name\'')
 | 
			
		||||
        else:
 | 
			
		||||
            args = args[1:]
 | 
			
		||||
 | 
			
		||||
        return self.__get_factory(factory_name)(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    cdef Factory __get_factory(self, str factory_name):
 | 
			
		||||
        if factory_name not in self.__factories:
 | 
			
		||||
            raise NoSuchProviderError(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -221,6 +221,33 @@ class FactoryTests(AsyncTestCase):
 | 
			
		|||
        self.assertIsNot(service1.client, service2.client)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FactoryAggregateTests(AsyncTestCase):
 | 
			
		||||
 | 
			
		||||
    def test_async_mode(self):
 | 
			
		||||
        object1 = object()
 | 
			
		||||
        object2 = object()
 | 
			
		||||
 | 
			
		||||
        async def _get_object1():
 | 
			
		||||
            return object1
 | 
			
		||||
 | 
			
		||||
        def _get_object2():
 | 
			
		||||
            return object2
 | 
			
		||||
 | 
			
		||||
        provider = providers.FactoryAggregate(
 | 
			
		||||
            object1=providers.Factory(_get_object1),
 | 
			
		||||
            object2=providers.Factory(_get_object2),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(provider.is_async_mode_undefined())
 | 
			
		||||
 | 
			
		||||
        created_object1 = self._run(provider('object1'))
 | 
			
		||||
        self.assertIs(created_object1, object1)
 | 
			
		||||
        self.assertTrue(provider.is_async_mode_enabled())
 | 
			
		||||
 | 
			
		||||
        created_object2 = self._run(provider('object2'))
 | 
			
		||||
        self.assertIs(created_object2, object2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SingletonTests(AsyncTestCase):
 | 
			
		||||
 | 
			
		||||
    def test_injections(self):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -520,6 +520,24 @@ class FactoryAggregateTests(unittest.TestCase):
 | 
			
		|||
        self.assertEqual(object_b.init_arg3, 33)
 | 
			
		||||
        self.assertEqual(object_b.init_arg4, 44)
 | 
			
		||||
 | 
			
		||||
    def test_call_factory_name_as_kwarg(self):
 | 
			
		||||
        object_a = self.factory_aggregate(
 | 
			
		||||
            factory_name='example_a',
 | 
			
		||||
            init_arg1=1,
 | 
			
		||||
            init_arg2=2,
 | 
			
		||||
            init_arg3=3,
 | 
			
		||||
            init_arg4=4,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsInstance(object_a, self.ExampleA)
 | 
			
		||||
        self.assertEqual(object_a.init_arg1, 1)
 | 
			
		||||
        self.assertEqual(object_a.init_arg2, 2)
 | 
			
		||||
        self.assertEqual(object_a.init_arg3, 3)
 | 
			
		||||
        self.assertEqual(object_a.init_arg4, 4)
 | 
			
		||||
 | 
			
		||||
    def test_call_no_factory_name(self):
 | 
			
		||||
        with self.assertRaises(TypeError):
 | 
			
		||||
            self.factory_aggregate()
 | 
			
		||||
 | 
			
		||||
    def test_call_no_such_provider(self):
 | 
			
		||||
        with self.assertRaises(errors.NoSuchProviderError):
 | 
			
		||||
            self.factory_aggregate('unknown')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user