mirror of
https://github.com/graphql-python/graphene-django.git
synced 2024-11-24 18:44:08 +03:00
233 lines
7.3 KiB
Python
233 lines
7.3 KiB
Python
from functools import partial
|
|
|
|
import six
|
|
from django.db.models.query import QuerySet
|
|
from graphql_relay.connection.arrayconnection import (
|
|
connection_from_list_slice,
|
|
get_offset_with_default,
|
|
)
|
|
from promise import Promise
|
|
|
|
from graphene import NonNull
|
|
from graphene.relay import ConnectionField, PageInfo
|
|
from graphene.types import Field, List
|
|
|
|
from .settings import graphene_settings
|
|
from .utils import maybe_queryset
|
|
|
|
|
|
class DjangoListField(Field):
|
|
def __init__(self, _type, *args, **kwargs):
|
|
from .types import DjangoObjectType
|
|
|
|
if isinstance(_type, NonNull):
|
|
_type = _type.of_type
|
|
|
|
# Django would never return a Set of None vvvvvvv
|
|
super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs)
|
|
|
|
assert issubclass(
|
|
self._underlying_type, DjangoObjectType
|
|
), "DjangoListField only accepts DjangoObjectType types"
|
|
|
|
@property
|
|
def _underlying_type(self):
|
|
_type = self._type
|
|
while hasattr(_type, "of_type"):
|
|
_type = _type.of_type
|
|
return _type
|
|
|
|
@property
|
|
def model(self):
|
|
return self._underlying_type._meta.model
|
|
|
|
def get_default_queryset(self):
|
|
return self.model._default_manager.get_queryset()
|
|
|
|
@staticmethod
|
|
def list_resolver(
|
|
django_object_type, resolver, default_queryset, root, info, **args
|
|
):
|
|
queryset = maybe_queryset(resolver(root, info, **args))
|
|
if queryset is None:
|
|
queryset = default_queryset
|
|
|
|
if isinstance(queryset, QuerySet):
|
|
# Pass queryset to the DjangoObjectType get_queryset method
|
|
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
|
|
|
|
return queryset
|
|
|
|
def get_resolver(self, parent_resolver):
|
|
_type = self.type
|
|
if isinstance(_type, NonNull):
|
|
_type = _type.of_type
|
|
django_object_type = _type.of_type.of_type
|
|
return partial(
|
|
self.list_resolver,
|
|
django_object_type,
|
|
parent_resolver,
|
|
self.get_default_queryset(),
|
|
)
|
|
|
|
|
|
class DjangoConnectionField(ConnectionField):
|
|
def __init__(self, *args, **kwargs):
|
|
self.on = kwargs.pop("on", False)
|
|
self.max_limit = kwargs.pop(
|
|
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
|
|
)
|
|
self.enforce_first_or_last = kwargs.pop(
|
|
"enforce_first_or_last",
|
|
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
|
|
)
|
|
super(DjangoConnectionField, self).__init__(*args, **kwargs)
|
|
|
|
@property
|
|
def type(self):
|
|
from .types import DjangoObjectType
|
|
|
|
_type = super(ConnectionField, self).type
|
|
non_null = False
|
|
if isinstance(_type, NonNull):
|
|
_type = _type.of_type
|
|
non_null = True
|
|
assert issubclass(
|
|
_type, DjangoObjectType
|
|
), "DjangoConnectionField only accepts DjangoObjectType types"
|
|
assert _type._meta.connection, "The type {} doesn't have a connection".format(
|
|
_type.__name__
|
|
)
|
|
connection_type = _type._meta.connection
|
|
if non_null:
|
|
return NonNull(connection_type)
|
|
return connection_type
|
|
|
|
@property
|
|
def connection_type(self):
|
|
type = self.type
|
|
if isinstance(type, NonNull):
|
|
return type.of_type
|
|
return type
|
|
|
|
@property
|
|
def node_type(self):
|
|
return self.connection_type._meta.node
|
|
|
|
@property
|
|
def model(self):
|
|
return self.node_type._meta.model
|
|
|
|
def get_manager(self):
|
|
if self.on:
|
|
return getattr(self.model, self.on)
|
|
else:
|
|
return self.model._default_manager
|
|
|
|
@classmethod
|
|
def resolve_queryset(cls, connection, queryset, info, args):
|
|
# queryset is the resolved iterable from ObjectType
|
|
return connection._meta.node.get_queryset(queryset, info)
|
|
|
|
@classmethod
|
|
def resolve_connection(cls, connection, args, iterable, max_limit=None):
|
|
iterable = maybe_queryset(iterable)
|
|
|
|
if isinstance(iterable, QuerySet):
|
|
list_length = iterable.count()
|
|
list_slice_length = (
|
|
min(max_limit, list_length) if max_limit is not None else list_length
|
|
)
|
|
else:
|
|
list_length = len(iterable)
|
|
list_slice_length = (
|
|
min(max_limit, list_length) if max_limit is not None else list_length
|
|
)
|
|
|
|
# If after is higher than list_length, connection_from_list_slice
|
|
# would try to do a negative slicing which makes django throw an
|
|
# AssertionError
|
|
after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
|
|
|
|
if max_limit is not None and "first" not in args:
|
|
args["first"] = max_limit
|
|
|
|
connection = connection_from_list_slice(
|
|
iterable[after:],
|
|
args,
|
|
slice_start=after,
|
|
list_length=list_length,
|
|
list_slice_length=list_slice_length,
|
|
connection_type=connection,
|
|
edge_type=connection.Edge,
|
|
pageinfo_type=PageInfo,
|
|
)
|
|
connection.iterable = iterable
|
|
connection.length = list_length
|
|
return connection
|
|
|
|
@classmethod
|
|
def connection_resolver(
|
|
cls,
|
|
resolver,
|
|
connection,
|
|
default_manager,
|
|
queryset_resolver,
|
|
max_limit,
|
|
enforce_first_or_last,
|
|
root,
|
|
info,
|
|
**args
|
|
):
|
|
first = args.get("first")
|
|
last = args.get("last")
|
|
|
|
if enforce_first_or_last:
|
|
assert first or last, (
|
|
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
|
|
).format(info.field_name)
|
|
|
|
if max_limit:
|
|
if first:
|
|
assert first <= max_limit, (
|
|
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
|
|
).format(first, info.field_name, max_limit)
|
|
args["first"] = min(first, max_limit)
|
|
|
|
if last:
|
|
assert last <= max_limit, (
|
|
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
|
|
).format(last, info.field_name, max_limit)
|
|
args["last"] = min(last, max_limit)
|
|
|
|
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
|
# or a resolve_foo (does not accept queryset)
|
|
iterable = resolver(root, info, **args)
|
|
if iterable is None:
|
|
iterable = default_manager
|
|
# thus the iterable gets refiltered by resolve_queryset
|
|
# but iterable might be promise
|
|
iterable = queryset_resolver(connection, iterable, info, args)
|
|
on_resolve = partial(
|
|
cls.resolve_connection, connection, args, max_limit=max_limit
|
|
)
|
|
|
|
if Promise.is_thenable(iterable):
|
|
return Promise.resolve(iterable).then(on_resolve)
|
|
|
|
return on_resolve(iterable)
|
|
|
|
def get_resolver(self, parent_resolver):
|
|
return partial(
|
|
self.connection_resolver,
|
|
parent_resolver,
|
|
self.connection_type,
|
|
self.get_manager(),
|
|
self.get_queryset_resolver(),
|
|
self.max_limit,
|
|
self.enforce_first_or_last,
|
|
)
|
|
|
|
def get_queryset_resolver(self):
|
|
return self.resolve_queryset
|