diff --git a/docs/api-guide/relations.md b/docs/api-guide/relations.md index cc4f55851..286c3dcff 100644 --- a/docs/api-guide/relations.md +++ b/docs/api-guide/relations.md @@ -179,6 +179,43 @@ When using `SlugRelatedField` as a read-write field, you will normally want to e * `required` - If set to `False`, the field will accept values of `None` or the empty-string for nullable relationships. * `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. +## MultiSlugRelatedField + +`MultiSlugRelatedField` may be used to represent the target of the relationship using a set of fields on the target. + +For example, the following serializer: + + class AddressSerializer(serializers.ModelSerializer): + postal_code = serializers.SlugRelatedField(many=True, read_only=True, + slug_fields=('code', 'country')) + + class Meta: + model = Address + fields = ('street', 'city', 'state', 'postal_code') + +Would serialize to a representation like this: + + { + 'street': '123 Main St.', + 'city': 'Boulder', + 'state': 'CO', + 'postal_code': { + 'code': '80305', + 'country': 'USA', + } + } + +By default this field is read-write, although you can change this behavior using the `read_only` flag. + +When using `MultiSlugRelatedField` as a read-write field, you will normally want to ensure that the slug fields corresponds to a set of model field declared as `unique_together`. + +**Arguments**: + +* `slug_fields` - The fields on the target that should be used to represent it. This should be a set of fields that uniquely identifies any given instance. For example, `('postal_code', 'country')`. **required** +* `many` - If applied to a to-many relationship, you should set this argument to `True`. +* `required` - If set to `False`, the field will accept values of `None` or the empty-string for nullable relationships. +* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. + ## HyperlinkedIdentityField This field can be applied as an identity relationship, such as the `'url'` field on a HyperlinkedModelSerializer. It can also be used for an attribute on the object. For example, the following serializer: diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 3463954dc..0833aaf94 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -17,6 +17,7 @@ from rest_framework.reverse import reverse from rest_framework.compat import urlparse from rest_framework.compat import smart_text import warnings +import collections ##### Relational fields ##### @@ -313,6 +314,52 @@ class SlugRelatedField(RelatedField): raise ValidationError(msg) +### Multi-Slug relations + + +class MultiSlugRelatedField(RelatedField): + """ + Represents a relationship using a unique set of fields on the target. + """ + read_only = False + + default_error_messages = { + 'does_not_exist': _("Object with %s does not exist."), + 'invalid': _('Invalid value.'), + } + + def __init__(self, *args, **kwargs): + self.slug_fields = kwargs.pop('slug_fields', None) + assert self.slug_fields, "slug_fields is required" + super(MultiSlugRelatedField, self).__init__(*args, **kwargs) + + def to_native(self, obj): + return dict(zip( + self.slug_fields, + (getattr(obj, slug_field) for slug_field in self.slug_fields), + )) + + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + if not isinstance(data, collections.Mapping): + raise ValidationError(self.error_messages['invalid']) + + if not set(data.keys()) == set(self.slug_fields): + raise ValidationError(self.error_messages['invalid']) + + try: + return self.queryset.get(**data) + except ObjectDoesNotExist: + lookups = ['='.join((lookup, value)) for lookup, value in zip(self.slug_fields, data)] + raise ValidationError(self.error_messages['does_not_exist'] % + ' '.join(lookups)) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] + raise ValidationError(msg) + + ### Hyperlinked relationships class HyperlinkedRelatedField(RelatedField): diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index fba3f8f7c..a7f262358 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -179,3 +179,25 @@ class FilterableItem(models.Model): text = models.CharField(max_length=100) decimal = models.DecimalField(max_digits=4, decimal_places=2) date = models.DateField() + + +# Models to test multi-slig relations +class TimeZone(models.Model): + pass + + +class PostalCode(models.Model): + code = models.CharField(max_length=10) + country = models.CharField(max_length=5) + + timezone = models.ForeignKey(TimeZone, null=True, blank=True, + related_name='postal_codes') + + class Meta: + unique_together = ( + ('code', 'country'), + ) + + +class Address(models.Model): + postal_code = models.ForeignKey(PostalCode, null=True, blank=True) diff --git a/rest_framework/tests/test_relations_multi_slug.py b/rest_framework/tests/test_relations_multi_slug.py new file mode 100644 index 000000000..618aa4778 --- /dev/null +++ b/rest_framework/tests/test_relations_multi_slug.py @@ -0,0 +1,195 @@ +from django.test import TestCase + +from rest_framework import serializers +from rest_framework.tests.models import PostalCode, Address, TimeZone + + +class AddressSerializer(serializers.ModelSerializer): + postal_code = serializers.MultiSlugRelatedField( + slug_fields=('code', 'country'), + ) + + class Meta: + model = Address + fields = ('id', 'postal_code',) + + +class TimeZoneSerializer(serializers.ModelSerializer): + postal_codes = serializers.MultiSlugRelatedField( + many=True, slug_fields=('code', 'country'), + ) + + class Meta: + model = TimeZone + fields = ('id', 'postal_codes',) + + +class MultiSlugFieldTest(TestCase): + def test_many_serialization(self): + postal_code = PostalCode.objects.create(code='12345', country='USA') + + address_a = Address.objects.create(postal_code=postal_code) + address_b = Address.objects.create(postal_code=postal_code) + + queryset = Address.objects.all() + serializer = AddressSerializer(queryset, many=True) + + expected = [ + {'id': address_a.pk, 'postal_code': {'code': '12345', 'country': 'USA'}}, + {'id': address_b.pk, 'postal_code': {'code': '12345', 'country': 'USA'}}, + ] + self.assertEqual( + serializer.data, + expected, + ) + + def test_singular_serialization(self): + postal_code = PostalCode.objects.create(code='12345', country='USA') + address = Address.objects.create(postal_code=postal_code) + + serializer = AddressSerializer(address) + + expected = { + 'id': address.pk, + 'postal_code': { + 'code': postal_code.code, + 'country': postal_code.country, + }, + } + self.assertEqual( + serializer.data, + expected, + ) + + def test_singular_serialization_when_null(self): + address = Address.objects.create() + + serializer = AddressSerializer(address) + + expected = { + 'id': address.pk, + 'postal_code': None, + } + self.assertEqual( + serializer.data, + expected, + ) + + def test_foreign_key_creation(self): + postal_code = PostalCode.objects.create(code='12345', country='USA') + + serializer = AddressSerializer(data={ + 'postal_code': { + 'code': postal_code.code, + 'country': postal_code.country, + }, + }) + self.assertTrue(serializer.is_valid()) + address = serializer.save() + self.assertEqual(address.postal_code, postal_code) + + def test_foreign_key_update(self): + postal_code = PostalCode.objects.create(code='12345', country='USA') + address = Address.objects.create(postal_code=postal_code) + + new_postal_code = PostalCode.objects.create(code='54321', country='USA') + + serializer = AddressSerializer(data={ + 'postal_code': { + 'code': new_postal_code.code, + 'country': new_postal_code.country, + }, + }) + self.assertTrue(serializer.is_valid()) + address = serializer.save() + self.assertEqual(address.postal_code, new_postal_code) + + def test_foreign_key_update_incomplete_slug(self): + postal_code = PostalCode.objects.create(code='12345', country='USA') + + serializer = AddressSerializer(data={ + 'postal_code': { + 'code': postal_code.code, + }, + }) + self.assertFalse(serializer.is_valid()) + self.assertIn('postal_code', serializer.errors) + + def test_foreign_key_update_incorrect_type(self): + serializer = AddressSerializer(data={ + 'postal_code': 1234, + }) + self.assertFalse(serializer.is_valid()) + self.assertIn('postal_code', serializer.errors) + + def test_reverse_foreign_key_retrieve(self): + timezone = TimeZone.objects.create() + PostalCode.objects.create(code='12345', country='USA', timezone=timezone) + PostalCode.objects.create(code='54321', country='USA', timezone=timezone) + + serializer = TimeZoneSerializer(timezone) + + expected = { + 'id': timezone.pk, + 'postal_codes': [ + {'code': '12345', 'country': 'USA'}, + {'code': '54321', 'country': 'USA'}, + ] + } + self.assertEqual( + serializer.data, + expected, + ) + + def test_reverse_foreign_key_create(self): + PostalCode.objects.create(code='12345', country='USA') + PostalCode.objects.create(code='54321', country='USA') + data = { + 'postal_codes': [ + {'code': '12345', 'country': 'USA'}, + {'code': '54321', 'country': 'USA'}, + ] + } + + serializer = TimeZoneSerializer(data=data) + + self.assertTrue(serializer.is_valid()) + + new_timezone = serializer.save() + + self.assertEqual(new_timezone.postal_codes.count(), 2) + + self.assertTrue( + PostalCode.objects.filter( + code='12345', country='USA', timezone=new_timezone, + ).exists(), + ) + self.assertTrue( + PostalCode.objects.filter( + code='54321', country='USA', timezone=new_timezone, + ).exists(), + ) + + def test_reverse_foreign_key_update(self): + timezone = TimeZone.objects.create() + PostalCode.objects.create(code='12345', country='USA') + PostalCode.objects.create(code='54321', country='USA') + + data = { + 'id': timezone.pk, + 'postal_codes': [ + {'code': '12345', 'country': 'USA'}, + {'code': '54321', 'country': 'USA'}, + ] + } + + # There should be no postal codes + self.assertEqual(timezone.postal_codes.count(), 0) + + serializer = TimeZoneSerializer(timezone, data=data) + + self.assertTrue(serializer.is_valid()) + + updated_timezone = serializer.save() + + self.assertEqual(updated_timezone.postal_codes.count(), 2)