diff --git a/graphene_django/auth/decorators.py b/graphene_django/auth/decorators.py index 18b4a8d..f0cbfb4 100644 --- a/graphene_django/auth/decorators.py +++ b/graphene_django/auth/decorators.py @@ -1,37 +1,45 @@ from functools import wraps from django.core.exceptions import PermissionDenied +from ..fields import DjangoConnectionField -from .utils import has_perm, is_authorized_to_mutate_object, is_related_to_user +from .utils import has_perm -def node_require_permission(permissions, user_field=None): +def node_require_permission(permissions): def require_permission_decorator(func): @wraps(func) def func_wrapper(cls, info, id): - if user_field: - user_field is not None - if is_authorized_to_mutate_object(cls._meta.model, info.context.user, user_field): - return func(cls, info, id) - print("Has Perm Result", has_perm(permissions=permissions, context=info.context)) if has_perm(permissions=permissions, context=info.context): - print("Node has persmissions") return func(cls, info, id) raise PermissionDenied('Permission Denied') return func_wrapper return require_permission_decorator -def mutation_require_permission(permissions, model=None, user_field=None): +def mutation_require_permission(permissions): def require_permission_decorator(func): @wraps(func) def func_wrapper(cls, root, info, **input): - if model or user_field: - assert model is not None and user_field is not None - object_instance = cls._meta.model.objects.get(pk=id) - if is_related_to_user(object_instance, info.context.user, user_field): - return func(cls, root, info, **input) if has_perm(permissions=permissions, context=info.context): return func(cls, root, info, **input) return cls(errors=PermissionDenied('Permission Denied')) return func_wrapper return require_permission_decorator + + +def connection_require_permission(permissions): + def require_permission_decorator(func): + @wraps(func) + def func_wrapper( + cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, info, **args): + if has_perm(permissions=permissions, context=info.context): + print("Has Perms") + return func( + cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, info, **args) + return DjangoConnectionField.connection_resolver( + resolver, connection, [PermissionDenied('Permission Denied'), ], max_limit, + enforce_first_or_last, root, info, **args) + return func_wrapper + return require_permission_decorator diff --git a/graphene_django/auth/fields.py b/graphene_django/auth/fields.py deleted file mode 100644 index 108ac03..0000000 --- a/graphene_django/auth/fields.py +++ /dev/null @@ -1,25 +0,0 @@ - -from django.core.exceptions import PermissionDenied - -from .utils import has_perm -from ..fields import DjangoConnectionField - - -class AuthDjangoConnectionField(DjangoConnectionField): - - @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, info, **args): - """ - Resolve the required connection if the user in context has the permission required. If the user - does not have the required permission then returns a *Permission Denied* to the request. - """ - assert self._permissions is not None - if has_perm(self._permissions, info.context) is not True: - print(DjangoConnectionField) - return DjangoConnectionField.connection_resolver( - resolver, connection, [PermissionDenied('Permission Denied'), ], max_limit, - enforce_first_or_last, root, info, **args) - return super(AuthDjangoConnectionField, self).connection_resolver( - cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, info, **args) diff --git a/graphene_django/auth/utils.py b/graphene_django/auth/utils.py index 51ed60e..38a90f7 100644 --- a/graphene_django/auth/utils.py +++ b/graphene_django/auth/utils.py @@ -14,7 +14,7 @@ def is_related_to_user(object_instance, user, field): return False -def is_authorized_to_mutate_object(model, user, field): +def is_authorized_to_mutate_object(model, user, id, field): """Return True when the when the user is unauthorized.""" object_instance = model.objects.get(pk=id) if is_related_to_user(object_instance, user, field): @@ -26,19 +26,14 @@ def has_perm(permissions, context): """ Validates if the user in the context has the permission required. """ + assert permissions if context is None: return False user = context.user if user.is_authenticated() is False: return False - print("Username", user.username) - print("Username Auth", user.is_authenticated()) - - if type(permissions) is tuple: - print("permissions", permissions) - for permission in permissions: - print("User has perm", user.has_perm(permission)) - if not user.has_perm(permission): - return False + for permission in permissions: + if not user.has_perm(permission): + return False return True diff --git a/graphene_django/tests/test_auth.py b/graphene_django/tests/test_auth.py index 04c2681..7adb883 100644 --- a/graphene_django/tests/test_auth.py +++ b/graphene_django/tests/test_auth.py @@ -16,7 +16,8 @@ from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..settings import graphene_settings from .models import Article, Reporter -from ..auth.decorators import node_require_permission, mutation_require_permission +from ..auth.decorators import node_require_permission, mutation_require_permission, connection_require_permission +from ..auth.utils import is_related_to_user, is_authorized_to_mutate_object from ..rest_framework.mutation import SerializerMutation pytestmark = pytest.mark.django_db @@ -66,6 +67,12 @@ class MyModelSerializer(serializers.ModelSerializer): fields = '__all__' +class ArticleSerializer(serializers.ModelSerializer): + class Meta: + model = Article + fields = '__all__' + + class MySerializer(serializers.Serializer): text = serializers.CharField() model = MyModelSerializer() @@ -74,6 +81,58 @@ class MySerializer(serializers.Serializer): return validated_data +def test_is_related_to_user(): + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + r2 = Reporter.objects.create( + first_name='Michael', + last_name='Doe', + email='mdoe@example.com', + a_choice=1 + ) + a = Article.objects.create( + headline='Article Node 1', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='es' + ) + result_1 = is_related_to_user(a, r, 'reporter') + result_2 = is_related_to_user(a, r2, 'reporter') + assert result_1 is True + assert result_2 is False + + +def test_is_authorized_to_mutate_object(): + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + r2 = Reporter.objects.create( + first_name='Michael', + last_name='Doe', + email='mdoe@example.com', + a_choice=1 + ) + Article.objects.create( + headline='Article Node 1', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='es' + ) + result_1 = is_authorized_to_mutate_object(Article, r, 1, 'reporter') + result_2 = is_authorized_to_mutate_object(Article, r2, 1, 'reporter') + assert result_1 is True + assert result_2 is False + + def test_node_anonymous_user(): class ReporterType(DjangoObjectType): @@ -113,6 +172,43 @@ def test_node_anonymous_user(): } +def test_node_no_context(): + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + @classmethod + @node_require_permission(permissions=('can_view_foo', )) + def get_node(cls, info, id): + return super(ReporterType, cls).get_node(info, id) + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + + class Query(graphene.ObjectType): + reporter = Node.Field(ReporterType) + + schema = graphene.Schema(query=Query) + query = ''' + query { + reporter(id: "UmVwb3J0ZXJUeXBlOjE="){ + firstName + } + } + ''' + result = schema.execute(query) + assert result.errors + assert result.data == { + 'reporter': None + } + + def test_node_authenticated_user_no_permissions(): class ReporterType(DjangoObjectType): @@ -193,22 +289,172 @@ def test_node_authenticated_user_with_permissions(): } -def test_mutate_and_get_payload_success(): +def test_auth_mutate_and_get_payload_anonymous(): class MyMutation(SerializerMutation): class Meta: serializer_class = MySerializer + @classmethod + @mutation_require_permission(permissions=('can_view_foo', )) + def mutate_and_get_payload(cls, root, info, **input): + return super(MyMutation, cls).mutate_and_get_payload(root, info, **input) + + context = Context(user=user_anonymous) + request = Mock(context=context, user=user_anonymous) + result = MyMutation.mutate_and_get_payload(root=None, info=request, **{ + 'text': 'value', + 'model': { + 'cool_name': 'other_value' + } + }) + assert result.errors is not None + + +def test_auth_mutate_and_get_payload_autheticated(): + + class MyMutation(SerializerMutation): + class Meta: + serializer_class = MySerializer + + @classmethod + @mutation_require_permission(permissions=('can_view_foo', )) + def mutate_and_get_payload(cls, root, info, **input): + return super(MyMutation, cls).mutate_and_get_payload(root, info, **input) + + context = Context(user=user_authenticated) + request = Mock(context=context, user=user_authenticated) + result = MyMutation.mutate_and_get_payload(root=None, info=request, **{ + 'text': 'value', + 'model': { + 'cool_name': 'other_value' + } + }) + assert result.errors is not None + + +def test_auth_mutate_and_get_payload_with_permissions(): + + class MyMutation(SerializerMutation): + class Meta: + serializer_class = MySerializer + + @classmethod @mutation_require_permission(permissions=('can_view_foo', )) def mutate_and_get_payload(cls, root, info, **input): return super(MyMutation, cls).mutate_and_get_payload(root, info, **input) context = Context(user=user_with_permissions) request = Mock(context=context, user=user_with_permissions) - result = MyMutation.mutate_and_get_payload(None, request, **{ + result = MyMutation.mutate_and_get_payload(root=None, info=request, **{ 'text': 'value', 'model': { 'cool_name': 'other_value' } }) assert result.errors is None + + +def test_auth_connection(): + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class ArticleType(DjangoObjectType): + + class Meta: + model = Article + interfaces = (Node, ) + filter_fields = ('lang', 'headline') + + class MyAuthDjangoConnectionField(DjangoConnectionField): + + @classmethod + @connection_require_permission(permissions=('can_view_foo', )) + def connection_resolver(cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, info, **args): + return super(MyAuthDjangoConnectionField, cls).connection_resolver( + resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, info, **args) + + class Query(graphene.ObjectType): + all_reporters = MyAuthDjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + Article.objects.create( + headline='Article Node 1', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='es' + ) + Article.objects.create( + headline='Article Node 2', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='es' + ) + Article.objects.create( + headline='Article Node 3', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='en' + ) + + schema = graphene.Schema(query=Query) + query = ''' + query NodeFilteringQuery { + allReporters { + edges { + node { + id + articles(lang: "es", headline: "Article Node 1") { + edges { + node { + id + } + } + } + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=', + 'articles': { + 'edges': [{ + 'node': { + 'id': 'QXJ0aWNsZVR5cGU6MQ==' + } + }] + } + } + }] + } + } + + context = Context(user=user_with_permissions) + request = Mock(context=context, user=user_with_permissions) + result = schema.execute(query, context_value=request) + assert not result.errors + assert result.data == expected + + context = Context(user=user_anonymous) + request = Mock(context=context, user=user_anonymous) + result = schema.execute(query, context_value=request) + assert result.errors