From e323e2bc0bef36955a32ae4becd828144e866c44 Mon Sep 17 00:00:00 2001 From: Thomas Leonard <64223923+tcleonard@users.noreply.github.com> Date: Tue, 23 Feb 2021 05:22:09 +0100 Subject: [PATCH] Add enum support to filters and fix filter typing (v2) (#1114) * - Add filtering support for choice fields converted to graphql Enum (or not) - Fix type of various filters (used to default to String) - Fix bug with contains introduced in previous PR - Fix bug with declared filters being overridden (see PR #1108) - Fix support for ArrayField and add documentation * Fix tests Co-authored-by: Thomas Leonard --- docs/filtering.rst | 43 +++ graphene_django/filter/__init__.py | 11 +- graphene_django/filter/fields.py | 4 +- graphene_django/filter/filters.py | 34 +- graphene_django/filter/tests/conftest.py | 42 ++- graphene_django/filter/tests/filters.py | 2 +- ...py => test_array_field_contains_filter.py} | 19 +- .../tests/test_array_field_exact_filter.py | 107 ++++++ ....py => test_array_field_overlap_filter.py} | 12 +- .../filter/tests/test_enum_filtering.py | 144 ++++++++ graphene_django/filter/tests/test_fields.py | 51 ++- .../filter/tests/test_in_filter.py | 344 +++++++++++++++--- graphene_django/filter/utils.py | 131 +++++-- graphene_django/tests/models.py | 6 +- graphene_django/tests/test_query.py | 2 + 15 files changed, 838 insertions(+), 114 deletions(-) rename graphene_django/filter/tests/{test_contains_filter.py => test_array_field_contains_filter.py} (74%) create mode 100644 graphene_django/filter/tests/test_array_field_exact_filter.py rename graphene_django/filter/tests/{test_overlap_filter.py => test_array_field_overlap_filter.py} (84%) create mode 100644 graphene_django/filter/tests/test_enum_filtering.py diff --git a/docs/filtering.rst b/docs/filtering.rst index e366fe2..f197b30 100644 --- a/docs/filtering.rst +++ b/docs/filtering.rst @@ -228,3 +228,46 @@ with this set up, you can now order the users under group: } } } + + +PostgreSQL `ArrayField` +----------------------- + +Graphene provides an easy to implement filters on `ArrayField` as they are not natively supported by django_filters: + +.. code:: python + + from django.db import models + from django_filters import FilterSet, OrderingFilter + from graphene_django.filter import ArrayFilter + + class Event(models.Model): + name = models.CharField(max_length=50) + tags = ArrayField(models.CharField(max_length=50)) + + class EventFilterSet(FilterSet): + class Meta: + model = Event + fields = { + "name": ["exact", "contains"], + } + + tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") + tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") + tags = ArrayFilter(field_name="tags", lookup_expr="exact") + + class EventType(DjangoObjectType): + class Meta: + model = Event + interfaces = (Node,) + filterset_class = EventFilterSet + +with this set up, you can now filter events by tags: + +.. code:: + + query { + events(tags_Overlap: ["concert", "festival"]) { + name + } + } diff --git a/graphene_django/filter/__init__.py b/graphene_django/filter/__init__.py index 5de36ad..94570c9 100644 --- a/graphene_django/filter/__init__.py +++ b/graphene_django/filter/__init__.py @@ -9,10 +9,19 @@ if not DJANGO_FILTER_INSTALLED: ) else: from .fields import DjangoFilterConnectionField - from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter + from .filters import ( + ArrayFilter, + GlobalIDFilter, + GlobalIDMultipleChoiceFilter, + ListFilter, + RangeFilter, + ) __all__ = [ "DjangoFilterConnectionField", "GlobalIDFilter", "GlobalIDMultipleChoiceFilter", + "ArrayFilter", + "ListFilter", + "RangeFilter", ] diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 7d8d2d8..9a4cf36 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -43,8 +43,8 @@ class DjangoFilterConnectionField(DjangoConnectionField): if self._extra_filter_meta: meta.update(self._extra_filter_meta) - filterset_class = self._provided_filterset_class or ( - self.node_type._meta.filterset_class + filterset_class = ( + self._provided_filterset_class or self.node_type._meta.filterset_class ) self._filterset_class = get_filterset_class(filterset_class, **meta) diff --git a/graphene_django/filter/filters.py b/graphene_django/filter/filters.py index 44832b5..3275ebf 100644 --- a/graphene_django/filter/filters.py +++ b/graphene_django/filter/filters.py @@ -2,6 +2,7 @@ from django.core.exceptions import ValidationError from django.forms import Field from django_filters import Filter, MultipleChoiceFilter +from django_filters.constants import EMPTY_VALUES from graphql_relay.node.node import from_global_id @@ -31,14 +32,15 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter): return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids) -class InFilter(Filter): +class ListFilter(Filter): """ - Filter for a list of value using the `__in` Django filter. + Filter that takes a list of value as input. + It is for example used for `__in` filters. """ def filter(self, qs, value): """ - Override the default filter class to check first weather the list is + Override the default filter class to check first whether the list is empty or not. This needs to be done as in this case we expect to get an empty output (if not an exclude filter) but django_filter consider an empty list @@ -52,7 +54,7 @@ class InFilter(Filter): else: return qs.none() else: - return super(InFilter, self).filter(qs, value) + return super(ListFilter, self).filter(qs, value) def validate_range(value): @@ -73,3 +75,27 @@ class RangeField(Field): class RangeFilter(Filter): field_class = RangeField + + +class ArrayFilter(Filter): + """ + Filter made for PostgreSQL ArrayField. + """ + + def filter(self, qs, value): + """ + Override the default filter class to check first whether the list is + empty or not. + This needs to be done as in this case we expect to get the filter applied with + an empty list since it's a valid value but django_filter consider an empty list + to be an empty input value (see `EMPTY_VALUES`) meaning that + the filter does not need to be applied (hence returning the original + queryset). + """ + if value in EMPTY_VALUES and value != []: + return qs + if self.distinct: + qs = qs.distinct() + lookup = "%s__%s" % (self.field_name, self.lookup_expr) + qs = self.get_method(qs)(**{lookup: value}) + return qs diff --git a/graphene_django/filter/tests/conftest.py b/graphene_django/filter/tests/conftest.py index 0313645..710234f 100644 --- a/graphene_django/filter/tests/conftest.py +++ b/graphene_django/filter/tests/conftest.py @@ -9,6 +9,7 @@ import graphene from graphene.relay import Node from graphene_django import DjangoObjectType from graphene_django.utils import DJANGO_FILTER_INSTALLED +from graphene_django.filter import ArrayFilter, ListFilter from ...compat import ArrayField @@ -32,27 +33,37 @@ def Event(): class Event(models.Model): name = models.CharField(max_length=50) tags = ArrayField(models.CharField(max_length=50)) + tag_ids = ArrayField(models.IntegerField()) + random_field = ArrayField(models.BooleanField()) return Event @pytest.fixture def EventFilterSet(Event): - - from django.contrib.postgres.forms import SimpleArrayField - - class ArrayFilter(filters.Filter): - base_field_class = SimpleArrayField - class EventFilterSet(FilterSet): class Meta: model = Event fields = { - "name": ["exact"], + "name": ["exact", "contains"], } + # Those are actually usable with our Query fixture bellow tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") + tags = ArrayFilter(field_name="tags", lookup_expr="exact") + + # Those are actually not usable and only to check type declarations + tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains") + tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap") + tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact") + random_field__contains = ArrayFilter( + field_name="random_field", lookup_expr="contains" + ) + random_field__overlap = ArrayFilter( + field_name="random_field", lookup_expr="overlap" + ) + random_field = ArrayFilter(field_name="random_field", lookup_expr="exact") return EventFilterSet @@ -70,6 +81,11 @@ def EventType(Event, EventFilterSet): @pytest.fixture def Query(Event, EventType): + """ + Note that we have to use a custom resolver to replicate the arrayfield filter behavior as + we are running unit tests in sqlite which does not have ArrayFields. + """ + class Query(graphene.ObjectType): events = DjangoFilterConnectionField(EventType) @@ -79,6 +95,7 @@ def Query(Event, EventType): Event(name="Live Show", tags=["concert", "music", "rock"],), Event(name="Musical", tags=["movie", "music"],), Event(name="Ballet", tags=["concert", "dance"],), + Event(name="Speech", tags=[],), ] STORE["events"] = events @@ -105,6 +122,13 @@ def Query(Event, EventType): STORE["events"], ) ) + if "tags__exact" in kwargs: + STORE["events"] = list( + filter( + lambda e: set(kwargs["tags__exact"]) == set(e.tags), + STORE["events"], + ) + ) def mock_queryset_filter(*args, **kwargs): filter_events(**kwargs) @@ -121,7 +145,9 @@ def Query(Event, EventType): m_queryset.filter.side_effect = mock_queryset_filter m_queryset.none.side_effect = mock_queryset_none m_queryset.count.side_effect = mock_queryset_count - m_queryset.__getitem__.side_effect = STORE["events"].__getitem__ + m_queryset.__getitem__.side_effect = lambda index: STORE[ + "events" + ].__getitem__(index) return m_queryset diff --git a/graphene_django/filter/tests/filters.py b/graphene_django/filter/tests/filters.py index 43b6a87..a7443c0 100644 --- a/graphene_django/filter/tests/filters.py +++ b/graphene_django/filter/tests/filters.py @@ -10,7 +10,7 @@ class ArticleFilter(django_filters.FilterSet): fields = { "headline": ["exact", "icontains"], "pub_date": ["gt", "lt", "exact"], - "reporter": ["exact"], + "reporter": ["exact", "in"], } order_by = OrderingFilter(fields=("pub_date",)) diff --git a/graphene_django/filter/tests/test_contains_filter.py b/graphene_django/filter/tests/test_array_field_contains_filter.py similarity index 74% rename from graphene_django/filter/tests/test_contains_filter.py rename to graphene_django/filter/tests/test_array_field_contains_filter.py index 35e775e..4144614 100644 --- a/graphene_django/filter/tests/test_contains_filter.py +++ b/graphene_django/filter/tests/test_array_field_contains_filter.py @@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_multiple(Query): +def test_array_field_contains_multiple(Query): """ - Test contains filter on a string field. + Test contains filter on a array field of string. """ schema = Schema(query=Query) @@ -32,9 +32,9 @@ def test_string_contains_multiple(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_one(Query): +def test_array_field_contains_one(Query): """ - Test contains filter on a string field. + Test contains filter on a array field of string. """ schema = Schema(query=Query) @@ -59,9 +59,9 @@ def test_string_contains_one(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_contains_none(Query): +def test_array_field_contains_empty_list(Query): """ - Test contains filter on a string field. + Test contains filter on a array field of string. """ schema = Schema(query=Query) @@ -79,4 +79,9 @@ def test_string_contains_none(Query): """ result = schema.execute(query) assert not result.errors - assert result.data["events"]["edges"] == [] + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + {"node": {"name": "Ballet"}}, + {"node": {"name": "Speech"}}, + ] diff --git a/graphene_django/filter/tests/test_array_field_exact_filter.py b/graphene_django/filter/tests/test_array_field_exact_filter.py new file mode 100644 index 0000000..814fd33 --- /dev/null +++ b/graphene_django/filter/tests/test_array_field_exact_filter.py @@ -0,0 +1,107 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_exact_no_match(Query): + """ + Test exact filter on a array field of string. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_exact_match(Query): + """ + Test exact filter on a array field of string. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags: ["movie", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_array_field_exact_empty_list(Query): + """ + Test exact filter on a array field of string. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Speech"}}, + ] + + +def test_array_field_filter_schema_type(Query): + """ + Check that the type in the filter is an array field like on the object type. + """ + schema = Schema(query=Query) + schema_str = str(schema) + + assert ( + """type EventType implements Node { + id: ID! + name: String! + tags: [String!]! + tagIds: [Int!]! + randomField: [Boolean!]! +}""" + in schema_str + ) + + assert ( + """type Query { + events(offset: Int, before: String, after: String, first: Int, last: Int, name: String, name_Contains: String, tags_Contains: [String!], tags_Overlap: [String!], tags: [String!], tagsIds_Contains: [Int!], tagsIds_Overlap: [Int!], tagsIds: [Int!], randomField_Contains: [Boolean!], randomField_Overlap: [Boolean!], randomField: [Boolean!]): EventTypeConnection +}""" + in schema_str + ) diff --git a/graphene_django/filter/tests/test_overlap_filter.py b/graphene_django/filter/tests/test_array_field_overlap_filter.py similarity index 84% rename from graphene_django/filter/tests/test_overlap_filter.py rename to graphene_django/filter/tests/test_array_field_overlap_filter.py index 32dfa44..5ce1576 100644 --- a/graphene_django/filter/tests/test_overlap_filter.py +++ b/graphene_django/filter/tests/test_array_field_overlap_filter.py @@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_multiple(Query): +def test_array_field_overlap_multiple(Query): """ - Test overlap filter on a string field. + Test overlap filter on a array field of string. """ schema = Schema(query=Query) @@ -34,9 +34,9 @@ def test_string_overlap_multiple(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_one(Query): +def test_array_field_overlap_one(Query): """ - Test overlap filter on a string field. + Test overlap filter on a array field of string. """ schema = Schema(query=Query) @@ -61,9 +61,9 @@ def test_string_overlap_one(Query): @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") -def test_string_overlap_none(Query): +def test_array_field_overlap_empty_list(Query): """ - Test overlap filter on a string field. + Test overlap filter on a array field of string. """ schema = Schema(query=Query) diff --git a/graphene_django/filter/tests/test_enum_filtering.py b/graphene_django/filter/tests/test_enum_filtering.py new file mode 100644 index 0000000..650c55e --- /dev/null +++ b/graphene_django/filter/tests/test_enum_filtering.py @@ -0,0 +1,144 @@ +import pytest + +import graphene +from graphene.relay import Node + +from graphene_django import DjangoObjectType, DjangoConnectionField +from graphene_django.tests.models import Article, Reporter +from graphene_django.utils import DJANGO_FILTER_INSTALLED + +pytestmark = [] + +if DJANGO_FILTER_INSTALLED: + from graphene_django.filter import DjangoFilterConnectionField +else: + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) + + +@pytest.fixture +def schema(): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class ArticleType(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + filter_fields = { + "lang": ["exact", "in"], + "reporter__a_choice": ["exact", "in"], + } + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + all_articles = DjangoFilterConnectionField(ArticleType) + + schema = graphene.Schema(query=Query) + return schema + + +@pytest.fixture +def reporter_article_data(): + john = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + jane = Reporter.objects.create( + first_name="Jane", last_name="Doe", email="janedoe@example.com", a_choice=2 + ) + Article.objects.create( + headline="Article Node 1", reporter=john, editor=john, lang="es", + ) + Article.objects.create( + headline="Article Node 2", reporter=john, editor=john, lang="en", + ) + Article.objects.create( + headline="Article Node 3", reporter=jane, editor=jane, lang="en", + ) + + +def test_filter_enum_on_connection(schema, reporter_article_data): + """ + Check that we can filter with enums on a connection. + """ + query = """ + query { + allArticles(lang: ES) { + edges { + node { + headline + } + } + } + } + """ + + expected = {"allArticles": {"edges": [{"node": {"headline": "Article Node 1"}},]}} + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_filter_on_foreign_key_enum_field(schema, reporter_article_data): + """ + Check that we can filter with enums on a field from a foreign key. + """ + query = """ + query { + allArticles(reporter_AChoice: A_1) { + edges { + node { + headline + } + } + } + } + """ + + expected = { + "allArticles": { + "edges": [ + {"node": {"headline": "Article Node 1"}}, + {"node": {"headline": "Article Node 2"}}, + ] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_filter_enum_field_schema_type(schema): + """ + Check that the type in the filter is an enum like on the object type. + """ + schema_str = str(schema) + + assert ( + """type ArticleType implements Node { + id: ID! + headline: String! + pubDate: Date! + pubDateTime: DateTime! + reporter: ReporterType! + editor: ReporterType! + lang: ArticleLang! + importance: ArticleImportance +}""" + in schema_str + ) + + assert ( + """type Query { + allReporters(offset: Int, before: String, after: String, first: Int, last: Int): ReporterTypeConnection + allArticles(offset: Int, before: String, after: String, first: Int, last: Int, lang: ArticleLang, lang_In: [ArticleLang], reporter_AChoice: ReporterAChoice, reporter_AChoice_In: [ReporterAChoice]): ArticleTypeConnection +}""" + in schema_str + ) diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 6de8361..d3e86a5 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -9,7 +9,7 @@ from graphene import Argument, Boolean, Decimal, Field, ObjectType, Schema, Stri from graphene.relay import Node from graphene_django import DjangoObjectType from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField -from graphene_django.tests.models import Article, Pet, Reporter +from graphene_django.tests.models import Article, Person, Pet, Reporter from graphene_django.utils import DJANGO_FILTER_INSTALLED pytestmark = [] @@ -87,6 +87,7 @@ def test_filter_explicit_filterset_arguments(): "pub_date__gt", "pub_date__lt", "reporter", + "reporter__in", ) @@ -676,7 +677,7 @@ def test_should_query_filter_node_limit(): node { id firstName - articles(lang: "es") { + articles(lang: ES) { edges { node { id @@ -1085,7 +1086,7 @@ def test_filter_filterset_based_on_mixin(): return filters - def filter_email_in(cls, queryset, name, value): + def filter_email_in(self, queryset, name, value): return queryset.filter(**{name: [value]}) class NewArticleFilter(ArticleFilterMixin, ArticleFilter): @@ -1171,3 +1172,47 @@ def test_filter_filterset_based_on_mixin(): assert not result.errors assert result.data == expected + + +def test_filter_string_contains(): + class PersonType(DjangoObjectType): + class Meta: + model = Person + interfaces = (Node,) + filter_fields = {"name": ["exact", "in", "contains", "icontains"]} + + class Query(ObjectType): + people = DjangoFilterConnectionField(PersonType) + + schema = Schema(query=Query) + + Person.objects.bulk_create( + [ + Person(name="Jack"), + Person(name="Joe"), + Person(name="Jane"), + Person(name="Peter"), + Person(name="Bob"), + ] + ) + query = """query nameContain($filter: String) { + people(name_Contains: $filter) { + edges { + node { + name + } + } + } + }""" + + result = schema.execute(query, variables={"filter": "Ja"}) + assert not result.errors + assert result.data == { + "people": {"edges": [{"node": {"name": "Jack"}}, {"node": {"name": "Jane"}},]} + } + + result = schema.execute(query, variables={"filter": "o"}) + assert not result.errors + assert result.data == { + "people": {"edges": [{"node": {"name": "Joe"}}, {"node": {"name": "Bob"}},]} + } diff --git a/graphene_django/filter/tests/test_in_filter.py b/graphene_django/filter/tests/test_in_filter.py index 9e9c323..f0015b6 100644 --- a/graphene_django/filter/tests/test_in_filter.py +++ b/graphene_django/filter/tests/test_in_filter.py @@ -1,3 +1,5 @@ +from datetime import datetime + import pytest from django_filters import FilterSet @@ -5,7 +7,8 @@ from django_filters import rest_framework as filters from graphene import ObjectType, Schema from graphene.relay import Node from graphene_django import DjangoObjectType -from graphene_django.tests.models import Pet, Person +from graphene_django.tests.models import Pet, Person, Reporter, Article, Film +from graphene_django.filter.tests.filters import ArticleFilter from graphene_django.utils import DJANGO_FILTER_INSTALLED pytestmark = [] @@ -20,40 +23,72 @@ else: ) -class PetNode(DjangoObjectType): - class Meta: - model = Pet - interfaces = (Node,) - filter_fields = { - "name": ["exact", "in"], - "age": ["exact", "in", "range"], - } +@pytest.fixture +def query(): + class PetNode(DjangoObjectType): + class Meta: + model = Pet + interfaces = (Node,) + filter_fields = { + "id": ["exact", "in"], + "name": ["exact", "in"], + "age": ["exact", "in", "range"], + } + + class ReporterNode(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + # choice filter using enum + filter_fields = {"reporter_type": ["exact", "in"]} + + class ArticleNode(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + filterset_class = ArticleFilter + + class FilmNode(DjangoObjectType): + class Meta: + model = Film + interfaces = (Node,) + # choice filter not using enum + filter_fields = { + "genre": ["exact", "in"], + } + convert_choices_to_enum = False + + class PersonFilterSet(FilterSet): + class Meta: + model = Person + fields = {"name": ["in"]} + + names = filters.BaseInFilter(method="filter_names") + + def filter_names(self, qs, name, value): + """ + This custom filter take a string as input with comma separated values. + Note that the value here is already a list as it has been transformed by the BaseInFilter class. + """ + return qs.filter(name__in=value) + + class PersonNode(DjangoObjectType): + class Meta: + model = Person + interfaces = (Node,) + filterset_class = PersonFilterSet + + class Query(ObjectType): + pets = DjangoFilterConnectionField(PetNode) + people = DjangoFilterConnectionField(PersonNode) + articles = DjangoFilterConnectionField(ArticleNode) + films = DjangoFilterConnectionField(FilmNode) + reporters = DjangoFilterConnectionField(ReporterNode) + + return Query -class PersonFilterSet(FilterSet): - class Meta: - model = Person - fields = {} - - names = filters.BaseInFilter(method="filter_names") - - def filter_names(self, qs, name, value): - return qs.filter(name__in=value) - - -class PersonNode(DjangoObjectType): - class Meta: - model = Person - interfaces = (Node,) - filterset_class = PersonFilterSet - - -class Query(ObjectType): - pets = DjangoFilterConnectionField(PetNode) - people = DjangoFilterConnectionField(PersonNode) - - -def test_string_in_filter(): +def test_string_in_filter(query): """ Test in filter on a string field. """ @@ -61,7 +96,7 @@ def test_string_in_filter(): Pet.objects.create(name="Mimi", age=3) Pet.objects.create(name="Jojo, the rabbit", age=3) - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { @@ -82,17 +117,19 @@ def test_string_in_filter(): ] -def test_string_in_filter_with_filterset_class(): - """Test in filter on a string field with a custom filterset class.""" +def test_string_in_filter_with_otjer_filter(query): + """ + Test in filter on a string field which has also a custom filter doing a similar operation. + """ Person.objects.create(name="John") Person.objects.create(name="Michael") Person.objects.create(name="Angela") - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { - people (names: ["John", "Michael"]) { + people (name_In: ["John", "Michael"]) { edges { node { name @@ -109,7 +146,36 @@ def test_string_in_filter_with_filterset_class(): ] -def test_int_in_filter(): +def test_string_in_filter_with_declared_filter(query): + """ + Test in filter on a string field with a custom filterset class. + """ + Person.objects.create(name="John") + Person.objects.create(name="Michael") + Person.objects.create(name="Angela") + + schema = Schema(query=query) + + query = """ + query { + people (names: "John,Michael") { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["people"]["edges"] == [ + {"node": {"name": "John"}}, + {"node": {"name": "Michael"}}, + ] + + +def test_int_in_filter(query): """ Test in filter on an integer field. """ @@ -117,7 +183,7 @@ def test_int_in_filter(): Pet.objects.create(name="Mimi", age=3) Pet.objects.create(name="Jojo, the rabbit", age=3) - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { @@ -157,7 +223,7 @@ def test_int_in_filter(): ] -def test_in_filter_with_empty_list(): +def test_in_filter_with_empty_list(query): """ Check that using a in filter with an empty list provided as input returns no objects. """ @@ -165,7 +231,7 @@ def test_in_filter_with_empty_list(): Pet.objects.create(name="Mimi", age=8) Pet.objects.create(name="Picotin", age=5) - schema = Schema(query=Query) + schema = Schema(query=query) query = """ query { @@ -181,3 +247,197 @@ def test_in_filter_with_empty_list(): result = schema.execute(query) assert not result.errors assert len(result.data["pets"]["edges"]) == 0 + + +def test_choice_in_filter_without_enum(query): + """ + Test in filter o an choice field not using an enum (Film.genre). + """ + + john_doe = Reporter.objects.create( + first_name="John", last_name="Doe", email="john@doe.com" + ) + jean_bon = Reporter.objects.create( + first_name="Jean", last_name="Bon", email="jean@bon.com" + ) + documentary_film = Film.objects.create(genre="do") + documentary_film.reporters.add(john_doe) + action_film = Film.objects.create(genre="ac") + action_film.reporters.add(john_doe) + other_film = Film.objects.create(genre="ot") + other_film.reporters.add(john_doe) + other_film.reporters.add(jean_bon) + + schema = Schema(query=query) + + query = """ + query { + films (genre_In: ["do", "ac"]) { + edges { + node { + genre + reporters { + edges { + node { + lastName + } + } + } + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["films"]["edges"] == [ + { + "node": { + "genre": "do", + "reporters": {"edges": [{"node": {"lastName": "Doe"}}]}, + } + }, + { + "node": { + "genre": "ac", + "reporters": {"edges": [{"node": {"lastName": "Doe"}}]}, + } + }, + ] + + +def test_fk_id_in_filter(query): + """ + Test in filter on an foreign key relationship. + """ + john_doe = Reporter.objects.create( + first_name="John", last_name="Doe", email="john@doe.com" + ) + jean_bon = Reporter.objects.create( + first_name="Jean", last_name="Bon", email="jean@bon.com" + ) + sara_croche = Reporter.objects.create( + first_name="Sara", last_name="Croche", email="sara@croche.com" + ) + Article.objects.create( + headline="A", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=john_doe, + editor=john_doe, + ) + Article.objects.create( + headline="B", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=jean_bon, + editor=jean_bon, + ) + Article.objects.create( + headline="C", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=sara_croche, + editor=sara_croche, + ) + + schema = Schema(query=query) + + query = """ + query { + articles (reporter_In: [%s, %s]) { + edges { + node { + headline + reporter { + lastName + } + } + } + } + } + """ % ( + john_doe.id, + jean_bon.id, + ) + result = schema.execute(query) + assert not result.errors + assert result.data["articles"]["edges"] == [ + {"node": {"headline": "A", "reporter": {"lastName": "Doe"}}}, + {"node": {"headline": "B", "reporter": {"lastName": "Bon"}}}, + ] + + +def test_enum_in_filter(query): + """ + Test in filter on a choice field using an enum (Reporter.reporter_type). + """ + + Reporter.objects.create( + first_name="John", last_name="Doe", email="john@doe.com", reporter_type=1 + ) + Reporter.objects.create( + first_name="Jean", last_name="Bon", email="jean@bon.com", reporter_type=2 + ) + Reporter.objects.create( + first_name="Jane", last_name="Doe", email="jane@doe.com", reporter_type=2 + ) + Reporter.objects.create( + first_name="Jack", last_name="Black", email="jack@black.com", reporter_type=None + ) + + schema = Schema(query=query) + + query = """ + query { + reporters (reporterType_In: [A_1]) { + edges { + node { + email + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["reporters"]["edges"] == [ + {"node": {"email": "john@doe.com"}}, + ] + + query = """ + query { + reporters (reporterType_In: [A_2]) { + edges { + node { + email + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["reporters"]["edges"] == [ + {"node": {"email": "jean@bon.com"}}, + {"node": {"email": "jane@doe.com"}}, + ] + + query = """ + query { + reporters (reporterType_In: [A_2, A_1]) { + edges { + node { + email + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["reporters"]["edges"] == [ + {"node": {"email": "john@doe.com"}}, + {"node": {"email": "jean@bon.com"}}, + {"node": {"email": "jane@doe.com"}}, + ] diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 2be3778..2638656 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -2,54 +2,104 @@ import six import graphene -from django_filters.utils import get_model_field +from django import forms + +from django_filters.utils import get_model_field, get_field_parts from django_filters.filters import Filter, BaseCSVFilter from .filterset import custom_filterset_factory, setup_filterset -from .filters import InFilter, RangeFilter +from .filters import ArrayFilter, ListFilter, RangeFilter +from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField + + +def get_field_type(registry, model, field_name): + """ + Try to get a model field corresponding Graphql type from the DjangoObjectType. + """ + object_type = registry.get_type_for_model(model) + if object_type: + object_type_field = object_type._meta.fields.get(field_name) + if object_type_field: + field_type = object_type_field.type + if isinstance(field_type, graphene.NonNull): + field_type = field_type.of_type + return field_type + return None def get_filtering_args_from_filterset(filterset_class, type): - """ Inspect a FilterSet and produce the arguments to pass to - a Graphene Field. These arguments will be available to - filter against in the GraphQL + """ + Inspect a FilterSet and produce the arguments to pass to a Graphene Field. + These arguments will be available to filter against in the GraphQL API. """ from ..forms.converter import convert_form_field args = {} model = filterset_class._meta.model + registry = type._meta.registry for name, filter_field in six.iteritems(filterset_class.base_filters): - form_field = None filter_type = filter_field.lookup_expr + field_type = None + form_field = None - if name in filterset_class.declared_filters: - # Get the filter field from the explicitly declared filter - form_field = filter_field.field - field = convert_form_field(form_field) - else: - # Get the filter field with no explicit type declaration - model_field = get_model_field(model, filter_field.field_name) - if filter_type != "isnull" and hasattr(model_field, "formfield"): - form_field = model_field.formfield( - required=filter_field.extra.get("required", False) - ) + if ( + name not in filterset_class.declared_filters + or isinstance(filter_field, ListFilter) + or isinstance(filter_field, RangeFilter) + or isinstance(filter_field, ArrayFilter) + ): + # Get the filter field for filters that are no explicitly declared. - # Fallback to field defined on filter if we can't get it from the - # model field - if not form_field: - form_field = filter_field.field + required = filter_field.extra.get("required", False) + if filter_type == "isnull": + field = graphene.Boolean(required=required) + else: + model_field = get_model_field(model, filter_field.field_name) - field = convert_form_field(form_field) + # Get the form field either from: + # 1. the formfield corresponding to the model field + # 2. the field defined on filter + if hasattr(model_field, "formfield"): + form_field = model_field.formfield(required=required) + if not form_field: + form_field = filter_field.field - if filter_type in {"in", "range", "contains", "overlap"}: - # Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of - # the same type as the field. See comments in - # `replace_csv_filters` method for more details. - field = graphene.List(field.get_type()) + # First try to get the matching field type from the GraphQL DjangoObjectType + if model_field: + if ( + isinstance(form_field, forms.ModelChoiceField) + or isinstance(form_field, forms.ModelMultipleChoiceField) + or isinstance(form_field, GlobalIDMultipleChoiceField) + or isinstance(form_field, GlobalIDFormField) + ): + # Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID. + field_type = get_field_type( + registry, model_field.related_model, "id" + ) + else: + field_type = get_field_type( + registry, model_field.model, model_field.name + ) - field_type = field.Argument() - field_type.description = filter_field.label - args[name] = field_type + if not field_type: + # Fallback on converting the form field either because: + # - it's an explicitly declared filters + # - we did not manage to get the type from the model type + form_field = form_field or filter_field.field + field_type = convert_form_field(form_field) + + if isinstance(filter_field, ListFilter) or isinstance( + filter_field, RangeFilter + ): + # Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of + # the same type as the field. See comments in `replace_csv_filters` method for more details. + field_type = graphene.List(field_type.get_type()) + + args[name] = graphene.Argument( + type=field_type.get_type(), + description=filter_field.label, + required=required, + ) return args @@ -71,18 +121,26 @@ def get_filterset_class(filterset_class, **meta): def replace_csv_filters(filterset_class): """ - Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore - but regular Filter objects that simply use the input value as filter argument on the queryset. + Replace the "in" and "range" filters (that are not explicitly declared) + to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore + but our custom InFilter/RangeFilter filter class that use the input + value as filter argument on the queryset. - This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we - can actually have a list as input and have a proper type verification of each value in the list. + This is because those BaseCSVFilter are expecting a string as input with + comma separated values. + But with GraphQl we can actually have a list as input and have a proper + type verification of each value in the list. See issue https://github.com/graphql-python/graphene-django/issues/1068. """ for name, filter_field in six.iteritems(filterset_class.base_filters): + # Do not touch any declared filters + if name in filterset_class.declared_filters: + continue + filter_type = filter_field.lookup_expr - if filter_type in {"in", "contains", "overlap"}: - filterset_class.base_filters[name] = InFilter( + if filter_type == "in": + filterset_class.base_filters[name] = ListFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr, label=filter_field.label, @@ -90,7 +148,6 @@ def replace_csv_filters(filterset_class): exclude=filter_field.exclude, **filter_field.extra ) - elif filter_type == "range": filterset_class.base_filters[name] = RangeFilter( field_name=filter_field.field_name, diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 20f509c..9e7be29 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -26,7 +26,7 @@ class Film(models.Model): genre = models.CharField( max_length=2, help_text="Genre", - choices=[("do", "Documentary"), ("ot", "Other")], + choices=[("do", "Documentary"), ("ac", "Action"), ("ot", "Other")], default="ot", ) reporters = models.ManyToManyField("Reporter", related_name="films") @@ -91,8 +91,8 @@ class CNNReporter(Reporter): class Article(models.Model): headline = models.CharField(max_length=100) - pub_date = models.DateField() - pub_date_time = models.DateTimeField() + pub_date = models.DateField(auto_now_add=True) + pub_date_time = models.DateTimeField(auto_now_add=True) reporter = models.ForeignKey( Reporter, on_delete=models.CASCADE, related_name="articles" ) diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 5ff4466..9d83f3f 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -412,6 +412,7 @@ def test_should_query_node_filtering(): model = Article interfaces = (Node,) filter_fields = ("lang",) + convert_choices_to_enum = False class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -534,6 +535,7 @@ def test_should_query_node_multiple_filtering(): model = Article interfaces = (Node,) filter_fields = ("lang", "headline") + convert_choices_to_enum = False class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType)