Merge branch 'master' into v3

This commit is contained in:
Jonathan Kim 2020-06-27 11:05:56 +01:00
commit 965ebdee13
10 changed files with 290 additions and 22 deletions

View File

@ -1,4 +1,4 @@
graphene>=2.1,<3 graphene>=2.1,<3
graphene-django>=2.1,<3 graphene-django>=2.1,<3
graphql-core>=2.1,<3 graphql-core>=2.1,<3
django==3.0.3 django==3.0.7

View File

@ -1,5 +1,5 @@
graphene>=2.1,<3 graphene>=2.1,<3
graphene-django>=2.1,<3 graphene-django>=2.1,<3
graphql-core>=2.1,<3 graphql-core>=2.1,<3
django==3.0.3 django==3.0.7
django-filter>=2 django-filter>=2

View File

@ -281,11 +281,15 @@ def convert_field_to_djangomodel(field, registry=None):
@convert_django_field.register(ArrayField) @convert_django_field.register(ArrayField)
def convert_postgres_array_to_list(field, registry=None): def convert_postgres_array_to_list(field, registry=None):
base_type = convert_django_field(field.base_field) inner_type = convert_django_field(field.base_field)
if not isinstance(base_type, (List, NonNull)): if not isinstance(inner_type, (List, NonNull)):
base_type = type(base_type) inner_type = (
NonNull(type(inner_type))
if inner_type.kwargs["required"]
else type(inner_type)
)
return List( return List(
base_type, inner_type,
description=get_django_field_description(field), description=get_django_field_description(field),
required=not field.null, required=not field.null,
) )
@ -303,7 +307,11 @@ def convert_postgres_field_to_string(field, registry=None):
def convert_postgres_range_to_string(field, registry=None): def convert_postgres_range_to_string(field, registry=None):
inner_type = convert_django_field(field.base_field) inner_type = convert_django_field(field.base_field)
if not isinstance(inner_type, (List, NonNull)): if not isinstance(inner_type, (List, NonNull)):
inner_type = type(inner_type) inner_type = (
NonNull(type(inner_type))
if inner_type.kwargs["required"]
else type(inner_type)
)
return List( return List(
inner_type, inner_type,
description=get_django_field_description(field), description=get_django_field_description(field),

View File

@ -1,4 +1,5 @@
import graphene import graphene
import pytest
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoConnectionField, DjangoObjectType from graphene_django import DjangoConnectionField, DjangoObjectType
@ -54,7 +55,10 @@ def test_should_query_field():
assert result.data == expected assert result.data == expected
def test_should_query_nested_field(): @pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_nested_field(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
r1 = Reporter(last_name="ABA") r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name="Griffin") r2 = Reporter(last_name="Griffin")
@ -165,7 +169,10 @@ def test_should_query_list():
assert result.data == expected assert result.data == expected
def test_should_query_connection(): @pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_connection(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
r1 = Reporter(last_name="ABA") r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name="Griffin") r2 = Reporter(last_name="Griffin")
@ -207,12 +214,16 @@ def test_should_query_connection():
) )
assert not result.errors assert not result.errors
assert result.data["allReporters"] == expected["allReporters"] assert result.data["allReporters"] == expected["allReporters"]
assert len(result.data["_debug"]["sql"]) == 2
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"] assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query) query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query assert result.data["_debug"]["sql"][1]["rawSql"] == query
def test_should_query_connectionfilter(): @pytest.mark.parametrize("max_limit", [None, 100])
def test_should_query_connectionfilter(graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
from ...filter import DjangoFilterConnectionField from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name="ABA") r1 = Reporter(last_name="ABA")
@ -257,6 +268,7 @@ def test_should_query_connectionfilter():
) )
assert not result.errors assert not result.errors
assert result.data["allReporters"] == expected["allReporters"] assert result.data["allReporters"] == expected["allReporters"]
assert len(result.data["_debug"]["sql"]) == 2
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"] assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query) query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query assert result.data["_debug"]["sql"][1]["rawSql"] == query

View File

@ -1,7 +1,10 @@
from functools import partial from functools import partial
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from graphql_relay.connection.arrayconnection import connection_from_array_slice from graphql_relay.connection.arrayconnection import (
connection_from_array_slice,
get_offset_with_default,
)
from promise import Promise from promise import Promise
from graphene import NonNull from graphene import NonNull
@ -127,24 +130,37 @@ class DjangoConnectionField(ConnectionField):
return connection._meta.node.get_queryset(queryset, info) return connection._meta.node.get_queryset(queryset, info)
@classmethod @classmethod
def resolve_connection(cls, connection, args, iterable): def resolve_connection(cls, connection, args, iterable, max_limit=None):
iterable = maybe_queryset(iterable) iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet): if isinstance(iterable, QuerySet):
_len = iterable.count() list_length = iterable.count()
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)
else: else:
_len = len(iterable) list_length = len(iterable)
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)
after = get_offset_with_default(args.get("after"), -1) + 1
if max_limit is not None and args.get("first", None) == None:
args["first"] = max_limit
connection = connection_from_array_slice( connection = connection_from_array_slice(
iterable, iterable[after:],
args, args,
slice_start=0, slice_start=after,
array_length=_len, array_length=list_length,
array_slice_length=_len, array_slice_length=list_slice_length,
connection_type=partial(connection_adapter, connection), connection_type=partial(connection_adapter, connection),
edge_type=connection.Edge, edge_type=connection.Edge,
page_info_type=page_info_adapter, page_info_type=page_info_adapter,
) )
connection.iterable = iterable connection.iterable = iterable
connection.length = _len connection.length = list_length
return connection return connection
@classmethod @classmethod
@ -189,7 +205,9 @@ class DjangoConnectionField(ConnectionField):
# thus the iterable gets refiltered by resolve_queryset # thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise # but iterable might be promise
iterable = queryset_resolver(connection, iterable, info, args) iterable = queryset_resolver(connection, iterable, info, args)
on_resolve = partial(cls.resolve_connection, connection, args) on_resolve = partial(
cls.resolve_connection, connection, args, max_limit=max_limit
)
if Promise.is_thenable(iterable): if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve) return Promise.resolve(iterable).then(on_resolve)

View File

@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from django.core.exceptions import ValidationError
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
@ -59,7 +60,12 @@ class DjangoFilterConnectionField(DjangoConnectionField):
connection, iterable, info, args connection, iterable, info, args
) )
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
return filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs filterset = filterset_class(
data=filter_kwargs, queryset=qs, request=info.context
)
if filterset.form.is_valid():
return filterset.qs
raise ValidationError(filterset.form.errors.as_json())
def get_queryset_resolver(self): def get_queryset_resolver(self):
return partial( return partial(

View File

@ -412,6 +412,114 @@ def test_global_id_field_relation():
assert id_filter.field_class == GlobalIDFormField assert id_filter.field_class == GlobalIDFormField
def test_global_id_field_relation_with_filter():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
all_articles = DjangoFilterConnectionField(ArticleFilterNode)
reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
editor=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
editor=r2,
)
# Query articles created by the reporter `r1`
query = """
query {
allArticles (reporter: "UmVwb3J0ZXJGaWx0ZXJOb2RlOjE=") {
edges {
node {
id
}
}
}
}
"""
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get back a single article
assert len(result.data["allArticles"]["edges"]) == 1
def test_global_id_field_relation_with_filter_not_valid_id():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
all_articles = DjangoFilterConnectionField(ArticleFilterNode)
reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
editor=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
editor=r2,
)
# Filter by the global ID that does not exist
query = """
query {
allArticles (reporter: "fake_global_id") {
edges {
node {
id
}
}
}
}
"""
schema = Schema(query=Query)
result = schema.execute(query)
assert "Invalid ID specified." in result.errors[0].message
def test_global_id_multiple_field_implicit(): def test_global_id_multiple_field_implicit():
field = DjangoFilterConnectionField(ReporterNode, fields=["pets"]) field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
filterset_class = field.filterset_class filterset_class = field.filterset_class

View File

@ -314,6 +314,14 @@ def test_should_postgres_array_convert_list():
) )
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.NonNull)
assert field.type.of_type.of_type.of_type == graphene.String
field = assert_conversion(
ArrayField, graphene.List, models.CharField(max_length=100, null=True)
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert field.type.of_type.of_type == graphene.String assert field.type.of_type.of_type == graphene.String
@ -325,6 +333,17 @@ def test_should_postgres_array_multiple_convert_list():
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List) assert isinstance(field.type.of_type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type.of_type, graphene.NonNull)
assert field.type.of_type.of_type.of_type.of_type == graphene.String
field = assert_conversion(
ArrayField,
graphene.List,
ArrayField(models.CharField(max_length=100, null=True)),
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List)
assert field.type.of_type.of_type.of_type == graphene.String assert field.type.of_type.of_type.of_type == graphene.String
@ -345,7 +364,8 @@ def test_should_postgres_range_convert_list():
field = assert_conversion(IntegerRangeField, graphene.List) field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert field.type.of_type.of_type == graphene.Int assert isinstance(field.type.of_type.of_type, graphene.NonNull)
assert field.type.of_type.of_type.of_type == graphene.Int
def test_generate_enum_name(): def test_generate_enum_name():

View File

@ -1109,6 +1109,98 @@ def test_should_resolve_get_queryset_connectionfields():
assert result.data == expected assert result.data == expected
REPORTERS = [
dict(
first_name="First {}".format(i),
last_name="Last {}".format(i),
email="johndoe+{}@example.com".format(i),
a_choice=1,
)
for i in range(6)
]
def test_should_return_max_limit(graphene_settings):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 4
reporters = [Reporter(**kwargs) for kwargs in REPORTERS]
Reporter.objects.bulk_create(reporters)
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
schema = graphene.Schema(query=Query)
query = """
query AllReporters {
allReporters {
edges {
node {
id
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 4
def test_should_have_next_page(graphene_settings):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 4
reporters = [Reporter(**kwargs) for kwargs in REPORTERS]
Reporter.objects.bulk_create(reporters)
db_reporters = Reporter.objects.all()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
schema = graphene.Schema(query=Query)
query = """
query AllReporters($first: Int, $after: String) {
allReporters(first: $first, after: $after) {
pageInfo {
hasNextPage
endCursor
}
edges {
node {
id
}
}
}
}
"""
result = schema.execute(query, variable_values={})
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 4
assert result.data["allReporters"]["pageInfo"]["hasNextPage"]
last_result = result.data["allReporters"]["pageInfo"]["endCursor"]
result2 = schema.execute(query, variable_values=dict(first=4, after=last_result))
assert not result2.errors
assert len(result2.data["allReporters"]["edges"]) == 2
assert not result2.data["allReporters"]["pageInfo"]["hasNextPage"]
gql_reporters = (
result.data["allReporters"]["edges"] + result2.data["allReporters"]["edges"]
)
assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == {
gql_reporter["node"]["id"] for gql_reporter in gql_reporters
}
def test_should_preserve_prefetch_related(django_assert_num_queries): def test_should_preserve_prefetch_related(django_assert_num_queries):
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:

View File

@ -50,6 +50,10 @@ setup(
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: Implementation :: PyPy",
"Framework :: Django",
"Framework :: Django :: 1.11",
"Framework :: Django :: 2.2",
"Framework :: Django :: 3.0",
], ],
keywords="api graphql protocol rest relay graphene", keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["tests"]), packages=find_packages(exclude=["tests"]),