diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 33ab0682c..1818e705e 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -80,6 +80,10 @@ def set_value(dictionary, keys, value): dictionary[keys[-1]] = value +def field_name_to_label(field_name): + return field_name.replace('_', ' ').capitalize() + + class SkipField(Exception): pass @@ -158,7 +162,7 @@ class Field(object): # `self.label` should deafult to being based on the field name. if self.label is None: - self.label = self.field_name.replace('_', ' ').capitalize() + self.label = field_name_to_label(self.field_name) # self.source should default to being the same as the field name. if self.source is None: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ecb2829b6..ba8d475f8 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -15,11 +15,12 @@ from django.core.exceptions import ValidationError from django.db import models from django.utils import six from django.utils.datastructures import SortedDict +from django.utils.text import capfirst from collections import namedtuple from rest_framework.compat import clean_manytomany_helptext from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings -from rest_framework.utils import html, modelinfo, representation +from rest_framework.utils import html, model_meta, representation import copy # Note: We do the following so that users of the framework can use this style: @@ -334,6 +335,14 @@ def lookup_class(mapping, instance): raise KeyError('Class %s not found in lookup.', cls.__name__) +def needs_label(model_field, field_name): + """ + Returns `True` if the label based on the model's verbose name + is not equal to the default label it would have based on it's field name. + """ + return capfirst(model_field.verbose_name) != field_name_to_label(field_name) + + class ModelSerializer(Serializer): field_mapping = { models.AutoField: IntegerField, @@ -397,54 +406,55 @@ class ModelSerializer(Serializer): """ Return all the fields that should be serialized for the model. """ - info = modelinfo.get_field_info(self.opts.model) + info = model_meta.get_field_info(self.opts.model) ret = SortedDict() serializer_url_field = self.get_url_field() if serializer_url_field: ret[api_settings.URL_FIELD_NAME] = serializer_url_field - serializer_pk_field = self.get_pk_field(info.pk) + field_name = info.pk.name + serializer_pk_field = self.get_pk_field(field_name, info.pk) if serializer_pk_field: - ret[info.pk.name] = serializer_pk_field + ret[field_name] = serializer_pk_field # Regular fields for field_name, field in info.fields.items(): - ret[field_name] = self.get_field(field) + ret[field_name] = self.get_field(field_name, field) # Forward relations for field_name, relation_info in info.forward_relations.items(): if self.opts.depth: - ret[field_name] = self.get_nested_field(*relation_info) + ret[field_name] = self.get_nested_field(field_name, *relation_info) else: - ret[field_name] = self.get_related_field(*relation_info) + ret[field_name] = self.get_related_field(field_name, *relation_info) # Reverse relations for accessor_name, relation_info in info.reverse_relations.items(): if accessor_name in self.opts.fields: if self.opts.depth: - ret[accessor_name] = self.get_nested_field(*relation_info) + ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info) else: - ret[accessor_name] = self.get_related_field(*relation_info) + ret[accessor_name] = self.get_related_field(accessor_name, *relation_info) return ret def get_url_field(self): return None - def get_pk_field(self, model_field): + def get_pk_field(self, field_name, model_field): """ Returns a default instance of the pk field. """ - return self.get_field(model_field) + return self.get_field(field_name, model_field) - def get_nested_field(self, model_field, related_model, to_many, has_through_model): + def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a nested relational field. Note that model_field will be `None` for reverse relationships. """ - class NestedModelSerializer(ModelSerializer): # Not right! + class NestedModelSerializer(ModelSerializer): class Meta: model = related_model depth = self.opts.depth - 1 @@ -454,7 +464,7 @@ class ModelSerializer(Serializer): kwargs['many'] = True return NestedModelSerializer(**kwargs) - def get_related_field(self, model_field, related_model, to_many, has_through_model): + def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. @@ -474,8 +484,8 @@ class ModelSerializer(Serializer): if model_field: if model_field.null or model_field.blank: kwargs['required'] = False - if model_field.verbose_name: - kwargs['label'] = model_field.verbose_name + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) if not model_field.editable: kwargs['read_only'] = True kwargs.pop('queryset', None) @@ -485,7 +495,7 @@ class ModelSerializer(Serializer): return PrimaryKeyRelatedField(**kwargs) - def get_field(self, model_field): + def get_field(self, field_name, model_field): """ Creates a default instance of a basic non-relational field. """ @@ -496,8 +506,8 @@ class ModelSerializer(Serializer): if model_field.null or model_field.blank: kwargs['required'] = False - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) if model_field.help_text: kwargs['help_text'] = model_field.help_text @@ -642,11 +652,11 @@ class HyperlinkedModelSerializer(ModelSerializer): return HyperlinkedIdentityField(**kwargs) - def get_pk_field(self, model_field): + def get_pk_field(self, field_name, model_field): if self.opts.fields and model_field.name in self.opts.fields: return self.get_field(model_field) - def get_related_field(self, model_field, related_model, to_many, has_through_model): + def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. """ @@ -665,8 +675,8 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: if model_field.null or model_field.blank: kwargs['required'] = False - if model_field.verbose_name: - kwargs['label'] = model_field.verbose_name + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) if not model_field.editable: kwargs['read_only'] = True kwargs.pop('queryset', None) diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/model_meta.py similarity index 100% rename from rest_framework/utils/modelinfo.py rename to rest_framework/utils/model_meta.py diff --git a/tests/test_model_field_mappings.py b/tests/test_model_field_mappings.py index d5750f9e3..b04ad5f2f 100644 --- a/tests/test_model_field_mappings.py +++ b/tests/test_model_field_mappings.py @@ -36,25 +36,25 @@ class RegularFieldsModel(models.Model): REGULAR_FIELDS_REPR = """ TestSerializer(): - auto_field = IntegerField(label='auto field', read_only=True) - big_integer_field = IntegerField(label='big integer field') - boolean_field = BooleanField(default=False, label='boolean field') - char_field = CharField(label='char field', max_length=100) - comma_seperated_integer_field = CharField(label='comma seperated integer field', max_length=100, validators=[]) - date_field = DateField(label='date field') - datetime_field = DateTimeField(label='datetime field') - decimal_field = DecimalField(decimal_places=1, label='decimal field', max_digits=3) - email_field = EmailField(label='email field', max_length=100) - float_field = FloatField(label='float field') - integer_field = IntegerField(label='integer field') - null_boolean_field = BooleanField(label='null boolean field', required=False) - positive_integer_field = IntegerField(label='positive integer field') - positive_small_integer_field = IntegerField(label='positive small integer field') - slug_field = SlugField(label='slug field', max_length=100) - small_integer_field = IntegerField(label='small integer field') - text_field = CharField(label='text field') - time_field = TimeField(label='time field') - url_field = URLField(label='url field', max_length=100) + auto_field = IntegerField(read_only=True) + big_integer_field = IntegerField() + boolean_field = BooleanField(default=False) + char_field = CharField(max_length=100) + comma_seperated_integer_field = CharField(max_length=100, validators=[]) + date_field = DateField() + datetime_field = DateTimeField() + decimal_field = DecimalField(decimal_places=1, max_digits=3) + email_field = EmailField(max_length=100) + float_field = FloatField() + integer_field = IntegerField() + null_boolean_field = BooleanField(required=False) + positive_integer_field = IntegerField() + positive_small_integer_field = IntegerField() + slug_field = SlugField(max_length=100) + small_integer_field = IntegerField() + text_field = CharField() + time_field = TimeField() + url_field = URLField(max_length=100) """.strip() @@ -81,9 +81,9 @@ class RelationalModel(models.Model): RELATIONAL_FLAT_REPR = """ TestSerializer(): id = IntegerField(label='ID', read_only=True) - foreign_key = PrimaryKeyRelatedField(label='foreign key', queryset=ForeignKeyTargetModel.objects.all()) - one_to_one = PrimaryKeyRelatedField(label='one to one', queryset=OneToOneTargetModel.objects.all()) - many_to_many = PrimaryKeyRelatedField(label='many to many', many=True, queryset=ManyToManyTargetModel.objects.all()) + foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all()) + one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all()) + many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all()) """.strip() @@ -92,22 +92,22 @@ TestSerializer(): id = IntegerField(label='ID', read_only=True) foreign_key = NestedModelSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) - name = CharField(label='name', max_length=100) + name = CharField(max_length=100) one_to_one = NestedModelSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) - name = CharField(label='name', max_length=100) + name = CharField(max_length=100) many_to_many = NestedModelSerializer(many=True, read_only=True): id = IntegerField(label='ID', read_only=True) - name = CharField(label='name', max_length=100) + name = CharField(max_length=100) """.strip() HYPERLINKED_FLAT_REPR = """ TestSerializer(): url = HyperlinkedIdentityField(view_name='relationalmodel-detail') - foreign_key = HyperlinkedRelatedField(label='foreign key', queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail') - one_to_one = HyperlinkedRelatedField(label='one to one', queryset=OneToOneTargetModel.objects.all(), view_name='onetoonetargetmodel-detail') - many_to_many = HyperlinkedRelatedField(label='many to many', many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') + foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail') + one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), view_name='onetoonetargetmodel-detail') + many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') """.strip() @@ -116,13 +116,13 @@ TestSerializer(): url = HyperlinkedIdentityField(view_name='relationalmodel-detail') foreign_key = NestedModelSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) - name = CharField(label='name', max_length=100) + name = CharField(max_length=100) one_to_one = NestedModelSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) - name = CharField(label='name', max_length=100) + name = CharField(max_length=100) many_to_many = NestedModelSerializer(many=True, read_only=True): id = IntegerField(label='ID', read_only=True) - name = CharField(label='name', max_length=100) + name = CharField(max_length=100) """.strip() @@ -180,4 +180,4 @@ class TestSerializerMappings(TestCase): # class Meta: # model = ManyToManyTargetModel # fields = ('id', 'name', 'reverse_many_to_many') - # print repr(TestSerializer()) \ No newline at end of file + # print repr(TestSerializer())