Add support of async mode for FactoryAggregate provider + tests

This commit is contained in:
Roman Mogylatov 2020-12-30 22:40:31 -05:00
parent d5a8da8907
commit e4fd36555f
5 changed files with 3737 additions and 9752 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1964,13 +1964,6 @@ cdef class FactoryAggregate(Provider):
return copied
def __call__(self, factory_name, *args, **kwargs):
"""Create new object using factory with provided name.
Callable interface implementation.
"""
return self.__get_factory(factory_name)(*args, **kwargs)
def __getattr__(self, factory_name):
"""Return aggregated factory."""
return self.__get_factory(factory_name)
@ -1998,6 +1991,19 @@ cdef class FactoryAggregate(Provider):
raise Error(
'{0} providers could not be overridden'.format(self.__class__))
cpdef object _provide(self, tuple args, dict kwargs):
try:
factory_name = args[0]
except IndexError:
try:
factory_name = kwargs.pop('factory_name')
except KeyError:
raise TypeError('Factory missing 1 required positional argument: \'factory_name\'')
else:
args = args[1:]
return self.__get_factory(factory_name)(*args, **kwargs)
cdef Factory __get_factory(self, str factory_name):
if factory_name not in self.__factories:
raise NoSuchProviderError(

View File

@ -221,6 +221,33 @@ class FactoryTests(AsyncTestCase):
self.assertIsNot(service1.client, service2.client)
class FactoryAggregateTests(AsyncTestCase):
def test_async_mode(self):
object1 = object()
object2 = object()
async def _get_object1():
return object1
def _get_object2():
return object2
provider = providers.FactoryAggregate(
object1=providers.Factory(_get_object1),
object2=providers.Factory(_get_object2),
)
self.assertTrue(provider.is_async_mode_undefined())
created_object1 = self._run(provider('object1'))
self.assertIs(created_object1, object1)
self.assertTrue(provider.is_async_mode_enabled())
created_object2 = self._run(provider('object2'))
self.assertIs(created_object2, object2)
class SingletonTests(AsyncTestCase):
def test_injections(self):

View File

@ -520,6 +520,24 @@ class FactoryAggregateTests(unittest.TestCase):
self.assertEqual(object_b.init_arg3, 33)
self.assertEqual(object_b.init_arg4, 44)
def test_call_factory_name_as_kwarg(self):
object_a = self.factory_aggregate(
factory_name='example_a',
init_arg1=1,
init_arg2=2,
init_arg3=3,
init_arg4=4,
)
self.assertIsInstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1)
self.assertEqual(object_a.init_arg2, 2)
self.assertEqual(object_a.init_arg3, 3)
self.assertEqual(object_a.init_arg4, 4)
def test_call_no_factory_name(self):
with self.assertRaises(TypeError):
self.factory_aggregate()
def test_call_no_such_provider(self):
with self.assertRaises(errors.NoSuchProviderError):
self.factory_aggregate('unknown')