2020-12-06 05:36:30 +03:00
|
|
|
import asyncio
|
|
|
|
import contextlib
|
|
|
|
import gc
|
|
|
|
import unittest
|
|
|
|
from unittest import mock
|
|
|
|
|
|
|
|
from httpx import AsyncClient
|
|
|
|
|
|
|
|
# Runtime import to avoid syntax errors in samples on Python < 3.5
|
|
|
|
import os
|
|
|
|
_SAMPLES_DIR = os.path.abspath(
|
|
|
|
os.path.sep.join((
|
|
|
|
os.path.dirname(__file__),
|
|
|
|
'../samples/',
|
|
|
|
)),
|
|
|
|
)
|
|
|
|
import sys
|
|
|
|
sys.path.append(_SAMPLES_DIR)
|
|
|
|
|
|
|
|
from wiringfastapi import web
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Refactor to use common async test case
|
|
|
|
def setup_test_loop(
|
|
|
|
loop_factory=asyncio.new_event_loop
|
|
|
|
) -> asyncio.AbstractEventLoop:
|
|
|
|
loop = loop_factory()
|
|
|
|
try:
|
|
|
|
module = loop.__class__.__module__
|
|
|
|
skip_watcher = 'uvloop' in module
|
|
|
|
except AttributeError: # pragma: no cover
|
|
|
|
# Just in case
|
|
|
|
skip_watcher = True
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
if sys.platform != "win32" and not skip_watcher:
|
|
|
|
policy = asyncio.get_event_loop_policy()
|
|
|
|
watcher = asyncio.SafeChildWatcher() # type: ignore
|
|
|
|
watcher.attach_loop(loop)
|
|
|
|
with contextlib.suppress(NotImplementedError):
|
|
|
|
policy.set_child_watcher(watcher)
|
|
|
|
return loop
|
|
|
|
|
|
|
|
|
|
|
|
def teardown_test_loop(loop: asyncio.AbstractEventLoop,
|
|
|
|
fast: bool=False) -> None:
|
|
|
|
closed = loop.is_closed()
|
|
|
|
if not closed:
|
|
|
|
loop.call_soon(loop.stop)
|
|
|
|
loop.run_forever()
|
|
|
|
loop.close()
|
|
|
|
|
|
|
|
if not fast:
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
asyncio.set_event_loop(None)
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncTestCase(unittest.TestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
self.loop = setup_test_loop()
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
teardown_test_loop(self.loop)
|
|
|
|
|
|
|
|
def _run(self, f):
|
|
|
|
return self.loop.run_until_complete(f)
|
|
|
|
|
|
|
|
|
|
|
|
class WiringFastAPITest(AsyncTestCase):
|
|
|
|
|
|
|
|
client: AsyncClient
|
|
|
|
|
|
|
|
def setUp(self) -> None:
|
|
|
|
super().setUp()
|
|
|
|
self.client = AsyncClient(app=web.app, base_url='http://test')
|
|
|
|
|
|
|
|
def tearDown(self) -> None:
|
|
|
|
self._run(self.client.aclose())
|
|
|
|
super().tearDown()
|
|
|
|
|
|
|
|
def test_depends_marker_injection(self):
|
2020-12-06 05:58:42 +03:00
|
|
|
class ServiceMock:
|
|
|
|
async def process(self):
|
|
|
|
return 'Foo'
|
2020-12-06 05:36:30 +03:00
|
|
|
|
2020-12-06 05:58:42 +03:00
|
|
|
with web.container.service.override(ServiceMock()):
|
2020-12-06 05:36:30 +03:00
|
|
|
response = self._run(self.client.get('/'))
|
|
|
|
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
self.assertEqual(response.json(), {'result': 'Foo'})
|
|
|
|
|
|
|
|
def test_depends_injection(self):
|
|
|
|
response = self._run(self.client.get('/auth', auth=('john_smith', 'secret')))
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
self.assertEqual(response.json(), {'username': 'john_smith', 'password': 'secret'})
|