Add enum support to filters and fix filter typing (v3) (#1119)

* - Add filtering support for choice fields converted to graphql Enum (or not)
- Fix type of various filters (used to default to String)
- Fix bug with contains introduced in previous PR
- Fix bug with declared filters being overridden (see PR #1108)
- Fix support for ArrayField and add documentation

* Fix for v3

Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
This commit is contained in:
Thomas Leonard 2021-02-23 05:21:32 +01:00 committed by GitHub
parent 5ce4553244
commit 2d4ca0ac7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 916 additions and 126 deletions

View File

@ -258,3 +258,46 @@ with this set up, you can now order the users under group:
}
}
}
PostgreSQL `ArrayField`
-----------------------
Graphene provides an easy to implement filters on `ArrayField` as they are not natively supported by django_filters:
.. code:: python
from django.db import models
from django_filters import FilterSet, OrderingFilter
from graphene_django.filter import ArrayFilter
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact", "contains"],
}
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
filterset_class = EventFilterSet
with this set up, you can now filter events by tags:
.. code::
query {
events(tags_Overlap: ["concert", "festival"]) {
name
}
}

View File

@ -35,7 +35,7 @@ from .utils.str_converters import to_const
class BlankValueField(Field):
def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
resolver = self.resolver or parent_resolver
# create custom resolver

View File

@ -9,10 +9,19 @@ if not DJANGO_FILTER_INSTALLED:
)
else:
from .fields import DjangoFilterConnectionField
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .filters import (
ArrayFilter,
GlobalIDFilter,
GlobalIDMultipleChoiceFilter,
ListFilter,
RangeFilter,
)
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
]

View File

@ -2,12 +2,31 @@ from collections import OrderedDict
from functools import partial
from django.core.exceptions import ValidationError
from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class
def convert_enum(data):
"""
Check if the data is a enum option (or potentially nested list of enum option)
and convert it to its value.
This method is used to pre-process the data for the filters as they can take an
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
"""
if isinstance(data, list):
return [convert_enum(item) for item in data]
if isinstance(type(data), EnumType):
return data.value
else:
return data
class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(
self,
@ -43,8 +62,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
filterset_class = self._provided_filterset_class or (
self.node_type._meta.filterset_class
filterset_class = (
self._provided_filterset_class or self.node_type._meta.filterset_class
)
self._filterset_class = get_filterset_class(filterset_class, **meta)
@ -68,7 +87,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if k in filtering_args:
if k == "order_by" and v is not None:
v = to_snake_case(v)
kwargs[k] = v
kwargs[k] = convert_enum(v)
return kwargs
qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
@ -78,7 +97,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context
)
if filterset.form.is_valid():
if filterset.is_valid():
return filterset.qs
raise ValidationError(filterset.form.errors.as_json())

View File

@ -2,6 +2,7 @@ from django.core.exceptions import ValidationError
from django.forms import Field
from django_filters import Filter, MultipleChoiceFilter
from django_filters.constants import EMPTY_VALUES
from graphql_relay.node.node import from_global_id
@ -31,14 +32,15 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
class InFilter(Filter):
class ListFilter(Filter):
"""
Filter for a list of value using the `__in` Django filter.
Filter that takes a list of value as input.
It is for example used for `__in` filters.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first weather the list is
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
@ -73,3 +75,27 @@ class RangeField(Field):
class RangeFilter(Filter):
field_class = RangeField
class ArrayFilter(Filter):
"""
Filter made for PostgreSQL ArrayField.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get the filter applied with
an empty list since it's a valid value but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value in EMPTY_VALUES and value != []:
return qs
if self.distinct:
qs = qs.distinct()
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
qs = self.get_method(qs)(**{lookup: value})
return qs

View File

@ -9,6 +9,7 @@ import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene_django.filter import ArrayFilter, ListFilter
from ...compat import ArrayField
@ -27,49 +28,61 @@ else:
STORE = {"events": []}
@pytest.fixture
def Event():
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
return Event
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())
@pytest.fixture
def EventFilterSet(Event):
from django.contrib.postgres.forms import SimpleArrayField
class ArrayFilter(filters.Filter):
base_field_class = SimpleArrayField
def EventFilterSet():
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact"],
"name": ["exact", "contains"],
}
# Those are actually usable with our Query fixture bellow
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
# Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap")
tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact")
random_field__contains = ArrayFilter(
field_name="random_field", lookup_expr="contains"
)
random_field__overlap = ArrayFilter(
field_name="random_field", lookup_expr="overlap"
)
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
return EventFilterSet
@pytest.fixture
def EventType(Event, EventFilterSet):
def EventType(EventFilterSet):
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet
return EventType
@pytest.fixture
def Query(Event, EventType):
def Query(EventType):
"""
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
we are running unit tests in sqlite which does not have ArrayFields.
"""
class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)
@ -79,6 +92,7 @@ def Query(Event, EventType):
Event(name="Live Show", tags=["concert", "music", "rock"],),
Event(name="Musical", tags=["movie", "music"],),
Event(name="Ballet", tags=["concert", "dance"],),
Event(name="Speech", tags=[],),
]
STORE["events"] = events
@ -105,6 +119,13 @@ def Query(Event, EventType):
STORE["events"],
)
)
if "tags__exact" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"],
)
)
def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs)
@ -121,7 +142,9 @@ def Query(Event, EventType):
m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = STORE["events"].__getitem__
m_queryset.__getitem__.side_effect = lambda index: STORE[
"events"
].__getitem__(index)
return m_queryset

View File

@ -10,7 +10,7 @@ class ArticleFilter(django_filters.FilterSet):
fields = {
"headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"],
"reporter": ["exact"],
"reporter": ["exact", "in"],
}
order_by = OrderingFilter(fields=("pub_date",))

View File

@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_multiple(Query):
def test_array_field_contains_multiple(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
@ -32,9 +32,9 @@ def test_string_contains_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_one(Query):
def test_array_field_contains_one(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
@ -59,9 +59,9 @@ def test_string_contains_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_none(Query):
def test_array_field_contains_empty_list(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
@ -79,4 +79,9 @@ def test_string_contains_none(Query):
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
{"node": {"name": "Speech"}},
]

View File

@ -0,0 +1,129 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_no_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["movie", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_empty_list(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Speech"}},
]
def test_array_field_filter_schema_type(Query):
"""
Check that the type in the filter is an array field like on the object type.
"""
schema = Schema(query=Query)
schema_str = str(schema)
assert (
'''type EventType implements Node {
"""The ID of the object"""
id: ID!
name: String!
tags: [String!]!
tagIds: [Int!]!
randomField: [Boolean!]!
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"name": "String",
"name_Contains": "String",
"tags_Contains": "[String!]",
"tags_Overlap": "[String!]",
"tags": "[String!]",
"tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]",
"randomField_Contains": "[Boolean!]",
"randomField_Overlap": "[Boolean!]",
"randomField": "[Boolean!]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
)
assert (
f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str
)

View File

@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_multiple(Query):
def test_array_field_overlap_multiple(Query):
"""
Test overlap filter on a string field.
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)
@ -34,9 +34,9 @@ def test_string_overlap_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_one(Query):
def test_array_field_overlap_one(Query):
"""
Test overlap filter on a string field.
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)
@ -61,9 +61,9 @@ def test_string_overlap_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_none(Query):
def test_array_field_overlap_empty_list(Query):
"""
Test overlap filter on a string field.
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)

View File

@ -0,0 +1,160 @@
import pytest
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType, DjangoConnectionField
from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def schema():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"lang": ["exact", "in"],
"reporter__a_choice": ["exact", "in"],
}
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoFilterConnectionField(ArticleType)
schema = graphene.Schema(query=Query)
return schema
@pytest.fixture
def reporter_article_data():
john = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
jane = Reporter.objects.create(
first_name="Jane", last_name="Doe", email="janedoe@example.com", a_choice=2
)
Article.objects.create(
headline="Article Node 1", reporter=john, editor=john, lang="es",
)
Article.objects.create(
headline="Article Node 2", reporter=john, editor=john, lang="en",
)
Article.objects.create(
headline="Article Node 3", reporter=jane, editor=jane, lang="en",
)
def test_filter_enum_on_connection(schema, reporter_article_data):
"""
Check that we can filter with enums on a connection.
"""
query = """
query {
allArticles(lang: ES) {
edges {
node {
headline
}
}
}
}
"""
expected = {"allArticles": {"edges": [{"node": {"headline": "Article Node 1"}},]}}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_on_foreign_key_enum_field(schema, reporter_article_data):
"""
Check that we can filter with enums on a field from a foreign key.
"""
query = """
query {
allArticles(reporter_AChoice: A_1) {
edges {
node {
headline
}
}
}
}
"""
expected = {
"allArticles": {
"edges": [
{"node": {"headline": "Article Node 1"}},
{"node": {"headline": "Article Node 2"}},
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_enum_field_schema_type(schema):
"""
Check that the type in the filter is an enum like on the object type.
"""
schema_str = str(schema)
assert (
'''type ArticleType implements Node {
"""The ID of the object"""
id: ID!
headline: String!
pubDate: Date!
pubDateTime: DateTime!
reporter: ReporterType!
editor: ReporterType!
"""Language"""
lang: TestsArticleLangChoices!
importance: TestsArticleImportanceChoices
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"reporter_AChoice": "TestsReporterAChoiceChoices",
"reporter_AChoice_In": "[TestsReporterAChoiceChoices]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
)
assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str

View File

@ -9,7 +9,7 @@ from graphene import Argument, Boolean, Field, Float, ObjectType, Schema, String
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.tests.models import Article, Person, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
@ -90,6 +90,7 @@ def test_filter_explicit_filterset_arguments():
"pub_date__gt",
"pub_date__lt",
"reporter",
"reporter__in",
)
@ -696,7 +697,7 @@ def test_should_query_filter_node_limit():
node {
id
firstName
articles(lang: "es") {
articles(lang: ES) {
edges {
node {
id
@ -738,6 +739,7 @@ def test_order_by():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
@ -1143,7 +1145,7 @@ def test_filter_filterset_based_on_mixin():
return filters
def filter_email_in(cls, queryset, name, value):
def filter_email_in(self, queryset, name, value):
return queryset.filter(**{name: [value]})
class NewArticleFilter(ArticleFilterMixin, ArticleFilter):
@ -1228,3 +1230,48 @@ def test_filter_filterset_based_on_mixin():
assert not result.errors
assert result.data == expected
def test_filter_string_contains():
class PersonType(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
fields = "__all__"
filter_fields = {"name": ["exact", "in", "contains", "icontains"]}
class Query(ObjectType):
people = DjangoFilterConnectionField(PersonType)
schema = Schema(query=Query)
Person.objects.bulk_create(
[
Person(name="Jack"),
Person(name="Joe"),
Person(name="Jane"),
Person(name="Peter"),
Person(name="Bob"),
]
)
query = """query nameContain($filter: String) {
people(name_Contains: $filter) {
edges {
node {
name
}
}
}
}"""
result = schema.execute(query, variables={"filter": "Ja"})
assert not result.errors
assert result.data == {
"people": {"edges": [{"node": {"name": "Jack"}}, {"node": {"name": "Jane"}},]}
}
result = schema.execute(query, variables={"filter": "o"})
assert not result.errors
assert result.data == {
"people": {"edges": [{"node": {"name": "Joe"}}, {"node": {"name": "Bob"}},]}
}

View File

@ -1,3 +1,5 @@
from datetime import datetime
import pytest
from django_filters import FilterSet
@ -5,7 +7,8 @@ from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet, Person
from graphene_django.tests.models import Pet, Person, Reporter, Article, Film
from graphene_django.filter.tests.filters import ArticleFilter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
@ -20,40 +23,77 @@ else:
)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
@pytest.fixture
def query():
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"id": ["exact", "in"],
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
class ReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
# choice filter using enum
filter_fields = {"reporter_type": ["exact", "in"]}
class ArticleNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilter
class FilmNode(DjangoObjectType):
class Meta:
model = Film
interfaces = (Node,)
fields = "__all__"
# choice filter not using enum
filter_fields = {
"genre": ["exact", "in"],
}
convert_choices_to_enum = False
class PersonFilterSet(FilterSet):
class Meta:
model = Person
fields = {"name": ["in"]}
names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value):
"""
This custom filter take a string as input with comma separated values.
Note that the value here is already a list as it has been transformed by the BaseInFilter class.
"""
return qs.filter(name__in=value)
class PersonNode(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
fields = "__all__"
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode)
articles = DjangoFilterConnectionField(ArticleNode)
films = DjangoFilterConnectionField(FilmNode)
reporters = DjangoFilterConnectionField(ReporterNode)
return Query
class PersonFilterSet(FilterSet):
class Meta:
model = Person
fields = {}
names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value):
return qs.filter(name__in=value)
class PersonNode(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode)
def test_string_in_filter():
def test_string_in_filter(query):
"""
Test in filter on a string field.
"""
@ -61,7 +101,7 @@ def test_string_in_filter():
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
@ -82,17 +122,19 @@ def test_string_in_filter():
]
def test_string_in_filter_with_filterset_class():
"""Test in filter on a string field with a custom filterset class."""
def test_string_in_filter_with_otjer_filter(query):
"""
Test in filter on a string field which has also a custom filter doing a similar operation.
"""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
people (names: ["John", "Michael"]) {
people (name_In: ["John", "Michael"]) {
edges {
node {
name
@ -109,7 +151,36 @@ def test_string_in_filter_with_filterset_class():
]
def test_int_in_filter():
def test_string_in_filter_with_declared_filter(query):
"""
Test in filter on a string field with a custom filterset class.
"""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=query)
query = """
query {
people (names: "John,Michael") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["people"]["edges"] == [
{"node": {"name": "John"}},
{"node": {"name": "Michael"}},
]
def test_int_in_filter(query):
"""
Test in filter on an integer field.
"""
@ -117,7 +188,7 @@ def test_int_in_filter():
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
@ -157,7 +228,7 @@ def test_int_in_filter():
]
def test_in_filter_with_empty_list():
def test_in_filter_with_empty_list(query):
"""
Check that using a in filter with an empty list provided as input returns no objects.
"""
@ -165,7 +236,7 @@ def test_in_filter_with_empty_list():
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
@ -181,3 +252,197 @@ def test_in_filter_with_empty_list():
result = schema.execute(query)
assert not result.errors
assert len(result.data["pets"]["edges"]) == 0
def test_choice_in_filter_without_enum(query):
"""
Test in filter o an choice field not using an enum (Film.genre).
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
documentary_film = Film.objects.create(genre="do")
documentary_film.reporters.add(john_doe)
action_film = Film.objects.create(genre="ac")
action_film.reporters.add(john_doe)
other_film = Film.objects.create(genre="ot")
other_film.reporters.add(john_doe)
other_film.reporters.add(jean_bon)
schema = Schema(query=query)
query = """
query {
films (genre_In: ["do", "ac"]) {
edges {
node {
genre
reporters {
edges {
node {
lastName
}
}
}
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["films"]["edges"] == [
{
"node": {
"genre": "do",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
{
"node": {
"genre": "ac",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
]
def test_fk_id_in_filter(query):
"""
Test in filter on an foreign key relationship.
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
sara_croche = Reporter.objects.create(
first_name="Sara", last_name="Croche", email="sara@croche.com"
)
Article.objects.create(
headline="A",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=john_doe,
editor=john_doe,
)
Article.objects.create(
headline="B",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=jean_bon,
editor=jean_bon,
)
Article.objects.create(
headline="C",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=sara_croche,
editor=sara_croche,
)
schema = Schema(query=query)
query = """
query {
articles (reporter_In: [%s, %s]) {
edges {
node {
headline
reporter {
lastName
}
}
}
}
}
""" % (
john_doe.id,
jean_bon.id,
)
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A", "reporter": {"lastName": "Doe"}}},
{"node": {"headline": "B", "reporter": {"lastName": "Bon"}}},
]
def test_enum_in_filter(query):
"""
Test in filter on a choice field using an enum (Reporter.reporter_type).
"""
Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com", reporter_type=1
)
Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jane", last_name="Doe", email="jane@doe.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jack", last_name="Black", email="jack@black.com", reporter_type=None
)
schema = Schema(query=query)
query = """
query {
reporters (reporterType_In: [A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2, A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]

View File

@ -25,6 +25,7 @@ class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],

View File

@ -1,53 +1,101 @@
import graphene
from django_filters.utils import get_model_field
from django import forms
from django_filters.utils import get_model_field, get_field_parts
from django_filters.filters import Filter, BaseCSVFilter
from .filterset import custom_filterset_factory, setup_filterset
from .filters import InFilter, RangeFilter
from .filters import ArrayFilter, ListFilter, RangeFilter
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
def get_field_type(registry, model, field_name):
"""
Try to get a model field corresponding Graphql type from the DjangoObjectType.
"""
object_type = registry.get_type_for_model(model)
if object_type:
object_type_field = object_type._meta.fields.get(field_name)
if object_type_field:
field_type = object_type_field.type
if isinstance(field_type, graphene.NonNull):
field_type = field_type.of_type
return field_type
return None
def get_filtering_args_from_filterset(filterset_class, type):
""" Inspect a FilterSet and produce the arguments to pass to
a Graphene Field. These arguments will be available to
filter against in the GraphQL
"""
Inspect a FilterSet and produce the arguments to pass to a Graphene Field.
These arguments will be available to filter against in the GraphQL API.
"""
from ..forms.converter import convert_form_field
args = {}
model = filterset_class._meta.model
registry = type._meta.registry
for name, filter_field in filterset_class.base_filters.items():
form_field = None
filter_type = filter_field.lookup_expr
field_type = None
form_field = None
if name in filterset_class.declared_filters:
# Get the filter field from the explicitly declared filter
form_field = filter_field.field
field = convert_form_field(form_field)
else:
# Get the filter field with no explicit type declaration
model_field = get_model_field(model, filter_field.field_name)
if filter_type != "isnull" and hasattr(model_field, "formfield"):
form_field = model_field.formfield(
required=filter_field.extra.get("required", False)
)
if (
name not in filterset_class.declared_filters
or isinstance(filter_field, ListFilter)
or isinstance(filter_field, RangeFilter)
or isinstance(filter_field, ArrayFilter)
):
# Get the filter field for filters that are no explicitly declared.
# Fallback to field defined on filter if we can't get it from the
# model field
if not form_field:
form_field = filter_field.field
required = filter_field.extra.get("required", False)
if filter_type == "isnull":
field = graphene.Boolean(required=required)
else:
model_field = get_model_field(model, filter_field.field_name)
field = convert_form_field(form_field)
# Get the form field either from:
# 1. the formfield corresponding to the model field
# 2. the field defined on filter
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(required=required)
if not form_field:
form_field = filter_field.field
if filter_type in {"in", "range", "contains", "overlap"}:
# Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of
# the same type as the field. See comments in
# `replace_csv_filters` method for more details.
field = graphene.List(field.get_type())
# First try to get the matching field type from the GraphQL DjangoObjectType
if model_field:
if (
isinstance(form_field, forms.ModelChoiceField)
or isinstance(form_field, forms.ModelMultipleChoiceField)
or isinstance(form_field, GlobalIDMultipleChoiceField)
or isinstance(form_field, GlobalIDFormField)
):
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
field_type = get_field_type(
registry, model_field.related_model, "id"
)
else:
field_type = get_field_type(
registry, model_field.model, model_field.name
)
field_type = field.Argument()
field_type.description = str(filter_field.label) if filter_field.label else None
args[name] = field_type
if not field_type:
# Fallback on converting the form field either because:
# - it's an explicitly declared filters
# - we did not manage to get the type from the model type
form_field = form_field or filter_field.field
field_type = convert_form_field(form_field)
if isinstance(filter_field, ListFilter) or isinstance(
filter_field, RangeFilter
):
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
# the same type as the field. See comments in `replace_csv_filters` method for more details.
field_type = graphene.List(field_type.get_type())
args[name] = graphene.Argument(
field_type.get_type(), description=filter_field.label, required=required,
)
return args
@ -69,18 +117,26 @@ def get_filterset_class(filterset_class, **meta):
def replace_csv_filters(filterset_class):
"""
Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
but regular Filter objects that simply use the input value as filter argument on the queryset.
Replace the "in" and "range" filters (that are not explicitly declared)
to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
but our custom InFilter/RangeFilter filter class that use the input
value as filter argument on the queryset.
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
can actually have a list as input and have a proper type verification of each value in the list.
This is because those BaseCSVFilter are expecting a string as input with
comma separated values.
But with GraphQl we can actually have a list as input and have a proper
type verification of each value in the list.
See issue https://github.com/graphql-python/graphene-django/issues/1068.
"""
for name, filter_field in list(filterset_class.base_filters.items()):
# Do not touch any declared filters
if name in filterset_class.declared_filters:
continue
filter_type = filter_field.lookup_expr
if filter_type in {"in", "contains", "overlap"}:
filterset_class.base_filters[name] = InFilter(
if filter_type == "in":
filterset_class.base_filters[name] = ListFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
@ -88,7 +144,6 @@ def replace_csv_filters(filterset_class):
exclude=filter_field.exclude,
**filter_field.extra
)
elif filter_type == "range":
filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name,

View File

@ -26,7 +26,7 @@ class Film(models.Model):
genre = models.CharField(
max_length=2,
help_text="Genre",
choices=[("do", "Documentary"), ("ot", "Other")],
choices=[("do", "Documentary"), ("ac", "Action"), ("ot", "Other")],
default="ot",
)
reporters = models.ManyToManyField("Reporter", related_name="films")
@ -91,8 +91,8 @@ class CNNReporter(Reporter):
class Article(models.Model):
headline = models.CharField(max_length=100)
pub_date = models.DateField()
pub_date_time = models.DateTimeField()
pub_date = models.DateField(auto_now_add=True)
pub_date_time = models.DateTimeField(auto_now_add=True)
reporter = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="articles"
)

View File

@ -421,6 +421,7 @@ def test_should_query_node_filtering():
interfaces = (Node,)
fields = "__all__"
filter_fields = ("lang",)
convert_choices_to_enum = False
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -546,6 +547,7 @@ def test_should_query_node_multiple_filtering():
interfaces = (Node,)
fields = "__all__"
filter_fields = ("lang", "headline")
convert_choices_to_enum = False
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1251,6 +1253,7 @@ class TestBackwardPagination:
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1455,6 +1458,7 @@ def test_connection_should_enable_offset_filtering():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1494,6 +1498,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1527,6 +1532,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1561,6 +1567,7 @@ def test_connection_should_allow_offset_filtering_with_after():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)

View File

@ -671,6 +671,7 @@ def test_django_objecttype_name_connection_propagation():
class Meta:
model = ReporterModel
name = "CustomReporterName"
fields = "__all__"
filter_fields = ["email"]
interfaces = (Node,)