mirror of
https://github.com/graphql-python/graphene-django.git
synced 2024-11-29 04:53:43 +03:00
feat: add TypedFilter which allow to explicitly give a filter input GraphQL type (#1142)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
This commit is contained in:
parent
6f1389c039
commit
998ed89a4e
|
@ -16,7 +16,7 @@ You will need to install it manually, which can be done as follows:
|
|||
|
||||
# You'll need to install django-filter
|
||||
pip install django-filter>=2
|
||||
|
||||
|
||||
After installing ``django-filter`` you'll need to add the application in the ``settings.py`` file:
|
||||
|
||||
.. code:: python
|
||||
|
@ -271,3 +271,41 @@ with this set up, you can now filter events by tags:
|
|||
name
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
`TypedFilter`
|
||||
-------------
|
||||
|
||||
Sometimes the automatic detection of the filter input type is not satisfactory for what you are trying to achieve.
|
||||
You can then explicitly specify the input type you want for your filter by using a `TypedFilter`:
|
||||
|
||||
.. code:: python
|
||||
|
||||
from django.db import models
|
||||
from django_filters import FilterSet, OrderingFilter
|
||||
import graphene
|
||||
from graphene_django.filter import TypedFilter
|
||||
|
||||
class Event(models.Model):
|
||||
name = models.CharField(max_length=50)
|
||||
|
||||
class EventFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Event
|
||||
fields = {
|
||||
"name": ["exact", "contains"],
|
||||
}
|
||||
|
||||
only_first = TypedFilter(input_type=graphene.Boolean, method="only_first_filter")
|
||||
|
||||
def only_first_filter(self, queryset, _name, value):
|
||||
if value:
|
||||
return queryset[:1]
|
||||
else:
|
||||
return queryset
|
||||
|
||||
class EventType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Event
|
||||
interfaces = (Node,)
|
||||
filterset_class = EventFilterSet
|
||||
|
|
|
@ -15,6 +15,7 @@ else:
|
|||
GlobalIDMultipleChoiceFilter,
|
||||
ListFilter,
|
||||
RangeFilter,
|
||||
TypedFilter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -24,4 +25,5 @@ else:
|
|||
"ArrayFilter",
|
||||
"ListFilter",
|
||||
"RangeFilter",
|
||||
"TypedFilter",
|
||||
]
|
||||
|
|
|
@ -1,101 +0,0 @@
|
|||
from django.core.exceptions import ValidationError
|
||||
from django.forms import Field
|
||||
|
||||
from django_filters import Filter, MultipleChoiceFilter
|
||||
from django_filters.constants import EMPTY_VALUES
|
||||
|
||||
from graphql_relay.node.node import from_global_id
|
||||
|
||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
class GlobalIDFilter(Filter):
|
||||
"""
|
||||
Filter for Relay global ID.
|
||||
"""
|
||||
|
||||
field_class = GlobalIDFormField
|
||||
|
||||
def filter(self, qs, value):
|
||||
""" Convert the filter value to a primary key before filtering """
|
||||
_id = None
|
||||
if value is not None:
|
||||
_, _id = from_global_id(value)
|
||||
return super(GlobalIDFilter, self).filter(qs, _id)
|
||||
|
||||
|
||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
||||
field_class = GlobalIDMultipleChoiceField
|
||||
|
||||
def filter(self, qs, value):
|
||||
gids = [from_global_id(v)[1] for v in value]
|
||||
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
|
||||
|
||||
|
||||
class ListFilter(Filter):
|
||||
"""
|
||||
Filter that takes a list of value as input.
|
||||
It is for example used for `__in` filters.
|
||||
"""
|
||||
|
||||
def filter(self, qs, value):
|
||||
"""
|
||||
Override the default filter class to check first whether the list is
|
||||
empty or not.
|
||||
This needs to be done as in this case we expect to get an empty output
|
||||
(if not an exclude filter) but django_filter consider an empty list
|
||||
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||
the filter does not need to be applied (hence returning the original
|
||||
queryset).
|
||||
"""
|
||||
if value is not None and len(value) == 0:
|
||||
if self.exclude:
|
||||
return qs
|
||||
else:
|
||||
return qs.none()
|
||||
else:
|
||||
return super(ListFilter, self).filter(qs, value)
|
||||
|
||||
|
||||
def validate_range(value):
|
||||
"""
|
||||
Validator for range filter input: the list of value must be of length 2.
|
||||
Note that validators are only run if the value is not empty.
|
||||
"""
|
||||
if len(value) != 2:
|
||||
raise ValidationError(
|
||||
"Invalid range specified: it needs to contain 2 values.", code="invalid"
|
||||
)
|
||||
|
||||
|
||||
class RangeField(Field):
|
||||
default_validators = [validate_range]
|
||||
empty_values = [None]
|
||||
|
||||
|
||||
class RangeFilter(Filter):
|
||||
field_class = RangeField
|
||||
|
||||
|
||||
class ArrayFilter(Filter):
|
||||
"""
|
||||
Filter made for PostgreSQL ArrayField.
|
||||
"""
|
||||
|
||||
def filter(self, qs, value):
|
||||
"""
|
||||
Override the default filter class to check first whether the list is
|
||||
empty or not.
|
||||
This needs to be done as in this case we expect to get the filter applied with
|
||||
an empty list since it's a valid value but django_filter consider an empty list
|
||||
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||
the filter does not need to be applied (hence returning the original
|
||||
queryset).
|
||||
"""
|
||||
if value in EMPTY_VALUES and value != []:
|
||||
return qs
|
||||
if self.distinct:
|
||||
qs = qs.distinct()
|
||||
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
|
||||
qs = self.get_method(qs)(**{lookup: value})
|
||||
return qs
|
25
graphene_django/filter/filters/__init__.py
Normal file
25
graphene_django/filter/filters/__init__.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import warnings
|
||||
from ...utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
if not DJANGO_FILTER_INSTALLED:
|
||||
warnings.warn(
|
||||
"Use of django filtering requires the django-filter package "
|
||||
"be installed. You can do so using `pip install django-filter`",
|
||||
ImportWarning,
|
||||
)
|
||||
else:
|
||||
from .array_filter import ArrayFilter
|
||||
from .global_id_filter import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||
from .list_filter import ListFilter
|
||||
from .range_filter import RangeFilter
|
||||
from .typed_filter import TypedFilter
|
||||
|
||||
__all__ = [
|
||||
"DjangoFilterConnectionField",
|
||||
"GlobalIDFilter",
|
||||
"GlobalIDMultipleChoiceFilter",
|
||||
"ArrayFilter",
|
||||
"ListFilter",
|
||||
"RangeFilter",
|
||||
"TypedFilter",
|
||||
]
|
27
graphene_django/filter/filters/array_filter.py
Normal file
27
graphene_django/filter/filters/array_filter.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from django_filters.constants import EMPTY_VALUES
|
||||
|
||||
from .typed_filter import TypedFilter
|
||||
|
||||
|
||||
class ArrayFilter(TypedFilter):
|
||||
"""
|
||||
Filter made for PostgreSQL ArrayField.
|
||||
"""
|
||||
|
||||
def filter(self, qs, value):
|
||||
"""
|
||||
Override the default filter class to check first whether the list is
|
||||
empty or not.
|
||||
This needs to be done as in this case we expect to get the filter applied with
|
||||
an empty list since it's a valid value but django_filter consider an empty list
|
||||
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||
the filter does not need to be applied (hence returning the original
|
||||
queryset).
|
||||
"""
|
||||
if value in EMPTY_VALUES and value != []:
|
||||
return qs
|
||||
if self.distinct:
|
||||
qs = qs.distinct()
|
||||
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
|
||||
qs = self.get_method(qs)(**{lookup: value})
|
||||
return qs
|
28
graphene_django/filter/filters/global_id_filter.py
Normal file
28
graphene_django/filter/filters/global_id_filter.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from django_filters import Filter, MultipleChoiceFilter
|
||||
|
||||
from graphql_relay.node.node import from_global_id
|
||||
|
||||
from ...forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
class GlobalIDFilter(Filter):
|
||||
"""
|
||||
Filter for Relay global ID.
|
||||
"""
|
||||
|
||||
field_class = GlobalIDFormField
|
||||
|
||||
def filter(self, qs, value):
|
||||
""" Convert the filter value to a primary key before filtering """
|
||||
_id = None
|
||||
if value is not None:
|
||||
_, _id = from_global_id(value)
|
||||
return super(GlobalIDFilter, self).filter(qs, _id)
|
||||
|
||||
|
||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
||||
field_class = GlobalIDMultipleChoiceField
|
||||
|
||||
def filter(self, qs, value):
|
||||
gids = [from_global_id(v)[1] for v in value]
|
||||
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
|
26
graphene_django/filter/filters/list_filter.py
Normal file
26
graphene_django/filter/filters/list_filter.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from .typed_filter import TypedFilter
|
||||
|
||||
|
||||
class ListFilter(TypedFilter):
|
||||
"""
|
||||
Filter that takes a list of value as input.
|
||||
It is for example used for `__in` filters.
|
||||
"""
|
||||
|
||||
def filter(self, qs, value):
|
||||
"""
|
||||
Override the default filter class to check first whether the list is
|
||||
empty or not.
|
||||
This needs to be done as in this case we expect to get an empty output
|
||||
(if not an exclude filter) but django_filter consider an empty list
|
||||
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||
the filter does not need to be applied (hence returning the original
|
||||
queryset).
|
||||
"""
|
||||
if value is not None and len(value) == 0:
|
||||
if self.exclude:
|
||||
return qs
|
||||
else:
|
||||
return qs.none()
|
||||
else:
|
||||
return super(ListFilter, self).filter(qs, value)
|
24
graphene_django/filter/filters/range_filter.py
Normal file
24
graphene_django/filter/filters/range_filter.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from django.core.exceptions import ValidationError
|
||||
from django.forms import Field
|
||||
|
||||
from .typed_filter import TypedFilter
|
||||
|
||||
|
||||
def validate_range(value):
|
||||
"""
|
||||
Validator for range filter input: the list of value must be of length 2.
|
||||
Note that validators are only run if the value is not empty.
|
||||
"""
|
||||
if len(value) != 2:
|
||||
raise ValidationError(
|
||||
"Invalid range specified: it needs to contain 2 values.", code="invalid"
|
||||
)
|
||||
|
||||
|
||||
class RangeField(Field):
|
||||
default_validators = [validate_range]
|
||||
empty_values = [None]
|
||||
|
||||
|
||||
class RangeFilter(TypedFilter):
|
||||
field_class = RangeField
|
27
graphene_django/filter/filters/typed_filter.py
Normal file
27
graphene_django/filter/filters/typed_filter.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from django_filters import Filter
|
||||
|
||||
from graphene.types.utils import get_type
|
||||
|
||||
|
||||
class TypedFilter(Filter):
|
||||
"""
|
||||
Filter class for which the input GraphQL type can explicitly be provided.
|
||||
If it is not provided, when building the schema, it will try to guess
|
||||
it from the field.
|
||||
"""
|
||||
|
||||
def __init__(self, input_type=None, *args, **kwargs):
|
||||
self._input_type = input_type
|
||||
super(TypedFilter, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def input_type(self):
|
||||
input_type = get_type(self._input_type)
|
||||
if input_type is not None:
|
||||
if not callable(getattr(input_type, "get_type", None)):
|
||||
raise ValueError(
|
||||
"Wrong `input_type` for {}: it only accepts graphene types, got {}".format(
|
||||
self.__class__.__name__, input_type
|
||||
)
|
||||
)
|
||||
return input_type
|
156
graphene_django/filter/tests/test_typed_filter.py
Normal file
156
graphene_django/filter/tests/test_typed_filter.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import pytest
|
||||
|
||||
from django_filters import FilterSet
|
||||
|
||||
import graphene
|
||||
from graphene.relay import Node
|
||||
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.models import Article, Reporter
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import (
|
||||
DjangoFilterConnectionField,
|
||||
TypedFilter,
|
||||
ListFilter,
|
||||
)
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema():
|
||||
class ArticleFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Article
|
||||
fields = {
|
||||
"lang": ["exact", "in"],
|
||||
}
|
||||
|
||||
lang__contains = TypedFilter(
|
||||
field_name="lang", lookup_expr="icontains", input_type=graphene.String
|
||||
)
|
||||
lang__in_str = ListFilter(
|
||||
field_name="lang",
|
||||
lookup_expr="in",
|
||||
input_type=graphene.List(graphene.String),
|
||||
)
|
||||
first_n = TypedFilter(input_type=graphene.Int, method="first_n_filter")
|
||||
only_first = TypedFilter(
|
||||
input_type=graphene.Boolean, method="only_first_filter"
|
||||
)
|
||||
|
||||
def first_n_filter(self, queryset, _name, value):
|
||||
return queryset[:value]
|
||||
|
||||
def only_first_filter(self, queryset, _name, value):
|
||||
if value:
|
||||
return queryset[:1]
|
||||
else:
|
||||
return queryset
|
||||
|
||||
class ArticleType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (Node,)
|
||||
filterset_class = ArticleFilterSet
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
articles = DjangoFilterConnectionField(ArticleType)
|
||||
|
||||
schema = graphene.Schema(query=Query)
|
||||
return schema
|
||||
|
||||
|
||||
def test_typed_filter_schema(schema):
|
||||
"""
|
||||
Check that the type provided in the filter is reflected in the schema.
|
||||
"""
|
||||
|
||||
schema_str = str(schema)
|
||||
|
||||
filters = {
|
||||
"offset": "Int",
|
||||
"before": "String",
|
||||
"after": "String",
|
||||
"first": "Int",
|
||||
"last": "Int",
|
||||
"lang": "ArticleLang",
|
||||
"lang_In": "[ArticleLang]",
|
||||
"lang_Contains": "String",
|
||||
"lang_InStr": "[String]",
|
||||
"firstN": "Int",
|
||||
"onlyFirst": "Boolean",
|
||||
}
|
||||
|
||||
all_articles_filters = (
|
||||
schema_str.split(" articles(")[1]
|
||||
.split("): ArticleTypeConnection\n")[0]
|
||||
.split(", ")
|
||||
)
|
||||
|
||||
for filter_field, gql_type in filters.items():
|
||||
assert "{}: {}".format(filter_field, gql_type) in all_articles_filters
|
||||
|
||||
|
||||
def test_typed_filters_work(schema):
|
||||
reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="")
|
||||
Article.objects.create(
|
||||
headline="A", reporter=reporter, editor=reporter, lang="es",
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="B", reporter=reporter, editor=reporter, lang="es",
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="C", reporter=reporter, editor=reporter, lang="en",
|
||||
)
|
||||
|
||||
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
{"node": {"headline": "B"}},
|
||||
]
|
||||
|
||||
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
{"node": {"headline": "B"}},
|
||||
]
|
||||
|
||||
query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }'
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "C"}},
|
||||
]
|
||||
|
||||
query = "query { articles (firstN: 2) { edges { node { headline } } } }"
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
{"node": {"headline": "B"}},
|
||||
]
|
||||
|
||||
query = "query { articles (onlyFirst: true) { edges { node { headline } } } }"
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["articles"]["edges"] == [
|
||||
{"node": {"headline": "A"}},
|
||||
]
|
|
@ -8,7 +8,7 @@ from django_filters.utils import get_model_field, get_field_parts
|
|||
from django_filters.filters import Filter, BaseCSVFilter
|
||||
|
||||
from .filterset import custom_filterset_factory, setup_filterset
|
||||
from .filters import ArrayFilter, ListFilter, RangeFilter
|
||||
from .filters import ArrayFilter, ListFilter, RangeFilter, TypedFilter
|
||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
|
@ -44,60 +44,62 @@ def get_filtering_args_from_filterset(filterset_class, type):
|
|||
form_field = None
|
||||
|
||||
if (
|
||||
name not in filterset_class.declared_filters
|
||||
or isinstance(filter_field, ListFilter)
|
||||
or isinstance(filter_field, RangeFilter)
|
||||
or isinstance(filter_field, ArrayFilter)
|
||||
isinstance(filter_field, TypedFilter)
|
||||
and filter_field.input_type is not None
|
||||
):
|
||||
# Get the filter field for filters that are no explicitly declared.
|
||||
if filter_type == "isnull":
|
||||
field = graphene.Boolean(required=required)
|
||||
else:
|
||||
model_field = get_model_field(model, filter_field.field_name)
|
||||
# First check if the filter input type has been explicitely given
|
||||
field_type = filter_field.input_type
|
||||
else:
|
||||
if name not in filterset_class.declared_filters or isinstance(
|
||||
filter_field, TypedFilter
|
||||
):
|
||||
# Get the filter field for filters that are no explicitly declared.
|
||||
if filter_type == "isnull":
|
||||
field = graphene.Boolean(required=required)
|
||||
else:
|
||||
model_field = get_model_field(model, filter_field.field_name)
|
||||
|
||||
# Get the form field either from:
|
||||
# 1. the formfield corresponding to the model field
|
||||
# 2. the field defined on filter
|
||||
if hasattr(model_field, "formfield"):
|
||||
form_field = model_field.formfield(required=required)
|
||||
if not form_field:
|
||||
form_field = filter_field.field
|
||||
# Get the form field either from:
|
||||
# 1. the formfield corresponding to the model field
|
||||
# 2. the field defined on filter
|
||||
if hasattr(model_field, "formfield"):
|
||||
form_field = model_field.formfield(required=required)
|
||||
if not form_field:
|
||||
form_field = filter_field.field
|
||||
|
||||
# First try to get the matching field type from the GraphQL DjangoObjectType
|
||||
if model_field:
|
||||
if (
|
||||
isinstance(form_field, forms.ModelChoiceField)
|
||||
or isinstance(form_field, forms.ModelMultipleChoiceField)
|
||||
or isinstance(form_field, GlobalIDMultipleChoiceField)
|
||||
or isinstance(form_field, GlobalIDFormField)
|
||||
):
|
||||
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
|
||||
field_type = get_field_type(
|
||||
registry, model_field.related_model, "id"
|
||||
)
|
||||
else:
|
||||
field_type = get_field_type(
|
||||
registry, model_field.model, model_field.name
|
||||
)
|
||||
# First try to get the matching field type from the GraphQL DjangoObjectType
|
||||
if model_field:
|
||||
if (
|
||||
isinstance(form_field, forms.ModelChoiceField)
|
||||
or isinstance(form_field, forms.ModelMultipleChoiceField)
|
||||
or isinstance(form_field, GlobalIDMultipleChoiceField)
|
||||
or isinstance(form_field, GlobalIDFormField)
|
||||
):
|
||||
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
|
||||
field_type = get_field_type(
|
||||
registry, model_field.related_model, "id"
|
||||
)
|
||||
else:
|
||||
field_type = get_field_type(
|
||||
registry, model_field.model, model_field.name
|
||||
)
|
||||
|
||||
if not field_type:
|
||||
# Fallback on converting the form field either because:
|
||||
# - it's an explicitly declared filters
|
||||
# - we did not manage to get the type from the model type
|
||||
form_field = form_field or filter_field.field
|
||||
field_type = convert_form_field(form_field)
|
||||
if not field_type:
|
||||
# Fallback on converting the form field either because:
|
||||
# - it's an explicitly declared filters
|
||||
# - we did not manage to get the type from the model type
|
||||
form_field = form_field or filter_field.field
|
||||
field_type = convert_form_field(form_field).get_type()
|
||||
|
||||
if isinstance(filter_field, ListFilter) or isinstance(
|
||||
filter_field, RangeFilter
|
||||
):
|
||||
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
|
||||
# the same type as the field. See comments in `replace_csv_filters` method for more details.
|
||||
field_type = graphene.List(field_type.get_type())
|
||||
if isinstance(filter_field, ListFilter) or isinstance(
|
||||
filter_field, RangeFilter
|
||||
):
|
||||
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
|
||||
# the same type as the field. See comments in `replace_csv_filters` method for more details.
|
||||
field_type = graphene.List(field_type)
|
||||
|
||||
args[name] = graphene.Argument(
|
||||
type=field_type.get_type(),
|
||||
description=filter_field.label,
|
||||
required=required,
|
||||
type=field_type, description=filter_field.label, required=required,
|
||||
)
|
||||
|
||||
return args
|
||||
|
|
Loading…
Reference in New Issue
Block a user