mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 01:47:57 +03:00 
			
		
		
		
	Add support for batching several requests into one
Batch format compatible with ReactRelayNetworkLayer (https://github.com/nodkz/react-relay-network-layer)
This commit is contained in:
		
							parent
							
								
									d348ec89c5
								
							
						
					
					
						commit
						0a18558bf6
					
				| 
						 | 
					@ -8,20 +8,23 @@ except ImportError:
 | 
				
			||||||
    from urllib.parse import urlencode
 | 
					    from urllib.parse import urlencode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def url_string(**url_params):
 | 
					def url_string(string='/graphql', **url_params):
 | 
				
			||||||
    string = '/graphql'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if url_params:
 | 
					    if url_params:
 | 
				
			||||||
        string += '?' + urlencode(url_params)
 | 
					        string += '?' + urlencode(url_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return string
 | 
					    return string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def batch_url_string(**url_params):
 | 
				
			||||||
 | 
					    return url_string('/graphql/batch', **url_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def response_json(response):
 | 
					def response_json(response):
 | 
				
			||||||
    return json.loads(response.content.decode())
 | 
					    return json.loads(response.content.decode())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
j = lambda **kwargs: json.dumps(kwargs)
 | 
					j = lambda **kwargs: json.dumps(kwargs)
 | 
				
			||||||
 | 
					jl = lambda **kwargs: json.dumps([kwargs])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_graphiql_is_enabled(client):
 | 
					def test_graphiql_is_enabled(client):
 | 
				
			||||||
| 
						 | 
					@ -169,6 +172,17 @@ def test_allows_post_with_json_encoding(client):
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_batch_allows_post_with_json_encoding(client):
 | 
				
			||||||
 | 
					    response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response_json(response) == [{
 | 
				
			||||||
 | 
					        'id': 1,
 | 
				
			||||||
 | 
					        'payload': { 'data': {'test': "Hello World"} },
 | 
				
			||||||
 | 
					        'status': 200,
 | 
				
			||||||
 | 
					    }]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_allows_sending_a_mutation_via_post(client):
 | 
					def test_allows_sending_a_mutation_via_post(client):
 | 
				
			||||||
    response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json')
 | 
					    response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -199,6 +213,22 @@ def test_supports_post_json_query_with_string_variables(client):
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_batch_supports_post_json_query_with_string_variables(client):
 | 
				
			||||||
 | 
					    response = client.post(batch_url_string(), jl(
 | 
				
			||||||
 | 
					        id=1,
 | 
				
			||||||
 | 
					        query='query helloWho($who: String){ test(who: $who) }',
 | 
				
			||||||
 | 
					        variables=json.dumps({'who': "Dolly"})
 | 
				
			||||||
 | 
					    ), 'application/json')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response_json(response) == [{
 | 
				
			||||||
 | 
					        'id': 1,
 | 
				
			||||||
 | 
					        'payload': { 'data': {'test': "Hello Dolly"} },
 | 
				
			||||||
 | 
					        'status': 200,
 | 
				
			||||||
 | 
					    }]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_supports_post_json_query_with_json_variables(client):
 | 
					def test_supports_post_json_query_with_json_variables(client):
 | 
				
			||||||
    response = client.post(url_string(), j(
 | 
					    response = client.post(url_string(), j(
 | 
				
			||||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
					        query='query helloWho($who: String){ test(who: $who) }',
 | 
				
			||||||
| 
						 | 
					@ -211,6 +241,21 @@ def test_supports_post_json_query_with_json_variables(client):
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_batch_supports_post_json_query_with_json_variables(client):
 | 
				
			||||||
 | 
					    response = client.post(batch_url_string(), jl(
 | 
				
			||||||
 | 
					        id=1,
 | 
				
			||||||
 | 
					        query='query helloWho($who: String){ test(who: $who) }',
 | 
				
			||||||
 | 
					        variables={'who': "Dolly"}
 | 
				
			||||||
 | 
					    ), 'application/json')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response_json(response) == [{
 | 
				
			||||||
 | 
					        'id': 1,
 | 
				
			||||||
 | 
					        'payload': { 'data': {'test': "Hello Dolly"} },
 | 
				
			||||||
 | 
					        'status': 200,
 | 
				
			||||||
 | 
					    }]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_supports_post_url_encoded_query_with_string_variables(client):
 | 
					def test_supports_post_url_encoded_query_with_string_variables(client):
 | 
				
			||||||
    response = client.post(url_string(), urlencode(dict(
 | 
					    response = client.post(url_string(), urlencode(dict(
 | 
				
			||||||
        query='query helloWho($who: String){ test(who: $who) }',
 | 
					        query='query helloWho($who: String){ test(who: $who) }',
 | 
				
			||||||
| 
						 | 
					@ -285,6 +330,33 @@ def test_allows_post_with_operation_name(client):
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_batch_allows_post_with_operation_name(client):
 | 
				
			||||||
 | 
					    response = client.post(batch_url_string(), jl(
 | 
				
			||||||
 | 
					        id=1,
 | 
				
			||||||
 | 
					        query='''
 | 
				
			||||||
 | 
					        query helloYou { test(who: "You"), ...shared }
 | 
				
			||||||
 | 
					        query helloWorld { test(who: "World"), ...shared }
 | 
				
			||||||
 | 
					        query helloDolly { test(who: "Dolly"), ...shared }
 | 
				
			||||||
 | 
					        fragment shared on QueryRoot {
 | 
				
			||||||
 | 
					          shared: test(who: "Everyone")
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        ''',
 | 
				
			||||||
 | 
					        operationName='helloWorld'
 | 
				
			||||||
 | 
					    ), 'application/json')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response_json(response) == [{
 | 
				
			||||||
 | 
					        'id': 1,
 | 
				
			||||||
 | 
					        'payload': {
 | 
				
			||||||
 | 
					            'data': {
 | 
				
			||||||
 | 
					                'test': 'Hello World',
 | 
				
			||||||
 | 
					                'shared': 'Hello Everyone'
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        'status': 200,
 | 
				
			||||||
 | 
					    }]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_allows_post_with_get_operation_name(client):
 | 
					def test_allows_post_with_get_operation_name(client):
 | 
				
			||||||
    response = client.post(url_string(
 | 
					    response = client.post(url_string(
 | 
				
			||||||
        operationName='helloWorld'
 | 
					        operationName='helloWorld'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,5 +3,6 @@ from django.conf.urls import url
 | 
				
			||||||
from ..views import GraphQLView
 | 
					from ..views import GraphQLView
 | 
				
			||||||
 | 
					
 | 
				
			||||||
urlpatterns = [
 | 
					urlpatterns = [
 | 
				
			||||||
 | 
					    url(r'^graphql/batch', GraphQLView.as_view(batch=True)),
 | 
				
			||||||
    url(r'^graphql', GraphQLView.as_view(graphiql=True)),
 | 
					    url(r'^graphql', GraphQLView.as_view(graphiql=True)),
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -62,8 +62,10 @@ class GraphQLView(View):
 | 
				
			||||||
    middleware = None
 | 
					    middleware = None
 | 
				
			||||||
    root_value = None
 | 
					    root_value = None
 | 
				
			||||||
    pretty = False
 | 
					    pretty = False
 | 
				
			||||||
 | 
					    batch = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False):
 | 
					    def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
 | 
				
			||||||
 | 
					                 batch=False):
 | 
				
			||||||
        if not schema:
 | 
					        if not schema:
 | 
				
			||||||
            schema = graphene_settings.SCHEMA
 | 
					            schema = graphene_settings.SCHEMA
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -77,8 +79,10 @@ class GraphQLView(View):
 | 
				
			||||||
        self.root_value = root_value
 | 
					        self.root_value = root_value
 | 
				
			||||||
        self.pretty = pretty
 | 
					        self.pretty = pretty
 | 
				
			||||||
        self.graphiql = graphiql
 | 
					        self.graphiql = graphiql
 | 
				
			||||||
 | 
					        self.batch = batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
 | 
					        assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
 | 
				
			||||||
 | 
					        assert not all((graphiql, batch)), 'Use either graphiql or batch processing'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # noinspection PyUnusedLocal
 | 
					    # noinspection PyUnusedLocal
 | 
				
			||||||
    def get_root_value(self, request):
 | 
					    def get_root_value(self, request):
 | 
				
			||||||
| 
						 | 
					@ -99,32 +103,12 @@ class GraphQLView(View):
 | 
				
			||||||
            data = self.parse_body(request)
 | 
					            data = self.parse_body(request)
 | 
				
			||||||
            show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
 | 
					            show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            query, variables, operation_name = self.get_graphql_params(request, data)
 | 
					            if self.batch:
 | 
				
			||||||
 | 
					                responses = [self.get_response(request, entry) for entry in data]
 | 
				
			||||||
            execution_result = self.execute_graphql_request(
 | 
					                result = '[{}]'.format(','.join([response[0] for response in responses]))
 | 
				
			||||||
                request,
 | 
					                status_code = max(responses, key=lambda response: response[1])[1]
 | 
				
			||||||
                data,
 | 
					 | 
				
			||||||
                query,
 | 
					 | 
				
			||||||
                variables,
 | 
					 | 
				
			||||||
                operation_name,
 | 
					 | 
				
			||||||
                show_graphiql
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if execution_result:
 | 
					 | 
				
			||||||
                response = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if execution_result.errors:
 | 
					 | 
				
			||||||
                    response['errors'] = [self.format_error(e) for e in execution_result.errors]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if execution_result.invalid:
 | 
					 | 
				
			||||||
                    status_code = 400
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    status_code = 200
 | 
					 | 
				
			||||||
                    response['data'] = execution_result.data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                result = self.json_encode(request, response, pretty=show_graphiql)
 | 
					 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                result = None
 | 
					                result, status_code = self.get_response(request, data, show_graphiql)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if show_graphiql:
 | 
					            if show_graphiql:
 | 
				
			||||||
                return self.render_graphiql(
 | 
					                return self.render_graphiql(
 | 
				
			||||||
| 
						 | 
					@ -150,6 +134,43 @@ class GraphQLView(View):
 | 
				
			||||||
            })
 | 
					            })
 | 
				
			||||||
            return response
 | 
					            return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_response(self, request, data, show_graphiql=False):
 | 
				
			||||||
 | 
					        query, variables, operation_name, id = self.get_graphql_params(request, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        execution_result = self.execute_graphql_request(
 | 
				
			||||||
 | 
					            request,
 | 
				
			||||||
 | 
					            data,
 | 
				
			||||||
 | 
					            query,
 | 
				
			||||||
 | 
					            variables,
 | 
				
			||||||
 | 
					            operation_name,
 | 
				
			||||||
 | 
					            show_graphiql
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if execution_result:
 | 
				
			||||||
 | 
					            response = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if execution_result.errors:
 | 
				
			||||||
 | 
					                response['errors'] = [self.format_error(e) for e in execution_result.errors]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if execution_result.invalid:
 | 
				
			||||||
 | 
					                status_code = 400
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                status_code = 200
 | 
				
			||||||
 | 
					                response['data'] = execution_result.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.batch:
 | 
				
			||||||
 | 
					                response = {
 | 
				
			||||||
 | 
					                    'id': id,
 | 
				
			||||||
 | 
					                    'payload': response,
 | 
				
			||||||
 | 
					                    'status': status_code,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            result = self.json_encode(request, response, pretty=show_graphiql)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            result = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return result, status_code
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def render_graphiql(self, request, **data):
 | 
					    def render_graphiql(self, request, **data):
 | 
				
			||||||
        return render(request, self.graphiql_template, data)
 | 
					        return render(request, self.graphiql_template, data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,7 +191,10 @@ class GraphQLView(View):
 | 
				
			||||||
        elif content_type == 'application/json':
 | 
					        elif content_type == 'application/json':
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                request_json = json.loads(request.body.decode('utf-8'))
 | 
					                request_json = json.loads(request.body.decode('utf-8'))
 | 
				
			||||||
                assert isinstance(request_json, dict)
 | 
					                if self.batch:
 | 
				
			||||||
 | 
					                    assert isinstance(request_json, list)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    assert isinstance(request_json, dict)
 | 
				
			||||||
                return request_json
 | 
					                return request_json
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                raise HttpError(HttpResponseBadRequest('POST body sent invalid JSON.'))
 | 
					                raise HttpError(HttpResponseBadRequest('POST body sent invalid JSON.'))
 | 
				
			||||||
| 
						 | 
					@ -242,6 +266,7 @@ class GraphQLView(View):
 | 
				
			||||||
    def get_graphql_params(request, data):
 | 
					    def get_graphql_params(request, data):
 | 
				
			||||||
        query = request.GET.get('query') or data.get('query')
 | 
					        query = request.GET.get('query') or data.get('query')
 | 
				
			||||||
        variables = request.GET.get('variables') or data.get('variables')
 | 
					        variables = request.GET.get('variables') or data.get('variables')
 | 
				
			||||||
 | 
					        id = request.GET.get('id') or data.get('id')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if variables and isinstance(variables, six.text_type):
 | 
					        if variables and isinstance(variables, six.text_type):
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -251,7 +276,7 @@ class GraphQLView(View):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        operation_name = request.GET.get('operationName') or data.get('operationName')
 | 
					        operation_name = request.GET.get('operationName') or data.get('operationName')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return query, variables, operation_name
 | 
					        return query, variables, operation_name, id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def format_error(error):
 | 
					    def format_error(error):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user