graphene-django/graphene_django/fields.py
Alexandre Kirszenberg 2417f79693 Introduce refine_queryset and get_connection_parameters
refine_queryset lets DjangoObjectType define exactly how the queryset
should be refined before being returned to the user. For instance, some
objects could be filtered out according to a predicate, or some fields
could be prefetched depending on what the initial query requested.

get_connection_parameters lets DjangoObjectType define the name and
value of parameters that can be passed to any DjangoConnectionField that
uses them.

Both these additions come as building blocks to allow custom refinements
and filters without having to go through django-filter. Moreover, such
filters can also be further optimized than previously allowed, as the
GraphQL info object is available in refine_queryset.
2019-02-05 18:13:21 +01:00

169 lines
5.5 KiB
Python

from functools import partial
from collections import OrderedDict
from django.db.models.query import QuerySet
from promise import Promise
from graphene.types import Field, List
from graphene.types.argument import to_arguments
from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from .settings import graphene_settings
from .utils import maybe_queryset
class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@property
def model(self):
return self.type.of_type._meta.node._meta.model
@staticmethod
def list_resolver(resolver, root, info, **args):
return maybe_queryset(resolver(root, info, **args))
def get_resolver(self, parent_resolver):
return partial(self.list_resolver, parent_resolver)
class DjangoConnectionField(ConnectionField):
def __init__(self, type, *args, **kwargs):
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
)
self.enforce_first_or_last = kwargs.pop(
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(type, *args, **kwargs)
@property
def args(self):
return to_arguments(self._base_args or OrderedDict(), self.node_type.get_connection_parameters())
@args.setter
def args(self, args):
self._base_args = args
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(
_type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection
@property
def node_type(self):
return self.type._meta.node
@property
def model(self):
return self.node_type._meta.model
def get_manager(self):
if self.on:
return getattr(self.model, self.on)
else:
return self.model._default_manager
@classmethod
def merge_querysets(cls, default_queryset, queryset):
if default_queryset.query.distinct and not queryset.query.distinct:
queryset = queryset.distinct()
elif queryset.query.distinct and not default_queryset.query.distinct:
default_queryset = default_queryset.distinct()
return queryset & default_queryset
@classmethod
def resolve_connection(cls, connection, node, default_manager, args, info, iterable):
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
if iterable is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
from .types import DjangoObjectType
if issubclass(node, DjangoObjectType):
iterable = node.refine_queryset(iterable, info, **args)
_len = iterable.count()
else:
_len = len(iterable)
connection = connection_from_list_slice(
iterable,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection,
edge_type=connection.Edge,
pageinfo_type=PageInfo,
)
connection.iterable = iterable
connection.length = _len
return connection
@classmethod
def connection_resolver(
cls,
resolver,
connection,
node,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
if enforce_first_or_last:
assert first or last, (
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit)
args["first"] = min(first, max_limit)
if last:
assert last <= max_limit, (
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit)
args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, node, default_manager, args, info)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
return on_resolve(iterable)
def get_resolver(self, parent_resolver):
return partial(
self.connection_resolver,
parent_resolver,
self.type,
self.node_type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last,
)