Validate in and range filter inputs (#1090)

Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
This commit is contained in:
Thomas Leonard 2021-01-10 04:15:21 +01:00 committed by GitHub
parent ea84827ab8
commit 10e48c27b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 33 deletions

View File

@ -9,7 +9,7 @@ if not DJANGO_FILTER_INSTALLED:
)
else:
from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
__all__ = [
"DjangoFilterConnectionField",

View File

@ -0,0 +1,75 @@
from django.core.exceptions import ValidationError
from django.forms import Field
from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""
field_class = GlobalIDFormField
def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
class InFilter(Filter):
"""
Filter for a list of value using the `__in` Django filter.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first weather the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super().filter(qs, value)
def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)
class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]
class RangeFilter(Filter):
field_class = RangeField

View File

@ -5,28 +5,7 @@ from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
from graphql_relay.node.node import from_global_id
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
field_class = GlobalIDFormField
def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
GRAPHENE_FILTER_SET_OVERRIDES = {

View File

@ -157,20 +157,19 @@ def test_int_in_filter():
]
def test_int_range_filter():
def test_in_filter_with_empty_list():
"""
Test in filter on an integer field.
Check that using a in filter with an empty list provided as input returns no objects.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query {
pets (age_Range: [4, 9]) {
pets (name_In: []) {
edges {
node {
name
@ -181,7 +180,4 @@ def test_int_range_filter():
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Picotin"}},
]
assert len(result.data["pets"]["edges"]) == 0

View File

@ -0,0 +1,114 @@
import json
import pytest
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
def test_int_range_filter():
"""
Test range filter on an integer field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query {
pets (age_Range: [4, 9]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Picotin"}},
]
def test_range_filter_with_invalid_input():
"""
Test range filter used with invalid inputs raise an error.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query ($rangeValue: [Int]) {
pets (age_Range: $rangeValue) {
edges {
node {
name
}
}
}
}
"""
expected_error = json.dumps(
{
"age__range": [
{
"message": "Invalid range specified: it needs to contain 2 values.",
"code": "invalid",
}
]
}
)
# Empty list
result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
# Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
# More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"

View File

@ -4,6 +4,7 @@ from django_filters.utils import get_model_field
from django_filters.filters import Filter, BaseCSVFilter
from .filterset import custom_filterset_factory, setup_filterset
from .filters import InFilter, RangeFilter
def get_filtering_args_from_filterset(filterset_class, type):
@ -78,9 +79,20 @@ def replace_csv_filters(filterset_class):
"""
for name, filter_field in list(filterset_class.base_filters.items()):
filter_type = filter_field.lookup_expr
if filter_type in ["in", "range"]:
if filter_type == "in":
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = Filter(
filterset_class.base_filters[name] = InFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
method=filter_field.method,
exclude=filter_field.exclude,
**filter_field.extra
)
if filter_type == "range":
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,