From 99512c53a100e8fcc515f62caa4311956ca6847f Mon Sep 17 00:00:00 2001 From: Thomas Leonard <64223923+tcleonard@users.noreply.github.com> Date: Wed, 23 Dec 2020 05:10:39 +0100 Subject: [PATCH] fix: in and range filters on DjangoFilterConnectionField (#1070) Co-authored-by: Thomas Leonard --- graphene_django/filter/fields.py | 7 +- .../filter/tests/test_in_filter.py | 139 ++++++++++++++++++ graphene_django/filter/utils.py | 67 +++++++-- 3 files changed, 202 insertions(+), 11 deletions(-) create mode 100644 graphene_django/filter/tests/test_in_filter.py diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 3a98e8d..2ee374c 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -21,6 +21,7 @@ class DjangoFilterConnectionField(DjangoConnectionField): self._fields = fields self._provided_filterset_class = filterset_class self._filterset_class = None + self._filtering_args = None self._extra_filter_meta = extra_filter_meta self._base_args = None super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) @@ -50,7 +51,11 @@ class DjangoFilterConnectionField(DjangoConnectionField): @property def filtering_args(self): - return get_filtering_args_from_filterset(self.filterset_class, self.node_type) + if not self._filtering_args: + self._filtering_args = get_filtering_args_from_filterset( + self.filterset_class, self.node_type + ) + return self._filtering_args @classmethod def resolve_queryset( diff --git a/graphene_django/filter/tests/test_in_filter.py b/graphene_django/filter/tests/test_in_filter.py new file mode 100644 index 0000000..3d4034e --- /dev/null +++ b/graphene_django/filter/tests/test_in_filter.py @@ -0,0 +1,139 @@ +import pytest + +from graphene import ObjectType, Schema +from graphene.relay import Node +from graphene_django import DjangoObjectType +from graphene_django.tests.models import Pet +from graphene_django.utils import DJANGO_FILTER_INSTALLED + +pytestmark = [] + +if DJANGO_FILTER_INSTALLED: + from graphene_django.filter import DjangoFilterConnectionField +else: + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) + + +class PetNode(DjangoObjectType): + class Meta: + model = Pet + interfaces = (Node,) + filter_fields = { + "name": ["exact", "in"], + "age": ["exact", "in", "range"], + } + + +class Query(ObjectType): + pets = DjangoFilterConnectionField(PetNode) + + +def test_string_in_filter(): + """ + Test in filter on a string field. + """ + Pet.objects.create(name="Brutus", age=12) + Pet.objects.create(name="Mimi", age=3) + Pet.objects.create(name="Jojo, the rabbit", age=3) + + schema = Schema(query=Query) + + query = """ + query { + pets (name_In: ["Brutus", "Jojo, the rabbit"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["pets"]["edges"] == [ + {"node": {"name": "Brutus"}}, + {"node": {"name": "Jojo, the rabbit"}}, + ] + + +def test_int_in_filter(): + """ + Test in filter on an integer field. + """ + Pet.objects.create(name="Brutus", age=12) + Pet.objects.create(name="Mimi", age=3) + Pet.objects.create(name="Jojo, the rabbit", age=3) + + schema = Schema(query=Query) + + query = """ + query { + pets (age_In: [3]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["pets"]["edges"] == [ + {"node": {"name": "Mimi"}}, + {"node": {"name": "Jojo, the rabbit"}}, + ] + + query = """ + query { + pets (age_In: [3, 12]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["pets"]["edges"] == [ + {"node": {"name": "Brutus"}}, + {"node": {"name": "Mimi"}}, + {"node": {"name": "Jojo, the rabbit"}}, + ] + + +def test_int_range_filter(): + """ + Test in filter on an integer field. + """ + Pet.objects.create(name="Brutus", age=12) + Pet.objects.create(name="Mimi", age=8) + Pet.objects.create(name="Jojo, the rabbit", age=3) + Pet.objects.create(name="Picotin", age=5) + + schema = Schema(query=Query) + + query = """ + query { + pets (age_Range: [4, 9]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["pets"]["edges"] == [ + {"node": {"name": "Mimi"}}, + {"node": {"name": "Picotin"}}, + ] diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index c5f18e2..becd5f5 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -1,6 +1,10 @@ import six +from graphene import List + from django_filters.utils import get_model_field +from django_filters.filters import Filter, BaseCSVFilter + from .filterset import custom_filterset_factory, setup_filterset @@ -17,8 +21,11 @@ def get_filtering_args_from_filterset(filterset_class, type): form_field = None if name in filterset_class.declared_filters: + # Get the filter field from the explicitly declared filter form_field = filter_field.field + field = convert_form_field(form_field) else: + # Get the filter field with no explicit type declaration model_field = get_model_field(model, filter_field.field_name) filter_type = filter_field.lookup_expr if filter_type != "isnull" and hasattr(model_field, "formfield"): @@ -26,12 +33,19 @@ def get_filtering_args_from_filterset(filterset_class, type): required=filter_field.extra.get("required", False) ) - # Fallback to field defined on filter if we can't get it from the - # model field - if not form_field: - form_field = filter_field.field + # Fallback to field defined on filter if we can't get it from the + # model field + if not form_field: + form_field = filter_field.field - field_type = convert_form_field(form_field).Argument() + field = convert_form_field(form_field) + + if filter_type in ["in", "range"]: + # Replace CSV filters (`in`, `range`) argument type to be a list of the same type as the field. + # See comments in `replace_csv_filters` method for more details. + field = List(field.get_type()) + + field_type = field.Argument() field_type.description = filter_field.label args[name] = field_type @@ -39,9 +53,42 @@ def get_filtering_args_from_filterset(filterset_class, type): def get_filterset_class(filterset_class, **meta): - """Get the class to be used as the FilterSet""" + """ + Get the class to be used as the FilterSet. + """ if filterset_class: - # If were given a FilterSet class, then set it up and - # return it - return setup_filterset(filterset_class) - return custom_filterset_factory(**meta) + # If were given a FilterSet class, then set it up. + graphene_filterset_class = setup_filterset(filterset_class) + else: + # Otherwise create one. + graphene_filterset_class = custom_filterset_factory(**meta) + + replace_csv_filters(graphene_filterset_class) + return graphene_filterset_class + + +def replace_csv_filters(filterset_class): + """ + Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore + but regular Filter objects that simply use the input value as filter argument on the queryset. + + This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we + can actually have a list as input and have a proper type verification of each value in the list. + + See issue https://github.com/graphql-python/graphene-django/issues/1068. + """ + for name, filter_field in six.iteritems(filterset_class.base_filters): + filter_type = filter_field.lookup_expr + if ( + filter_type in ["in", "range"] + and name not in filterset_class.declared_filters + ): + assert isinstance(filter_field, BaseCSVFilter) + filterset_class.base_filters[name] = Filter( + field_name=filter_field.field_name, + lookup_expr=filter_field.lookup_expr, + label=filter_field.label, + method=filter_field.method, + exclude=filter_field.exclude, + **filter_field.extra + )