diff --git a/rest_framework/fields.py b/rest_framework/fields.py index bea773001..de0cf4eaa 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -639,20 +639,37 @@ class URLField(CharField): class UUIDField(Field): + valid_formats = ('hex_verbose', 'hex', 'int', 'urn') + default_error_messages = { 'invalid': _('"{value}" is not a valid UUID.'), } + def __init__(self, **kwargs): + self.uuid_format = kwargs.pop('format', None) or 'hex_verbose' + if self.uuid_format not in self.valid_formats: + raise ValueError( + 'Invalid format for uuid representation. ' + 'Must be one of "{0}"'.format('", "'.join(self.valid_formats)) + ) + super(UUIDField, self).__init__(**kwargs) + def to_internal_value(self, data): if not isinstance(data, uuid.UUID): try: - return uuid.UUID(data) + if self.uuid_format == 'int': + return uuid.UUID(int=data) + else: + return uuid.UUID(hex=data) except (ValueError, TypeError): self.fail('invalid', value=data) return data def to_representation(self, value): - return str(value) + if self.uuid_format == 'hex_verbose': + return str(value) + else: + return getattr(value, self.uuid_format) # Number types... diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 3a966c5bf..169898c31 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -134,10 +134,16 @@ class PrimaryKeyRelatedField(RelatedField): 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), } + def __init__(self, **kwargs): + self.pk_field = kwargs.pop('pk_field', None) + super(PrimaryKeyRelatedField, self).__init__(**kwargs) + def use_pk_only_optimization(self): return True def to_internal_value(self, data): + if self.pk_field is not None: + data = self.pk_field.to_internal_value(data) try: return self.get_queryset().get(pk=data) except ObjectDoesNotExist: @@ -146,6 +152,8 @@ class PrimaryKeyRelatedField(RelatedField): self.fail('incorrect_type', data_type=type(data).__name__) def to_representation(self, value): + if self.pk_field is not None: + return self.pk_field.to_representation(value.pk) return value.pk diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2eef6eeb5..cad9d0b44 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -11,6 +11,7 @@ python primitives. response content is handled by parsers and renderers. """ from __future__ import unicode_literals + from django.db import models from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField from django.db.models import query @@ -731,6 +732,7 @@ class ModelSerializer(Serializer): models.TimeField: TimeField, models.URLField: URLField, } + serializer_related_field = PrimaryKeyRelatedField serializer_url_field = HyperlinkedIdentityField serializer_choice_field = ChoiceField @@ -1021,6 +1023,15 @@ class ModelSerializer(Serializer): field_class = self.serializer_related_field field_kwargs = get_relation_kwargs(field_name, relation_info) + # `pk_field` is only valid for primary key relationships + pk_field_info = field_kwargs.pop('pk_field', None) + if issubclass(field_class, PrimaryKeyRelatedField) and pk_field_info is not None: + pk_field_class, pk_field_kwargs = pk_field_info + pk_field_class = self.serializer_field_mapping[pk_field_class] + if not issubclass(pk_field_class, ModelField): + pk_field_kwargs.pop('model_field', None) + field_kwargs['pk_field'] = pk_field_class(**pk_field_kwargs) + # `view_name` is only valid for hyperlinked relationships. if not issubclass(field_class, HyperlinkedRelatedField): field_kwargs.pop('view_name', None) diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index c97ec5d0e..991846d53 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -203,6 +203,13 @@ def get_relation_kwargs(field_name, relation_info): 'view_name': get_detail_view_name(related_model) } + related_pk = related_model._meta.pk + if (not isinstance(related_pk, models.AutoField) and + not getattr(related_pk, 'is_relation', False)): + pk_field_class = type(related_pk) + pk_field_kwargs = get_field_kwargs('pk', related_pk) + kwargs['pk_field'] = (pk_field_class, pk_field_kwargs) + if to_many: kwargs['many'] = True diff --git a/tests/test_fields.py b/tests/test_fields.py index 1aa528da6..63288e372 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -526,7 +526,8 @@ class TestUUIDField(FieldValues): """ valid_inputs = { '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'), - '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda') + '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'), + 'urn:uuid:213b7d9b-244f-410d-828c-dabce7a2615d': uuid.UUID('213b7d9b-244f-410d-828c-dabce7a2615d'), } invalid_inputs = { '825d7aeb-05a9-45b5-a5b7': ['"825d7aeb-05a9-45b5-a5b7" is not a valid UUID.'] @@ -536,6 +537,18 @@ class TestUUIDField(FieldValues): } field = serializers.UUIDField() + def _test_format(self, uuid_format, formatted_uuid_0): + field = serializers.UUIDField(format=uuid_format) + assert field.to_representation(uuid.UUID(int=0)) == formatted_uuid_0 + assert field.to_internal_value(formatted_uuid_0) == uuid.UUID(int=0) + + def test_formats(self): + self._test_format('int', 0) + self._test_format(None, '00000000-0000-0000-0000-000000000000') + self._test_format('hex_verbose', '00000000-0000-0000-0000-000000000000') + self._test_format('urn', 'urn:uuid:00000000-0000-0000-0000-000000000000') + self._test_format('hex', '0' * 32) + # Number types... diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index bce2008a8..2ce06c0ac 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -279,6 +279,11 @@ class ThroughTargetModel(models.Model): name = models.CharField(max_length=100) +class StringForeignKeyTargetModel(models.Model): + id = models.CharField(primary_key=True, max_length=128) + name = models.CharField(max_length=100) + + class Supplementary(models.Model): extra = models.IntegerField() forwards = models.ForeignKey('ThroughTargetModel') @@ -291,6 +296,8 @@ class RelationalModel(models.Model): one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one') through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through') + string_foreign_key = models.ForeignKey(StringForeignKeyTargetModel, related_name='reverse_string_foreign_key') + class TestRelationalFieldMappings(TestCase): def test_pk_relations(self): @@ -303,6 +310,7 @@ class TestRelationalFieldMappings(TestCase): id = IntegerField(label='ID', read_only=True) foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all()) one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[]) + string_foreign_key = PrimaryKeyRelatedField(pk_field=CharField(label='Id', max_length=128, validators=[]), queryset=StringForeignKeyTargetModel.objects.all()) many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all()) through = PrimaryKeyRelatedField(many=True, read_only=True) """) @@ -323,6 +331,9 @@ class TestRelationalFieldMappings(TestCase): one_to_one = NestedSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) + string_foreign_key = NestedSerializer(read_only=True): + id = CharField(max_length=128, validators=[]) + name = CharField(max_length=100) many_to_many = NestedSerializer(many=True, read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) @@ -342,6 +353,7 @@ class TestRelationalFieldMappings(TestCase): url = HyperlinkedIdentityField(view_name='relationalmodel-detail') foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail') one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[], view_name='onetoonetargetmodel-detail') + string_foreign_key = HyperlinkedRelatedField(queryset=StringForeignKeyTargetModel.objects.all(), view_name='stringforeignkeytargetmodel-detail') many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') """) @@ -362,6 +374,9 @@ class TestRelationalFieldMappings(TestCase): one_to_one = NestedSerializer(read_only=True): url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') name = CharField(max_length=100) + string_foreign_key = NestedSerializer(read_only=True): + url = HyperlinkedIdentityField(view_name='stringforeignkeytargetmodel-detail') + name = CharField(max_length=100) many_to_many = NestedSerializer(many=True, read_only=True): url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail') name = CharField(max_length=100) @@ -427,6 +442,20 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(unicode_repr(TestSerializer()), expected) + def test_pk_reverse_string_foreign_key(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = StringForeignKeyTargetModel + fields = ('id', 'name', 'reverse_string_foreign_key') + + expected = dedent(""" + TestSerializer(): + id = CharField(max_length=128, validators=[]) + name = CharField(max_length=100) + reverse_string_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) + """) + self.assertEqual(unicode_repr(TestSerializer()), expected) + class TestIntegration(TestCase): def setUp(self): @@ -441,9 +470,14 @@ class TestIntegration(TestCase): name='many_to_many (%d)' % idx ) for idx in range(3) ] + self.string_foreign_key_target = StringForeignKeyTargetModel.objects.create( + id='stringified_id: 1', + name='string_foreign_key' + ) self.instance = RelationalModel.objects.create( foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, + string_foreign_key=self.string_foreign_key_target ) self.instance.many_to_many = self.many_to_many_targets self.instance.save() @@ -459,6 +493,7 @@ class TestIntegration(TestCase): 'foreign_key': self.foreign_key_target.pk, 'one_to_one': self.one_to_one_target.pk, 'many_to_many': [item.pk for item in self.many_to_many_targets], + 'string_foreign_key': self.string_foreign_key_target.pk, 'through': [] } self.assertEqual(serializer.data, expected) @@ -479,15 +514,20 @@ class TestIntegration(TestCase): name='new many_to_many (%d)' % idx ) for idx in range(3) ] + new_string_foreign_key = StringForeignKeyTargetModel.objects.create( + id='stringified_id: 2', + name='string_foreign_key' + ) data = { 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk, } # Serializer should validate okay. serializer = TestSerializer(data=data) - assert serializer.is_valid() + assert serializer.is_valid(raise_exception=True) # Creating the instance, relationship attributes should be set. instance = serializer.save() @@ -498,6 +538,8 @@ class TestIntegration(TestCase): ] == [ item.pk for item in new_many_to_many ] + + assert instance.string_foreign_key.pk == new_string_foreign_key.pk assert list(instance.through.all()) == [] # Representation should be correct. @@ -506,6 +548,7 @@ class TestIntegration(TestCase): 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk, 'through': [] } self.assertEqual(serializer.data, expected) @@ -526,10 +569,15 @@ class TestIntegration(TestCase): name='new many_to_many (%d)' % idx ) for idx in range(3) ] + new_string_foreign_key = StringForeignKeyTargetModel.objects.create( + id='stringified_id: 2', + name='string_foreign_key' + ) data = { 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk } # Serializer should validate okay. @@ -545,6 +593,7 @@ class TestIntegration(TestCase): ] == [ item.pk for item in new_many_to_many ] + assert instance.string_foreign_key.pk == new_string_foreign_key.pk assert list(instance.through.all()) == [] # Representation should be correct. @@ -553,6 +602,7 @@ class TestIntegration(TestCase): 'foreign_key': new_foreign_key.pk, 'one_to_one': new_one_to_one.pk, 'many_to_many': [item.pk for item in new_many_to_many], + 'string_foreign_key': new_string_foreign_key.pk, 'through': [] } self.assertEqual(serializer.data, expected) @@ -639,3 +689,32 @@ class TestSerializerMetaClass(TestCase): str(exception), "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer." ) +""" +{ + False = + { + = TestSerializer( + data= + { + 'foreign_key': 2, + 'many_to_many': [4, 5, 6], + 'one_to_one': 2, + + 'string_foreign_key': '' + }): id ...any=True, queryset=ManyToManyTargetModel.objects.all())\\n + through = PrimaryKeyRelatedField(many=True, read_only=True).is_valid\n + }()\n +} +""" diff --git a/tests/test_relations.py b/tests/test_relations.py index fbe176e24..45d5c1a1e 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,3 +1,4 @@ +import uuid from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset from django.core.exceptions import ImproperlyConfigured from django.utils.datastructures import MultiValueDict @@ -48,6 +49,42 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): assert representation == self.instance.pk +class TestUUIDPrimaryKeyRelatedField(APISimpleTestCase): + def setUp(self): + self.queryset = MockQueryset([ + MockObject(pk=uuid.UUID(int=1), name='foo'), + MockObject(pk=uuid.UUID(int=2), name='bar'), + MockObject(pk=uuid.UUID(int=3), name='baz') + ]) + self.instance = self.queryset.items[2] + self.field = serializers.PrimaryKeyRelatedField( + pk_field=serializers.UUIDField(), + queryset=self.queryset + ) + + def test_pk_related_lookup_exists(self): + instance = self.field.to_internal_value(self.instance.pk) + assert instance is self.instance + + def test_pk_related_lookup_does_not_exist(self): + bad_value = uuid.UUID(int=4) + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value(str(bad_value)) + msg = excinfo.value.detail[0] + assert msg == 'Invalid pk "{0}" - object does not exist.'.format(bad_value) + + def test_pk_related_lookup_invalid_type(self): + bad_value = BadType() + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value(bad_value) + msg = excinfo.value.detail[0] + assert msg == '"{0}" is not a valid UUID.'.format(bad_value) + + def test_pk_representation(self): + representation = self.field.to_representation(self.instance) + assert representation == str(self.instance.pk) + + class TestHyperlinkedIdentityField(APISimpleTestCase): def setUp(self): self.instance = MockObject(pk=1, name='foo')