From 32f00ef7114ec5bd31de19d5a7c9e7e84853ae2e Mon Sep 17 00:00:00 2001 From: Fabien Le Frapper Date: Thu, 26 Aug 2021 08:56:02 +0200 Subject: [PATCH] Add async view --- graphene_django/views.py | 136 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 128 insertions(+), 8 deletions(-) diff --git a/graphene_django/views.py b/graphene_django/views.py index c23b020..df75be4 100644 --- a/graphene_django/views.py +++ b/graphene_django/views.py @@ -1,4 +1,5 @@ import inspect +import asyncio import json import re @@ -6,22 +7,19 @@ from django.db import connection, transaction from django.http import HttpResponse, HttpResponseNotAllowed from django.http.response import HttpResponseBadRequest from django.shortcuts import render -from django.utils.decorators import method_decorator +from django.utils.decorators import method_decorator, classonlymethod from django.views.decorators.csrf import ensure_csrf_cookie from django.views.generic import View +from graphene import Schema +from graphene_django.constants import MUTATION_ERRORS_FLAG +from graphene_django.settings import graphene_settings +from graphene_django.utils.utils import set_rollback from graphql import OperationType, get_operation_ast, parse, validate from graphql.error import GraphQLError from graphql.error import format_error as format_graphql_error from graphql.execution import ExecutionResult - -from graphene import Schema from graphql.execution.middleware import MiddlewareManager -from graphene_django.constants import MUTATION_ERRORS_FLAG -from graphene_django.utils.utils import set_rollback - -from .settings import graphene_settings - class HttpError(Exception): def __init__(self, response, message=None, *args, **kwargs): @@ -396,3 +394,125 @@ class GraphQLView(View): meta = request.META content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) return content_type.split(";", 1)[0].lower() + +class AsyncGraphQLView(GraphQLView): + @classonlymethod + def as_view(cls, **initkwargs): + # This code tells django that this view is async, see docs here: + # https://docs.djangoproject.com/en/3.1/topics/async/#async-views + + view = super().as_view(**initkwargs) + view._is_coroutine = asyncio.coroutines._is_coroutine + return view + + async def dispatch(self, request, *args, **kwargs): + try: + if request.method.lower() not in ("get", "post"): + raise HttpError( + HttpResponseNotAllowed( + ["GET", "POST"], "GraphQL only supports GET and POST requests." + ) + ) + + data = self.parse_body(request) + show_graphiql = self.graphiql and self.can_display_graphiql(request, data) + + if show_graphiql: + return self.render_graphiql( + request, + # Dependency parameters. + whatwg_fetch_version=self.whatwg_fetch_version, + whatwg_fetch_sri=self.whatwg_fetch_sri, + react_version=self.react_version, + react_sri=self.react_sri, + react_dom_sri=self.react_dom_sri, + graphiql_version=self.graphiql_version, + graphiql_sri=self.graphiql_sri, + graphiql_css_sri=self.graphiql_css_sri, + subscriptions_transport_ws_version=self.subscriptions_transport_ws_version, + subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri, + # The SUBSCRIPTION_PATH setting. + subscription_path=self.subscription_path, + # GraphiQL headers tab, + graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED, + ) + + result, status_code = await self.get_response(request, data, show_graphiql) + + return HttpResponse( + status=status_code, content=result, content_type="application/json" + ) + + except HttpError as e: + response = e.response + response["Content-Type"] = "application/json" + response.content = self.json_encode( + request, {"errors": [self.format_error(e)]} + ) + return response + + async def get_response(self, request, data, show_graphiql=False): + query, variables, operation_name, id = self.get_graphql_params(request, data) + + execution_result = await self.execute_graphql_request( + request, data, query, variables, operation_name, show_graphiql + ) + + status_code = 200 + if execution_result: + response = {} + + if execution_result.errors: + response["errors"] = [ + self.format_error(e) for e in execution_result.errors + ] + + if execution_result.errors and any( + not getattr(e, "path", None) for e in execution_result.errors + ): + status_code = 400 + else: + response["data"] = execution_result.data + + result = self.json_encode(request, response, pretty=show_graphiql) + else: + result = None + + return result, status_code + + async def execute_graphql_request( + self, request, data, query, variables, operation_name, show_graphiql=False + ): + if not query: + if show_graphiql: + return None + raise HttpError(HttpResponseBadRequest("Must provide query string.")) + + try: + document = parse(query) + except Exception as e: + return ExecutionResult(errors=[e]) + + validation_errors = validate(self.schema.graphql_schema, document) + if validation_errors: + return ExecutionResult(data=None, errors=validation_errors) + + try: + extra_options = {} + if self.execution_context_class: + extra_options["execution_context_class"] = self.execution_context_class + + options = { + "source": query, + "root_value": self.get_root_value(request), + "variable_values": variables, + "operation_name": operation_name, + "context_value": self.get_context(request), + "middleware": self.get_middleware(request), + } + options.update(extra_options) + + return await self.schema.execute_async(**options) + except Exception as e: + return ExecutionResult(errors=[e]) +