graphene-django/graphene_django/views.py

752 lines
27 KiB
Python
Raw Normal View History

2016-09-20 08:33:46 +03:00
import inspect
import json
2016-09-20 08:15:10 +03:00
import re
2023-03-30 23:01:42 +03:00
from asyncio import gather, coroutines
from django.db import connection, transaction
from django.http import HttpResponse, HttpResponseNotAllowed
from django.http.response import HttpResponseBadRequest
from django.shortcuts import render
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import ensure_csrf_cookie
2023-03-30 23:01:42 +03:00
from django.utils.decorators import classonlymethod
from django.views.generic import View
from graphql import OperationType, get_operation_ast, parse, validate
2016-09-20 08:15:10 +03:00
from graphql.error import GraphQLError
from graphql.execution import ExecutionResult
from graphene import Schema
from graphql.execution.middleware import MiddlewareManager
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.utils.utils import set_rollback
from .settings import graphene_settings
class HttpError(Exception):
def __init__(self, response, message=None, *args, **kwargs):
self.response = response
self.message = message = message or response.content.decode()
super().__init__(message, *args, **kwargs)
def get_accepted_content_types(request):
def qualify(x):
2018-07-20 02:51:33 +03:00
parts = x.split(";", 1)
if len(parts) == 2:
2018-07-20 02:51:33 +03:00
match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
if match:
2017-12-12 05:08:42 +03:00
return parts[0].strip(), float(match.group(2))
return parts[0].strip(), 1
2018-07-20 02:51:33 +03:00
raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
qualified_content_types = map(qualify, raw_content_types)
2018-07-20 02:51:33 +03:00
return list(
x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
)
2016-09-20 08:33:46 +03:00
def instantiate_middleware(middlewares):
for middleware in middlewares:
if inspect.isclass(middleware):
yield middleware()
continue
yield middleware
class GraphQLView(View):
2018-07-20 02:51:33 +03:00
graphiql_template = "graphene/graphiql.html"
# Polyfill for window.fetch.
2021-04-21 09:05:46 +03:00
whatwg_fetch_version = "3.6.2"
whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c="
# React and ReactDOM.
2021-04-21 09:05:46 +03:00
react_version = "17.0.2"
react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8="
react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0="
# The GraphiQL React app.
2023-04-14 17:34:17 +03:00
graphiql_version = "2.4.1" # "1.0.3"
graphiql_sri = "sha256-s+f7CFAPSUIygFnRC2nfoiEKd3liCUy+snSdYFAoLUc=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk="
graphiql_css_sri = "sha256-88yn8FJMyGboGs4Bj+Pbb3kWOWXo7jmb+XCRHE+282k=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E="
# The websocket transport library for subscriptions.
2023-04-14 17:34:17 +03:00
subscriptions_transport_ws_version = "5.12.1"
subscriptions_transport_ws_sri = (
2023-04-14 17:34:17 +03:00
"sha256-EZhvg6ANJrBsgLvLAa0uuHNLepLJVCFYS+xlb5U/bqw="
)
2016-09-20 08:50:51 +03:00
schema = None
graphiql = False
middleware = None
root_value = None
pretty = False
batch = False
subscription_path = None
execution_context_class = None
2018-07-20 02:51:33 +03:00
def __init__(
self,
schema=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
subscription_path=None,
execution_context_class=None,
2018-07-20 02:51:33 +03:00
):
if not schema:
schema = graphene_settings.SCHEMA
if middleware is None:
middleware = graphene_settings.MIDDLEWARE
self.schema = self.schema or schema
2016-09-20 08:33:46 +03:00
if middleware is not None:
if isinstance(middleware, MiddlewareManager):
self.middleware = middleware
else:
self.middleware = list(instantiate_middleware(middleware))
self.root_value = root_value
self.pretty = self.pretty or pretty
self.graphiql = self.graphiql or graphiql
self.batch = self.batch or batch
self.execution_context_class = execution_context_class
if subscription_path is None:
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
2017-10-25 20:54:13 +03:00
assert isinstance(
self.schema, Schema
2018-07-20 02:51:33 +03:00
), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal
def get_root_value(self, request):
return self.root_value
def get_middleware(self, request):
return self.middleware
def get_context(self, request):
return request
@method_decorator(ensure_csrf_cookie)
def dispatch(self, request, *args, **kwargs):
try:
2018-07-20 02:51:33 +03:00
if request.method.lower() not in ("get", "post"):
raise HttpError(
HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request)
2018-07-20 02:51:33 +03:00
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
if show_graphiql:
return self.render_graphiql(
2019-05-20 14:41:25 +03:00
request,
# Dependency parameters.
whatwg_fetch_version=self.whatwg_fetch_version,
whatwg_fetch_sri=self.whatwg_fetch_sri,
2019-05-20 14:41:25 +03:00
react_version=self.react_version,
react_sri=self.react_sri,
react_dom_sri=self.react_dom_sri,
graphiql_version=self.graphiql_version,
graphiql_sri=self.graphiql_sri,
graphiql_css_sri=self.graphiql_css_sri,
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
# The SUBSCRIPTION_PATH setting.
subscription_path=self.subscription_path,
# GraphiQL headers tab,
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
)
if self.batch:
responses = [self.get_response(request, entry) for entry in data]
2018-07-20 02:51:33 +03:00
result = "[{}]".format(
",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else:
2018-07-20 02:51:33 +03:00
result, status_code = self.get_response(request, data, show_graphiql)
return HttpResponse(
2018-07-20 02:51:33 +03:00
status=status_code, content=result, content_type="application/json"
)
except HttpError as e:
response = e.response
2018-07-20 02:51:33 +03:00
response["Content-Type"] = "application/json"
response.content = self.json_encode(
request, {"errors": [self.format_error(e)]}
)
return response
def get_response(self, request, data, show_graphiql=False):
2018-07-20 02:51:33 +03:00
query, variables, operation_name, id = self.get_graphql_params(request, data)
execution_result = self.execute_graphql_request(
2018-07-20 02:51:33 +03:00
request, data, query, variables, operation_name, show_graphiql
)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
set_rollback()
2016-10-31 14:16:58 +03:00
status_code = 200
if execution_result:
response = {}
if execution_result.errors:
set_rollback()
2018-07-20 02:51:33 +03:00
response["errors"] = [
self.format_error(e) for e in execution_result.errors
]
if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
status_code = 400
else:
2018-07-20 02:51:33 +03:00
response["data"] = execution_result.data
if self.batch:
2018-07-20 02:51:33 +03:00
response["id"] = id
response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql)
else:
result = None
return result, status_code
def render_graphiql(self, request, **data):
return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False):
2018-07-20 02:51:33 +03:00
if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(",", ":"))
2018-07-20 02:51:33 +03:00
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
def parse_body(self, request):
content_type = self.get_content_type(request)
2018-07-20 02:51:33 +03:00
if content_type == "application/graphql":
return {"query": request.body.decode()}
2018-07-20 02:51:33 +03:00
elif content_type == "application/json":
# noinspection PyBroadException
try:
2018-07-20 02:51:33 +03:00
body = request.body.decode("utf-8")
except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e)))
try:
request_json = json.loads(body)
if self.batch:
2017-02-20 12:08:42 +03:00
assert isinstance(request_json, list), (
2018-07-20 02:51:33 +03:00
"Batch requests should receive a list, but received {}."
2017-02-20 12:08:42 +03:00
).format(repr(request_json))
2018-07-20 02:51:33 +03:00
assert (
len(request_json) > 0
), "Received an empty list in the batch request."
else:
2018-07-20 02:51:33 +03:00
assert isinstance(
request_json, dict
), "The received data is not a valid JSON query."
return request_json
2017-02-20 12:08:42 +03:00
except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError):
2018-07-20 02:51:33 +03:00
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
2018-07-20 02:51:33 +03:00
elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST
return {}
2018-07-20 02:51:33 +03:00
def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query:
if show_graphiql:
return None
2018-07-20 02:51:33 +03:00
raise HttpError(HttpResponseBadRequest("Must provide query string."))
try:
document = parse(query)
except Exception as e:
return ExecutionResult(errors=[e])
2018-07-20 02:51:33 +03:00
if request.method.lower() == "get":
operation_ast = get_operation_ast(document, operation_name)
if operation_ast and operation_ast.operation != OperationType.QUERY:
if show_graphiql:
return None
2016-09-20 08:15:10 +03:00
2018-07-20 02:51:33 +03:00
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_ast.operation.value
2018-07-20 02:51:33 +03:00
),
)
)
validation_errors = validate(self.schema.graphql_schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
try:
extra_options = {}
if self.execution_context_class:
extra_options["execution_context_class"] = self.execution_context_class
options = {
"source": query,
"root_value": self.get_root_value(request),
"variable_values": variables,
"operation_name": operation_name,
"context_value": self.get_context(request),
"middleware": self.get_middleware(request),
}
options.update(extra_options)
operation_ast = get_operation_ast(document, operation_name)
if (
operation_ast
and operation_ast.operation == OperationType.MUTATION
and (
graphene_settings.ATOMIC_MUTATIONS is True
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
)
):
with transaction.atomic():
result = self.schema.execute(**options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
transaction.set_rollback(True)
return result
return self.schema.execute(**options)
except Exception as e:
return ExecutionResult(errors=[e])
@classmethod
def can_display_graphiql(cls, request, data):
2018-07-20 02:51:33 +03:00
raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request)
@classmethod
def request_wants_html(cls, request):
accepted = get_accepted_content_types(request)
2017-12-12 05:08:42 +03:00
accepted_length = len(accepted)
2017-12-18 19:40:19 +03:00
# the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
2018-07-20 02:51:33 +03:00
html_priority = (
accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
2017-12-17 03:32:01 +03:00
return html_priority > json_priority
@staticmethod
def get_graphql_params(request, data):
2018-07-20 02:51:33 +03:00
query = request.GET.get("query") or data.get("query")
2023-03-30 23:01:42 +03:00
variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, str):
try:
variables = json.loads(variables)
except Exception:
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
operation_name = request.GET.get("operationName") or data.get("operationName")
if operation_name == "null":
operation_name = None
return query, variables, operation_name, id
@staticmethod
def format_error(error):
if isinstance(error, GraphQLError):
return error.formatted
return {"message": str(error)}
@staticmethod
def get_content_type(request):
meta = request.META
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()
class AsyncGraphQLView(GraphQLView):
graphiql_template = "graphene/graphiql.html"
# Polyfill for window.fetch.
whatwg_fetch_version = "3.6.2"
whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c="
# React and ReactDOM.
react_version = "17.0.2"
react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8="
react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0="
# The GraphiQL React app.
graphiql_version = "1.4.7" # "1.0.3"
graphiql_sri = "sha256-cpZ8w9D/i6XdEbY/Eu7yAXeYzReVw0mxYd7OU3gUcsc=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk="
graphiql_css_sri = "sha256-HADQowUuFum02+Ckkv5Yu5ygRoLllHZqg0TFZXY7NHI=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E="
# The websocket transport library for subscriptions.
subscriptions_transport_ws_version = "0.9.18"
subscriptions_transport_ws_sri = (
"sha256-i0hAXd4PdJ/cHX3/8tIy/Q/qKiWr5WSTxMFuL9tACkw="
)
schema = None
graphiql = False
middleware = None
root_value = None
pretty = False
batch = False
subscription_path = None
execution_context_class = None
def __init__(
self,
schema=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
subscription_path=None,
execution_context_class=None,
):
if not schema:
schema = graphene_settings.SCHEMA
if middleware is None:
middleware = graphene_settings.MIDDLEWARE
self.schema = self.schema or schema
if middleware is not None:
if isinstance(middleware, MiddlewareManager):
self.middleware = middleware
else:
self.middleware = list(instantiate_middleware(middleware))
self.root_value = root_value
self.pretty = self.pretty or pretty
self.graphiql = self.graphiql or graphiql
self.batch = self.batch or batch
self.execution_context_class = execution_context_class
if subscription_path is None:
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
assert isinstance(
self.schema, Schema
), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal
def get_root_value(self, request):
return self.root_value
def get_middleware(self, request):
return self.middleware
def get_context(self, request):
return request
@classonlymethod
def as_view(cls, **initkwargs):
view = super().as_view(**initkwargs)
view._is_coroutine = coroutines._is_coroutine
return view
async def dispatch(self, request, *args, **kwargs):
try:
if request.method.lower() not in ("get", "post"):
raise HttpError(
HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
if show_graphiql:
return self.render_graphiql(
request,
# Dependency parameters.
whatwg_fetch_version=self.whatwg_fetch_version,
whatwg_fetch_sri=self.whatwg_fetch_sri,
react_version=self.react_version,
react_sri=self.react_sri,
react_dom_sri=self.react_dom_sri,
graphiql_version=self.graphiql_version,
graphiql_sri=self.graphiql_sri,
graphiql_css_sri=self.graphiql_css_sri,
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
# The SUBSCRIPTION_PATH setting.
subscription_path=self.subscription_path,
# GraphiQL headers tab,
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
)
if self.batch:
responses = await gather(*[self.get_response(request, entry) for entry in data])
result = "[{}]".format(
",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else:
result, status_code = await self.get_response(request, data, show_graphiql)
return HttpResponse(
status=status_code, content=result, content_type="application/json"
)
except HttpError as e:
response = e.response
response["Content-Type"] = "application/json"
response.content = self.json_encode(
request, {"errors": [self.format_error(e)]}
)
return response
async def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(request, data)
execution_result = await self.execute_graphql_request(
request, data, query, variables, operation_name, show_graphiql
)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
set_rollback()
status_code = 200
if execution_result:
response = {}
if execution_result.errors:
set_rollback()
response["errors"] = [
self.format_error(e) for e in execution_result.errors
]
if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
status_code = 400
else:
response["data"] = execution_result.data
if self.batch:
response["id"] = id
response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql)
else:
result = None
return result, status_code
def render_graphiql(self, request, **data):
return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False):
if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(",", ":"))
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
def parse_body(self, request):
content_type = self.get_content_type(request)
if content_type == "application/graphql":
return {"query": request.body.decode()}
elif content_type == "application/json":
# noinspection PyBroadException
try:
body = request.body.decode("utf-8")
except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e)))
try:
request_json = json.loads(body)
if self.batch:
assert isinstance(request_json, list), (
"Batch requests should receive a list, but received {}."
).format(repr(request_json))
assert (
len(request_json) > 0
), "Received an empty list in the batch request."
else:
assert isinstance(
request_json, dict
), "The received data is not a valid JSON query."
return request_json
except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST
return {}
async def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query:
if show_graphiql:
return None
raise HttpError(HttpResponseBadRequest("Must provide query string."))
try:
document = parse(query)
except Exception as e:
return ExecutionResult(errors=[e])
if request.method.lower() == "get":
operation_ast = get_operation_ast(document, operation_name)
if operation_ast and operation_ast.operation != OperationType.QUERY:
if show_graphiql:
return None
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_ast.operation.value
),
)
)
validation_errors = validate(self.schema.graphql_schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
try:
extra_options = {}
if self.execution_context_class:
extra_options["execution_context_class"] = self.execution_context_class
options = {
"source": query,
"root_value": self.get_root_value(request),
"variable_values": variables,
"operation_name": operation_name,
"context_value": self.get_context(request),
"middleware": self.get_middleware(request),
}
options.update(extra_options)
operation_ast = get_operation_ast(document, operation_name)
if (
operation_ast
and operation_ast.operation == OperationType.MUTATION
and (
graphene_settings.ATOMIC_MUTATIONS is True
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
)
):
with transaction.atomic():
result = await self.schema.execute_async(**options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
transaction.set_rollback(True)
return result
return await self.schema.execute_async(**options)
except Exception as e:
return ExecutionResult(errors=[e])
@classmethod
def can_display_graphiql(cls, request, data):
raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request)
@classmethod
def request_wants_html(cls, request):
accepted = get_accepted_content_types(request)
accepted_length = len(accepted)
# the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
html_priority = (
accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
return html_priority > json_priority
@staticmethod
def get_graphql_params(request, data):
query = request.GET.get("query") or data.get("query")
2018-07-20 02:51:33 +03:00
variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, str):
try:
variables = json.loads(variables)
2017-10-25 20:54:13 +03:00
except Exception:
2018-07-20 02:51:33 +03:00
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
2018-07-20 02:51:33 +03:00
operation_name = request.GET.get("operationName") or data.get("operationName")
if operation_name == "null":
operation_name = None
return query, variables, operation_name, id
@staticmethod
def format_error(error):
if isinstance(error, GraphQLError):
return error.formatted
return {"message": str(error)}
@staticmethod
def get_content_type(request):
meta = request.META
2018-07-20 02:51:33 +03:00
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()