from functools import partial from django.db.models.query import QuerySet from promise import Promise from neomodel import NodeSet from graphene.types import Field, List from graphene.types.scalars import Boolean, Int from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice from .settings import graphene_settings from .utils import maybe_queryset, is_parent_set class DjangoListField(Field): def __init__(self, _type, *args, **kwargs): super(DjangoListField, self).__init__(List(_type), *args, **kwargs) @property def model(self): return self.type.of_type._meta.node._meta.model @staticmethod def list_resolver(resolver, root, info, **args): return maybe_queryset(resolver(root, info, **args)) def get_resolver(self, parent_resolver): return partial(self.list_resolver, parent_resolver) class DjangoConnectionField(ConnectionField): def __init__(self, *args, **kwargs): self.on = "nodes" self.max_limit = kwargs.pop( "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT ) self.enforce_first_or_last = kwargs.pop( "enforce_first_or_last", graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, ) kwargs.setdefault('know_parent', Boolean(default_value=False, description='Know parent type in nodes?' ' \n Default = ')) kwargs['first'] = Int(default_value=100) super(DjangoConnectionField, self).__init__(*args, **kwargs) @property def type(self): from .types import DjangoObjectType _type = super(ConnectionField, self).type assert issubclass( _type, DjangoObjectType ), "DjangoConnectionField only accepts DjangoObjectType types" assert _type._meta.connection, "The type {} doesn't have a connection".format( _type.__name__ ) return _type._meta.connection @property def node_type(self): return self.type._meta.node @property def model(self): return self.node_type._meta.model def get_manager(self): if self.on: return getattr(self.model, self.on) else: return self.model._default_manager @classmethod def merge_querysets(cls, default_queryset, queryset): if default_queryset.query.distinct and not queryset.query.distinct: queryset = queryset.distinct() elif queryset.query.distinct and not default_queryset.query.distinct: default_queryset = default_queryset.distinct() return queryset & default_queryset @classmethod def resolve_connection(cls, connection, default_manager, args, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) _len = len(iterable) connection = connection_from_list_slice( iterable, args, slice_start=0, list_length=_len, list_slice_length=_len, connection_type=connection, edge_type=connection.Edge, pageinfo_type=PageInfo, ) connection.iterable = iterable connection.length = _len return connection @classmethod def connection_resolver_original( cls, resolver, connection, default_manager, max_limit, enforce_first_or_last, root, info, **args ): first = args.get("first") last = args.get("last") if enforce_first_or_last: assert first or last, ( "You must provide a `first` or `last` value to properly paginate the `{}` connection." ).format(info.field_name) if max_limit: if first: assert first <= max_limit, ( "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." ).format(first, info.field_name, max_limit) args["first"] = min(first, max_limit) if last: assert last <= max_limit, ( "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." ).format(last, info.field_name, max_limit) args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) return on_resolve(iterable) @classmethod def connection_resolver( cls, resolver, connection, default_manager, max_limit, enforce_first_or_last, root, info, **args ): _parent = is_parent_set(info) if not _parent: _parent = args.get('know_parent', False) def new_resolver(root, info, **kwargs): qs = resolver(root, info, **kwargs) if qs is None: qs = default_manager.filter() if _parent and root is not None: instances = [] for instance in qs: setattr(instance, '_parent', root) instances.append(instance) return instances return qs return cls.connection_resolver_original( new_resolver, connection, default_manager, max_limit, enforce_first_or_last, root, info, **args) def get_resolver(self, parent_resolver): return partial( self.connection_resolver, parent_resolver, self.type, self.get_manager(), self.max_limit, self.enforce_first_or_last, )