2016-09-18 02:29:00 +03:00
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
from django.db.models.query import QuerySet
|
2016-09-18 03:09:56 +03:00
|
|
|
|
2017-05-20 02:12:28 +03:00
|
|
|
from promise import Promise
|
|
|
|
|
2016-09-21 07:25:05 +03:00
|
|
|
from graphene.types import Field, List
|
2016-09-18 02:29:00 +03:00
|
|
|
from graphene.relay import ConnectionField, PageInfo
|
|
|
|
from graphql_relay.connection.arrayconnection import connection_from_list_slice
|
2016-09-18 03:09:56 +03:00
|
|
|
|
2017-04-15 12:09:05 +03:00
|
|
|
from .settings import graphene_settings
|
2017-07-25 09:42:40 +03:00
|
|
|
from .utils import maybe_queryset
|
2016-09-18 02:29:00 +03:00
|
|
|
|
|
|
|
|
2016-09-22 05:32:39 +03:00
|
|
|
class DjangoListField(Field):
|
2016-09-21 07:25:05 +03:00
|
|
|
|
|
|
|
def __init__(self, _type, *args, **kwargs):
|
2016-10-16 04:19:34 +03:00
|
|
|
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
|
2016-09-22 05:32:39 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def model(self):
|
|
|
|
return self.type.of_type._meta.node._meta.model
|
2016-09-21 07:25:05 +03:00
|
|
|
|
|
|
|
@staticmethod
|
2017-07-28 19:43:27 +03:00
|
|
|
def list_resolver(resolver, root, info, **args):
|
|
|
|
return maybe_queryset(resolver(root, info, **args))
|
2016-09-21 07:25:05 +03:00
|
|
|
|
|
|
|
def get_resolver(self, parent_resolver):
|
2017-07-28 19:43:27 +03:00
|
|
|
return partial(self.list_resolver, parent_resolver)
|
2016-09-21 07:25:05 +03:00
|
|
|
|
|
|
|
|
2016-09-18 02:29:00 +03:00
|
|
|
class DjangoConnectionField(ConnectionField):
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.on = kwargs.pop('on', False)
|
2017-04-15 12:09:05 +03:00
|
|
|
self.max_limit = kwargs.pop(
|
|
|
|
'max_limit',
|
|
|
|
graphene_settings.RELAY_CONNECTION_MAX_LIMIT
|
|
|
|
)
|
|
|
|
self.enforce_first_or_last = kwargs.pop(
|
|
|
|
'enforce_first_or_last',
|
|
|
|
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST
|
|
|
|
)
|
2016-10-16 04:19:34 +03:00
|
|
|
super(DjangoConnectionField, self).__init__(*args, **kwargs)
|
2016-09-18 02:29:00 +03:00
|
|
|
|
2017-07-25 08:27:50 +03:00
|
|
|
@property
|
|
|
|
def type(self):
|
|
|
|
from .types import DjangoObjectType
|
|
|
|
_type = super(ConnectionField, self).type
|
|
|
|
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
|
2017-07-25 09:42:40 +03:00
|
|
|
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
|
2017-07-25 08:27:50 +03:00
|
|
|
return _type._meta.connection
|
|
|
|
|
2017-03-03 04:54:15 +03:00
|
|
|
@property
|
|
|
|
def node_type(self):
|
|
|
|
return self.type._meta.node
|
|
|
|
|
2016-09-18 02:29:00 +03:00
|
|
|
@property
|
|
|
|
def model(self):
|
2017-03-03 04:54:15 +03:00
|
|
|
return self.node_type._meta.model
|
2016-09-18 02:29:00 +03:00
|
|
|
|
|
|
|
def get_manager(self):
|
|
|
|
if self.on:
|
|
|
|
return getattr(self.model, self.on)
|
|
|
|
else:
|
|
|
|
return self.model._default_manager
|
|
|
|
|
2017-04-15 11:00:02 +03:00
|
|
|
@classmethod
|
|
|
|
def merge_querysets(cls, default_queryset, queryset):
|
2017-06-05 22:56:09 +03:00
|
|
|
return queryset & default_queryset
|
2017-04-15 11:00:02 +03:00
|
|
|
|
2017-05-20 02:33:00 +03:00
|
|
|
@classmethod
|
|
|
|
def resolve_connection(cls, connection, default_manager, args, iterable):
|
|
|
|
if iterable is None:
|
|
|
|
iterable = default_manager
|
|
|
|
iterable = maybe_queryset(iterable)
|
|
|
|
if isinstance(iterable, QuerySet):
|
|
|
|
if iterable is not default_manager:
|
|
|
|
default_queryset = maybe_queryset(default_manager)
|
|
|
|
iterable = cls.merge_querysets(default_queryset, iterable)
|
|
|
|
_len = iterable.count()
|
|
|
|
else:
|
|
|
|
_len = len(iterable)
|
|
|
|
connection = connection_from_list_slice(
|
|
|
|
iterable,
|
|
|
|
args,
|
|
|
|
slice_start=0,
|
|
|
|
list_length=_len,
|
|
|
|
list_slice_length=_len,
|
|
|
|
connection_type=connection,
|
|
|
|
edge_type=connection.Edge,
|
|
|
|
pageinfo_type=PageInfo,
|
|
|
|
)
|
|
|
|
connection.iterable = iterable
|
|
|
|
connection.length = _len
|
|
|
|
return connection
|
|
|
|
|
2017-04-15 11:00:02 +03:00
|
|
|
@classmethod
|
2017-04-15 12:09:05 +03:00
|
|
|
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
|
2017-07-28 19:43:27 +03:00
|
|
|
enforce_first_or_last, root, info, **args):
|
2017-04-15 12:09:05 +03:00
|
|
|
first = args.get('first')
|
|
|
|
last = args.get('last')
|
|
|
|
|
|
|
|
if enforce_first_or_last:
|
|
|
|
assert first or last, (
|
|
|
|
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
|
|
|
|
).format(info.field_name)
|
|
|
|
|
|
|
|
if max_limit:
|
|
|
|
if first:
|
|
|
|
assert first <= max_limit, (
|
|
|
|
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
|
|
|
|
).format(first, info.field_name, max_limit)
|
|
|
|
args['first'] = min(first, max_limit)
|
|
|
|
|
|
|
|
if last:
|
|
|
|
assert last <= max_limit, (
|
|
|
|
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
|
2017-12-12 20:24:11 +03:00
|
|
|
).format(last, info.field_name, max_limit)
|
2017-04-15 12:09:05 +03:00
|
|
|
args['last'] = min(last, max_limit)
|
|
|
|
|
2017-07-28 19:43:27 +03:00
|
|
|
iterable = resolver(root, info, **args)
|
2017-05-20 02:33:00 +03:00
|
|
|
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
|
|
|
|
|
|
|
|
if Promise.is_thenable(iterable):
|
|
|
|
return Promise.resolve(iterable).then(on_resolve)
|
|
|
|
|
|
|
|
return on_resolve(iterable)
|
2016-09-18 02:29:00 +03:00
|
|
|
|
|
|
|
def get_resolver(self, parent_resolver):
|
2017-07-28 19:43:27 +03:00
|
|
|
return partial(
|
2017-04-15 12:09:05 +03:00
|
|
|
self.connection_resolver,
|
|
|
|
parent_resolver,
|
|
|
|
self.type,
|
|
|
|
self.get_manager(),
|
|
|
|
self.max_limit,
|
|
|
|
self.enforce_first_or_last
|
2017-07-28 19:43:27 +03:00
|
|
|
)
|