From 73b8a4aac4735e4acbbad818d6d41a51966a5116 Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sat, 27 Feb 2021 09:45:49 -0500 Subject: [PATCH] Introduce wiring inspect filter (#412) * Introduce wiring inspect filter * Upgrade exclusion filter * Refactor wiring --- docs/main/changelog.rst | 7 ++- src/dependency_injector/wiring.py | 54 ++++++++++++++++++---- tests/unit/samples/wiringflask/web.py | 34 ++++++++++++++ tests/unit/wiring/test_wiringflask_py36.py | 33 +++++++++++++ 4 files changed, 117 insertions(+), 11 deletions(-) create mode 100644 tests/unit/samples/wiringflask/web.py create mode 100644 tests/unit/wiring/test_wiringflask_py36.py diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index bfc873e8..80a6c133 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -9,13 +9,16 @@ follows `Semantic versioning`_ Development version ------------------- +- Introduce wiring inspect filter to filter out ``flask.request`` and other local proxy objects + from the inspection. + See issue: `#408 `_. + Many thanks to `@bvanfleet `_ for reporting the issue and + help in finding the root cause. - Add ``boto3`` example. - Add tests for ``.as_float()`` modifier usage with wiring. - Make refactoring of wiring module and tests. See PR # `#406 `_. Thanks to `@withshubh `_ for the contribution: - - Refactor unnecessary ``else`` / ``elif`` in ``wiring`` module when ``if`` block has a - return statement. - Remove unused imports in tests. - Use literal syntax to create data structure in tests. - Add integration with a static analysis tool `DeepSource `_. diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 96d1e238..c5a591da 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -37,10 +37,21 @@ else: try: - from fastapi.params import Depends as FastAPIDepends - fastapi_installed = True + import fastapi.params except ImportError: - fastapi_installed = False + fastapi = None + + +try: + import starlette.requests +except ImportError: + starlette = None + + +try: + import werkzeug.local +except ImportError: + werkzeug = None from . import providers @@ -111,20 +122,21 @@ class ProvidersMap: ) -> Optional[providers.Provider]: if isinstance(provider, providers.Delegate): return self._resolve_delegate(provider) - if isinstance(provider, ( + elif isinstance(provider, ( providers.ProvidedInstance, providers.AttributeGetter, providers.ItemGetter, providers.MethodCaller, )): return self._resolve_provided_instance(provider) - if isinstance(provider, providers.ConfigurationOption): + elif isinstance(provider, providers.ConfigurationOption): return self._resolve_config_option(provider) - if isinstance(provider, providers.TypedConfigurationOption): + elif isinstance(provider, providers.TypedConfigurationOption): return self._resolve_config_option(provider.option, as_=provider.provides) - if isinstance(provider, str): + elif isinstance(provider, str): return self._resolve_string_id(provider, modifier) - return self._resolve_provider(provider) + else: + return self._resolve_provider(provider) def _resolve_string_id( self, @@ -247,6 +259,28 @@ class ProvidersMap: return providers_map +class InspectFilter: + + def is_excluded(self, instance: object) -> bool: + if self._is_werkzeug_local_proxy(instance): + return True + elif self._is_starlette_request_cls(instance): + return True + else: + return False + + def _is_werkzeug_local_proxy(self, instance: object) -> bool: + return werkzeug and isinstance(instance, werkzeug.local.LocalProxy) + + def _is_starlette_request_cls(self, instance: object) -> bool: + return starlette \ + and isinstance(instance, type) \ + and issubclass(instance, starlette.requests.Request) + + +inspect_filter = InspectFilter() + + def wire( # noqa: C901 container: Container, *, @@ -268,6 +302,8 @@ def wire( # noqa: C901 for module in modules: for name, member in inspect.getmembers(module): + if inspect_filter.is_excluded(member): + continue if inspect.isfunction(member): _patch_fn(module, name, member, providers_map) elif inspect.isclass(member): @@ -530,7 +566,7 @@ def _is_fastapi_default_arg_injection(injection, kwargs): def _is_fastapi_depends(param: Any) -> bool: - return fastapi_installed and isinstance(param, FastAPIDepends) + return fastapi and isinstance(param, fastapi.params.Depends) def _is_patched(fn): diff --git a/tests/unit/samples/wiringflask/web.py b/tests/unit/samples/wiringflask/web.py new file mode 100644 index 00000000..59b5d004 --- /dev/null +++ b/tests/unit/samples/wiringflask/web.py @@ -0,0 +1,34 @@ +import sys + +from flask import Flask, jsonify, request, current_app, session, g +from flask import _request_ctx_stack, _app_ctx_stack +from dependency_injector import containers, providers +from dependency_injector.wiring import inject, Provide + +# This is here for testing wiring bypasses these objects without crashing +request, current_app, session, g # noqa +_request_ctx_stack, _app_ctx_stack # noqa + + +class Service: + def process(self) -> str: + return 'Ok' + + +class Container(containers.DeclarativeContainer): + + service = providers.Factory(Service) + + +app = Flask(__name__) + + +@app.route('/') +@inject +def index(service: Service = Provide[Container.service]): + result = service.process() + return jsonify({'result': result}) + + +container = Container() +container.wire(modules=[sys.modules[__name__]]) diff --git a/tests/unit/wiring/test_wiringflask_py36.py b/tests/unit/wiring/test_wiringflask_py36.py new file mode 100644 index 00000000..1eaaa4d8 --- /dev/null +++ b/tests/unit/wiring/test_wiringflask_py36.py @@ -0,0 +1,33 @@ +import unittest + +# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir +import os +_TOP_DIR = os.path.abspath( + os.path.sep.join(( + os.path.dirname(__file__), + '../', + )), +) +_SAMPLES_DIR = os.path.abspath( + os.path.sep.join(( + os.path.dirname(__file__), + '../samples/', + )), +) +import sys +sys.path.append(_TOP_DIR) +sys.path.append(_SAMPLES_DIR) + +from wiringflask import web + + +class WiringFlaskTest(unittest.TestCase): + + def test(self): + client = web.app.test_client() + + with web.app.app_context(): + response = client.get('/') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, b'{"result":"Ok"}\n')