mirror of
				https://github.com/graphql-python/graphene-django.git
				synced 2025-11-04 09:57:53 +03:00 
			
		
		
		
	fix: in and range filters on DjangoFilterConnectionField (#1070)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
This commit is contained in:
		
							parent
							
								
									7b35695067
								
							
						
					
					
						commit
						99512c53a1
					
				| 
						 | 
				
			
			@ -21,6 +21,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
 | 
			
		|||
        self._fields = fields
 | 
			
		||||
        self._provided_filterset_class = filterset_class
 | 
			
		||||
        self._filterset_class = None
 | 
			
		||||
        self._filtering_args = None
 | 
			
		||||
        self._extra_filter_meta = extra_filter_meta
 | 
			
		||||
        self._base_args = None
 | 
			
		||||
        super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -50,7 +51,11 @@ class DjangoFilterConnectionField(DjangoConnectionField):
 | 
			
		|||
 | 
			
		||||
    @property
 | 
			
		||||
    def filtering_args(self):
 | 
			
		||||
        return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
 | 
			
		||||
        if not self._filtering_args:
 | 
			
		||||
            self._filtering_args = get_filtering_args_from_filterset(
 | 
			
		||||
                self.filterset_class, self.node_type
 | 
			
		||||
            )
 | 
			
		||||
        return self._filtering_args
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_queryset(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										139
									
								
								graphene_django/filter/tests/test_in_filter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								graphene_django/filter/tests/test_in_filter.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,139 @@
 | 
			
		|||
import pytest
 | 
			
		||||
 | 
			
		||||
from graphene import ObjectType, Schema
 | 
			
		||||
from graphene.relay import Node
 | 
			
		||||
from graphene_django import DjangoObjectType
 | 
			
		||||
from graphene_django.tests.models import Pet
 | 
			
		||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
 | 
			
		||||
 | 
			
		||||
pytestmark = []
 | 
			
		||||
 | 
			
		||||
if DJANGO_FILTER_INSTALLED:
 | 
			
		||||
    from graphene_django.filter import DjangoFilterConnectionField
 | 
			
		||||
else:
 | 
			
		||||
    pytestmark.append(
 | 
			
		||||
        pytest.mark.skipif(
 | 
			
		||||
            True, reason="django_filters not installed or not compatible"
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PetNode(DjangoObjectType):
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = Pet
 | 
			
		||||
        interfaces = (Node,)
 | 
			
		||||
        filter_fields = {
 | 
			
		||||
            "name": ["exact", "in"],
 | 
			
		||||
            "age": ["exact", "in", "range"],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Query(ObjectType):
 | 
			
		||||
    pets = DjangoFilterConnectionField(PetNode)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_string_in_filter():
 | 
			
		||||
    """
 | 
			
		||||
    Test in filter on a string field.
 | 
			
		||||
    """
 | 
			
		||||
    Pet.objects.create(name="Brutus", age=12)
 | 
			
		||||
    Pet.objects.create(name="Mimi", age=3)
 | 
			
		||||
    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        pets (name_In: ["Brutus", "Jojo, the rabbit"]) {
 | 
			
		||||
            edges {
 | 
			
		||||
                node {
 | 
			
		||||
                    name
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data["pets"]["edges"] == [
 | 
			
		||||
        {"node": {"name": "Brutus"}},
 | 
			
		||||
        {"node": {"name": "Jojo, the rabbit"}},
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_int_in_filter():
 | 
			
		||||
    """
 | 
			
		||||
    Test in filter on an integer field.
 | 
			
		||||
    """
 | 
			
		||||
    Pet.objects.create(name="Brutus", age=12)
 | 
			
		||||
    Pet.objects.create(name="Mimi", age=3)
 | 
			
		||||
    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        pets (age_In: [3]) {
 | 
			
		||||
            edges {
 | 
			
		||||
                node {
 | 
			
		||||
                    name
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data["pets"]["edges"] == [
 | 
			
		||||
        {"node": {"name": "Mimi"}},
 | 
			
		||||
        {"node": {"name": "Jojo, the rabbit"}},
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        pets (age_In: [3, 12]) {
 | 
			
		||||
            edges {
 | 
			
		||||
                node {
 | 
			
		||||
                    name
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data["pets"]["edges"] == [
 | 
			
		||||
        {"node": {"name": "Brutus"}},
 | 
			
		||||
        {"node": {"name": "Mimi"}},
 | 
			
		||||
        {"node": {"name": "Jojo, the rabbit"}},
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_int_range_filter():
 | 
			
		||||
    """
 | 
			
		||||
    Test in filter on an integer field.
 | 
			
		||||
    """
 | 
			
		||||
    Pet.objects.create(name="Brutus", age=12)
 | 
			
		||||
    Pet.objects.create(name="Mimi", age=8)
 | 
			
		||||
    Pet.objects.create(name="Jojo, the rabbit", age=3)
 | 
			
		||||
    Pet.objects.create(name="Picotin", age=5)
 | 
			
		||||
 | 
			
		||||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    query = """
 | 
			
		||||
    query {
 | 
			
		||||
        pets (age_Range: [4, 9]) {
 | 
			
		||||
            edges {
 | 
			
		||||
                node {
 | 
			
		||||
                    name
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    """
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data["pets"]["edges"] == [
 | 
			
		||||
        {"node": {"name": "Mimi"}},
 | 
			
		||||
        {"node": {"name": "Picotin"}},
 | 
			
		||||
    ]
 | 
			
		||||
| 
						 | 
				
			
			@ -1,6 +1,10 @@
 | 
			
		|||
import six
 | 
			
		||||
 | 
			
		||||
from graphene import List
 | 
			
		||||
 | 
			
		||||
from django_filters.utils import get_model_field
 | 
			
		||||
from django_filters.filters import Filter, BaseCSVFilter
 | 
			
		||||
 | 
			
		||||
from .filterset import custom_filterset_factory, setup_filterset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -17,8 +21,11 @@ def get_filtering_args_from_filterset(filterset_class, type):
 | 
			
		|||
        form_field = None
 | 
			
		||||
 | 
			
		||||
        if name in filterset_class.declared_filters:
 | 
			
		||||
            # Get the filter field from the explicitly declared filter
 | 
			
		||||
            form_field = filter_field.field
 | 
			
		||||
            field = convert_form_field(form_field)
 | 
			
		||||
        else:
 | 
			
		||||
            # Get the filter field with no explicit type declaration
 | 
			
		||||
            model_field = get_model_field(model, filter_field.field_name)
 | 
			
		||||
            filter_type = filter_field.lookup_expr
 | 
			
		||||
            if filter_type != "isnull" and hasattr(model_field, "formfield"):
 | 
			
		||||
| 
						 | 
				
			
			@ -26,12 +33,19 @@ def get_filtering_args_from_filterset(filterset_class, type):
 | 
			
		|||
                    required=filter_field.extra.get("required", False)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # Fallback to field defined on filter if we can't get it from the
 | 
			
		||||
        # model field
 | 
			
		||||
        if not form_field:
 | 
			
		||||
            form_field = filter_field.field
 | 
			
		||||
            # Fallback to field defined on filter if we can't get it from the
 | 
			
		||||
            # model field
 | 
			
		||||
            if not form_field:
 | 
			
		||||
                form_field = filter_field.field
 | 
			
		||||
 | 
			
		||||
        field_type = convert_form_field(form_field).Argument()
 | 
			
		||||
            field = convert_form_field(form_field)
 | 
			
		||||
 | 
			
		||||
            if filter_type in ["in", "range"]:
 | 
			
		||||
                # Replace CSV filters (`in`, `range`) argument type to be a list of the same type as the field.
 | 
			
		||||
                # See comments in `replace_csv_filters` method for more details.
 | 
			
		||||
                field = List(field.get_type())
 | 
			
		||||
 | 
			
		||||
        field_type = field.Argument()
 | 
			
		||||
        field_type.description = filter_field.label
 | 
			
		||||
        args[name] = field_type
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -39,9 +53,42 @@ def get_filtering_args_from_filterset(filterset_class, type):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def get_filterset_class(filterset_class, **meta):
 | 
			
		||||
    """Get the class to be used as the FilterSet"""
 | 
			
		||||
    """
 | 
			
		||||
    Get the class to be used as the FilterSet.
 | 
			
		||||
    """
 | 
			
		||||
    if filterset_class:
 | 
			
		||||
        # If were given a FilterSet class, then set it up and
 | 
			
		||||
        # return it
 | 
			
		||||
        return setup_filterset(filterset_class)
 | 
			
		||||
    return custom_filterset_factory(**meta)
 | 
			
		||||
        # If were given a FilterSet class, then set it up.
 | 
			
		||||
        graphene_filterset_class = setup_filterset(filterset_class)
 | 
			
		||||
    else:
 | 
			
		||||
        # Otherwise create one.
 | 
			
		||||
        graphene_filterset_class = custom_filterset_factory(**meta)
 | 
			
		||||
 | 
			
		||||
    replace_csv_filters(graphene_filterset_class)
 | 
			
		||||
    return graphene_filterset_class
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def replace_csv_filters(filterset_class):
 | 
			
		||||
    """
 | 
			
		||||
    Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
 | 
			
		||||
    but regular Filter objects that simply use the input value as filter argument on the queryset.
 | 
			
		||||
 | 
			
		||||
    This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
 | 
			
		||||
    can actually have a list as input and have a proper type verification of each value in the list.
 | 
			
		||||
 | 
			
		||||
    See issue https://github.com/graphql-python/graphene-django/issues/1068.
 | 
			
		||||
    """
 | 
			
		||||
    for name, filter_field in six.iteritems(filterset_class.base_filters):
 | 
			
		||||
        filter_type = filter_field.lookup_expr
 | 
			
		||||
        if (
 | 
			
		||||
            filter_type in ["in", "range"]
 | 
			
		||||
            and name not in filterset_class.declared_filters
 | 
			
		||||
        ):
 | 
			
		||||
            assert isinstance(filter_field, BaseCSVFilter)
 | 
			
		||||
            filterset_class.base_filters[name] = Filter(
 | 
			
		||||
                field_name=filter_field.field_name,
 | 
			
		||||
                lookup_expr=filter_field.lookup_expr,
 | 
			
		||||
                label=filter_field.label,
 | 
			
		||||
                method=filter_field.method,
 | 
			
		||||
                exclude=filter_field.exclude,
 | 
			
		||||
                **filter_field.extra
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user