diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d8e544d47..fadb5bea2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -661,6 +661,9 @@ class ModelSerializer(Serializer): }) _related_class = PrimaryKeyRelatedField + def create_instance(self, model_class, validated_data): + return model_class.objects.create(**validated_data) + def create(self, validated_data): """ We have a bit of extra checking around this in order to provide @@ -696,7 +699,7 @@ class ModelSerializer(Serializer): many_to_many[field_name] = validated_data.pop(field_name) try: - instance = ModelClass.objects.create(**validated_data) + instance = self.create_instance(ModelClass, validated_data) except TypeError as exc: msg = ( 'Got a `TypeError` when calling `%s.objects.create()`. ' @@ -721,13 +724,15 @@ class ModelSerializer(Serializer): return instance + def save_instance(self, instance): + instance.save() + def update(self, instance, validated_data): raise_errors_on_nested_writes('update', self) for attr, value in validated_data.items(): setattr(instance, attr, value) - instance.save() - + self.save_instance(instance) return instance def get_validators(self):