Add async view

This commit is contained in:
Fabien Le Frapper 2021-08-26 08:56:02 +02:00
parent e7f7d8da07
commit 32f00ef711

View File

@ -1,4 +1,5 @@
import inspect
import asyncio
import json
import re
@ -6,22 +7,19 @@ from django.db import connection, transaction
from django.http import HttpResponse, HttpResponseNotAllowed
from django.http.response import HttpResponseBadRequest
from django.shortcuts import render
from django.utils.decorators import method_decorator
from django.utils.decorators import method_decorator, classonlymethod
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import View
from graphene import Schema
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.settings import graphene_settings
from graphene_django.utils.utils import set_rollback
from graphql import OperationType, get_operation_ast, parse, validate
from graphql.error import GraphQLError
from graphql.error import format_error as format_graphql_error
from graphql.execution import ExecutionResult
from graphene import Schema
from graphql.execution.middleware import MiddlewareManager
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.utils.utils import set_rollback
from .settings import graphene_settings
class HttpError(Exception):
def __init__(self, response, message=None, *args, **kwargs):
@ -396,3 +394,125 @@ class GraphQLView(View):
meta = request.META
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()
class AsyncGraphQLView(GraphQLView):
@classonlymethod
def as_view(cls, **initkwargs):
# This code tells django that this view is async, see docs here:
# https://docs.djangoproject.com/en/3.1/topics/async/#async-views
view = super().as_view(**initkwargs)
view._is_coroutine = asyncio.coroutines._is_coroutine
return view
async def dispatch(self, request, *args, **kwargs):
try:
if request.method.lower() not in ("get", "post"):
raise HttpError(
HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
if show_graphiql:
return self.render_graphiql(
request,
# Dependency parameters.
whatwg_fetch_version=self.whatwg_fetch_version,
whatwg_fetch_sri=self.whatwg_fetch_sri,
react_version=self.react_version,
react_sri=self.react_sri,
react_dom_sri=self.react_dom_sri,
graphiql_version=self.graphiql_version,
graphiql_sri=self.graphiql_sri,
graphiql_css_sri=self.graphiql_css_sri,
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
# The SUBSCRIPTION_PATH setting.
subscription_path=self.subscription_path,
# GraphiQL headers tab,
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
)
result, status_code = await self.get_response(request, data, show_graphiql)
return HttpResponse(
status=status_code, content=result, content_type="application/json"
)
except HttpError as e:
response = e.response
response["Content-Type"] = "application/json"
response.content = self.json_encode(
request, {"errors": [self.format_error(e)]}
)
return response
async def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(request, data)
execution_result = await self.execute_graphql_request(
request, data, query, variables, operation_name, show_graphiql
)
status_code = 200
if execution_result:
response = {}
if execution_result.errors:
response["errors"] = [
self.format_error(e) for e in execution_result.errors
]
if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
status_code = 400
else:
response["data"] = execution_result.data
result = self.json_encode(request, response, pretty=show_graphiql)
else:
result = None
return result, status_code
async def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query:
if show_graphiql:
return None
raise HttpError(HttpResponseBadRequest("Must provide query string."))
try:
document = parse(query)
except Exception as e:
return ExecutionResult(errors=[e])
validation_errors = validate(self.schema.graphql_schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
try:
extra_options = {}
if self.execution_context_class:
extra_options["execution_context_class"] = self.execution_context_class
options = {
"source": query,
"root_value": self.get_root_value(request),
"variable_values": variables,
"operation_name": operation_name,
"context_value": self.get_context(request),
"middleware": self.get_middleware(request),
}
options.update(extra_options)
return await self.schema.execute_async(**options)
except Exception as e:
return ExecutionResult(errors=[e])