From 3209dc02adf904ac01a01b065ec3ed4f9d691c34 Mon Sep 17 00:00:00 2001 From: Thomas Stephenson Date: Thu, 26 Mar 2015 07:32:02 +1100 Subject: [PATCH] Support for PrimaryKeyRelatedFields with arbitrary types Include support for relations where the target primary key is not an AutoField or IntegerField. CharFields may already have been supported? So this is mainly to support models with UUID primary keys. --- rest_framework/relations.py | 8 +++ rest_framework/serializers.py | 11 ++++ rest_framework/utils/field_mapping.py | 7 +++ tests/test_model_serializer.py | 81 ++++++++++++++++++++++++++- tests/test_relations.py | 37 ++++++++++++ 5 files changed, 143 insertions(+), 1 deletion(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 3a966c5bf..169898c31 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -134,10 +134,16 @@ class PrimaryKeyRelatedField(RelatedField): 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), } + def __init__(self, **kwargs): + self.pk_field = kwargs.pop('pk_field', None) + super(PrimaryKeyRelatedField, self).__init__(**kwargs) + def use_pk_only_optimization(self): return True def to_internal_value(self, data): + if self.pk_field is not None: + data = self.pk_field.to_internal_value(data) try: return self.get_queryset().get(pk=data) except ObjectDoesNotExist: @@ -146,6 +152,8 @@ class PrimaryKeyRelatedField(RelatedField): self.fail('incorrect_type', data_type=type(data).__name__) def to_representation(self, value): + if self.pk_field is not None: + return self.pk_field.to_representation(value.pk) return value.pk diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2eef6eeb5..cad9d0b44 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -11,6 +11,7 @@ python primitives. response content is handled by parsers and renderers. """ from __future__ import unicode_literals + from django.db import models from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField from django.db.models import query @@ -731,6 +732,7 @@ class ModelSerializer(Serializer): models.TimeField: TimeField, models.URLField: URLField, } + serializer_related_field = PrimaryKeyRelatedField serializer_url_field = HyperlinkedIdentityField serializer_choice_field = ChoiceField @@ -1021,6 +1023,15 @@ class ModelSerializer(Serializer): field_class = self.serializer_related_field field_kwargs = get_relation_kwargs(field_name, relation_info) + # `pk_field` is only valid for primary key relationships + pk_field_info = field_kwargs.pop('pk_field', None) + if issubclass(field_class, PrimaryKeyRelatedField) and pk_field_info is not None: + pk_field_class, pk_field_kwargs = pk_field_info + pk_field_class = self.serializer_field_mapping[pk_field_class] + if not issubclass(pk_field_class, ModelField): + pk_field_kwargs.pop('model_field', None) + field_kwargs['pk_field'] = pk_field_class(**pk_field_kwargs) + # `view_name` is only valid for hyperlinked relationships. if not issubclass(field_class, HyperlinkedRelatedField): field_kwargs.pop('view_name', None) diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index c97ec5d0e..991846d53 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -203,6 +203,13 @@ def get_relation_kwargs(field_name, relation_info): 'view_name': get_detail_view_name(related_model) } + related_pk = related_model._meta.pk + if (not isinstance(related_pk, models.AutoField) and + not getattr(related_pk, 'is_relation', False)): + pk_field_class = type(related_pk) + pk_field_kwargs = get_field_kwargs('pk', related_pk) + kwargs['pk_field'] = (pk_field_class, pk_field_kwargs) + if to_many: kwargs['many'] = True diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index bce2008a8..2ce06c0ac 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -279,6 +279,11 @@ class ThroughTargetModel(models.Model): name = models.CharField(max_length=100) +class StringForeignKeyTargetModel(models.Model): + id = models.CharField(primary_key=True, max_length=128) + name = models.CharField(max_length=100) + + class Supplementary(models.Model): extra = models.IntegerField() forwards = models.ForeignKey('ThroughTargetModel') @@ -291,6 +296,8 @@ class RelationalModel(models.Model): one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one') through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through') + string_foreign_key = models.ForeignKey(StringForeignKeyTargetModel, related_name='reverse_string_foreign_key') + class TestRelationalFieldMappings(TestCase): def test_pk_relations(self): @@ -303,6 +310,7 @@ class TestRelationalFieldMappings(TestCase): id = IntegerField(label='ID', read_only=True) foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all()) one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[]) + string_foreign_key = PrimaryKeyRelatedField(pk_field=CharField(label='Id', max_length=128, validators=[]), queryset=StringForeignKeyTargetModel.objects.all()) many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all()) through = PrimaryKeyRelatedField(many=True, read_only=True) """) @@ -323,6 +331,9 @@ class TestRelationalFieldMappings(TestCase): one_to_one = NestedSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) + string_foreign_key = NestedSerializer(read_only=True): + id = CharField(max_length=128, validators=[]) + name = CharField(max_length=100) many_to_many = NestedSerializer(many=True, read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) @@ -342,6 +353,7 @@ class TestRelationalFieldMappings(TestCase): url = HyperlinkedIdentityField(view_name='relationalmodel-detail') foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail') one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[], view_name='onetoonetargetmodel-detail') + string_foreign_key = HyperlinkedRelatedField(queryset=StringForeignKeyTargetModel.objects.all(), view_name='stringforeignkeytargetmodel-detail') many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') """) @@ -362,6 +374,9 @@ class TestRelationalFieldMappings(TestCase): one_to_one = NestedSerializer(read_only=True): url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') name = CharField(max_length=100) + string_foreign_key = NestedSerializer(read_only=True): + url = HyperlinkedIdentityField(view_name='stringforeignkeytargetmodel-detail') + name = CharField(max_length=100) many_to_many = NestedSerializer(many=True, read_only=True): url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail') name = CharField(max_length=100) @@ -427,6 +442,20 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(unicode_repr(TestSerializer()), expected) + def test_pk_reverse_string_foreign_key(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = StringForeignKeyTargetModel + fields = ('id', 'name', 'reverse_string_foreign_key') + + expected = dedent(""" + TestSerializer(): + id = CharField(max_length=128, validators=[]) + name = CharField(max_length=100) + reverse_string_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) + """) + self.assertEqual(unicode_repr(TestSerializer()), expected) + class TestIntegration(TestCase): def setUp(self): @@ -441,9 +470,14 @@ class TestIntegration(TestCase): name='many_to_many (%d)' % idx ) for idx in range(3) ] + self.string_foreign_key_target = StringForeignKeyTargetModel.objects.create( + id='stringified_id: 1', + name='string_foreign_key' + ) self.instance = RelationalModel.objects.create( foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, + string_foreign_key=self.string_foreign_key_target ) self.instance.many_to_many = self.many_to_many_targets self.instance.save() @@ -459,6 +493,7 @@ class TestIntegration(TestCase): 'foreign_key': self.foreign_key_target.pk, 'one_to_one': self.one_to_one_target.pk, 'many_to_many': [item.pk for item in self.many_to_many_targets], + 'string_foreign_key': self.string_foreign_key_target.pk, 'through': [] } self.assertEqual(serializer.data, expected) @@ -479,15 +514,20 @@ class TestIntegration(TestCase): name='new many_to_many (%d)' % idx ) for idx in range(3) ] + new_string_foreign_key = StringForeignKeyTargetModel.objects.create( + id='stringified_id: 2', + name='string_foreign_key' + ) data = { 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk, } # Serializer should validate okay. serializer = TestSerializer(data=data) - assert serializer.is_valid() + assert serializer.is_valid(raise_exception=True) # Creating the instance, relationship attributes should be set. instance = serializer.save() @@ -498,6 +538,8 @@ class TestIntegration(TestCase): ] == [ item.pk for item in new_many_to_many ] + + assert instance.string_foreign_key.pk == new_string_foreign_key.pk assert list(instance.through.all()) == [] # Representation should be correct. @@ -506,6 +548,7 @@ class TestIntegration(TestCase): 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk, 'through': [] } self.assertEqual(serializer.data, expected) @@ -526,10 +569,15 @@ class TestIntegration(TestCase): name='new many_to_many (%d)' % idx ) for idx in range(3) ] + new_string_foreign_key = StringForeignKeyTargetModel.objects.create( + id='stringified_id: 2', + name='string_foreign_key' + ) data = { 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk } # Serializer should validate okay. @@ -545,6 +593,7 @@ class TestIntegration(TestCase): ] == [ item.pk for item in new_many_to_many ] + assert instance.string_foreign_key.pk == new_string_foreign_key.pk assert list(instance.through.all()) == [] # Representation should be correct. @@ -553,6 +602,7 @@ class TestIntegration(TestCase): 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk, 'through': [] } self.assertEqual(serializer.data, expected) @@ -639,3 +689,32 @@ class TestSerializerMetaClass(TestCase): str(exception), "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer." ) +""" +{ + False = + { + = TestSerializer( + data= + { + 'foreign_key': 2, + 'many_to_many': [4, 5, 6], + 'one_to_one': 2, + + 'string_foreign_key': '' + }): id ...any=True, queryset=ManyToManyTargetModel.objects.all())\\n + through = PrimaryKeyRelatedField(many=True, read_only=True).is_valid\n + }()\n +} +""" diff --git a/tests/test_relations.py b/tests/test_relations.py index fbe176e24..45d5c1a1e 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,3 +1,4 @@ +import uuid from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset from django.core.exceptions import ImproperlyConfigured from django.utils.datastructures import MultiValueDict @@ -48,6 +49,42 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): assert representation == self.instance.pk +class TestUUIDPrimaryKeyRelatedField(APISimpleTestCase): + def setUp(self): + self.queryset = MockQueryset([ + MockObject(pk=uuid.UUID(int=1), name='foo'), + MockObject(pk=uuid.UUID(int=2), name='bar'), + MockObject(pk=uuid.UUID(int=3), name='baz') + ]) + self.instance = self.queryset.items[2] + self.field = serializers.PrimaryKeyRelatedField( + pk_field=serializers.UUIDField(), + queryset=self.queryset + ) + + def test_pk_related_lookup_exists(self): + instance = self.field.to_internal_value(self.instance.pk) + assert instance is self.instance + + def test_pk_related_lookup_does_not_exist(self): + bad_value = uuid.UUID(int=4) + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value(str(bad_value)) + msg = excinfo.value.detail[0] + assert msg == 'Invalid pk "{0}" - object does not exist.'.format(bad_value) + + def test_pk_related_lookup_invalid_type(self): + bad_value = BadType() + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value(bad_value) + msg = excinfo.value.detail[0] + assert msg == '"{0}" is not a valid UUID.'.format(bad_value) + + def test_pk_representation(self): + representation = self.field.to_representation(self.instance) + assert representation == str(self.instance.pk) + + class TestHyperlinkedIdentityField(APISimpleTestCase): def setUp(self): self.instance = MockObject(pk=1, name='foo')