From fda876fdc20756c712ffa579f8f37ffb8c158c3b Mon Sep 17 00:00:00 2001 From: Niall Date: Mon, 6 Mar 2017 19:41:04 +0000 Subject: [PATCH] Long-winded intersection using sets --- graphene_django/fields.py | 8 +++++--- graphene_django/filter/fields.py | 27 +++++---------------------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 9ba999c..c7d9968 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -53,9 +53,11 @@ class DjangoConnectionField(ConnectionField): iterable = default_manager iterable = maybe_queryset(iterable) if isinstance(iterable, QuerySet): - if default_manager is not None and iterable is not default_manager: - iterable &= maybe_queryset(default_manager) - _len = iterable.count() + if iterable is not default_manager: + iterable = list(set(iterable).intersection(maybe_queryset(default_manager))) + _len = len(iterable) + else: + _len = iterable.count() else: _len = len(iterable) connection = connection_from_list_slice( diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 1b2c1c8..363e1d9 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -1,8 +1,6 @@ from collections import OrderedDict from functools import partial -from django.db.models.query import QuerySet - # from graphene.relay import is_node from graphene.types.argument import to_arguments from ..fields import DjangoConnectionField @@ -46,30 +44,15 @@ class DjangoFilterConnectionField(DjangoConnectionField): def filtering_args(self): return get_filtering_args_from_filterset(self.filterset_class, self.node_type) - # @staticmethod - # def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args, - # root, args, context, info): - # filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} - # qs = filterset_class( - # data=filter_kwargs, - # queryset=default_manager.get_queryset() - # ).qs - # return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info) - @staticmethod def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args, root, args, context, info): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} - - def new_resolver(root, args, context, info): - qs = resolver(root, args, context, info) - if qs is None or not isinstance(qs, QuerySet): - qs = default_manager.get_queryset() - qs = filterset_class(data=filter_kwargs, queryset=qs).qs - - return qs - - return DjangoConnectionField.connection_resolver(new_resolver, connection, None, root, args, context, info) + qs = filterset_class( + data=filter_kwargs, + queryset=default_manager.get_queryset() + ).qs + return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info) def get_resolver(self, parent_resolver): return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(),