diff --git a/graphene_django/filter/tests/test_in_filter.py b/graphene_django/filter/tests/test_in_filter.py index 3d4034e..7bbee65 100644 --- a/graphene_django/filter/tests/test_in_filter.py +++ b/graphene_django/filter/tests/test_in_filter.py @@ -1,9 +1,11 @@ import pytest +from django_filters import FilterSet +from django_filters import rest_framework as filters from graphene import ObjectType, Schema from graphene.relay import Node from graphene_django import DjangoObjectType -from graphene_django.tests.models import Pet +from graphene_django.tests.models import Pet, Person from graphene_django.utils import DJANGO_FILTER_INSTALLED pytestmark = [] @@ -28,8 +30,27 @@ class PetNode(DjangoObjectType): } +class PersonFilterSet(FilterSet): + class Meta: + model = Person + fields = {} + + names = filters.BaseInFilter(method="filter_names") + + def filter_names(self, qs, name, value): + return qs.filter(name__in=value) + + +class PersonNode(DjangoObjectType): + class Meta: + model = Person + interfaces = (Node,) + filterset_class = PersonFilterSet + + class Query(ObjectType): pets = DjangoFilterConnectionField(PetNode) + people = DjangoFilterConnectionField(PersonNode) def test_string_in_filter(): @@ -61,6 +82,33 @@ def test_string_in_filter(): ] +def test_string_in_filter_with_filterset_class(): + """Test in filter on a string field with a custom filterset class.""" + Person.objects.create(name="John") + Person.objects.create(name="Michael") + Person.objects.create(name="Angela") + + schema = Schema(query=Query) + + query = """ + query { + people (names: ["John", "Michael"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["people"]["edges"] == [ + {"node": {"name": "John"}}, + {"node": {"name": "Michael"}}, + ] + + def test_int_in_filter(): """ Test in filter on an integer field. diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index becd5f5..71c5b49 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -19,6 +19,7 @@ def get_filtering_args_from_filterset(filterset_class, type): model = filterset_class._meta.model for name, filter_field in six.iteritems(filterset_class.base_filters): form_field = None + filter_type = filter_field.lookup_expr if name in filterset_class.declared_filters: # Get the filter field from the explicitly declared filter @@ -27,7 +28,6 @@ def get_filtering_args_from_filterset(filterset_class, type): else: # Get the filter field with no explicit type declaration model_field = get_model_field(model, filter_field.field_name) - filter_type = filter_field.lookup_expr if filter_type != "isnull" and hasattr(model_field, "formfield"): form_field = model_field.formfield( required=filter_field.extra.get("required", False) @@ -40,10 +40,11 @@ def get_filtering_args_from_filterset(filterset_class, type): field = convert_form_field(form_field) - if filter_type in ["in", "range"]: - # Replace CSV filters (`in`, `range`) argument type to be a list of the same type as the field. - # See comments in `replace_csv_filters` method for more details. - field = List(field.get_type()) + if filter_type in ["in", "range"]: + # Replace CSV filters (`in`, `range`) argument type to be a list of + # the same type as the field. See comments in + # `replace_csv_filters` method for more details. + field = List(field.get_type()) field_type = field.Argument() field_type.description = filter_field.label @@ -79,10 +80,7 @@ def replace_csv_filters(filterset_class): """ for name, filter_field in six.iteritems(filterset_class.base_filters): filter_type = filter_field.lookup_expr - if ( - filter_type in ["in", "range"] - and name not in filterset_class.declared_filters - ): + if filter_type in ["in", "range"]: assert isinstance(filter_field, BaseCSVFilter) filterset_class.base_filters[name] = Filter( field_name=filter_field.field_name, diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 44a5d8a..20f509c 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -6,6 +6,10 @@ from django.utils.translation import ugettext_lazy as _ CHOICES = ((1, "this"), (2, _("that"))) +class Person(models.Model): + name = models.CharField(max_length=30) + + class Pet(models.Model): name = models.CharField(max_length=30) age = models.PositiveIntegerField()