diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 54e67cd16..91612e970 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -22,6 +22,7 @@ from rest_framework.fields import ( from rest_framework.reverse import reverse from rest_framework.settings import api_settings from rest_framework.utils import html +from rest_framework.validators import ValidateSetRelationPermission def method_overridden(method_name, klass, instance): @@ -105,6 +106,12 @@ class RelatedField(Field): ) kwargs.pop('many', None) kwargs.pop('allow_empty', None) + try: + permission = kwargs.pop('permission') + except KeyError: + pass + else: + self.validators.append(ValidateSetRelationPermission(permission)) super(RelatedField, self).__init__(**kwargs) def __new__(cls, *args, **kwargs): diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 7f7740711..b413af18a 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -12,7 +12,7 @@ from django.db import DataError from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.utils.representation import smart_repr @@ -284,3 +284,19 @@ class UniqueForYearValidator(BaseUniqueForValidator): filter_kwargs[self.field_name] = value filter_kwargs['%s__year' % self.date_field_name] = date.year return qs_filter(queryset, **filter_kwargs) + + +class ValidateSetRelationPermission(object): + def __init__(self, permission): + self.permission = permission + self.request = None + + def set_context(self, field): + self.field_name = field.source_attrs[-1] + self.request = field.parent.context.get('request', None) + + def __call__(self, value): + if not getattr(self.request, 'user', None): + return + if not self.request.user.has_perm(self.permission, value): + raise PermissionDenied(detail='You are not allowed to set a relationship on %s field.' % self.field_name) diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index 887a6f423..42cadf46a 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -1,10 +1,14 @@ from __future__ import unicode_literals from django.conf.urls import url +from django.contrib.auth.models import Permission, User from django.test import TestCase, override_settings +import pytest from rest_framework import serializers +from rest_framework.exceptions import PermissionDenied from rest_framework.test import APIRequestFactory +from rest_framework.request import Request from tests.models import ( ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget, NullableForeignKeySource, NullableOneToOneSource, OneToOneTarget @@ -56,6 +60,15 @@ class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): fields = ('url', 'name', 'target') +class ForeignKeySourceSerializerWithPermission(serializers.HyperlinkedModelSerializer): + class Meta: + model = ForeignKeySource + fields = ('url', 'name', 'target') + extra_kwargs = { + 'target': {'permission': 'tests.add_foreignkeytarget'} + } + + # Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: @@ -206,6 +219,7 @@ class HyperlinkedForeignKeyTests(TestCase): for idx in range(1, 4): source = ForeignKeySource(name='source-%d' % idx, target=target) source.save() + User.objects.create_user('permitted', 'permitted@example.com', 'password') def test_foreign_key_retrieve(self): queryset = ForeignKeySource.objects.all() @@ -324,6 +338,64 @@ class HyperlinkedForeignKeyTests(TestCase): assert not serializer.is_valid() assert serializer.errors == {'target': ['This field may not be null.']} + def test_foreign_key_create_with_permission(self): + from django.contrib.auth.backends import ModelBackend + + def get_all_permissions(self, user_obj, obj=None): + if not user_obj.is_active or user_obj.is_anonymous: + return set() + if not hasattr(user_obj, '_perm_cache'): + user_obj._perm_cache = self.get_user_permissions(user_obj) + user_obj._perm_cache.update(self.get_group_permissions(user_obj)) + return user_obj._perm_cache + + def _get_permissions(self, user_obj, obj, from_name): + if not user_obj.is_active or user_obj.is_anonymous: + return set() + + perm_cache_name = '_%s_perm_cache' % from_name + if not hasattr(user_obj, perm_cache_name): + if user_obj.is_superuser: + perms = Permission.objects.all() + else: + perms = getattr(self, '_get_%s_permissions' % from_name)(user_obj) + perms = perms.values_list('content_type__app_label', 'codename').order_by() + setattr(user_obj, perm_cache_name, set("%s.%s" % (ct, name) for ct, name in perms)) + return getattr(user_obj, perm_cache_name) + + # normally django.contrib.auth.backends.ModelBackend doesn't accept + # permission checking if an object is passed. + # Here the monkey patching pretend it is OK to pass such obj. + original_get_all_permissions = ModelBackend.get_all_permissions + original_get_permissions = ModelBackend._get_permissions + + try: + ModelBackend.get_all_permissions = get_all_permissions + ModelBackend._get_permissions = _get_permissions + request = Request(factory.post('/')) + user = User.objects.get(username='permitted') + request.user = user + + data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} + serializer = ForeignKeySourceSerializerWithPermission(data=data, context={'request': request}) + with pytest.raises(PermissionDenied) as excinfo: + serializer.is_valid() + assert ('You are not allowed to set a relationship on target field.' in str(excinfo.value)) + + permission = Permission.objects.get(codename='add_foreignkeytarget') + user.user_permissions.add(permission) + user = User.objects.get(username='permitted') + request.user = user + + serializer = ForeignKeySourceSerializerWithPermission(data=data, context={'request': request}) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.target == ForeignKeyTarget.objects.get(name='target-2') + finally: + ModelBackend.get_all_permissions = original_get_all_permissions + ModelBackend._get_permissions = original_get_permissions + @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') class HyperlinkedNullableForeignKeyTests(TestCase):