Fix for v3

This commit is contained in:
Thomas Leonard 2021-02-13 18:27:16 +01:00
parent c5ac23ca78
commit e81256cdc5
11 changed files with 104 additions and 38 deletions

View File

@ -36,7 +36,7 @@ from .utils.str_converters import to_const
class BlankValueField(Field):
def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
resolver = self.resolver or parent_resolver
# create custom resolver

View File

@ -2,12 +2,31 @@ from collections import OrderedDict
from functools import partial
from django.core.exceptions import ValidationError
from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class
def convert_enum(data):
"""
Check if the data is a enum option (or potentially nested list of enum option)
and convert it to its value.
This method is used to pre-process the data for the filters as they can take an
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
"""
if isinstance(data, list):
return [convert_enum(item) for item in data]
if isinstance(type(data), EnumType):
return data.value
else:
return data
class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(
self,
@ -68,7 +87,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if k in filtering_args:
if k == "order_by" and v is not None:
v = to_snake_case(v)
kwargs[k] = v
kwargs[k] = convert_enum(v)
return kwargs
qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
@ -78,7 +97,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context
)
if filterset.form.is_valid():
if filterset.is_valid():
return filterset.qs
raise ValidationError(filterset.form.errors.as_json())

View File

@ -28,19 +28,15 @@ else:
STORE = {"events": []}
@pytest.fixture
def Event():
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())
return Event
@pytest.fixture
def EventFilterSet(Event):
def EventFilterSet():
class EventFilterSet(FilterSet):
class Meta:
model = Event
@ -69,18 +65,19 @@ def EventFilterSet(Event):
@pytest.fixture
def EventType(Event, EventFilterSet):
def EventType(EventFilterSet):
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet
return EventType
@pytest.fixture
def Query(Event, EventType):
def Query(EventType):
"""
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
we are running unit tests in sqlite which does not have ArrayFields.

View File

@ -89,19 +89,41 @@ def test_array_field_filter_schema_type(Query):
schema_str = str(schema)
assert (
"""type EventType implements Node {
'''type EventType implements Node {
"""The ID of the object"""
id: ID!
name: String!
tags: [String!]!
tagIds: [Int!]!
randomField: [Boolean!]!
}"""
}'''
in schema_str
)
assert (
"""type Query {
events(offset: Int, before: String, after: String, first: Int, last: Int, name: String, name_Contains: String, tags_Contains: [String!], tags_Overlap: [String!], tags: [String!], tagsIds_Contains: [Int!], tagsIds_Overlap: [Int!], tagsIds: [Int!], randomField_Contains: [Boolean!], randomField_Overlap: [Boolean!], randomField: [Boolean!]): EventTypeConnection
}"""
in schema_str
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"name": "String",
"name_Contains": "String",
"tags_Contains": "[String!]",
"tags_Overlap": "[String!]",
"tags": "[String!]",
"tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]",
"randomField_Contains": "[Boolean!]",
"randomField_Overlap": "[Boolean!]",
"randomField": "[Boolean!]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
)
assert (
f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str
)

View File

@ -25,11 +25,13 @@ def schema():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"lang": ["exact", "in"],
"reporter__a_choice": ["exact", "in"],
@ -122,23 +124,37 @@ def test_filter_enum_field_schema_type(schema):
schema_str = str(schema)
assert (
"""type ArticleType implements Node {
'''type ArticleType implements Node {
"""The ID of the object"""
id: ID!
headline: String!
pubDate: Date!
pubDateTime: DateTime!
reporter: ReporterType!
editor: ReporterType!
lang: ArticleLang!
importance: ArticleImportance
}"""
"""Language"""
lang: TestsArticleLangChoices!
importance: TestsArticleImportanceChoices
}'''
in schema_str
)
assert (
"""type Query {
allReporters(offset: Int, before: String, after: String, first: Int, last: Int): ReporterTypeConnection
allArticles(offset: Int, before: String, after: String, first: Int, last: Int, lang: ArticleLang, lang_In: [ArticleLang], reporter_AChoice: ReporterAChoice, reporter_AChoice_In: [ReporterAChoice]): ArticleTypeConnection
}"""
in schema_str
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"reporter_AChoice": "TestsReporterAChoiceChoices",
"reporter_AChoice_In": "[TestsReporterAChoiceChoices]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
)
assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str

View File

@ -739,6 +739,7 @@ def test_order_by():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
@ -1236,6 +1237,7 @@ def test_filter_string_contains():
class Meta:
model = Person
interfaces = (Node,)
fields = "__all__"
filter_fields = {"name": ["exact", "in", "contains", "icontains"]}
class Query(ObjectType):

View File

@ -29,6 +29,7 @@ def query():
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"id": ["exact", "in"],
"name": ["exact", "in"],
@ -39,6 +40,7 @@ def query():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
# choice filter using enum
filter_fields = {"reporter_type": ["exact", "in"]}
@ -46,12 +48,14 @@ def query():
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilter
class FilmNode(DjangoObjectType):
class Meta:
model = Film
interfaces = (Node,)
fields = "__all__"
# choice filter not using enum
filter_fields = {
"genre": ["exact", "in"],
@ -77,6 +81,7 @@ def query():
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
fields = "__all__"
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)

View File

@ -25,6 +25,7 @@ class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
@ -101,14 +102,14 @@ def test_range_filter_with_invalid_input():
# Empty list
result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
assert result.errors[0].message == expected_error
# Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
assert result.errors[0].message == expected_error
# More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
assert result.errors[0].message == expected_error

View File

@ -94,9 +94,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
field_type = graphene.List(field_type.get_type())
args[name] = graphene.Argument(
type=field_type.get_type(),
description=filter_field.label,
required=required,
field_type.get_type(), description=filter_field.label, required=required,
)
return args

View File

@ -1253,6 +1253,7 @@ class TestBackwardPagination:
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1457,6 +1458,7 @@ def test_connection_should_enable_offset_filtering():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1496,6 +1498,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1529,6 +1532,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1563,6 +1567,7 @@ def test_connection_should_allow_offset_filtering_with_after():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)

View File

@ -671,6 +671,7 @@ def test_django_objecttype_name_connection_propagation():
class Meta:
model = ReporterModel
name = "CustomReporterName"
fields = "__all__"
filter_fields = ["email"]
interfaces = (Node,)