import inspect import json import re from asyncio import gather, coroutines from django.db import connection, transaction from django.http import HttpResponse, HttpResponseNotAllowed from django.http.response import HttpResponseBadRequest from django.shortcuts import render from django.utils.decorators import method_decorator from django.views.decorators.csrf import ensure_csrf_cookie from django.utils.decorators import classonlymethod from django.views.generic import View from graphql import OperationType, get_operation_ast, parse, validate from graphql.error import GraphQLError from graphql.execution import ExecutionResult from graphene import Schema from graphql.execution.middleware import MiddlewareManager from graphene_django.constants import MUTATION_ERRORS_FLAG from graphene_django.utils.utils import set_rollback from .settings import graphene_settings class HttpError(Exception): def __init__(self, response, message=None, *args, **kwargs): self.response = response self.message = message = message or response.content.decode() super().__init__(message, *args, **kwargs) def get_accepted_content_types(request): def qualify(x): parts = x.split(";", 1) if len(parts) == 2: match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1]) if match: return parts[0].strip(), float(match.group(2)) return parts[0].strip(), 1 raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",") qualified_content_types = map(qualify, raw_content_types) return list( x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True) ) def instantiate_middleware(middlewares): for middleware in middlewares: if inspect.isclass(middleware): yield middleware() continue yield middleware class GraphQLView(View): graphiql_template = "graphene/graphiql.html" # Polyfill for window.fetch. whatwg_fetch_version = "3.6.2" whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c=" # React and ReactDOM. react_version = "17.0.2" react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8=" react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=" # The GraphiQL React app. graphiql_version = "2.4.1" # "1.0.3" graphiql_sri = "sha256-s+f7CFAPSUIygFnRC2nfoiEKd3liCUy+snSdYFAoLUc=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk=" graphiql_css_sri = "sha256-88yn8FJMyGboGs4Bj+Pbb3kWOWXo7jmb+XCRHE+282k=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E=" # The websocket transport library for subscriptions. subscriptions_transport_ws_version = "5.12.1" subscriptions_transport_ws_sri = ( "sha256-EZhvg6ANJrBsgLvLAa0uuHNLepLJVCFYS+xlb5U/bqw=" ) schema = None graphiql = False middleware = None root_value = None pretty = False batch = False subscription_path = None execution_context_class = None def __init__( self, schema=None, middleware=None, root_value=None, graphiql=False, pretty=False, batch=False, subscription_path=None, execution_context_class=None, ): if not schema: schema = graphene_settings.SCHEMA if middleware is None: middleware = graphene_settings.MIDDLEWARE self.schema = self.schema or schema if middleware is not None: if isinstance(middleware, MiddlewareManager): self.middleware = middleware else: self.middleware = list(instantiate_middleware(middleware)) self.root_value = root_value self.pretty = self.pretty or pretty self.graphiql = self.graphiql or graphiql self.batch = self.batch or batch self.execution_context_class = execution_context_class if subscription_path is None: self.subscription_path = graphene_settings.SUBSCRIPTION_PATH assert isinstance( self.schema, Schema ), "A Schema is required to be provided to GraphQLView." assert not all((graphiql, batch)), "Use either graphiql or batch processing" # noinspection PyUnusedLocal def get_root_value(self, request): return self.root_value def get_middleware(self, request): return self.middleware def get_context(self, request): return request @method_decorator(ensure_csrf_cookie) def dispatch(self, request, *args, **kwargs): try: if request.method.lower() not in ("get", "post"): raise HttpError( HttpResponseNotAllowed( ["GET", "POST"], "GraphQL only supports GET and POST requests." ) ) data = self.parse_body(request) show_graphiql = self.graphiql and self.can_display_graphiql(request, data) if show_graphiql: return self.render_graphiql( request, # Dependency parameters. whatwg_fetch_version=self.whatwg_fetch_version, whatwg_fetch_sri=self.whatwg_fetch_sri, react_version=self.react_version, react_sri=self.react_sri, react_dom_sri=self.react_dom_sri, graphiql_version=self.graphiql_version, graphiql_sri=self.graphiql_sri, graphiql_css_sri=self.graphiql_css_sri, subscriptions_transport_ws_version=self.subscriptions_transport_ws_version, subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri, # The SUBSCRIPTION_PATH setting. subscription_path=self.subscription_path, # GraphiQL headers tab, graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED, graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS, ) if self.batch: responses = [self.get_response(request, entry) for entry in data] result = "[{}]".format( ",".join([response[0] for response in responses]) ) status_code = ( responses and max(responses, key=lambda response: response[1])[1] or 200 ) else: result, status_code = self.get_response(request, data, show_graphiql) return HttpResponse( status=status_code, content=result, content_type="application/json" ) except HttpError as e: response = e.response response["Content-Type"] = "application/json" response.content = self.json_encode( request, {"errors": [self.format_error(e)]} ) return response def get_response(self, request, data, show_graphiql=False): query, variables, operation_name, id = self.get_graphql_params(request, data) execution_result = self.execute_graphql_request( request, data, query, variables, operation_name, show_graphiql ) if getattr(request, MUTATION_ERRORS_FLAG, False) is True: set_rollback() status_code = 200 if execution_result: response = {} if execution_result.errors: set_rollback() response["errors"] = [ self.format_error(e) for e in execution_result.errors ] if execution_result.errors and any( not getattr(e, "path", None) for e in execution_result.errors ): status_code = 400 else: response["data"] = execution_result.data if self.batch: response["id"] = id response["status"] = status_code result = self.json_encode(request, response, pretty=show_graphiql) else: result = None return result, status_code def render_graphiql(self, request, **data): return render(request, self.graphiql_template, data) def json_encode(self, request, d, pretty=False): if not (self.pretty or pretty) and not request.GET.get("pretty"): return json.dumps(d, separators=(",", ":")) return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": ")) def parse_body(self, request): content_type = self.get_content_type(request) if content_type == "application/graphql": return {"query": request.body.decode()} elif content_type == "application/json": # noinspection PyBroadException try: body = request.body.decode("utf-8") except Exception as e: raise HttpError(HttpResponseBadRequest(str(e))) try: request_json = json.loads(body) if self.batch: assert isinstance(request_json, list), ( "Batch requests should receive a list, but received {}." ).format(repr(request_json)) assert ( len(request_json) > 0 ), "Received an empty list in the batch request." else: assert isinstance( request_json, dict ), "The received data is not a valid JSON query." return request_json except AssertionError as e: raise HttpError(HttpResponseBadRequest(str(e))) except (TypeError, ValueError): raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON.")) elif content_type in [ "application/x-www-form-urlencoded", "multipart/form-data", ]: return request.POST return {} def execute_graphql_request( self, request, data, query, variables, operation_name, show_graphiql=False ): if not query: if show_graphiql: return None raise HttpError(HttpResponseBadRequest("Must provide query string.")) try: document = parse(query) except Exception as e: return ExecutionResult(errors=[e]) if request.method.lower() == "get": operation_ast = get_operation_ast(document, operation_name) if operation_ast and operation_ast.operation != OperationType.QUERY: if show_graphiql: return None raise HttpError( HttpResponseNotAllowed( ["POST"], "Can only perform a {} operation from a POST request.".format( operation_ast.operation.value ), ) ) validation_errors = validate(self.schema.graphql_schema, document) if validation_errors: return ExecutionResult(data=None, errors=validation_errors) try: extra_options = {} if self.execution_context_class: extra_options["execution_context_class"] = self.execution_context_class options = { "source": query, "root_value": self.get_root_value(request), "variable_values": variables, "operation_name": operation_name, "context_value": self.get_context(request), "middleware": self.get_middleware(request), } options.update(extra_options) operation_ast = get_operation_ast(document, operation_name) if ( operation_ast and operation_ast.operation == OperationType.MUTATION and ( graphene_settings.ATOMIC_MUTATIONS is True or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True ) ): with transaction.atomic(): result = self.schema.execute(**options) if getattr(request, MUTATION_ERRORS_FLAG, False) is True: transaction.set_rollback(True) return result return self.schema.execute(**options) except Exception as e: return ExecutionResult(errors=[e]) @classmethod def can_display_graphiql(cls, request, data): raw = "raw" in request.GET or "raw" in data return not raw and cls.request_wants_html(request) @classmethod def request_wants_html(cls, request): accepted = get_accepted_content_types(request) accepted_length = len(accepted) # the list will be ordered in preferred first - so we have to make # sure the most preferred gets the highest number html_priority = ( accepted_length - accepted.index("text/html") if "text/html" in accepted else 0 ) json_priority = ( accepted_length - accepted.index("application/json") if "application/json" in accepted else 0 ) return html_priority > json_priority @staticmethod def get_graphql_params(request, data): query = request.GET.get("query") or data.get("query") variables = request.GET.get("variables") or data.get("variables") id = request.GET.get("id") or data.get("id") if variables and isinstance(variables, str): try: variables = json.loads(variables) except Exception: raise HttpError(HttpResponseBadRequest("Variables are invalid JSON.")) operation_name = request.GET.get("operationName") or data.get("operationName") if operation_name == "null": operation_name = None return query, variables, operation_name, id @staticmethod def format_error(error): if isinstance(error, GraphQLError): return error.formatted return {"message": str(error)} @staticmethod def get_content_type(request): meta = request.META content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) return content_type.split(";", 1)[0].lower() class AsyncGraphQLView(GraphQLView): graphiql_template = "graphene/graphiql.html" # Polyfill for window.fetch. whatwg_fetch_version = "3.6.2" whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c=" # React and ReactDOM. react_version = "17.0.2" react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8=" react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=" # The GraphiQL React app. graphiql_version = "1.4.7" # "1.0.3" graphiql_sri = "sha256-cpZ8w9D/i6XdEbY/Eu7yAXeYzReVw0mxYd7OU3gUcsc=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk=" graphiql_css_sri = "sha256-HADQowUuFum02+Ckkv5Yu5ygRoLllHZqg0TFZXY7NHI=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E=" # The websocket transport library for subscriptions. subscriptions_transport_ws_version = "0.9.18" subscriptions_transport_ws_sri = ( "sha256-i0hAXd4PdJ/cHX3/8tIy/Q/qKiWr5WSTxMFuL9tACkw=" ) schema = None graphiql = False middleware = None root_value = None pretty = False batch = False subscription_path = None execution_context_class = None def __init__( self, schema=None, middleware=None, root_value=None, graphiql=False, pretty=False, batch=False, subscription_path=None, execution_context_class=None, ): if not schema: schema = graphene_settings.SCHEMA if middleware is None: middleware = graphene_settings.MIDDLEWARE self.schema = self.schema or schema if middleware is not None: if isinstance(middleware, MiddlewareManager): self.middleware = middleware else: self.middleware = list(instantiate_middleware(middleware)) self.root_value = root_value self.pretty = self.pretty or pretty self.graphiql = self.graphiql or graphiql self.batch = self.batch or batch self.execution_context_class = execution_context_class if subscription_path is None: self.subscription_path = graphene_settings.SUBSCRIPTION_PATH assert isinstance( self.schema, Schema ), "A Schema is required to be provided to GraphQLView." assert not all((graphiql, batch)), "Use either graphiql or batch processing" # noinspection PyUnusedLocal def get_root_value(self, request): return self.root_value def get_middleware(self, request): return self.middleware def get_context(self, request): return request @classonlymethod def as_view(cls, **initkwargs): view = super().as_view(**initkwargs) view._is_coroutine = coroutines._is_coroutine return view async def dispatch(self, request, *args, **kwargs): try: if request.method.lower() not in ("get", "post"): raise HttpError( HttpResponseNotAllowed( ["GET", "POST"], "GraphQL only supports GET and POST requests." ) ) data = self.parse_body(request) show_graphiql = self.graphiql and self.can_display_graphiql(request, data) if show_graphiql: return self.render_graphiql( request, # Dependency parameters. whatwg_fetch_version=self.whatwg_fetch_version, whatwg_fetch_sri=self.whatwg_fetch_sri, react_version=self.react_version, react_sri=self.react_sri, react_dom_sri=self.react_dom_sri, graphiql_version=self.graphiql_version, graphiql_sri=self.graphiql_sri, graphiql_css_sri=self.graphiql_css_sri, subscriptions_transport_ws_version=self.subscriptions_transport_ws_version, subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri, # The SUBSCRIPTION_PATH setting. subscription_path=self.subscription_path, # GraphiQL headers tab, graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED, graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS, ) if self.batch: responses = await gather(*[self.get_response(request, entry) for entry in data]) result = "[{}]".format( ",".join([response[0] for response in responses]) ) status_code = ( responses and max(responses, key=lambda response: response[1])[1] or 200 ) else: result, status_code = await self.get_response(request, data, show_graphiql) return HttpResponse( status=status_code, content=result, content_type="application/json" ) except HttpError as e: response = e.response response["Content-Type"] = "application/json" response.content = self.json_encode( request, {"errors": [self.format_error(e)]} ) return response async def get_response(self, request, data, show_graphiql=False): query, variables, operation_name, id = self.get_graphql_params(request, data) execution_result = await self.execute_graphql_request( request, data, query, variables, operation_name, show_graphiql ) if getattr(request, MUTATION_ERRORS_FLAG, False) is True: set_rollback() status_code = 200 if execution_result: response = {} if execution_result.errors: set_rollback() response["errors"] = [ self.format_error(e) for e in execution_result.errors ] if execution_result.errors and any( not getattr(e, "path", None) for e in execution_result.errors ): status_code = 400 else: response["data"] = execution_result.data if self.batch: response["id"] = id response["status"] = status_code result = self.json_encode(request, response, pretty=show_graphiql) else: result = None return result, status_code def render_graphiql(self, request, **data): return render(request, self.graphiql_template, data) def json_encode(self, request, d, pretty=False): if not (self.pretty or pretty) and not request.GET.get("pretty"): return json.dumps(d, separators=(",", ":")) return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": ")) def parse_body(self, request): content_type = self.get_content_type(request) if content_type == "application/graphql": return {"query": request.body.decode()} elif content_type == "application/json": # noinspection PyBroadException try: body = request.body.decode("utf-8") except Exception as e: raise HttpError(HttpResponseBadRequest(str(e))) try: request_json = json.loads(body) if self.batch: assert isinstance(request_json, list), ( "Batch requests should receive a list, but received {}." ).format(repr(request_json)) assert ( len(request_json) > 0 ), "Received an empty list in the batch request." else: assert isinstance( request_json, dict ), "The received data is not a valid JSON query." return request_json except AssertionError as e: raise HttpError(HttpResponseBadRequest(str(e))) except (TypeError, ValueError): raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON.")) elif content_type in [ "application/x-www-form-urlencoded", "multipart/form-data", ]: return request.POST return {} async def execute_graphql_request( self, request, data, query, variables, operation_name, show_graphiql=False ): if not query: if show_graphiql: return None raise HttpError(HttpResponseBadRequest("Must provide query string.")) try: document = parse(query) except Exception as e: return ExecutionResult(errors=[e]) if request.method.lower() == "get": operation_ast = get_operation_ast(document, operation_name) if operation_ast and operation_ast.operation != OperationType.QUERY: if show_graphiql: return None raise HttpError( HttpResponseNotAllowed( ["POST"], "Can only perform a {} operation from a POST request.".format( operation_ast.operation.value ), ) ) validation_errors = validate(self.schema.graphql_schema, document) if validation_errors: return ExecutionResult(data=None, errors=validation_errors) try: extra_options = {} if self.execution_context_class: extra_options["execution_context_class"] = self.execution_context_class options = { "source": query, "root_value": self.get_root_value(request), "variable_values": variables, "operation_name": operation_name, "context_value": self.get_context(request), "middleware": self.get_middleware(request), } options.update(extra_options) operation_ast = get_operation_ast(document, operation_name) if ( operation_ast and operation_ast.operation == OperationType.MUTATION and ( graphene_settings.ATOMIC_MUTATIONS is True or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True ) ): with transaction.atomic(): result = await self.schema.execute_async(**options) if getattr(request, MUTATION_ERRORS_FLAG, False) is True: transaction.set_rollback(True) return result return await self.schema.execute_async(**options) except Exception as e: return ExecutionResult(errors=[e]) @classmethod def can_display_graphiql(cls, request, data): raw = "raw" in request.GET or "raw" in data return not raw and cls.request_wants_html(request) @classmethod def request_wants_html(cls, request): accepted = get_accepted_content_types(request) accepted_length = len(accepted) # the list will be ordered in preferred first - so we have to make # sure the most preferred gets the highest number html_priority = ( accepted_length - accepted.index("text/html") if "text/html" in accepted else 0 ) json_priority = ( accepted_length - accepted.index("application/json") if "application/json" in accepted else 0 ) return html_priority > json_priority @staticmethod def get_graphql_params(request, data): query = request.GET.get("query") or data.get("query") variables = request.GET.get("variables") or data.get("variables") id = request.GET.get("id") or data.get("id") if variables and isinstance(variables, str): try: variables = json.loads(variables) except Exception: raise HttpError(HttpResponseBadRequest("Variables are invalid JSON.")) operation_name = request.GET.get("operationName") or data.get("operationName") if operation_name == "null": operation_name = None return query, variables, operation_name, id @staticmethod def format_error(error): if isinstance(error, GraphQLError): return error.formatted return {"message": str(error)} @staticmethod def get_content_type(request): meta = request.META content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) return content_type.split(";", 1)[0].lower()