diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index bdccd4a..c6425b8 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -1,5 +1,9 @@ +import inspect + +from collections import OrderedDict from functools import partial +from graphene.types.argument import to_arguments from ..fields import DjangoConnectionField from graphene.relay import is_node from .utils import get_filtering_args_from_filterset, get_filterset_class @@ -7,28 +11,66 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class class DjangoFilterConnectionField(DjangoConnectionField): - def __init__(self, type, fields=None, extra_filter_meta=None, - filterset_class=None, *args, **kwargs): + def __init__(self, type, fields=None, order_by=None, + extra_filter_meta=None, filterset_class=None, + *args, **kwargs): + self._fields = fields + self._type = type + self._filterset_class = filterset_class + self._extra_filter_meta = extra_filter_meta + self._base_args = None + super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) - if is_node(type): - _fields = type._meta.filter_fields - _model = type._meta.model + @property + def node_type(self): + if inspect.isfunction(self._type) or inspect.ismethod(self._type): + return self._type() + return self._type + + @property + def meta(self): + if is_node(self.node_type): + _model = self.node_type._meta.model else: # ConnectionFields can also be passed Connections, # in which case, we need to use the Node of the connection # to get our relevant args. - _fields = type._meta.node._meta.filter_fields - _model = type._meta.node._meta.model + _model = self.node_type._meta.node._meta.model - self.fields = fields or _fields - meta = dict(model=_model, fields=self.fields) - if extra_filter_meta: - meta.update(extra_filter_meta) - self.filterset_class = get_filterset_class(filterset_class, **meta) - self.filtering_args = get_filtering_args_from_filterset(self.filterset_class, type) - kwargs.setdefault('args', {}) - kwargs['args'].update(self.filtering_args) - super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) + meta = dict(model=_model, + fields=self.fields) + if self._extra_filter_meta: + meta.update(self._extra_filter_meta) + return meta + + @property + def fields(self): + if self._fields: + return self._fields + + if is_node(self.node_type): + return self.node_type._meta.filter_fields + else: + # ConnectionFields can also be passed Connections, + # in which case, we need to use the Node of the connection + # to get our relevant args. + return self.node_type._meta.node._meta.filter_fields + + @property + def args(self): + return to_arguments(self._base_args or OrderedDict(), self.filtering_args) + + @args.setter + def args(self, args): + self._base_args = args + + @property + def filterset_class(self): + return get_filterset_class(self._filterset_class, **self.meta) + + @property + def filtering_args(self): + return get_filtering_args_from_filterset(self.filterset_class, self.node_type) @staticmethod def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args, diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 63c9e37..c95e2d7 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -348,3 +348,20 @@ def test_filter_filterset_related_results(): assert not result.errors # We should only get two reporters assert len(result.data['allReporters']['edges']) == 2 + + +def test_recursive_filter_connection(): + class ReporterFilterNode(DjangoObjectType): + child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode) + + def resolve_child_reporters(self, args, context, info): + return [] + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterFilterNode) + + assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode