mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-03 20:10:10 +03:00
Merge bd4f41e2fe
into b3b0515ae6
This commit is contained in:
commit
6b9555b044
|
@ -22,6 +22,8 @@ from django.forms import widgets
|
|||
from django.utils.datastructures import SortedDict
|
||||
from rest_framework.compat import get_concrete_model, six
|
||||
from rest_framework.settings import api_settings
|
||||
from mongoengine import fields, dereference
|
||||
from bson import DBRef
|
||||
|
||||
|
||||
# Note: We do the following so that users of the framework can use this style:
|
||||
|
@ -1113,3 +1115,214 @@ class HyperlinkedModelSerializer(ModelSerializer):
|
|||
'model_name': model_meta.object_name.lower()
|
||||
}
|
||||
return self._default_view_name % format_kwargs
|
||||
|
||||
|
||||
class MongoEngineModelSerializerOptions(serializers.ModelSerializerOptions):
|
||||
"""
|
||||
Meta class options for MongoEngineModelSerializer
|
||||
"""
|
||||
def __init__(self, meta):
|
||||
super(MongoEngineModelSerializerOptions, self).__init__(meta)
|
||||
self.validations = getattr(meta, 'related_model_validations', {})
|
||||
|
||||
|
||||
class MongoEngineModelSerializer(ModelSerializer):
|
||||
|
||||
##USAGE###################################################################
|
||||
# class BlogEntrySerializer(MongoEngineModelSerializer): #
|
||||
# class Meta: #
|
||||
# model = BlogEntry #
|
||||
# depth = 1 #
|
||||
# related_model_validations = {'Category': Category, 'author': User} #
|
||||
##########################################################################
|
||||
_options_class = MongoEngineModelSerializerOptions
|
||||
|
||||
def validate_related_field(self, attrs, source, object_type):
|
||||
"""
|
||||
Validate related model
|
||||
"""
|
||||
value = attrs[source]
|
||||
|
||||
try:
|
||||
object_type.objects.get(pk=value)
|
||||
except object_type.DoesNotExist:
|
||||
raise ValidationError(object_type.__name__ + ' with PK ' + value + ' does not exists.')
|
||||
return attrs
|
||||
|
||||
def perform_validation(self, attrs):
|
||||
"""
|
||||
Rest Framework built-in validation + related model validations
|
||||
"""
|
||||
for field_name, field in self.fields.items():
|
||||
if field_name in self._errors:
|
||||
continue
|
||||
|
||||
source = field.source or field_name
|
||||
if self.partial and source not in attrs:
|
||||
continue
|
||||
|
||||
# Related Model Validations
|
||||
if field_name in self.opts.validations:
|
||||
try:
|
||||
self.validate_related_field(attrs, source, self.opts.validations[field_name])
|
||||
except ValidationError as err:
|
||||
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
|
||||
|
||||
try:
|
||||
validate_method = getattr(self, 'validate_%s' % field_name, None)
|
||||
if validate_method:
|
||||
attrs = validate_method(attrs, source)
|
||||
except ValidationError as err:
|
||||
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
|
||||
|
||||
if not self._errors:
|
||||
try:
|
||||
attrs = self.validate(attrs)
|
||||
except ValidationError as err:
|
||||
if hasattr(err, 'message_dict'):
|
||||
for field_name, error_messages in err.message_dict.items():
|
||||
self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
|
||||
elif hasattr(err, 'messages'):
|
||||
self._errors['non_field_errors'] = err.messages
|
||||
|
||||
return attrs
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
if instance is not None:
|
||||
for key, val in attrs.items():
|
||||
try:
|
||||
setattr(instance, key, val)
|
||||
except ValueError:
|
||||
self._errors[key] = self.error_messages['required']
|
||||
|
||||
else:
|
||||
instance = self.opts.model(**attrs)
|
||||
return instance
|
||||
|
||||
def get_default_fields(self):
|
||||
cls = self.opts.model
|
||||
opts = get_concrete_model(cls)
|
||||
fields = []
|
||||
fields += [getattr(opts, field) for field in opts._fields]
|
||||
|
||||
ret = SortedDict()
|
||||
|
||||
for model_field in fields:
|
||||
if isinstance(model_field, fields.ObjectIdField):
|
||||
field = self.get_pk_field(model_field)
|
||||
else:
|
||||
field = self.get_field(model_field)
|
||||
|
||||
if field:
|
||||
field.initialize(parent=self, field_name=model_field.name)
|
||||
ret[model_field.name] = field
|
||||
|
||||
for field_name in self.opts.read_only_fields:
|
||||
assert field_name in ret,\
|
||||
"read_only_fields on '%s' included invalid item '%s'" %\
|
||||
(self.__class__.__name__, field_name)
|
||||
ret[field_name].read_only = True
|
||||
|
||||
return ret
|
||||
|
||||
def get_field(self, model_field):
|
||||
kwargs = {}
|
||||
|
||||
if model_field.required:
|
||||
kwargs['required'] = False
|
||||
|
||||
if model_field.default:
|
||||
kwargs['required'] = False
|
||||
kwargs['default'] = model_field.default
|
||||
|
||||
if model_field.__class__ == models.TextField:
|
||||
kwargs['widget'] = widgets.Textarea
|
||||
|
||||
field_mapping = {
|
||||
fields.FloatField: FloatField,
|
||||
fields.IntField: IntegerField,
|
||||
fields.DateTimeField: DateTimeField,
|
||||
fields.EmailField: EmailField,
|
||||
fields.URLField: URLField,
|
||||
fields.StringField: CharField,
|
||||
fields.BooleanField: BooleanField,
|
||||
fields.FileField: FileField,
|
||||
fields.ImageField: ImageField,
|
||||
fields.ObjectIdField: Field,
|
||||
fields.ReferenceField: CharField,
|
||||
}
|
||||
try:
|
||||
return field_mapping[model_field.__class__](**kwargs)
|
||||
except KeyError:
|
||||
return fields.ModelField(model_field=model_field, **kwargs)
|
||||
|
||||
def transform_object(self, obj, fields, depth):
|
||||
"""
|
||||
Models to natives
|
||||
Recursion for embedded models
|
||||
"""
|
||||
object_data = obj._data
|
||||
counter = 0
|
||||
multiplier = 0
|
||||
|
||||
for field in fields:
|
||||
if issubclass(object_data[field].__class__, DBRef) or issubclass(object_data[field].__class__, fields.Document):
|
||||
multiplier += 1
|
||||
|
||||
for field in fields:
|
||||
if depth == 0:
|
||||
object_data = unicode(object_data['id'])
|
||||
break
|
||||
elif issubclass(object_data[field].__class__, DBRef):
|
||||
object_data = dereference.DeReference().__call__(object_data)
|
||||
if counter < depth*multiplier:
|
||||
counter += 1
|
||||
object_data[field] = self.transform_object(object_data[field], object_data[field]._fields, depth-counter)
|
||||
else:
|
||||
object_data[field] = unicode(object_data[field].pk)
|
||||
elif issubclass(object_data[field].__class__, fields.Document):
|
||||
if counter < depth*multiplier:
|
||||
counter += 1
|
||||
object_data[field] = self.transform_object(object_data[field], object_data[field]._fields, depth-counter)
|
||||
else:
|
||||
object_data[field] = unicode(object_data[field].pk)
|
||||
else:
|
||||
object_data[field] = unicode(object_data[field])
|
||||
|
||||
return object_data
|
||||
|
||||
def to_native(self, obj):
|
||||
"""
|
||||
Rest framework built-in to_native + transform_object
|
||||
"""
|
||||
ret = self._dict_class()
|
||||
ret.fields = self._dict_class()
|
||||
depth = self.opts.depth
|
||||
|
||||
for field_name, field in self.fields.items():
|
||||
if field.read_only and obj is None:
|
||||
continue
|
||||
field.initialize(parent=self, field_name=field_name)
|
||||
key = self.get_field_key(field_name)
|
||||
value = field.field_to_native(obj, field_name)
|
||||
#Call transform_object if field is a related model
|
||||
if issubclass(obj._data[key].__class__, fields.Document) or isinstance(obj._data[key], DBRef):
|
||||
value = self.transform_object(obj._data[key], value, depth)
|
||||
if not getattr(field, 'write_only', False):
|
||||
ret[key] = value
|
||||
ret.fields[key] = self.augment_field(field, field_name, key, value)
|
||||
|
||||
return ret
|
||||
|
||||
def from_native(self, data, files=None):
|
||||
self._errors = {}
|
||||
|
||||
if data is not None or files is not None:
|
||||
attrs = self.restore_fields(data, files)
|
||||
if attrs is not None:
|
||||
attrs = self.perform_validation(attrs)
|
||||
else:
|
||||
self._errors['non_field_errors'] = ['No input provided']
|
||||
|
||||
if not self._errors:
|
||||
return self.restore_object(attrs, instance=getattr(self, 'object', None))
|
||||
|
|
Loading…
Reference in New Issue
Block a user