mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-07-11 08:42:32 +03:00
Support contains/overlap filters
This commit is contained in:
parent
fc57054069
commit
38cf281b5d
125
graphene_django/filter/tests/conftest.py
Normal file
125
graphene_django/filter/tests/conftest.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
from mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
|
from django.db.models.query import QuerySet
|
||||||
|
from django_filters import filters
|
||||||
|
from django_filters import FilterSet
|
||||||
|
import graphene
|
||||||
|
from graphene.relay import Node
|
||||||
|
from graphene_django import DjangoObjectType
|
||||||
|
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||||
|
|
||||||
|
from ...compat import ArrayField
|
||||||
|
|
||||||
|
pytestmark = []
|
||||||
|
|
||||||
|
if DJANGO_FILTER_INSTALLED:
|
||||||
|
from graphene_django.filter import DjangoFilterConnectionField
|
||||||
|
else:
|
||||||
|
pytestmark.append(
|
||||||
|
pytest.mark.skipif(
|
||||||
|
True, reason="django_filters not installed or not compatible"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def Event():
|
||||||
|
class Event(models.Model):
|
||||||
|
name = models.CharField(max_length=50)
|
||||||
|
tags = ArrayField(models.CharField(max_length=50))
|
||||||
|
|
||||||
|
return Event
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def EventFilterSet(Event):
|
||||||
|
|
||||||
|
from django.contrib.postgres.forms import SimpleArrayField
|
||||||
|
|
||||||
|
class ArrayFilter(filters.Filter):
|
||||||
|
base_field_class = SimpleArrayField
|
||||||
|
|
||||||
|
class EventFilterSet(FilterSet):
|
||||||
|
class Meta:
|
||||||
|
model = Event
|
||||||
|
fields = {
|
||||||
|
"name": ["exact"],
|
||||||
|
}
|
||||||
|
|
||||||
|
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
|
||||||
|
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
|
||||||
|
|
||||||
|
return EventFilterSet
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def EventType(Event, EventFilterSet):
|
||||||
|
class EventType(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = Event
|
||||||
|
interfaces = (Node,)
|
||||||
|
filterset_class = EventFilterSet
|
||||||
|
|
||||||
|
return EventType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def Query(Event, EventType):
|
||||||
|
class Query(graphene.ObjectType):
|
||||||
|
events = DjangoFilterConnectionField(EventType)
|
||||||
|
|
||||||
|
def resolve_events(self, info, **kwargs):
|
||||||
|
|
||||||
|
events = [
|
||||||
|
Event(name="Live Show", tags=["concert", "music", "rock"],),
|
||||||
|
Event(name="Musical", tags=["movie", "music"],),
|
||||||
|
Event(name="Ballet", tags=["concert", "dance"],),
|
||||||
|
]
|
||||||
|
|
||||||
|
m_queryset = MagicMock(spec=QuerySet)
|
||||||
|
m_queryset.model = Event
|
||||||
|
|
||||||
|
def filter_events(**kwargs):
|
||||||
|
nonlocal events
|
||||||
|
if "tags__contains" in kwargs:
|
||||||
|
events = list(
|
||||||
|
filter(
|
||||||
|
lambda e: set(kwargs["tags__contains"]).issubset(
|
||||||
|
set(e.tags)
|
||||||
|
),
|
||||||
|
events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "tags__overlap" in kwargs:
|
||||||
|
events = list(
|
||||||
|
filter(
|
||||||
|
lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
|
||||||
|
set(e.tags)
|
||||||
|
),
|
||||||
|
events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_queryset_filter(*args, **kwargs):
|
||||||
|
filter_events(**kwargs)
|
||||||
|
return m_queryset
|
||||||
|
|
||||||
|
def mock_queryset_none(*args, **kwargs):
|
||||||
|
nonlocal events
|
||||||
|
events = []
|
||||||
|
return m_queryset
|
||||||
|
|
||||||
|
def mock_queryset_count(*args, **kwargs):
|
||||||
|
return len(events)
|
||||||
|
|
||||||
|
m_queryset.all.return_value = m_queryset
|
||||||
|
m_queryset.filter.side_effect = mock_queryset_filter
|
||||||
|
m_queryset.none.side_effect = mock_queryset_none
|
||||||
|
m_queryset.count.side_effect = mock_queryset_count
|
||||||
|
m_queryset.__getitem__.side_effect = events.__getitem__
|
||||||
|
|
||||||
|
return m_queryset
|
||||||
|
|
||||||
|
return Query
|
82
graphene_django/filter/tests/test_contains_filter.py
Normal file
82
graphene_django/filter/tests/test_contains_filter.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphene import Schema
|
||||||
|
|
||||||
|
from ...compat import ArrayField, MissingType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_contains_multiple(Event, Query):
|
||||||
|
"""
|
||||||
|
Test contains filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Contains: ["concert", "music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_contains_one(Event, Query):
|
||||||
|
"""
|
||||||
|
Test contains filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Contains: ["music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
{"node": {"name": "Musical"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_contains_none(Event, Query):
|
||||||
|
"""
|
||||||
|
Test contains filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Contains: []) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == []
|
84
graphene_django/filter/tests/test_overlap_filter.py
Normal file
84
graphene_django/filter/tests/test_overlap_filter.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphene import Schema
|
||||||
|
|
||||||
|
from ...compat import ArrayField, MissingType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_overlap_multiple(Event, Query):
|
||||||
|
"""
|
||||||
|
Test overlap filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Overlap: ["concert", "music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
{"node": {"name": "Musical"}},
|
||||||
|
{"node": {"name": "Ballet"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_overlap_one(Event, Query):
|
||||||
|
"""
|
||||||
|
Test overlap filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Overlap: ["music"]) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == [
|
||||||
|
{"node": {"name": "Live Show"}},
|
||||||
|
{"node": {"name": "Musical"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
|
||||||
|
def test_string_overlap_none(Event, Query):
|
||||||
|
"""
|
||||||
|
Test overlap filter on a string field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
events (tags_Overlap: []) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data["events"]["edges"] == []
|
|
@ -1,6 +1,6 @@
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from graphene import List
|
import graphene
|
||||||
|
|
||||||
from django_filters.utils import get_model_field
|
from django_filters.utils import get_model_field
|
||||||
from django_filters.filters import Filter, BaseCSVFilter
|
from django_filters.filters import Filter, BaseCSVFilter
|
||||||
|
@ -41,11 +41,11 @@ def get_filtering_args_from_filterset(filterset_class, type):
|
||||||
|
|
||||||
field = convert_form_field(form_field)
|
field = convert_form_field(form_field)
|
||||||
|
|
||||||
if filter_type in ["in", "range"]:
|
if filter_type in {"in", "range", "contains", "overlap"}:
|
||||||
# Replace CSV filters (`in`, `range`) argument type to be a list of
|
# Replace CSV filters (`in`, `range`) argument type to be a list of
|
||||||
# the same type as the field. See comments in
|
# the same type as the field. See comments in
|
||||||
# `replace_csv_filters` method for more details.
|
# `replace_csv_filters` method for more details.
|
||||||
field = List(field.get_type())
|
field = graphene.List(field.get_type())
|
||||||
|
|
||||||
field_type = field.Argument()
|
field_type = field.Argument()
|
||||||
field_type.description = filter_field.label
|
field_type.description = filter_field.label
|
||||||
|
@ -81,8 +81,7 @@ def replace_csv_filters(filterset_class):
|
||||||
"""
|
"""
|
||||||
for name, filter_field in six.iteritems(filterset_class.base_filters):
|
for name, filter_field in six.iteritems(filterset_class.base_filters):
|
||||||
filter_type = filter_field.lookup_expr
|
filter_type = filter_field.lookup_expr
|
||||||
if filter_type == "in":
|
if filter_type in {"in", "contains", "overlap"}:
|
||||||
assert isinstance(filter_field, BaseCSVFilter)
|
|
||||||
filterset_class.base_filters[name] = InFilter(
|
filterset_class.base_filters[name] = InFilter(
|
||||||
field_name=filter_field.field_name,
|
field_name=filter_field.field_name,
|
||||||
lookup_expr=filter_field.lookup_expr,
|
lookup_expr=filter_field.lookup_expr,
|
||||||
|
@ -92,8 +91,7 @@ def replace_csv_filters(filterset_class):
|
||||||
**filter_field.extra
|
**filter_field.extra
|
||||||
)
|
)
|
||||||
|
|
||||||
if filter_type == "range":
|
elif filter_type == "range":
|
||||||
assert isinstance(filter_field, BaseCSVFilter)
|
|
||||||
filterset_class.base_filters[name] = RangeFilter(
|
filterset_class.base_filters[name] = RangeFilter(
|
||||||
field_name=filter_field.field_name,
|
field_name=filter_field.field_name,
|
||||||
lookup_expr=filter_field.lookup_expr,
|
lookup_expr=filter_field.lookup_expr,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user