Add tests for Dependency provider async mode

This commit is contained in:
Roman Mogylatov 2020-12-29 23:10:50 -05:00
parent c9242e51f6
commit d5a8da8907

View File

@ -473,7 +473,12 @@ class DependencyTests(AsyncTestCase):
provider = providers.Dependency(instance_of=float) provider = providers.Dependency(instance_of=float)
provider.override(providers.Callable(get_async)) provider.override(providers.Callable(get_async))
self.assertTrue(provider.is_async_mode_undefined())
dependency1 = self._run(provider()) dependency1 = self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
dependency2 = self._run(provider()) dependency2 = self._run(provider())
self.assertEqual(dependency1, dependency) self.assertEqual(dependency1, dependency)
@ -486,9 +491,13 @@ class DependencyTests(AsyncTestCase):
provider = providers.Dependency(instance_of=float) provider = providers.Dependency(instance_of=float)
provider.override(providers.Callable(get_async)) provider.override(providers.Callable(get_async))
self.assertTrue(provider.is_async_mode_undefined())
with self.assertRaises(errors.Error): with self.assertRaises(errors.Error):
self._run(provider()) self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
def test_async_mode(self): def test_async_mode(self):
dependency = 123 dependency = 123
@ -501,7 +510,12 @@ class DependencyTests(AsyncTestCase):
provider = providers.Dependency(instance_of=int) provider = providers.Dependency(instance_of=int)
provider.override(providers.Factory(get_async)) provider.override(providers.Factory(get_async))
self.assertTrue(provider.is_async_mode_undefined())
dependency1 = self._run(provider()) dependency1 = self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
dependency2 = self._run(provider()) dependency2 = self._run(provider())
self.assertEqual(dependency1, dependency) self.assertEqual(dependency1, dependency)
self.assertEqual(dependency2, dependency) self.assertEqual(dependency2, dependency)
@ -509,6 +523,9 @@ class DependencyTests(AsyncTestCase):
provider.override(providers.Factory(get_sync)) provider.override(providers.Factory(get_sync))
dependency3 = self._run(provider()) dependency3 = self._run(provider())
self.assertTrue(provider.is_async_mode_enabled())
dependency4 = self._run(provider()) dependency4 = self._run(provider())
self.assertEqual(dependency3, dependency) self.assertEqual(dependency3, dependency)
self.assertEqual(dependency4, dependency) self.assertEqual(dependency4, dependency)