diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 1ecce45..bbf1940 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -6,10 +6,11 @@ from promise import Promise from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo +from graphene.utils.get_unbound_function import get_unbound_function from graphql_relay.connection.arrayconnection import connection_from_list_slice from .settings import graphene_settings -from .utils import maybe_queryset +from .utils import maybe_queryset, auth_resolver class DjangoListField(Field): @@ -151,3 +152,20 @@ class DjangoConnectionField(ConnectionField): self.max_limit, self.enforce_first_or_last, ) + + +class PermissionField(Field): + """Class to manage permission for fields""" + + def __init__(self, type, permissions=(), permissions_resolver=auth_resolver, *args, **kwargs): + """Get permissions to access a field""" + super(PermissionField, self).__init__(type, *args, **kwargs) + self.permissions = permissions + self.permissions_resolver = permissions_resolver + + def get_resolver(self, parent_resolver): + """Intercept resolver to analyse permissions""" + parent_resolver = super(PermissionField, self).get_resolver(parent_resolver) + if self.permissions: + return partial(get_unbound_function(self.permissions_resolver), parent_resolver, self.permissions, True) + return parent_resolver diff --git a/graphene_django/forms/mutation.py b/graphene_django/forms/mutation.py index c39b3c3..1d1fe74 100644 --- a/graphene_django/forms/mutation.py +++ b/graphene_django/forms/mutation.py @@ -107,7 +107,7 @@ class DjangoFormMutation(BaseDjangoFormMutation): @classmethod def __init_subclass_with_meta__( - cls, form_class=None, only_fields=(), exclude_fields=(), **options + cls, form_class=None, mirror_input=False, only_fields=(), exclude_fields=(), **options ): if not form_class: @@ -115,7 +115,10 @@ class DjangoFormMutation(BaseDjangoFormMutation): form = form_class() input_fields = fields_for_form(form, only_fields, exclude_fields) - output_fields = fields_for_form(form, only_fields, exclude_fields) + if mirror_input: + output_fields = fields_for_form(form, only_fields, exclude_fields) + else: + output_fields = {} _meta = DjangoFormMutationOptions(cls) _meta.form_class = form_class diff --git a/graphene_django/forms/tests/test_mutation.py b/graphene_django/forms/tests/test_mutation.py index df0ffd5..1f39afe 100644 --- a/graphene_django/forms/tests/test_mutation.py +++ b/graphene_django/forms/tests/test_mutation.py @@ -139,3 +139,36 @@ class ModelFormMutationTests(TestCase): self.assertEqual(result.errors[0].messages, ["This field is required."]) self.assertIn("age", fields_w_error) self.assertEqual(result.errors[1].messages, ["This field is required."]) + + +class FormMutationTests(TestCase): + def test_default_meta_fields(self): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + self.assertNotIn("text", MyMutation._meta.fields) + + def test_mirror_meta_fields(self): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + mirror_input = True + + self.assertIn("text", MyMutation._meta.fields) + + def test_default_input_meta_fields(self): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + + self.assertIn("client_mutation_id", MyMutation.Input._meta.fields) + self.assertIn("text", MyMutation.Input._meta.fields) + + def test_exclude_fields_input_meta_fields(self): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + exclude_fields = ['text'] + + self.assertNotIn("text", MyMutation.Input._meta.fields) + self.assertIn("client_mutation_id", MyMutation.Input._meta.fields) diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py new file mode 100644 index 0000000..23cce7f --- /dev/null +++ b/graphene_django/tests/test_fields.py @@ -0,0 +1,46 @@ +from unittest import TestCase +from django.core.exceptions import PermissionDenied +from graphene_django.fields import PermissionField + + +class MyInstance(object): + value = "value" + + def resolver(self): + return "resolver method" + + +class PermissionFieldTests(TestCase): + + def test_permission_field(self): + MyType = object() + field = PermissionField(MyType, permissions=['perm1', 'perm2'], source='resolver') + resolver = field.get_resolver(None) + + class Viewer(object): + def has_perm(self, perm): + return perm == 'perm2' + + class Info(object): + class Context(object): + user = Viewer() + context = Context() + + self.assertEqual(resolver(MyInstance(), Info()), MyInstance().resolver()) + + def test_permission_field_without_permission(self): + MyType = object() + field = PermissionField(MyType, permissions=['perm1', 'perm2'], source='resolver') + resolver = field.get_resolver(field.resolver) + + class Viewer(object): + def has_perm(self, perm): + return False + + class Info(object): + class Context(object): + user = Viewer() + context = Context() + + with self.assertRaises(PermissionDenied): + resolver(MyInstance(), Info()) diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index becd031..e1068c1 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -1,4 +1,4 @@ -from ..utils import get_model_fields +from ..utils import get_model_fields, has_permissions from .models import Film, Reporter @@ -10,3 +10,23 @@ def test_get_model_fields_no_duplication(): film_fields = get_model_fields(Film) film_name_set = set([field[0] for field in film_fields]) assert len(film_fields) == len(film_name_set) + + +def test_has_permissions(): + class Viewer(object): + @staticmethod + def has_perm(permission): + return permission + + viewer_as_perm = has_permissions(Viewer(), [False, True, False]) + assert viewer_as_perm + + +def test_viewer_without_permissions(): + class Viewer(object): + @staticmethod + def has_perm(permission): + return permission + + viewer_as_perm = has_permissions(Viewer(), [False, False, False]) + assert not viewer_as_perm diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 560f604..f4800eb 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -1,10 +1,12 @@ import inspect +from django.core.exceptions import PermissionDenied from django.db import models from django.db.models.manager import Manager # from graphene.utils import LazyList +from graphene.utils.get_unbound_function import get_unbound_function class LazyList(object): @@ -81,3 +83,53 @@ def import_single_dispatch(): ) return singledispatch + + +def has_permissions(viewer, permissions): + """ + Verify that at least one permission is accomplished + :param viewer: Field's viewer + :param permissions: Field permissions + :return: True if viewer has permission. False otherwise. + """ + if not permissions: + return True + return any([viewer.has_perm(perm) for perm in permissions]) + + +def resolve_bound_resolver(resolver, root, info, **args): + """ + Resolve provided resolver + :param resolver: Explicit field resolver + :param root: Schema root + :param info: Schema info + :param args: Schema args + :return: Resolved field + """ + resolver = get_unbound_function(resolver) + return resolver(root, info, **args) + + +def auth_resolver(parent_resolver, permissions, raise_exception, root, info, **args): + """ + Middleware resolver to check viewer's permissions + :param parent_resolver: Field resolver + :param permissions: Field permissions + :param raise_exception: If True a PermissionDenied is raised + :param root: Schema root + :param info: Schema info + :param args: Schema args + :return: Resolved field. None if the viewer does not have permission to access the field. + """ + # Get viewer from context + if not hasattr(info.context, 'user'): + raise PermissionDenied() + user = info.context.user + + if has_permissions(user, permissions): + if parent_resolver: + # A resolver is provided in the class + return resolve_bound_resolver(parent_resolver, root, info, **args) + elif raise_exception: + raise PermissionDenied() + return None