From 58d846e5e748367dfab91dc5b73afb563b12dc68 Mon Sep 17 00:00:00 2001 From: NateScarlet Date: Tue, 18 Dec 2018 18:11:57 +0800 Subject: [PATCH] Support use `first` and `last` at same time as `offset` --- graphene_django/fields.py | 70 +++++++++++---------- graphene_django/filter/tests/test_fields.py | 54 +++++++++++++--- graphene_django/tests/test_query.py | 20 +++--- 3 files changed, 92 insertions(+), 52 deletions(-) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 1ecce45..7261baf 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,11 +1,10 @@ from functools import partial from django.db.models.query import QuerySet - +from graphene.relay import ConnectionField, PageInfo +from graphene.types import Field, List from promise import Promise -from graphene.types import Field, List -from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice from .settings import graphene_settings @@ -103,39 +102,44 @@ class DjangoConnectionField(ConnectionField): @classmethod def connection_resolver( - cls, - resolver, - connection, - default_manager, - max_limit, - enforce_first_or_last, - root, - info, - **args - ): - first = args.get("first") - last = args.get("last") + cls, + resolver, + connection, + default_manager, + max_limit, + enforce_first_or_last, + root, + info, + **kwargs): + # pylint: disable=R0913,W0221 - if enforce_first_or_last: - assert first or last, ( - "You must provide a `first` or `last` value to properly paginate the `{}` connection." - ).format(info.field_name) + first = kwargs.get("first") + last = kwargs.get("last") + if not (first is None or first > 0): + raise ValueError( + "`first` argument must be positive, got `{first}`".format(**locals())) + if not (last is None or last > 0): + raise ValueError( + "`last` argument must be positive, got `{last}`".format(**locals())) + if enforce_first_or_last and not (first or last): + raise ValueError( + "You must provide a `first` or `last` value " + "to properly paginate the `{info.field_name}` connection.".format(**locals())) - if max_limit: - if first: - assert first <= max_limit, ( - "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." - ).format(first, info.field_name, max_limit) - args["first"] = min(first, max_limit) + if not max_limit: + pass + elif first is None and last is None: + kwargs['first'] = max_limit + else: + count = min(i for i in (first, last) if i) + if count > max_limit: + raise ValueError(("Requesting {count} records " + "on the `{info.field_name}` connection " + "exceeds the limit of {max_limit} records.").format(**locals())) - if last: - assert last <= max_limit, ( - "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." - ).format(last, info.field_name, max_limit) - args["last"] = min(last, max_limit) - - iterable = resolver(root, info, **args) - on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + iterable = resolver(root, info, **kwargs) + on_resolve = partial(cls.resolve_connection, + connection, default_manager, kwargs) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index f9ef0ae..f5d2fb7 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -1,17 +1,18 @@ from datetime import datetime import pytest - -from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String -from graphene.relay import Node -from graphene_django import DjangoObjectType -from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField -from graphene_django.tests.models import Article, Pet, Reporter -from graphene_django.utils import DJANGO_FILTER_INSTALLED - # for annotation test from django.db.models import TextField, Value from django.db.models.functions import Concat +from graphene import (Argument, Boolean, Field, Float, ObjectType, Schema, + String) +from graphene.relay import Node + +from graphene_django import DjangoObjectType +from graphene_django.forms import (GlobalIDFormField, + GlobalIDMultipleChoiceField) +from graphene_django.tests.models import Article, Pet, Reporter +from graphene_django.utils import DJANGO_FILTER_INSTALLED pytestmark = [] @@ -697,3 +698,40 @@ def test_annotation_is_perserved(): assert not result.errors assert result.data == expected + +def test_filter_with_union(): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = ("first_name",) + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterType) + + @classmethod + def resolve_all_reporters(cls, root, info, **kwargs): + ret = Reporter.objects.none() | Reporter.objects.filter(first_name="John") + + + Reporter.objects.create(first_name="John", last_name="Doe") + + schema = Schema(query=Query) + + query = """ + query NodeFilteringQuery { + allReporters(firstName: "abc") { + edges { + node { + firstName + } + } + } + } + """ + expected = {"allReporters": {"edges": []}} + + result = schema.execute(query) + + assert not result.errors + assert result.data == expected diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 1716034..955b94a 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -1,21 +1,19 @@ import datetime +import graphene import pytest from django.db import models +from django.db.models import Q from django.utils.functional import SimpleLazyObject +from graphene.relay import Node from py.test import raises -from django.db.models import Q - -import graphene -from graphene.relay import Node - -from ..utils import DJANGO_FILTER_INSTALLED -from ..compat import MissingType, JSONField +from ..compat import JSONField, MissingType from ..fields import DjangoConnectionField -from ..types import DjangoObjectType from ..settings import graphene_settings -from .models import Article, CNNReporter, Reporter, Film, FilmDetails +from ..types import DjangoObjectType +from ..utils import DJANGO_FILTER_INSTALLED +from .models import Article, CNNReporter, Film, FilmDetails, Reporter pytestmark = pytest.mark.django_db @@ -603,7 +601,7 @@ def test_should_error_if_first_is_greater_than_max(): assert len(result.errors) == 1 assert str(result.errors[0]) == ( "Requesting 101 records on the `allReporters` connection " - "exceeds the `first` limit of 100 records." + "exceeds the limit of 100 records." ) assert result.data == expected @@ -644,7 +642,7 @@ def test_should_error_if_last_is_greater_than_max(): assert len(result.errors) == 1 assert str(result.errors[0]) == ( "Requesting 101 records on the `allReporters` connection " - "exceeds the `last` limit of 100 records." + "exceeds the limit of 100 records." ) assert result.data == expected