mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-01-31 11:48:38 +03:00
Support "contains" and "overlap" filtering (v3) (#1101)
* Support contains/overlap filters * Remove unused fixtures
This commit is contained in:
parent
bcc7f85dad
commit
5dea6ffa41
128
graphene_django/filter/tests/conftest.py
Normal file
128
graphene_django/filter/tests/conftest.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
from mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
from django.db.models.query import QuerySet
|
||||||
|
from django_filters import filters
|
||||||
|
from django_filters import FilterSet
|
||||||
|
import graphene
|
||||||
|
from graphene.relay import Node
|
||||||
|
from graphene_django import DjangoObjectType
|
||||||
|
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||||
|
|
||||||
|
from ...compat import ArrayField
|
||||||
|
|
||||||
|
pytestmark = []
|
||||||
|
|
||||||
|
if DJANGO_FILTER_INSTALLED:
|
||||||
|
from graphene_django.filter import DjangoFilterConnectionField
|
||||||
|
else:
|
||||||
|
pytestmark.append(
|
||||||
|
pytest.mark.skipif(
|
||||||
|
True, reason="django_filters not installed or not compatible"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
STORE = {"events": []}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def Event():
|
||||||
|
class Event(models.Model):
|
||||||
|
name = models.CharField(max_length=50)
|
||||||
|
tags = ArrayField(models.CharField(max_length=50))
|
||||||
|
|
||||||
|
return Event
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def EventFilterSet(Event):
|
||||||
|
|
||||||
|
from django.contrib.postgres.forms import SimpleArrayField
|
||||||
|
|
||||||
|
class ArrayFilter(filters.Filter):
|
||||||
|
base_field_class = SimpleArrayField
|
||||||
|
|
||||||
|
class EventFilterSet(FilterSet):
|
||||||
|
class Meta:
|
||||||
|
model = Event
|
||||||
|
fields = {
|
||||||
|
"name": ["exact"],
|
||||||
|
}
|
||||||
|
|
||||||
|
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
|
||||||
|
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
|
||||||
|
|
||||||
|
return EventFilterSet
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def EventType(Event, EventFilterSet):
|
||||||
|
class EventType(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = Event
|
||||||
|
interfaces = (Node,)
|
||||||
|
filterset_class = EventFilterSet
|
||||||
|
|
||||||
|
return EventType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def Query(Event, EventType):
|
||||||
|
class Query(graphene.ObjectType):
|
||||||
|
events = DjangoFilterConnectionField(EventType)
|
||||||
|
|
||||||
|
def resolve_events(self, info, **kwargs):
|
||||||
|
|
||||||
|
events = [
|
||||||
|
Event(name="Live Show", tags=["concert", "music", "rock"],),
|
||||||
|
Event(name="Musical", tags=["movie", "music"],),
|
||||||
|
Event(name="Ballet", tags=["concert", "dance"],),
|
||||||
|
]
|
||||||
|
|
||||||
|
STORE["events"] = events
|
||||||
|
|
||||||
|
m_queryset = MagicMock(spec=QuerySet)
|
||||||
|
m_queryset.model = Event
|
||||||
|
|
||||||
|
def filter_events(**kwargs):
|
||||||
|
if "tags__contains" in kwargs:
|
||||||
|
STORE["events"] = list(
|
||||||
|
filter(
|
||||||
|
lambda e: set(kwargs["tags__contains"]).issubset(
|
||||||
|
set(e.tags)
|
||||||
|
),
|
||||||
|
STORE["events"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "tags__overlap" in kwargs:
|
||||||
|
STORE["events"] = list(
|
||||||
|
filter(
|
||||||
|
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
|
||||||
|
set(e.tags)
|
||||||
|
),
|
||||||
|
STORE["events"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_queryset_filter(*args, **kwargs):
|
||||||
|
filter_events(**kwargs)
|
||||||
|
return m_queryset
|
||||||
|
|
||||||
|
def mock_queryset_none(*args, **kwargs):
|
||||||
|
STORE["events"] = []
|
||||||
|
return m_queryset
|
||||||
|
|
||||||
|
def mock_queryset_count(*args, **kwargs):
|
||||||
|
return len(STORE["events"])
|
||||||
|
|
||||||
|
m_queryset.all.return_value = m_queryset
|
||||||
|
m_queryset.filter.side_effect = mock_queryset_filter
|
||||||
|
m_queryset.none.side_effect = mock_queryset_none
|
||||||
|
m_queryset.count.side_effect = mock_queryset_count
|
||||||
|
m_queryset.__getitem__.side_effect = STORE["events"].__getitem__
|
||||||
|
|
||||||
|
return m_queryset
|
||||||
|
|
||||||
|
return Query
|
82
graphene_django/filter/tests/test_contains_filter.py
Normal file
82
graphene_django/filter/tests/test_contains_filter.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphene import Schema
|
||||||
|
|
||||||
|
from ...compat import ArrayField, MissingType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_contains_multiple(Query):
|
||||||
|
"""
|
||||||
|
Test contains filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Contains: ["concert", "music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_contains_one(Query):
|
||||||
|
"""
|
||||||
|
Test contains filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Contains: ["music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
{"node": {"name": "Musical"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_contains_none(Query):
|
||||||
|
"""
|
||||||
|
Test contains filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Contains: []) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == []
|
84
graphene_django/filter/tests/test_overlap_filter.py
Normal file
84
graphene_django/filter/tests/test_overlap_filter.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphene import Schema
|
||||||
|
|
||||||
|
from ...compat import ArrayField, MissingType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_overlap_multiple(Query):
|
||||||
|
"""
|
||||||
|
Test overlap filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Overlap: ["concert", "music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
{"node": {"name": "Musical"}},
|
||||||
|
{"node": {"name": "Ballet"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_overlap_one(Query):
|
||||||
|
"""
|
||||||
|
Test overlap filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Overlap: ["music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
{"node": {"name": "Musical"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_overlap_none(Query):
|
||||||
|
"""
|
||||||
|
Test overlap filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Overlap: []) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == []
|
|
@ -1,4 +1,4 @@
|
||||||
from graphene import List
|
import graphene
|
||||||
|
|
||||||
from django_filters.utils import get_model_field
|
from django_filters.utils import get_model_field
|
||||||
from django_filters.filters import Filter, BaseCSVFilter
|
from django_filters.filters import Filter, BaseCSVFilter
|
||||||
|
@ -39,11 +39,11 @@ def get_filtering_args_from_filterset(filterset_class, type):
|
||||||
|
|
||||||
field = convert_form_field(form_field)
|
field = convert_form_field(form_field)
|
||||||
|
|
||||||
if filter_type in ["in", "range"]:
|
if filter_type in {"in", "range", "contains", "overlap"}:
|
||||||
# Replace CSV filters (`in`, `range`) argument type to be a list of
|
# Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of
|
||||||
# the same type as the field. See comments in
|
# the same type as the field. See comments in
|
||||||
# `replace_csv_filters` method for more details.
|
# `replace_csv_filters` method for more details.
|
||||||
field = List(field.get_type())
|
field = graphene.List(field.get_type())
|
||||||
|
|
||||||
field_type = field.Argument()
|
field_type = field.Argument()
|
||||||
field_type.description = str(filter_field.label) if filter_field.label else None
|
field_type.description = str(filter_field.label) if filter_field.label else None
|
||||||
|
@ -69,7 +69,7 @@ def get_filterset_class(filterset_class, **meta):
|
||||||
|
|
||||||
def replace_csv_filters(filterset_class):
|
def replace_csv_filters(filterset_class):
|
||||||
"""
|
"""
|
||||||
Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
|
Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
|
||||||
but regular Filter objects that simply use the input value as filter argument on the queryset.
|
but regular Filter objects that simply use the input value as filter argument on the queryset.
|
||||||
|
|
||||||
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
|
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
|
||||||
|
@ -79,8 +79,7 @@ def replace_csv_filters(filterset_class):
|
||||||
"""
|
"""
|
||||||
for name, filter_field in list(filterset_class.base_filters.items()):
|
for name, filter_field in list(filterset_class.base_filters.items()):
|
||||||
filter_type = filter_field.lookup_expr
|
filter_type = filter_field.lookup_expr
|
||||||
if filter_type == "in":
|
if filter_type in {"in", "contains", "overlap"}:
|
||||||
assert isinstance(filter_field, BaseCSVFilter)
|
|
||||||
filterset_class.base_filters[name] = InFilter(
|
filterset_class.base_filters[name] = InFilter(
|
||||||
field_name=filter_field.field_name,
|
field_name=filter_field.field_name,
|
||||||
lookup_expr=filter_field.lookup_expr,
|
lookup_expr=filter_field.lookup_expr,
|
||||||
|
@ -90,8 +89,7 @@ def replace_csv_filters(filterset_class):
|
||||||
**filter_field.extra
|
**filter_field.extra
|
||||||
)
|
)
|
||||||
|
|
||||||
if filter_type == "range":
|
elif filter_type == "range":
|
||||||
assert isinstance(filter_field, BaseCSVFilter)
|
|
||||||
filterset_class.base_filters[name] = RangeFilter(
|
filterset_class.base_filters[name] = RangeFilter(
|
||||||
field_name=filter_field.field_name,
|
field_name=filter_field.field_name,
|
||||||
lookup_expr=filter_field.lookup_expr,
|
lookup_expr=filter_field.lookup_expr,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user