diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 25be2e6ab..60080b29b 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -13,6 +13,8 @@ from django.http.multipartparser import MultiPartParserError, parse_header, Chun from rest_framework.compat import yaml, etree from rest_framework.exceptions import ParseError from rest_framework.compat import six +from rest_framework.utils.datastructures import TokenExpandedDict +from rest_framework.settings import api_settings import json import datetime import decimal @@ -40,6 +42,16 @@ class BaseParser(object): """ raise NotImplementedError(".parse() must be overridden.") + def _parse_tokenization(self, data): + """ + Configuration dependant processing of input data, where character tokens + (such as periods '.') can be used to reshape for data into nested form, + for use with NestedModelSerializer. + """ + if api_settings.NESTED_FIELDS: + data = TokenExpandedDict(data) + return data + class JSONParser(BaseParser): """ @@ -108,6 +120,8 @@ class FormParser(BaseParser): parser_context = parser_context or {} encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) data = QueryDict(stream.read(), encoding=encoding) + data = self._parse_tokenization(data) + return data @@ -134,6 +148,8 @@ class MultiPartParser(BaseParser): try: parser = DjangoMultiPartParser(meta, stream, upload_handlers, encoding) data, files = parser.parse() + data = self._parse_tokenization(data) + return DataAndFiles(data, files) except MultiPartParserError as exc: raise ParseError('Multipart form parse error - %s' % six.u(exc)) diff --git a/rest_framework/request.py b/rest_framework/request.py index 0d88ebc7e..b08eb9939 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -150,8 +150,8 @@ class Request(object): Similar to usual behaviour of `request.POST`, except that it handles arbitrary parsers, and also works on methods other than POST (eg PUT). """ - if not _hasattr(self, '_data'): - self._load_data_and_files() + #if not _hasattr(self, '_data'): + # self._load_data_and_files() return self._data @property @@ -277,8 +277,7 @@ class Request(object): return # At this point we're committed to parsing the request as form data. - self._data = self._request.POST - self._files = self._request.FILES + self._data, self._files = self._parse() # Method overloading - change the method and remove the param from the content. if (self._METHOD_PARAM and diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2b260c256..cdb384615 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -887,43 +887,66 @@ class ModelSerializer(Serializer): """ Save the deserialized object and return it. """ - if getattr(obj, '_nested_forward_relations', None): - # Nested relationships need to be saved before we can save the - # parent instance. - for field_name, sub_object in obj._nested_forward_relations.items(): - if sub_object: - self.save_object(sub_object) - setattr(obj, field_name, sub_object) + def save_nested_forward_relations(obj): + """ + Save nested nested forward relations + """ + if getattr(obj, '_nested_forward_relations', None): + # Nested relationships need to be saved before we can save the + # parent instance. + for field_name, sub_object in obj._nested_forward_relations.items(): + if sub_object: + self.save_object(sub_object) + setattr(obj, field_name, sub_object) - obj.save(**kwargs) + obj.save(**kwargs) - if getattr(obj, '_m2m_data', None): - for accessor_name, object_list in obj._m2m_data.items(): - setattr(obj, accessor_name, object_list) - del(obj._m2m_data) + def save_m2m(obj): + """ + Save nested ManyToMany relations + """ + if getattr(obj, '_m2m_data', None) and hasattr(obj._m2m_data, '__iter__'): + for accessor_name, object_list in obj._m2m_data.items(): + if hasattr(object_list, '__iter__'): + for m2m_object in object_list: + save_nested_forward_relations(m2m_object) + save_m2m(m2m_object) + save_related_data(m2m_object) + setattr(obj, accessor_name, object_list) + m2m_object.save() + del(obj._m2m_data) - if getattr(obj, '_related_data', None): - for accessor_name, related in obj._related_data.items(): - if isinstance(related, RelationsList): - # Nested reverse fk relationship - for related_item in related: + def save_related_data(obj): + """ + Save nested nested related data + """ + if getattr(obj, '_related_data', None): + for accessor_name, related in obj._related_data.items(): + if isinstance(related, RelationsList): + # Nested reverse fk relationship + for related_item in related: + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related_item, fk_field, obj) + self.save_object(related_item) + + # Delete any removed objects + if related._deleted: + [self.delete_object(item) for item in related._deleted] + + elif isinstance(related, models.Model): + # Nested reverse one-one relationship fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) + setattr(related, fk_field, obj) + self.save_object(related) + else: + # Reverse FK or reverse one-one + setattr(obj, accessor_name, related) + del(obj._related_data) - # Delete any removed objects - if related._deleted: - [self.delete_object(item) for item in related._deleted] - - elif isinstance(related, models.Model): - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) - else: - # Reverse FK or reverse one-one - setattr(obj, accessor_name, related) - del(obj._related_data) + # Save + save_nested_forward_relations(obj) + save_m2m(obj) + save_related_data(obj) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index beb511aca..4b75397b5 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -65,6 +65,10 @@ DEFAULTS = { 'anon': None, }, + # Nested (multi-dimensional) field names + 'NESTED_FIELDS': False, + 'NESTED_FIELD_TOKENIZER': '.', + # Pagination 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, diff --git a/rest_framework/utils/datastructures.py b/rest_framework/utils/datastructures.py new file mode 100755 index 000000000..57012d1e7 --- /dev/null +++ b/rest_framework/utils/datastructures.py @@ -0,0 +1,38 @@ +""" +Utility functions for reshaping datastructures +""" +from rest_framework.settings import api_settings +from django.http import QueryDict + +class TokenExpandedDict(QueryDict): + """ + A special dictionary constructor that takes a dictionary in which the keys + may contain dots to specify inner dictionaries. It's confusing, but this + example should make sense. + + >>> d = TokenExpandedDict({'person.1.firstname': ['Simon'], \ + 'person.1.lastname': ['Willison'], \ + 'person.2.firstname': ['Adrian'], \ + 'person.2.lastname': ['Holovaty']}) + >>> d + {'person': {'1': {'lastname': ['Willison'], 'firstname': ['Simon']}, '2': {'lastname': ['Holovaty'], 'firstname': ['Adrian']}}} + >>> d['person'] + {'1': {'lastname': ['Willison'], 'firstname': ['Simon']}, '2': {'lastname': ['Holovaty'], 'firstname': ['Adrian']}} + >>> d['person']['1'] + {'lastname': ['Willison'], 'firstname': ['Simon']} + + # Gotcha: Results are unpredictable if the dots are "uneven": + >>> TokenExpandedDict({'c.1': 2, 'c.2': 3, 'c': 1}) + {'c': 1} + """ + def __init__(self, key_to_list_mapping): + for k, v in key_to_list_mapping.items(): + current = self + bits = k.split(api_settings.NESTED_FIELD_TOKENIZER) + for bit in bits[:-1]: + current = current.setdefault(bit, {}) + # Now assign value to current position + try: + current[bits[-1]] = v + except TypeError: # Special-case if current isn't a dict. + current = {bits[-1]: v} \ No newline at end of file