diff --git a/tests/unit/providers/test_coroutines_py3.py b/tests/unit/providers/test_coroutines_py3.py index e970b2ee..828dda88 100644 --- a/tests/unit/providers/test_coroutines_py3.py +++ b/tests/unit/providers/test_coroutines_py3.py @@ -1,6 +1,7 @@ """Dependency injector coroutine providers unit tests.""" import asyncio +from asyncio import coroutines, events, tasks import unittest2 as unittest @@ -18,9 +19,48 @@ def _example(arg1, arg2, arg3, arg4): return arg1, arg2, arg3, arg4 -def _run(coro): - loop = asyncio.get_event_loop() - return loop.run_until_complete(coro) +def _run(coro, debug=False): + if events._get_running_loop() is not None: + raise RuntimeError( + "asyncio.run() cannot be called from a running event loop") + + if not coroutines.iscoroutine(coro): + raise ValueError("a coroutine was expected, got {!r}".format(coro)) + + loop = events.new_event_loop() + try: + events.set_event_loop(loop) + loop.set_debug(debug) + return loop.run_until_complete(coro) + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + events.set_event_loop(None) + loop.close() + + +def _cancel_all_tasks(loop): + to_cancel = tasks.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete( + tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during asyncio.run() shutdown', + 'exception': task.exception(), + 'task': task, + }) class CoroutineTests(unittest.TestCase):