This commit is contained in:
Alan Braithwaite 2014-04-18 20:13:51 +00:00
commit e6dd318f9e
3 changed files with 105 additions and 9 deletions

View File

@ -15,6 +15,7 @@ import copy
import datetime
import inspect
import types
import warnings
from decimal import Decimal
from django.contrib.contenttypes.generic import GenericForeignKey
from django.core.paginator import Page
@ -131,6 +132,31 @@ def _is_protected_type(obj):
)
def _convert_fields(data, mapping):
"""
Takes data as dict or none and str->str map of keys and
returns None or new dictionary with converted keys
"""
# Handle translation of serialized fields into non serailzed fields
if data is not None:
translated_data = copy.deepcopy(data)
for key in mapping.keys():
if key not in translated_data:
continue
newkey = mapping.get(key)
try: # MultiValueDict
value = translated_data.getlist(key)
del translated_data[key]
translated_data.setlist(newkey, value)
except AttributeError:
value = translated_data.pop(key)
translated_data[newkey] = value
else: # Data can be None so translated_data is too
translated_data = None
return translated_data
def _get_declared_fields(bases, attrs):
"""
Create a list of serializer field instances from the passed in 'attrs',
@ -167,6 +193,7 @@ class SerializerOptions(object):
self.depth = getattr(meta, 'depth', 0)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
self.convert_fields = getattr(meta, 'convert_fields', False)
class BaseSerializer(WritableField):
@ -265,6 +292,20 @@ class BaseSerializer(WritableField):
"""
return field_name
def get_field_name_map(self):
"""
Return a map of serialized->python field names
"""
ret = SortedDict()
for name, value in list(self.fields.items()):
key = self.get_field_key(name)
if key in ret:
warnings.warn("Duplicate key found in fields. This can happen if `get_field_key`"
" can return the same string for two different inputs! Ensure your keys are unique"
" after running them all through `get_field_key`", Warning, stacklevel=3)
ret[key] = name
return ret
def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
@ -276,6 +317,10 @@ class BaseSerializer(WritableField):
self._errors['non_field_errors'] = ['Invalid data']
return None
if self.opts.convert_fields:
key_map = self.get_field_name_map()
data = _convert_fields(data, key_map)
for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name)
try:
@ -758,9 +803,9 @@ class ModelSerializer(Serializer):
field.read_only = True
ret[accessor_name] = field
# Ensure that 'read_only_fields' is an iterable
assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
# Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
@ -775,10 +820,10 @@ class ModelSerializer(Serializer):
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
# Ensure that 'write_only_fields' is an iterable
assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
@ -789,7 +834,7 @@ class ModelSerializer(Serializer):
"Non-existant field '%s' specified in `write_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].write_only = True
ret[field_name].write_only = True
return ret

View File

@ -175,3 +175,8 @@ class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
class ModelWithUnderscoreFields(RESTFrameworkModel):
char_field = models.CharField(max_length=100)
number_field = models.IntegerField()

View File

@ -7,9 +7,12 @@ from django.utils import unittest
from django.utils.datastructures import MultiValueDict
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers, fields, relations
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
from rest_framework.tests.models import (
HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel,
DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel,
Photo, RESTFrameworkModel, ModelWithUnderscoreFields
)
from rest_framework.tests.models import BasicModelSerializer
import datetime
import pickle
@ -176,6 +179,16 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
fields = ['some_integer']
class ModelFieldConversionSerializer(serializers.ModelSerializer):
def get_field_key(self, field_name):
return field_name.replace('_','').capitalize()
class Meta:
convert_fields = True
model = ModelWithUnderscoreFields
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@ -331,6 +344,23 @@ class BasicTests(TestCase):
exclusions = serializer.get_validation_exclusions()
self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded')
def test_serialize_with_conversion(self):
"""
Verify keys get converted from serialized value to deserialized value
"""
underscore = ModelWithUnderscoreFields(char_field='slartibartfast', number_field=42)
underscore.save()
serializer = ModelFieldConversionSerializer(underscore)
serialized = {'Id': 1, 'Numberfield': 42, 'Charfield':'slartibartfast'}
self.assertEqual(serialized, serializer.data, "Validate that serialing data with conversion works")
serializer = ModelFieldConversionSerializer(data=serialized)
self.assertTrue(serializer.is_valid(),
'Data should get converted from serialized value into deserialized value')
self.assertEqual('slartibartfast', serializer.object.char_field)
self.assertEqual(42, serializer.object.number_field)
class DictStyleSerializer(serializers.Serializer):
"""
@ -836,6 +866,22 @@ class ManyToManyTests(TestCase):
self.assertEqual(instance.pk, 2)
self.assertEqual(list(instance.rel.all()), [])
def test_create_empty_relationship_flat_data_field_convert(self):
"""
Create an instance of a model with a ManyToMany relationship,
containing no items, using a representation that does not support
lists (eg form data).
"""
data = MultiValueDict()
data.setlist('rel', [''])
self.serializer_class.Meta.convert_fields = True
serializer = self.serializer_class(data=data)
self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
self.assertEqual(len(ManyToManyModel.objects.all()), 2)
self.assertEqual(instance.pk, 2)
self.assertEqual(list(instance.rel.all()), [])
class ReadOnlyManyToManyTests(TestCase):
def setUp(self):