diff --git a/graphene/types/field.py b/graphene/types/field.py index a1428632..2c5aa459 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -1,4 +1,5 @@ import inspect +from graphql.error import GraphQLError from collections import Mapping, OrderedDict from functools import partial @@ -31,6 +32,7 @@ class Field(MountedType): required=False, _creation_counter=None, default_value=None, + permission_classes=[], **extra_args ): super(Field, self).__init__(_creation_counter=_creation_counter) @@ -66,10 +68,39 @@ class Field(MountedType): self.deprecation_reason = deprecation_reason self.description = description self.default_value = default_value + self.permission_classes = permission_classes + + def get_permissions(self): + """ + Instantiates and returns the list of permissions that this field requires. + """ + return [permission() for permission in self.permission_classes] + + def check_permissions(self, info): + for permission in self.get_permissions(): + if not permission.has_permission(info, self): + self.permission_denied( + info, message=getattr(permission, 'message', None) + ) + + def permission_denied(self, info, message=None): + raise GraphQLError(message) @property def type(self): return get_type(self._type) def get_resolver(self, parent_resolver): - return self.resolver or parent_resolver + _resolver = self.resolver or parent_resolver + + if not _resolver: + return None + + def resolver(root, info, *args, **kwargs): + + # TODO: pass root? + self.check_permissions(info) + + return _resolver(root, info, *args, **kwargs) + + return resolver diff --git a/graphene/types/tests/test_field_permissions.py b/graphene/types/tests/test_field_permissions.py new file mode 100644 index 00000000..7b7d8322 --- /dev/null +++ b/graphene/types/tests/test_field_permissions.py @@ -0,0 +1,44 @@ +from functools import partial + +import pytest +from graphql.error import GraphQLError + +from ..argument import Argument +from ..field import Field +from ..scalars import String +from ..structures import NonNull +from .utils import MyLazyType + + +class MyInstance(object): + value = "value" + value_func = staticmethod(lambda: "value_func") + + def value_method(self): + return "value_method" + + +class AlwaysFalsePermission(object): + def has_permission(self, info, field): + return False + +class AlwaysTruePermission(object): + def has_permission(self, info, field): + return True + + +def test_raises_error(): + MyType = object() + field = Field(MyType, source="value", permission_classes=[AlwaysFalsePermission]) + + with pytest.raises(GraphQLError): + + field.get_resolver(None)(MyInstance(), None) + + # TODO: test error message + +def test_does_not_raise_error(): + MyType = object() + field = Field(MyType, source="value", permission_classes=[AlwaysTruePermission]) + + field.get_resolver(None)(MyInstance(), None)