Merge pull request #58 from tangerilli/recursive-nodes

Allow recursive connections with DjangoFilterConnectionField
This commit is contained in:
Syrus Akbary 2017-03-02 17:26:51 -08:00 committed by GitHub
commit acff3d59db
2 changed files with 75 additions and 16 deletions

View File

@ -1,5 +1,9 @@
import inspect
from collections import OrderedDict
from functools import partial from functools import partial
from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from graphene.relay import is_node from graphene.relay import is_node
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
@ -7,28 +11,66 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
class DjangoFilterConnectionField(DjangoConnectionField): class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(self, type, fields=None, extra_filter_meta=None, def __init__(self, type, fields=None, order_by=None,
filterset_class=None, *args, **kwargs): extra_filter_meta=None, filterset_class=None,
*args, **kwargs):
self._fields = fields
self._type = type
self._filterset_class = filterset_class
self._extra_filter_meta = extra_filter_meta
self._base_args = None
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
if is_node(type): @property
_fields = type._meta.filter_fields def node_type(self):
_model = type._meta.model if inspect.isfunction(self._type) or inspect.ismethod(self._type):
return self._type()
return self._type
@property
def meta(self):
if is_node(self.node_type):
_model = self.node_type._meta.model
else: else:
# ConnectionFields can also be passed Connections, # ConnectionFields can also be passed Connections,
# in which case, we need to use the Node of the connection # in which case, we need to use the Node of the connection
# to get our relevant args. # to get our relevant args.
_fields = type._meta.node._meta.filter_fields _model = self.node_type._meta.node._meta.model
_model = type._meta.node._meta.model
self.fields = fields or _fields meta = dict(model=_model,
meta = dict(model=_model, fields=self.fields) fields=self.fields)
if extra_filter_meta: if self._extra_filter_meta:
meta.update(extra_filter_meta) meta.update(self._extra_filter_meta)
self.filterset_class = get_filterset_class(filterset_class, **meta) return meta
self.filtering_args = get_filtering_args_from_filterset(self.filterset_class, type)
kwargs.setdefault('args', {}) @property
kwargs['args'].update(self.filtering_args) def fields(self):
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) if self._fields:
return self._fields
if is_node(self.node_type):
return self.node_type._meta.filter_fields
else:
# ConnectionFields can also be passed Connections,
# in which case, we need to use the Node of the connection
# to get our relevant args.
return self.node_type._meta.node._meta.filter_fields
@property
def args(self):
return to_arguments(self._base_args or OrderedDict(), self.filtering_args)
@args.setter
def args(self, args):
self._base_args = args
@property
def filterset_class(self):
return get_filterset_class(self._filterset_class, **self.meta)
@property
def filtering_args(self):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
@staticmethod @staticmethod
def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args, def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args,

View File

@ -348,3 +348,20 @@ def test_filter_filterset_related_results():
assert not result.errors assert not result.errors
# We should only get two reporters # We should only get two reporters
assert len(result.data['allReporters']['edges']) == 2 assert len(result.data['allReporters']['edges']) == 2
def test_recursive_filter_connection():
class ReporterFilterNode(DjangoObjectType):
child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode)
def resolve_child_reporters(self, args, context, info):
return []
class Meta:
model = Reporter
interfaces = (Node, )
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode