From 0425985dab3e468c5e9e2edde4e747888e554f50 Mon Sep 17 00:00:00 2001 From: Kike Isidoro Date: Fri, 2 Aug 2019 16:31:28 +0200 Subject: [PATCH] Check for filters defined on base filterset classes --- graphene_django/filter/tests/test_fields.py | 83 +++++++++++++++++++++ graphene_django/filter/utils.py | 22 +++--- 2 files changed, 96 insertions(+), 9 deletions(-) diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 99876b6..026f909 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -818,3 +818,86 @@ def test_integer_field_filter_type(): } """ ) + + +def test_filter_filterset_based_on_mixin(): + class ArticleFilterMixin: + + @classmethod + def get_filters(cls): + filters = super().get_filters() + filters.update({ + 'viewer__email__in': django_filters.CharFilter( + method='filter_email_in', + field_name='reporter__email__in', + ), + }) + + return filters + + class NewArticleFilter(ArticleFilterMixin, ArticleFilter): + pass + + class NewReporterNode(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class NewArticleFilterNode(DjangoObjectType): + viewer = Field(NewReporterNode) + + class Meta: + model = Article + interfaces = (Node,) + filterset_class = NewArticleFilter + + def resolve_viewer(self, info): + return self.reporter + + class Query(ObjectType): + all_articles = DjangoFilterConnectionField(NewArticleFilterNode) + + reporter = Reporter.objects.create( + first_name="John", last_name="Doe", email="john@doe.com") + + article = Article.objects.create( + headline="Hello", + reporter=reporter, + editor=reporter, + pub_date=datetime.now(), + pub_date_time=datetime.now()) + + schema = Schema(query=Query) + + query = """ + query NodeFilteringQuery { + allArticles { + edges { + node { + viewer { + email + } + } + } + } + } + """ + + expected = { + "allArticles": { + "edges": [ + { + "node": { + "viewer": { + "email": reporter.email, + } + } + } + ] + } + } + + result = schema.execute(query) + + assert not result.errors + assert result.data == expected diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 00030a0..81efb63 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -13,21 +13,25 @@ def get_filtering_args_from_filterset(filterset_class, type): args = {} model = filterset_class._meta.model for name, filter_field in six.iteritems(filterset_class.base_filters): + form_field = None + if name in filterset_class.declared_filters: form_field = filter_field.field else: field_name = name.split("__", 1)[0] - model_field = model._meta.get_field(field_name) - if hasattr(model_field, "formfield"): - form_field = model_field.formfield( - required=filter_field.extra.get("required", False) - ) + if hasattr(model, field_name): + model_field = model._meta.get_field(field_name) - # Fallback to field defined on filter if we can't get it from the - # model field - if not form_field: - form_field = filter_field.field + if hasattr(model_field, "formfield"): + form_field = model_field.formfield( + required=filter_field.extra.get("required", False) + ) + + # Fallback to field defined on filter if we can't get it from the + # model field + if not form_field: + form_field = filter_field.field field_type = convert_form_field(form_field).Argument() field_type.description = filter_field.label