mirror of
https://github.com/graphql-python/graphene-django.git
synced 2024-11-15 14:17:55 +03:00
Validate in and range filter inputs (#1092)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
This commit is contained in:
parent
1281c1338d
commit
aff56b882b
|
@ -9,7 +9,7 @@ if not DJANGO_FILTER_INSTALLED:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .fields import DjangoFilterConnectionField
|
from .fields import DjangoFilterConnectionField
|
||||||
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DjangoFilterConnectionField",
|
"DjangoFilterConnectionField",
|
||||||
|
|
75
graphene_django/filter/filters.py
Normal file
75
graphene_django/filter/filters.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
|
from django.forms import Field
|
||||||
|
|
||||||
|
from django_filters import Filter, MultipleChoiceFilter
|
||||||
|
|
||||||
|
from graphql_relay.node.node import from_global_id
|
||||||
|
|
||||||
|
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalIDFilter(Filter):
|
||||||
|
"""
|
||||||
|
Filter for Relay global ID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
field_class = GlobalIDFormField
|
||||||
|
|
||||||
|
def filter(self, qs, value):
|
||||||
|
""" Convert the filter value to a primary key before filtering """
|
||||||
|
_id = None
|
||||||
|
if value is not None:
|
||||||
|
_, _id = from_global_id(value)
|
||||||
|
return super(GlobalIDFilter, self).filter(qs, _id)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
||||||
|
field_class = GlobalIDMultipleChoiceField
|
||||||
|
|
||||||
|
def filter(self, qs, value):
|
||||||
|
gids = [from_global_id(v)[1] for v in value]
|
||||||
|
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
|
||||||
|
|
||||||
|
|
||||||
|
class InFilter(Filter):
|
||||||
|
"""
|
||||||
|
Filter for a list of value using the `__in` Django filter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filter(self, qs, value):
|
||||||
|
"""
|
||||||
|
Override the default filter class to check first weather the list is
|
||||||
|
empty or not.
|
||||||
|
This needs to be done as in this case we expect to get an empty output
|
||||||
|
(if not an exclude filter) but django_filter consider an empty list
|
||||||
|
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||||
|
the filter does not need to be applied (hence returning the original
|
||||||
|
queryset).
|
||||||
|
"""
|
||||||
|
if value is not None and len(value) == 0:
|
||||||
|
if self.exclude:
|
||||||
|
return qs
|
||||||
|
else:
|
||||||
|
return qs.none()
|
||||||
|
else:
|
||||||
|
return super(InFilter, self).filter(qs, value)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_range(value):
|
||||||
|
"""
|
||||||
|
Validator for range filter input: the list of value must be of length 2.
|
||||||
|
Note that validators are only run if the value is not empty.
|
||||||
|
"""
|
||||||
|
if len(value) != 2:
|
||||||
|
raise ValidationError(
|
||||||
|
"Invalid range specified: it needs to contain 2 values.", code="invalid"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RangeField(Field):
|
||||||
|
default_validators = [validate_range]
|
||||||
|
empty_values = [None]
|
||||||
|
|
||||||
|
|
||||||
|
class RangeFilter(Filter):
|
||||||
|
field_class = RangeField
|
|
@ -1,32 +1,11 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django_filters import Filter, MultipleChoiceFilter, VERSION
|
from django_filters import VERSION
|
||||||
from django_filters.filterset import BaseFilterSet, FilterSet
|
from django_filters.filterset import BaseFilterSet, FilterSet
|
||||||
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
|
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
|
||||||
|
|
||||||
from graphql_relay.node.node import from_global_id
|
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||||
|
|
||||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalIDFilter(Filter):
|
|
||||||
field_class = GlobalIDFormField
|
|
||||||
|
|
||||||
def filter(self, qs, value):
|
|
||||||
""" Convert the filter value to a primary key before filtering """
|
|
||||||
_id = None
|
|
||||||
if value is not None:
|
|
||||||
_, _id = from_global_id(value)
|
|
||||||
return super(GlobalIDFilter, self).filter(qs, _id)
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
|
||||||
field_class = GlobalIDMultipleChoiceField
|
|
||||||
|
|
||||||
def filter(self, qs, value):
|
|
||||||
gids = [from_global_id(v)[1] for v in value]
|
|
||||||
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
|
|
||||||
|
|
||||||
|
|
||||||
GRAPHENE_FILTER_SET_OVERRIDES = {
|
GRAPHENE_FILTER_SET_OVERRIDES = {
|
||||||
|
|
|
@ -157,20 +157,19 @@ def test_int_in_filter():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_int_range_filter():
|
def test_in_filter_with_empty_list():
|
||||||
"""
|
"""
|
||||||
Test in filter on an integer field.
|
Check that using a in filter with an empty list provided as input returns no objects.
|
||||||
"""
|
"""
|
||||||
Pet.objects.create(name="Brutus", age=12)
|
Pet.objects.create(name="Brutus", age=12)
|
||||||
Pet.objects.create(name="Mimi", age=8)
|
Pet.objects.create(name="Mimi", age=8)
|
||||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
|
||||||
Pet.objects.create(name="Picotin", age=5)
|
Pet.objects.create(name="Picotin", age=5)
|
||||||
|
|
||||||
schema = Schema(query=Query)
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
query {
|
query {
|
||||||
pets (age_Range: [4, 9]) {
|
pets (name_In: []) {
|
||||||
edges {
|
edges {
|
||||||
node {
|
node {
|
||||||
name
|
name
|
||||||
|
@ -181,7 +180,4 @@ def test_int_range_filter():
|
||||||
"""
|
"""
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data["pets"]["edges"] == [
|
assert len(result.data["pets"]["edges"]) == 0
|
||||||
{"node": {"name": "Mimi"}},
|
|
||||||
{"node": {"name": "Picotin"}},
|
|
||||||
]
|
|
||||||
|
|
115
graphene_django/filter/tests/test_range_filter.py
Normal file
115
graphene_django/filter/tests/test_range_filter.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from django_filters import FilterSet
|
||||||
|
from django_filters import rest_framework as filters
|
||||||
|
from graphene import ObjectType, Schema
|
||||||
|
from graphene.relay import Node
|
||||||
|
from graphene_django import DjangoObjectType
|
||||||
|
from graphene_django.tests.models import Pet
|
||||||
|
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||||
|
|
||||||
|
pytestmark = []
|
||||||
|
|
||||||
|
if DJANGO_FILTER_INSTALLED:
|
||||||
|
from graphene_django.filter import DjangoFilterConnectionField
|
||||||
|
else:
|
||||||
|
pytestmark.append(
|
||||||
|
pytest.mark.skipif(
|
||||||
|
True, reason="django_filters not installed or not compatible"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PetNode(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = Pet
|
||||||
|
interfaces = (Node,)
|
||||||
|
filter_fields = {
|
||||||
|
"name": ["exact", "in"],
|
||||||
|
"age": ["exact", "in", "range"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
pets = DjangoFilterConnectionField(PetNode)
|
||||||
|
|
||||||
|
|
||||||
|
def test_int_range_filter():
|
||||||
|
"""
|
||||||
|
Test range filter on an integer field.
|
||||||
|
"""
|
||||||
|
Pet.objects.create(name="Brutus", age=12)
|
||||||
|
Pet.objects.create(name="Mimi", age=8)
|
||||||
|
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||||
|
Pet.objects.create(name="Picotin", age=5)
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
pets (age_Range: [4, 9]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["pets"]["edges"] == [
|
||||||
|
{"node": {"name": "Mimi"}},
|
||||||
|
{"node": {"name": "Picotin"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_range_filter_with_invalid_input():
|
||||||
|
"""
|
||||||
|
Test range filter used with invalid inputs raise an error.
|
||||||
|
"""
|
||||||
|
Pet.objects.create(name="Brutus", age=12)
|
||||||
|
Pet.objects.create(name="Mimi", age=8)
|
||||||
|
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||||
|
Pet.objects.create(name="Picotin", age=5)
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query ($rangeValue: [Int]) {
|
||||||
|
pets (age_Range: $rangeValue) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
expected_error = json.dumps(
|
||||||
|
{
|
||||||
|
"age__range": [
|
||||||
|
{
|
||||||
|
"message": "Invalid range specified: it needs to contain 2 values.",
|
||||||
|
"code": "invalid",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty list
|
||||||
|
result = schema.execute(query, variables={"rangeValue": []})
|
||||||
|
assert len(result.errors) == 1
|
||||||
|
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
|
||||||
|
|
||||||
|
# Only one item in the list
|
||||||
|
result = schema.execute(query, variables={"rangeValue": [1]})
|
||||||
|
assert len(result.errors) == 1
|
||||||
|
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
|
||||||
|
|
||||||
|
# More than 2 items in the list
|
||||||
|
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
|
||||||
|
assert len(result.errors) == 1
|
||||||
|
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
|
|
@ -6,6 +6,7 @@ from django_filters.utils import get_model_field
|
||||||
from django_filters.filters import Filter, BaseCSVFilter
|
from django_filters.filters import Filter, BaseCSVFilter
|
||||||
|
|
||||||
from .filterset import custom_filterset_factory, setup_filterset
|
from .filterset import custom_filterset_factory, setup_filterset
|
||||||
|
from .filters import InFilter, RangeFilter
|
||||||
|
|
||||||
|
|
||||||
def get_filtering_args_from_filterset(filterset_class, type):
|
def get_filtering_args_from_filterset(filterset_class, type):
|
||||||
|
@ -80,9 +81,20 @@ def replace_csv_filters(filterset_class):
|
||||||
"""
|
"""
|
||||||
for name, filter_field in six.iteritems(filterset_class.base_filters):
|
for name, filter_field in six.iteritems(filterset_class.base_filters):
|
||||||
filter_type = filter_field.lookup_expr
|
filter_type = filter_field.lookup_expr
|
||||||
if filter_type in ["in", "range"]:
|
if filter_type == "in":
|
||||||
assert isinstance(filter_field, BaseCSVFilter)
|
assert isinstance(filter_field, BaseCSVFilter)
|
||||||
filterset_class.base_filters[name] = Filter(
|
filterset_class.base_filters[name] = InFilter(
|
||||||
|
field_name=filter_field.field_name,
|
||||||
|
lookup_expr=filter_field.lookup_expr,
|
||||||
|
label=filter_field.label,
|
||||||
|
method=filter_field.method,
|
||||||
|
exclude=filter_field.exclude,
|
||||||
|
**filter_field.extra
|
||||||
|
)
|
||||||
|
|
||||||
|
if filter_type == "range":
|
||||||
|
assert isinstance(filter_field, BaseCSVFilter)
|
||||||
|
filterset_class.base_filters[name] = RangeFilter(
|
||||||
field_name=filter_field.field_name,
|
field_name=filter_field.field_name,
|
||||||
lookup_expr=filter_field.lookup_expr,
|
lookup_expr=filter_field.lookup_expr,
|
||||||
label=filter_field.label,
|
label=filter_field.label,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user