diff --git a/graphene_django/fields.py b/graphene_django/fields.py index e6daa88..ad9bff3 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -154,32 +154,42 @@ class DjangoConnectionField(ConnectionField): enforce_first_or_last, root, info, - **args + **kwargs ): - first = args.get("first") - last = args.get("last") + first = kwargs.get("first") + last = kwargs.get("last") - if enforce_first_or_last: - assert first or last, ( - "You must provide a `first` or `last` value to properly paginate the `{}` connection." - ).format(info.field_name) + if not (first is None or first > 0): + raise ValueError( + "`first` argument must be positive, got `{first}`".format(**locals())) + if not (last is None or last > 0): + raise ValueError( + "`last` argument must be positive, got `{last}`".format(**locals())) + + if enforce_first_or_last and not (first or last): + raise ValueError( + "You must provide a `first` or `last` value " + "to properly paginate the `{info.field_name}` connection.".format(**locals())) if max_limit: if first: assert first <= max_limit, ( "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." ).format(first, info.field_name, max_limit) - args["first"] = min(first, max_limit) + kwargs["first"] = min(first, max_limit) if last: assert last <= max_limit, ( "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." ).format(last, info.field_name, max_limit) - args["last"] = min(last, max_limit) + kwargs["last"] = min(last, max_limit) - iterable = resolver(root, info, **args) - queryset = cls.resolve_queryset(connection, default_manager, info, args) - on_resolve = partial(cls.resolve_connection, connection, queryset, args) + if first is None and last is None: + kwargs['first'] = max_limit + + iterable = resolver(root, info, **kwargs) + queryset = cls.resolve_queryset(connection, default_manager, info, kwargs) + on_resolve = partial(cls.resolve_connection, connection, queryset, kwargs) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 1ffa0f4..0423e57 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -950,7 +950,7 @@ def test_filter_filterset_based_on_mixin(): } } } - """ + """ % reporter_1.email ) @@ -971,3 +971,41 @@ def test_filter_filterset_based_on_mixin(): assert not result.errors assert result.data == expected + + +def test_filter_with_union(): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = ("first_name",) + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterType) + + @classmethod + def resolve_all_reporters(cls, root, info, **kwargs): + ret = Reporter.objects.none() | Reporter.objects.filter(first_name="John") + + + Reporter.objects.create(first_name="John", last_name="Doe") + + schema = Schema(query=Query) + + query = """ + query NodeFilteringQuery { + allReporters(firstName: "abc") { + edges { + node { + firstName + } + } + } + } + """ + expected = {"allReporters": {"edges": []}} + + result = schema.execute(query) + + assert not result.errors + assert result.data == expected diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index f24f84b..15345a2 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -1,23 +1,23 @@ import base64 import datetime +import graphene import pytest from django.db import models -from django.utils.functional import SimpleLazyObject -from py.test import raises - from django.db.models import Q from graphql_relay import to_global_id import graphene +from django.utils.functional import SimpleLazyObject from graphene.relay import Node +from py.test import raises -from ..utils import DJANGO_FILTER_INSTALLED -from ..compat import MissingType, JSONField +from ..compat import JSONField, MissingType from ..fields import DjangoConnectionField -from ..types import DjangoObjectType from ..settings import graphene_settings -from .models import Article, CNNReporter, Reporter, Film, FilmDetails +from ..types import DjangoObjectType +from ..utils import DJANGO_FILTER_INSTALLED +from .models import Article, CNNReporter, Film, FilmDetails, Reporter pytestmark = pytest.mark.django_db @@ -661,7 +661,7 @@ def test_should_error_if_first_is_greater_than_max(): assert len(result.errors) == 1 assert str(result.errors[0]) == ( "Requesting 101 records on the `allReporters` connection " - "exceeds the `first` limit of 100 records." + "exceeds the limit of 100 records." ) assert result.data == expected @@ -702,7 +702,7 @@ def test_should_error_if_last_is_greater_than_max(): assert len(result.errors) == 1 assert str(result.errors[0]) == ( "Requesting 101 records on the `allReporters` connection " - "exceeds the `last` limit of 100 records." + "exceeds the limit of 100 records." ) assert result.data == expected