from functools import partial, reduce from django.db.models.query import QuerySet from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo from graphene.utils.get_unbound_function import get_unbound_function from graphql_relay.connection.arrayconnection import connection_from_list_slice from promise import Promise from .settings import graphene_settings from .utils import maybe_queryset, auth_resolver class DjangoListField(Field): def __init__(self, _type, *args, **kwargs): super(DjangoListField, self).__init__(List(_type), *args, **kwargs) @property def model(self): return self.type.of_type._meta.node._meta.model @staticmethod def list_resolver(resolver, root, info, **args): return maybe_queryset(resolver(root, info, **args)) def get_resolver(self, parent_resolver): return partial(self.list_resolver, parent_resolver) class DjangoConnectionField(ConnectionField): def __init__(self, *args, **kwargs): self.on = kwargs.pop("on", False) self.max_limit = kwargs.pop( "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT ) self.enforce_first_or_last = kwargs.pop( "enforce_first_or_last", graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, ) super(DjangoConnectionField, self).__init__(*args, **kwargs) @property def type(self): from .types import DjangoObjectType _type = super(ConnectionField, self).type assert issubclass( _type, DjangoObjectType ), "DjangoConnectionField only accepts DjangoObjectType types" assert _type._meta.connection, "The type {} doesn't have a connection".format( _type.__name__ ) return _type._meta.connection @property def node_type(self): return self.type._meta.node @property def model(self): return self.node_type._meta.model def get_manager(self): if self.on: return getattr(self.model, self.on) else: return self.model._default_manager @classmethod def merge_querysets(cls, default_queryset, queryset): if default_queryset.query.distinct and not queryset.query.distinct: queryset = queryset.distinct() elif queryset.query.distinct and not default_queryset.query.distinct: default_queryset = default_queryset.distinct() return queryset & default_queryset @classmethod def resolve_connection(cls, connection, default_manager, args, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) if isinstance(iterable, QuerySet): if iterable is not default_manager: default_queryset = maybe_queryset(default_manager) iterable = cls.merge_querysets(default_queryset, iterable) _len = iterable.count() else: _len = len(iterable) connection = connection_from_list_slice( iterable, args, slice_start=0, list_length=_len, list_slice_length=_len, connection_type=connection, edge_type=connection.Edge, pageinfo_type=PageInfo, ) connection.iterable = iterable connection.length = _len return connection @classmethod def connection_resolver( cls, resolver, connection, default_manager, max_limit, enforce_first_or_last, root, info, **args ): first = args.get("first") last = args.get("last") if enforce_first_or_last: assert first or last, ( "You must provide a `first` or `last` value to properly paginate the `{}` connection." ).format(info.field_name) if max_limit: if first: assert first <= max_limit, ( "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." ).format(first, info.field_name, max_limit) args["first"] = min(first, max_limit) if last: assert last <= max_limit, ( "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." ).format(last, info.field_name, max_limit) args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) return on_resolve(iterable) def get_resolver(self, parent_resolver): return partial( self.connection_resolver, parent_resolver, self.type, self.get_manager(), self.max_limit, self.enforce_first_or_last, ) class DjangoField(Field): """Class to manage permission for fields""" def __init__(self, type, permissions=(), permissions_resolver=auth_resolver, *args, **kwargs): """Get permissions to access a field""" super(DjangoField, self).__init__(type, *args, **kwargs) self.permissions = permissions self.permissions_resolver = permissions_resolver def get_resolver(self, parent_resolver): """Intercept resolver to analyse permissions""" parent_resolver = super(DjangoField, self).get_resolver(parent_resolver) if self.permissions: return partial(get_unbound_function(self.permissions_resolver), parent_resolver, self.permissions, None, None, True) return parent_resolver class DataLoaderField(DjangoField): """Class to manage access to data-loader when resolve the field""" def __init__(self, type, data_loader, source_loader, load_many=False, *args, **kwargs): """ Initialization of data-loader to resolve field :param data_loader: data-loader to resolve field :param source_loader: field to obtain the key for data-loading :param load_many: Whether the resolver should try tu obtain one element or multiple elements :param kwargs: Extra arguments """ self.data_loader = data_loader self.source_loader = source_loader self.load_many = load_many super(DataLoaderField, self).__init__(type, *args, **kwargs) # If no resolver is explicitly provided, use dataloader self.resolver = self.resolver or self.resolver_data_loader def resolver_data_loader(self, root, info, *args, **kwargs): """Resolve field through dataloader""" if root: source_loader = reduce(lambda x, y: getattr(x, y), self.source_loader.split('.'), root) else: source_loader = kwargs.get(self.source_loader) if self.load_many: return self.data_loader.load_many(source_loader) if source_loader: return self.data_loader.load(source_loader) return None