Fix for v3

This commit is contained in:
Thomas Leonard 2021-02-13 18:27:16 +01:00
parent c5ac23ca78
commit e81256cdc5
11 changed files with 104 additions and 38 deletions

View File

@ -36,7 +36,7 @@ from .utils.str_converters import to_const
class BlankValueField(Field): class BlankValueField(Field):
def get_resolver(self, parent_resolver): def wrap_resolve(self, parent_resolver):
resolver = self.resolver or parent_resolver resolver = self.resolver or parent_resolver
# create custom resolver # create custom resolver

View File

@ -2,12 +2,31 @@ from collections import OrderedDict
from functools import partial from functools import partial
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
def convert_enum(data):
"""
Check if the data is a enum option (or potentially nested list of enum option)
and convert it to its value.
This method is used to pre-process the data for the filters as they can take an
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
"""
if isinstance(data, list):
return [convert_enum(item) for item in data]
if isinstance(type(data), EnumType):
return data.value
else:
return data
class DjangoFilterConnectionField(DjangoConnectionField): class DjangoFilterConnectionField(DjangoConnectionField):
def __init__( def __init__(
self, self,
@ -68,7 +87,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if k in filtering_args: if k in filtering_args:
if k == "order_by" and v is not None: if k == "order_by" and v is not None:
v = to_snake_case(v) v = to_snake_case(v)
kwargs[k] = v kwargs[k] = convert_enum(v)
return kwargs return kwargs
qs = super(DjangoFilterConnectionField, cls).resolve_queryset( qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
@ -78,7 +97,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
filterset = filterset_class( filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context data=filter_kwargs(), queryset=qs, request=info.context
) )
if filterset.form.is_valid(): if filterset.is_valid():
return filterset.qs return filterset.qs
raise ValidationError(filterset.form.errors.as_json()) raise ValidationError(filterset.form.errors.as_json())

View File

@ -28,19 +28,15 @@ else:
STORE = {"events": []} STORE = {"events": []}
@pytest.fixture class Event(models.Model):
def Event(): name = models.CharField(max_length=50)
class Event(models.Model): tags = ArrayField(models.CharField(max_length=50))
name = models.CharField(max_length=50) tag_ids = ArrayField(models.IntegerField())
tags = ArrayField(models.CharField(max_length=50)) random_field = ArrayField(models.BooleanField())
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())
return Event
@pytest.fixture @pytest.fixture
def EventFilterSet(Event): def EventFilterSet():
class EventFilterSet(FilterSet): class EventFilterSet(FilterSet):
class Meta: class Meta:
model = Event model = Event
@ -69,18 +65,19 @@ def EventFilterSet(Event):
@pytest.fixture @pytest.fixture
def EventType(Event, EventFilterSet): def EventType(EventFilterSet):
class EventType(DjangoObjectType): class EventType(DjangoObjectType):
class Meta: class Meta:
model = Event model = Event
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet filterset_class = EventFilterSet
return EventType return EventType
@pytest.fixture @pytest.fixture
def Query(Event, EventType): def Query(EventType):
""" """
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
we are running unit tests in sqlite which does not have ArrayFields. we are running unit tests in sqlite which does not have ArrayFields.

View File

@ -89,19 +89,41 @@ def test_array_field_filter_schema_type(Query):
schema_str = str(schema) schema_str = str(schema)
assert ( assert (
"""type EventType implements Node { '''type EventType implements Node {
"""The ID of the object"""
id: ID! id: ID!
name: String! name: String!
tags: [String!]! tags: [String!]!
tagIds: [Int!]! tagIds: [Int!]!
randomField: [Boolean!]! randomField: [Boolean!]!
}""" }'''
in schema_str in schema_str
) )
assert ( filters = {
"""type Query { "offset": "Int",
events(offset: Int, before: String, after: String, first: Int, last: Int, name: String, name_Contains: String, tags_Contains: [String!], tags_Overlap: [String!], tags: [String!], tagsIds_Contains: [Int!], tagsIds_Overlap: [Int!], tagsIds: [Int!], randomField_Contains: [Boolean!], randomField_Overlap: [Boolean!], randomField: [Boolean!]): EventTypeConnection "before": "String",
}""" "after": "String",
in schema_str "first": "Int",
"last": "Int",
"name": "String",
"name_Contains": "String",
"tags_Contains": "[String!]",
"tags_Overlap": "[String!]",
"tags": "[String!]",
"tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]",
"randomField_Contains": "[Boolean!]",
"randomField_Overlap": "[Boolean!]",
"randomField": "[Boolean!]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
)
assert (
f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str
) )

View File

@ -25,11 +25,13 @@ def schema():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filter_fields = { filter_fields = {
"lang": ["exact", "in"], "lang": ["exact", "in"],
"reporter__a_choice": ["exact", "in"], "reporter__a_choice": ["exact", "in"],
@ -122,23 +124,37 @@ def test_filter_enum_field_schema_type(schema):
schema_str = str(schema) schema_str = str(schema)
assert ( assert (
"""type ArticleType implements Node { '''type ArticleType implements Node {
"""The ID of the object"""
id: ID! id: ID!
headline: String! headline: String!
pubDate: Date! pubDate: Date!
pubDateTime: DateTime! pubDateTime: DateTime!
reporter: ReporterType! reporter: ReporterType!
editor: ReporterType! editor: ReporterType!
lang: ArticleLang!
importance: ArticleImportance """Language"""
}""" lang: TestsArticleLangChoices!
importance: TestsArticleImportanceChoices
}'''
in schema_str in schema_str
) )
assert ( filters = {
"""type Query { "offset": "Int",
allReporters(offset: Int, before: String, after: String, first: Int, last: Int): ReporterTypeConnection "before": "String",
allArticles(offset: Int, before: String, after: String, first: Int, last: Int, lang: ArticleLang, lang_In: [ArticleLang], reporter_AChoice: ReporterAChoice, reporter_AChoice_In: [ReporterAChoice]): ArticleTypeConnection "after": "String",
}""" "first": "Int",
in schema_str "last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"reporter_AChoice": "TestsReporterAChoiceChoices",
"reporter_AChoice_In": "[TestsReporterAChoiceChoices]",
}
filters_str = ", ".join(
[
f"{filter_field}: {gql_type} = null"
for filter_field, gql_type in filters.items()
]
) )
assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str

View File

@ -739,6 +739,7 @@ def test_order_by():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField( all_reporters = DjangoFilterConnectionField(
@ -1236,6 +1237,7 @@ def test_filter_string_contains():
class Meta: class Meta:
model = Person model = Person
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filter_fields = {"name": ["exact", "in", "contains", "icontains"]} filter_fields = {"name": ["exact", "in", "contains", "icontains"]}
class Query(ObjectType): class Query(ObjectType):

View File

@ -29,6 +29,7 @@ def query():
class Meta: class Meta:
model = Pet model = Pet
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filter_fields = { filter_fields = {
"id": ["exact", "in"], "id": ["exact", "in"],
"name": ["exact", "in"], "name": ["exact", "in"],
@ -39,6 +40,7 @@ def query():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
# choice filter using enum # choice filter using enum
filter_fields = {"reporter_type": ["exact", "in"]} filter_fields = {"reporter_type": ["exact", "in"]}
@ -46,12 +48,14 @@ def query():
class Meta: class Meta:
model = Article model = Article
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilter filterset_class = ArticleFilter
class FilmNode(DjangoObjectType): class FilmNode(DjangoObjectType):
class Meta: class Meta:
model = Film model = Film
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
# choice filter not using enum # choice filter not using enum
filter_fields = { filter_fields = {
"genre": ["exact", "in"], "genre": ["exact", "in"],
@ -77,6 +81,7 @@ def query():
model = Person model = Person
interfaces = (Node,) interfaces = (Node,)
filterset_class = PersonFilterSet filterset_class = PersonFilterSet
fields = "__all__"
class Query(ObjectType): class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode) pets = DjangoFilterConnectionField(PetNode)

View File

@ -25,6 +25,7 @@ class PetNode(DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filter_fields = { filter_fields = {
"name": ["exact", "in"], "name": ["exact", "in"],
"age": ["exact", "in", "range"], "age": ["exact", "in", "range"],
@ -101,14 +102,14 @@ def test_range_filter_with_invalid_input():
# Empty list # Empty list
result = schema.execute(query, variables={"rangeValue": []}) result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']" assert result.errors[0].message == expected_error
# Only one item in the list # Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]}) result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']" assert result.errors[0].message == expected_error
# More than 2 items in the list # More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]}) result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']" assert result.errors[0].message == expected_error

View File

@ -94,9 +94,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
field_type = graphene.List(field_type.get_type()) field_type = graphene.List(field_type.get_type())
args[name] = graphene.Argument( args[name] = graphene.Argument(
type=field_type.get_type(), field_type.get_type(), description=filter_field.label, required=required,
description=filter_field.label,
required=required,
) )
return args return args

View File

@ -1253,6 +1253,7 @@ class TestBackwardPagination:
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1457,6 +1458,7 @@ def test_connection_should_enable_offset_filtering():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1496,6 +1498,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1529,6 +1532,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1563,6 +1567,7 @@ def test_connection_should_allow_offset_filtering_with_after():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)

View File

@ -671,6 +671,7 @@ def test_django_objecttype_name_connection_propagation():
class Meta: class Meta:
model = ReporterModel model = ReporterModel
name = "CustomReporterName" name = "CustomReporterName"
fields = "__all__"
filter_fields = ["email"] filter_fields = ["email"]
interfaces = (Node,) interfaces = (Node,)