mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-04 04:20:12 +03:00
Merge bd4f41e2fe
into b3b0515ae6
This commit is contained in:
commit
6b9555b044
|
@ -22,6 +22,8 @@ from django.forms import widgets
|
||||||
from django.utils.datastructures import SortedDict
|
from django.utils.datastructures import SortedDict
|
||||||
from rest_framework.compat import get_concrete_model, six
|
from rest_framework.compat import get_concrete_model, six
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
from mongoengine import fields, dereference
|
||||||
|
from bson import DBRef
|
||||||
|
|
||||||
|
|
||||||
# Note: We do the following so that users of the framework can use this style:
|
# Note: We do the following so that users of the framework can use this style:
|
||||||
|
@ -1113,3 +1115,214 @@ class HyperlinkedModelSerializer(ModelSerializer):
|
||||||
'model_name': model_meta.object_name.lower()
|
'model_name': model_meta.object_name.lower()
|
||||||
}
|
}
|
||||||
return self._default_view_name % format_kwargs
|
return self._default_view_name % format_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class MongoEngineModelSerializerOptions(serializers.ModelSerializerOptions):
|
||||||
|
"""
|
||||||
|
Meta class options for MongoEngineModelSerializer
|
||||||
|
"""
|
||||||
|
def __init__(self, meta):
|
||||||
|
super(MongoEngineModelSerializerOptions, self).__init__(meta)
|
||||||
|
self.validations = getattr(meta, 'related_model_validations', {})
|
||||||
|
|
||||||
|
|
||||||
|
class MongoEngineModelSerializer(ModelSerializer):
|
||||||
|
|
||||||
|
##USAGE###################################################################
|
||||||
|
# class BlogEntrySerializer(MongoEngineModelSerializer): #
|
||||||
|
# class Meta: #
|
||||||
|
# model = BlogEntry #
|
||||||
|
# depth = 1 #
|
||||||
|
# related_model_validations = {'Category': Category, 'author': User} #
|
||||||
|
##########################################################################
|
||||||
|
_options_class = MongoEngineModelSerializerOptions
|
||||||
|
|
||||||
|
def validate_related_field(self, attrs, source, object_type):
|
||||||
|
"""
|
||||||
|
Validate related model
|
||||||
|
"""
|
||||||
|
value = attrs[source]
|
||||||
|
|
||||||
|
try:
|
||||||
|
object_type.objects.get(pk=value)
|
||||||
|
except object_type.DoesNotExist:
|
||||||
|
raise ValidationError(object_type.__name__ + ' with PK ' + value + ' does not exists.')
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
def perform_validation(self, attrs):
|
||||||
|
"""
|
||||||
|
Rest Framework built-in validation + related model validations
|
||||||
|
"""
|
||||||
|
for field_name, field in self.fields.items():
|
||||||
|
if field_name in self._errors:
|
||||||
|
continue
|
||||||
|
|
||||||
|
source = field.source or field_name
|
||||||
|
if self.partial and source not in attrs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Related Model Validations
|
||||||
|
if field_name in self.opts.validations:
|
||||||
|
try:
|
||||||
|
self.validate_related_field(attrs, source, self.opts.validations[field_name])
|
||||||
|
except ValidationError as err:
|
||||||
|
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_method = getattr(self, 'validate_%s' % field_name, None)
|
||||||
|
if validate_method:
|
||||||
|
attrs = validate_method(attrs, source)
|
||||||
|
except ValidationError as err:
|
||||||
|
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
|
||||||
|
|
||||||
|
if not self._errors:
|
||||||
|
try:
|
||||||
|
attrs = self.validate(attrs)
|
||||||
|
except ValidationError as err:
|
||||||
|
if hasattr(err, 'message_dict'):
|
||||||
|
for field_name, error_messages in err.message_dict.items():
|
||||||
|
self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
|
||||||
|
elif hasattr(err, 'messages'):
|
||||||
|
self._errors['non_field_errors'] = err.messages
|
||||||
|
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
def restore_object(self, attrs, instance=None):
|
||||||
|
if instance is not None:
|
||||||
|
for key, val in attrs.items():
|
||||||
|
try:
|
||||||
|
setattr(instance, key, val)
|
||||||
|
except ValueError:
|
||||||
|
self._errors[key] = self.error_messages['required']
|
||||||
|
|
||||||
|
else:
|
||||||
|
instance = self.opts.model(**attrs)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def get_default_fields(self):
|
||||||
|
cls = self.opts.model
|
||||||
|
opts = get_concrete_model(cls)
|
||||||
|
fields = []
|
||||||
|
fields += [getattr(opts, field) for field in opts._fields]
|
||||||
|
|
||||||
|
ret = SortedDict()
|
||||||
|
|
||||||
|
for model_field in fields:
|
||||||
|
if isinstance(model_field, fields.ObjectIdField):
|
||||||
|
field = self.get_pk_field(model_field)
|
||||||
|
else:
|
||||||
|
field = self.get_field(model_field)
|
||||||
|
|
||||||
|
if field:
|
||||||
|
field.initialize(parent=self, field_name=model_field.name)
|
||||||
|
ret[model_field.name] = field
|
||||||
|
|
||||||
|
for field_name in self.opts.read_only_fields:
|
||||||
|
assert field_name in ret,\
|
||||||
|
"read_only_fields on '%s' included invalid item '%s'" %\
|
||||||
|
(self.__class__.__name__, field_name)
|
||||||
|
ret[field_name].read_only = True
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_field(self, model_field):
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
if model_field.required:
|
||||||
|
kwargs['required'] = False
|
||||||
|
|
||||||
|
if model_field.default:
|
||||||
|
kwargs['required'] = False
|
||||||
|
kwargs['default'] = model_field.default
|
||||||
|
|
||||||
|
if model_field.__class__ == models.TextField:
|
||||||
|
kwargs['widget'] = widgets.Textarea
|
||||||
|
|
||||||
|
field_mapping = {
|
||||||
|
fields.FloatField: FloatField,
|
||||||
|
fields.IntField: IntegerField,
|
||||||
|
fields.DateTimeField: DateTimeField,
|
||||||
|
fields.EmailField: EmailField,
|
||||||
|
fields.URLField: URLField,
|
||||||
|
fields.StringField: CharField,
|
||||||
|
fields.BooleanField: BooleanField,
|
||||||
|
fields.FileField: FileField,
|
||||||
|
fields.ImageField: ImageField,
|
||||||
|
fields.ObjectIdField: Field,
|
||||||
|
fields.ReferenceField: CharField,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
return field_mapping[model_field.__class__](**kwargs)
|
||||||
|
except KeyError:
|
||||||
|
return fields.ModelField(model_field=model_field, **kwargs)
|
||||||
|
|
||||||
|
def transform_object(self, obj, fields, depth):
|
||||||
|
"""
|
||||||
|
Models to natives
|
||||||
|
Recursion for embedded models
|
||||||
|
"""
|
||||||
|
object_data = obj._data
|
||||||
|
counter = 0
|
||||||
|
multiplier = 0
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if issubclass(object_data[field].__class__, DBRef) or issubclass(object_data[field].__class__, fields.Document):
|
||||||
|
multiplier += 1
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if depth == 0:
|
||||||
|
object_data = unicode(object_data['id'])
|
||||||
|
break
|
||||||
|
elif issubclass(object_data[field].__class__, DBRef):
|
||||||
|
object_data = dereference.DeReference().__call__(object_data)
|
||||||
|
if counter < depth*multiplier:
|
||||||
|
counter += 1
|
||||||
|
object_data[field] = self.transform_object(object_data[field], object_data[field]._fields, depth-counter)
|
||||||
|
else:
|
||||||
|
object_data[field] = unicode(object_data[field].pk)
|
||||||
|
elif issubclass(object_data[field].__class__, fields.Document):
|
||||||
|
if counter < depth*multiplier:
|
||||||
|
counter += 1
|
||||||
|
object_data[field] = self.transform_object(object_data[field], object_data[field]._fields, depth-counter)
|
||||||
|
else:
|
||||||
|
object_data[field] = unicode(object_data[field].pk)
|
||||||
|
else:
|
||||||
|
object_data[field] = unicode(object_data[field])
|
||||||
|
|
||||||
|
return object_data
|
||||||
|
|
||||||
|
def to_native(self, obj):
|
||||||
|
"""
|
||||||
|
Rest framework built-in to_native + transform_object
|
||||||
|
"""
|
||||||
|
ret = self._dict_class()
|
||||||
|
ret.fields = self._dict_class()
|
||||||
|
depth = self.opts.depth
|
||||||
|
|
||||||
|
for field_name, field in self.fields.items():
|
||||||
|
if field.read_only and obj is None:
|
||||||
|
continue
|
||||||
|
field.initialize(parent=self, field_name=field_name)
|
||||||
|
key = self.get_field_key(field_name)
|
||||||
|
value = field.field_to_native(obj, field_name)
|
||||||
|
#Call transform_object if field is a related model
|
||||||
|
if issubclass(obj._data[key].__class__, fields.Document) or isinstance(obj._data[key], DBRef):
|
||||||
|
value = self.transform_object(obj._data[key], value, depth)
|
||||||
|
if not getattr(field, 'write_only', False):
|
||||||
|
ret[key] = value
|
||||||
|
ret.fields[key] = self.augment_field(field, field_name, key, value)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def from_native(self, data, files=None):
|
||||||
|
self._errors = {}
|
||||||
|
|
||||||
|
if data is not None or files is not None:
|
||||||
|
attrs = self.restore_fields(data, files)
|
||||||
|
if attrs is not None:
|
||||||
|
attrs = self.perform_validation(attrs)
|
||||||
|
else:
|
||||||
|
self._errors['non_field_errors'] = ['No input provided']
|
||||||
|
|
||||||
|
if not self._errors:
|
||||||
|
return self.restore_object(attrs, instance=getattr(self, 'object', None))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user