Long-winded intersection using sets

This commit is contained in:
Niall 2017-03-06 19:41:04 +00:00
parent e2284fefb5
commit fda876fdc2
2 changed files with 10 additions and 25 deletions

View File

@ -53,9 +53,11 @@ class DjangoConnectionField(ConnectionField):
iterable = default_manager iterable = default_manager
iterable = maybe_queryset(iterable) iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet): if isinstance(iterable, QuerySet):
if default_manager is not None and iterable is not default_manager: if iterable is not default_manager:
iterable &= maybe_queryset(default_manager) iterable = list(set(iterable).intersection(maybe_queryset(default_manager)))
_len = iterable.count() _len = len(iterable)
else:
_len = iterable.count()
else: else:
_len = len(iterable) _len = len(iterable)
connection = connection_from_list_slice( connection = connection_from_list_slice(

View File

@ -1,8 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from django.db.models.query import QuerySet
# from graphene.relay import is_node # from graphene.relay import is_node
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
@ -46,30 +44,15 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filtering_args(self): def filtering_args(self):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type) return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
# @staticmethod
# def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args,
# root, args, context, info):
# filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
# qs = filterset_class(
# data=filter_kwargs,
# queryset=default_manager.get_queryset()
# ).qs
# return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info)
@staticmethod @staticmethod
def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args, def connection_resolver(resolver, connection, default_manager, filterset_class, filtering_args,
root, args, context, info): root, args, context, info):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
def new_resolver(root, args, context, info): data=filter_kwargs,
qs = resolver(root, args, context, info) queryset=default_manager.get_queryset()
if qs is None or not isinstance(qs, QuerySet): ).qs
qs = default_manager.get_queryset() return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info)
qs = filterset_class(data=filter_kwargs, queryset=qs).qs
return qs
return DjangoConnectionField.connection_resolver(new_resolver, connection, None, root, args, context, info)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(), return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(),