mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 09:57:53 +03:00 
			
		
		
		
	Validate in and range filter inputs (#1090)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
This commit is contained in:
		
							parent
							
								
									ea84827ab8
								
							
						
					
					
						commit
						10e48c27b7
					
				| 
						 | 
					@ -9,7 +9,7 @@ if not DJANGO_FILTER_INSTALLED:
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    from .fields import DjangoFilterConnectionField
 | 
					    from .fields import DjangoFilterConnectionField
 | 
				
			||||||
    from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
 | 
					    from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __all__ = [
 | 
					    __all__ = [
 | 
				
			||||||
        "DjangoFilterConnectionField",
 | 
					        "DjangoFilterConnectionField",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										75
									
								
								graphene_django/filter/filters.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								graphene_django/filter/filters.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,75 @@
 | 
				
			||||||
 | 
					from django.core.exceptions import ValidationError
 | 
				
			||||||
 | 
					from django.forms import Field
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django_filters import Filter, MultipleChoiceFilter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from graphql_relay.node.node import from_global_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GlobalIDFilter(Filter):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Filter for Relay global ID.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    field_class = GlobalIDFormField
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def filter(self, qs, value):
 | 
				
			||||||
 | 
					        """ Convert the filter value to a primary key before filtering """
 | 
				
			||||||
 | 
					        _id = None
 | 
				
			||||||
 | 
					        if value is not None:
 | 
				
			||||||
 | 
					            _, _id = from_global_id(value)
 | 
				
			||||||
 | 
					        return super(GlobalIDFilter, self).filter(qs, _id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
 | 
				
			||||||
 | 
					    field_class = GlobalIDMultipleChoiceField
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def filter(self, qs, value):
 | 
				
			||||||
 | 
					        gids = [from_global_id(v)[1] for v in value]
 | 
				
			||||||
 | 
					        return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InFilter(Filter):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Filter for a list of value using the `__in` Django filter.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def filter(self, qs, value):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Override the default filter class to check first weather the list is
 | 
				
			||||||
 | 
					        empty or not.
 | 
				
			||||||
 | 
					        This needs to be done as in this case we expect to get an empty output
 | 
				
			||||||
 | 
					        (if not an exclude filter) but django_filter consider an empty list
 | 
				
			||||||
 | 
					        to be an empty input value (see `EMPTY_VALUES`) meaning that
 | 
				
			||||||
 | 
					        the filter does not need to be applied (hence returning the original
 | 
				
			||||||
 | 
					        queryset).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if value is not None and len(value) == 0:
 | 
				
			||||||
 | 
					            if self.exclude:
 | 
				
			||||||
 | 
					                return qs
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return qs.none()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return super().filter(qs, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_range(value):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Validator for range filter input: the list of value must be of length 2.
 | 
				
			||||||
 | 
					    Note that validators are only run if the value is not empty.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if len(value) != 2:
 | 
				
			||||||
 | 
					        raise ValidationError(
 | 
				
			||||||
 | 
					            "Invalid range specified: it needs to contain 2 values.", code="invalid"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RangeField(Field):
 | 
				
			||||||
 | 
					    default_validators = [validate_range]
 | 
				
			||||||
 | 
					    empty_values = [None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RangeFilter(Filter):
 | 
				
			||||||
 | 
					    field_class = RangeField
 | 
				
			||||||
| 
						 | 
					@ -5,28 +5,7 @@ from django_filters import Filter, MultipleChoiceFilter
 | 
				
			||||||
from django_filters.filterset import BaseFilterSet, FilterSet
 | 
					from django_filters.filterset import BaseFilterSet, FilterSet
 | 
				
			||||||
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
 | 
					from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from graphql_relay.node.node import from_global_id
 | 
					from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
 | 
				
			||||||
 | 
					 | 
				
			||||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GlobalIDFilter(Filter):
 | 
					 | 
				
			||||||
    field_class = GlobalIDFormField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def filter(self, qs, value):
 | 
					 | 
				
			||||||
        """ Convert the filter value to a primary key before filtering """
 | 
					 | 
				
			||||||
        _id = None
 | 
					 | 
				
			||||||
        if value is not None:
 | 
					 | 
				
			||||||
            _, _id = from_global_id(value)
 | 
					 | 
				
			||||||
        return super(GlobalIDFilter, self).filter(qs, _id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
 | 
					 | 
				
			||||||
    field_class = GlobalIDMultipleChoiceField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def filter(self, qs, value):
 | 
					 | 
				
			||||||
        gids = [from_global_id(v)[1] for v in value]
 | 
					 | 
				
			||||||
        return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
GRAPHENE_FILTER_SET_OVERRIDES = {
 | 
					GRAPHENE_FILTER_SET_OVERRIDES = {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -157,20 +157,19 @@ def test_int_in_filter():
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_int_range_filter():
 | 
					def test_in_filter_with_empty_list():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Test in filter on an integer field.
 | 
					    Check that using a in filter with an empty list provided as input returns no objects.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Pet.objects.create(name="Brutus", age=12)
 | 
					    Pet.objects.create(name="Brutus", age=12)
 | 
				
			||||||
    Pet.objects.create(name="Mimi", age=8)
 | 
					    Pet.objects.create(name="Mimi", age=8)
 | 
				
			||||||
    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
					 | 
				
			||||||
    Pet.objects.create(name="Picotin", age=5)
 | 
					    Pet.objects.create(name="Picotin", age=5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    schema = Schema(query=Query)
 | 
					    schema = Schema(query=Query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    query = """
 | 
					    query = """
 | 
				
			||||||
    query {
 | 
					    query {
 | 
				
			||||||
        pets (age_Range: [4, 9]) {
 | 
					        pets (name_In: []) {
 | 
				
			||||||
            edges {
 | 
					            edges {
 | 
				
			||||||
                node {
 | 
					                node {
 | 
				
			||||||
                    name
 | 
					                    name
 | 
				
			||||||
| 
						 | 
					@ -181,7 +180,4 @@ def test_int_range_filter():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    result = schema.execute(query)
 | 
					    result = schema.execute(query)
 | 
				
			||||||
    assert not result.errors
 | 
					    assert not result.errors
 | 
				
			||||||
    assert result.data["pets"]["edges"] == [
 | 
					    assert len(result.data["pets"]["edges"]) == 0
 | 
				
			||||||
        {"node": {"name": "Mimi"}},
 | 
					 | 
				
			||||||
        {"node": {"name": "Picotin"}},
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										114
									
								
								graphene_django/filter/tests/test_range_filter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								graphene_django/filter/tests/test_range_filter.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,114 @@
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django_filters import FilterSet
 | 
				
			||||||
 | 
					from django_filters import rest_framework as filters
 | 
				
			||||||
 | 
					from graphene import ObjectType, Schema
 | 
				
			||||||
 | 
					from graphene.relay import Node
 | 
				
			||||||
 | 
					from graphene_django import DjangoObjectType
 | 
				
			||||||
 | 
					from graphene_django.tests.models import Pet
 | 
				
			||||||
 | 
					from graphene_django.utils import DJANGO_FILTER_INSTALLED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pytestmark = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if DJANGO_FILTER_INSTALLED:
 | 
				
			||||||
 | 
					    from graphene_django.filter import DjangoFilterConnectionField
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    pytestmark.append(
 | 
				
			||||||
 | 
					        pytest.mark.skipif(
 | 
				
			||||||
 | 
					            True, reason="django_filters not installed or not compatible"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PetNode(DjangoObjectType):
 | 
				
			||||||
 | 
					    class Meta:
 | 
				
			||||||
 | 
					        model = Pet
 | 
				
			||||||
 | 
					        interfaces = (Node,)
 | 
				
			||||||
 | 
					        filter_fields = {
 | 
				
			||||||
 | 
					            "name": ["exact", "in"],
 | 
				
			||||||
 | 
					            "age": ["exact", "in", "range"],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Query(ObjectType):
 | 
				
			||||||
 | 
					    pets = DjangoFilterConnectionField(PetNode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_int_range_filter():
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Test range filter on an integer field.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Brutus", age=12)
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Mimi", age=8)
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Picotin", age=5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    schema = Schema(query=Query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query = """
 | 
				
			||||||
 | 
					    query {
 | 
				
			||||||
 | 
					        pets (age_Range: [4, 9]) {
 | 
				
			||||||
 | 
					            edges {
 | 
				
			||||||
 | 
					                node {
 | 
				
			||||||
 | 
					                    name
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    result = schema.execute(query)
 | 
				
			||||||
 | 
					    assert not result.errors
 | 
				
			||||||
 | 
					    assert result.data["pets"]["edges"] == [
 | 
				
			||||||
 | 
					        {"node": {"name": "Mimi"}},
 | 
				
			||||||
 | 
					        {"node": {"name": "Picotin"}},
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_range_filter_with_invalid_input():
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Test range filter used with invalid inputs raise an error.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Brutus", age=12)
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Mimi", age=8)
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
				
			||||||
 | 
					    Pet.objects.create(name="Picotin", age=5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    schema = Schema(query=Query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query = """
 | 
				
			||||||
 | 
					    query ($rangeValue: [Int]) {
 | 
				
			||||||
 | 
					        pets (age_Range: $rangeValue) {
 | 
				
			||||||
 | 
					            edges {
 | 
				
			||||||
 | 
					                node {
 | 
				
			||||||
 | 
					                    name
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    expected_error = json.dumps(
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "age__range": [
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "message": "Invalid range specified: it needs to contain 2 values.",
 | 
				
			||||||
 | 
					                    "code": "invalid",
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Empty list
 | 
				
			||||||
 | 
					    result = schema.execute(query, variables={"rangeValue": []})
 | 
				
			||||||
 | 
					    assert len(result.errors) == 1
 | 
				
			||||||
 | 
					    assert result.errors[0].message == f"['{expected_error}']"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Only one item in the list
 | 
				
			||||||
 | 
					    result = schema.execute(query, variables={"rangeValue": [1]})
 | 
				
			||||||
 | 
					    assert len(result.errors) == 1
 | 
				
			||||||
 | 
					    assert result.errors[0].message == f"['{expected_error}']"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # More than 2 items in the list
 | 
				
			||||||
 | 
					    result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
 | 
				
			||||||
 | 
					    assert len(result.errors) == 1
 | 
				
			||||||
 | 
					    assert result.errors[0].message == f"['{expected_error}']"
 | 
				
			||||||
| 
						 | 
					@ -4,6 +4,7 @@ from django_filters.utils import get_model_field
 | 
				
			||||||
from django_filters.filters import Filter, BaseCSVFilter
 | 
					from django_filters.filters import Filter, BaseCSVFilter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .filterset import custom_filterset_factory, setup_filterset
 | 
					from .filterset import custom_filterset_factory, setup_filterset
 | 
				
			||||||
 | 
					from .filters import InFilter, RangeFilter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_filtering_args_from_filterset(filterset_class, type):
 | 
					def get_filtering_args_from_filterset(filterset_class, type):
 | 
				
			||||||
| 
						 | 
					@ -78,9 +79,20 @@ def replace_csv_filters(filterset_class):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    for name, filter_field in list(filterset_class.base_filters.items()):
 | 
					    for name, filter_field in list(filterset_class.base_filters.items()):
 | 
				
			||||||
        filter_type = filter_field.lookup_expr
 | 
					        filter_type = filter_field.lookup_expr
 | 
				
			||||||
        if filter_type in ["in", "range"]:
 | 
					        if filter_type == "in":
 | 
				
			||||||
            assert isinstance(filter_field, BaseCSVFilter)
 | 
					            assert isinstance(filter_field, BaseCSVFilter)
 | 
				
			||||||
            filterset_class.base_filters[name] = Filter(
 | 
					            filterset_class.base_filters[name] = InFilter(
 | 
				
			||||||
 | 
					                field_name=filter_field.field_name,
 | 
				
			||||||
 | 
					                lookup_expr=filter_field.lookup_expr,
 | 
				
			||||||
 | 
					                label=filter_field.label,
 | 
				
			||||||
 | 
					                method=filter_field.method,
 | 
				
			||||||
 | 
					                exclude=filter_field.exclude,
 | 
				
			||||||
 | 
					                **filter_field.extra
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if filter_type == "range":
 | 
				
			||||||
 | 
					            assert isinstance(filter_field, BaseCSVFilter)
 | 
				
			||||||
 | 
					            filterset_class.base_filters[name] = RangeFilter(
 | 
				
			||||||
                field_name=filter_field.field_name,
 | 
					                field_name=filter_field.field_name,
 | 
				
			||||||
                lookup_expr=filter_field.lookup_expr,
 | 
					                lookup_expr=filter_field.lookup_expr,
 | 
				
			||||||
                label=filter_field.label,
 | 
					                label=filter_field.label,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user