diff --git a/graphene_django/filter/tests/conftest.py b/graphene_django/filter/tests/conftest.py new file mode 100644 index 0000000..4a1fead --- /dev/null +++ b/graphene_django/filter/tests/conftest.py @@ -0,0 +1,125 @@ +from mock import MagicMock +import pytest + +from django.db import models +from django.db.models.query import QuerySet +from django_filters import filters +from django_filters import FilterSet +import graphene +from graphene.relay import Node +from graphene_django import DjangoObjectType +from graphene_django.utils import DJANGO_FILTER_INSTALLED + +from ...compat import ArrayField + +pytestmark = [] + +if DJANGO_FILTER_INSTALLED: + from graphene_django.filter import DjangoFilterConnectionField +else: + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) + + +@pytest.fixture +def Event(): + class Event(models.Model): + name = models.CharField(max_length=50) + tags = ArrayField(models.CharField(max_length=50)) + + return Event + + +@pytest.fixture +def EventFilterSet(Event): + + from django.contrib.postgres.forms import SimpleArrayField + + class ArrayFilter(filters.Filter): + base_field_class = SimpleArrayField + + class EventFilterSet(FilterSet): + class Meta: + model = Event + fields = { + "name": ["exact"], + } + + tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") + tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") + + return EventFilterSet + + +@pytest.fixture +def EventType(Event, EventFilterSet): + class EventType(DjangoObjectType): + class Meta: + model = Event + interfaces = (Node,) + filterset_class = EventFilterSet + + return EventType + + +@pytest.fixture +def Query(Event, EventType): + class Query(graphene.ObjectType): + events = DjangoFilterConnectionField(EventType) + + def resolve_events(self, info, **kwargs): + + events = [ + Event(name="Live Show", tags=["concert", "music", "rock"],), + Event(name="Musical", tags=["movie", "music"],), + Event(name="Ballet", tags=["concert", "dance"],), + ] + + m_queryset = MagicMock(spec=QuerySet) + m_queryset.model = Event + + def filter_events(**kwargs): + nonlocal events + if "tags__contains" in kwargs: + events = list( + filter( + lambda e: set(kwargs["tags__contains"]).issubset( + set(e.tags) + ), + events, + ) + ) + if "tags__overlap" in kwargs: + events = list( + filter( + lambda e: not set(kwargs["tags__overlap"]).isdisjoint( + set(e.tags) + ), + events, + ) + ) + + def mock_queryset_filter(*args, **kwargs): + filter_events(**kwargs) + return m_queryset + + def mock_queryset_none(*args, **kwargs): + nonlocal events + events = [] + return m_queryset + + def mock_queryset_count(*args, **kwargs): + return len(events) + + m_queryset.all.return_value = m_queryset + m_queryset.filter.side_effect = mock_queryset_filter + m_queryset.none.side_effect = mock_queryset_none + m_queryset.count.side_effect = mock_queryset_count + m_queryset.__getitem__.side_effect = events.__getitem__ + + return m_queryset + + return Query diff --git a/graphene_django/filter/tests/test_contains_filter.py b/graphene_django/filter/tests/test_contains_filter.py new file mode 100644 index 0000000..3e90a3b --- /dev/null +++ b/graphene_django/filter/tests/test_contains_filter.py @@ -0,0 +1,82 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_multiple(Event, Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_one(Event, Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: ["music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_contains_none(Event, Query): + """ + Test contains filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Contains: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] diff --git a/graphene_django/filter/tests/test_overlap_filter.py b/graphene_django/filter/tests/test_overlap_filter.py new file mode 100644 index 0000000..90e825f --- /dev/null +++ b/graphene_django/filter/tests/test_overlap_filter.py @@ -0,0 +1,84 @@ +import pytest + +from graphene import Schema + +from ...compat import ArrayField, MissingType + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_multiple(Event, Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: ["concert", "music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + {"node": {"name": "Ballet"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_one(Event, Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: ["music"]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [ + {"node": {"name": "Live Show"}}, + {"node": {"name": "Musical"}}, + ] + + +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") +def test_string_overlap_none(Event, Query): + """ + Test overlap filter on a string field. + """ + + schema = Schema(query=Query) + + query = """ + query { + events (tags_Overlap: []) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(query) + assert not result.errors + assert result.data["events"]["edges"] == [] diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index dce08c7..1be114e 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -1,6 +1,6 @@ import six -from graphene import List +import graphene from django_filters.utils import get_model_field from django_filters.filters import Filter, BaseCSVFilter @@ -41,11 +41,11 @@ def get_filtering_args_from_filterset(filterset_class, type): field = convert_form_field(form_field) - if filter_type in ["in", "range"]: + if filter_type in {"in", "range", "contains", "overlap"}: # Replace CSV filters (`in`, `range`) argument type to be a list of # the same type as the field. See comments in # `replace_csv_filters` method for more details. - field = List(field.get_type()) + field = graphene.List(field.get_type()) field_type = field.Argument() field_type.description = filter_field.label @@ -81,8 +81,7 @@ def replace_csv_filters(filterset_class): """ for name, filter_field in six.iteritems(filterset_class.base_filters): filter_type = filter_field.lookup_expr - if filter_type == "in": - assert isinstance(filter_field, BaseCSVFilter) + if filter_type in {"in", "contains", "overlap"}: filterset_class.base_filters[name] = InFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr, @@ -92,8 +91,7 @@ def replace_csv_filters(filterset_class): **filter_field.extra ) - if filter_type == "range": - assert isinstance(filter_field, BaseCSVFilter) + elif filter_type == "range": filterset_class.base_filters[name] = RangeFilter( field_name=filter_field.field_name, lookup_expr=filter_field.lookup_expr,