Support contains/overlap filters

This commit is contained in:
Lucas Brémond 2021-01-17 18:46:33 -08:00
parent fc57054069
commit 38cf281b5d
4 changed files with 296 additions and 7 deletions

View File

@ -0,0 +1,125 @@
from mock import MagicMock
import pytest
from django.db import models
from django.db.models.query import QuerySet
from django_filters import filters
from django_filters import FilterSet
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.utils import DJANGO_FILTER_INSTALLED
from ...compat import ArrayField
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def Event():
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
return Event
@pytest.fixture
def EventFilterSet(Event):
from django.contrib.postgres.forms import SimpleArrayField
class ArrayFilter(filters.Filter):
base_field_class = SimpleArrayField
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact"],
}
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
return EventFilterSet
@pytest.fixture
def EventType(Event, EventFilterSet):
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
filterset_class = EventFilterSet
return EventType
@pytest.fixture
def Query(Event, EventType):
class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)
def resolve_events(self, info, **kwargs):
events = [
Event(name="Live Show", tags=["concert", "music", "rock"],),
Event(name="Musical", tags=["movie", "music"],),
Event(name="Ballet", tags=["concert", "dance"],),
]
m_queryset = MagicMock(spec=QuerySet)
m_queryset.model = Event
def filter_events(**kwargs):
nonlocal events
if "tags__contains" in kwargs:
events = list(
filter(
lambda e: set(kwargs["tags__contains"]).issubset(
set(e.tags)
),
events,
)
)
if "tags__overlap" in kwargs:
events = list(
filter(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
set(e.tags)
),
events,
)
)
def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs)
return m_queryset
def mock_queryset_none(*args, **kwargs):
nonlocal events
events = []
return m_queryset
def mock_queryset_count(*args, **kwargs):
return len(events)
m_queryset.all.return_value = m_queryset
m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = events.__getitem__
return m_queryset
return Query

View File

@ -0,0 +1,82 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_multiple(Event, Query):
"""
Test contains filter on a string field.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Contains: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_one(Event, Query):
"""
Test contains filter on a string field.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Contains: ["music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_none(Event, Query):
"""
Test contains filter on a string field.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Contains: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []

View File

@ -0,0 +1,84 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_multiple(Event, Query):
"""
Test overlap filter on a string field.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Overlap: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_one(Event, Query):
"""
Test overlap filter on a string field.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Overlap: ["music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_none(Event, Query):
"""
Test overlap filter on a string field.
"""
schema = Schema(query=Query)
query = """
query {
events (tags_Overlap: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []

View File

@ -1,6 +1,6 @@
import six import six
from graphene import List import graphene
from django_filters.utils import get_model_field from django_filters.utils import get_model_field
from django_filters.filters import Filter, BaseCSVFilter from django_filters.filters import Filter, BaseCSVFilter
@ -41,11 +41,11 @@ def get_filtering_args_from_filterset(filterset_class, type):
field = convert_form_field(form_field) field = convert_form_field(form_field)
if filter_type in ["in", "range"]: if filter_type in {"in", "range", "contains", "overlap"}:
# Replace CSV filters (`in`, `range`) argument type to be a list of # Replace CSV filters (`in`, `range`) argument type to be a list of
# the same type as the field. See comments in # the same type as the field. See comments in
# `replace_csv_filters` method for more details. # `replace_csv_filters` method for more details.
field = List(field.get_type()) field = graphene.List(field.get_type())
field_type = field.Argument() field_type = field.Argument()
field_type.description = filter_field.label field_type.description = filter_field.label
@ -81,8 +81,7 @@ def replace_csv_filters(filterset_class):
""" """
for name, filter_field in six.iteritems(filterset_class.base_filters): for name, filter_field in six.iteritems(filterset_class.base_filters):
filter_type = filter_field.lookup_expr filter_type = filter_field.lookup_expr
if filter_type == "in": if filter_type in {"in", "contains", "overlap"}:
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = InFilter( filterset_class.base_filters[name] = InFilter(
field_name=filter_field.field_name, field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr, lookup_expr=filter_field.lookup_expr,
@ -92,8 +91,7 @@ def replace_csv_filters(filterset_class):
**filter_field.extra **filter_field.extra
) )
if filter_type == "range": elif filter_type == "range":
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = RangeFilter( filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name, field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr, lookup_expr=filter_field.lookup_expr,