Support for PrimaryKeyRelatedFields with arbitrary types

Include support for relations where the target primary key
is not an AutoField or IntegerField.

CharFields may already have been supported? So this is mainly
to support models with UUID primary keys.
This commit is contained in:
Thomas Stephenson 2015-03-26 07:32:02 +11:00
parent 079da5e7f0
commit 3209dc02ad
5 changed files with 143 additions and 1 deletions

View File

@ -134,10 +134,16 @@ class PrimaryKeyRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
}
def __init__(self, **kwargs):
self.pk_field = kwargs.pop('pk_field', None)
super(PrimaryKeyRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self):
return True
def to_internal_value(self, data):
if self.pk_field is not None:
data = self.pk_field.to_internal_value(data)
try:
return self.get_queryset().get(pk=data)
except ObjectDoesNotExist:
@ -146,6 +152,8 @@ class PrimaryKeyRelatedField(RelatedField):
self.fail('incorrect_type', data_type=type(data).__name__)
def to_representation(self, value):
if self.pk_field is not None:
return self.pk_field.to_representation(value.pk)
return value.pk

View File

@ -11,6 +11,7 @@ python primitives.
response content is handled by parsers and renderers.
"""
from __future__ import unicode_literals
from django.db import models
from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField
from django.db.models import query
@ -731,6 +732,7 @@ class ModelSerializer(Serializer):
models.TimeField: TimeField,
models.URLField: URLField,
}
serializer_related_field = PrimaryKeyRelatedField
serializer_url_field = HyperlinkedIdentityField
serializer_choice_field = ChoiceField
@ -1021,6 +1023,15 @@ class ModelSerializer(Serializer):
field_class = self.serializer_related_field
field_kwargs = get_relation_kwargs(field_name, relation_info)
# `pk_field` is only valid for primary key relationships
pk_field_info = field_kwargs.pop('pk_field', None)
if issubclass(field_class, PrimaryKeyRelatedField) and pk_field_info is not None:
pk_field_class, pk_field_kwargs = pk_field_info
pk_field_class = self.serializer_field_mapping[pk_field_class]
if not issubclass(pk_field_class, ModelField):
pk_field_kwargs.pop('model_field', None)
field_kwargs['pk_field'] = pk_field_class(**pk_field_kwargs)
# `view_name` is only valid for hyperlinked relationships.
if not issubclass(field_class, HyperlinkedRelatedField):
field_kwargs.pop('view_name', None)

View File

@ -203,6 +203,13 @@ def get_relation_kwargs(field_name, relation_info):
'view_name': get_detail_view_name(related_model)
}
related_pk = related_model._meta.pk
if (not isinstance(related_pk, models.AutoField) and
not getattr(related_pk, 'is_relation', False)):
pk_field_class = type(related_pk)
pk_field_kwargs = get_field_kwargs('pk', related_pk)
kwargs['pk_field'] = (pk_field_class, pk_field_kwargs)
if to_many:
kwargs['many'] = True

View File

@ -279,6 +279,11 @@ class ThroughTargetModel(models.Model):
name = models.CharField(max_length=100)
class StringForeignKeyTargetModel(models.Model):
id = models.CharField(primary_key=True, max_length=128)
name = models.CharField(max_length=100)
class Supplementary(models.Model):
extra = models.IntegerField()
forwards = models.ForeignKey('ThroughTargetModel')
@ -291,6 +296,8 @@ class RelationalModel(models.Model):
one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one')
through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through')
string_foreign_key = models.ForeignKey(StringForeignKeyTargetModel, related_name='reverse_string_foreign_key')
class TestRelationalFieldMappings(TestCase):
def test_pk_relations(self):
@ -303,6 +310,7 @@ class TestRelationalFieldMappings(TestCase):
id = IntegerField(label='ID', read_only=True)
foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all())
one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>])
string_foreign_key = PrimaryKeyRelatedField(pk_field=CharField(label='Id', max_length=128, validators=[<UniqueValidator(queryset=StringForeignKeyTargetModel.objects.all())>]), queryset=StringForeignKeyTargetModel.objects.all())
many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
@ -323,6 +331,9 @@ class TestRelationalFieldMappings(TestCase):
one_to_one = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
string_foreign_key = NestedSerializer(read_only=True):
id = CharField(max_length=128, validators=[<UniqueValidator(queryset=StringForeignKeyTargetModel.objects.all())>])
name = CharField(max_length=100)
many_to_many = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
@ -342,6 +353,7 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail')
one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>], view_name='onetoonetargetmodel-detail')
string_foreign_key = HyperlinkedRelatedField(queryset=StringForeignKeyTargetModel.objects.all(), view_name='stringforeignkeytargetmodel-detail')
many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
""")
@ -362,6 +374,9 @@ class TestRelationalFieldMappings(TestCase):
one_to_one = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100)
string_foreign_key = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='stringforeignkeytargetmodel-detail')
name = CharField(max_length=100)
many_to_many = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
name = CharField(max_length=100)
@ -427,6 +442,20 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_string_foreign_key(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = StringForeignKeyTargetModel
fields = ('id', 'name', 'reverse_string_foreign_key')
expected = dedent("""
TestSerializer():
id = CharField(max_length=128, validators=[<UniqueValidator(queryset=StringForeignKeyTargetModel.objects.all())>])
name = CharField(max_length=100)
reverse_string_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""")
self.assertEqual(unicode_repr(TestSerializer()), expected)
class TestIntegration(TestCase):
def setUp(self):
@ -441,9 +470,14 @@ class TestIntegration(TestCase):
name='many_to_many (%d)' % idx
) for idx in range(3)
]
self.string_foreign_key_target = StringForeignKeyTargetModel.objects.create(
id='stringified_id: 1',
name='string_foreign_key'
)
self.instance = RelationalModel.objects.create(
foreign_key=self.foreign_key_target,
one_to_one=self.one_to_one_target,
string_foreign_key=self.string_foreign_key_target
)
self.instance.many_to_many = self.many_to_many_targets
self.instance.save()
@ -459,6 +493,7 @@ class TestIntegration(TestCase):
'foreign_key': self.foreign_key_target.pk,
'one_to_one': self.one_to_one_target.pk,
'many_to_many': [item.pk for item in self.many_to_many_targets],
'string_foreign_key': self.string_foreign_key_target.pk,
'through': []
}
self.assertEqual(serializer.data, expected)
@ -479,15 +514,20 @@ class TestIntegration(TestCase):
name='new many_to_many (%d)' % idx
) for idx in range(3)
]
new_string_foreign_key = StringForeignKeyTargetModel.objects.create(
id='stringified_id: 2',
name='string_foreign_key'
)
data = {
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'string_foreign_key': new_string_foreign_key.pk,
}
# Serializer should validate okay.
serializer = TestSerializer(data=data)
assert serializer.is_valid()
assert serializer.is_valid(raise_exception=True)
# Creating the instance, relationship attributes should be set.
instance = serializer.save()
@ -498,6 +538,8 @@ class TestIntegration(TestCase):
] == [
item.pk for item in new_many_to_many
]
assert instance.string_foreign_key.pk == new_string_foreign_key.pk
assert list(instance.through.all()) == []
# Representation should be correct.
@ -506,6 +548,7 @@ class TestIntegration(TestCase):
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'string_foreign_key': new_string_foreign_key.pk,
'through': []
}
self.assertEqual(serializer.data, expected)
@ -526,10 +569,15 @@ class TestIntegration(TestCase):
name='new many_to_many (%d)' % idx
) for idx in range(3)
]
new_string_foreign_key = StringForeignKeyTargetModel.objects.create(
id='stringified_id: 2',
name='string_foreign_key'
)
data = {
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'string_foreign_key': new_string_foreign_key.pk
}
# Serializer should validate okay.
@ -545,6 +593,7 @@ class TestIntegration(TestCase):
] == [
item.pk for item in new_many_to_many
]
assert instance.string_foreign_key.pk == new_string_foreign_key.pk
assert list(instance.through.all()) == []
# Representation should be correct.
@ -553,6 +602,7 @@ class TestIntegration(TestCase):
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'string_foreign_key': new_string_foreign_key.pk,
'through': []
}
self.assertEqual(serializer.data, expected)
@ -639,3 +689,32 @@ class TestSerializerMetaClass(TestCase):
str(exception),
"Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."
)
"""
{
False = <bound method TestSerializer.is_valid of TestSerializer(
data=
{
'foreign_key': 2,
'many_to_many': [4, 5, 6],
'one_to_one...ny=True,
queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True)>
{
<bound method TestSerializer.is_valid of TestSerializer(data=
{
'foreign_key': 2,
'many_to_many': [4, 5, 6],
'one_to_one...ny=True,
queryset=ManyToManyTargetModel.objects.all())\\n through = PrimaryKeyRelatedField(many=True, read_only=True)> = TestSerializer(
data=
{
'foreign_key': 2,
'many_to_many': [4, 5, 6],
'one_to_one': 2,
'string_foreign_key': ''
}): id ...any=True, queryset=ManyToManyTargetModel.objects.all())\\n
through = PrimaryKeyRelatedField(many=True, read_only=True).is_valid\n
}()\n
}
"""

View File

@ -1,3 +1,4 @@
import uuid
from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset
from django.core.exceptions import ImproperlyConfigured
from django.utils.datastructures import MultiValueDict
@ -48,6 +49,42 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
assert representation == self.instance.pk
class TestUUIDPrimaryKeyRelatedField(APISimpleTestCase):
def setUp(self):
self.queryset = MockQueryset([
MockObject(pk=uuid.UUID(int=1), name='foo'),
MockObject(pk=uuid.UUID(int=2), name='bar'),
MockObject(pk=uuid.UUID(int=3), name='baz')
])
self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField(
pk_field=serializers.UUIDField(),
queryset=self.queryset
)
def test_pk_related_lookup_exists(self):
instance = self.field.to_internal_value(self.instance.pk)
assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self):
bad_value = uuid.UUID(int=4)
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(str(bad_value))
msg = excinfo.value.detail[0]
assert msg == 'Invalid pk "{0}" - object does not exist.'.format(bad_value)
def test_pk_related_lookup_invalid_type(self):
bad_value = BadType()
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(bad_value)
msg = excinfo.value.detail[0]
assert msg == '"{0}" is not a valid UUID.'.format(bad_value)
def test_pk_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == str(self.instance.pk)
class TestHyperlinkedIdentityField(APISimpleTestCase):
def setUp(self):
self.instance = MockObject(pk=1, name='foo')