mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-02-22 22:53:00 +03:00
Add support of async mode for FactoryAggregate provider + tests
This commit is contained in:
parent
d5a8da8907
commit
e4fd36555f
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1964,13 +1964,6 @@ cdef class FactoryAggregate(Provider):
|
|||
|
||||
return copied
|
||||
|
||||
def __call__(self, factory_name, *args, **kwargs):
|
||||
"""Create new object using factory with provided name.
|
||||
|
||||
Callable interface implementation.
|
||||
"""
|
||||
return self.__get_factory(factory_name)(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, factory_name):
|
||||
"""Return aggregated factory."""
|
||||
return self.__get_factory(factory_name)
|
||||
|
@ -1998,6 +1991,19 @@ cdef class FactoryAggregate(Provider):
|
|||
raise Error(
|
||||
'{0} providers could not be overridden'.format(self.__class__))
|
||||
|
||||
cpdef object _provide(self, tuple args, dict kwargs):
|
||||
try:
|
||||
factory_name = args[0]
|
||||
except IndexError:
|
||||
try:
|
||||
factory_name = kwargs.pop('factory_name')
|
||||
except KeyError:
|
||||
raise TypeError('Factory missing 1 required positional argument: \'factory_name\'')
|
||||
else:
|
||||
args = args[1:]
|
||||
|
||||
return self.__get_factory(factory_name)(*args, **kwargs)
|
||||
|
||||
cdef Factory __get_factory(self, str factory_name):
|
||||
if factory_name not in self.__factories:
|
||||
raise NoSuchProviderError(
|
||||
|
|
|
@ -221,6 +221,33 @@ class FactoryTests(AsyncTestCase):
|
|||
self.assertIsNot(service1.client, service2.client)
|
||||
|
||||
|
||||
class FactoryAggregateTests(AsyncTestCase):
|
||||
|
||||
def test_async_mode(self):
|
||||
object1 = object()
|
||||
object2 = object()
|
||||
|
||||
async def _get_object1():
|
||||
return object1
|
||||
|
||||
def _get_object2():
|
||||
return object2
|
||||
|
||||
provider = providers.FactoryAggregate(
|
||||
object1=providers.Factory(_get_object1),
|
||||
object2=providers.Factory(_get_object2),
|
||||
)
|
||||
|
||||
self.assertTrue(provider.is_async_mode_undefined())
|
||||
|
||||
created_object1 = self._run(provider('object1'))
|
||||
self.assertIs(created_object1, object1)
|
||||
self.assertTrue(provider.is_async_mode_enabled())
|
||||
|
||||
created_object2 = self._run(provider('object2'))
|
||||
self.assertIs(created_object2, object2)
|
||||
|
||||
|
||||
class SingletonTests(AsyncTestCase):
|
||||
|
||||
def test_injections(self):
|
||||
|
|
|
@ -520,6 +520,24 @@ class FactoryAggregateTests(unittest.TestCase):
|
|||
self.assertEqual(object_b.init_arg3, 33)
|
||||
self.assertEqual(object_b.init_arg4, 44)
|
||||
|
||||
def test_call_factory_name_as_kwarg(self):
|
||||
object_a = self.factory_aggregate(
|
||||
factory_name='example_a',
|
||||
init_arg1=1,
|
||||
init_arg2=2,
|
||||
init_arg3=3,
|
||||
init_arg4=4,
|
||||
)
|
||||
self.assertIsInstance(object_a, self.ExampleA)
|
||||
self.assertEqual(object_a.init_arg1, 1)
|
||||
self.assertEqual(object_a.init_arg2, 2)
|
||||
self.assertEqual(object_a.init_arg3, 3)
|
||||
self.assertEqual(object_a.init_arg4, 4)
|
||||
|
||||
def test_call_no_factory_name(self):
|
||||
with self.assertRaises(TypeError):
|
||||
self.factory_aggregate()
|
||||
|
||||
def test_call_no_such_provider(self):
|
||||
with self.assertRaises(errors.NoSuchProviderError):
|
||||
self.factory_aggregate('unknown')
|
||||
|
|
Loading…
Reference in New Issue
Block a user