mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 01:47:57 +03:00 
			
		
		
		
	Validate in and range filter inputs (#1090)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
This commit is contained in:
		
							parent
							
								
									ea84827ab8
								
							
						
					
					
						commit
						10e48c27b7
					
				| 
						 | 
				
			
			@ -9,7 +9,7 @@ if not DJANGO_FILTER_INSTALLED:
 | 
			
		|||
    )
 | 
			
		||||
else:
 | 
			
		||||
    from .fields import DjangoFilterConnectionField
 | 
			
		||||
    from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
 | 
			
		||||
    from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
 | 
			
		||||
 | 
			
		||||
    __all__ = [
 | 
			
		||||
        "DjangoFilterConnectionField",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										75
									
								
								graphene_django/filter/filters.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								graphene_django/filter/filters.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,75 @@
 | 
			
		|||
from django.core.exceptions import ValidationError
 | 
			
		||||
from django.forms import Field
 | 
			
		||||
 | 
			
		||||
from django_filters import Filter, MultipleChoiceFilter
 | 
			
		||||
 | 
			
		||||
from graphql_relay.node.node import from_global_id
 | 
			
		||||
 | 
			
		||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GlobalIDFilter(Filter):
 | 
			
		||||
    """
 | 
			
		||||
    Filter for Relay global ID.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    field_class = GlobalIDFormField
 | 
			
		||||
 | 
			
		||||
    def filter(self, qs, value):
 | 
			
		||||
        """ Convert the filter value to a primary key before filtering """
 | 
			
		||||
        _id = None
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            _, _id = from_global_id(value)
 | 
			
		||||
        return super(GlobalIDFilter, self).filter(qs, _id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
 | 
			
		||||
    field_class = GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
    def filter(self, qs, value):
 | 
			
		||||
        gids = [from_global_id(v)[1] for v in value]
 | 
			
		||||
        return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InFilter(Filter):
 | 
			
		||||
    """
 | 
			
		||||
    Filter for a list of value using the `__in` Django filter.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def filter(self, qs, value):
 | 
			
		||||
        """
 | 
			
		||||
        Override the default filter class to check first weather the list is
 | 
			
		||||
        empty or not.
 | 
			
		||||
        This needs to be done as in this case we expect to get an empty output
 | 
			
		||||
        (if not an exclude filter) but django_filter consider an empty list
 | 
			
		||||
        to be an empty input value (see `EMPTY_VALUES`) meaning that
 | 
			
		||||
        the filter does not need to be applied (hence returning the original
 | 
			
		||||
        queryset).
 | 
			
		||||
        """
 | 
			
		||||
        if value is not None and len(value) == 0:
 | 
			
		||||
            if self.exclude:
 | 
			
		||||
                return qs
 | 
			
		||||
            else:
 | 
			
		||||
                return qs.none()
 | 
			
		||||
        else:
 | 
			
		||||
            return super().filter(qs, value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_range(value):
 | 
			
		||||
    """
 | 
			
		||||
    Validator for range filter input: the list of value must be of length 2.
 | 
			
		||||
    Note that validators are only run if the value is not empty.
 | 
			
		||||
    """
 | 
			
		||||
    if len(value) != 2:
 | 
			
		||||
        raise ValidationError(
 | 
			
		||||
            "Invalid range specified: it needs to contain 2 values.", code="invalid"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RangeField(Field):
 | 
			
		||||
    default_validators = [validate_range]
 | 
			
		||||
    empty_values = [None]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RangeFilter(Filter):
 | 
			
		||||
    field_class = RangeField
 | 
			
		||||
| 
						 | 
				
			
			@ -5,28 +5,7 @@ from django_filters import Filter, MultipleChoiceFilter
 | 
			
		|||
from django_filters.filterset import BaseFilterSet, FilterSet
 | 
			
		||||
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
 | 
			
		||||
 | 
			
		||||
from graphql_relay.node.node import from_global_id
 | 
			
		||||
 | 
			
		||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GlobalIDFilter(Filter):
 | 
			
		||||
    field_class = GlobalIDFormField
 | 
			
		||||
 | 
			
		||||
    def filter(self, qs, value):
 | 
			
		||||
        """ Convert the filter value to a primary key before filtering """
 | 
			
		||||
        _id = None
 | 
			
		||||
        if value is not None:
 | 
			
		||||
            _, _id = from_global_id(value)
 | 
			
		||||
        return super(GlobalIDFilter, self).filter(qs, _id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
 | 
			
		||||
    field_class = GlobalIDMultipleChoiceField
 | 
			
		||||
 | 
			
		||||
    def filter(self, qs, value):
 | 
			
		||||
        gids = [from_global_id(v)[1] for v in value]
 | 
			
		||||
        return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
 | 
			
		||||
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
GRAPHENE_FILTER_SET_OVERRIDES = {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -157,20 +157,19 @@ def test_int_in_filter():
 | 
			
		|||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_int_range_filter():
 | 
			
		||||
def test_in_filter_with_empty_list():
 | 
			
		||||
    """
 | 
			
		||||
    Test in filter on an integer field.
 | 
			
		||||
    Check that using a in filter with an empty list provided as input returns no objects.
 | 
			
		||||
    """
 | 
			
		||||
    Pet.objects.create(name="Brutus", age=12)
 | 
			
		||||
    Pet.objects.create(name="Mimi", age=8)
 | 
			
		||||
    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
			
		||||
    Pet.objects.create(name="Picotin", age=5)
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        pets (age_Range: [4, 9]) {
 | 
			
		||||
        pets (name_In: []) {
 | 
			
		||||
            edges {
 | 
			
		||||
                node {
 | 
			
		||||
                    name
 | 
			
		||||
| 
						 | 
				
			
			@ -181,7 +180,4 @@ def test_int_range_filter():
 | 
			
		|||
    """
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data["pets"]["edges"] == [
 | 
			
		||||
        {"node": {"name": "Mimi"}},
 | 
			
		||||
        {"node": {"name": "Picotin"}},
 | 
			
		||||
    ]
 | 
			
		||||
    assert len(result.data["pets"]["edges"]) == 0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										114
									
								
								graphene_django/filter/tests/test_range_filter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								graphene_django/filter/tests/test_range_filter.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,114 @@
 | 
			
		|||
import json
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from django_filters import FilterSet
 | 
			
		||||
from django_filters import rest_framework as filters
 | 
			
		||||
from graphene import ObjectType, Schema
 | 
			
		||||
from graphene.relay import Node
 | 
			
		||||
from graphene_django import DjangoObjectType
 | 
			
		||||
from graphene_django.tests.models import Pet
 | 
			
		||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
 | 
			
		||||
 | 
			
		||||
pytestmark = []
 | 
			
		||||
 | 
			
		||||
if DJANGO_FILTER_INSTALLED:
 | 
			
		||||
    from graphene_django.filter import DjangoFilterConnectionField
 | 
			
		||||
else:
 | 
			
		||||
    pytestmark.append(
 | 
			
		||||
        pytest.mark.skipif(
 | 
			
		||||
            True, reason="django_filters not installed or not compatible"
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PetNode(DjangoObjectType):
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Pet
 | 
			
		||||
        interfaces = (Node,)
 | 
			
		||||
        filter_fields = {
 | 
			
		||||
            "name": ["exact", "in"],
 | 
			
		||||
            "age": ["exact", "in", "range"],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Query(ObjectType):
 | 
			
		||||
    pets = DjangoFilterConnectionField(PetNode)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_int_range_filter():
 | 
			
		||||
    """
 | 
			
		||||
    Test range filter on an integer field.
 | 
			
		||||
    """
 | 
			
		||||
    Pet.objects.create(name="Brutus", age=12)
 | 
			
		||||
    Pet.objects.create(name="Mimi", age=8)
 | 
			
		||||
    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
			
		||||
    Pet.objects.create(name="Picotin", age=5)
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        pets (age_Range: [4, 9]) {
 | 
			
		||||
            edges {
 | 
			
		||||
                node {
 | 
			
		||||
                    name
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data["pets"]["edges"] == [
 | 
			
		||||
        {"node": {"name": "Mimi"}},
 | 
			
		||||
        {"node": {"name": "Picotin"}},
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_range_filter_with_invalid_input():
 | 
			
		||||
    """
 | 
			
		||||
    Test range filter used with invalid inputs raise an error.
 | 
			
		||||
    """
 | 
			
		||||
    Pet.objects.create(name="Brutus", age=12)
 | 
			
		||||
    Pet.objects.create(name="Mimi", age=8)
 | 
			
		||||
    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
			
		||||
    Pet.objects.create(name="Picotin", age=5)
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
    query ($rangeValue: [Int]) {
 | 
			
		||||
        pets (age_Range: $rangeValue) {
 | 
			
		||||
            edges {
 | 
			
		||||
                node {
 | 
			
		||||
                    name
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    expected_error = json.dumps(
 | 
			
		||||
        {
 | 
			
		||||
            "age__range": [
 | 
			
		||||
                {
 | 
			
		||||
                    "message": "Invalid range specified: it needs to contain 2 values.",
 | 
			
		||||
                    "code": "invalid",
 | 
			
		||||
                }
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Empty list
 | 
			
		||||
    result = schema.execute(query, variables={"rangeValue": []})
 | 
			
		||||
    assert len(result.errors) == 1
 | 
			
		||||
    assert result.errors[0].message == f"['{expected_error}']"
 | 
			
		||||
 | 
			
		||||
    # Only one item in the list
 | 
			
		||||
    result = schema.execute(query, variables={"rangeValue": [1]})
 | 
			
		||||
    assert len(result.errors) == 1
 | 
			
		||||
    assert result.errors[0].message == f"['{expected_error}']"
 | 
			
		||||
 | 
			
		||||
    # More than 2 items in the list
 | 
			
		||||
    result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
 | 
			
		||||
    assert len(result.errors) == 1
 | 
			
		||||
    assert result.errors[0].message == f"['{expected_error}']"
 | 
			
		||||
| 
						 | 
				
			
			@ -4,6 +4,7 @@ from django_filters.utils import get_model_field
 | 
			
		|||
from django_filters.filters import Filter, BaseCSVFilter
 | 
			
		||||
 | 
			
		||||
from .filterset import custom_filterset_factory, setup_filterset
 | 
			
		||||
from .filters import InFilter, RangeFilter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_filtering_args_from_filterset(filterset_class, type):
 | 
			
		||||
| 
						 | 
				
			
			@ -78,9 +79,20 @@ def replace_csv_filters(filterset_class):
 | 
			
		|||
    """
 | 
			
		||||
    for name, filter_field in list(filterset_class.base_filters.items()):
 | 
			
		||||
        filter_type = filter_field.lookup_expr
 | 
			
		||||
        if filter_type in ["in", "range"]:
 | 
			
		||||
        if filter_type == "in":
 | 
			
		||||
            assert isinstance(filter_field, BaseCSVFilter)
 | 
			
		||||
            filterset_class.base_filters[name] = Filter(
 | 
			
		||||
            filterset_class.base_filters[name] = InFilter(
 | 
			
		||||
                field_name=filter_field.field_name,
 | 
			
		||||
                lookup_expr=filter_field.lookup_expr,
 | 
			
		||||
                label=filter_field.label,
 | 
			
		||||
                method=filter_field.method,
 | 
			
		||||
                exclude=filter_field.exclude,
 | 
			
		||||
                **filter_field.extra
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if filter_type == "range":
 | 
			
		||||
            assert isinstance(filter_field, BaseCSVFilter)
 | 
			
		||||
            filterset_class.base_filters[name] = RangeFilter(
 | 
			
		||||
                field_name=filter_field.field_name,
 | 
			
		||||
                lookup_expr=filter_field.lookup_expr,
 | 
			
		||||
                label=filter_field.label,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user