diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index defcfc1..bdccd4a 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -1,6 +1,7 @@ from functools import partial from ..fields import DjangoConnectionField +from graphene.relay import is_node from .utils import get_filtering_args_from_filterset, get_filterset_class @@ -9,9 +10,18 @@ class DjangoFilterConnectionField(DjangoConnectionField): def __init__(self, type, fields=None, extra_filter_meta=None, filterset_class=None, *args, **kwargs): - self.fields = fields or type._meta.filter_fields - meta = dict(model=type._meta.model, - fields=self.fields) + if is_node(type): + _fields = type._meta.filter_fields + _model = type._meta.model + else: + # ConnectionFields can also be passed Connections, + # in which case, we need to use the Node of the connection + # to get our relevant args. + _fields = type._meta.node._meta.filter_fields + _model = type._meta.node._meta.model + + self.fields = fields or _fields + meta = dict(model=_model, fields=self.fields) if extra_filter_meta: meta.update(extra_filter_meta) self.filterset_class = get_filterset_class(filterset_class, **meta)