diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 1ffa0f4..462fcf8 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -877,7 +877,7 @@ def test_filter_filterset_based_on_mixin(): filters = super(FilterSet, cls).get_filters() filters.update( { - "viewer__email__in": django_filters.CharFilter( + "reporter__email__in": django_filters.CharFilter( method="filter_email_in", field_name="reporter__email__in" ) } @@ -897,16 +897,11 @@ def test_filter_filterset_based_on_mixin(): interfaces = (Node,) class NewArticleFilterNode(DjangoObjectType): - viewer = Field(NewReporterNode) - class Meta: model = Article interfaces = (Node,) filterset_class = NewArticleFilter - def resolve_viewer(self, info): - return self.reporter - class Query(ObjectType): all_articles = DjangoFilterConnectionField(NewArticleFilterNode) @@ -939,11 +934,11 @@ def test_filter_filterset_based_on_mixin(): query = ( """ query NodeFilteringQuery { - allArticles(viewer_Email_In: "%s") { + allArticles(reporter_Email_In: "%s") { edges { node { headline - viewer { + reporter { email } } @@ -960,7 +955,7 @@ def test_filter_filterset_based_on_mixin(): { "node": { "headline": article_1.headline, - "viewer": {"email": reporter_1.email}, + "reporter": {"email": reporter_1.email}, } } ] diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index abb03a9..2ec6b40 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -1,8 +1,36 @@ import six +from django.db.models.constants import LOOKUP_SEP +from django.core.exceptions import FieldDoesNotExist +from django.db.models.fields.related import ForeignObjectRel, RelatedField from .filterset import custom_filterset_factory, setup_filterset +def get_field_parts_with_expression(model, query_expr): + """ + Traverses the model with a given query expression, + returns the found fields along the path and the remaining expression + """ + parts = query_expr.split(LOOKUP_SEP) + opts = model._meta + fields = [] + + # walk relationships + for i, name in enumerate(parts): + try: + field = opts.get_field(name) + except FieldDoesNotExist: + return fields, LOOKUP_SEP.join(parts[i:]) + + fields.append(field) + if isinstance(field, RelatedField): + opts = field.remote_field.model._meta + elif isinstance(field, ForeignObjectRel): + opts = field.related_model._meta + + return fields, "exact" + + def get_filtering_args_from_filterset(filterset_class, type): """ Inspect a FilterSet and produce the arguments to pass to a Graphene Field. These arguments will be available to @@ -18,22 +46,15 @@ def get_filtering_args_from_filterset(filterset_class, type): if name in filterset_class.declared_filters: form_field = filter_field.field else: - try: - field_name, filter_type = name.rsplit("__", 1) - except ValueError: - field_name = name - filter_type = None + fields, lookup_expr = get_field_parts_with_expression(model, name) + assert fields, str((model, name, filterset_class)) + model_field = fields[-1] + filter_type = lookup_expr - # If the filter type is `isnull` then use the filter provided by - # DjangoFilter (a BooleanFilter). - # Otherwise try and get a filter based on the actual model field - if filter_type != "isnull" and hasattr(model, field_name): - model_field = model._meta.get_field(field_name) - - if hasattr(model_field, "formfield"): - form_field = model_field.formfield( - required=filter_field.extra.get("required", False) - ) + if filter_type != "isnull" and hasattr(model_field, "formfield"): + form_field = model_field.formfield( + required=filter_field.extra.get("required", False) + ) # Fallback to field defined on filter if we can't get it from the # model field