From 0a18558bf6f92d5ddf38b8d8937d5194e1490579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Ochman?= Date: Mon, 31 Oct 2016 11:56:51 +0100 Subject: [PATCH] Add support for batching several requests into one Batch format compatible with ReactRelayNetworkLayer (https://github.com/nodkz/react-relay-network-layer) --- graphene_django/tests/test_views.py | 78 +++++++++++++++++++++++++-- graphene_django/tests/urls.py | 1 + graphene_django/views.py | 81 +++++++++++++++++++---------- 3 files changed, 129 insertions(+), 31 deletions(-) diff --git a/graphene_django/tests/test_views.py b/graphene_django/tests/test_views.py index 1f9b5de..e7ec187 100644 --- a/graphene_django/tests/test_views.py +++ b/graphene_django/tests/test_views.py @@ -8,20 +8,23 @@ except ImportError: from urllib.parse import urlencode -def url_string(**url_params): - string = '/graphql' - +def url_string(string='/graphql', **url_params): if url_params: string += '?' + urlencode(url_params) return string +def batch_url_string(**url_params): + return url_string('/graphql/batch', **url_params) + + def response_json(response): return json.loads(response.content.decode()) j = lambda **kwargs: json.dumps(kwargs) +jl = lambda **kwargs: json.dumps([kwargs]) def test_graphiql_is_enabled(client): @@ -169,6 +172,17 @@ def test_allows_post_with_json_encoding(client): } +def test_batch_allows_post_with_json_encoding(client): + response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json') + + assert response.status_code == 200 + assert response_json(response) == [{ + 'id': 1, + 'payload': { 'data': {'test': "Hello World"} }, + 'status': 200, + }] + + def test_allows_sending_a_mutation_via_post(client): response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json') @@ -199,6 +213,22 @@ def test_supports_post_json_query_with_string_variables(client): } + +def test_batch_supports_post_json_query_with_string_variables(client): + response = client.post(batch_url_string(), jl( + id=1, + query='query helloWho($who: String){ test(who: $who) }', + variables=json.dumps({'who': "Dolly"}) + ), 'application/json') + + assert response.status_code == 200 + assert response_json(response) == [{ + 'id': 1, + 'payload': { 'data': {'test': "Hello Dolly"} }, + 'status': 200, + }] + + def test_supports_post_json_query_with_json_variables(client): response = client.post(url_string(), j( query='query helloWho($who: String){ test(who: $who) }', @@ -211,6 +241,21 @@ def test_supports_post_json_query_with_json_variables(client): } +def test_batch_supports_post_json_query_with_json_variables(client): + response = client.post(batch_url_string(), jl( + id=1, + query='query helloWho($who: String){ test(who: $who) }', + variables={'who': "Dolly"} + ), 'application/json') + + assert response.status_code == 200 + assert response_json(response) == [{ + 'id': 1, + 'payload': { 'data': {'test': "Hello Dolly"} }, + 'status': 200, + }] + + def test_supports_post_url_encoded_query_with_string_variables(client): response = client.post(url_string(), urlencode(dict( query='query helloWho($who: String){ test(who: $who) }', @@ -285,6 +330,33 @@ def test_allows_post_with_operation_name(client): } +def test_batch_allows_post_with_operation_name(client): + response = client.post(batch_url_string(), jl( + id=1, + query=''' + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + ''', + operationName='helloWorld' + ), 'application/json') + + assert response.status_code == 200 + assert response_json(response) == [{ + 'id': 1, + 'payload': { + 'data': { + 'test': 'Hello World', + 'shared': 'Hello Everyone' + } + }, + 'status': 200, + }] + + def test_allows_post_with_get_operation_name(client): response = client.post(url_string( operationName='helloWorld' diff --git a/graphene_django/tests/urls.py b/graphene_django/tests/urls.py index ff4459e..8597baa 100644 --- a/graphene_django/tests/urls.py +++ b/graphene_django/tests/urls.py @@ -3,5 +3,6 @@ from django.conf.urls import url from ..views import GraphQLView urlpatterns = [ + url(r'^graphql/batch', GraphQLView.as_view(batch=True)), url(r'^graphql', GraphQLView.as_view(graphiql=True)), ] diff --git a/graphene_django/views.py b/graphene_django/views.py index cec3aab..b6344de 100644 --- a/graphene_django/views.py +++ b/graphene_django/views.py @@ -62,8 +62,10 @@ class GraphQLView(View): middleware = None root_value = None pretty = False + batch = False - def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False): + def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False, + batch=False): if not schema: schema = graphene_settings.SCHEMA @@ -77,8 +79,10 @@ class GraphQLView(View): self.root_value = root_value self.pretty = pretty self.graphiql = graphiql + self.batch = batch assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' + assert not all((graphiql, batch)), 'Use either graphiql or batch processing' # noinspection PyUnusedLocal def get_root_value(self, request): @@ -99,32 +103,12 @@ class GraphQLView(View): data = self.parse_body(request) show_graphiql = self.graphiql and self.can_display_graphiql(request, data) - query, variables, operation_name = self.get_graphql_params(request, data) - - execution_result = self.execute_graphql_request( - request, - data, - query, - variables, - operation_name, - show_graphiql - ) - - if execution_result: - response = {} - - if execution_result.errors: - response['errors'] = [self.format_error(e) for e in execution_result.errors] - - if execution_result.invalid: - status_code = 400 - else: - status_code = 200 - response['data'] = execution_result.data - - result = self.json_encode(request, response, pretty=show_graphiql) + if self.batch: + responses = [self.get_response(request, entry) for entry in data] + result = '[{}]'.format(','.join([response[0] for response in responses])) + status_code = max(responses, key=lambda response: response[1])[1] else: - result = None + result, status_code = self.get_response(request, data, show_graphiql) if show_graphiql: return self.render_graphiql( @@ -150,6 +134,43 @@ class GraphQLView(View): }) return response + def get_response(self, request, data, show_graphiql=False): + query, variables, operation_name, id = self.get_graphql_params(request, data) + + execution_result = self.execute_graphql_request( + request, + data, + query, + variables, + operation_name, + show_graphiql + ) + + if execution_result: + response = {} + + if execution_result.errors: + response['errors'] = [self.format_error(e) for e in execution_result.errors] + + if execution_result.invalid: + status_code = 400 + else: + status_code = 200 + response['data'] = execution_result.data + + if self.batch: + response = { + 'id': id, + 'payload': response, + 'status': status_code, + } + + result = self.json_encode(request, response, pretty=show_graphiql) + else: + result = None + + return result, status_code + def render_graphiql(self, request, **data): return render(request, self.graphiql_template, data) @@ -170,7 +191,10 @@ class GraphQLView(View): elif content_type == 'application/json': try: request_json = json.loads(request.body.decode('utf-8')) - assert isinstance(request_json, dict) + if self.batch: + assert isinstance(request_json, list) + else: + assert isinstance(request_json, dict) return request_json except: raise HttpError(HttpResponseBadRequest('POST body sent invalid JSON.')) @@ -242,6 +266,7 @@ class GraphQLView(View): def get_graphql_params(request, data): query = request.GET.get('query') or data.get('query') variables = request.GET.get('variables') or data.get('variables') + id = request.GET.get('id') or data.get('id') if variables and isinstance(variables, six.text_type): try: @@ -251,7 +276,7 @@ class GraphQLView(View): operation_name = request.GET.get('operationName') or data.get('operationName') - return query, variables, operation_name + return query, variables, operation_name, id @staticmethod def format_error(error):