diff --git a/rest_framework/fields.py b/rest_framework/fields.py index da2dd54be..f218713f1 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1758,6 +1758,7 @@ class JSONField(Field): def __init__(self, *args, **kwargs): self.binary = kwargs.pop('binary', False) self.encoder = kwargs.pop('encoder', None) + self.decoder = kwargs.pop('decoder', None) super().__init__(*args, **kwargs) def get_value(self, dictionary): @@ -1777,7 +1778,7 @@ class JSONField(Field): if self.binary or getattr(data, 'is_json_string', False): if isinstance(data, bytes): data = data.decode() - return json.loads(data) + return json.loads(data, cls=self.decoder) else: json.dumps(data, cls=self.encoder) except (TypeError, ValueError): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 916f8bec4..439220b34 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -884,6 +884,8 @@ class ModelSerializer(Serializer): models.GenericIPAddressField: IPAddressField, models.FilePathField: FilePathField, } + if hasattr(models, 'JSONField'): + serializer_field_mapping[models.JSONField] = JSONField if postgres_fields: serializer_field_mapping[postgres_fields.HStoreField] = HStoreField serializer_field_mapping[postgres_fields.ArrayField] = ListField @@ -1242,10 +1244,13 @@ class ModelSerializer(Serializer): # `allow_blank` is only valid for textual fields. field_kwargs.pop('allow_blank', None) - if postgres_fields and isinstance(model_field, postgres_fields.JSONField): + is_django_jsonfield = hasattr(models, 'JSONField') and isinstance(model_field, models.JSONField) + if (postgres_fields and isinstance(model_field, postgres_fields.JSONField)) or is_django_jsonfield: # Populate the `encoder` argument of `JSONField` instances generated - # for the PostgreSQL specific `JSONField`. + # for the model `JSONField`. field_kwargs['encoder'] = getattr(model_field, 'encoder', None) + if is_django_jsonfield: + field_kwargs['decoder'] = getattr(model_field, 'decoder', None) if postgres_fields and isinstance(model_field, postgres_fields.ArrayField): # Populate the `child` argument on `ListField` instances generated diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index ed270be5e..c008495cc 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -92,7 +92,8 @@ def get_field_kwargs(field_name, model_field): kwargs['allow_unicode'] = model_field.allow_unicode if isinstance(model_field, models.TextField) and not model_field.choices or \ - (postgres_fields and isinstance(model_field, postgres_fields.JSONField)): + (postgres_fields and isinstance(model_field, postgres_fields.JSONField)) or \ + (hasattr(models, 'JSONField') and isinstance(model_field, models.JSONField)): kwargs['style'] = {'base_template': 'textarea.html'} if isinstance(model_field, models.AutoField) or not model_field.editable: diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 51b8f2e22..1733930a6 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -7,6 +7,7 @@ an appropriate set of serializer fields for each case. """ import datetime import decimal +import json # noqa import sys import tempfile from collections import OrderedDict @@ -478,6 +479,7 @@ class TestPosgresFieldsMapping(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) + @pytest.mark.skipif(hasattr(models, 'JSONField'), reason='has models.JSONField') def test_json_field(self): class JSONFieldModel(models.Model): json_field = postgres_fields.JSONField() @@ -496,6 +498,30 @@ class TestPosgresFieldsMapping(TestCase): self.assertEqual(repr(TestSerializer()), expected) +class CustomJSONDecoder(json.JSONDecoder): + pass + + +@pytest.mark.skipif(not hasattr(models, 'JSONField'), reason='no models.JSONField') +class TestDjangoJSONFieldMapping(TestCase): + def test_json_field(self): + class JSONFieldModel(models.Model): + json_field = models.JSONField() + json_field_with_encoder = models.JSONField(encoder=DjangoJSONEncoder, decoder=CustomJSONDecoder) + + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = JSONFieldModel + fields = ['json_field', 'json_field_with_encoder'] + + expected = dedent(""" + TestSerializer(): + json_field = JSONField(decoder=None, encoder=None, style={'base_template': 'textarea.html'}) + json_field_with_encoder = JSONField(decoder=, encoder=, style={'base_template': 'textarea.html'}) + """) + self.assertEqual(repr(TestSerializer()), expected) + + # Tests for relational field mappings. # ------------------------------------