diff --git a/examples/cookbook-plain/requirements.txt b/examples/cookbook-plain/requirements.txt index 480f757..ae9ecc9 100644 --- a/examples/cookbook-plain/requirements.txt +++ b/examples/cookbook-plain/requirements.txt @@ -1,4 +1,4 @@ graphene>=2.1,<3 graphene-django>=2.1,<3 graphql-core>=2.1,<3 -django==3.0.3 +django==3.0.7 diff --git a/examples/cookbook/requirements.txt b/examples/cookbook/requirements.txt index 4375fcc..7ae2d89 100644 --- a/examples/cookbook/requirements.txt +++ b/examples/cookbook/requirements.txt @@ -1,5 +1,5 @@ graphene>=2.1,<3 graphene-django>=2.1,<3 graphql-core>=2.1,<3 -django==3.0.3 +django==3.0.7 django-filter>=2 diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 1cdc8cc..860887f 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -281,11 +281,15 @@ def convert_field_to_djangomodel(field, registry=None): @convert_django_field.register(ArrayField) def convert_postgres_array_to_list(field, registry=None): - base_type = convert_django_field(field.base_field) - if not isinstance(base_type, (List, NonNull)): - base_type = type(base_type) + inner_type = convert_django_field(field.base_field) + if not isinstance(inner_type, (List, NonNull)): + inner_type = ( + NonNull(type(inner_type)) + if inner_type.kwargs["required"] + else type(inner_type) + ) return List( - base_type, + inner_type, description=get_django_field_description(field), required=not field.null, ) @@ -303,7 +307,11 @@ def convert_postgres_field_to_string(field, registry=None): def convert_postgres_range_to_string(field, registry=None): inner_type = convert_django_field(field.base_field) if not isinstance(inner_type, (List, NonNull)): - inner_type = type(inner_type) + inner_type = ( + NonNull(type(inner_type)) + if inner_type.kwargs["required"] + else type(inner_type) + ) return List( inner_type, description=get_django_field_description(field), diff --git a/graphene_django/debug/tests/test_query.py b/graphene_django/debug/tests/test_query.py index fcfcd71..d963b9c 100644 --- a/graphene_django/debug/tests/test_query.py +++ b/graphene_django/debug/tests/test_query.py @@ -1,4 +1,5 @@ import graphene +import pytest from graphene.relay import Node from graphene_django import DjangoConnectionField, DjangoObjectType @@ -54,7 +55,10 @@ def test_should_query_field(): assert result.data == expected -def test_should_query_nested_field(): +@pytest.mark.parametrize("max_limit", [None, 100]) +def test_should_query_nested_field(graphene_settings, max_limit): + graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit + r1 = Reporter(last_name="ABA") r1.save() r2 = Reporter(last_name="Griffin") @@ -165,7 +169,10 @@ def test_should_query_list(): assert result.data == expected -def test_should_query_connection(): +@pytest.mark.parametrize("max_limit", [None, 100]) +def test_should_query_connection(graphene_settings, max_limit): + graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit + r1 = Reporter(last_name="ABA") r1.save() r2 = Reporter(last_name="Griffin") @@ -207,12 +214,16 @@ def test_should_query_connection(): ) assert not result.errors assert result.data["allReporters"] == expected["allReporters"] + assert len(result.data["_debug"]["sql"]) == 2 assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) assert result.data["_debug"]["sql"][1]["rawSql"] == query -def test_should_query_connectionfilter(): +@pytest.mark.parametrize("max_limit", [None, 100]) +def test_should_query_connectionfilter(graphene_settings, max_limit): + graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit + from ...filter import DjangoFilterConnectionField r1 = Reporter(last_name="ABA") @@ -257,6 +268,7 @@ def test_should_query_connectionfilter(): ) assert not result.errors assert result.data["allReporters"] == expected["allReporters"] + assert len(result.data["_debug"]["sql"]) == 2 assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) assert result.data["_debug"]["sql"][1]["rawSql"] == query diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 7ad2040..415d792 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,7 +1,10 @@ from functools import partial from django.db.models.query import QuerySet -from graphql_relay.connection.arrayconnection import connection_from_array_slice +from graphql_relay.connection.arrayconnection import ( + connection_from_array_slice, + get_offset_with_default, +) from promise import Promise from graphene import NonNull @@ -127,24 +130,37 @@ class DjangoConnectionField(ConnectionField): return connection._meta.node.get_queryset(queryset, info) @classmethod - def resolve_connection(cls, connection, args, iterable): + def resolve_connection(cls, connection, args, iterable, max_limit=None): iterable = maybe_queryset(iterable) + if isinstance(iterable, QuerySet): - _len = iterable.count() + list_length = iterable.count() + list_slice_length = ( + min(max_limit, list_length) if max_limit is not None else list_length + ) else: - _len = len(iterable) + list_length = len(iterable) + list_slice_length = ( + min(max_limit, list_length) if max_limit is not None else list_length + ) + + after = get_offset_with_default(args.get("after"), -1) + 1 + + if max_limit is not None and args.get("first", None) == None: + args["first"] = max_limit + connection = connection_from_array_slice( - iterable, + iterable[after:], args, - slice_start=0, - array_length=_len, - array_slice_length=_len, + slice_start=after, + array_length=list_length, + array_slice_length=list_slice_length, connection_type=partial(connection_adapter, connection), edge_type=connection.Edge, page_info_type=page_info_adapter, ) connection.iterable = iterable - connection.length = _len + connection.length = list_length return connection @classmethod @@ -189,7 +205,9 @@ class DjangoConnectionField(ConnectionField): # thus the iterable gets refiltered by resolve_queryset # but iterable might be promise iterable = queryset_resolver(connection, iterable, info, args) - on_resolve = partial(cls.resolve_connection, connection, args) + on_resolve = partial( + cls.resolve_connection, connection, args, max_limit=max_limit + ) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index a46a4b7..3a98e8d 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -1,6 +1,7 @@ from collections import OrderedDict from functools import partial +from django.core.exceptions import ValidationError from graphene.types.argument import to_arguments from ..fields import DjangoConnectionField from .utils import get_filtering_args_from_filterset, get_filterset_class @@ -59,7 +60,12 @@ class DjangoFilterConnectionField(DjangoConnectionField): connection, iterable, info, args ) filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} - return filterset_class(data=filter_kwargs, queryset=qs, request=info.context).qs + filterset = filterset_class( + data=filter_kwargs, queryset=qs, request=info.context + ) + if filterset.form.is_valid(): + return filterset.qs + raise ValidationError(filterset.form.errors.as_json()) def get_queryset_resolver(self): return partial( diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index d0460db..0f5f024 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -412,6 +412,114 @@ def test_global_id_field_relation(): assert id_filter.field_class == GlobalIDFormField +def test_global_id_field_relation_with_filter(): + class ReporterFilterNode(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = ["first_name", "articles"] + + class ArticleFilterNode(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + filter_fields = ["headline", "reporter"] + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + all_articles = DjangoFilterConnectionField(ArticleFilterNode) + reporter = Field(ReporterFilterNode) + article = Field(ArticleFilterNode) + + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + editor=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + editor=r2, + ) + + # Query articles created by the reporter `r1` + query = """ + query { + allArticles (reporter: "UmVwb3J0ZXJGaWx0ZXJOb2RlOjE=") { + edges { + node { + id + } + } + } + } + """ + schema = Schema(query=Query) + result = schema.execute(query) + assert not result.errors + # We should only get back a single article + assert len(result.data["allArticles"]["edges"]) == 1 + + +def test_global_id_field_relation_with_filter_not_valid_id(): + class ReporterFilterNode(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = ["first_name", "articles"] + + class ArticleFilterNode(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + filter_fields = ["headline", "reporter"] + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + all_articles = DjangoFilterConnectionField(ArticleFilterNode) + reporter = Field(ReporterFilterNode) + article = Field(ArticleFilterNode) + + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + editor=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + editor=r2, + ) + + # Filter by the global ID that does not exist + query = """ + query { + allArticles (reporter: "fake_global_id") { + edges { + node { + id + } + } + } + } + """ + schema = Schema(query=Query) + result = schema.execute(query) + assert "Invalid ID specified." in result.errors[0].message + + def test_global_id_multiple_field_implicit(): field = DjangoFilterConnectionField(ReporterNode, fields=["pets"]) filterset_class = field.filterset_class diff --git a/graphene_django/tests/test_converter.py b/graphene_django/tests/test_converter.py index 0d55fc4..501a4f8 100644 --- a/graphene_django/tests/test_converter.py +++ b/graphene_django/tests/test_converter.py @@ -314,6 +314,14 @@ def test_should_postgres_array_convert_list(): ) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) + assert isinstance(field.type.of_type.of_type, graphene.NonNull) + assert field.type.of_type.of_type.of_type == graphene.String + + field = assert_conversion( + ArrayField, graphene.List, models.CharField(max_length=100, null=True) + ) + assert isinstance(field.type, graphene.NonNull) + assert isinstance(field.type.of_type, graphene.List) assert field.type.of_type.of_type == graphene.String @@ -325,6 +333,17 @@ def test_should_postgres_array_multiple_convert_list(): assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type.of_type, graphene.List) + assert isinstance(field.type.of_type.of_type.of_type, graphene.NonNull) + assert field.type.of_type.of_type.of_type.of_type == graphene.String + + field = assert_conversion( + ArrayField, + graphene.List, + ArrayField(models.CharField(max_length=100, null=True)), + ) + assert isinstance(field.type, graphene.NonNull) + assert isinstance(field.type.of_type, graphene.List) + assert isinstance(field.type.of_type.of_type, graphene.List) assert field.type.of_type.of_type.of_type == graphene.String @@ -345,7 +364,8 @@ def test_should_postgres_range_convert_list(): field = assert_conversion(IntegerRangeField, graphene.List) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) - assert field.type.of_type.of_type == graphene.Int + assert isinstance(field.type.of_type.of_type, graphene.NonNull) + assert field.type.of_type.of_type.of_type == graphene.Int def test_generate_enum_name(): diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index ad53230..f81cee0 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -1109,6 +1109,98 @@ def test_should_resolve_get_queryset_connectionfields(): assert result.data == expected +REPORTERS = [ + dict( + first_name="First {}".format(i), + last_name="Last {}".format(i), + email="johndoe+{}@example.com".format(i), + a_choice=1, + ) + for i in range(6) +] + + +def test_should_return_max_limit(graphene_settings): + graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 4 + reporters = [Reporter(**kwargs) for kwargs in REPORTERS] + Reporter.objects.bulk_create(reporters) + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query AllReporters { + allReporters { + edges { + node { + id + } + } + } + } + """ + + result = schema.execute(query) + assert not result.errors + assert len(result.data["allReporters"]["edges"]) == 4 + + +def test_should_have_next_page(graphene_settings): + graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 4 + reporters = [Reporter(**kwargs) for kwargs in REPORTERS] + Reporter.objects.bulk_create(reporters) + db_reporters = Reporter.objects.all() + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query AllReporters($first: Int, $after: String) { + allReporters(first: $first, after: $after) { + pageInfo { + hasNextPage + endCursor + } + edges { + node { + id + } + } + } + } + """ + + result = schema.execute(query, variable_values={}) + assert not result.errors + assert len(result.data["allReporters"]["edges"]) == 4 + assert result.data["allReporters"]["pageInfo"]["hasNextPage"] + + last_result = result.data["allReporters"]["pageInfo"]["endCursor"] + result2 = schema.execute(query, variable_values=dict(first=4, after=last_result)) + assert not result2.errors + assert len(result2.data["allReporters"]["edges"]) == 2 + assert not result2.data["allReporters"]["pageInfo"]["hasNextPage"] + gql_reporters = ( + result.data["allReporters"]["edges"] + result2.data["allReporters"]["edges"] + ) + + assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == { + gql_reporter["node"]["id"] for gql_reporter in gql_reporters + } + + def test_should_preserve_prefetch_related(django_assert_num_queries): class ReporterType(DjangoObjectType): class Meta: diff --git a/setup.py b/setup.py index 5e90b7e..2ddbe08 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,10 @@ setup( "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: Implementation :: PyPy", + "Framework :: Django", + "Framework :: Django :: 1.11", + "Framework :: Django :: 2.2", + "Framework :: Django :: 3.0", ], keywords="api graphql protocol rest relay graphene", packages=find_packages(exclude=["tests"]),