diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 1da68d5..024f6bd 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -36,7 +36,7 @@ from .utils.str_converters import to_const class BlankValueField(Field): - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): resolver = self.resolver or parent_resolver # create custom resolver diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index cf3f71a..c6dd50e 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -2,12 +2,31 @@ from collections import OrderedDict from functools import partial from django.core.exceptions import ValidationError + +from graphene.types.enum import EnumType from graphene.types.argument import to_arguments from graphene.utils.str_converters import to_snake_case + from ..fields import DjangoConnectionField from .utils import get_filtering_args_from_filterset, get_filterset_class +def convert_enum(data): + """ + Check if the data is a enum option (or potentially nested list of enum option) + and convert it to its value. + + This method is used to pre-process the data for the filters as they can take an + graphene.Enum as argument, but filters (from django_filters) expect a simple value. + """ + if isinstance(data, list): + return [convert_enum(item) for item in data] + if isinstance(type(data), EnumType): + return data.value + else: + return data + + class DjangoFilterConnectionField(DjangoConnectionField): def __init__( self, @@ -68,7 +87,7 @@ class DjangoFilterConnectionField(DjangoConnectionField): if k in filtering_args: if k == "order_by" and v is not None: v = to_snake_case(v) - kwargs[k] = v + kwargs[k] = convert_enum(v) return kwargs qs = super(DjangoFilterConnectionField, cls).resolve_queryset( @@ -78,7 +97,7 @@ class DjangoFilterConnectionField(DjangoConnectionField): filterset = filterset_class( data=filter_kwargs(), queryset=qs, request=info.context ) - if filterset.form.is_valid(): + if filterset.is_valid(): return filterset.qs raise ValidationError(filterset.form.errors.as_json()) diff --git a/graphene_django/filter/tests/conftest.py b/graphene_django/filter/tests/conftest.py index 710234f..57924af 100644 --- a/graphene_django/filter/tests/conftest.py +++ b/graphene_django/filter/tests/conftest.py @@ -28,19 +28,15 @@ else: STORE = {"events": []} -@pytest.fixture -def Event(): - class Event(models.Model): - name = models.CharField(max_length=50) - tags = ArrayField(models.CharField(max_length=50)) - tag_ids = ArrayField(models.IntegerField()) - random_field = ArrayField(models.BooleanField()) - - return Event +class Event(models.Model): + name = models.CharField(max_length=50) + tags = ArrayField(models.CharField(max_length=50)) + tag_ids = ArrayField(models.IntegerField()) + random_field = ArrayField(models.BooleanField()) @pytest.fixture -def EventFilterSet(Event): +def EventFilterSet(): class EventFilterSet(FilterSet): class Meta: model = Event @@ -69,18 +65,19 @@ def EventFilterSet(Event): @pytest.fixture -def EventType(Event, EventFilterSet): +def EventType(EventFilterSet): class EventType(DjangoObjectType): class Meta: model = Event interfaces = (Node,) + fields = "__all__" filterset_class = EventFilterSet return EventType @pytest.fixture -def Query(Event, EventType): +def Query(EventType): """ Note that we have to use a custom resolver to replicate the arrayfield filter behavior as we are running unit tests in sqlite which does not have ArrayFields. diff --git a/graphene_django/filter/tests/test_array_field_exact_filter.py b/graphene_django/filter/tests/test_array_field_exact_filter.py index 814fd33..b07abed 100644 --- a/graphene_django/filter/tests/test_array_field_exact_filter.py +++ b/graphene_django/filter/tests/test_array_field_exact_filter.py @@ -89,19 +89,41 @@ def test_array_field_filter_schema_type(Query): schema_str = str(schema) assert ( - """type EventType implements Node { + '''type EventType implements Node { + """The ID of the object""" id: ID! name: String! tags: [String!]! tagIds: [Int!]! randomField: [Boolean!]! -}""" +}''' in schema_str ) - assert ( - """type Query { - events(offset: Int, before: String, after: String, first: Int, last: Int, name: String, name_Contains: String, tags_Contains: [String!], tags_Overlap: [String!], tags: [String!], tagsIds_Contains: [Int!], tagsIds_Overlap: [Int!], tagsIds: [Int!], randomField_Contains: [Boolean!], randomField_Overlap: [Boolean!], randomField: [Boolean!]): EventTypeConnection -}""" - in schema_str + filters = { + "offset": "Int", + "before": "String", + "after": "String", + "first": "Int", + "last": "Int", + "name": "String", + "name_Contains": "String", + "tags_Contains": "[String!]", + "tags_Overlap": "[String!]", + "tags": "[String!]", + "tagsIds_Contains": "[Int!]", + "tagsIds_Overlap": "[Int!]", + "tagsIds": "[Int!]", + "randomField_Contains": "[Boolean!]", + "randomField_Overlap": "[Boolean!]", + "randomField": "[Boolean!]", + } + filters_str = ", ".join( + [ + f"{filter_field}: {gql_type} = null" + for filter_field, gql_type in filters.items() + ] + ) + assert ( + f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str ) diff --git a/graphene_django/filter/tests/test_enum_filtering.py b/graphene_django/filter/tests/test_enum_filtering.py index 650c55e..09c69b3 100644 --- a/graphene_django/filter/tests/test_enum_filtering.py +++ b/graphene_django/filter/tests/test_enum_filtering.py @@ -25,11 +25,13 @@ def schema(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class ArticleType(DjangoObjectType): class Meta: model = Article interfaces = (Node,) + fields = "__all__" filter_fields = { "lang": ["exact", "in"], "reporter__a_choice": ["exact", "in"], @@ -122,23 +124,37 @@ def test_filter_enum_field_schema_type(schema): schema_str = str(schema) assert ( - """type ArticleType implements Node { + '''type ArticleType implements Node { + """The ID of the object""" id: ID! headline: String! pubDate: Date! pubDateTime: DateTime! reporter: ReporterType! editor: ReporterType! - lang: ArticleLang! - importance: ArticleImportance -}""" + + """Language""" + lang: TestsArticleLangChoices! + importance: TestsArticleImportanceChoices +}''' in schema_str ) - assert ( - """type Query { - allReporters(offset: Int, before: String, after: String, first: Int, last: Int): ReporterTypeConnection - allArticles(offset: Int, before: String, after: String, first: Int, last: Int, lang: ArticleLang, lang_In: [ArticleLang], reporter_AChoice: ReporterAChoice, reporter_AChoice_In: [ReporterAChoice]): ArticleTypeConnection -}""" - in schema_str + filters = { + "offset": "Int", + "before": "String", + "after": "String", + "first": "Int", + "last": "Int", + "lang": "TestsArticleLangChoices", + "lang_In": "[TestsArticleLangChoices]", + "reporter_AChoice": "TestsReporterAChoiceChoices", + "reporter_AChoice_In": "[TestsReporterAChoiceChoices]", + } + filters_str = ", ".join( + [ + f"{filter_field}: {gql_type} = null" + for filter_field, gql_type in filters.items() + ] ) + assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 461f29f..274f6ac 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -739,6 +739,7 @@ def test_order_by(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(ObjectType): all_reporters = DjangoFilterConnectionField( @@ -1236,6 +1237,7 @@ def test_filter_string_contains(): class Meta: model = Person interfaces = (Node,) + fields = "__all__" filter_fields = {"name": ["exact", "in", "contains", "icontains"]} class Query(ObjectType): diff --git a/graphene_django/filter/tests/test_in_filter.py b/graphene_django/filter/tests/test_in_filter.py index f0015b6..7ad0286 100644 --- a/graphene_django/filter/tests/test_in_filter.py +++ b/graphene_django/filter/tests/test_in_filter.py @@ -29,6 +29,7 @@ def query(): class Meta: model = Pet interfaces = (Node,) + fields = "__all__" filter_fields = { "id": ["exact", "in"], "name": ["exact", "in"], @@ -39,6 +40,7 @@ def query(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" # choice filter using enum filter_fields = {"reporter_type": ["exact", "in"]} @@ -46,12 +48,14 @@ def query(): class Meta: model = Article interfaces = (Node,) + fields = "__all__" filterset_class = ArticleFilter class FilmNode(DjangoObjectType): class Meta: model = Film interfaces = (Node,) + fields = "__all__" # choice filter not using enum filter_fields = { "genre": ["exact", "in"], @@ -77,6 +81,7 @@ def query(): model = Person interfaces = (Node,) filterset_class = PersonFilterSet + fields = "__all__" class Query(ObjectType): pets = DjangoFilterConnectionField(PetNode) diff --git a/graphene_django/filter/tests/test_range_filter.py b/graphene_django/filter/tests/test_range_filter.py index 995f588..6227a70 100644 --- a/graphene_django/filter/tests/test_range_filter.py +++ b/graphene_django/filter/tests/test_range_filter.py @@ -25,6 +25,7 @@ class PetNode(DjangoObjectType): class Meta: model = Pet interfaces = (Node,) + fields = "__all__" filter_fields = { "name": ["exact", "in"], "age": ["exact", "in", "range"], @@ -101,14 +102,14 @@ def test_range_filter_with_invalid_input(): # Empty list result = schema.execute(query, variables={"rangeValue": []}) assert len(result.errors) == 1 - assert result.errors[0].message == f"['{expected_error}']" + assert result.errors[0].message == expected_error # Only one item in the list result = schema.execute(query, variables={"rangeValue": [1]}) assert len(result.errors) == 1 - assert result.errors[0].message == f"['{expected_error}']" + assert result.errors[0].message == expected_error # More than 2 items in the list result = schema.execute(query, variables={"rangeValue": [1, 2, 3]}) assert len(result.errors) == 1 - assert result.errors[0].message == f"['{expected_error}']" + assert result.errors[0].message == expected_error diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 6df8859..d4fc1bf 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -94,9 +94,7 @@ def get_filtering_args_from_filterset(filterset_class, type): field_type = graphene.List(field_type.get_type()) args[name] = graphene.Argument( - type=field_type.get_type(), - description=filter_field.label, - required=required, + field_type.get_type(), description=filter_field.label, required=required, ) return args diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index e01c098..aabe19c 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -1253,6 +1253,7 @@ class TestBackwardPagination: class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1457,6 +1458,7 @@ def test_connection_should_enable_offset_filtering(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1496,6 +1498,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit( class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1529,6 +1532,7 @@ def test_connection_should_forbid_offset_filtering_with_before(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1563,6 +1567,7 @@ def test_connection_should_allow_offset_filtering_with_after(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index cb653e1..bde72c7 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -671,6 +671,7 @@ def test_django_objecttype_name_connection_propagation(): class Meta: model = ReporterModel name = "CustomReporterName" + fields = "__all__" filter_fields = ["email"] interfaces = (Node,)