diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index 47e2ce993..81ee8a4d2 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -143,6 +143,17 @@ Default: `ordering` --- +## Serializer settings + +#### MODEL_SERIALIZER_FIELD_MAPPING + +Extra field mapping used to extend or override mapping of django db fields to serializer fields which is used by +ModelSerializer to set up fields for serializer. + +Default: `{}` + +--- + ## Versioning settings #### DEFAULT_VERSION diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b1b7b6477..4241c0106 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -894,48 +894,8 @@ class ModelSerializer(Serializer): * A set of default validators are automatically populated. * Default `.create()` and `.update()` implementations are provided. - The process of automatically determining a set of serializer fields - based on the model fields is reasonably complex, but you almost certainly - don't need to dig into the implementation. - - If the `ModelSerializer` class *doesn't* generate the set of fields that - you need you should either declare the extra/differing fields explicitly on - the serializer class, or simply use a `Serializer` class. """ - serializer_field_mapping = { - models.AutoField: IntegerField, - models.BigIntegerField: IntegerField, - models.BooleanField: BooleanField, - models.CharField: CharField, - models.CommaSeparatedIntegerField: CharField, - models.DateField: DateField, - models.DateTimeField: DateTimeField, - models.DecimalField: DecimalField, - models.DurationField: DurationField, - models.EmailField: EmailField, - models.Field: ModelField, - models.FileField: FileField, - models.FloatField: FloatField, - models.ImageField: ImageField, - models.IntegerField: IntegerField, - models.NullBooleanField: BooleanField, - models.PositiveIntegerField: IntegerField, - models.PositiveSmallIntegerField: IntegerField, - models.SlugField: SlugField, - models.SmallIntegerField: IntegerField, - models.TextField: CharField, - models.TimeField: TimeField, - models.URLField: URLField, - models.UUIDField: UUIDField, - models.GenericIPAddressField: IPAddressField, - models.FilePathField: FilePathField, - } - if hasattr(models, 'JSONField'): - serializer_field_mapping[models.JSONField] = JSONField - if postgres_fields: - serializer_field_mapping[postgres_fields.HStoreField] = HStoreField - serializer_field_mapping[postgres_fields.ArrayField] = ListField - serializer_field_mapping[postgres_fields.JSONField] = JSONField + serializer_related_field = PrimaryKeyRelatedField serializer_related_to_field = SlugRelatedField serializer_url_field = HyperlinkedIdentityField @@ -950,6 +910,61 @@ class ModelSerializer(Serializer): # "HTTP 201 Created" responses. url_field_name = None + @property + def serializer_field_mapping(self): + """Get mapping of django model field to serializer field. + + The process of automatically determining a set of serializer fields + based on the model fields is reasonably complex, but you almost certainly + don't need to dig into the implementation. + + If the `ModelSerializer` class *doesn't* generate the set of fields that + you need you should either extend serializer_field_mapping with + the extra/differing fields explicitly, or simply use a `Serializer` + class. + + """ + serializer_field_mapping = { + models.AutoField: IntegerField, + models.BigIntegerField: IntegerField, + models.BooleanField: BooleanField, + models.CharField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.DateField: DateField, + models.DateTimeField: DateTimeField, + models.DecimalField: DecimalField, + models.DurationField: DurationField, + models.EmailField: EmailField, + models.Field: ModelField, + models.FileField: FileField, + models.FloatField: FloatField, + models.ImageField: ImageField, + models.IntegerField: IntegerField, + models.NullBooleanField: BooleanField, + models.PositiveIntegerField: IntegerField, + models.PositiveSmallIntegerField: IntegerField, + models.SlugField: SlugField, + models.SmallIntegerField: IntegerField, + models.TextField: CharField, + models.TimeField: TimeField, + models.URLField: URLField, + models.UUIDField: UUIDField, + models.GenericIPAddressField: IPAddressField, + models.FilePathField: FilePathField, + } + if hasattr(models, 'JSONField'): + serializer_field_mapping[models.JSONField] = JSONField + if postgres_fields: + serializer_field_mapping[postgres_fields.HStoreField] = HStoreField + serializer_field_mapping[postgres_fields.ArrayField] = ListField + serializer_field_mapping[postgres_fields.JSONField] = JSONField + for ( + model_field, + serializer_field, + ) in api_settings.MODEL_SERIALIZER_FIELD_MAPPING.items(): + serializer_field_mapping[model_field] = serializer_field + return serializer_field_mapping + # Default `create` and `update` behavior... def create(self, validated_data): """ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index b0d7bacec..5f42bc455 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -126,6 +126,9 @@ DEFAULTS = { 'retrieve': 'read', 'destroy': 'delete' }, + + # Serializers + 'MODEL_SERIALIZER_FIELD_MAPPING': {} } @@ -147,7 +150,8 @@ IMPORT_STRINGS = [ 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', 'VIEW_NAME_FUNCTION', - 'VIEW_DESCRIPTION_FUNCTION' + 'VIEW_DESCRIPTION_FUNCTION', + 'MODEL_SERIALIZER_FIELD_MAPPING', ] @@ -168,6 +172,16 @@ def perform_import(val, setting_name): return import_from_string(val, setting_name) elif isinstance(val, (list, tuple)): return [import_from_string(item, setting_name) for item in val] + elif isinstance(val, (dict)): + return { + import_from_string( + key, + setting_name, + ): import_from_string( + value, + setting_name, + ) for key, value in val.items() + } return val diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index ae1a2b0fa..719093032 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -21,7 +21,7 @@ from django.core.validators import ( from django.db import models from django.db.models.signals import m2m_changed from django.dispatch import receiver -from django.test import TestCase +from django.test import TestCase, override_settings from rest_framework import serializers from rest_framework.compat import postgres_fields @@ -43,6 +43,12 @@ class CustomField(models.Field): pass +class CustomCharFieldField(serializers.CharField): + """ + A custom serializer field simply for testing purposes. + """ + + class OneFieldModel(models.Model): char_field = models.CharField(max_length=100) @@ -194,6 +200,32 @@ class TestRegularFieldMappings(TestCase): custom_field = ModelField\(model_field=\) file_path_field = FilePathField\(path=%r\) """ % tempfile.gettempdir()) + print(expected) + assert re.search(expected, repr(TestSerializer())) is not None + + @override_settings( + REST_FRAMEWORK={ + 'MODEL_SERIALIZER_FIELD_MAPPING': { + 'django.db.models.CharField': 'tests.test_model_serializer.CustomCharFieldField', + } + }, + ) + def test_custom_fields(self): + """ + If MODEL_SERIALIZER_FIELD_MAPPING is set than model fields should map + to their equivalent serializer fields. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ( + "char_field", + ) + + expected = dedent(r""" + TestSerializer\(\): + char_field = CustomCharFieldField\(max_length=100\) + """) assert re.search(expected, repr(TestSerializer())) is not None def test_field_options(self):