mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-07-09 15:52:26 +03:00
Merge c2d601c0e2
into ea45de02ad
This commit is contained in:
commit
a8e6217150
|
@ -1,4 +1,8 @@
|
|||
from graphene import Node
|
||||
import asyncio
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from graphene import Field, Node, String
|
||||
from graphene_django.filter import DjangoFilterConnectionField
|
||||
from graphene_django.types import DjangoObjectType
|
||||
|
||||
|
@ -6,12 +10,32 @@ from cookbook.recipes.models import Recipe, RecipeIngredient
|
|||
|
||||
|
||||
class RecipeNode(DjangoObjectType):
|
||||
async_field = String()
|
||||
|
||||
class Meta:
|
||||
model = Recipe
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
filter_fields = ["title", "amounts"]
|
||||
|
||||
async def resolve_async_field(self, info):
|
||||
await asyncio.sleep(2)
|
||||
return "success"
|
||||
|
||||
|
||||
class RecipeType(DjangoObjectType):
|
||||
async_field = String()
|
||||
|
||||
class Meta:
|
||||
model = Recipe
|
||||
fields = "__all__"
|
||||
filter_fields = ["title", "amounts"]
|
||||
skip_registry = True
|
||||
|
||||
async def resolve_async_field(self, info):
|
||||
await asyncio.sleep(2)
|
||||
return "success"
|
||||
|
||||
|
||||
class RecipeIngredientNode(DjangoObjectType):
|
||||
class Meta:
|
||||
|
@ -28,7 +52,13 @@ class RecipeIngredientNode(DjangoObjectType):
|
|||
|
||||
class Query:
|
||||
recipe = Node.Field(RecipeNode)
|
||||
raw_recipe = Field(RecipeType)
|
||||
all_recipes = DjangoFilterConnectionField(RecipeNode)
|
||||
|
||||
recipeingredient = Node.Field(RecipeIngredientNode)
|
||||
all_recipeingredients = DjangoFilterConnectionField(RecipeIngredientNode)
|
||||
|
||||
@staticmethod
|
||||
@sync_to_async
|
||||
def resolve_raw_recipe(self, info):
|
||||
return Recipe.objects.first()
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from django.conf.urls import url
|
||||
from django.contrib import admin
|
||||
from django.urls import re_path
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from graphene_django.views import GraphQLView
|
||||
from graphene_django.views import AsyncGraphQLView
|
||||
|
||||
urlpatterns = [
|
||||
url(r"^admin/", admin.site.urls),
|
||||
url(r"^graphql$", GraphQLView.as_view(graphiql=True)),
|
||||
re_path(r"^admin/", admin.site.urls),
|
||||
re_path(r"^graphql$", csrf_exempt(AsyncGraphQLView.as_view(graphiql=True))),
|
||||
]
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from asgiref.sync import sync_to_async
|
||||
from django.db import connections
|
||||
from graphql.type.definition import GraphQLNonNull
|
||||
|
||||
from ..utils import is_running_async, is_sync_function
|
||||
from .exception.formating import wrap_exception
|
||||
from .sql.tracking import unwrap_cursor, wrap_cursor
|
||||
from .types import DjangoDebug
|
||||
|
@ -67,3 +70,28 @@ class DjangoDebugMiddleware:
|
|||
return context.django_debug.on_resolve_error(e)
|
||||
context.django_debug.add_result(result)
|
||||
return result
|
||||
|
||||
|
||||
class DjangoSyncRequiredMiddleware:
|
||||
def resolve(self, next, root, info, **args):
|
||||
parent_type = info.parent_type
|
||||
return_type = info.return_type
|
||||
|
||||
if isinstance(parent_type, GraphQLNonNull):
|
||||
parent_type = parent_type.of_type
|
||||
if isinstance(return_type, GraphQLNonNull):
|
||||
return_type = return_type.of_type
|
||||
|
||||
if any(
|
||||
[
|
||||
hasattr(parent_type, "graphene_type")
|
||||
and hasattr(parent_type.graphene_type._meta, "model"),
|
||||
hasattr(return_type, "graphene_type")
|
||||
and hasattr(return_type.graphene_type._meta, "model"),
|
||||
info.parent_type.name == "Mutation",
|
||||
]
|
||||
):
|
||||
if is_sync_function(next) and is_running_async():
|
||||
return sync_to_async(next)(root, info, **args)
|
||||
|
||||
return next(root, info, **args)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from functools import partial
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.db.models.query import QuerySet
|
||||
from graphql_relay import (
|
||||
connection_from_array_slice,
|
||||
|
@ -7,7 +8,6 @@ from graphql_relay import (
|
|||
get_offset_with_default,
|
||||
offset_to_cursor,
|
||||
)
|
||||
from promise import Promise
|
||||
|
||||
from graphene import Int, NonNull
|
||||
from graphene.relay import ConnectionField
|
||||
|
@ -15,7 +15,7 @@ from graphene.relay.connection import connection_adapter, page_info_adapter
|
|||
from graphene.types import Field, List
|
||||
|
||||
from .settings import graphene_settings
|
||||
from .utils import maybe_queryset
|
||||
from .utils import is_running_async, is_sync_function, maybe_queryset
|
||||
|
||||
|
||||
class DjangoListField(Field):
|
||||
|
@ -49,11 +49,36 @@ class DjangoListField(Field):
|
|||
def get_manager(self):
|
||||
return self.model._default_manager
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def list_resolver(
|
||||
django_object_type, resolver, default_manager, root, info, **args
|
||||
cls, django_object_type, resolver, default_manager, root, info, **args
|
||||
):
|
||||
queryset = maybe_queryset(resolver(root, info, **args))
|
||||
if is_running_async():
|
||||
if is_sync_function(resolver):
|
||||
resolver = sync_to_async(resolver)
|
||||
|
||||
iterable = resolver(root, info, **args)
|
||||
|
||||
if info.is_awaitable(iterable):
|
||||
|
||||
async def resolve_list_async(iterable):
|
||||
queryset = maybe_queryset(await iterable)
|
||||
if queryset is None:
|
||||
queryset = maybe_queryset(default_manager)
|
||||
|
||||
if isinstance(queryset, QuerySet):
|
||||
# Pass queryset to the DjangoObjectType get_queryset method
|
||||
queryset = maybe_queryset(
|
||||
await sync_to_async(django_object_type.get_queryset)(
|
||||
queryset, info
|
||||
)
|
||||
)
|
||||
|
||||
return await sync_to_async(list)(queryset)
|
||||
|
||||
return resolve_list_async(iterable)
|
||||
|
||||
queryset = maybe_queryset(iterable)
|
||||
if queryset is None:
|
||||
queryset = maybe_queryset(default_manager)
|
||||
|
||||
|
@ -61,7 +86,7 @@ class DjangoListField(Field):
|
|||
# Pass queryset to the DjangoObjectType get_queryset method
|
||||
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
|
||||
|
||||
return queryset
|
||||
return list(queryset)
|
||||
|
||||
def wrap_resolve(self, parent_resolver):
|
||||
resolver = super().wrap_resolve(parent_resolver)
|
||||
|
@ -235,20 +260,36 @@ class DjangoConnectionField(ConnectionField):
|
|||
|
||||
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
||||
# or a resolve_foo (does not accept queryset)
|
||||
|
||||
if is_running_async():
|
||||
if is_sync_function(resolver):
|
||||
resolver = sync_to_async(resolver)
|
||||
|
||||
iterable = resolver(root, info, **args)
|
||||
|
||||
if info.is_awaitable(iterable):
|
||||
|
||||
async def resolve_connection_async(iterable):
|
||||
iterable = await iterable
|
||||
if iterable is None:
|
||||
iterable = default_manager
|
||||
|
||||
iterable = await sync_to_async(queryset_resolver)(
|
||||
connection, iterable, info, args
|
||||
)
|
||||
|
||||
return await sync_to_async(cls.resolve_connection)(
|
||||
connection, args, iterable, max_limit=max_limit
|
||||
)
|
||||
|
||||
return resolve_connection_async(iterable)
|
||||
|
||||
if iterable is None:
|
||||
iterable = default_manager
|
||||
# thus the iterable gets refiltered by resolve_queryset
|
||||
# but iterable might be promise
|
||||
iterable = queryset_resolver(connection, iterable, info, args)
|
||||
on_resolve = partial(
|
||||
cls.resolve_connection, connection, args, max_limit=max_limit
|
||||
)
|
||||
|
||||
if Promise.is_thenable(iterable):
|
||||
return Promise.resolve(iterable).then(on_resolve)
|
||||
|
||||
return on_resolve(iterable)
|
||||
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)
|
||||
|
||||
def wrap_resolve(self, parent_resolver):
|
||||
return partial(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from graphene.types.argument import to_arguments
|
||||
|
@ -92,6 +93,18 @@ class DjangoFilterConnectionField(DjangoConnectionField):
|
|||
|
||||
qs = super().resolve_queryset(connection, iterable, info, args)
|
||||
|
||||
if info.is_awaitable(qs):
|
||||
|
||||
async def filter_async(qs):
|
||||
filterset = filterset_class(
|
||||
data=filter_kwargs(), queryset=await qs, request=info.context
|
||||
)
|
||||
if await sync_to_async(filterset.is_valid)():
|
||||
return filterset.qs
|
||||
raise ValidationError(filterset.form.errors.as_json())
|
||||
|
||||
return filter_async(qs)
|
||||
|
||||
filterset = filterset_class(
|
||||
data=filter_kwargs(), queryset=qs, request=info.context
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework import serializers
|
||||
|
||||
|
@ -11,6 +12,7 @@ from graphene.types.mutation import MutationOptions
|
|||
from graphene.types.objecttype import yank_fields_from_attrs
|
||||
|
||||
from ..types import ErrorType
|
||||
from ..utils import is_running_async
|
||||
from .serializer_converter import convert_serializer_field
|
||||
|
||||
|
||||
|
@ -166,6 +168,17 @@ class SerializerMutation(ClientIDMutation):
|
|||
kwargs = cls.get_serializer_kwargs(root, info, **input)
|
||||
serializer = cls._meta.serializer_class(**kwargs)
|
||||
|
||||
if is_running_async():
|
||||
|
||||
async def perform_mutate_async():
|
||||
if await sync_to_async(serializer.is_valid)():
|
||||
return await sync_to_async(cls.perform_mutate)(serializer, info)
|
||||
else:
|
||||
errors = ErrorType.from_errors(serializer.errors)
|
||||
return cls(errors=errors)
|
||||
|
||||
return perform_mutate_async()
|
||||
|
||||
if serializer.is_valid():
|
||||
return cls.perform_mutate(serializer, info)
|
||||
else:
|
||||
|
|
6
graphene_django/tests/async_test_helper.py
Normal file
6
graphene_django/tests/async_test_helper.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from asgiref.sync import async_to_sync
|
||||
|
||||
|
||||
def assert_async_result_equal(schema, query, result, **kwargs):
|
||||
async_result = async_to_sync(schema.execute_async)(query, **kwargs)
|
||||
assert async_result == result
|
|
@ -2,6 +2,7 @@ import datetime
|
|||
import re
|
||||
|
||||
import pytest
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.db.models import Count, Model, Prefetch
|
||||
|
||||
from graphene import List, NonNull, ObjectType, Schema, String
|
||||
|
@ -9,6 +10,7 @@ from graphene.relay import Node
|
|||
|
||||
from ..fields import DjangoConnectionField, DjangoListField
|
||||
from ..types import DjangoObjectType
|
||||
from .async_test_helper import assert_async_result_equal
|
||||
from .models import (
|
||||
Article as ArticleModel,
|
||||
Film as FilmModel,
|
||||
|
@ -82,6 +84,7 @@ class TestDjangoListField:
|
|||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert_async_result_equal(schema, query, result)
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
|
@ -109,6 +112,7 @@ class TestDjangoListField:
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": []}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
@ -119,6 +123,7 @@ class TestDjangoListField:
|
|||
assert result.data == {
|
||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_override_resolver(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
|
@ -146,6 +151,35 @@ class TestDjangoListField:
|
|||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
|
||||
def test_override_resolver_async_execution(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return ReporterModel.objects.filter(first_name="Tara")
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = async_to_sync(schema.execute_async)(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
|
@ -210,6 +244,7 @@ class TestDjangoListField:
|
|||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_override_resolver_nested_list_field(self):
|
||||
class Article(DjangoObjectType):
|
||||
|
@ -268,6 +303,7 @@ class TestDjangoListField:
|
|||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_same_type_nested_list_field(self):
|
||||
class Person(DjangoObjectType):
|
||||
|
@ -376,6 +412,7 @@ class TestDjangoListField:
|
|||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_resolve_list(self):
|
||||
"""Resolving a plain list should work (and not call get_queryset)"""
|
||||
|
@ -424,6 +461,53 @@ class TestDjangoListField:
|
|||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_resolve_list_async(self):
|
||||
"""Resolving a plain list should work (and not call get_queryset) when running under async"""
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return [ReporterModel.objects.get(first_name="Debra")]
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = async_to_sync(schema.execute_async)(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_get_queryset_foreign_key(self):
|
||||
class Article(DjangoObjectType):
|
||||
class Meta:
|
||||
|
@ -483,6 +567,7 @@ class TestDjangoListField:
|
|||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_resolve_list_external_resolver(self):
|
||||
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||
|
@ -531,6 +616,53 @@ class TestDjangoListField:
|
|||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_resolve_list_external_resolver_async(self):
|
||||
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
def resolve_reporters(_, info):
|
||||
return [ReporterModel.objects.get(first_name="Debra")]
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = async_to_sync(schema.execute_async)(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_get_queryset_filter_external_resolver(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
|
@ -575,6 +707,7 @@ class TestDjangoListField:
|
|||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_select_related_and_prefetch_related_are_respected(
|
||||
self, django_assert_num_queries
|
||||
|
@ -717,6 +850,7 @@ class TestDjangoListField:
|
|||
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
|
||||
captured.captured_queries[1]["sql"],
|
||||
)
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
class TestDjangoConnectionField:
|
||||
|
|
|
@ -2,6 +2,7 @@ import base64
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
|
@ -15,6 +16,7 @@ from ..compat import IntegerRangeField, MissingType
|
|||
from ..fields import DjangoConnectionField
|
||||
from ..types import DjangoObjectType
|
||||
from ..utils import DJANGO_FILTER_INSTALLED
|
||||
from .async_test_helper import assert_async_result_equal
|
||||
from .models import (
|
||||
APNewsReporter,
|
||||
Article,
|
||||
|
@ -43,6 +45,7 @@ def test_should_query_only_fields():
|
|||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_simplelazy_objects():
|
||||
|
@ -68,6 +71,7 @@ def test_should_query_simplelazy_objects():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporter": {"id": "1"}}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_wrapped_simplelazy_objects():
|
||||
|
@ -93,6 +97,7 @@ def test_should_query_wrapped_simplelazy_objects():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporter": {"id": "1"}}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_well():
|
||||
|
@ -121,6 +126,7 @@ def test_should_query_well():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist")
|
||||
|
@ -175,6 +181,7 @@ def test_should_query_postgres_fields():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_node():
|
||||
|
@ -256,6 +263,7 @@ def test_should_node():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_onetoone_fields():
|
||||
|
@ -314,6 +322,7 @@ def test_should_query_onetoone_fields():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_connectionfields():
|
||||
|
@ -352,6 +361,7 @@ def test_should_query_connectionfields():
|
|||
"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}],
|
||||
}
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_keep_annotations():
|
||||
|
@ -411,6 +421,7 @@ def test_should_keep_annotations():
|
|||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -492,6 +503,7 @@ def test_should_query_node_filtering():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -537,6 +549,7 @@ def test_should_query_node_filtering_with_distinct_queryset():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -626,6 +639,7 @@ def test_should_query_node_multiple_filtering():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_enforce_first_or_last(graphene_settings):
|
||||
|
@ -666,6 +680,7 @@ def test_should_enforce_first_or_last(graphene_settings):
|
|||
"paginate the `allReporters` connection.\n"
|
||||
)
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_error_if_first_is_greater_than_max(graphene_settings):
|
||||
|
@ -708,6 +723,7 @@ def test_should_error_if_first_is_greater_than_max(graphene_settings):
|
|||
"exceeds the `first` limit of 100 records.\n"
|
||||
)
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_error_if_last_is_greater_than_max(graphene_settings):
|
||||
|
@ -750,6 +766,7 @@ def test_should_error_if_last_is_greater_than_max(graphene_settings):
|
|||
"exceeds the `last` limit of 100 records.\n"
|
||||
)
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_promise_connectionfields():
|
||||
|
@ -785,6 +802,7 @@ def test_should_query_promise_connectionfields():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_connectionfields_with_last():
|
||||
|
@ -822,6 +840,7 @@ def test_should_query_connectionfields_with_last():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_connectionfields_with_manager():
|
||||
|
@ -863,6 +882,7 @@ def test_should_query_connectionfields_with_manager():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_query_dataloader_fields():
|
||||
|
@ -965,6 +985,106 @@ def test_should_query_dataloader_fields():
|
|||
assert result.data == expected
|
||||
|
||||
|
||||
def test_should_query_dataloader_fields_async():
|
||||
from promise import Promise
|
||||
from promise.dataloader import DataLoader
|
||||
|
||||
def article_batch_load_fn(keys):
|
||||
queryset = Article.objects.filter(reporter_id__in=keys)
|
||||
return Promise.resolve(
|
||||
[
|
||||
[article for article in queryset if article.reporter_id == id]
|
||||
for id in keys
|
||||
]
|
||||
)
|
||||
|
||||
article_loader = DataLoader(article_batch_load_fn)
|
||||
|
||||
class ArticleType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (Node,)
|
||||
fields = "__all__"
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
use_connection = True
|
||||
fields = "__all__"
|
||||
|
||||
articles = DjangoConnectionField(ArticleType)
|
||||
|
||||
def resolve_articles(self, info, **args):
|
||||
return article_loader.load(self.id).get()
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = DjangoConnectionField(ReporterType)
|
||||
|
||||
r = Reporter.objects.create(
|
||||
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
|
||||
)
|
||||
|
||||
Article.objects.create(
|
||||
headline="Article Node 1",
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
reporter=r,
|
||||
editor=r,
|
||||
lang="es",
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="Article Node 2",
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
reporter=r,
|
||||
editor=r,
|
||||
lang="en",
|
||||
)
|
||||
|
||||
schema = graphene.Schema(query=Query)
|
||||
query = """
|
||||
query ReporterPromiseConnectionQuery {
|
||||
allReporters(first: 1) {
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
articles(first: 2) {
|
||||
edges {
|
||||
node {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
expected = {
|
||||
"allReporters": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"id": "UmVwb3J0ZXJUeXBlOjE=",
|
||||
"articles": {
|
||||
"edges": [
|
||||
{"node": {"headline": "Article Node 1"}},
|
||||
{"node": {"headline": "Article Node 2"}},
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
result = async_to_sync(schema.execute_async)(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
def test_should_handle_inherited_choices():
|
||||
class BaseModel(models.Model):
|
||||
choice_field = models.IntegerField(choices=((0, "zero"), (1, "one")))
|
||||
|
@ -1071,6 +1191,7 @@ def test_proxy_model_support():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_model_inheritance_support_reverse_relationships():
|
||||
|
@ -1411,6 +1532,7 @@ def test_should_resolve_get_queryset_connectionfields():
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_connection_should_limit_after_to_list_length():
|
||||
|
@ -1448,6 +1570,7 @@ def test_connection_should_limit_after_to_list_length():
|
|||
expected = {"allReporters": {"edges": []}}
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result, variable_values={"after": after})
|
||||
|
||||
|
||||
REPORTERS = [
|
||||
|
@ -1491,6 +1614,7 @@ def test_should_return_max_limit(graphene_settings):
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 4
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_have_next_page(graphene_settings):
|
||||
|
@ -1529,6 +1653,7 @@ def test_should_have_next_page(graphene_settings):
|
|||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 4
|
||||
assert result.data["allReporters"]["pageInfo"]["hasNextPage"]
|
||||
assert_async_result_equal(schema, query, result, variable_values={})
|
||||
|
||||
last_result = result.data["allReporters"]["pageInfo"]["endCursor"]
|
||||
result2 = schema.execute(query, variable_values={"first": 4, "after": last_result})
|
||||
|
@ -1542,6 +1667,9 @@ def test_should_have_next_page(graphene_settings):
|
|||
assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == {
|
||||
gql_reporter["node"]["id"] for gql_reporter in gql_reporters
|
||||
}
|
||||
assert_async_result_equal(
|
||||
schema, query, result2, variable_values={"first": 4, "after": last_result}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_limit", [100, 4])
|
||||
|
@ -1565,7 +1693,7 @@ class TestBackwardPagination:
|
|||
|
||||
def test_query_last(self, graphene_settings, max_limit):
|
||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||
query_last = """
|
||||
query = """
|
||||
query {
|
||||
allReporters(last: 3) {
|
||||
edges {
|
||||
|
@ -1577,16 +1705,17 @@ class TestBackwardPagination:
|
|||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query_last)
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 3
|
||||
assert [
|
||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||
] == ["First 3", "First 4", "First 5"]
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_query_first_and_last(self, graphene_settings, max_limit):
|
||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||
query_first_and_last = """
|
||||
query = """
|
||||
query {
|
||||
allReporters(first: 4, last: 3) {
|
||||
edges {
|
||||
|
@ -1598,12 +1727,13 @@ class TestBackwardPagination:
|
|||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query_first_and_last)
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 3
|
||||
assert [
|
||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||
] == ["First 1", "First 2", "First 3"]
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_query_first_last_and_after(self, graphene_settings, max_limit):
|
||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||
|
@ -1629,6 +1759,9 @@ class TestBackwardPagination:
|
|||
assert [
|
||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||
] == ["First 2", "First 3", "First 4"]
|
||||
assert_async_result_equal(
|
||||
schema, query_first_last_and_after, result, variable_values={"after": after}
|
||||
)
|
||||
|
||||
def test_query_last_and_before(self, graphene_settings, max_limit):
|
||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||
|
@ -1650,6 +1783,7 @@ class TestBackwardPagination:
|
|||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 1
|
||||
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5"
|
||||
assert_async_result_equal(schema, query_first_last_and_after, result)
|
||||
|
||||
before = base64.b64encode(b"arrayconnection:5").decode()
|
||||
result = schema.execute(
|
||||
|
@ -1659,6 +1793,12 @@ class TestBackwardPagination:
|
|||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 1
|
||||
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4"
|
||||
assert_async_result_equal(
|
||||
schema,
|
||||
query_first_last_and_after,
|
||||
result,
|
||||
variable_values={"before": before},
|
||||
)
|
||||
|
||||
|
||||
def test_should_preserve_prefetch_related(django_assert_num_queries):
|
||||
|
@ -1713,6 +1853,7 @@ def test_should_preserve_prefetch_related(django_assert_num_queries):
|
|||
with django_assert_num_queries(3):
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_should_preserve_annotations():
|
||||
|
@ -1768,6 +1909,7 @@ def test_should_preserve_annotations():
|
|||
}
|
||||
assert result.data == expected, str(result.data)
|
||||
assert not result.errors
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_connection_should_enable_offset_filtering():
|
||||
|
@ -1807,6 +1949,7 @@ def test_connection_should_enable_offset_filtering():
|
|||
}
|
||||
}
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_connection_should_enable_offset_filtering_higher_than_max_limit(
|
||||
|
@ -1851,6 +1994,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
|
|||
}
|
||||
}
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
|
||||
def test_connection_should_forbid_offset_filtering_with_before():
|
||||
|
@ -1881,6 +2025,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
|
|||
expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection."
|
||||
assert len(result.errors) == 1
|
||||
assert result.errors[0].message == expected_error
|
||||
assert_async_result_equal(schema, query, result, variable_values={"before": before})
|
||||
|
||||
|
||||
def test_connection_should_allow_offset_filtering_with_after():
|
||||
|
@ -1923,6 +2068,7 @@ def test_connection_should_allow_offset_filtering_with_after():
|
|||
}
|
||||
}
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result, variable_values={"after": after})
|
||||
|
||||
|
||||
def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
||||
|
@ -1953,6 +2099,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
|||
assert not result.errors
|
||||
expected = {"allReporters": {"edges": []}}
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result, variable_values={"last": 2})
|
||||
|
||||
Reporter.objects.create(first_name="John", last_name="Doe")
|
||||
Reporter.objects.create(first_name="Some", last_name="Guy")
|
||||
|
@ -1970,6 +2117,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
|||
}
|
||||
}
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result, variable_values={"last": 2})
|
||||
|
||||
result = schema.execute(query, variable_values={"last": 4})
|
||||
assert not result.errors
|
||||
|
@ -1984,6 +2132,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
|||
}
|
||||
}
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result, variable_values={"last": 4})
|
||||
|
||||
result = schema.execute(query, variable_values={"last": 20})
|
||||
assert not result.errors
|
||||
|
@ -1998,6 +2147,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
|||
}
|
||||
}
|
||||
assert result.data == expected
|
||||
assert_async_result_equal(schema, query, result, variable_values={"last": 20})
|
||||
|
||||
|
||||
def test_should_query_nullable_foreign_key():
|
||||
|
|
|
@ -16,6 +16,7 @@ from .utils import (
|
|||
DJANGO_FILTER_INSTALLED,
|
||||
camelize,
|
||||
get_model_fields,
|
||||
is_running_async,
|
||||
is_valid_django_model,
|
||||
)
|
||||
|
||||
|
@ -288,7 +289,11 @@ class DjangoObjectType(ObjectType):
|
|||
def get_node(cls, info, id):
|
||||
queryset = cls.get_queryset(cls._meta.model.objects, info)
|
||||
try:
|
||||
if is_running_async():
|
||||
return queryset.aget(pk=id)
|
||||
|
||||
return queryset.get(pk=id)
|
||||
|
||||
except cls._meta.model.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ from .utils import (
|
|||
camelize,
|
||||
get_model_fields,
|
||||
get_reverse_fields,
|
||||
is_running_async,
|
||||
is_sync_function,
|
||||
is_valid_django_model,
|
||||
maybe_queryset,
|
||||
)
|
||||
|
@ -17,5 +19,7 @@ __all__ = [
|
|||
"camelize",
|
||||
"is_valid_django_model",
|
||||
"GraphQLTestCase",
|
||||
"is_sync_function",
|
||||
"is_running_async",
|
||||
"bypass_get_queryset",
|
||||
]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
from asyncio import get_running_loop
|
||||
|
||||
import django
|
||||
from django.db import connection, models, transaction
|
||||
|
@ -139,6 +140,21 @@ def set_rollback():
|
|||
transaction.set_rollback(True)
|
||||
|
||||
|
||||
def is_running_async():
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_sync_function(func):
|
||||
return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
||||
func
|
||||
)
|
||||
|
||||
|
||||
def bypass_get_queryset(resolver):
|
||||
"""
|
||||
Adds a bypass_get_queryset attribute to the resolver, which is used to
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import inspect
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from asyncio import coroutines, gather
|
||||
|
||||
from django.db import connection, transaction
|
||||
from django.http import HttpResponse, HttpResponseNotAllowed
|
||||
from django.http.response import HttpResponseBadRequest
|
||||
from django.shortcuts import render
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.utils.decorators import classonlymethod, method_decorator
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
from django.views.generic import View
|
||||
from graphql import (
|
||||
|
@ -431,3 +433,336 @@ class GraphQLView(View):
|
|||
meta = request.META
|
||||
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
|
||||
return content_type.split(";", 1)[0].lower()
|
||||
|
||||
|
||||
class AsyncGraphQLView(GraphQLView):
|
||||
schema = None
|
||||
graphiql = False
|
||||
middleware = None
|
||||
root_value = None
|
||||
pretty = False
|
||||
batch = False
|
||||
subscription_path = None
|
||||
execution_context_class = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema=None,
|
||||
middleware=None,
|
||||
root_value=None,
|
||||
graphiql=False,
|
||||
pretty=False,
|
||||
batch=False,
|
||||
subscription_path=None,
|
||||
execution_context_class=None,
|
||||
):
|
||||
if not schema:
|
||||
schema = graphene_settings.SCHEMA
|
||||
|
||||
if middleware is None:
|
||||
middleware = graphene_settings.MIDDLEWARE
|
||||
|
||||
self.schema = self.schema or schema
|
||||
if middleware is not None:
|
||||
if isinstance(middleware, MiddlewareManager):
|
||||
self.middleware = middleware
|
||||
else:
|
||||
self.middleware = list(instantiate_middleware(middleware))
|
||||
self.root_value = root_value
|
||||
self.pretty = self.pretty or pretty
|
||||
self.graphiql = self.graphiql or graphiql
|
||||
self.batch = self.batch or batch
|
||||
self.execution_context_class = execution_context_class
|
||||
if subscription_path is None:
|
||||
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
|
||||
|
||||
assert isinstance(
|
||||
self.schema, Schema
|
||||
), "A Schema is required to be provided to GraphQLView."
|
||||
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
def get_root_value(self, request):
|
||||
return self.root_value
|
||||
|
||||
def get_middleware(self, request):
|
||||
return self.middleware
|
||||
|
||||
def get_context(self, request):
|
||||
return request
|
||||
|
||||
@classonlymethod
|
||||
def as_view(cls, **initkwargs):
|
||||
view = super().as_view(**initkwargs)
|
||||
view._is_coroutine = coroutines._is_coroutine
|
||||
return view
|
||||
|
||||
async def dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
if request.method.lower() not in ("get", "post"):
|
||||
raise HttpError(
|
||||
HttpResponseNotAllowed(
|
||||
["GET", "POST"], "GraphQL only supports GET and POST requests."
|
||||
)
|
||||
)
|
||||
|
||||
data = self.parse_body(request)
|
||||
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
|
||||
|
||||
if show_graphiql:
|
||||
return self.render_graphiql(
|
||||
request,
|
||||
# Dependency parameters.
|
||||
whatwg_fetch_version=self.whatwg_fetch_version,
|
||||
whatwg_fetch_sri=self.whatwg_fetch_sri,
|
||||
react_version=self.react_version,
|
||||
react_sri=self.react_sri,
|
||||
react_dom_sri=self.react_dom_sri,
|
||||
graphiql_version=self.graphiql_version,
|
||||
graphiql_sri=self.graphiql_sri,
|
||||
graphiql_css_sri=self.graphiql_css_sri,
|
||||
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
|
||||
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
|
||||
graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version,
|
||||
graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri,
|
||||
# The SUBSCRIPTION_PATH setting.
|
||||
subscription_path=self.subscription_path,
|
||||
# GraphiQL headers tab,
|
||||
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
|
||||
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
|
||||
)
|
||||
|
||||
if self.batch:
|
||||
responses = await gather(
|
||||
*[self.get_response(request, entry) for entry in data]
|
||||
)
|
||||
result = "[{}]".format(
|
||||
",".join([response[0] for response in responses])
|
||||
)
|
||||
status_code = (
|
||||
responses
|
||||
and max(responses, key=lambda response: response[1])[1]
|
||||
or 200
|
||||
)
|
||||
else:
|
||||
result, status_code = await self.get_response(
|
||||
request, data, show_graphiql
|
||||
)
|
||||
|
||||
return HttpResponse(
|
||||
status=status_code, content=result, content_type="application/json"
|
||||
)
|
||||
|
||||
except HttpError as e:
|
||||
response = e.response
|
||||
response["Content-Type"] = "application/json"
|
||||
response.content = self.json_encode(
|
||||
request, {"errors": [self.format_error(e)]}
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_response(self, request, data, show_graphiql=False):
|
||||
query, variables, operation_name, id = self.get_graphql_params(request, data)
|
||||
|
||||
execution_result = await self.execute_graphql_request(
|
||||
request, data, query, variables, operation_name, show_graphiql
|
||||
)
|
||||
|
||||
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||
set_rollback()
|
||||
|
||||
status_code = 200
|
||||
if execution_result:
|
||||
response = {}
|
||||
|
||||
if execution_result.errors:
|
||||
for e in execution_result.errors:
|
||||
print(e)
|
||||
traceback.print_tb(e.__traceback__)
|
||||
set_rollback()
|
||||
response["errors"] = [
|
||||
self.format_error(e) for e in execution_result.errors
|
||||
]
|
||||
|
||||
if execution_result.errors and any(
|
||||
not getattr(e, "path", None) for e in execution_result.errors
|
||||
):
|
||||
status_code = 400
|
||||
else:
|
||||
response["data"] = execution_result.data
|
||||
|
||||
if self.batch:
|
||||
response["id"] = id
|
||||
response["status"] = status_code
|
||||
|
||||
result = self.json_encode(request, response, pretty=show_graphiql)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result, status_code
|
||||
|
||||
def render_graphiql(self, request, **data):
|
||||
return render(request, self.graphiql_template, data)
|
||||
|
||||
def json_encode(self, request, d, pretty=False):
|
||||
if not (self.pretty or pretty) and not request.GET.get("pretty"):
|
||||
return json.dumps(d, separators=(",", ":"))
|
||||
|
||||
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
|
||||
|
||||
def parse_body(self, request):
|
||||
content_type = self.get_content_type(request)
|
||||
|
||||
if content_type == "application/graphql":
|
||||
return {"query": request.body.decode()}
|
||||
|
||||
elif content_type == "application/json":
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
body = request.body.decode("utf-8")
|
||||
except Exception as e:
|
||||
raise HttpError(HttpResponseBadRequest(str(e)))
|
||||
|
||||
try:
|
||||
request_json = json.loads(body)
|
||||
if self.batch:
|
||||
assert isinstance(request_json, list), (
|
||||
"Batch requests should receive a list, but received {}."
|
||||
).format(repr(request_json))
|
||||
assert (
|
||||
len(request_json) > 0
|
||||
), "Received an empty list in the batch request."
|
||||
else:
|
||||
assert isinstance(
|
||||
request_json, dict
|
||||
), "The received data is not a valid JSON query."
|
||||
return request_json
|
||||
except AssertionError as e:
|
||||
raise HttpError(HttpResponseBadRequest(str(e)))
|
||||
except (TypeError, ValueError):
|
||||
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
|
||||
|
||||
elif content_type in [
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
]:
|
||||
return request.POST
|
||||
|
||||
return {}
|
||||
|
||||
async def execute_graphql_request(
|
||||
self, request, data, query, variables, operation_name, show_graphiql=False
|
||||
):
|
||||
if not query:
|
||||
if show_graphiql:
|
||||
return None
|
||||
raise HttpError(HttpResponseBadRequest("Must provide query string."))
|
||||
|
||||
try:
|
||||
document = parse(query)
|
||||
except Exception as e:
|
||||
return ExecutionResult(errors=[e])
|
||||
|
||||
if request.method.lower() == "get":
|
||||
operation_ast = get_operation_ast(document, operation_name)
|
||||
if operation_ast and operation_ast.operation != OperationType.QUERY:
|
||||
if show_graphiql:
|
||||
return None
|
||||
|
||||
raise HttpError(
|
||||
HttpResponseNotAllowed(
|
||||
["POST"],
|
||||
"Can only perform a {} operation from a POST request.".format(
|
||||
operation_ast.operation.value
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
extra_options = {}
|
||||
if self.execution_context_class:
|
||||
extra_options["execution_context_class"] = self.execution_context_class
|
||||
|
||||
options = {
|
||||
"source": query,
|
||||
"root_value": self.get_root_value(request),
|
||||
"variable_values": variables,
|
||||
"operation_name": operation_name,
|
||||
"context_value": self.get_context(request),
|
||||
"middleware": self.get_middleware(request),
|
||||
}
|
||||
options.update(extra_options)
|
||||
|
||||
operation_ast = get_operation_ast(document, operation_name)
|
||||
if (
|
||||
operation_ast
|
||||
and operation_ast.operation == OperationType.MUTATION
|
||||
and (
|
||||
graphene_settings.ATOMIC_MUTATIONS is True
|
||||
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
|
||||
)
|
||||
):
|
||||
with transaction.atomic():
|
||||
result = await self.schema.execute_async(**options)
|
||||
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||
transaction.set_rollback(True)
|
||||
return result
|
||||
|
||||
return await self.schema.execute_async(**options)
|
||||
except Exception as e:
|
||||
return ExecutionResult(errors=[e])
|
||||
|
||||
@classmethod
|
||||
def can_display_graphiql(cls, request, data):
|
||||
raw = "raw" in request.GET or "raw" in data
|
||||
return not raw and cls.request_wants_html(request)
|
||||
|
||||
@classmethod
|
||||
def request_wants_html(cls, request):
|
||||
accepted = get_accepted_content_types(request)
|
||||
accepted_length = len(accepted)
|
||||
# the list will be ordered in preferred first - so we have to make
|
||||
# sure the most preferred gets the highest number
|
||||
html_priority = (
|
||||
accepted_length - accepted.index("text/html")
|
||||
if "text/html" in accepted
|
||||
else 0
|
||||
)
|
||||
json_priority = (
|
||||
accepted_length - accepted.index("application/json")
|
||||
if "application/json" in accepted
|
||||
else 0
|
||||
)
|
||||
|
||||
return html_priority > json_priority
|
||||
|
||||
@staticmethod
|
||||
def get_graphql_params(request, data):
|
||||
query = request.GET.get("query") or data.get("query")
|
||||
variables = request.GET.get("variables") or data.get("variables")
|
||||
id = request.GET.get("id") or data.get("id")
|
||||
|
||||
if variables and isinstance(variables, str):
|
||||
try:
|
||||
variables = json.loads(variables)
|
||||
except Exception:
|
||||
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
|
||||
|
||||
operation_name = request.GET.get("operationName") or data.get("operationName")
|
||||
if operation_name == "null":
|
||||
operation_name = None
|
||||
|
||||
return query, variables, operation_name, id
|
||||
|
||||
@staticmethod
|
||||
def format_error(error):
|
||||
if isinstance(error, GraphQLError):
|
||||
return error.formatted
|
||||
|
||||
return {"message": str(error)}
|
||||
|
||||
@staticmethod
|
||||
def get_content_type(request):
|
||||
meta = request.META
|
||||
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
|
||||
return content_type.split(";", 1)[0].lower()
|
||||
|
|
Loading…
Reference in New Issue
Block a user