From bd4f41e2fe170a772ab4381fb5fceac1db5e3506 Mon Sep 17 00:00:00 2001 From: Umut Bozkurt Date: Fri, 11 Apr 2014 12:29:10 +0300 Subject: [PATCH] mongoengine support for serialisers code needs a lot of refactoring --- rest_framework/serializers.py | 213 ++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index cb7539e0b..f65923d28 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -22,6 +22,8 @@ from django.forms import widgets from django.utils.datastructures import SortedDict from rest_framework.compat import get_concrete_model, six from rest_framework.settings import api_settings +from mongoengine import fields, dereference +from bson import DBRef # Note: We do the following so that users of the framework can use this style: @@ -1113,3 +1115,214 @@ class HyperlinkedModelSerializer(ModelSerializer): 'model_name': model_meta.object_name.lower() } return self._default_view_name % format_kwargs + + +class MongoEngineModelSerializerOptions(serializers.ModelSerializerOptions): + """ + Meta class options for MongoEngineModelSerializer + """ + def __init__(self, meta): + super(MongoEngineModelSerializerOptions, self).__init__(meta) + self.validations = getattr(meta, 'related_model_validations', {}) + + +class MongoEngineModelSerializer(ModelSerializer): + + ##USAGE################################################################### + # class BlogEntrySerializer(MongoEngineModelSerializer): # + # class Meta: # + # model = BlogEntry # + # depth = 1 # + # related_model_validations = {'Category': Category, 'author': User} # + ########################################################################## + _options_class = MongoEngineModelSerializerOptions + + def validate_related_field(self, attrs, source, object_type): + """ + Validate related model + """ + value = attrs[source] + + try: + object_type.objects.get(pk=value) + except object_type.DoesNotExist: + raise ValidationError(object_type.__name__ + ' with PK ' + value + ' does not exists.') + return attrs + + def perform_validation(self, attrs): + """ + Rest Framework built-in validation + related model validations + """ + for field_name, field in self.fields.items(): + if field_name in self._errors: + continue + + source = field.source or field_name + if self.partial and source not in attrs: + continue + + # Related Model Validations + if field_name in self.opts.validations: + try: + self.validate_related_field(attrs, source, self.opts.validations[field_name]) + except ValidationError as err: + self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) + + try: + validate_method = getattr(self, 'validate_%s' % field_name, None) + if validate_method: + attrs = validate_method(attrs, source) + except ValidationError as err: + self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) + + if not self._errors: + try: + attrs = self.validate(attrs) + except ValidationError as err: + if hasattr(err, 'message_dict'): + for field_name, error_messages in err.message_dict.items(): + self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) + elif hasattr(err, 'messages'): + self._errors['non_field_errors'] = err.messages + + return attrs + + def restore_object(self, attrs, instance=None): + if instance is not None: + for key, val in attrs.items(): + try: + setattr(instance, key, val) + except ValueError: + self._errors[key] = self.error_messages['required'] + + else: + instance = self.opts.model(**attrs) + return instance + + def get_default_fields(self): + cls = self.opts.model + opts = get_concrete_model(cls) + fields = [] + fields += [getattr(opts, field) for field in opts._fields] + + ret = SortedDict() + + for model_field in fields: + if isinstance(model_field, fields.ObjectIdField): + field = self.get_pk_field(model_field) + else: + field = self.get_field(model_field) + + if field: + field.initialize(parent=self, field_name=model_field.name) + ret[model_field.name] = field + + for field_name in self.opts.read_only_fields: + assert field_name in ret,\ + "read_only_fields on '%s' included invalid item '%s'" %\ + (self.__class__.__name__, field_name) + ret[field_name].read_only = True + + return ret + + def get_field(self, model_field): + kwargs = {} + + if model_field.required: + kwargs['required'] = False + + if model_field.default: + kwargs['required'] = False + kwargs['default'] = model_field.default + + if model_field.__class__ == models.TextField: + kwargs['widget'] = widgets.Textarea + + field_mapping = { + fields.FloatField: FloatField, + fields.IntField: IntegerField, + fields.DateTimeField: DateTimeField, + fields.EmailField: EmailField, + fields.URLField: URLField, + fields.StringField: CharField, + fields.BooleanField: BooleanField, + fields.FileField: FileField, + fields.ImageField: ImageField, + fields.ObjectIdField: Field, + fields.ReferenceField: CharField, + } + try: + return field_mapping[model_field.__class__](**kwargs) + except KeyError: + return fields.ModelField(model_field=model_field, **kwargs) + + def transform_object(self, obj, fields, depth): + """ + Models to natives + Recursion for embedded models + """ + object_data = obj._data + counter = 0 + multiplier = 0 + + for field in fields: + if issubclass(object_data[field].__class__, DBRef) or issubclass(object_data[field].__class__, fields.Document): + multiplier += 1 + + for field in fields: + if depth == 0: + object_data = unicode(object_data['id']) + break + elif issubclass(object_data[field].__class__, DBRef): + object_data = dereference.DeReference().__call__(object_data) + if counter < depth*multiplier: + counter += 1 + object_data[field] = self.transform_object(object_data[field], object_data[field]._fields, depth-counter) + else: + object_data[field] = unicode(object_data[field].pk) + elif issubclass(object_data[field].__class__, fields.Document): + if counter < depth*multiplier: + counter += 1 + object_data[field] = self.transform_object(object_data[field], object_data[field]._fields, depth-counter) + else: + object_data[field] = unicode(object_data[field].pk) + else: + object_data[field] = unicode(object_data[field]) + + return object_data + + def to_native(self, obj): + """ + Rest framework built-in to_native + transform_object + """ + ret = self._dict_class() + ret.fields = self._dict_class() + depth = self.opts.depth + + for field_name, field in self.fields.items(): + if field.read_only and obj is None: + continue + field.initialize(parent=self, field_name=field_name) + key = self.get_field_key(field_name) + value = field.field_to_native(obj, field_name) + #Call transform_object if field is a related model + if issubclass(obj._data[key].__class__, fields.Document) or isinstance(obj._data[key], DBRef): + value = self.transform_object(obj._data[key], value, depth) + if not getattr(field, 'write_only', False): + ret[key] = value + ret.fields[key] = self.augment_field(field, field_name, key, value) + + return ret + + def from_native(self, data, files=None): + self._errors = {} + + if data is not None or files is not None: + attrs = self.restore_fields(data, files) + if attrs is not None: + attrs = self.perform_validation(attrs) + else: + self._errors['non_field_errors'] = ['No input provided'] + + if not self._errors: + return self.restore_object(attrs, instance=getattr(self, 'object', None))