This commit is contained in:
Josh Warwick 2024-04-09 03:43:42 +03:00 committed by GitHub
commit a8e6217150
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 801 additions and 24 deletions

View File

@ -1,4 +1,8 @@
from graphene import Node
import asyncio
from asgiref.sync import sync_to_async
from graphene import Field, Node, String
from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType
@ -6,12 +10,32 @@ from cookbook.recipes.models import Recipe, RecipeIngredient
class RecipeNode(DjangoObjectType):
async_field = String()
class Meta:
model = Recipe
interfaces = (Node,)
fields = "__all__"
filter_fields = ["title", "amounts"]
async def resolve_async_field(self, info):
await asyncio.sleep(2)
return "success"
class RecipeType(DjangoObjectType):
async_field = String()
class Meta:
model = Recipe
fields = "__all__"
filter_fields = ["title", "amounts"]
skip_registry = True
async def resolve_async_field(self, info):
await asyncio.sleep(2)
return "success"
class RecipeIngredientNode(DjangoObjectType):
class Meta:
@ -28,7 +52,13 @@ class RecipeIngredientNode(DjangoObjectType):
class Query:
recipe = Node.Field(RecipeNode)
raw_recipe = Field(RecipeType)
all_recipes = DjangoFilterConnectionField(RecipeNode)
recipeingredient = Node.Field(RecipeIngredientNode)
all_recipeingredients = DjangoFilterConnectionField(RecipeIngredientNode)
@staticmethod
@sync_to_async
def resolve_raw_recipe(self, info):
return Recipe.objects.first()

View File

@ -1,9 +1,10 @@
from django.conf.urls import url
from django.contrib import admin
from django.urls import re_path
from django.views.decorators.csrf import csrf_exempt
from graphene_django.views import GraphQLView
from graphene_django.views import AsyncGraphQLView
urlpatterns = [
url(r"^admin/", admin.site.urls),
url(r"^graphql$", GraphQLView.as_view(graphiql=True)),
re_path(r"^admin/", admin.site.urls),
re_path(r"^graphql$", csrf_exempt(AsyncGraphQLView.as_view(graphiql=True))),
]

View File

@ -1,5 +1,8 @@
from asgiref.sync import sync_to_async
from django.db import connections
from graphql.type.definition import GraphQLNonNull
from ..utils import is_running_async, is_sync_function
from .exception.formating import wrap_exception
from .sql.tracking import unwrap_cursor, wrap_cursor
from .types import DjangoDebug
@ -67,3 +70,28 @@ class DjangoDebugMiddleware:
return context.django_debug.on_resolve_error(e)
context.django_debug.add_result(result)
return result
class DjangoSyncRequiredMiddleware:
def resolve(self, next, root, info, **args):
parent_type = info.parent_type
return_type = info.return_type
if isinstance(parent_type, GraphQLNonNull):
parent_type = parent_type.of_type
if isinstance(return_type, GraphQLNonNull):
return_type = return_type.of_type
if any(
[
hasattr(parent_type, "graphene_type")
and hasattr(parent_type.graphene_type._meta, "model"),
hasattr(return_type, "graphene_type")
and hasattr(return_type.graphene_type._meta, "model"),
info.parent_type.name == "Mutation",
]
):
if is_sync_function(next) and is_running_async():
return sync_to_async(next)(root, info, **args)
return next(root, info, **args)

View File

@ -1,5 +1,6 @@
from functools import partial
from asgiref.sync import sync_to_async
from django.db.models.query import QuerySet
from graphql_relay import (
connection_from_array_slice,
@ -7,7 +8,6 @@ from graphql_relay import (
get_offset_with_default,
offset_to_cursor,
)
from promise import Promise
from graphene import Int, NonNull
from graphene.relay import ConnectionField
@ -15,7 +15,7 @@ from graphene.relay.connection import connection_adapter, page_info_adapter
from graphene.types import Field, List
from .settings import graphene_settings
from .utils import maybe_queryset
from .utils import is_running_async, is_sync_function, maybe_queryset
class DjangoListField(Field):
@ -49,11 +49,36 @@ class DjangoListField(Field):
def get_manager(self):
return self.model._default_manager
@staticmethod
@classmethod
def list_resolver(
django_object_type, resolver, default_manager, root, info, **args
cls, django_object_type, resolver, default_manager, root, info, **args
):
queryset = maybe_queryset(resolver(root, info, **args))
if is_running_async():
if is_sync_function(resolver):
resolver = sync_to_async(resolver)
iterable = resolver(root, info, **args)
if info.is_awaitable(iterable):
async def resolve_list_async(iterable):
queryset = maybe_queryset(await iterable)
if queryset is None:
queryset = maybe_queryset(default_manager)
if isinstance(queryset, QuerySet):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(
await sync_to_async(django_object_type.get_queryset)(
queryset, info
)
)
return await sync_to_async(list)(queryset)
return resolve_list_async(iterable)
queryset = maybe_queryset(iterable)
if queryset is None:
queryset = maybe_queryset(default_manager)
@ -61,7 +86,7 @@ class DjangoListField(Field):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
return queryset
return list(queryset)
def wrap_resolve(self, parent_resolver):
resolver = super().wrap_resolve(parent_resolver)
@ -235,20 +260,36 @@ class DjangoConnectionField(ConnectionField):
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
# or a resolve_foo (does not accept queryset)
if is_running_async():
if is_sync_function(resolver):
resolver = sync_to_async(resolver)
iterable = resolver(root, info, **args)
if info.is_awaitable(iterable):
async def resolve_connection_async(iterable):
iterable = await iterable
if iterable is None:
iterable = default_manager
iterable = await sync_to_async(queryset_resolver)(
connection, iterable, info, args
)
return await sync_to_async(cls.resolve_connection)(
connection, args, iterable, max_limit=max_limit
)
return resolve_connection_async(iterable)
if iterable is None:
iterable = default_manager
# thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise
iterable = queryset_resolver(connection, iterable, info, args)
on_resolve = partial(
cls.resolve_connection, connection, args, max_limit=max_limit
)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
return on_resolve(iterable)
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)
def wrap_resolve(self, parent_resolver):
return partial(

View File

@ -1,6 +1,7 @@
from collections import OrderedDict
from functools import partial
from asgiref.sync import sync_to_async
from django.core.exceptions import ValidationError
from graphene.types.argument import to_arguments
@ -92,6 +93,18 @@ class DjangoFilterConnectionField(DjangoConnectionField):
qs = super().resolve_queryset(connection, iterable, info, args)
if info.is_awaitable(qs):
async def filter_async(qs):
filterset = filterset_class(
data=filter_kwargs(), queryset=await qs, request=info.context
)
if await sync_to_async(filterset.is_valid)():
return filterset.qs
raise ValidationError(filterset.form.errors.as_json())
return filter_async(qs)
filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context
)

View File

@ -1,6 +1,7 @@
from collections import OrderedDict
from enum import Enum
from asgiref.sync import sync_to_async
from django.shortcuts import get_object_or_404
from rest_framework import serializers
@ -11,6 +12,7 @@ from graphene.types.mutation import MutationOptions
from graphene.types.objecttype import yank_fields_from_attrs
from ..types import ErrorType
from ..utils import is_running_async
from .serializer_converter import convert_serializer_field
@ -166,6 +168,17 @@ class SerializerMutation(ClientIDMutation):
kwargs = cls.get_serializer_kwargs(root, info, **input)
serializer = cls._meta.serializer_class(**kwargs)
if is_running_async():
async def perform_mutate_async():
if await sync_to_async(serializer.is_valid)():
return await sync_to_async(cls.perform_mutate)(serializer, info)
else:
errors = ErrorType.from_errors(serializer.errors)
return cls(errors=errors)
return perform_mutate_async()
if serializer.is_valid():
return cls.perform_mutate(serializer, info)
else:

View File

@ -0,0 +1,6 @@
from asgiref.sync import async_to_sync
def assert_async_result_equal(schema, query, result, **kwargs):
async_result = async_to_sync(schema.execute_async)(query, **kwargs)
assert async_result == result

View File

@ -2,6 +2,7 @@ import datetime
import re
import pytest
from asgiref.sync import async_to_sync
from django.db.models import Count, Model, Prefetch
from graphene import List, NonNull, ObjectType, Schema, String
@ -9,6 +10,7 @@ from graphene.relay import Node
from ..fields import DjangoConnectionField, DjangoListField
from ..types import DjangoObjectType
from .async_test_helper import assert_async_result_equal
from .models import (
Article as ArticleModel,
Film as FilmModel,
@ -82,6 +84,7 @@ class TestDjangoListField:
result = schema.execute(query)
assert_async_result_equal(schema, query, result)
assert not result.errors
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
@ -109,6 +112,7 @@ class TestDjangoListField:
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": []}
assert_async_result_equal(schema, query, result)
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
@ -119,6 +123,7 @@ class TestDjangoListField:
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
}
assert_async_result_equal(schema, query, result)
def test_override_resolver(self):
class Reporter(DjangoObjectType):
@ -146,6 +151,35 @@ class TestDjangoListField:
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
def test_override_resolver_async_execution(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
def resolve_reporters(_, info):
return ReporterModel.objects.filter(first_name="Tara")
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = async_to_sync(schema.execute_async)(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
@ -210,6 +244,7 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)
def test_override_resolver_nested_list_field(self):
class Article(DjangoObjectType):
@ -268,6 +303,7 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)
def test_same_type_nested_list_field(self):
class Person(DjangoObjectType):
@ -376,6 +412,7 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
assert_async_result_equal(schema, query, result)
def test_resolve_list(self):
"""Resolving a plain list should work (and not call get_queryset)"""
@ -424,6 +461,53 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_resolve_list_async(self):
"""Resolving a plain list should work (and not call get_queryset) when running under async"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = async_to_sync(schema.execute_async)(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_foreign_key(self):
class Article(DjangoObjectType):
class Meta:
@ -483,6 +567,7 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)
def test_resolve_list_external_resolver(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
@ -531,6 +616,53 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_resolve_list_external_resolver_async(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = async_to_sync(schema.execute_async)(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_filter_external_resolver(self):
class Reporter(DjangoObjectType):
class Meta:
@ -575,6 +707,7 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
assert_async_result_equal(schema, query, result)
def test_select_related_and_prefetch_related_are_respected(
self, django_assert_num_queries
@ -717,6 +850,7 @@ class TestDjangoListField:
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
captured.captured_queries[1]["sql"],
)
assert_async_result_equal(schema, query, result)
class TestDjangoConnectionField:

View File

@ -2,6 +2,7 @@ import base64
import datetime
import pytest
from asgiref.sync import async_to_sync
from django.db import models
from django.db.models import Q
from django.utils.functional import SimpleLazyObject
@ -15,6 +16,7 @@ from ..compat import IntegerRangeField, MissingType
from ..fields import DjangoConnectionField
from ..types import DjangoObjectType
from ..utils import DJANGO_FILTER_INSTALLED
from .async_test_helper import assert_async_result_equal
from .models import (
APNewsReporter,
Article,
@ -43,6 +45,7 @@ def test_should_query_only_fields():
"""
result = schema.execute(query)
assert not result.errors
assert_async_result_equal(schema, query, result)
def test_should_query_simplelazy_objects():
@ -68,6 +71,7 @@ def test_should_query_simplelazy_objects():
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporter": {"id": "1"}}
assert_async_result_equal(schema, query, result)
def test_should_query_wrapped_simplelazy_objects():
@ -93,6 +97,7 @@ def test_should_query_wrapped_simplelazy_objects():
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporter": {"id": "1"}}
assert_async_result_equal(schema, query, result)
def test_should_query_well():
@ -121,6 +126,7 @@ def test_should_query_well():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
@pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist")
@ -175,6 +181,7 @@ def test_should_query_postgres_fields():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_node():
@ -256,6 +263,7 @@ def test_should_node():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_query_onetoone_fields():
@ -314,6 +322,7 @@ def test_should_query_onetoone_fields():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_query_connectionfields():
@ -352,6 +361,7 @@ def test_should_query_connectionfields():
"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}],
}
}
assert_async_result_equal(schema, query, result)
def test_should_keep_annotations():
@ -411,6 +421,7 @@ def test_should_keep_annotations():
"""
result = schema.execute(query)
assert not result.errors
assert_async_result_equal(schema, query, result)
@pytest.mark.skipif(
@ -492,6 +503,7 @@ def test_should_query_node_filtering():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
@pytest.mark.skipif(
@ -537,6 +549,7 @@ def test_should_query_node_filtering_with_distinct_queryset():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
@pytest.mark.skipif(
@ -626,6 +639,7 @@ def test_should_query_node_multiple_filtering():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_enforce_first_or_last(graphene_settings):
@ -666,6 +680,7 @@ def test_should_enforce_first_or_last(graphene_settings):
"paginate the `allReporters` connection.\n"
)
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_error_if_first_is_greater_than_max(graphene_settings):
@ -708,6 +723,7 @@ def test_should_error_if_first_is_greater_than_max(graphene_settings):
"exceeds the `first` limit of 100 records.\n"
)
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_error_if_last_is_greater_than_max(graphene_settings):
@ -750,6 +766,7 @@ def test_should_error_if_last_is_greater_than_max(graphene_settings):
"exceeds the `last` limit of 100 records.\n"
)
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_query_promise_connectionfields():
@ -785,6 +802,7 @@ def test_should_query_promise_connectionfields():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_query_connectionfields_with_last():
@ -822,6 +840,7 @@ def test_should_query_connectionfields_with_last():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_query_connectionfields_with_manager():
@ -863,6 +882,7 @@ def test_should_query_connectionfields_with_manager():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_should_query_dataloader_fields():
@ -965,6 +985,106 @@ def test_should_query_dataloader_fields():
assert result.data == expected
def test_should_query_dataloader_fields_async():
from promise import Promise
from promise.dataloader import DataLoader
def article_batch_load_fn(keys):
queryset = Article.objects.filter(reporter_id__in=keys)
return Promise.resolve(
[
[article for article in queryset if article.reporter_id == id]
for id in keys
]
)
article_loader = DataLoader(article_batch_load_fn)
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
articles = DjangoConnectionField(ArticleType)
def resolve_articles(self, info, **args):
return article_loader.load(self.id).get()
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
r = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
Article.objects.create(
headline="Article Node 1",
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r,
editor=r,
lang="es",
)
Article.objects.create(
headline="Article Node 2",
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r,
editor=r,
lang="en",
)
schema = graphene.Schema(query=Query)
query = """
query ReporterPromiseConnectionQuery {
allReporters(first: 1) {
edges {
node {
id
articles(first: 2) {
edges {
node {
headline
}
}
}
}
}
}
}
"""
expected = {
"allReporters": {
"edges": [
{
"node": {
"id": "UmVwb3J0ZXJUeXBlOjE=",
"articles": {
"edges": [
{"node": {"headline": "Article Node 1"}},
{"node": {"headline": "Article Node 2"}},
]
},
}
}
]
}
}
result = async_to_sync(schema.execute_async)(query)
assert not result.errors
assert result.data == expected
def test_should_handle_inherited_choices():
class BaseModel(models.Model):
choice_field = models.IntegerField(choices=((0, "zero"), (1, "one")))
@ -1071,6 +1191,7 @@ def test_proxy_model_support():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_model_inheritance_support_reverse_relationships():
@ -1411,6 +1532,7 @@ def test_should_resolve_get_queryset_connectionfields():
result = schema.execute(query)
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_connection_should_limit_after_to_list_length():
@ -1448,6 +1570,7 @@ def test_connection_should_limit_after_to_list_length():
expected = {"allReporters": {"edges": []}}
assert not result.errors
assert result.data == expected
assert_async_result_equal(schema, query, result, variable_values={"after": after})
REPORTERS = [
@ -1491,6 +1614,7 @@ def test_should_return_max_limit(graphene_settings):
result = schema.execute(query)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 4
assert_async_result_equal(schema, query, result)
def test_should_have_next_page(graphene_settings):
@ -1529,6 +1653,7 @@ def test_should_have_next_page(graphene_settings):
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 4
assert result.data["allReporters"]["pageInfo"]["hasNextPage"]
assert_async_result_equal(schema, query, result, variable_values={})
last_result = result.data["allReporters"]["pageInfo"]["endCursor"]
result2 = schema.execute(query, variable_values={"first": 4, "after": last_result})
@ -1542,6 +1667,9 @@ def test_should_have_next_page(graphene_settings):
assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == {
gql_reporter["node"]["id"] for gql_reporter in gql_reporters
}
assert_async_result_equal(
schema, query, result2, variable_values={"first": 4, "after": last_result}
)
@pytest.mark.parametrize("max_limit", [100, 4])
@ -1565,7 +1693,7 @@ class TestBackwardPagination:
def test_query_last(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_last = """
query = """
query {
allReporters(last: 3) {
edges {
@ -1577,16 +1705,17 @@ class TestBackwardPagination:
}
"""
result = schema.execute(query_last)
result = schema.execute(query)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3
assert [
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 3", "First 4", "First 5"]
assert_async_result_equal(schema, query, result)
def test_query_first_and_last(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_and_last = """
query = """
query {
allReporters(first: 4, last: 3) {
edges {
@ -1598,12 +1727,13 @@ class TestBackwardPagination:
}
"""
result = schema.execute(query_first_and_last)
result = schema.execute(query)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3
assert [
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 1", "First 2", "First 3"]
assert_async_result_equal(schema, query, result)
def test_query_first_last_and_after(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
@ -1629,6 +1759,9 @@ class TestBackwardPagination:
assert [
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 2", "First 3", "First 4"]
assert_async_result_equal(
schema, query_first_last_and_after, result, variable_values={"after": after}
)
def test_query_last_and_before(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
@ -1650,6 +1783,7 @@ class TestBackwardPagination:
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 1
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5"
assert_async_result_equal(schema, query_first_last_and_after, result)
before = base64.b64encode(b"arrayconnection:5").decode()
result = schema.execute(
@ -1659,6 +1793,12 @@ class TestBackwardPagination:
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 1
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4"
assert_async_result_equal(
schema,
query_first_last_and_after,
result,
variable_values={"before": before},
)
def test_should_preserve_prefetch_related(django_assert_num_queries):
@ -1713,6 +1853,7 @@ def test_should_preserve_prefetch_related(django_assert_num_queries):
with django_assert_num_queries(3):
result = schema.execute(query)
assert not result.errors
assert_async_result_equal(schema, query, result)
def test_should_preserve_annotations():
@ -1768,6 +1909,7 @@ def test_should_preserve_annotations():
}
assert result.data == expected, str(result.data)
assert not result.errors
assert_async_result_equal(schema, query, result)
def test_connection_should_enable_offset_filtering():
@ -1807,6 +1949,7 @@ def test_connection_should_enable_offset_filtering():
}
}
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_connection_should_enable_offset_filtering_higher_than_max_limit(
@ -1851,6 +1994,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
}
}
assert result.data == expected
assert_async_result_equal(schema, query, result)
def test_connection_should_forbid_offset_filtering_with_before():
@ -1881,6 +2025,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection."
assert len(result.errors) == 1
assert result.errors[0].message == expected_error
assert_async_result_equal(schema, query, result, variable_values={"before": before})
def test_connection_should_allow_offset_filtering_with_after():
@ -1923,6 +2068,7 @@ def test_connection_should_allow_offset_filtering_with_after():
}
}
assert result.data == expected
assert_async_result_equal(schema, query, result, variable_values={"after": after})
def test_connection_should_succeed_if_last_higher_than_number_of_objects():
@ -1953,6 +2099,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
assert not result.errors
expected = {"allReporters": {"edges": []}}
assert result.data == expected
assert_async_result_equal(schema, query, result, variable_values={"last": 2})
Reporter.objects.create(first_name="John", last_name="Doe")
Reporter.objects.create(first_name="Some", last_name="Guy")
@ -1970,6 +2117,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
}
}
assert result.data == expected
assert_async_result_equal(schema, query, result, variable_values={"last": 2})
result = schema.execute(query, variable_values={"last": 4})
assert not result.errors
@ -1984,6 +2132,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
}
}
assert result.data == expected
assert_async_result_equal(schema, query, result, variable_values={"last": 4})
result = schema.execute(query, variable_values={"last": 20})
assert not result.errors
@ -1998,6 +2147,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
}
}
assert result.data == expected
assert_async_result_equal(schema, query, result, variable_values={"last": 20})
def test_should_query_nullable_foreign_key():

View File

@ -16,6 +16,7 @@ from .utils import (
DJANGO_FILTER_INSTALLED,
camelize,
get_model_fields,
is_running_async,
is_valid_django_model,
)
@ -288,7 +289,11 @@ class DjangoObjectType(ObjectType):
def get_node(cls, info, id):
queryset = cls.get_queryset(cls._meta.model.objects, info)
try:
if is_running_async():
return queryset.aget(pk=id)
return queryset.get(pk=id)
except cls._meta.model.DoesNotExist:
return None

View File

@ -5,6 +5,8 @@ from .utils import (
camelize,
get_model_fields,
get_reverse_fields,
is_running_async,
is_sync_function,
is_valid_django_model,
maybe_queryset,
)
@ -17,5 +19,7 @@ __all__ = [
"camelize",
"is_valid_django_model",
"GraphQLTestCase",
"is_sync_function",
"is_running_async",
"bypass_get_queryset",
]

View File

@ -1,4 +1,5 @@
import inspect
from asyncio import get_running_loop
import django
from django.db import connection, models, transaction
@ -139,6 +140,21 @@ def set_rollback():
transaction.set_rollback(True)
def is_running_async():
try:
get_running_loop()
except RuntimeError:
return False
else:
return True
def is_sync_function(func):
return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
func
)
def bypass_get_queryset(resolver):
"""
Adds a bypass_get_queryset attribute to the resolver, which is used to

View File

@ -1,12 +1,14 @@
import inspect
import json
import re
import traceback
from asyncio import coroutines, gather
from django.db import connection, transaction
from django.http import HttpResponse, HttpResponseNotAllowed
from django.http.response import HttpResponseBadRequest
from django.shortcuts import render
from django.utils.decorators import method_decorator
from django.utils.decorators import classonlymethod, method_decorator
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import View
from graphql import (
@ -431,3 +433,336 @@ class GraphQLView(View):
meta = request.META
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()
class AsyncGraphQLView(GraphQLView):
schema = None
graphiql = False
middleware = None
root_value = None
pretty = False
batch = False
subscription_path = None
execution_context_class = None
def __init__(
self,
schema=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
subscription_path=None,
execution_context_class=None,
):
if not schema:
schema = graphene_settings.SCHEMA
if middleware is None:
middleware = graphene_settings.MIDDLEWARE
self.schema = self.schema or schema
if middleware is not None:
if isinstance(middleware, MiddlewareManager):
self.middleware = middleware
else:
self.middleware = list(instantiate_middleware(middleware))
self.root_value = root_value
self.pretty = self.pretty or pretty
self.graphiql = self.graphiql or graphiql
self.batch = self.batch or batch
self.execution_context_class = execution_context_class
if subscription_path is None:
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
assert isinstance(
self.schema, Schema
), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal
def get_root_value(self, request):
return self.root_value
def get_middleware(self, request):
return self.middleware
def get_context(self, request):
return request
@classonlymethod
def as_view(cls, **initkwargs):
view = super().as_view(**initkwargs)
view._is_coroutine = coroutines._is_coroutine
return view
async def dispatch(self, request, *args, **kwargs):
try:
if request.method.lower() not in ("get", "post"):
raise HttpError(
HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
if show_graphiql:
return self.render_graphiql(
request,
# Dependency parameters.
whatwg_fetch_version=self.whatwg_fetch_version,
whatwg_fetch_sri=self.whatwg_fetch_sri,
react_version=self.react_version,
react_sri=self.react_sri,
react_dom_sri=self.react_dom_sri,
graphiql_version=self.graphiql_version,
graphiql_sri=self.graphiql_sri,
graphiql_css_sri=self.graphiql_css_sri,
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version,
graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri,
# The SUBSCRIPTION_PATH setting.
subscription_path=self.subscription_path,
# GraphiQL headers tab,
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
)
if self.batch:
responses = await gather(
*[self.get_response(request, entry) for entry in data]
)
result = "[{}]".format(
",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else:
result, status_code = await self.get_response(
request, data, show_graphiql
)
return HttpResponse(
status=status_code, content=result, content_type="application/json"
)
except HttpError as e:
response = e.response
response["Content-Type"] = "application/json"
response.content = self.json_encode(
request, {"errors": [self.format_error(e)]}
)
return response
async def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(request, data)
execution_result = await self.execute_graphql_request(
request, data, query, variables, operation_name, show_graphiql
)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
set_rollback()
status_code = 200
if execution_result:
response = {}
if execution_result.errors:
for e in execution_result.errors:
print(e)
traceback.print_tb(e.__traceback__)
set_rollback()
response["errors"] = [
self.format_error(e) for e in execution_result.errors
]
if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
status_code = 400
else:
response["data"] = execution_result.data
if self.batch:
response["id"] = id
response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql)
else:
result = None
return result, status_code
def render_graphiql(self, request, **data):
return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False):
if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(",", ":"))
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
def parse_body(self, request):
content_type = self.get_content_type(request)
if content_type == "application/graphql":
return {"query": request.body.decode()}
elif content_type == "application/json":
# noinspection PyBroadException
try:
body = request.body.decode("utf-8")
except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e)))
try:
request_json = json.loads(body)
if self.batch:
assert isinstance(request_json, list), (
"Batch requests should receive a list, but received {}."
).format(repr(request_json))
assert (
len(request_json) > 0
), "Received an empty list in the batch request."
else:
assert isinstance(
request_json, dict
), "The received data is not a valid JSON query."
return request_json
except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST
return {}
async def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query:
if show_graphiql:
return None
raise HttpError(HttpResponseBadRequest("Must provide query string."))
try:
document = parse(query)
except Exception as e:
return ExecutionResult(errors=[e])
if request.method.lower() == "get":
operation_ast = get_operation_ast(document, operation_name)
if operation_ast and operation_ast.operation != OperationType.QUERY:
if show_graphiql:
return None
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_ast.operation.value
),
)
)
try:
extra_options = {}
if self.execution_context_class:
extra_options["execution_context_class"] = self.execution_context_class
options = {
"source": query,
"root_value": self.get_root_value(request),
"variable_values": variables,
"operation_name": operation_name,
"context_value": self.get_context(request),
"middleware": self.get_middleware(request),
}
options.update(extra_options)
operation_ast = get_operation_ast(document, operation_name)
if (
operation_ast
and operation_ast.operation == OperationType.MUTATION
and (
graphene_settings.ATOMIC_MUTATIONS is True
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
)
):
with transaction.atomic():
result = await self.schema.execute_async(**options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
transaction.set_rollback(True)
return result
return await self.schema.execute_async(**options)
except Exception as e:
return ExecutionResult(errors=[e])
@classmethod
def can_display_graphiql(cls, request, data):
raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request)
@classmethod
def request_wants_html(cls, request):
accepted = get_accepted_content_types(request)
accepted_length = len(accepted)
# the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
html_priority = (
accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
return html_priority > json_priority
@staticmethod
def get_graphql_params(request, data):
query = request.GET.get("query") or data.get("query")
variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, str):
try:
variables = json.loads(variables)
except Exception:
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
operation_name = request.GET.get("operationName") or data.get("operationName")
if operation_name == "null":
operation_name = None
return query, variables, operation_name, id
@staticmethod
def format_error(error):
if isinstance(error, GraphQLError):
return error.formatted
return {"message": str(error)}
@staticmethod
def get_content_type(request):
meta = request.META
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()

View File

@ -22,6 +22,7 @@ tests_require = [
"pytz",
"django-filter>=22.1",
"pytest-django>=4.5.2",
"pytest-asyncio>=0.16,<2",
] + rest_framework_require