diff --git a/tests/unit/asyncutils.py b/tests/unit/asyncutils.py new file mode 100644 index 00000000..ca1b96d2 --- /dev/null +++ b/tests/unit/asyncutils.py @@ -0,0 +1,57 @@ +"""Test utils.""" + +import asyncio +import contextlib +import sys +import gc +import unittest + + +def run(main): + loop = asyncio.get_event_loop() + return loop.run_until_complete(main) + + +def setup_test_loop( + loop_factory=asyncio.new_event_loop +) -> asyncio.AbstractEventLoop: + loop = loop_factory() + try: + module = loop.__class__.__module__ + skip_watcher = 'uvloop' in module + except AttributeError: # pragma: no cover + # Just in case + skip_watcher = True + asyncio.set_event_loop(loop) + if sys.platform != 'win32' and not skip_watcher: + policy = asyncio.get_event_loop_policy() + watcher = asyncio.SafeChildWatcher() # type: ignore + watcher.attach_loop(loop) + with contextlib.suppress(NotImplementedError): + policy.set_child_watcher(watcher) + return loop + + +def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: + closed = loop.is_closed() + if not closed: + loop.call_soon(loop.stop) + loop.run_forever() + loop.close() + + if not fast: + gc.collect() + + asyncio.set_event_loop(None) + + +class AsyncTestCase(unittest.TestCase): + + def setUp(self): + self.loop = setup_test_loop() + + def tearDown(self): + teardown_test_loop(self.loop) + + def _run(self, f): + return self.loop.run_until_complete(f) diff --git a/tests/unit/samples/wiringsamples/asyncinjections.py b/tests/unit/samples/wiringsamples/asyncinjections.py new file mode 100644 index 00000000..204300e3 --- /dev/null +++ b/tests/unit/samples/wiringsamples/asyncinjections.py @@ -0,0 +1,50 @@ +import asyncio + +from dependency_injector import containers, providers +from dependency_injector.wiring import inject, Provide, Closing + + +class TestResource: + def __init__(self): + self.init_counter = 0 + self.shutdown_counter = 0 + + def reset_counters(self): + self.init_counter = 0 + self.shutdown_counter = 0 + + +resource1 = TestResource() +resource2 = TestResource() + + +async def async_resource(resource): + await asyncio.sleep(0.001) + resource.init_counter += 1 + + yield resource + + await asyncio.sleep(0.001) + resource.shutdown_counter += 1 + + +class Container(containers.DeclarativeContainer): + + resource1 = providers.Resource(async_resource, providers.Object(resource1)) + resource2 = providers.Resource(async_resource, providers.Object(resource2)) + + +@inject +async def async_injection( + resource1: object = Provide[Container.resource1], + resource2: object = Provide[Container.resource2], +): + return resource1, resource2 + + +@inject +async def async_injection_with_closing( + resource1: object = Closing[Provide[Container.resource1]], + resource2: object = Closing[Provide[Container.resource2]], +): + return resource1, resource2 diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index 7c061ad9..ada0e8d9 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -5,6 +5,12 @@ from dependency_injector.wiring import wire, Provide, Closing # Runtime import to avoid syntax errors in samples on Python < 3.5 import os +_TOP_DIR = os.path.abspath( + os.path.sep.join(( + os.path.dirname(__file__), + '../', + )), +) _SAMPLES_DIR = os.path.abspath( os.path.sep.join(( os.path.dirname(__file__), @@ -12,8 +18,11 @@ _SAMPLES_DIR = os.path.abspath( )), ) import sys +sys.path.append(_TOP_DIR) sys.path.append(_SAMPLES_DIR) +from asyncutils import AsyncTestCase + from wiringsamples import module, package from wiringsamples.service import Service from wiringsamples.container import Container, SubContainer @@ -267,3 +276,56 @@ class WiringAndFastAPITest(unittest.TestCase): self.assertEqual(result_2.shutdown_counter, 2) self.assertIsNot(result_1, result_2) + + +class WiringAsyncInjectionsTest(AsyncTestCase): + + def test_async_injections(self): + from wiringsamples import asyncinjections + + container = asyncinjections.Container() + container.wire(modules=[asyncinjections]) + self.addCleanup(container.unwire) + + asyncinjections.resource1.reset_counters() + asyncinjections.resource2.reset_counters() + + resource1, resource2 = self._run(asyncinjections.async_injection()) + + self.assertIs(resource1, asyncinjections.resource1) + self.assertEqual(asyncinjections.resource1.init_counter, 1) + self.assertEqual(asyncinjections.resource1.shutdown_counter, 0) + + self.assertIs(resource2, asyncinjections.resource2) + self.assertEqual(asyncinjections.resource2.init_counter, 1) + self.assertEqual(asyncinjections.resource2.shutdown_counter, 0) + + def test_async_injections_with_closing(self): + from wiringsamples import asyncinjections + + container = asyncinjections.Container() + container.wire(modules=[asyncinjections]) + self.addCleanup(container.unwire) + + asyncinjections.resource1.reset_counters() + asyncinjections.resource2.reset_counters() + + resource1, resource2 = self._run(asyncinjections.async_injection_with_closing()) + + self.assertIs(resource1, asyncinjections.resource1) + self.assertEqual(asyncinjections.resource1.init_counter, 1) + self.assertEqual(asyncinjections.resource1.shutdown_counter, 1) + + self.assertIs(resource2, asyncinjections.resource2) + self.assertEqual(asyncinjections.resource2.init_counter, 1) + self.assertEqual(asyncinjections.resource2.shutdown_counter, 1) + + resource1, resource2 = self._run(asyncinjections.async_injection_with_closing()) + + self.assertIs(resource1, asyncinjections.resource1) + self.assertEqual(asyncinjections.resource1.init_counter, 2) + self.assertEqual(asyncinjections.resource1.shutdown_counter, 2) + + self.assertIs(resource2, asyncinjections.resource2) + self.assertEqual(asyncinjections.resource2.init_counter, 2) + self.assertEqual(asyncinjections.resource2.shutdown_counter, 2)