Add tests for async injections in wiring @inject

This commit is contained in:
Roman Mogylatov 2020-11-30 18:15:07 -05:00
parent c4825956f7
commit 3f8af1cd56
3 changed files with 169 additions and 0 deletions

57
tests/unit/asyncutils.py Normal file
View File

@ -0,0 +1,57 @@
"""Test utils."""
import asyncio
import contextlib
import sys
import gc
import unittest
def run(main):
loop = asyncio.get_event_loop()
return loop.run_until_complete(main)
def setup_test_loop(
loop_factory=asyncio.new_event_loop
) -> asyncio.AbstractEventLoop:
loop = loop_factory()
try:
module = loop.__class__.__module__
skip_watcher = 'uvloop' in module
except AttributeError: # pragma: no cover
# Just in case
skip_watcher = True
asyncio.set_event_loop(loop)
if sys.platform != 'win32' and not skip_watcher:
policy = asyncio.get_event_loop_policy()
watcher = asyncio.SafeChildWatcher() # type: ignore
watcher.attach_loop(loop)
with contextlib.suppress(NotImplementedError):
policy.set_child_watcher(watcher)
return loop
def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
if not fast:
gc.collect()
asyncio.set_event_loop(None)
class AsyncTestCase(unittest.TestCase):
def setUp(self):
self.loop = setup_test_loop()
def tearDown(self):
teardown_test_loop(self.loop)
def _run(self, f):
return self.loop.run_until_complete(f)

View File

@ -0,0 +1,50 @@
import asyncio
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing
class TestResource:
def __init__(self):
self.init_counter = 0
self.shutdown_counter = 0
def reset_counters(self):
self.init_counter = 0
self.shutdown_counter = 0
resource1 = TestResource()
resource2 = TestResource()
async def async_resource(resource):
await asyncio.sleep(0.001)
resource.init_counter += 1
yield resource
await asyncio.sleep(0.001)
resource.shutdown_counter += 1
class Container(containers.DeclarativeContainer):
resource1 = providers.Resource(async_resource, providers.Object(resource1))
resource2 = providers.Resource(async_resource, providers.Object(resource2))
@inject
async def async_injection(
resource1: object = Provide[Container.resource1],
resource2: object = Provide[Container.resource2],
):
return resource1, resource2
@inject
async def async_injection_with_closing(
resource1: object = Closing[Provide[Container.resource1]],
resource2: object = Closing[Provide[Container.resource2]],
):
return resource1, resource2

View File

@ -5,6 +5,12 @@ from dependency_injector.wiring import wire, Provide, Closing
# Runtime import to avoid syntax errors in samples on Python < 3.5 # Runtime import to avoid syntax errors in samples on Python < 3.5
import os import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
_SAMPLES_DIR = os.path.abspath( _SAMPLES_DIR = os.path.abspath(
os.path.sep.join(( os.path.sep.join((
os.path.dirname(__file__), os.path.dirname(__file__),
@ -12,8 +18,11 @@ _SAMPLES_DIR = os.path.abspath(
)), )),
) )
import sys import sys
sys.path.append(_TOP_DIR)
sys.path.append(_SAMPLES_DIR) sys.path.append(_SAMPLES_DIR)
from asyncutils import AsyncTestCase
from wiringsamples import module, package from wiringsamples import module, package
from wiringsamples.service import Service from wiringsamples.service import Service
from wiringsamples.container import Container, SubContainer from wiringsamples.container import Container, SubContainer
@ -267,3 +276,56 @@ class WiringAndFastAPITest(unittest.TestCase):
self.assertEqual(result_2.shutdown_counter, 2) self.assertEqual(result_2.shutdown_counter, 2)
self.assertIsNot(result_1, result_2) self.assertIsNot(result_1, result_2)
class WiringAsyncInjectionsTest(AsyncTestCase):
def test_async_injections(self):
from wiringsamples import asyncinjections
container = asyncinjections.Container()
container.wire(modules=[asyncinjections])
self.addCleanup(container.unwire)
asyncinjections.resource1.reset_counters()
asyncinjections.resource2.reset_counters()
resource1, resource2 = self._run(asyncinjections.async_injection())
self.assertIs(resource1, asyncinjections.resource1)
self.assertEqual(asyncinjections.resource1.init_counter, 1)
self.assertEqual(asyncinjections.resource1.shutdown_counter, 0)
self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 1)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 0)
def test_async_injections_with_closing(self):
from wiringsamples import asyncinjections
container = asyncinjections.Container()
container.wire(modules=[asyncinjections])
self.addCleanup(container.unwire)
asyncinjections.resource1.reset_counters()
asyncinjections.resource2.reset_counters()
resource1, resource2 = self._run(asyncinjections.async_injection_with_closing())
self.assertIs(resource1, asyncinjections.resource1)
self.assertEqual(asyncinjections.resource1.init_counter, 1)
self.assertEqual(asyncinjections.resource1.shutdown_counter, 1)
self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 1)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 1)
resource1, resource2 = self._run(asyncinjections.async_injection_with_closing())
self.assertIs(resource1, asyncinjections.resource1)
self.assertEqual(asyncinjections.resource1.init_counter, 2)
self.assertEqual(asyncinjections.resource1.shutdown_counter, 2)
self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 2)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 2)