Introduce refine_queryset and get_connection_parameters

refine_queryset lets DjangoObjectType define exactly how the queryset
should be refined before being returned to the user. For instance, some
objects could be filtered out according to a predicate, or some fields
could be prefetched depending on what the initial query requested.

get_connection_parameters lets DjangoObjectType define the name and
value of parameters that can be passed to any DjangoConnectionField that
uses them.

Both these additions come as building blocks to allow custom refinements
and filters without having to go through django-filter. Moreover, such
filters can also be further optimized than previously allowed, as the
GraphQL info object is available in refine_queryset.
This commit is contained in:
Alexandre Kirszenberg 2019-02-04 19:38:07 +01:00
parent f76f38ef30
commit 2417f79693
3 changed files with 51 additions and 25 deletions

View File

@ -1,10 +1,12 @@
from functools import partial
from collections import OrderedDict
from django.db.models.query import QuerySet
from promise import Promise
from graphene.types import Field, List
from graphene.types.argument import to_arguments
from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
@ -29,7 +31,7 @@ class DjangoListField(Field):
class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
def __init__(self, type, *args, **kwargs):
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
@ -38,7 +40,15 @@ class DjangoConnectionField(ConnectionField):
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
super(DjangoConnectionField, self).__init__(type, *args, **kwargs)
@property
def args(self):
return to_arguments(self._base_args or OrderedDict(), self.node_type.get_connection_parameters())
@args.setter
def args(self, args):
self._base_args = args
@property
def type(self):
@ -76,7 +86,7 @@ class DjangoConnectionField(ConnectionField):
return queryset & default_queryset
@classmethod
def resolve_connection(cls, connection, default_manager, args, iterable):
def resolve_connection(cls, connection, node, default_manager, args, info, iterable):
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable)
@ -84,6 +94,9 @@ class DjangoConnectionField(ConnectionField):
if iterable is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
from .types import DjangoObjectType
if issubclass(node, DjangoObjectType):
iterable = node.refine_queryset(iterable, info, **args)
_len = iterable.count()
else:
_len = len(iterable)
@ -103,15 +116,16 @@ class DjangoConnectionField(ConnectionField):
@classmethod
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
cls,
resolver,
connection,
node,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
@ -135,7 +149,7 @@ class DjangoConnectionField(ConnectionField):
args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
on_resolve = partial(cls.resolve_connection, connection, node, default_manager, args, info)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
@ -147,6 +161,7 @@ class DjangoConnectionField(ConnectionField):
self.connection_resolver,
parent_resolver,
self.type,
self.node_type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last,

View File

@ -76,17 +76,18 @@ class DjangoFilterConnectionField(DjangoConnectionField):
@classmethod
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
cls,
resolver,
connection,
node,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
@ -98,6 +99,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
return super(DjangoFilterConnectionField, cls).connection_resolver(
resolver,
connection,
node,
qs,
max_limit,
enforce_first_or_last,
@ -111,6 +113,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.connection_resolver,
parent_resolver,
self.type,
self.node_type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last,

View File

@ -111,6 +111,14 @@ class DjangoObjectType(ObjectType):
if not skip_registry:
registry.register(cls)
@classmethod
def refine_queryset(cls, qs, info, **kwargs):
return qs
@classmethod
def get_connection_parameters(cls):
return {}
def resolve_id(self, info):
return self.pk
@ -130,6 +138,6 @@ class DjangoObjectType(ObjectType):
@classmethod
def get_node(cls, info, id):
try:
return cls._meta.model.objects.get(pk=id)
return cls.refine_queryset(cls._meta.model.objects, info).get(pk=id)
except cls._meta.model.DoesNotExist:
return None