mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 05:50:13 +03:00
Support for PrimaryKeyRelatedFields with arbitrary types
Include support for relations where the target primary key is not an AutoField or IntegerField. CharFields may already have been supported? So this is mainly to support models with UUID primary keys.
This commit is contained in:
parent
079da5e7f0
commit
3209dc02ad
|
@ -134,10 +134,16 @@ class PrimaryKeyRelatedField(RelatedField):
|
|||
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.pk_field = kwargs.pop('pk_field', None)
|
||||
super(PrimaryKeyRelatedField, self).__init__(**kwargs)
|
||||
|
||||
def use_pk_only_optimization(self):
|
||||
return True
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if self.pk_field is not None:
|
||||
data = self.pk_field.to_internal_value(data)
|
||||
try:
|
||||
return self.get_queryset().get(pk=data)
|
||||
except ObjectDoesNotExist:
|
||||
|
@ -146,6 +152,8 @@ class PrimaryKeyRelatedField(RelatedField):
|
|||
self.fail('incorrect_type', data_type=type(data).__name__)
|
||||
|
||||
def to_representation(self, value):
|
||||
if self.pk_field is not None:
|
||||
return self.pk_field.to_representation(value.pk)
|
||||
return value.pk
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ python primitives.
|
|||
response content is handled by parsers and renderers.
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import models
|
||||
from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField
|
||||
from django.db.models import query
|
||||
|
@ -731,6 +732,7 @@ class ModelSerializer(Serializer):
|
|||
models.TimeField: TimeField,
|
||||
models.URLField: URLField,
|
||||
}
|
||||
|
||||
serializer_related_field = PrimaryKeyRelatedField
|
||||
serializer_url_field = HyperlinkedIdentityField
|
||||
serializer_choice_field = ChoiceField
|
||||
|
@ -1021,6 +1023,15 @@ class ModelSerializer(Serializer):
|
|||
field_class = self.serializer_related_field
|
||||
field_kwargs = get_relation_kwargs(field_name, relation_info)
|
||||
|
||||
# `pk_field` is only valid for primary key relationships
|
||||
pk_field_info = field_kwargs.pop('pk_field', None)
|
||||
if issubclass(field_class, PrimaryKeyRelatedField) and pk_field_info is not None:
|
||||
pk_field_class, pk_field_kwargs = pk_field_info
|
||||
pk_field_class = self.serializer_field_mapping[pk_field_class]
|
||||
if not issubclass(pk_field_class, ModelField):
|
||||
pk_field_kwargs.pop('model_field', None)
|
||||
field_kwargs['pk_field'] = pk_field_class(**pk_field_kwargs)
|
||||
|
||||
# `view_name` is only valid for hyperlinked relationships.
|
||||
if not issubclass(field_class, HyperlinkedRelatedField):
|
||||
field_kwargs.pop('view_name', None)
|
||||
|
|
|
@ -203,6 +203,13 @@ def get_relation_kwargs(field_name, relation_info):
|
|||
'view_name': get_detail_view_name(related_model)
|
||||
}
|
||||
|
||||
related_pk = related_model._meta.pk
|
||||
if (not isinstance(related_pk, models.AutoField) and
|
||||
not getattr(related_pk, 'is_relation', False)):
|
||||
pk_field_class = type(related_pk)
|
||||
pk_field_kwargs = get_field_kwargs('pk', related_pk)
|
||||
kwargs['pk_field'] = (pk_field_class, pk_field_kwargs)
|
||||
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
|
||||
|
|
|
@ -279,6 +279,11 @@ class ThroughTargetModel(models.Model):
|
|||
name = models.CharField(max_length=100)
|
||||
|
||||
|
||||
class StringForeignKeyTargetModel(models.Model):
|
||||
id = models.CharField(primary_key=True, max_length=128)
|
||||
name = models.CharField(max_length=100)
|
||||
|
||||
|
||||
class Supplementary(models.Model):
|
||||
extra = models.IntegerField()
|
||||
forwards = models.ForeignKey('ThroughTargetModel')
|
||||
|
@ -291,6 +296,8 @@ class RelationalModel(models.Model):
|
|||
one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one')
|
||||
through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through')
|
||||
|
||||
string_foreign_key = models.ForeignKey(StringForeignKeyTargetModel, related_name='reverse_string_foreign_key')
|
||||
|
||||
|
||||
class TestRelationalFieldMappings(TestCase):
|
||||
def test_pk_relations(self):
|
||||
|
@ -303,6 +310,7 @@ class TestRelationalFieldMappings(TestCase):
|
|||
id = IntegerField(label='ID', read_only=True)
|
||||
foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all())
|
||||
one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>])
|
||||
string_foreign_key = PrimaryKeyRelatedField(pk_field=CharField(label='Id', max_length=128, validators=[<UniqueValidator(queryset=StringForeignKeyTargetModel.objects.all())>]), queryset=StringForeignKeyTargetModel.objects.all())
|
||||
many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all())
|
||||
through = PrimaryKeyRelatedField(many=True, read_only=True)
|
||||
""")
|
||||
|
@ -323,6 +331,9 @@ class TestRelationalFieldMappings(TestCase):
|
|||
one_to_one = NestedSerializer(read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(max_length=100)
|
||||
string_foreign_key = NestedSerializer(read_only=True):
|
||||
id = CharField(max_length=128, validators=[<UniqueValidator(queryset=StringForeignKeyTargetModel.objects.all())>])
|
||||
name = CharField(max_length=100)
|
||||
many_to_many = NestedSerializer(many=True, read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(max_length=100)
|
||||
|
@ -342,6 +353,7 @@ class TestRelationalFieldMappings(TestCase):
|
|||
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
|
||||
foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail')
|
||||
one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>], view_name='onetoonetargetmodel-detail')
|
||||
string_foreign_key = HyperlinkedRelatedField(queryset=StringForeignKeyTargetModel.objects.all(), view_name='stringforeignkeytargetmodel-detail')
|
||||
many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
|
||||
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
|
||||
""")
|
||||
|
@ -362,6 +374,9 @@ class TestRelationalFieldMappings(TestCase):
|
|||
one_to_one = NestedSerializer(read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
string_foreign_key = NestedSerializer(read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='stringforeignkeytargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
many_to_many = NestedSerializer(many=True, read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
|
@ -427,6 +442,20 @@ class TestRelationalFieldMappings(TestCase):
|
|||
""")
|
||||
self.assertEqual(unicode_repr(TestSerializer()), expected)
|
||||
|
||||
def test_pk_reverse_string_foreign_key(self):
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = StringForeignKeyTargetModel
|
||||
fields = ('id', 'name', 'reverse_string_foreign_key')
|
||||
|
||||
expected = dedent("""
|
||||
TestSerializer():
|
||||
id = CharField(max_length=128, validators=[<UniqueValidator(queryset=StringForeignKeyTargetModel.objects.all())>])
|
||||
name = CharField(max_length=100)
|
||||
reverse_string_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
|
||||
""")
|
||||
self.assertEqual(unicode_repr(TestSerializer()), expected)
|
||||
|
||||
|
||||
class TestIntegration(TestCase):
|
||||
def setUp(self):
|
||||
|
@ -441,9 +470,14 @@ class TestIntegration(TestCase):
|
|||
name='many_to_many (%d)' % idx
|
||||
) for idx in range(3)
|
||||
]
|
||||
self.string_foreign_key_target = StringForeignKeyTargetModel.objects.create(
|
||||
id='stringified_id: 1',
|
||||
name='string_foreign_key'
|
||||
)
|
||||
self.instance = RelationalModel.objects.create(
|
||||
foreign_key=self.foreign_key_target,
|
||||
one_to_one=self.one_to_one_target,
|
||||
string_foreign_key=self.string_foreign_key_target
|
||||
)
|
||||
self.instance.many_to_many = self.many_to_many_targets
|
||||
self.instance.save()
|
||||
|
@ -459,6 +493,7 @@ class TestIntegration(TestCase):
|
|||
'foreign_key': self.foreign_key_target.pk,
|
||||
'one_to_one': self.one_to_one_target.pk,
|
||||
'many_to_many': [item.pk for item in self.many_to_many_targets],
|
||||
'string_foreign_key': self.string_foreign_key_target.pk,
|
||||
'through': []
|
||||
}
|
||||
self.assertEqual(serializer.data, expected)
|
||||
|
@ -479,15 +514,20 @@ class TestIntegration(TestCase):
|
|||
name='new many_to_many (%d)' % idx
|
||||
) for idx in range(3)
|
||||
]
|
||||
new_string_foreign_key = StringForeignKeyTargetModel.objects.create(
|
||||
id='stringified_id: 2',
|
||||
name='string_foreign_key'
|
||||
)
|
||||
data = {
|
||||
'foreign_key': new_foreign_key.pk,
|
||||
'one_to_one': new_one_to_one.pk,
|
||||
'many_to_many': [item.pk for item in new_many_to_many],
|
||||
'string_foreign_key': new_string_foreign_key.pk,
|
||||
}
|
||||
|
||||
# Serializer should validate okay.
|
||||
serializer = TestSerializer(data=data)
|
||||
assert serializer.is_valid()
|
||||
assert serializer.is_valid(raise_exception=True)
|
||||
|
||||
# Creating the instance, relationship attributes should be set.
|
||||
instance = serializer.save()
|
||||
|
@ -498,6 +538,8 @@ class TestIntegration(TestCase):
|
|||
] == [
|
||||
item.pk for item in new_many_to_many
|
||||
]
|
||||
|
||||
assert instance.string_foreign_key.pk == new_string_foreign_key.pk
|
||||
assert list(instance.through.all()) == []
|
||||
|
||||
# Representation should be correct.
|
||||
|
@ -506,6 +548,7 @@ class TestIntegration(TestCase):
|
|||
'foreign_key': new_foreign_key.pk,
|
||||
'one_to_one': new_one_to_one.pk,
|
||||
'many_to_many': [item.pk for item in new_many_to_many],
|
||||
'string_foreign_key': new_string_foreign_key.pk,
|
||||
'through': []
|
||||
}
|
||||
self.assertEqual(serializer.data, expected)
|
||||
|
@ -526,10 +569,15 @@ class TestIntegration(TestCase):
|
|||
name='new many_to_many (%d)' % idx
|
||||
) for idx in range(3)
|
||||
]
|
||||
new_string_foreign_key = StringForeignKeyTargetModel.objects.create(
|
||||
id='stringified_id: 2',
|
||||
name='string_foreign_key'
|
||||
)
|
||||
data = {
|
||||
'foreign_key': new_foreign_key.pk,
|
||||
'one_to_one': new_one_to_one.pk,
|
||||
'many_to_many': [item.pk for item in new_many_to_many],
|
||||
'string_foreign_key': new_string_foreign_key.pk
|
||||
}
|
||||
|
||||
# Serializer should validate okay.
|
||||
|
@ -545,6 +593,7 @@ class TestIntegration(TestCase):
|
|||
] == [
|
||||
item.pk for item in new_many_to_many
|
||||
]
|
||||
assert instance.string_foreign_key.pk == new_string_foreign_key.pk
|
||||
assert list(instance.through.all()) == []
|
||||
|
||||
# Representation should be correct.
|
||||
|
@ -553,6 +602,7 @@ class TestIntegration(TestCase):
|
|||
'foreign_key': new_foreign_key.pk,
|
||||
'one_to_one': new_one_to_one.pk,
|
||||
'many_to_many': [item.pk for item in new_many_to_many],
|
||||
'string_foreign_key': new_string_foreign_key.pk,
|
||||
'through': []
|
||||
}
|
||||
self.assertEqual(serializer.data, expected)
|
||||
|
@ -639,3 +689,32 @@ class TestSerializerMetaClass(TestCase):
|
|||
str(exception),
|
||||
"Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."
|
||||
)
|
||||
"""
|
||||
{
|
||||
False = <bound method TestSerializer.is_valid of TestSerializer(
|
||||
data=
|
||||
{
|
||||
'foreign_key': 2,
|
||||
'many_to_many': [4, 5, 6],
|
||||
'one_to_one...ny=True,
|
||||
queryset=ManyToManyTargetModel.objects.all())
|
||||
through = PrimaryKeyRelatedField(many=True, read_only=True)>
|
||||
{
|
||||
<bound method TestSerializer.is_valid of TestSerializer(data=
|
||||
{
|
||||
'foreign_key': 2,
|
||||
'many_to_many': [4, 5, 6],
|
||||
'one_to_one...ny=True,
|
||||
queryset=ManyToManyTargetModel.objects.all())\\n through = PrimaryKeyRelatedField(many=True, read_only=True)> = TestSerializer(
|
||||
data=
|
||||
{
|
||||
'foreign_key': 2,
|
||||
'many_to_many': [4, 5, 6],
|
||||
'one_to_one': 2,
|
||||
|
||||
'string_foreign_key': ''
|
||||
}): id ...any=True, queryset=ManyToManyTargetModel.objects.all())\\n
|
||||
through = PrimaryKeyRelatedField(many=True, read_only=True).is_valid\n
|
||||
}()\n
|
||||
}
|
||||
"""
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import uuid
|
||||
from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.utils.datastructures import MultiValueDict
|
||||
|
@ -48,6 +49,42 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
|
|||
assert representation == self.instance.pk
|
||||
|
||||
|
||||
class TestUUIDPrimaryKeyRelatedField(APISimpleTestCase):
|
||||
def setUp(self):
|
||||
self.queryset = MockQueryset([
|
||||
MockObject(pk=uuid.UUID(int=1), name='foo'),
|
||||
MockObject(pk=uuid.UUID(int=2), name='bar'),
|
||||
MockObject(pk=uuid.UUID(int=3), name='baz')
|
||||
])
|
||||
self.instance = self.queryset.items[2]
|
||||
self.field = serializers.PrimaryKeyRelatedField(
|
||||
pk_field=serializers.UUIDField(),
|
||||
queryset=self.queryset
|
||||
)
|
||||
|
||||
def test_pk_related_lookup_exists(self):
|
||||
instance = self.field.to_internal_value(self.instance.pk)
|
||||
assert instance is self.instance
|
||||
|
||||
def test_pk_related_lookup_does_not_exist(self):
|
||||
bad_value = uuid.UUID(int=4)
|
||||
with pytest.raises(serializers.ValidationError) as excinfo:
|
||||
self.field.to_internal_value(str(bad_value))
|
||||
msg = excinfo.value.detail[0]
|
||||
assert msg == 'Invalid pk "{0}" - object does not exist.'.format(bad_value)
|
||||
|
||||
def test_pk_related_lookup_invalid_type(self):
|
||||
bad_value = BadType()
|
||||
with pytest.raises(serializers.ValidationError) as excinfo:
|
||||
self.field.to_internal_value(bad_value)
|
||||
msg = excinfo.value.detail[0]
|
||||
assert msg == '"{0}" is not a valid UUID.'.format(bad_value)
|
||||
|
||||
def test_pk_representation(self):
|
||||
representation = self.field.to_representation(self.instance)
|
||||
assert representation == str(self.instance.pk)
|
||||
|
||||
|
||||
class TestHyperlinkedIdentityField(APISimpleTestCase):
|
||||
def setUp(self):
|
||||
self.instance = MockObject(pk=1, name='foo')
|
||||
|
|
Loading…
Reference in New Issue
Block a user