diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 1ecce45..bbf1940 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -6,10 +6,11 @@ from promise import Promise from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo +from graphene.utils.get_unbound_function import get_unbound_function from graphql_relay.connection.arrayconnection import connection_from_list_slice from .settings import graphene_settings -from .utils import maybe_queryset +from .utils import maybe_queryset, auth_resolver class DjangoListField(Field): @@ -151,3 +152,20 @@ class DjangoConnectionField(ConnectionField): self.max_limit, self.enforce_first_or_last, ) + + +class PermissionField(Field): + """Class to manage permission for fields""" + + def __init__(self, type, permissions=(), permissions_resolver=auth_resolver, *args, **kwargs): + """Get permissions to access a field""" + super(PermissionField, self).__init__(type, *args, **kwargs) + self.permissions = permissions + self.permissions_resolver = permissions_resolver + + def get_resolver(self, parent_resolver): + """Intercept resolver to analyse permissions""" + parent_resolver = super(PermissionField, self).get_resolver(parent_resolver) + if self.permissions: + return partial(get_unbound_function(self.permissions_resolver), parent_resolver, self.permissions, True) + return parent_resolver diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py new file mode 100644 index 0000000..23cce7f --- /dev/null +++ b/graphene_django/tests/test_fields.py @@ -0,0 +1,46 @@ +from unittest import TestCase +from django.core.exceptions import PermissionDenied +from graphene_django.fields import PermissionField + + +class MyInstance(object): + value = "value" + + def resolver(self): + return "resolver method" + + +class PermissionFieldTests(TestCase): + + def test_permission_field(self): + MyType = object() + field = PermissionField(MyType, permissions=['perm1', 'perm2'], source='resolver') + resolver = field.get_resolver(None) + + class Viewer(object): + def has_perm(self, perm): + return perm == 'perm2' + + class Info(object): + class Context(object): + user = Viewer() + context = Context() + + self.assertEqual(resolver(MyInstance(), Info()), MyInstance().resolver()) + + def test_permission_field_without_permission(self): + MyType = object() + field = PermissionField(MyType, permissions=['perm1', 'perm2'], source='resolver') + resolver = field.get_resolver(field.resolver) + + class Viewer(object): + def has_perm(self, perm): + return False + + class Info(object): + class Context(object): + user = Viewer() + context = Context() + + with self.assertRaises(PermissionDenied): + resolver(MyInstance(), Info()) diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index becd031..e1068c1 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -1,4 +1,4 @@ -from ..utils import get_model_fields +from ..utils import get_model_fields, has_permissions from .models import Film, Reporter @@ -10,3 +10,23 @@ def test_get_model_fields_no_duplication(): film_fields = get_model_fields(Film) film_name_set = set([field[0] for field in film_fields]) assert len(film_fields) == len(film_name_set) + + +def test_has_permissions(): + class Viewer(object): + @staticmethod + def has_perm(permission): + return permission + + viewer_as_perm = has_permissions(Viewer(), [False, True, False]) + assert viewer_as_perm + + +def test_viewer_without_permissions(): + class Viewer(object): + @staticmethod + def has_perm(permission): + return permission + + viewer_as_perm = has_permissions(Viewer(), [False, False, False]) + assert not viewer_as_perm diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 560f604..f4800eb 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -1,10 +1,12 @@ import inspect +from django.core.exceptions import PermissionDenied from django.db import models from django.db.models.manager import Manager # from graphene.utils import LazyList +from graphene.utils.get_unbound_function import get_unbound_function class LazyList(object): @@ -81,3 +83,53 @@ def import_single_dispatch(): ) return singledispatch + + +def has_permissions(viewer, permissions): + """ + Verify that at least one permission is accomplished + :param viewer: Field's viewer + :param permissions: Field permissions + :return: True if viewer has permission. False otherwise. + """ + if not permissions: + return True + return any([viewer.has_perm(perm) for perm in permissions]) + + +def resolve_bound_resolver(resolver, root, info, **args): + """ + Resolve provided resolver + :param resolver: Explicit field resolver + :param root: Schema root + :param info: Schema info + :param args: Schema args + :return: Resolved field + """ + resolver = get_unbound_function(resolver) + return resolver(root, info, **args) + + +def auth_resolver(parent_resolver, permissions, raise_exception, root, info, **args): + """ + Middleware resolver to check viewer's permissions + :param parent_resolver: Field resolver + :param permissions: Field permissions + :param raise_exception: If True a PermissionDenied is raised + :param root: Schema root + :param info: Schema info + :param args: Schema args + :return: Resolved field. None if the viewer does not have permission to access the field. + """ + # Get viewer from context + if not hasattr(info.context, 'user'): + raise PermissionDenied() + user = info.context.user + + if has_permissions(user, permissions): + if parent_resolver: + # A resolver is provided in the class + return resolve_bound_resolver(parent_resolver, root, info, **args) + elif raise_exception: + raise PermissionDenied() + return None