mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-10-25 13:10:59 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			324 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			324 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import inspect
 | |
| import json
 | |
| import re
 | |
| 
 | |
| import six
 | |
| from django.http import HttpResponse, HttpResponseNotAllowed
 | |
| from django.http.response import HttpResponseBadRequest
 | |
| from django.shortcuts import render
 | |
| from django.utils.decorators import method_decorator
 | |
| from django.views.generic import View
 | |
| from django.views.decorators.csrf import ensure_csrf_cookie
 | |
| 
 | |
| from graphql import Source, execute, parse, validate
 | |
| from graphql.error import format_error as format_graphql_error
 | |
| from graphql.error import GraphQLError
 | |
| from graphql.execution import ExecutionResult
 | |
| from graphql.type.schema import GraphQLSchema
 | |
| from graphql.utils.get_operation_ast import get_operation_ast
 | |
| 
 | |
| from .settings import graphene_settings
 | |
| 
 | |
| 
 | |
| class HttpError(Exception):
 | |
| 
 | |
|     def __init__(self, response, message=None, *args, **kwargs):
 | |
|         self.response = response
 | |
|         self.message = message = message or response.content.decode()
 | |
|         super(HttpError, self).__init__(message, *args, **kwargs)
 | |
| 
 | |
| 
 | |
| def get_accepted_content_types(request):
 | |
|     def qualify(x):
 | |
|         parts = x.split(';', 1)
 | |
|         if len(parts) == 2:
 | |
|             match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)',
 | |
|                              parts[1])
 | |
|             if match:
 | |
|                 return parts[0], float(match.group(2))
 | |
|         return parts[0], 1
 | |
| 
 | |
|     raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',')
 | |
|     qualified_content_types = map(qualify, raw_content_types)
 | |
|     return list(x[0] for x in sorted(qualified_content_types,
 | |
|                                      key=lambda x: x[1], reverse=True))
 | |
| 
 | |
| 
 | |
| def instantiate_middleware(middlewares):
 | |
|     for middleware in middlewares:
 | |
|         if inspect.isclass(middleware):
 | |
|             yield middleware()
 | |
|             continue
 | |
|         yield middleware
 | |
| 
 | |
| 
 | |
| class GraphQLView(View):
 | |
|     graphiql_version = '0.10.2'
 | |
|     graphiql_template = 'graphene/graphiql.html'
 | |
| 
 | |
|     schema = None
 | |
|     graphiql = False
 | |
|     executor = None
 | |
|     middleware = None
 | |
|     root_value = None
 | |
|     pretty = False
 | |
|     batch = False
 | |
| 
 | |
|     def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
 | |
|                  batch=False):
 | |
|         if not schema:
 | |
|             schema = graphene_settings.SCHEMA
 | |
| 
 | |
|         if middleware is None:
 | |
|             middleware = graphene_settings.MIDDLEWARE
 | |
| 
 | |
|         self.schema = schema
 | |
|         if middleware is not None:
 | |
|             self.middleware = list(instantiate_middleware(middleware))
 | |
|         self.executor = executor
 | |
|         self.root_value = root_value
 | |
|         self.pretty = pretty
 | |
|         self.graphiql = graphiql
 | |
|         self.batch = batch
 | |
| 
 | |
|         assert isinstance(
 | |
|             self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
 | |
|         assert not all((graphiql, batch)
 | |
|                        ), 'Use either graphiql or batch processing'
 | |
| 
 | |
|     # noinspection PyUnusedLocal
 | |
|     def get_root_value(self, request):
 | |
|         return self.root_value
 | |
| 
 | |
|     def get_middleware(self, request):
 | |
|         return self.middleware
 | |
| 
 | |
|     def get_context(self, request):
 | |
|         return request
 | |
| 
 | |
|     @method_decorator(ensure_csrf_cookie)
 | |
|     def dispatch(self, request, *args, **kwargs):
 | |
|         try:
 | |
|             if request.method.lower() not in ('get', 'post'):
 | |
|                 raise HttpError(HttpResponseNotAllowed(
 | |
|                     ['GET', 'POST'], 'GraphQL only supports GET and POST requests.'))
 | |
| 
 | |
|             data = self.parse_body(request)
 | |
|             show_graphiql = self.graphiql and self.can_display_graphiql(
 | |
|                 request, data)
 | |
| 
 | |
|             if self.batch:
 | |
|                 responses = [self.get_response(
 | |
|                     request, entry) for entry in data]
 | |
|                 result = '[{}]'.format(
 | |
|                     ','.join([response[0] for response in responses]))
 | |
|                 status_code = max(
 | |
|                     responses, key=lambda response: response[1])[1]
 | |
|             else:
 | |
|                 result, status_code = self.get_response(
 | |
|                     request, data, show_graphiql)
 | |
| 
 | |
|             if show_graphiql:
 | |
|                 query, variables, operation_name, id = self.get_graphql_params(
 | |
|                     request, data)
 | |
|                 return self.render_graphiql(
 | |
|                     request,
 | |
|                     graphiql_version=self.graphiql_version,
 | |
|                     query=query or '',
 | |
|                     variables=json.dumps(variables) or '',
 | |
|                     operation_name=operation_name or '',
 | |
|                     result=result or ''
 | |
|                 )
 | |
| 
 | |
|             return HttpResponse(
 | |
|                 status=status_code,
 | |
|                 content=result,
 | |
|                 content_type='application/json'
 | |
|             )
 | |
| 
 | |
|         except HttpError as e:
 | |
|             response = e.response
 | |
|             response['Content-Type'] = 'application/json'
 | |
|             response.content = self.json_encode(request, {
 | |
|                 'errors': [self.format_error(e)]
 | |
|             })
 | |
|             return response
 | |
| 
 | |
|     def get_response(self, request, data, show_graphiql=False):
 | |
|         query, variables, operation_name, id = self.get_graphql_params(
 | |
|             request, data)
 | |
| 
 | |
|         execution_result = self.execute_graphql_request(
 | |
|             request,
 | |
|             data,
 | |
|             query,
 | |
|             variables,
 | |
|             operation_name,
 | |
|             show_graphiql
 | |
|         )
 | |
| 
 | |
|         status_code = 200
 | |
|         if execution_result:
 | |
|             response = {}
 | |
| 
 | |
|             if execution_result.errors:
 | |
|                 response['errors'] = [self.format_error(
 | |
|                     e) for e in execution_result.errors]
 | |
| 
 | |
|             if execution_result.invalid:
 | |
|                 status_code = 400
 | |
|             else:
 | |
|                 response['data'] = execution_result.data
 | |
| 
 | |
|             if self.batch:
 | |
|                 response['id'] = id
 | |
|                 response['status'] = status_code
 | |
| 
 | |
|             result = self.json_encode(request, response, pretty=show_graphiql)
 | |
|         else:
 | |
|             result = None
 | |
| 
 | |
|         return result, status_code
 | |
| 
 | |
|     def render_graphiql(self, request, **data):
 | |
|         return render(request, self.graphiql_template, data)
 | |
| 
 | |
|     def json_encode(self, request, d, pretty=False):
 | |
|         if not (self.pretty or pretty) and not request.GET.get('pretty'):
 | |
|             return json.dumps(d, separators=(',', ':'))
 | |
| 
 | |
|         return json.dumps(d, sort_keys=True,
 | |
|                           indent=2, separators=(',', ': '))
 | |
| 
 | |
|     def parse_body(self, request):
 | |
|         content_type = self.get_content_type(request)
 | |
| 
 | |
|         if content_type == 'application/graphql':
 | |
|             return {'query': request.body.decode()}
 | |
| 
 | |
|         elif content_type == 'application/json':
 | |
|             # noinspection PyBroadException
 | |
|             try:
 | |
|                 body = request.body.decode('utf-8')
 | |
|             except Exception as e:
 | |
|                 raise HttpError(HttpResponseBadRequest(str(e)))
 | |
| 
 | |
|             try:
 | |
|                 request_json = json.loads(body)
 | |
|                 if self.batch:
 | |
|                     assert isinstance(request_json, list), (
 | |
|                         'Batch requests should receive a list, but received {}.'
 | |
|                     ).format(repr(request_json))
 | |
|                     assert len(request_json) > 0, (
 | |
|                         'Received an empty list in the batch request.'
 | |
|                     )
 | |
|                 else:
 | |
|                     assert isinstance(request_json, dict), (
 | |
|                         'The received data is not a valid JSON query.'
 | |
|                     )
 | |
|                 return request_json
 | |
|             except AssertionError as e:
 | |
|                 raise HttpError(HttpResponseBadRequest(str(e)))
 | |
|             except (TypeError, ValueError):
 | |
|                 raise HttpError(HttpResponseBadRequest(
 | |
|                     'POST body sent invalid JSON.'))
 | |
| 
 | |
|         elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']:
 | |
|             return request.POST
 | |
| 
 | |
|         return {}
 | |
| 
 | |
|     def execute(self, *args, **kwargs):
 | |
|         return execute(self.schema, *args, **kwargs)
 | |
| 
 | |
|     def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False):
 | |
|         if not query:
 | |
|             if show_graphiql:
 | |
|                 return None
 | |
|             raise HttpError(HttpResponseBadRequest(
 | |
|                 'Must provide query string.'))
 | |
| 
 | |
|         source = Source(query, name='GraphQL request')
 | |
| 
 | |
|         try:
 | |
|             document_ast = parse(source)
 | |
|             validation_errors = validate(self.schema, document_ast)
 | |
|             if validation_errors:
 | |
|                 return ExecutionResult(
 | |
|                     errors=validation_errors,
 | |
|                     invalid=True,
 | |
|                 )
 | |
|         except Exception as e:
 | |
|             return ExecutionResult(errors=[e], invalid=True)
 | |
| 
 | |
|         if request.method.lower() == 'get':
 | |
|             operation_ast = get_operation_ast(document_ast, operation_name)
 | |
|             if operation_ast and operation_ast.operation != 'query':
 | |
|                 if show_graphiql:
 | |
|                     return None
 | |
| 
 | |
|                 raise HttpError(HttpResponseNotAllowed(
 | |
|                     ['POST'], 'Can only perform a {} operation from a POST request.'.format(
 | |
|                         operation_ast.operation)
 | |
|                 ))
 | |
| 
 | |
|         try:
 | |
|             return self.execute(
 | |
|                 document_ast,
 | |
|                 root_value=self.get_root_value(request),
 | |
|                 variable_values=variables,
 | |
|                 operation_name=operation_name,
 | |
|                 context_value=self.get_context(request),
 | |
|                 middleware=self.get_middleware(request),
 | |
|                 executor=self.executor,
 | |
|             )
 | |
|         except Exception as e:
 | |
|             return ExecutionResult(errors=[e], invalid=True)
 | |
| 
 | |
|     @classmethod
 | |
|     def can_display_graphiql(cls, request, data):
 | |
|         raw = 'raw' in request.GET or 'raw' in data
 | |
|         return not raw and cls.request_wants_html(request)
 | |
| 
 | |
|     @classmethod
 | |
|     def request_wants_html(cls, request):
 | |
|         accepted = get_accepted_content_types(request)
 | |
|         html_index = accepted.count('text/html')
 | |
|         json_index = accepted.count('application/json')
 | |
| 
 | |
|         return html_index > json_index
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_graphql_params(request, data):
 | |
|         query = request.GET.get('query') or data.get('query')
 | |
|         variables = request.GET.get('variables') or data.get('variables')
 | |
|         id = request.GET.get('id') or data.get('id')
 | |
| 
 | |
|         if variables and isinstance(variables, six.text_type):
 | |
|             try:
 | |
|                 variables = json.loads(variables)
 | |
|             except Exception:
 | |
|                 raise HttpError(HttpResponseBadRequest(
 | |
|                     'Variables are invalid JSON.'))
 | |
| 
 | |
|         operation_name = request.GET.get(
 | |
|             'operationName') or data.get('operationName')
 | |
|         if operation_name == "null":
 | |
|             operation_name = None
 | |
| 
 | |
|         return query, variables, operation_name, id
 | |
| 
 | |
|     @staticmethod
 | |
|     def format_error(error):
 | |
|         if isinstance(error, GraphQLError):
 | |
|             return format_graphql_error(error)
 | |
| 
 | |
|         return {'message': six.text_type(error)}
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_content_type(request):
 | |
|         meta = request.META
 | |
|         content_type = meta.get(
 | |
|             'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
 | |
|         return content_type.split(';', 1)[0].lower()
 |