From aa30750d395dc1cc5f550d933506d978c20d285e Mon Sep 17 00:00:00 2001 From: Jonathan Kim Date: Sun, 7 Jul 2019 20:11:27 +0100 Subject: [PATCH] Bugfix: Correct filter types for DjangoFilterConnectionFields (#682) * Get form field from Django model before defaulting to django-filter * Add test * Cleanup some flake8 warnings and pytest warnings * Run isort and add black compatible config --- graphene_django/filter/tests/test_fields.py | 73 +++++++++++++++++---- graphene_django/filter/utils.py | 19 +++++- setup.cfg | 5 ++ 3 files changed, 85 insertions(+), 12 deletions(-) diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index b9bc599..d163ff3 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -1,18 +1,17 @@ from datetime import datetime +from textwrap import dedent import pytest +from django.db.models import TextField, Value +from django.db.models.functions import Concat -from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String +from graphene import Argument, Boolean, Field, Float, ObjectType, Schema, String from graphene.relay import Node from graphene_django import DjangoObjectType from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.utils import DJANGO_FILTER_INSTALLED -# for annotation test -from django.db.models import TextField, Value -from django.db.models.functions import Concat - pytestmark = [] if DJANGO_FILTER_INSTALLED: @@ -183,7 +182,7 @@ def test_filter_shortcut_filterset_context(): } """ schema = Schema(query=Query) - result = schema.execute(query, context_value=context()) + result = schema.execute(query, context=context()) assert not result.errors assert len(result.data["contextArticles"]["edges"]) == 1 @@ -462,15 +461,15 @@ def test_filter_filterset_related_results_with_filter(): class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) - r1 = Reporter.objects.create( + Reporter.objects.create( first_name="A test user", last_name="Last Name", email="test1@test.com" ) - r2 = Reporter.objects.create( + Reporter.objects.create( first_name="Other test user", last_name="Other Last Name", email="test2@test.com", ) - r3 = Reporter.objects.create( + Reporter.objects.create( first_name="Random", last_name="RandomLast", email="random@test.com" ) @@ -638,7 +637,7 @@ def test_should_query_filter_node_double_limit_raises(): Reporter.objects.create( first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2 ) - r = Reporter.objects.create( + Reporter.objects.create( first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) @@ -684,7 +683,7 @@ def test_order_by_is_perserved(): return reporters Reporter.objects.create(first_name="b") - r = Reporter.objects.create(first_name="a") + Reporter.objects.create(first_name="a") schema = Schema(query=Query) query = """ @@ -767,3 +766,55 @@ def test_annotation_is_perserved(): assert not result.errors assert result.data == expected + + +def test_integer_field_filter_type(): + class PetType(DjangoObjectType): + class Meta: + model = Pet + interfaces = (Node,) + filter_fields = {"age": ["exact"]} + only_fields = ["age"] + + class Query(ObjectType): + pets = DjangoFilterConnectionField(PetType) + + schema = Schema(query=Query) + + assert str(schema) == dedent( + """\ + schema { + query: Query + } + + interface Node { + id: ID! + } + + type PageInfo { + hasNextPage: Boolean! + hasPreviousPage: Boolean! + startCursor: String + endCursor: String + } + + type PetType implements Node { + age: Int! + id: ID! + } + + type PetTypeConnection { + pageInfo: PageInfo! + edges: [PetTypeEdge]! + } + + type PetTypeEdge { + node: PetType + cursor: String! + } + + type Query { + pets(before: String, after: String, first: Int, last: Int, age: Int): PetTypeConnection + } + """ + ) diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index cfa5621..00030a0 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -11,8 +11,25 @@ def get_filtering_args_from_filterset(filterset_class, type): from ..forms.converter import convert_form_field args = {} + model = filterset_class._meta.model for name, filter_field in six.iteritems(filterset_class.base_filters): - field_type = convert_form_field(filter_field.field).Argument() + if name in filterset_class.declared_filters: + form_field = filter_field.field + else: + field_name = name.split("__", 1)[0] + model_field = model._meta.get_field(field_name) + + if hasattr(model_field, "formfield"): + form_field = model_field.formfield( + required=filter_field.extra.get("required", False) + ) + + # Fallback to field defined on filter if we can't get it from the + # model field + if not form_field: + form_field = filter_field.field + + field_type = convert_form_field(form_field).Argument() field_type.description = filter_field.label args[name] = field_type diff --git a/setup.cfg b/setup.cfg index 7d93d3e..def0b67 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,3 +38,8 @@ omit = */tests/* [isort] known_first_party=graphene,graphene_django +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=88