From ec49d5675144e9cb438f2f1e1a79464a2f134263 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sat, 5 Dec 2020 21:36:30 -0500 Subject: [PATCH] Add tests for FastAPI wiring --- tests/unit/samples/wiringfastapi/web.py | 39 ++++++++ tests/unit/wiring/test_wiringfastapi_py36.py | 95 ++++++++++++++++++++ tox.ini | 1 + 3 files changed, 135 insertions(+) create mode 100644 tests/unit/samples/wiringfastapi/web.py create mode 100644 tests/unit/wiring/test_wiringfastapi_py36.py diff --git a/tests/unit/samples/wiringfastapi/web.py b/tests/unit/samples/wiringfastapi/web.py new file mode 100644 index 00000000..2e563d7d --- /dev/null +++ b/tests/unit/samples/wiringfastapi/web.py @@ -0,0 +1,39 @@ +import sys + +from fastapi import FastAPI, Depends +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from dependency_injector import containers, providers +from dependency_injector.wiring import inject, Provide + + +class Service: + async def process(self) -> str: + return 'Ok' + + +class Container(containers.DeclarativeContainer): + + service = providers.Factory(Service) + + +app = FastAPI() +security = HTTPBasic() + + +@app.api_route('/') +@inject +async def index(service: Service = Depends(Provide[Container.service])): + result = await service.process() + return {'result': result} + + +@app.get('/auth') +@inject +def read_current_user( + credentials: HTTPBasicCredentials = Depends(security) +): + return {'username': credentials.username, 'password': credentials.password} + + +container = Container() +container.wire(modules=[sys.modules[__name__]]) diff --git a/tests/unit/wiring/test_wiringfastapi_py36.py b/tests/unit/wiring/test_wiringfastapi_py36.py new file mode 100644 index 00000000..0e349409 --- /dev/null +++ b/tests/unit/wiring/test_wiringfastapi_py36.py @@ -0,0 +1,95 @@ +import asyncio +import contextlib +import gc +import unittest +from unittest import mock + +from httpx import AsyncClient + +# Runtime import to avoid syntax errors in samples on Python < 3.5 +import os +_SAMPLES_DIR = os.path.abspath( + os.path.sep.join(( + os.path.dirname(__file__), + '../samples/', + )), +) +import sys +sys.path.append(_SAMPLES_DIR) + +from wiringfastapi import web + + +# TODO: Refactor to use common async test case +def setup_test_loop( + loop_factory=asyncio.new_event_loop +) -> asyncio.AbstractEventLoop: + loop = loop_factory() + try: + module = loop.__class__.__module__ + skip_watcher = 'uvloop' in module + except AttributeError: # pragma: no cover + # Just in case + skip_watcher = True + asyncio.set_event_loop(loop) + if sys.platform != "win32" and not skip_watcher: + policy = asyncio.get_event_loop_policy() + watcher = asyncio.SafeChildWatcher() # type: ignore + watcher.attach_loop(loop) + with contextlib.suppress(NotImplementedError): + policy.set_child_watcher(watcher) + return loop + + +def teardown_test_loop(loop: asyncio.AbstractEventLoop, + fast: bool=False) -> None: + closed = loop.is_closed() + if not closed: + loop.call_soon(loop.stop) + loop.run_forever() + loop.close() + + if not fast: + gc.collect() + + asyncio.set_event_loop(None) + + +class AsyncTestCase(unittest.TestCase): + + def setUp(self): + self.loop = setup_test_loop() + + def tearDown(self): + teardown_test_loop(self.loop) + + def _run(self, f): + return self.loop.run_until_complete(f) + + +class WiringFastAPITest(AsyncTestCase): + + client: AsyncClient + + def setUp(self) -> None: + super().setUp() + self.client = AsyncClient(app=web.app, base_url='http://test') + + def tearDown(self) -> None: + self._run(self.client.aclose()) + super().tearDown() + + def test_depends_marker_injection(self): + service_mock = mock.AsyncMock(spec=web.Service) + service_mock.process.return_value = 'Foo' + + with web.container.service.override(service_mock): + response = self._run(self.client.get('/')) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {'result': 'Foo'}) + + def test_depends_injection(self): + response = self._run(self.client.get('/auth', auth=('john_smith', 'secret'))) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {'username': 'john_smith', 'password': 'secret'}) diff --git a/tox.ini b/tox.ini index b73fff22..6c889816 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,7 @@ deps= unittest2 # TODO: Hotfix, remove when fixed https://github.com/aio-libs/aiohttp/issues/5107 typing_extensions + httpx extras= yaml flask