graphene-django/graphene_django/fields.py
Olivia Rodriguez Valdes 4c6e7209c3 Add viewer management
2019-01-04 15:02:50 -05:00

193 lines
6.5 KiB
Python

from functools import partial
from django.core.exceptions import PermissionDenied
from django.db.models.query import QuerySet
from promise import Promise
from graphene.types import Field, List
from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from .settings import graphene_settings
from .utils import maybe_queryset, has_permissions, resolve_bound_resolver
class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@property
def model(self):
return self.type.of_type._meta.node._meta.model
@staticmethod
def list_resolver(resolver, root, info, **args):
return maybe_queryset(resolver(root, info, **args))
def get_resolver(self, parent_resolver):
return partial(self.list_resolver, parent_resolver)
class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
)
self.enforce_first_or_last = kwargs.pop(
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(
_type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection
@property
def node_type(self):
return self.type._meta.node
@property
def model(self):
return self.node_type._meta.model
def get_manager(self):
if self.on:
return getattr(self.model, self.on)
else:
return self.model._default_manager
@classmethod
def merge_querysets(cls, default_queryset, queryset):
if default_queryset.query.distinct and not queryset.query.distinct:
queryset = queryset.distinct()
elif queryset.query.distinct and not default_queryset.query.distinct:
default_queryset = default_queryset.distinct()
return queryset & default_queryset
@classmethod
def resolve_connection(cls, connection, default_manager, args, iterable):
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
if iterable is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
_len = iterable.count()
else:
_len = len(iterable)
connection = connection_from_list_slice(
iterable,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection,
edge_type=connection.Edge,
pageinfo_type=PageInfo,
)
connection.iterable = iterable
connection.length = _len
return connection
@classmethod
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
if enforce_first_or_last:
assert first or last, (
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit)
args["first"] = min(first, max_limit)
if last:
assert last <= max_limit, (
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit)
args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
return on_resolve(iterable)
def get_resolver(self, parent_resolver):
return partial(
self.connection_resolver,
parent_resolver,
self.type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last,
)
class DjangoPermissionField(Field):
"""Class to manage permission for fields"""
def __init__(self, type, permissions, *args, **kwargs):
"""Get permissions to access a field"""
super(DjangoPermissionField, self).__init__(type, *args, **kwargs)
self.permissions = permissions
def get_viewer(self, root, info, **args):
"""Get viewer to verify permissions"""
return info.context.user
def permission_resolver(self, parent_resolver, raise_exception, root, info, **args):
"""
Middleware resolver to check viewer's permissions
:param parent_resolver: Field resolver
:param raise_exception: If True a PermissionDenied is raised
:param root: Schema root
:param info: Schema info
:param args: Schema args
:return: Resolved field. None if the viewer does not have permission to access the field.
"""
# Get viewer from context
user = self.get_viewer(root, info, **args)
if has_permissions(user, self.permissions):
if parent_resolver:
# A resolver is provided in the class
return resolve_bound_resolver(parent_resolver, root, info, **args)
# Get default resolver
elif raise_exception:
raise PermissionDenied()
return None
def get_resolver(self, parent_resolver):
"""Intercept resolver to analyse permissions"""
return partial(self.permission_resolver, parent_resolver, True)