mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-12-08 02:34:16 +03:00
refine_queryset lets DjangoObjectType define exactly how the queryset should be refined before being returned to the user. For instance, some objects could be filtered out according to a predicate, or some fields could be prefetched depending on what the initial query requested. get_connection_parameters lets DjangoObjectType define the name and value of parameters that can be passed to any DjangoConnectionField that uses them. Both these additions come as building blocks to allow custom refinements and filters without having to go through django-filter. Moreover, such filters can also be further optimized than previously allowed, as the GraphQL info object is available in refine_queryset.
169 lines
5.5 KiB
Python
169 lines
5.5 KiB
Python
from functools import partial
|
|
from collections import OrderedDict
|
|
|
|
from django.db.models.query import QuerySet
|
|
|
|
from promise import Promise
|
|
|
|
from graphene.types import Field, List
|
|
from graphene.types.argument import to_arguments
|
|
from graphene.relay import ConnectionField, PageInfo
|
|
from graphql_relay.connection.arrayconnection import connection_from_list_slice
|
|
|
|
from .settings import graphene_settings
|
|
from .utils import maybe_queryset
|
|
|
|
|
|
class DjangoListField(Field):
|
|
def __init__(self, _type, *args, **kwargs):
|
|
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
|
|
|
|
@property
|
|
def model(self):
|
|
return self.type.of_type._meta.node._meta.model
|
|
|
|
@staticmethod
|
|
def list_resolver(resolver, root, info, **args):
|
|
return maybe_queryset(resolver(root, info, **args))
|
|
|
|
def get_resolver(self, parent_resolver):
|
|
return partial(self.list_resolver, parent_resolver)
|
|
|
|
|
|
class DjangoConnectionField(ConnectionField):
|
|
def __init__(self, type, *args, **kwargs):
|
|
self.on = kwargs.pop("on", False)
|
|
self.max_limit = kwargs.pop(
|
|
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
|
|
)
|
|
self.enforce_first_or_last = kwargs.pop(
|
|
"enforce_first_or_last",
|
|
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
|
|
)
|
|
super(DjangoConnectionField, self).__init__(type, *args, **kwargs)
|
|
|
|
@property
|
|
def args(self):
|
|
return to_arguments(self._base_args or OrderedDict(), self.node_type.get_connection_parameters())
|
|
|
|
@args.setter
|
|
def args(self, args):
|
|
self._base_args = args
|
|
|
|
@property
|
|
def type(self):
|
|
from .types import DjangoObjectType
|
|
|
|
_type = super(ConnectionField, self).type
|
|
assert issubclass(
|
|
_type, DjangoObjectType
|
|
), "DjangoConnectionField only accepts DjangoObjectType types"
|
|
assert _type._meta.connection, "The type {} doesn't have a connection".format(
|
|
_type.__name__
|
|
)
|
|
return _type._meta.connection
|
|
|
|
@property
|
|
def node_type(self):
|
|
return self.type._meta.node
|
|
|
|
@property
|
|
def model(self):
|
|
return self.node_type._meta.model
|
|
|
|
def get_manager(self):
|
|
if self.on:
|
|
return getattr(self.model, self.on)
|
|
else:
|
|
return self.model._default_manager
|
|
|
|
@classmethod
|
|
def merge_querysets(cls, default_queryset, queryset):
|
|
if default_queryset.query.distinct and not queryset.query.distinct:
|
|
queryset = queryset.distinct()
|
|
elif queryset.query.distinct and not default_queryset.query.distinct:
|
|
default_queryset = default_queryset.distinct()
|
|
return queryset & default_queryset
|
|
|
|
@classmethod
|
|
def resolve_connection(cls, connection, node, default_manager, args, info, iterable):
|
|
if iterable is None:
|
|
iterable = default_manager
|
|
iterable = maybe_queryset(iterable)
|
|
if isinstance(iterable, QuerySet):
|
|
if iterable is not default_manager:
|
|
default_queryset = maybe_queryset(default_manager)
|
|
iterable = cls.merge_querysets(default_queryset, iterable)
|
|
from .types import DjangoObjectType
|
|
if issubclass(node, DjangoObjectType):
|
|
iterable = node.refine_queryset(iterable, info, **args)
|
|
_len = iterable.count()
|
|
else:
|
|
_len = len(iterable)
|
|
connection = connection_from_list_slice(
|
|
iterable,
|
|
args,
|
|
slice_start=0,
|
|
list_length=_len,
|
|
list_slice_length=_len,
|
|
connection_type=connection,
|
|
edge_type=connection.Edge,
|
|
pageinfo_type=PageInfo,
|
|
)
|
|
connection.iterable = iterable
|
|
connection.length = _len
|
|
return connection
|
|
|
|
@classmethod
|
|
def connection_resolver(
|
|
cls,
|
|
resolver,
|
|
connection,
|
|
node,
|
|
default_manager,
|
|
max_limit,
|
|
enforce_first_or_last,
|
|
root,
|
|
info,
|
|
**args
|
|
):
|
|
first = args.get("first")
|
|
last = args.get("last")
|
|
|
|
if enforce_first_or_last:
|
|
assert first or last, (
|
|
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
|
|
).format(info.field_name)
|
|
|
|
if max_limit:
|
|
if first:
|
|
assert first <= max_limit, (
|
|
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
|
|
).format(first, info.field_name, max_limit)
|
|
args["first"] = min(first, max_limit)
|
|
|
|
if last:
|
|
assert last <= max_limit, (
|
|
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
|
|
).format(last, info.field_name, max_limit)
|
|
args["last"] = min(last, max_limit)
|
|
|
|
iterable = resolver(root, info, **args)
|
|
on_resolve = partial(cls.resolve_connection, connection, node, default_manager, args, info)
|
|
|
|
if Promise.is_thenable(iterable):
|
|
return Promise.resolve(iterable).then(on_resolve)
|
|
|
|
return on_resolve(iterable)
|
|
|
|
def get_resolver(self, parent_resolver):
|
|
return partial(
|
|
self.connection_resolver,
|
|
parent_resolver,
|
|
self.type,
|
|
self.node_type,
|
|
self.get_manager(),
|
|
self.max_limit,
|
|
self.enforce_first_or_last,
|
|
)
|