mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-07-17 11:42:33 +03:00
Merge c2d601c0e2
into ea45de02ad
This commit is contained in:
commit
a8e6217150
|
@ -1,4 +1,8 @@
|
||||||
from graphene import Node
|
import asyncio
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
from graphene import Field, Node, String
|
||||||
from graphene_django.filter import DjangoFilterConnectionField
|
from graphene_django.filter import DjangoFilterConnectionField
|
||||||
from graphene_django.types import DjangoObjectType
|
from graphene_django.types import DjangoObjectType
|
||||||
|
|
||||||
|
@ -6,12 +10,32 @@ from cookbook.recipes.models import Recipe, RecipeIngredient
|
||||||
|
|
||||||
|
|
||||||
class RecipeNode(DjangoObjectType):
|
class RecipeNode(DjangoObjectType):
|
||||||
|
async_field = String()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Recipe
|
model = Recipe
|
||||||
interfaces = (Node,)
|
interfaces = (Node,)
|
||||||
fields = "__all__"
|
fields = "__all__"
|
||||||
filter_fields = ["title", "amounts"]
|
filter_fields = ["title", "amounts"]
|
||||||
|
|
||||||
|
async def resolve_async_field(self, info):
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeType(DjangoObjectType):
|
||||||
|
async_field = String()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Recipe
|
||||||
|
fields = "__all__"
|
||||||
|
filter_fields = ["title", "amounts"]
|
||||||
|
skip_registry = True
|
||||||
|
|
||||||
|
async def resolve_async_field(self, info):
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
class RecipeIngredientNode(DjangoObjectType):
|
class RecipeIngredientNode(DjangoObjectType):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -28,7 +52,13 @@ class RecipeIngredientNode(DjangoObjectType):
|
||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
recipe = Node.Field(RecipeNode)
|
recipe = Node.Field(RecipeNode)
|
||||||
|
raw_recipe = Field(RecipeType)
|
||||||
all_recipes = DjangoFilterConnectionField(RecipeNode)
|
all_recipes = DjangoFilterConnectionField(RecipeNode)
|
||||||
|
|
||||||
recipeingredient = Node.Field(RecipeIngredientNode)
|
recipeingredient = Node.Field(RecipeIngredientNode)
|
||||||
all_recipeingredients = DjangoFilterConnectionField(RecipeIngredientNode)
|
all_recipeingredients = DjangoFilterConnectionField(RecipeIngredientNode)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@sync_to_async
|
||||||
|
def resolve_raw_recipe(self, info):
|
||||||
|
return Recipe.objects.first()
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from django.conf.urls import url
|
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
|
from django.urls import re_path
|
||||||
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
|
|
||||||
from graphene_django.views import GraphQLView
|
from graphene_django.views import AsyncGraphQLView
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
url(r"^admin/", admin.site.urls),
|
re_path(r"^admin/", admin.site.urls),
|
||||||
url(r"^graphql$", GraphQLView.as_view(graphiql=True)),
|
re_path(r"^graphql$", csrf_exempt(AsyncGraphQLView.as_view(graphiql=True))),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
|
from graphql.type.definition import GraphQLNonNull
|
||||||
|
|
||||||
|
from ..utils import is_running_async, is_sync_function
|
||||||
from .exception.formating import wrap_exception
|
from .exception.formating import wrap_exception
|
||||||
from .sql.tracking import unwrap_cursor, wrap_cursor
|
from .sql.tracking import unwrap_cursor, wrap_cursor
|
||||||
from .types import DjangoDebug
|
from .types import DjangoDebug
|
||||||
|
@ -67,3 +70,28 @@ class DjangoDebugMiddleware:
|
||||||
return context.django_debug.on_resolve_error(e)
|
return context.django_debug.on_resolve_error(e)
|
||||||
context.django_debug.add_result(result)
|
context.django_debug.add_result(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoSyncRequiredMiddleware:
|
||||||
|
def resolve(self, next, root, info, **args):
|
||||||
|
parent_type = info.parent_type
|
||||||
|
return_type = info.return_type
|
||||||
|
|
||||||
|
if isinstance(parent_type, GraphQLNonNull):
|
||||||
|
parent_type = parent_type.of_type
|
||||||
|
if isinstance(return_type, GraphQLNonNull):
|
||||||
|
return_type = return_type.of_type
|
||||||
|
|
||||||
|
if any(
|
||||||
|
[
|
||||||
|
hasattr(parent_type, "graphene_type")
|
||||||
|
and hasattr(parent_type.graphene_type._meta, "model"),
|
||||||
|
hasattr(return_type, "graphene_type")
|
||||||
|
and hasattr(return_type.graphene_type._meta, "model"),
|
||||||
|
info.parent_type.name == "Mutation",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
if is_sync_function(next) and is_running_async():
|
||||||
|
return sync_to_async(next)(root, info, **args)
|
||||||
|
|
||||||
|
return next(root, info, **args)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from graphql_relay import (
|
from graphql_relay import (
|
||||||
connection_from_array_slice,
|
connection_from_array_slice,
|
||||||
|
@ -7,7 +8,6 @@ from graphql_relay import (
|
||||||
get_offset_with_default,
|
get_offset_with_default,
|
||||||
offset_to_cursor,
|
offset_to_cursor,
|
||||||
)
|
)
|
||||||
from promise import Promise
|
|
||||||
|
|
||||||
from graphene import Int, NonNull
|
from graphene import Int, NonNull
|
||||||
from graphene.relay import ConnectionField
|
from graphene.relay import ConnectionField
|
||||||
|
@ -15,7 +15,7 @@ from graphene.relay.connection import connection_adapter, page_info_adapter
|
||||||
from graphene.types import Field, List
|
from graphene.types import Field, List
|
||||||
|
|
||||||
from .settings import graphene_settings
|
from .settings import graphene_settings
|
||||||
from .utils import maybe_queryset
|
from .utils import is_running_async, is_sync_function, maybe_queryset
|
||||||
|
|
||||||
|
|
||||||
class DjangoListField(Field):
|
class DjangoListField(Field):
|
||||||
|
@ -49,11 +49,36 @@ class DjangoListField(Field):
|
||||||
def get_manager(self):
|
def get_manager(self):
|
||||||
return self.model._default_manager
|
return self.model._default_manager
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def list_resolver(
|
def list_resolver(
|
||||||
django_object_type, resolver, default_manager, root, info, **args
|
cls, django_object_type, resolver, default_manager, root, info, **args
|
||||||
):
|
):
|
||||||
queryset = maybe_queryset(resolver(root, info, **args))
|
if is_running_async():
|
||||||
|
if is_sync_function(resolver):
|
||||||
|
resolver = sync_to_async(resolver)
|
||||||
|
|
||||||
|
iterable = resolver(root, info, **args)
|
||||||
|
|
||||||
|
if info.is_awaitable(iterable):
|
||||||
|
|
||||||
|
async def resolve_list_async(iterable):
|
||||||
|
queryset = maybe_queryset(await iterable)
|
||||||
|
if queryset is None:
|
||||||
|
queryset = maybe_queryset(default_manager)
|
||||||
|
|
||||||
|
if isinstance(queryset, QuerySet):
|
||||||
|
# Pass queryset to the DjangoObjectType get_queryset method
|
||||||
|
queryset = maybe_queryset(
|
||||||
|
await sync_to_async(django_object_type.get_queryset)(
|
||||||
|
queryset, info
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return await sync_to_async(list)(queryset)
|
||||||
|
|
||||||
|
return resolve_list_async(iterable)
|
||||||
|
|
||||||
|
queryset = maybe_queryset(iterable)
|
||||||
if queryset is None:
|
if queryset is None:
|
||||||
queryset = maybe_queryset(default_manager)
|
queryset = maybe_queryset(default_manager)
|
||||||
|
|
||||||
|
@ -61,7 +86,7 @@ class DjangoListField(Field):
|
||||||
# Pass queryset to the DjangoObjectType get_queryset method
|
# Pass queryset to the DjangoObjectType get_queryset method
|
||||||
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
|
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
|
||||||
|
|
||||||
return queryset
|
return list(queryset)
|
||||||
|
|
||||||
def wrap_resolve(self, parent_resolver):
|
def wrap_resolve(self, parent_resolver):
|
||||||
resolver = super().wrap_resolve(parent_resolver)
|
resolver = super().wrap_resolve(parent_resolver)
|
||||||
|
@ -235,20 +260,36 @@ class DjangoConnectionField(ConnectionField):
|
||||||
|
|
||||||
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
||||||
# or a resolve_foo (does not accept queryset)
|
# or a resolve_foo (does not accept queryset)
|
||||||
|
|
||||||
|
if is_running_async():
|
||||||
|
if is_sync_function(resolver):
|
||||||
|
resolver = sync_to_async(resolver)
|
||||||
|
|
||||||
iterable = resolver(root, info, **args)
|
iterable = resolver(root, info, **args)
|
||||||
|
|
||||||
|
if info.is_awaitable(iterable):
|
||||||
|
|
||||||
|
async def resolve_connection_async(iterable):
|
||||||
|
iterable = await iterable
|
||||||
|
if iterable is None:
|
||||||
|
iterable = default_manager
|
||||||
|
|
||||||
|
iterable = await sync_to_async(queryset_resolver)(
|
||||||
|
connection, iterable, info, args
|
||||||
|
)
|
||||||
|
|
||||||
|
return await sync_to_async(cls.resolve_connection)(
|
||||||
|
connection, args, iterable, max_limit=max_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
return resolve_connection_async(iterable)
|
||||||
|
|
||||||
if iterable is None:
|
if iterable is None:
|
||||||
iterable = default_manager
|
iterable = default_manager
|
||||||
# thus the iterable gets refiltered by resolve_queryset
|
# thus the iterable gets refiltered by resolve_queryset
|
||||||
# but iterable might be promise
|
# but iterable might be promise
|
||||||
iterable = queryset_resolver(connection, iterable, info, args)
|
iterable = queryset_resolver(connection, iterable, info, args)
|
||||||
on_resolve = partial(
|
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)
|
||||||
cls.resolve_connection, connection, args, max_limit=max_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
if Promise.is_thenable(iterable):
|
|
||||||
return Promise.resolve(iterable).then(on_resolve)
|
|
||||||
|
|
||||||
return on_resolve(iterable)
|
|
||||||
|
|
||||||
def wrap_resolve(self, parent_resolver):
|
def wrap_resolve(self, parent_resolver):
|
||||||
return partial(
|
return partial(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
|
|
||||||
from graphene.types.argument import to_arguments
|
from graphene.types.argument import to_arguments
|
||||||
|
@ -92,6 +93,18 @@ class DjangoFilterConnectionField(DjangoConnectionField):
|
||||||
|
|
||||||
qs = super().resolve_queryset(connection, iterable, info, args)
|
qs = super().resolve_queryset(connection, iterable, info, args)
|
||||||
|
|
||||||
|
if info.is_awaitable(qs):
|
||||||
|
|
||||||
|
async def filter_async(qs):
|
||||||
|
filterset = filterset_class(
|
||||||
|
data=filter_kwargs(), queryset=await qs, request=info.context
|
||||||
|
)
|
||||||
|
if await sync_to_async(filterset.is_valid)():
|
||||||
|
return filterset.qs
|
||||||
|
raise ValidationError(filterset.form.errors.as_json())
|
||||||
|
|
||||||
|
return filter_async(qs)
|
||||||
|
|
||||||
filterset = filterset_class(
|
filterset = filterset_class(
|
||||||
data=filter_kwargs(), queryset=qs, request=info.context
|
data=filter_kwargs(), queryset=qs, request=info.context
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
@ -11,6 +12,7 @@ from graphene.types.mutation import MutationOptions
|
||||||
from graphene.types.objecttype import yank_fields_from_attrs
|
from graphene.types.objecttype import yank_fields_from_attrs
|
||||||
|
|
||||||
from ..types import ErrorType
|
from ..types import ErrorType
|
||||||
|
from ..utils import is_running_async
|
||||||
from .serializer_converter import convert_serializer_field
|
from .serializer_converter import convert_serializer_field
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,6 +168,17 @@ class SerializerMutation(ClientIDMutation):
|
||||||
kwargs = cls.get_serializer_kwargs(root, info, **input)
|
kwargs = cls.get_serializer_kwargs(root, info, **input)
|
||||||
serializer = cls._meta.serializer_class(**kwargs)
|
serializer = cls._meta.serializer_class(**kwargs)
|
||||||
|
|
||||||
|
if is_running_async():
|
||||||
|
|
||||||
|
async def perform_mutate_async():
|
||||||
|
if await sync_to_async(serializer.is_valid)():
|
||||||
|
return await sync_to_async(cls.perform_mutate)(serializer, info)
|
||||||
|
else:
|
||||||
|
errors = ErrorType.from_errors(serializer.errors)
|
||||||
|
return cls(errors=errors)
|
||||||
|
|
||||||
|
return perform_mutate_async()
|
||||||
|
|
||||||
if serializer.is_valid():
|
if serializer.is_valid():
|
||||||
return cls.perform_mutate(serializer, info)
|
return cls.perform_mutate(serializer, info)
|
||||||
else:
|
else:
|
||||||
|
|
6
graphene_django/tests/async_test_helper.py
Normal file
6
graphene_django/tests/async_test_helper.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
|
||||||
|
|
||||||
|
def assert_async_result_equal(schema, query, result, **kwargs):
|
||||||
|
async_result = async_to_sync(schema.execute_async)(query, **kwargs)
|
||||||
|
assert async_result == result
|
|
@ -2,6 +2,7 @@ import datetime
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
from django.db.models import Count, Model, Prefetch
|
from django.db.models import Count, Model, Prefetch
|
||||||
|
|
||||||
from graphene import List, NonNull, ObjectType, Schema, String
|
from graphene import List, NonNull, ObjectType, Schema, String
|
||||||
|
@ -9,6 +10,7 @@ from graphene.relay import Node
|
||||||
|
|
||||||
from ..fields import DjangoConnectionField, DjangoListField
|
from ..fields import DjangoConnectionField, DjangoListField
|
||||||
from ..types import DjangoObjectType
|
from ..types import DjangoObjectType
|
||||||
|
from .async_test_helper import assert_async_result_equal
|
||||||
from .models import (
|
from .models import (
|
||||||
Article as ArticleModel,
|
Article as ArticleModel,
|
||||||
Film as FilmModel,
|
Film as FilmModel,
|
||||||
|
@ -82,6 +84,7 @@ class TestDjangoListField:
|
||||||
|
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
|
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {
|
assert result.data == {
|
||||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||||
|
@ -109,6 +112,7 @@ class TestDjangoListField:
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": []}
|
assert result.data == {"reporters": []}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
@ -119,6 +123,7 @@ class TestDjangoListField:
|
||||||
assert result.data == {
|
assert result.data == {
|
||||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_override_resolver(self):
|
def test_override_resolver(self):
|
||||||
class Reporter(DjangoObjectType):
|
class Reporter(DjangoObjectType):
|
||||||
|
@ -146,6 +151,35 @@ class TestDjangoListField:
|
||||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
|
||||||
|
def test_override_resolver_async_execution(self):
|
||||||
|
class Reporter(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = ReporterModel
|
||||||
|
fields = ("first_name",)
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
reporters = DjangoListField(Reporter)
|
||||||
|
|
||||||
|
def resolve_reporters(_, info):
|
||||||
|
return ReporterModel.objects.filter(first_name="Tara")
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
reporters {
|
||||||
|
firstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
|
result = async_to_sync(schema.execute_async)(query)
|
||||||
|
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
@ -210,6 +244,7 @@ class TestDjangoListField:
|
||||||
{"firstName": "Debra", "articles": []},
|
{"firstName": "Debra", "articles": []},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_override_resolver_nested_list_field(self):
|
def test_override_resolver_nested_list_field(self):
|
||||||
class Article(DjangoObjectType):
|
class Article(DjangoObjectType):
|
||||||
|
@ -268,6 +303,7 @@ class TestDjangoListField:
|
||||||
{"firstName": "Debra", "articles": []},
|
{"firstName": "Debra", "articles": []},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_same_type_nested_list_field(self):
|
def test_same_type_nested_list_field(self):
|
||||||
class Person(DjangoObjectType):
|
class Person(DjangoObjectType):
|
||||||
|
@ -376,6 +412,7 @@ class TestDjangoListField:
|
||||||
|
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_resolve_list(self):
|
def test_resolve_list(self):
|
||||||
"""Resolving a plain list should work (and not call get_queryset)"""
|
"""Resolving a plain list should work (and not call get_queryset)"""
|
||||||
|
@ -424,6 +461,53 @@ class TestDjangoListField:
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
|
def test_resolve_list_async(self):
|
||||||
|
"""Resolving a plain list should work (and not call get_queryset) when running under async"""
|
||||||
|
|
||||||
|
class Reporter(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = ReporterModel
|
||||||
|
fields = ("first_name", "articles")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_queryset(cls, queryset, info):
|
||||||
|
# Only get reporters with at least 1 article
|
||||||
|
return queryset.annotate(article_count=Count("articles")).filter(
|
||||||
|
article_count__gt=0
|
||||||
|
)
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
reporters = DjangoListField(Reporter)
|
||||||
|
|
||||||
|
def resolve_reporters(_, info):
|
||||||
|
return [ReporterModel.objects.get(first_name="Debra")]
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
reporters {
|
||||||
|
firstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
|
ArticleModel.objects.create(
|
||||||
|
headline="Amazing news",
|
||||||
|
reporter=r1,
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
pub_date_time=datetime.datetime.now(),
|
||||||
|
editor=r1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = async_to_sync(schema.execute_async)(query)
|
||||||
|
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
def test_get_queryset_foreign_key(self):
|
def test_get_queryset_foreign_key(self):
|
||||||
class Article(DjangoObjectType):
|
class Article(DjangoObjectType):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -483,6 +567,7 @@ class TestDjangoListField:
|
||||||
{"firstName": "Debra", "articles": []},
|
{"firstName": "Debra", "articles": []},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_resolve_list_external_resolver(self):
|
def test_resolve_list_external_resolver(self):
|
||||||
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||||
|
@ -531,6 +616,53 @@ class TestDjangoListField:
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
|
def test_resolve_list_external_resolver_async(self):
|
||||||
|
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||||
|
|
||||||
|
class Reporter(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = ReporterModel
|
||||||
|
fields = ("first_name", "articles")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_queryset(cls, queryset, info):
|
||||||
|
# Only get reporters with at least 1 article
|
||||||
|
return queryset.annotate(article_count=Count("articles")).filter(
|
||||||
|
article_count__gt=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_reporters(_, info):
|
||||||
|
return [ReporterModel.objects.get(first_name="Debra")]
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
reporters {
|
||||||
|
firstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
|
ArticleModel.objects.create(
|
||||||
|
headline="Amazing news",
|
||||||
|
reporter=r1,
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
pub_date_time=datetime.datetime.now(),
|
||||||
|
editor=r1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = async_to_sync(schema.execute_async)(query)
|
||||||
|
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
def test_get_queryset_filter_external_resolver(self):
|
def test_get_queryset_filter_external_resolver(self):
|
||||||
class Reporter(DjangoObjectType):
|
class Reporter(DjangoObjectType):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -575,6 +707,7 @@ class TestDjangoListField:
|
||||||
|
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_select_related_and_prefetch_related_are_respected(
|
def test_select_related_and_prefetch_related_are_respected(
|
||||||
self, django_assert_num_queries
|
self, django_assert_num_queries
|
||||||
|
@ -717,6 +850,7 @@ class TestDjangoListField:
|
||||||
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
|
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
|
||||||
captured.captured_queries[1]["sql"],
|
captured.captured_queries[1]["sql"],
|
||||||
)
|
)
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
class TestDjangoConnectionField:
|
class TestDjangoConnectionField:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import base64
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from django.utils.functional import SimpleLazyObject
|
from django.utils.functional import SimpleLazyObject
|
||||||
|
@ -15,6 +16,7 @@ from ..compat import IntegerRangeField, MissingType
|
||||||
from ..fields import DjangoConnectionField
|
from ..fields import DjangoConnectionField
|
||||||
from ..types import DjangoObjectType
|
from ..types import DjangoObjectType
|
||||||
from ..utils import DJANGO_FILTER_INSTALLED
|
from ..utils import DJANGO_FILTER_INSTALLED
|
||||||
|
from .async_test_helper import assert_async_result_equal
|
||||||
from .models import (
|
from .models import (
|
||||||
APNewsReporter,
|
APNewsReporter,
|
||||||
Article,
|
Article,
|
||||||
|
@ -43,6 +45,7 @@ def test_should_query_only_fields():
|
||||||
"""
|
"""
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_simplelazy_objects():
|
def test_should_query_simplelazy_objects():
|
||||||
|
@ -68,6 +71,7 @@ def test_should_query_simplelazy_objects():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporter": {"id": "1"}}
|
assert result.data == {"reporter": {"id": "1"}}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_wrapped_simplelazy_objects():
|
def test_should_query_wrapped_simplelazy_objects():
|
||||||
|
@ -93,6 +97,7 @@ def test_should_query_wrapped_simplelazy_objects():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporter": {"id": "1"}}
|
assert result.data == {"reporter": {"id": "1"}}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_well():
|
def test_should_query_well():
|
||||||
|
@ -121,6 +126,7 @@ def test_should_query_well():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist")
|
@pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist")
|
||||||
|
@ -175,6 +181,7 @@ def test_should_query_postgres_fields():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_node():
|
def test_should_node():
|
||||||
|
@ -256,6 +263,7 @@ def test_should_node():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_onetoone_fields():
|
def test_should_query_onetoone_fields():
|
||||||
|
@ -314,6 +322,7 @@ def test_should_query_onetoone_fields():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_connectionfields():
|
def test_should_query_connectionfields():
|
||||||
|
@ -352,6 +361,7 @@ def test_should_query_connectionfields():
|
||||||
"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}],
|
"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_keep_annotations():
|
def test_should_keep_annotations():
|
||||||
|
@ -411,6 +421,7 @@ def test_should_keep_annotations():
|
||||||
"""
|
"""
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
@ -492,6 +503,7 @@ def test_should_query_node_filtering():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
@ -537,6 +549,7 @@ def test_should_query_node_filtering_with_distinct_queryset():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
@ -626,6 +639,7 @@ def test_should_query_node_multiple_filtering():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_enforce_first_or_last(graphene_settings):
|
def test_should_enforce_first_or_last(graphene_settings):
|
||||||
|
@ -666,6 +680,7 @@ def test_should_enforce_first_or_last(graphene_settings):
|
||||||
"paginate the `allReporters` connection.\n"
|
"paginate the `allReporters` connection.\n"
|
||||||
)
|
)
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_error_if_first_is_greater_than_max(graphene_settings):
|
def test_should_error_if_first_is_greater_than_max(graphene_settings):
|
||||||
|
@ -708,6 +723,7 @@ def test_should_error_if_first_is_greater_than_max(graphene_settings):
|
||||||
"exceeds the `first` limit of 100 records.\n"
|
"exceeds the `first` limit of 100 records.\n"
|
||||||
)
|
)
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_error_if_last_is_greater_than_max(graphene_settings):
|
def test_should_error_if_last_is_greater_than_max(graphene_settings):
|
||||||
|
@ -750,6 +766,7 @@ def test_should_error_if_last_is_greater_than_max(graphene_settings):
|
||||||
"exceeds the `last` limit of 100 records.\n"
|
"exceeds the `last` limit of 100 records.\n"
|
||||||
)
|
)
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_promise_connectionfields():
|
def test_should_query_promise_connectionfields():
|
||||||
|
@ -785,6 +802,7 @@ def test_should_query_promise_connectionfields():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_connectionfields_with_last():
|
def test_should_query_connectionfields_with_last():
|
||||||
|
@ -822,6 +840,7 @@ def test_should_query_connectionfields_with_last():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_connectionfields_with_manager():
|
def test_should_query_connectionfields_with_manager():
|
||||||
|
@ -863,6 +882,7 @@ def test_should_query_connectionfields_with_manager():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_dataloader_fields():
|
def test_should_query_dataloader_fields():
|
||||||
|
@ -965,6 +985,106 @@ def test_should_query_dataloader_fields():
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_query_dataloader_fields_async():
|
||||||
|
from promise import Promise
|
||||||
|
from promise.dataloader import DataLoader
|
||||||
|
|
||||||
|
def article_batch_load_fn(keys):
|
||||||
|
queryset = Article.objects.filter(reporter_id__in=keys)
|
||||||
|
return Promise.resolve(
|
||||||
|
[
|
||||||
|
[article for article in queryset if article.reporter_id == id]
|
||||||
|
for id in keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
article_loader = DataLoader(article_batch_load_fn)
|
||||||
|
|
||||||
|
class ArticleType(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = Article
|
||||||
|
interfaces = (Node,)
|
||||||
|
fields = "__all__"
|
||||||
|
|
||||||
|
class ReporterType(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = Reporter
|
||||||
|
interfaces = (Node,)
|
||||||
|
use_connection = True
|
||||||
|
fields = "__all__"
|
||||||
|
|
||||||
|
articles = DjangoConnectionField(ArticleType)
|
||||||
|
|
||||||
|
def resolve_articles(self, info, **args):
|
||||||
|
return article_loader.load(self.id).get()
|
||||||
|
|
||||||
|
class Query(graphene.ObjectType):
|
||||||
|
all_reporters = DjangoConnectionField(ReporterType)
|
||||||
|
|
||||||
|
r = Reporter.objects.create(
|
||||||
|
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
|
||||||
|
)
|
||||||
|
|
||||||
|
Article.objects.create(
|
||||||
|
headline="Article Node 1",
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
pub_date_time=datetime.datetime.now(),
|
||||||
|
reporter=r,
|
||||||
|
editor=r,
|
||||||
|
lang="es",
|
||||||
|
)
|
||||||
|
Article.objects.create(
|
||||||
|
headline="Article Node 2",
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
pub_date_time=datetime.datetime.now(),
|
||||||
|
reporter=r,
|
||||||
|
editor=r,
|
||||||
|
lang="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = graphene.Schema(query=Query)
|
||||||
|
query = """
|
||||||
|
query ReporterPromiseConnectionQuery {
|
||||||
|
allReporters(first: 1) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
id
|
||||||
|
articles(first: 2) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
headline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"allReporters": {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"node": {
|
||||||
|
"id": "UmVwb3J0ZXJUeXBlOjE=",
|
||||||
|
"articles": {
|
||||||
|
"edges": [
|
||||||
|
{"node": {"headline": "Article Node 1"}},
|
||||||
|
{"node": {"headline": "Article Node 2"}},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = async_to_sync(schema.execute_async)(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == expected
|
||||||
|
|
||||||
|
|
||||||
def test_should_handle_inherited_choices():
|
def test_should_handle_inherited_choices():
|
||||||
class BaseModel(models.Model):
|
class BaseModel(models.Model):
|
||||||
choice_field = models.IntegerField(choices=((0, "zero"), (1, "one")))
|
choice_field = models.IntegerField(choices=((0, "zero"), (1, "one")))
|
||||||
|
@ -1071,6 +1191,7 @@ def test_proxy_model_support():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_model_inheritance_support_reverse_relationships():
|
def test_model_inheritance_support_reverse_relationships():
|
||||||
|
@ -1411,6 +1532,7 @@ def test_should_resolve_get_queryset_connectionfields():
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_connection_should_limit_after_to_list_length():
|
def test_connection_should_limit_after_to_list_length():
|
||||||
|
@ -1448,6 +1570,7 @@ def test_connection_should_limit_after_to_list_length():
|
||||||
expected = {"allReporters": {"edges": []}}
|
expected = {"allReporters": {"edges": []}}
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={"after": after})
|
||||||
|
|
||||||
|
|
||||||
REPORTERS = [
|
REPORTERS = [
|
||||||
|
@ -1491,6 +1614,7 @@ def test_should_return_max_limit(graphene_settings):
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert len(result.data["allReporters"]["edges"]) == 4
|
assert len(result.data["allReporters"]["edges"]) == 4
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_have_next_page(graphene_settings):
|
def test_should_have_next_page(graphene_settings):
|
||||||
|
@ -1529,6 +1653,7 @@ def test_should_have_next_page(graphene_settings):
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert len(result.data["allReporters"]["edges"]) == 4
|
assert len(result.data["allReporters"]["edges"]) == 4
|
||||||
assert result.data["allReporters"]["pageInfo"]["hasNextPage"]
|
assert result.data["allReporters"]["pageInfo"]["hasNextPage"]
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={})
|
||||||
|
|
||||||
last_result = result.data["allReporters"]["pageInfo"]["endCursor"]
|
last_result = result.data["allReporters"]["pageInfo"]["endCursor"]
|
||||||
result2 = schema.execute(query, variable_values={"first": 4, "after": last_result})
|
result2 = schema.execute(query, variable_values={"first": 4, "after": last_result})
|
||||||
|
@ -1542,6 +1667,9 @@ def test_should_have_next_page(graphene_settings):
|
||||||
assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == {
|
assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == {
|
||||||
gql_reporter["node"]["id"] for gql_reporter in gql_reporters
|
gql_reporter["node"]["id"] for gql_reporter in gql_reporters
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(
|
||||||
|
schema, query, result2, variable_values={"first": 4, "after": last_result}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("max_limit", [100, 4])
|
@pytest.mark.parametrize("max_limit", [100, 4])
|
||||||
|
@ -1565,7 +1693,7 @@ class TestBackwardPagination:
|
||||||
|
|
||||||
def test_query_last(self, graphene_settings, max_limit):
|
def test_query_last(self, graphene_settings, max_limit):
|
||||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||||
query_last = """
|
query = """
|
||||||
query {
|
query {
|
||||||
allReporters(last: 3) {
|
allReporters(last: 3) {
|
||||||
edges {
|
edges {
|
||||||
|
@ -1577,16 +1705,17 @@ class TestBackwardPagination:
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = schema.execute(query_last)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert len(result.data["allReporters"]["edges"]) == 3
|
assert len(result.data["allReporters"]["edges"]) == 3
|
||||||
assert [
|
assert [
|
||||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||||
] == ["First 3", "First 4", "First 5"]
|
] == ["First 3", "First 4", "First 5"]
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_query_first_and_last(self, graphene_settings, max_limit):
|
def test_query_first_and_last(self, graphene_settings, max_limit):
|
||||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||||
query_first_and_last = """
|
query = """
|
||||||
query {
|
query {
|
||||||
allReporters(first: 4, last: 3) {
|
allReporters(first: 4, last: 3) {
|
||||||
edges {
|
edges {
|
||||||
|
@ -1598,12 +1727,13 @@ class TestBackwardPagination:
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = schema.execute(query_first_and_last)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert len(result.data["allReporters"]["edges"]) == 3
|
assert len(result.data["allReporters"]["edges"]) == 3
|
||||||
assert [
|
assert [
|
||||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||||
] == ["First 1", "First 2", "First 3"]
|
] == ["First 1", "First 2", "First 3"]
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_query_first_last_and_after(self, graphene_settings, max_limit):
|
def test_query_first_last_and_after(self, graphene_settings, max_limit):
|
||||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||||
|
@ -1629,6 +1759,9 @@ class TestBackwardPagination:
|
||||||
assert [
|
assert [
|
||||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||||
] == ["First 2", "First 3", "First 4"]
|
] == ["First 2", "First 3", "First 4"]
|
||||||
|
assert_async_result_equal(
|
||||||
|
schema, query_first_last_and_after, result, variable_values={"after": after}
|
||||||
|
)
|
||||||
|
|
||||||
def test_query_last_and_before(self, graphene_settings, max_limit):
|
def test_query_last_and_before(self, graphene_settings, max_limit):
|
||||||
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
|
||||||
|
@ -1650,6 +1783,7 @@ class TestBackwardPagination:
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert len(result.data["allReporters"]["edges"]) == 1
|
assert len(result.data["allReporters"]["edges"]) == 1
|
||||||
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5"
|
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5"
|
||||||
|
assert_async_result_equal(schema, query_first_last_and_after, result)
|
||||||
|
|
||||||
before = base64.b64encode(b"arrayconnection:5").decode()
|
before = base64.b64encode(b"arrayconnection:5").decode()
|
||||||
result = schema.execute(
|
result = schema.execute(
|
||||||
|
@ -1659,6 +1793,12 @@ class TestBackwardPagination:
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert len(result.data["allReporters"]["edges"]) == 1
|
assert len(result.data["allReporters"]["edges"]) == 1
|
||||||
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4"
|
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4"
|
||||||
|
assert_async_result_equal(
|
||||||
|
schema,
|
||||||
|
query_first_last_and_after,
|
||||||
|
result,
|
||||||
|
variable_values={"before": before},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_should_preserve_prefetch_related(django_assert_num_queries):
|
def test_should_preserve_prefetch_related(django_assert_num_queries):
|
||||||
|
@ -1713,6 +1853,7 @@ def test_should_preserve_prefetch_related(django_assert_num_queries):
|
||||||
with django_assert_num_queries(3):
|
with django_assert_num_queries(3):
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_should_preserve_annotations():
|
def test_should_preserve_annotations():
|
||||||
|
@ -1768,6 +1909,7 @@ def test_should_preserve_annotations():
|
||||||
}
|
}
|
||||||
assert result.data == expected, str(result.data)
|
assert result.data == expected, str(result.data)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_connection_should_enable_offset_filtering():
|
def test_connection_should_enable_offset_filtering():
|
||||||
|
@ -1807,6 +1949,7 @@ def test_connection_should_enable_offset_filtering():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_connection_should_enable_offset_filtering_higher_than_max_limit(
|
def test_connection_should_enable_offset_filtering_higher_than_max_limit(
|
||||||
|
@ -1851,6 +1994,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
|
|
||||||
def test_connection_should_forbid_offset_filtering_with_before():
|
def test_connection_should_forbid_offset_filtering_with_before():
|
||||||
|
@ -1881,6 +2025,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
|
||||||
expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection."
|
expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection."
|
||||||
assert len(result.errors) == 1
|
assert len(result.errors) == 1
|
||||||
assert result.errors[0].message == expected_error
|
assert result.errors[0].message == expected_error
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={"before": before})
|
||||||
|
|
||||||
|
|
||||||
def test_connection_should_allow_offset_filtering_with_after():
|
def test_connection_should_allow_offset_filtering_with_after():
|
||||||
|
@ -1923,6 +2068,7 @@ def test_connection_should_allow_offset_filtering_with_after():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={"after": after})
|
||||||
|
|
||||||
|
|
||||||
def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
||||||
|
@ -1953,6 +2099,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
expected = {"allReporters": {"edges": []}}
|
expected = {"allReporters": {"edges": []}}
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={"last": 2})
|
||||||
|
|
||||||
Reporter.objects.create(first_name="John", last_name="Doe")
|
Reporter.objects.create(first_name="John", last_name="Doe")
|
||||||
Reporter.objects.create(first_name="Some", last_name="Guy")
|
Reporter.objects.create(first_name="Some", last_name="Guy")
|
||||||
|
@ -1970,6 +2117,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={"last": 2})
|
||||||
|
|
||||||
result = schema.execute(query, variable_values={"last": 4})
|
result = schema.execute(query, variable_values={"last": 4})
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
|
@ -1984,6 +2132,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={"last": 4})
|
||||||
|
|
||||||
result = schema.execute(query, variable_values={"last": 20})
|
result = schema.execute(query, variable_values={"last": 20})
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
|
@ -1998,6 +2147,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
assert_async_result_equal(schema, query, result, variable_values={"last": 20})
|
||||||
|
|
||||||
|
|
||||||
def test_should_query_nullable_foreign_key():
|
def test_should_query_nullable_foreign_key():
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .utils import (
|
||||||
DJANGO_FILTER_INSTALLED,
|
DJANGO_FILTER_INSTALLED,
|
||||||
camelize,
|
camelize,
|
||||||
get_model_fields,
|
get_model_fields,
|
||||||
|
is_running_async,
|
||||||
is_valid_django_model,
|
is_valid_django_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -288,7 +289,11 @@ class DjangoObjectType(ObjectType):
|
||||||
def get_node(cls, info, id):
|
def get_node(cls, info, id):
|
||||||
queryset = cls.get_queryset(cls._meta.model.objects, info)
|
queryset = cls.get_queryset(cls._meta.model.objects, info)
|
||||||
try:
|
try:
|
||||||
|
if is_running_async():
|
||||||
|
return queryset.aget(pk=id)
|
||||||
|
|
||||||
return queryset.get(pk=id)
|
return queryset.get(pk=id)
|
||||||
|
|
||||||
except cls._meta.model.DoesNotExist:
|
except cls._meta.model.DoesNotExist:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,8 @@ from .utils import (
|
||||||
camelize,
|
camelize,
|
||||||
get_model_fields,
|
get_model_fields,
|
||||||
get_reverse_fields,
|
get_reverse_fields,
|
||||||
|
is_running_async,
|
||||||
|
is_sync_function,
|
||||||
is_valid_django_model,
|
is_valid_django_model,
|
||||||
maybe_queryset,
|
maybe_queryset,
|
||||||
)
|
)
|
||||||
|
@ -17,5 +19,7 @@ __all__ = [
|
||||||
"camelize",
|
"camelize",
|
||||||
"is_valid_django_model",
|
"is_valid_django_model",
|
||||||
"GraphQLTestCase",
|
"GraphQLTestCase",
|
||||||
|
"is_sync_function",
|
||||||
|
"is_running_async",
|
||||||
"bypass_get_queryset",
|
"bypass_get_queryset",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
from asyncio import get_running_loop
|
||||||
|
|
||||||
import django
|
import django
|
||||||
from django.db import connection, models, transaction
|
from django.db import connection, models, transaction
|
||||||
|
@ -139,6 +140,21 @@ def set_rollback():
|
||||||
transaction.set_rollback(True)
|
transaction.set_rollback(True)
|
||||||
|
|
||||||
|
|
||||||
|
def is_running_async():
|
||||||
|
try:
|
||||||
|
get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_sync_function(func):
|
||||||
|
return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
||||||
|
func
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bypass_get_queryset(resolver):
|
def bypass_get_queryset(resolver):
|
||||||
"""
|
"""
|
||||||
Adds a bypass_get_queryset attribute to the resolver, which is used to
|
Adds a bypass_get_queryset attribute to the resolver, which is used to
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import traceback
|
||||||
|
from asyncio import coroutines, gather
|
||||||
|
|
||||||
from django.db import connection, transaction
|
from django.db import connection, transaction
|
||||||
from django.http import HttpResponse, HttpResponseNotAllowed
|
from django.http import HttpResponse, HttpResponseNotAllowed
|
||||||
from django.http.response import HttpResponseBadRequest
|
from django.http.response import HttpResponseBadRequest
|
||||||
from django.shortcuts import render
|
from django.shortcuts import render
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import classonlymethod, method_decorator
|
||||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||||
from django.views.generic import View
|
from django.views.generic import View
|
||||||
from graphql import (
|
from graphql import (
|
||||||
|
@ -431,3 +433,336 @@ class GraphQLView(View):
|
||||||
meta = request.META
|
meta = request.META
|
||||||
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
|
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
|
||||||
return content_type.split(";", 1)[0].lower()
|
return content_type.split(";", 1)[0].lower()
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncGraphQLView(GraphQLView):
|
||||||
|
schema = None
|
||||||
|
graphiql = False
|
||||||
|
middleware = None
|
||||||
|
root_value = None
|
||||||
|
pretty = False
|
||||||
|
batch = False
|
||||||
|
subscription_path = None
|
||||||
|
execution_context_class = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
schema=None,
|
||||||
|
middleware=None,
|
||||||
|
root_value=None,
|
||||||
|
graphiql=False,
|
||||||
|
pretty=False,
|
||||||
|
batch=False,
|
||||||
|
subscription_path=None,
|
||||||
|
execution_context_class=None,
|
||||||
|
):
|
||||||
|
if not schema:
|
||||||
|
schema = graphene_settings.SCHEMA
|
||||||
|
|
||||||
|
if middleware is None:
|
||||||
|
middleware = graphene_settings.MIDDLEWARE
|
||||||
|
|
||||||
|
self.schema = self.schema or schema
|
||||||
|
if middleware is not None:
|
||||||
|
if isinstance(middleware, MiddlewareManager):
|
||||||
|
self.middleware = middleware
|
||||||
|
else:
|
||||||
|
self.middleware = list(instantiate_middleware(middleware))
|
||||||
|
self.root_value = root_value
|
||||||
|
self.pretty = self.pretty or pretty
|
||||||
|
self.graphiql = self.graphiql or graphiql
|
||||||
|
self.batch = self.batch or batch
|
||||||
|
self.execution_context_class = execution_context_class
|
||||||
|
if subscription_path is None:
|
||||||
|
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
self.schema, Schema
|
||||||
|
), "A Schema is required to be provided to GraphQLView."
|
||||||
|
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
|
||||||
|
|
||||||
|
# noinspection PyUnusedLocal
|
||||||
|
def get_root_value(self, request):
|
||||||
|
return self.root_value
|
||||||
|
|
||||||
|
def get_middleware(self, request):
|
||||||
|
return self.middleware
|
||||||
|
|
||||||
|
def get_context(self, request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
@classonlymethod
|
||||||
|
def as_view(cls, **initkwargs):
|
||||||
|
view = super().as_view(**initkwargs)
|
||||||
|
view._is_coroutine = coroutines._is_coroutine
|
||||||
|
return view
|
||||||
|
|
||||||
|
async def dispatch(self, request, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
if request.method.lower() not in ("get", "post"):
|
||||||
|
raise HttpError(
|
||||||
|
HttpResponseNotAllowed(
|
||||||
|
["GET", "POST"], "GraphQL only supports GET and POST requests."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = self.parse_body(request)
|
||||||
|
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
|
||||||
|
|
||||||
|
if show_graphiql:
|
||||||
|
return self.render_graphiql(
|
||||||
|
request,
|
||||||
|
# Dependency parameters.
|
||||||
|
whatwg_fetch_version=self.whatwg_fetch_version,
|
||||||
|
whatwg_fetch_sri=self.whatwg_fetch_sri,
|
||||||
|
react_version=self.react_version,
|
||||||
|
react_sri=self.react_sri,
|
||||||
|
react_dom_sri=self.react_dom_sri,
|
||||||
|
graphiql_version=self.graphiql_version,
|
||||||
|
graphiql_sri=self.graphiql_sri,
|
||||||
|
graphiql_css_sri=self.graphiql_css_sri,
|
||||||
|
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
|
||||||
|
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
|
||||||
|
graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version,
|
||||||
|
graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri,
|
||||||
|
# The SUBSCRIPTION_PATH setting.
|
||||||
|
subscription_path=self.subscription_path,
|
||||||
|
# GraphiQL headers tab,
|
||||||
|
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
|
||||||
|
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.batch:
|
||||||
|
responses = await gather(
|
||||||
|
*[self.get_response(request, entry) for entry in data]
|
||||||
|
)
|
||||||
|
result = "[{}]".format(
|
||||||
|
",".join([response[0] for response in responses])
|
||||||
|
)
|
||||||
|
status_code = (
|
||||||
|
responses
|
||||||
|
and max(responses, key=lambda response: response[1])[1]
|
||||||
|
or 200
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result, status_code = await self.get_response(
|
||||||
|
request, data, show_graphiql
|
||||||
|
)
|
||||||
|
|
||||||
|
return HttpResponse(
|
||||||
|
status=status_code, content=result, content_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
except HttpError as e:
|
||||||
|
response = e.response
|
||||||
|
response["Content-Type"] = "application/json"
|
||||||
|
response.content = self.json_encode(
|
||||||
|
request, {"errors": [self.format_error(e)]}
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def get_response(self, request, data, show_graphiql=False):
|
||||||
|
query, variables, operation_name, id = self.get_graphql_params(request, data)
|
||||||
|
|
||||||
|
execution_result = await self.execute_graphql_request(
|
||||||
|
request, data, query, variables, operation_name, show_graphiql
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||||
|
set_rollback()
|
||||||
|
|
||||||
|
status_code = 200
|
||||||
|
if execution_result:
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if execution_result.errors:
|
||||||
|
for e in execution_result.errors:
|
||||||
|
print(e)
|
||||||
|
traceback.print_tb(e.__traceback__)
|
||||||
|
set_rollback()
|
||||||
|
response["errors"] = [
|
||||||
|
self.format_error(e) for e in execution_result.errors
|
||||||
|
]
|
||||||
|
|
||||||
|
if execution_result.errors and any(
|
||||||
|
not getattr(e, "path", None) for e in execution_result.errors
|
||||||
|
):
|
||||||
|
status_code = 400
|
||||||
|
else:
|
||||||
|
response["data"] = execution_result.data
|
||||||
|
|
||||||
|
if self.batch:
|
||||||
|
response["id"] = id
|
||||||
|
response["status"] = status_code
|
||||||
|
|
||||||
|
result = self.json_encode(request, response, pretty=show_graphiql)
|
||||||
|
else:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
return result, status_code
|
||||||
|
|
||||||
|
def render_graphiql(self, request, **data):
|
||||||
|
return render(request, self.graphiql_template, data)
|
||||||
|
|
||||||
|
def json_encode(self, request, d, pretty=False):
|
||||||
|
if not (self.pretty or pretty) and not request.GET.get("pretty"):
|
||||||
|
return json.dumps(d, separators=(",", ":"))
|
||||||
|
|
||||||
|
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
|
||||||
|
|
||||||
|
def parse_body(self, request):
|
||||||
|
content_type = self.get_content_type(request)
|
||||||
|
|
||||||
|
if content_type == "application/graphql":
|
||||||
|
return {"query": request.body.decode()}
|
||||||
|
|
||||||
|
elif content_type == "application/json":
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
body = request.body.decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
raise HttpError(HttpResponseBadRequest(str(e)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_json = json.loads(body)
|
||||||
|
if self.batch:
|
||||||
|
assert isinstance(request_json, list), (
|
||||||
|
"Batch requests should receive a list, but received {}."
|
||||||
|
).format(repr(request_json))
|
||||||
|
assert (
|
||||||
|
len(request_json) > 0
|
||||||
|
), "Received an empty list in the batch request."
|
||||||
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
request_json, dict
|
||||||
|
), "The received data is not a valid JSON query."
|
||||||
|
return request_json
|
||||||
|
except AssertionError as e:
|
||||||
|
raise HttpError(HttpResponseBadRequest(str(e)))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
|
||||||
|
|
||||||
|
elif content_type in [
|
||||||
|
"application/x-www-form-urlencoded",
|
||||||
|
"multipart/form-data",
|
||||||
|
]:
|
||||||
|
return request.POST
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def execute_graphql_request(
|
||||||
|
self, request, data, query, variables, operation_name, show_graphiql=False
|
||||||
|
):
|
||||||
|
if not query:
|
||||||
|
if show_graphiql:
|
||||||
|
return None
|
||||||
|
raise HttpError(HttpResponseBadRequest("Must provide query string."))
|
||||||
|
|
||||||
|
try:
|
||||||
|
document = parse(query)
|
||||||
|
except Exception as e:
|
||||||
|
return ExecutionResult(errors=[e])
|
||||||
|
|
||||||
|
if request.method.lower() == "get":
|
||||||
|
operation_ast = get_operation_ast(document, operation_name)
|
||||||
|
if operation_ast and operation_ast.operation != OperationType.QUERY:
|
||||||
|
if show_graphiql:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raise HttpError(
|
||||||
|
HttpResponseNotAllowed(
|
||||||
|
["POST"],
|
||||||
|
"Can only perform a {} operation from a POST request.".format(
|
||||||
|
operation_ast.operation.value
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
extra_options = {}
|
||||||
|
if self.execution_context_class:
|
||||||
|
extra_options["execution_context_class"] = self.execution_context_class
|
||||||
|
|
||||||
|
options = {
|
||||||
|
"source": query,
|
||||||
|
"root_value": self.get_root_value(request),
|
||||||
|
"variable_values": variables,
|
||||||
|
"operation_name": operation_name,
|
||||||
|
"context_value": self.get_context(request),
|
||||||
|
"middleware": self.get_middleware(request),
|
||||||
|
}
|
||||||
|
options.update(extra_options)
|
||||||
|
|
||||||
|
operation_ast = get_operation_ast(document, operation_name)
|
||||||
|
if (
|
||||||
|
operation_ast
|
||||||
|
and operation_ast.operation == OperationType.MUTATION
|
||||||
|
and (
|
||||||
|
graphene_settings.ATOMIC_MUTATIONS is True
|
||||||
|
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
with transaction.atomic():
|
||||||
|
result = await self.schema.execute_async(**options)
|
||||||
|
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||||
|
transaction.set_rollback(True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return await self.schema.execute_async(**options)
|
||||||
|
except Exception as e:
|
||||||
|
return ExecutionResult(errors=[e])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_display_graphiql(cls, request, data):
|
||||||
|
raw = "raw" in request.GET or "raw" in data
|
||||||
|
return not raw and cls.request_wants_html(request)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def request_wants_html(cls, request):
|
||||||
|
accepted = get_accepted_content_types(request)
|
||||||
|
accepted_length = len(accepted)
|
||||||
|
# the list will be ordered in preferred first - so we have to make
|
||||||
|
# sure the most preferred gets the highest number
|
||||||
|
html_priority = (
|
||||||
|
accepted_length - accepted.index("text/html")
|
||||||
|
if "text/html" in accepted
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
json_priority = (
|
||||||
|
accepted_length - accepted.index("application/json")
|
||||||
|
if "application/json" in accepted
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return html_priority > json_priority
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_graphql_params(request, data):
|
||||||
|
query = request.GET.get("query") or data.get("query")
|
||||||
|
variables = request.GET.get("variables") or data.get("variables")
|
||||||
|
id = request.GET.get("id") or data.get("id")
|
||||||
|
|
||||||
|
if variables and isinstance(variables, str):
|
||||||
|
try:
|
||||||
|
variables = json.loads(variables)
|
||||||
|
except Exception:
|
||||||
|
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
|
||||||
|
|
||||||
|
operation_name = request.GET.get("operationName") or data.get("operationName")
|
||||||
|
if operation_name == "null":
|
||||||
|
operation_name = None
|
||||||
|
|
||||||
|
return query, variables, operation_name, id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_error(error):
|
||||||
|
if isinstance(error, GraphQLError):
|
||||||
|
return error.formatted
|
||||||
|
|
||||||
|
return {"message": str(error)}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_content_type(request):
|
||||||
|
meta = request.META
|
||||||
|
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
|
||||||
|
return content_type.split(";", 1)[0].lower()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user