From d8eb9e6d45c227582559ec4318b1f92562c718da Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 10:48:16 +0100 Subject: [PATCH 01/13] Docs whitespace fix. --- docs/api-guide/generic-views.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md index b1c4e65ad..d30b7f9bf 100755 --- a/docs/api-guide/generic-views.md +++ b/docs/api-guide/generic-views.md @@ -19,8 +19,8 @@ Typically when using the generic views, you'll override the view, and set severa from django.contrib.auth.models import User from myapp.serializers import UserSerializer - from rest_framework import generics - from rest_framework.permissions import IsAdminUser + from rest_framework import generics + from rest_framework.permissions import IsAdminUser class UserList(generics.ListCreateAPIView): queryset = User.objects.all() From f62c874ea9621ae67fb56e7e453dca8fd5039051 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 10:48:40 +0100 Subject: [PATCH 02/13] Remove `filter_backend`. Closes #1775. --- rest_framework/generics.py | 20 +------------------- rest_framework/settings.py | 10 ---------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index a6f686571..8bacf470b 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -83,7 +83,6 @@ class GenericAPIView(views.APIView): slug_url_kwarg = 'slug' slug_field = 'slug' allow_empty = True - filter_backend = api_settings.FILTER_BACKEND def get_serializer_context(self): """ @@ -191,24 +190,7 @@ class GenericAPIView(views.APIView): """ Returns the list of filter backends that this view requires. """ - if self.filter_backends is None: - filter_backends = [] - else: - # Note that we are returning a *copy* of the class attribute, - # so that it is safe for the view to mutate it if needed. - filter_backends = list(self.filter_backends) - - if not filter_backends and self.filter_backend: - warnings.warn( - 'The `filter_backend` attribute and `FILTER_BACKEND` setting ' - 'are deprecated in favor of a `filter_backends` ' - 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' - 'a *list* of filter backend classes.', - DeprecationWarning, stacklevel=2 - ) - filter_backends = [self.filter_backend] - - return filter_backends + return list(self.filter_backends) # The following methods provide default implementations # that you may want to override for more complex cases. diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 644751f87..bbe7a56ad 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -111,9 +111,6 @@ DEFAULTS = { ), 'TIME_FORMAT': None, - # Pending deprecation - 'FILTER_BACKEND': None, - } @@ -129,7 +126,6 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', 'EXCEPTION_HANDLER', - 'FILTER_BACKEND', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', @@ -196,15 +192,9 @@ class APISettings(object): if val and attr in self.import_strings: val = perform_import(val, attr) - self.validate_setting(attr, val) - # Cache the result setattr(self, attr, val) return val - def validate_setting(self, attr, val): - if attr == 'FILTER_BACKEND' and val is not None: - # Make sure we can initialize the class - val() api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) From 0f8fdf4e72b67ff46474c13c8b532bf319a58099 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 10:57:24 +0100 Subject: [PATCH 03/13] Remove `allow_empty`. Closes #1774. --- docs/api-guide/generic-views.md | 2 -- rest_framework/generics.py | 12 +----------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md index d30b7f9bf..49be0cae8 100755 --- a/docs/api-guide/generic-views.md +++ b/docs/api-guide/generic-views.md @@ -212,8 +212,6 @@ Provides a `.list(request, *args, **kwargs)` method, that implements listing a q If the queryset is populated, this returns a `200 OK` response, with a serialized representation of the queryset as the body of the response. The response data may optionally be paginated. -If the queryset is empty this returns a `200 OK` response, unless the `.allow_empty` attribute on the view is set to `False`, in which case it will return a `404 Not Found`. - ## CreateModelMixin Provides a `.create(request, *args, **kwargs)` method, that implements creating and saving a new model instance. diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8bacf470b..cb8061b75 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -82,7 +82,6 @@ class GenericAPIView(views.APIView): pk_url_kwarg = 'pk' slug_url_kwarg = 'slug' slug_field = 'slug' - allow_empty = True def get_serializer_context(self): """ @@ -140,16 +139,7 @@ class GenericAPIView(views.APIView): if not page_size: return None - if not self.allow_empty: - warnings.warn( - 'The `allow_empty` parameter is deprecated. ' - 'To use `allow_empty=False` style behavior, You should override ' - '`get_queryset()` and explicitly raise a 404 on empty querysets.', - DeprecationWarning, stacklevel=2 - ) - - paginator = self.paginator_class(queryset, page_size, - allow_empty_first_page=self.allow_empty) + paginator = self.paginator_class(queryset, page_size) page_kwarg = self.kwargs.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 From b3bbf416707cf8c71861b0fd6e966a557acef412 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 11:09:35 +0100 Subject: [PATCH 04/13] Remove `allow_empty` --- rest_framework/mixins.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2cc87eef1..dc4c9f353 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -70,24 +70,9 @@ class ListModelMixin(object): """ List a queryset. """ - empty_error = "Empty list and '%(class_name)s.allow_empty' is False." - def list(self, request, *args, **kwargs): self.object_list = self.filter_queryset(self.get_queryset()) - # Default is to allow empty querysets. This can be altered by setting - # `.allow_empty = False`, to raise 404 errors on empty querysets. - if not self.allow_empty and not self.object_list: - warnings.warn( - 'The `allow_empty` parameter is deprecated. ' - 'To use `allow_empty=False` style behavior, You should override ' - '`get_queryset()` and explicitly raise a 404 on empty querysets.', - DeprecationWarning - ) - class_name = self.__class__.__name__ - error_msg = self.empty_error % {'class_name': class_name} - raise Http404(error_msg) - # Switch between paginated or standard style responses page = self.paginate_queryset(self.object_list) if page is not None: From e5e6329a222def3b0745f90fc55ee36de95ada83 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 11:29:26 +0100 Subject: [PATCH 05/13] Remove `pk_url_field`, `slug_url_field`, `slug_field`. Closes #1773. --- rest_framework/generics.py | 36 ++--------- rest_framework/mixins.py | 36 +++-------- rest_framework/relations.py | 117 ++---------------------------------- 3 files changed, 15 insertions(+), 174 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index cb8061b75..e21dc5c7b 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -76,13 +76,6 @@ class GenericAPIView(views.APIView): model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS paginator_class = Paginator - ###################################### - # These are pending deprecation... - - pk_url_kwarg = 'pk' - slug_url_kwarg = 'slug' - slug_field = 'slug' - def get_serializer_context(self): """ Extra context provided to the serializer class. @@ -270,7 +263,7 @@ class GenericAPIView(views.APIView): error_format = "'%s' must define 'queryset' or 'model'" raise ImproperlyConfigured(error_format % self.__class__.__name__) - def get_object(self, queryset=None): + def get_object(self): """ Returns the object the view is displaying. @@ -278,36 +271,14 @@ class GenericAPIView(views.APIView): queryset lookups. Eg if objects are referenced using multiple keyword arguments in the url conf. """ - # Determine the base queryset to use. - if queryset is None: - queryset = self.filter_queryset(self.get_queryset()) - else: - pass # Deprecation warning + queryset = self.filter_queryset(self.get_queryset()) # Perform the lookup filtering. # Note that `pk` and `slug` are deprecated styles of lookup filtering. lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup = self.kwargs.get(lookup_url_kwarg, None) - pk = self.kwargs.get(self.pk_url_kwarg, None) - slug = self.kwargs.get(self.slug_url_kwarg, None) - if lookup is not None: - filter_kwargs = {self.lookup_field: lookup} - elif pk is not None and self.lookup_field == 'pk': - warnings.warn( - 'The `pk_url_kwarg` attribute is deprecated. ' - 'Use the `lookup_field` attribute instead', - DeprecationWarning - ) - filter_kwargs = {'pk': pk} - elif slug is not None and self.lookup_field == 'pk': - warnings.warn( - 'The `slug_url_kwarg` attribute is deprecated. ' - 'Use the `lookup_field` attribute instead', - DeprecationWarning - ) - filter_kwargs = {self.slug_field: slug} - else: + if lookup is None: raise ImproperlyConfigured( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' @@ -315,6 +286,7 @@ class GenericAPIView(views.APIView): (self.__class__.__name__, self.lookup_field) ) + filter_kwargs = {self.lookup_field: lookup} obj = get_object_or_404(queryset, **filter_kwargs) # May raise a permission denied diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index dc4c9f353..ac59d9795 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -12,10 +12,9 @@ from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request from rest_framework.settings import api_settings -import warnings -def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): +def _get_validation_exclusions(obj, lookup_field=None): """ Given a model instance, and an optional pk and slug field, return the full list of all other field names on that model. @@ -23,23 +22,13 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None) For use when performing full_clean on a model instance, so we only clean the required fields. """ - include = [] - - if pk: - # Deprecated + if lookup_field == 'pk': pk_field = obj._meta.pk while pk_field.rel: pk_field = pk_field.rel.to._meta.pk - include.append(pk_field.name) + lookup_field = pk_field.name - if slug_field: - # Deprecated - include.append(slug_field) - - if lookup_field and lookup_field != 'pk': - include.append(lookup_field) - - return [field.name for field in obj._meta.fields if field.name not in include] + return [field.name for field in obj._meta.fields if field.name != lookup_field] class CreateModelMixin(object): @@ -146,26 +135,15 @@ class UpdateModelMixin(object): """ Set any attributes on the object that are implicit in the request. """ - # pk and/or slug attributes are implicit in the URL. lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup = self.kwargs.get(lookup_url_kwarg, None) - pk = self.kwargs.get(self.pk_url_kwarg, None) - slug = self.kwargs.get(self.slug_url_kwarg, None) - slug_field = slug and self.slug_field or None + lookup_value = self.kwargs[lookup_url_kwarg] - if lookup: - setattr(obj, self.lookup_field, lookup) - - if pk: - setattr(obj, 'pk', pk) - - if slug: - setattr(obj, slug_field, slug) + setattr(obj, self.lookup_field, lookup_value) # Ensure we clean the attributes so that we don't eg return integer # pk using a string representation, as provided by the url conf kwarg. if hasattr(obj, 'full_clean'): - exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field) + exclude = _get_validation_exclusions(obj, self.lookup_field) obj.full_clean(exclude) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 1acbdce26..56870b408 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -16,7 +16,6 @@ from rest_framework.fields import Field, WritableField, get_component, is_simple from rest_framework.reverse import reverse from rest_framework.compat import urlparse from rest_framework.compat import smart_text -import warnings # Relational fields @@ -320,11 +319,6 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), } - # These are all deprecated - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') @@ -334,22 +328,6 @@ class HyperlinkedRelatedField(RelatedField): self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.format = kwargs.pop('format', None) - # These are deprecated - if 'pk_url_kwarg' in kwargs: - msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_url_kwarg' in kwargs: - msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_field' in kwargs: - msg = 'slug_field is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) - self.slug_field = kwargs.pop('slug_field', self.slug_field) - default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) def get_url(self, obj, view_name, request, format): @@ -361,39 +339,7 @@ class HyperlinkedRelatedField(RelatedField): """ lookup_field = getattr(obj, self.lookup_field) kwargs = {self.lookup_field: lookup_field} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - if self.pk_url_kwarg != 'pk': - # Only try pk if it has been explicitly set. - # Otherwise, the default `lookup_field = 'pk'` has us covered. - pk = obj.pk - kwargs = {self.pk_url_kwarg: pk} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - slug = getattr(obj, self.slug_field, None) - if slug is not None: - # Only try slug if it corresponds to an attribute on the object. - kwargs = {self.slug_url_kwarg: slug} - try: - ret = reverse(view_name, kwargs=kwargs, request=request, format=format) - if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': - # If the lookup succeeds using the default slug params, - # then `slug_field` is being used implicitly, and we - # we need to warn about the pending deprecation. - msg = 'Implicit slug field hyperlinked fields are deprecated.' \ - 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - return ret - except NoReverseMatch: - pass - - raise NoReverseMatch() + return reverse(view_name, kwargs=kwargs, request=request, format=format) def get_object(self, queryset, view_name, view_args, view_kwargs): """ @@ -402,19 +348,8 @@ class HyperlinkedRelatedField(RelatedField): Takes the matched URL conf arguments, and the queryset, and should return an object instance, or raise an `ObjectDoesNotExist` exception. """ - lookup = view_kwargs.get(self.lookup_field, None) - pk = view_kwargs.get(self.pk_url_kwarg, None) - slug = view_kwargs.get(self.slug_url_kwarg, None) - - if lookup is not None: - filter_kwargs = {self.lookup_field: lookup} - elif pk is not None: - filter_kwargs = {'pk': pk} - elif slug is not None: - filter_kwargs = {self.slug_field: slug} - else: - raise ObjectDoesNotExist() - + lookup_value = view_kwargs[self.lookup_field] + filter_kwargs = {self.lookup_field: lookup_value} return queryset.get(**filter_kwargs) def to_native(self, obj): @@ -486,11 +421,6 @@ class HyperlinkedIdentityField(Field): lookup_field = 'pk' read_only = True - # These are all deprecated - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') @@ -502,22 +432,6 @@ class HyperlinkedIdentityField(Field): lookup_field = kwargs.pop('lookup_field', None) self.lookup_field = lookup_field or self.lookup_field - # These are deprecated - if 'pk_url_kwarg' in kwargs: - msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_url_kwarg' in kwargs: - msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_field' in kwargs: - msg = 'slug_field is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - self.slug_field = kwargs.pop('slug_field', self.slug_field) - default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) - self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) def field_to_native(self, obj, field_name): @@ -569,27 +483,4 @@ class HyperlinkedIdentityField(Field): if lookup_field is None: return None - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - if self.pk_url_kwarg != 'pk': - # Only try pk lookup if it has been explicitly set. - # Otherwise, the default `lookup_field = 'pk'` has us covered. - kwargs = {self.pk_url_kwarg: obj.pk} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - slug = getattr(obj, self.slug_field, None) - if slug: - # Only use slug lookup if a slug field exists on the model - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - raise NoReverseMatch() + return reverse(view_name, kwargs=kwargs, request=request, format=format) From b8c8d10a18741b76355ed7035655d0101c1d778a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 11:38:54 +0100 Subject: [PATCH 06/13] Remove `page_size` argument. `paginate_queryset` no longer takes an optional `page_size` argument. --- rest_framework/generics.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index e21dc5c7b..09035303f 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -111,26 +111,14 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def paginate_queryset(self, queryset, page_size=None): + def paginate_queryset(self, queryset): """ Paginate a queryset if required, either returning a page object, or `None` if pagination is not configured for this view. """ - deprecated_style = False - if page_size is not None: - warnings.warn('The `page_size` parameter to `paginate_queryset()` ' - 'is deprecated. ' - 'Note that the return style of this method is also ' - 'changed, and will simply return a page object ' - 'when called without a `page_size` argument.', - DeprecationWarning, stacklevel=2) - deprecated_style = True - else: - # Determine the required page size. - # If pagination is not configured, simply return None. - page_size = self.get_paginate_by() - if not page_size: - return None + page_size = self.get_paginate_by() + if not page_size: + return None paginator = self.paginator_class(queryset, page_size) page_kwarg = self.kwargs.get(self.page_kwarg) @@ -152,8 +140,6 @@ class GenericAPIView(views.APIView): 'message': str(exc) }) - if deprecated_style: - return (paginator, page, page.object_list, page.has_other_pages()) return page def filter_queryset(self, queryset): From b3253b42836acd123224e88c0927f1ee6a031d94 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:35:53 +0100 Subject: [PATCH 07/13] Remove `.model` usage in tests. Remove the shortcut `.model` view attribute usage from test cases. --- rest_framework/generics.py | 49 +++++---------------- tests/serializers.py | 7 --- tests/test_filters.py | 65 +++++++++++++++++++--------- tests/test_generics.py | 34 +++++++++++---- tests/test_hyperlinkedserializers.py | 62 ++++++++++++++++++-------- tests/test_nullable_fields.py | 13 +++++- tests/test_pagination.py | 32 ++++++++++---- tests/test_permissions.py | 24 +++++++--- tests/test_response.py | 5 ++- tests/test_validation.py | 4 +- tests/views.py | 8 ---- 11 files changed, 185 insertions(+), 118 deletions(-) delete mode 100644 tests/serializers.py delete mode 100644 tests/views.py diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 09035303f..68222864f 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -51,11 +51,6 @@ class GenericAPIView(views.APIView): queryset = None serializer_class = None - # This shortcut may be used instead of setting either or both - # of the `queryset`/`serializer_class` attributes, although using - # the explicit style is generally preferred. - model = None - # If you want to use object lookups other than pk, set this attribute. # For more complex lookup requirements override `get_object()`. lookup_field = 'pk' @@ -71,9 +66,8 @@ class GenericAPIView(views.APIView): # The filter backend classes to use for queryset filtering filter_backends = api_settings.DEFAULT_FILTER_BACKENDS - # The following attributes may be subject to change, + # The following attribute may be subject to change, # and should be considered private API. - model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS paginator_class = Paginator def get_serializer_context(self): @@ -199,26 +193,13 @@ class GenericAPIView(views.APIView): (Eg. admins get full serialization, others get basic serialization) """ - serializer_class = self.serializer_class - if serializer_class is not None: - return serializer_class - - warnings.warn( - 'The `.model` attribute on view classes is now deprecated in favor ' - 'of the more explicit `serializer_class` and `queryset` attributes.', - DeprecationWarning, stacklevel=2 + assert self.serializer_class is not None, ( + "'%s' should either include a `serializer_class` attribute, " + "or override the `get_serializer_class()` method." + % self.__class__.__name__ ) - assert self.model is not None, \ - "'%s' should either include a 'serializer_class' attribute, " \ - "or use the 'model' attribute as a shortcut for " \ - "automatically generating a serializer class." \ - % self.__class__.__name__ - - class DefaultSerializer(self.model_serializer_class): - class Meta: - model = self.model - return DefaultSerializer + return self.serializer_class def get_queryset(self): """ @@ -235,19 +216,13 @@ class GenericAPIView(views.APIView): (Eg. return a list of items that is specific to the user) """ - if self.queryset is not None: - return self.queryset._clone() + assert self.queryset is not None, ( + "'%s' should either include a `queryset` attribute, " + "or override the `get_queryset()` method." + % self.__class__.__name__ + ) - if self.model is not None: - warnings.warn( - 'The `.model` attribute on view classes is now deprecated in favor ' - 'of the more explicit `serializer_class` and `queryset` attributes.', - DeprecationWarning, stacklevel=2 - ) - return self.model._default_manager.all() - - error_format = "'%s' must define 'queryset' or 'model'" - raise ImproperlyConfigured(error_format % self.__class__.__name__) + return self.queryset._clone() def get_object(self): """ diff --git a/tests/serializers.py b/tests/serializers.py deleted file mode 100644 index be7b37722..000000000 --- a/tests/serializers.py +++ /dev/null @@ -1,7 +0,0 @@ -from rest_framework import serializers -from tests.models import NullableForeignKeySource - - -class NullableFKSourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableForeignKeySource diff --git a/tests/test_filters.py b/tests/test_filters.py index 47bffd436..6f24b1abb 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -16,9 +16,14 @@ factory = APIRequestFactory() if django_filters: + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_fields = ['decimal', 'date'] filter_backends = (filters.DjangoFilterBackend,) @@ -33,7 +38,8 @@ if django_filters: fields = ['text', 'decimal', 'date'] class FilterClassRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = SeveralFieldsFilter filter_backends = (filters.DjangoFilterBackend,) @@ -46,12 +52,14 @@ if django_filters: fields = ['text'] class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = MisconfiguredFilter filter_backends = (filters.DjangoFilterBackend,) class FilterClassDetailView(generics.RetrieveAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = SeveralFieldsFilter filter_backends = (filters.DjangoFilterBackend,) @@ -63,15 +71,12 @@ if django_filters: model = BaseFilterableItem class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = BaseFilterableItemFilter filter_backends = (filters.DjangoFilterBackend,) # Regression test for #814 - class FilterableItemSerializer(serializers.ModelSerializer): - class Meta: - model = FilterableItem - class FilterFieldsQuerysetView(generics.ListCreateAPIView): queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer @@ -323,6 +328,11 @@ class SearchFilterModel(models.Model): text = models.CharField(max_length=100) +class SearchFilterSerializer(serializers.ModelSerializer): + class Meta: + model = SearchFilterModel + + class SearchFilterTests(TestCase): def setUp(self): # Sequence of title/text is: @@ -342,7 +352,8 @@ class SearchFilterTests(TestCase): def test_search(self): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('title', 'text') @@ -359,7 +370,8 @@ class SearchFilterTests(TestCase): def test_exact_search(self): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('=title', 'text') @@ -375,7 +387,8 @@ class SearchFilterTests(TestCase): def test_startswith_search(self): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('title', '^text') @@ -392,7 +405,8 @@ class SearchFilterTests(TestCase): def test_search_with_nonstandard_search_param(self): with temporary_setting('SEARCH_PARAM', 'query', module=filters): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('title', 'text') @@ -418,6 +432,11 @@ class OrderingFilterRelatedModel(models.Model): related_name="relateds") +class OrderingFilterSerializer(serializers.ModelSerializer): + class Meta: + model = OrdringFilterModel + + class OrderingFilterTests(TestCase): def setUp(self): # Sequence of title/text is: @@ -440,7 +459,8 @@ class OrderingFilterTests(TestCase): def test_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrdringFilterModel + queryset = OrdringFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) @@ -459,7 +479,8 @@ class OrderingFilterTests(TestCase): def test_reverse_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrdringFilterModel + queryset = OrdringFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) @@ -478,7 +499,8 @@ class OrderingFilterTests(TestCase): def test_incorrectfield_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrdringFilterModel + queryset = OrdringFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) @@ -497,7 +519,8 @@ class OrderingFilterTests(TestCase): def test_default_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrdringFilterModel + queryset = OrdringFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) oredering_fields = ('text',) @@ -516,7 +539,8 @@ class OrderingFilterTests(TestCase): def test_default_ordering_using_string(self): class OrderingListView(generics.ListAPIView): - model = OrdringFilterModel + queryset = OrdringFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = 'title' ordering_fields = ('text',) @@ -545,7 +569,7 @@ class OrderingFilterTests(TestCase): new_related.save() class OrderingListView(generics.ListAPIView): - model = OrdringFilterModel + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = 'title' ordering_fields = '__all__' @@ -567,7 +591,8 @@ class OrderingFilterTests(TestCase): def test_ordering_with_nonstandard_ordering_param(self): with temporary_setting('ORDERING_PARAM', 'order', filters): class OrderingListView(generics.ListAPIView): - model = OrdringFilterModel + queryset = OrdringFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) diff --git a/tests/test_generics.py b/tests/test_generics.py index e9f5bebdd..f50d53e99 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -11,18 +11,30 @@ from tests.models import ForeignKeySource, ForeignKeyTarget factory = APIRequestFactory() +class BasicSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + +class ForeignKeySerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySource + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer class InstanceView(generics.RetrieveUpdateDestroyAPIView): """ Example description for OPTIONS. """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer def get_queryset(self): queryset = super(InstanceView, self).get_queryset() @@ -33,7 +45,8 @@ class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): """ FK: example description for OPTIONS. """ - model = ForeignKeySource + queryset = ForeignKeySource.objects.all() + serializer_class = ForeignKeySerializer class SlugSerializer(serializers.ModelSerializer): @@ -48,7 +61,7 @@ class SlugBasedInstanceView(InstanceView): """ A model with a slug-field. """ - model = SlugBasedModel + queryset = SlugBasedModel.objects.all() serializer_class = SlugSerializer lookup_field = 'slug' @@ -503,7 +516,7 @@ class TestOverriddenGetObject(TestCase): """ Example detail view for override of get_object(). """ - model = BasicModel + serializer_class = BasicSerializer def get_object(self): pk = int(self.kwargs['pk']) @@ -573,7 +586,7 @@ class ClassASerializer(serializers.ModelSerializer): class ExampleView(generics.ListCreateAPIView): serializer_class = ClassASerializer - model = ClassA + queryset = ClassA.objects.all() class TestM2MBrowseableAPI(TestCase): @@ -603,7 +616,7 @@ class TwoFieldModel(models.Model): class DynamicSerializerView(generics.ListCreateAPIView): - model = TwoFieldModel + queryset = TwoFieldModel.objects.all() renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) def get_serializer_class(self): @@ -612,8 +625,11 @@ class DynamicSerializerView(generics.ListCreateAPIView): class Meta: model = TwoFieldModel fields = ('field_b',) - return DynamicSerializer - return super(DynamicSerializerView, self).get_serializer_class() + else: + class DynamicSerializer(serializers.ModelSerializer): + class Meta: + model = TwoFieldModel + return DynamicSerializer class TestFilterBackendAppliedToViews(TestCase): diff --git a/tests/test_hyperlinkedserializers.py b/tests/test_hyperlinkedserializers.py index d45485391..0e8c1ed46 100644 --- a/tests/test_hyperlinkedserializers.py +++ b/tests/test_hyperlinkedserializers.py @@ -39,59 +39,85 @@ class AlbumSerializer(serializers.ModelSerializer): fields = ('title', 'url') +class BasicSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = BasicModel + + +class AnchorSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = Anchor + + +class ManyToManySerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = ManyToManyModel + + +class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + + +class OptionalRelationSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = OptionalRelationModel + + class BasicList(generics.ListCreateAPIView): - model = BasicModel - model_serializer_class = serializers.HyperlinkedModelSerializer + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer class BasicDetail(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel - model_serializer_class = serializers.HyperlinkedModelSerializer + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer class AnchorDetail(generics.RetrieveAPIView): - model = Anchor - model_serializer_class = serializers.HyperlinkedModelSerializer + queryset = Anchor.objects.all() + serializer_class = AnchorSerializer class ManyToManyList(generics.ListAPIView): - model = ManyToManyModel - model_serializer_class = serializers.HyperlinkedModelSerializer + queryset = ManyToManyModel.objects.all() + serializer_class = ManyToManySerializer class ManyToManyDetail(generics.RetrieveAPIView): - model = ManyToManyModel - model_serializer_class = serializers.HyperlinkedModelSerializer + queryset = ManyToManyModel.objects.all() + serializer_class = ManyToManySerializer class BlogPostCommentListCreate(generics.ListCreateAPIView): - model = BlogPostComment + queryset = BlogPostComment.objects.all() serializer_class = BlogPostCommentSerializer class BlogPostCommentDetail(generics.RetrieveAPIView): - model = BlogPostComment + queryset = BlogPostComment.objects.all() serializer_class = BlogPostCommentSerializer class BlogPostDetail(generics.RetrieveAPIView): - model = BlogPost + queryset = BlogPost.objects.all() + serializer_class = BlogPostSerializer class PhotoListCreate(generics.ListCreateAPIView): - model = Photo - model_serializer_class = PhotoSerializer + queryset = Photo.objects.all() + serializer_class = PhotoSerializer class AlbumDetail(generics.RetrieveAPIView): - model = Album + queryset = Album.objects.all() serializer_class = AlbumSerializer lookup_field = 'title' class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): - model = OptionalRelationModel - model_serializer_class = serializers.HyperlinkedModelSerializer + queryset = OptionalRelationModel.objects.all() + serializer_class = OptionalRelationSerializer urlpatterns = patterns( diff --git a/tests/test_nullable_fields.py b/tests/test_nullable_fields.py index 0c133fc2c..8d0c84bb0 100644 --- a/tests/test_nullable_fields.py +++ b/tests/test_nullable_fields.py @@ -1,10 +1,19 @@ from django.core.urlresolvers import reverse from django.conf.urls import patterns, url +from rest_framework import serializers, generics from rest_framework.test import APITestCase from tests.models import NullableForeignKeySource -from tests.serializers import NullableFKSourceSerializer -from tests.views import NullableFKSourceDetail + + +class NullableFKSourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableForeignKeySource + + +class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): + queryset = NullableForeignKeySource.objects.all() + serializer_class = NullableFKSourceSerializer urlpatterns = patterns( diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 80c33e2eb..8f9e0005e 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -4,7 +4,7 @@ from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.utils import unittest -from rest_framework import generics, status, pagination, filters, serializers +from rest_framework import generics, serializers, status, pagination, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from .models import BasicModel, FilterableItem @@ -22,11 +22,22 @@ def split_arguments_from_url(url): return path, args +class BasicSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + +class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer paginate_by = 10 @@ -34,14 +45,16 @@ class DefaultPageSizeKwargView(generics.ListAPIView): """ View for testing default paginate_by_param usage """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer class PaginateByParamView(generics.ListAPIView): """ View for testing custom paginate_by_param usage """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer paginate_by_param = 'page_size' @@ -49,7 +62,8 @@ class MaxPaginateByView(generics.ListAPIView): """ View for testing custom max_paginate_by usage """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer paginate_by = 3 max_paginate_by = 5 paginate_by_param = 'page_size' @@ -140,7 +154,8 @@ class IntegrationTestPaginationAndFiltering(TestCase): fields = ['text', 'decimal', 'date'] class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer paginate_by = 10 filter_class = DecimalFilter filter_backends = (filters.DjangoFilterBackend,) @@ -188,7 +203,8 @@ class IntegrationTestPaginationAndFiltering(TestCase): return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) class BasicFilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer paginate_by = 10 filter_backends = (DecimalFilterBackend,) @@ -387,7 +403,7 @@ class TestContextPassedToCustomField(TestCase): def test_with_pagination(self): class ListView(generics.ListCreateAPIView): - model = BasicModel + queryset = BasicModel.objects.all() serializer_class = BasicModelSerializer paginate_by = 1 diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 93f8020f3..b90ba4f19 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -3,7 +3,7 @@ from django.contrib.auth.models import User, Permission, Group from django.db import models from django.test import TestCase from django.utils import unittest -from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING +from rest_framework import generics, serializers, status, permissions, authentication, HTTP_HEADER_ENCODING from rest_framework.compat import guardian, get_model_name from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.test import APIRequestFactory @@ -13,14 +13,21 @@ import base64 factory = APIRequestFactory() +class BasicSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + class RootView(generics.ListCreateAPIView): - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [permissions.DjangoModelPermissions] class InstanceView(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [permissions.DjangoModelPermissions] @@ -167,6 +174,11 @@ class BasicPermModel(models.Model): ) +class BasicPermSerializer(serializers.ModelSerializer): + class Meta: + model = BasicPermModel + + # Custom object-level permission, that includes 'view' permissions class ViewObjectPermissions(permissions.DjangoObjectPermissions): perms_map = { @@ -181,7 +193,8 @@ class ViewObjectPermissions(permissions.DjangoObjectPermissions): class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView): - model = BasicPermModel + queryset = BasicPermModel.objects.all() + serializer_class = BasicPermSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [ViewObjectPermissions] @@ -189,7 +202,8 @@ object_permissions_view = ObjectPermissionInstanceView.as_view() class ObjectPermissionListView(generics.ListAPIView): - model = BasicPermModel + queryset = BasicPermModel.objects.all() + serializer_class = BasicPermSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [ViewObjectPermissions] diff --git a/tests/test_response.py b/tests/test_response.py index 2eff83d3d..004c565c9 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -86,14 +86,15 @@ class HTMLView1(APIView): class HTMLNewModelViewSet(viewsets.ModelViewSet): - model = BasicModel + serializer_class = BasicModelSerializer + queryset = BasicModel.objects.all() class HTMLNewModelView(generics.ListCreateAPIView): renderer_classes = (BrowsableAPIRenderer,) permission_classes = [] serializer_class = BasicModelSerializer - model = BasicModel + queryset = BasicModel.objects.all() new_model_viewset_router = routers.DefaultRouter() diff --git a/tests/test_validation.py b/tests/test_validation.py index e13e4078c..f62d9068b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -22,7 +22,7 @@ class ValidationModelSerializer(serializers.ModelSerializer): class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): - model = ValidationModel + queryset = ValidationModel.objects.all() serializer_class = ValidationModelSerializer @@ -117,7 +117,7 @@ class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): - model = ValidationMaxValueValidatorModel + queryset = ValidationMaxValueValidatorModel.objects.all() serializer_class = ValidationMaxValueValidatorModelSerializer diff --git a/tests/views.py b/tests/views.py deleted file mode 100644 index 55935e924..000000000 --- a/tests/views.py +++ /dev/null @@ -1,8 +0,0 @@ -from rest_framework import generics -from .models import NullableForeignKeySource -from .serializers import NullableFKSourceSerializer - - -class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): - model = NullableForeignKeySource - model_serializer_class = NullableFKSourceSerializer From 72c0811576feb89decf6fc6dc4ee5e25eca0aece Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:48:04 +0100 Subject: [PATCH 08/13] Minor tidy up. --- rest_framework/generics.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 68222864f..d0adeaec0 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -3,7 +3,7 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals -from django.core.exceptions import ImproperlyConfigured, PermissionDenied +from django.core.exceptions import PermissionDenied from django.core.paginator import Paginator, InvalidPage from django.http import Http404 from django.shortcuts import get_object_or_404 as _get_object_or_404 @@ -235,19 +235,16 @@ class GenericAPIView(views.APIView): queryset = self.filter_queryset(self.get_queryset()) # Perform the lookup filtering. - # Note that `pk` and `slug` are deprecated styles of lookup filtering. lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup = self.kwargs.get(lookup_url_kwarg, None) - if lookup is None: - raise ImproperlyConfigured( - 'Expected view %s to be called with a URL keyword argument ' - 'named "%s". Fix your URL conf, or set the `.lookup_field` ' - 'attribute on the view correctly.' % - (self.__class__.__name__, self.lookup_field) - ) + assert lookup_url_kwarg in self.kwargs, ( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, lookup_url_kwarg) + ) - filter_kwargs = {self.lookup_field: lookup} + filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} obj = get_object_or_404(queryset, **filter_kwargs) # May raise a permission denied From ce7b2cded94abc12ae1be076642de96684d0927b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:48:49 +0100 Subject: [PATCH 09/13] Remove deprecated generic views. `MultipleObjectAPIView` and `SingleObjectAPIView` are no longer required. --- rest_framework/generics.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index d0adeaec0..e6cbfca90 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -442,25 +442,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) - - -# Deprecated classes - -class MultipleObjectAPIView(GenericAPIView): - def __init__(self, *args, **kwargs): - warnings.warn( - 'Subclassing `MultipleObjectAPIView` is deprecated. ' - 'You should simply subclass `GenericAPIView` instead.', - DeprecationWarning, stacklevel=2 - ) - super(MultipleObjectAPIView, self).__init__(*args, **kwargs) - - -class SingleObjectAPIView(GenericAPIView): - def __init__(self, *args, **kwargs): - warnings.warn( - 'Subclassing `SingleObjectAPIView` is deprecated. ' - 'You should simply subclass `GenericAPIView` instead.', - DeprecationWarning, stacklevel=2 - ) - super(SingleObjectAPIView, self).__init__(*args, **kwargs) From f87d32558eb3b36f14a798ec48e4943d25380b92 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:53:45 +0100 Subject: [PATCH 10/13] Remove `.link()` and `.action()` decorators. --- rest_framework/decorators.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 449ba0a29..cc5d92c2e 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -130,37 +130,3 @@ def list_route(methods=['get'], **kwargs): func.kwargs = kwargs return func return decorator - - -# These are now pending deprecation, in favor of `detail_route` and `list_route`. - -def link(**kwargs): - """ - Used to mark a method on a ViewSet that should be routed for detail GET requests. - """ - msg = 'link is pending deprecation. Use detail_route instead.' - warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) - - def decorator(func): - func.bind_to_methods = ['get'] - func.detail = True - func.kwargs = kwargs - return func - - return decorator - - -def action(methods=['post'], **kwargs): - """ - Used to mark a method on a ViewSet that should be routed for detail POST requests. - """ - msg = 'action is pending deprecation. Use detail_route instead.' - warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) - - def decorator(func): - func.bind_to_methods = methods - func.detail = True - func.kwargs = kwargs - return func - - return decorator From b552b62540e5144272c9c13c28f120ffe5fcbe45 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:54:03 +0100 Subject: [PATCH 11/13] `get_paginate_by` no longer takes optional `.queryset` --- rest_framework/generics.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index e6cbfca90..40c498440 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -158,7 +158,7 @@ class GenericAPIView(views.APIView): # The following methods provide default implementations # that you may want to override for more complex cases. - def get_paginate_by(self, queryset=None): + def get_paginate_by(self): """ Return the size of pages to use with pagination. @@ -167,11 +167,6 @@ class GenericAPIView(views.APIView): Otherwise defaults to using `self.paginate_by`. """ - if queryset is not None: - warnings.warn('The `queryset` parameter to `get_paginate_by()` ' - 'is deprecated.', - DeprecationWarning, stacklevel=2) - if self.paginate_by_param: try: return strict_positive_int( From 371d30aa8737c4b3aaf28ee10cc2b77a9c4d1fd9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:54:52 +0100 Subject: [PATCH 12/13] Remove unused imports. --- rest_framework/decorators.py | 1 - rest_framework/generics.py | 1 - 2 files changed, 2 deletions(-) diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index cc5d92c2e..d28d6e22a 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,7 +10,6 @@ from __future__ import unicode_literals from django.utils import six from rest_framework.views import APIView import types -import warnings def api_view(http_method_names): diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 40c498440..b3bd6ce92 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -11,7 +11,6 @@ from django.utils.translation import ugettext as _ from rest_framework import views, mixins, exceptions from rest_framework.request import clone_request from rest_framework.settings import api_settings -import warnings def strict_positive_int(integer_string, cutoff=None): From 4ac4676a40b121d27cfd1173ff548d96b8d3de2f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 16:46:26 +0100 Subject: [PATCH 13/13] First pass --- rest_framework/fields.py | 1226 ++++++------------------------ rest_framework/generics.py | 10 +- rest_framework/mixins.py | 24 +- rest_framework/pagination.py | 22 +- rest_framework/relations.py | 486 ------------ rest_framework/renderers.py | 10 +- rest_framework/serializers.py | 1200 +++++++++-------------------- rest_framework/utils/encoders.py | 18 +- rest_framework/utils/html.py | 86 +++ 9 files changed, 714 insertions(+), 2368 deletions(-) create mode 100644 rest_framework/utils/html.py diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9d707c9b5..a83bf94c4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,1038 +1,308 @@ -""" -Serializer fields perform validation on incoming data. - -They are very similar to Django's form fields. -""" -from __future__ import unicode_literals - -import copy -import datetime -import inspect -import re -import warnings -from decimal import Decimal, DecimalException -from django import forms -from django.core import validators -from django.core.exceptions import ValidationError -from django.conf import settings -from django.db.models.fields import BLANK_CHOICE_DASH -from django.http import QueryDict -from django.forms import widgets -from django.utils import six, timezone -from django.utils.encoding import is_protected_type -from django.utils.translation import ugettext_lazy as _ -from django.utils.datastructures import SortedDict -from django.utils.dateparse import parse_date, parse_datetime, parse_time -from rest_framework import ISO_8601 -from rest_framework.compat import ( - BytesIO, smart_text, - force_text, is_non_str_iterable -) -from rest_framework.settings import api_settings +from rest_framework.utils import html -def is_simple_callable(obj): +class empty: """ - True if the object is a callable that takes no arguments. + This class is used to represent no data being provided for a given input + or output value. + + It is required because `None` may be a valid input or output value. """ - function = inspect.isfunction(obj) - method = inspect.ismethod(obj) - - if not (function or method): - return False - - args, _, _, defaults = inspect.getargspec(obj) - len_args = len(args) if function else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults + pass -def get_component(obj, attr_name): +def get_attribute(instance, attrs): """ - Given an object, and an attribute name, - return that attribute on the object. + Similar to Python's built in `getattr(instance, attr)`, + but takes a list of nested attributes, instead of a single attribute. """ - if isinstance(obj, dict): - val = obj.get(attr_name) - else: - val = getattr(obj, attr_name) - - if is_simple_callable(val): - return val() - return val + for attr in attrs: + instance = getattr(instance, attr) + return instance -def readable_datetime_formats(formats): - format = ', '.join(formats).replace( - ISO_8601, - 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' - ) - return humanize_strptime(format) - - -def readable_date_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') - return humanize_strptime(format) - - -def readable_time_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') - return humanize_strptime(format) - - -def humanize_strptime(format_string): - # Note that we're missing some of the locale specific mappings that - # don't really make sense. - mapping = { - "%Y": "YYYY", - "%y": "YY", - "%m": "MM", - "%b": "[Jan-Dec]", - "%B": "[January-December]", - "%d": "DD", - "%H": "hh", - "%I": "hh", # Requires '%p' to differentiate from '%H'. - "%M": "mm", - "%S": "ss", - "%f": "uuuuuu", - "%a": "[Mon-Sun]", - "%A": "[Monday-Sunday]", - "%p": "[AM|PM]", - "%z": "[+HHMM|-HHMM]" - } - for key, val in mapping.items(): - format_string = format_string.replace(key, val) - return format_string - - -def strip_multiple_choice_msg(help_text): +def set_value(dictionary, keys, value): """ - Remove the 'Hold down "control" ...' message that is Django enforces in - select multiple fields on ModelForms. (Required for 1.5 and earlier) + Similar to Python's built in `dictionary[key] = value`, + but takes a list of nested keys instead of a single key. - See https://code.djangoproject.com/ticket/9321 + set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} + set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} + set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}} """ - multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') - multiple_choice_msg = force_text(multiple_choice_msg) + if not keys: + dictionary.update(value) + return - return help_text.replace(multiple_choice_msg, '') + for key in keys[:-1]: + if key not in dictionary: + dictionary[key] = {} + dictionary = dictionary[key] + + dictionary[keys[-1]] = value + + +class ValidationError(Exception): + pass + + +class SkipField(Exception): + pass class Field(object): - read_only = True - creation_counter = 0 - empty = '' - type_name = None - partial = False - use_files = False - form_field_class = forms.CharField - type_label = 'field' - widget = None + _creation_counter = 0 - def __init__(self, source=None, label=None, help_text=None): - self.parent = None - - self.creation_counter = Field.creation_counter - Field.creation_counter += 1 - - self.source = source - - if label is not None: - self.label = smart_text(label) - else: - self.label = None - - if help_text is not None: - self.help_text = strip_multiple_choice_msg(smart_text(help_text)) - else: - self.help_text = None - - self._errors = [] - self._value = None - self._name = None - - @property - def errors(self): - return self._errors - - def widget_html(self): - if not self.widget: - return '' - - attrs = {} - if 'id' not in self.widget.attrs: - attrs['id'] = self._name - - return self.widget.render(self._name, self._value, attrs=attrs) - - def label_tag(self): - return '' % (self._name, self.label) - - def initialize(self, parent, field_name): - """ - Called to set up a field prior to field_to_native or field_from_native. - - parent - The parent serializer. - field_name - The name of the field being initialized. - """ - self.parent = parent - self.root = parent.root or parent - self.context = self.root.context - self.partial = self.root.partial - if self.partial: - self.required = False - - def field_from_native(self, data, files, field_name, into): - """ - Given a dictionary and a field name, updates the dictionary `into`, - with the field and it's deserialized value. - """ - return - - def field_to_native(self, obj, field_name): - """ - Given an object and a field name, returns the value that should be - serialized for that field. - """ - if obj is None: - return self.empty - - if self.source == '*': - return self.to_native(obj) - - source = self.source or field_name - value = obj - - for component in source.split('.'): - value = get_component(value, component) - if value is None: - break - - return self.to_native(value) - - def to_native(self, value): - """ - Converts the field's value into it's simple representation. - """ - if is_simple_callable(value): - value = value() - - if is_protected_type(value): - return value - elif (is_non_str_iterable(value) and - not isinstance(value, (dict, six.string_types))): - return [self.to_native(item) for item in value] - elif isinstance(value, dict): - # Make sure we preserve field ordering, if it exists - ret = SortedDict() - for key, val in value.items(): - ret[key] = self.to_native(val) - return ret - return force_text(value) - - def attributes(self): - """ - Returns a dictionary of attributes to be used when serializing to xml. - """ - if self.type_name: - return {'type': self.type_name} - return {} - - def metadata(self): - metadata = SortedDict() - metadata['type'] = self.type_label - metadata['required'] = getattr(self, 'required', False) - optional_attrs = ['read_only', 'label', 'help_text', - 'min_length', 'max_length'] - for attr in optional_attrs: - value = getattr(self, attr, None) - if value is not None and value != '': - metadata[attr] = force_text(value, strings_only=True) - return metadata - - -class WritableField(Field): - """ - Base for read/write fields. - """ - write_only = False - default_validators = [] - default_error_messages = { - 'required': _('This field is required.'), - 'invalid': _('Invalid value.'), + MESSAGES = { + 'required': 'This field is required.' } - widget = widgets.TextInput - default = None - def __init__(self, source=None, label=None, help_text=None, - read_only=False, write_only=False, required=None, - validators=[], error_messages=None, widget=None, - default=None, blank=None): + _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' + _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' + _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' + _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' + _MISSING_ERROR_MESSAGE = ( + 'ValidationError raised by `{class_name}`, but error key `{key}` does ' + 'not exist in the `MESSAGES` dictionary.' + ) - super(WritableField, self).__init__(source=source, label=label, help_text=help_text) + def __init__(self, read_only=False, write_only=False, + required=None, default=empty, initial=None, source=None, + label=None, style=None): + self._creation_counter = Field._creation_counter + Field._creation_counter += 1 + + # If `required` is unset, then use `True` unless a default is provided. + if required is None: + required = default is empty and not read_only + + # Some combinations of keyword arguments do not make sense. + assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY + assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED + assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT + assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT self.read_only = read_only self.write_only = write_only + self.required = required + self.default = default + self.source = source + self.initial = initial + self.label = label + self.style = {} if style is None else style - assert not (read_only and write_only), "Cannot set read_only=True and write_only=True" + def bind(self, field_name, parent, root): + """ + Setup the context for the field instance. + """ + self.field_name = field_name + self.parent = parent + self.root = root - if required is None: - self.required = not(read_only) + # `self.label` should deafult to being based on the field name. + if self.label is None: + self.label = self.field_name.replace('_', ' ').capitalize() + + # self.source should default to being the same as the field name. + if self.source is None: + self.source = field_name + + # self.source_attrs is a list of attributes that need to be looked up + # when serializing the instance, or populating the validated data. + if self.source == '*': + self.source_attrs = [] else: - assert not (read_only and required), "Cannot set required=True and read_only=True" - self.required = required + self.source_attrs = self.source.split('.') - messages = {} - for c in reversed(self.__class__.__mro__): - messages.update(getattr(c, 'default_error_messages', {})) - messages.update(error_messages or {}) - self.error_messages = messages + def get_initial(self): + """ + Return a value to use when the field is being returned as a primative + value, without any object instance. + """ + return self.initial - self.validators = self.default_validators + validators - self.default = default if default is not None else self.default + def get_value(self, dictionary): + """ + Given the *incoming* primative data, return the value for this field + that should be validated and transformed to a native value. + """ + return dictionary.get(self.field_name, empty) - # Widgets are only used for HTML forms. - widget = widget or self.widget - if isinstance(widget, type): - widget = widget() - self.widget = widget + def get_attribute(self, instance): + """ + Given the *outgoing* object instance, return the value for this field + that should be returned as a primative value. + """ + return get_attribute(instance, self.source_attrs) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - result.validators = self.validators[:] - return result + def get_default(self): + """ + Return the default value to use when validating data if no input + is provided for this field. - def get_default_value(self): - if is_simple_callable(self.default): - return self.default() + If a default has not been set for this field then this will simply + return `empty`, indicating that no value should be set in the + validated data for this field. + """ + if self.default is empty: + raise SkipField() return self.default - def validate(self, value): - if value in validators.EMPTY_VALUES and self.required: - raise ValidationError(self.error_messages['required']) - - def run_validators(self, value): - if value in validators.EMPTY_VALUES: - return - errors = [] - for v in self.validators: - try: - v(value) - except ValidationError as e: - if hasattr(e, 'code') and e.code in self.error_messages: - message = self.error_messages[e.code] - if e.params: - message = message % e.params - errors.append(message) - else: - errors.extend(e.messages) - if errors: - raise ValidationError(errors) - - def field_to_native(self, obj, field_name): - if self.write_only: - return None - return super(WritableField, self).field_to_native(obj, field_name) - - def field_from_native(self, data, files, field_name, into): + def validate(self, data=empty): """ - Given a dictionary and a field name, updates the dictionary `into`, - with the field and it's deserialized value. + Validate a simple representation and return the internal value. + + The provided data may be `empty` if no representation was included. + May return `empty` if the field should not be included in the + validated data. """ - if self.read_only: - return + if data is empty: + if self.required: + self.fail('required') + return self.get_default() - try: - data = data or {} - if self.use_files: - files = files or {} - try: - native = files[field_name] - except KeyError: - native = data[field_name] - else: - native = data[field_name] - except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - native = self.get_default_value() - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return + return self.to_native(data) - value = self.from_native(native) - if self.source == '*': - if value: - into.update(value) - else: - self.validate(value) - self.run_validators(value) - into[self.source or field_name] = value - - def from_native(self, value): + def to_native(self, data): """ - Reverts a simple representation back to the field's value. + Transform the *incoming* primative data into a native value. + """ + return data + + def to_primative(self, value): + """ + Transform the *outgoing* native value into primative data. """ return value - -class ModelField(WritableField): - """ - A generic field that can be used against an arbitrary model field. - """ - def __init__(self, *args, **kwargs): + def fail(self, key, **kwargs): + """ + A helper method that simply raises a validation error. + """ try: - self.model_field = kwargs.pop('model_field') + raise ValidationError(self.MESSAGES[key].format(**kwargs)) except KeyError: - raise ValueError("ModelField requires 'model_field' kwarg") + class_name = self.__class__.__name__ + msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + raise AssertionError(msg) - self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) - self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) - self.min_value = kwargs.pop('min_value', - getattr(self.model_field, 'min_value', None)) - self.max_value = kwargs.pop('max_value', - getattr(self.model_field, 'max_value', None)) - super(ModelField, self).__init__(*args, **kwargs) +class BooleanField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_value': '`{input}` is not a valid boolean.' + } + TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True} + FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} - if self.min_length is not None: - self.validators.append(validators.MinLengthValidator(self.min_length)) - if self.max_length is not None: - self.validators.append(validators.MaxLengthValidator(self.max_length)) - if self.min_value is not None: - self.validators.append(validators.MinValueValidator(self.min_value)) - if self.max_value is not None: - self.validators.append(validators.MaxValueValidator(self.max_value)) + def get_value(self, dictionary): + if html.is_html_input(dictionary): + # HTML forms do not send a `False` value on an empty checkbox, + # so we override the default empty value to be False. + return dictionary.get(self.field_name, False) + return dictionary.get(self.field_name, empty) - def from_native(self, value): - rel = getattr(self.model_field, "rel", None) - if rel is not None: - return rel.to._meta.get_field(rel.field_name).to_python(value) + def to_native(self, data): + if data in self.TRUE_VALUES: + return True + elif data in self.FALSE_VALUES: + return False + self.fail('invalid_value', input=data) + + +class CharField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'blank': 'This field may not be blank.' + } + + def __init__(self, *args, **kwargs): + self.allow_blank = kwargs.pop('allow_blank', False) + super(CharField, self).__init__(*args, **kwargs) + + def to_native(self, data): + if data == '' and not self.allow_blank: + self.fail('blank') + return str(data) + + +class ChoiceField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_choice': '`{input}` is not a valid choice.' + } + coerce_to_type = str + + def __init__(self, *args, **kwargs): + choices = kwargs.pop('choices') + + assert choices, '`choices` argument is required and may not be empty' + + # Allow either single or paired choices style: + # choices = [1, 2, 3] + # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] + pairs = [ + isinstance(item, (list, tuple)) and len(item) == 2 + for item in choices + ] + if all(pairs): + self.choices = {key: val for key, val in choices} else: - return self.model_field.to_python(value) + self.choices = {item: item for item in choices} - def field_to_native(self, obj, field_name): - value = self.model_field._get_val_from_obj(obj) - if is_protected_type(value): - return value - return self.model_field.value_to_string(obj) - - def attributes(self): - return { - "type": self.model_field.get_internal_type() + # Map the string representation of choices to the underlying value. + # Allows us to deal with eg. integer choices while supporting either + # integer or string input, but still get the correct datatype out. + self.choice_strings_to_values = { + str(key): key for key in self.choices.keys() } - -# Typed Fields - -class BooleanField(WritableField): - type_name = 'BooleanField' - type_label = 'boolean' - form_field_class = forms.BooleanField - widget = widgets.CheckboxInput - default_error_messages = { - 'invalid': _("'%s' value must be either True or False."), - } - empty = False - - def field_from_native(self, data, files, field_name, into): - # HTML checkboxes do not explicitly represent unchecked as `False` - # we deal with that here... - if isinstance(data, QueryDict) and self.default is None: - self.default = False - - return super(BooleanField, self).field_from_native( - data, files, field_name, into - ) - - def from_native(self, value): - if value in ('true', 't', 'True', '1'): - return True - if value in ('false', 'f', 'False', '0'): - return False - return bool(value) - - -class CharField(WritableField): - type_name = 'CharField' - type_label = 'string' - form_field_class = forms.CharField - - def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs): - self.max_length, self.min_length = max_length, min_length - self.allow_none = allow_none - super(CharField, self).__init__(*args, **kwargs) - if min_length is not None: - self.validators.append(validators.MinLengthValidator(min_length)) - if max_length is not None: - self.validators.append(validators.MaxLengthValidator(max_length)) - - def from_native(self, value): - if isinstance(value, six.string_types): - return value - - if value is None and not self.allow_none: - return '' - - return smart_text(value) - - -class URLField(CharField): - type_name = 'URLField' - type_label = 'url' - - def __init__(self, **kwargs): - if 'validators' not in kwargs: - kwargs['validators'] = [validators.URLValidator()] - super(URLField, self).__init__(**kwargs) - - -class SlugField(CharField): - type_name = 'SlugField' - type_label = 'slug' - form_field_class = forms.SlugField - - default_error_messages = { - 'invalid': _("Enter a valid 'slug' consisting of letters, numbers," - " underscores or hyphens."), - } - default_validators = [validators.validate_slug] - - def __init__(self, *args, **kwargs): - super(SlugField, self).__init__(*args, **kwargs) - - -class ChoiceField(WritableField): - type_name = 'ChoiceField' - type_label = 'choice' - form_field_class = forms.ChoiceField - widget = widgets.Select - default_error_messages = { - 'invalid_choice': _('Select a valid choice. %(value)s is not one of ' - 'the available choices.'), - } - - def __init__(self, choices=(), blank_display_value=None, *args, **kwargs): - self.empty = kwargs.pop('empty', '') super(ChoiceField, self).__init__(*args, **kwargs) - self.choices = choices - if not self.required: - if blank_display_value is None: - blank_choice = BLANK_CHOICE_DASH - else: - blank_choice = [('', blank_display_value)] - self.choices = blank_choice + self.choices - - def _get_choices(self): - return self._choices - - def _set_choices(self, value): - # Setting choices also sets the choices on the widget. - # choices can be any iterable, but we call list() on it because - # it will be consumed more than once. - self._choices = self.widget.choices = list(value) - - choices = property(_get_choices, _set_choices) - - def metadata(self): - data = super(ChoiceField, self).metadata() - data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices] - return data - - def validate(self, value): - """ - Validates that the input is in self.choices. - """ - super(ChoiceField, self).validate(value) - if value and not self.valid_value(value): - raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) - - def valid_value(self, value): - """ - Check to see if the provided value is a valid choice. - """ - for k, v in self.choices: - if isinstance(v, (list, tuple)): - # This is an optgroup, so look inside the group for options - for k2, v2 in v: - if value == smart_text(k2): - return True - else: - if value == smart_text(k) or value == k: - return True - return False - - def from_native(self, value): - value = super(ChoiceField, self).from_native(value) - if value == self.empty or value in validators.EMPTY_VALUES: - return self.empty - return value - - -class EmailField(CharField): - type_name = 'EmailField' - type_label = 'email' - form_field_class = forms.EmailField - - default_error_messages = { - 'invalid': _('Enter a valid email address.'), - } - default_validators = [validators.validate_email] - - def from_native(self, value): - ret = super(EmailField, self).from_native(value) - if ret is None: - return None - return ret.strip() - - -class RegexField(CharField): - type_name = 'RegexField' - type_label = 'regex' - form_field_class = forms.RegexField - - def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): - super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) - self.regex = regex - - def _get_regex(self): - return self._regex - - def _set_regex(self, regex): - if isinstance(regex, six.string_types): - regex = re.compile(regex) - self._regex = regex - if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: - self.validators.remove(self._regex_validator) - self._regex_validator = validators.RegexValidator(regex=regex) - self.validators.append(self._regex_validator) - - regex = property(_get_regex, _set_regex) - - -class DateField(WritableField): - type_name = 'DateField' - type_label = 'date' - widget = widgets.DateInput - form_field_class = forms.DateField - - default_error_messages = { - 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), - } - empty = None - input_formats = api_settings.DATE_INPUT_FORMATS - format = api_settings.DATE_FORMAT - - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(DateField, self).__init__(*args, **kwargs) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - if isinstance(value, datetime.datetime): - if timezone and settings.USE_TZ and timezone.is_aware(value): - # Convert aware datetimes to the default time zone - # before casting them to dates (#17742). - default_timezone = timezone.get_default_timezone() - value = timezone.make_naive(value, default_timezone) - return value.date() - if isinstance(value, datetime.date): - return value - - for format in self.input_formats: - if format.lower() == ISO_8601: - try: - parsed = parse_date(value) - except (ValueError, TypeError): - pass - else: - if parsed is not None: - return parsed - else: - try: - parsed = datetime.datetime.strptime(value, format) - except (ValueError, TypeError): - pass - else: - return parsed.date() - - msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) - raise ValidationError(msg) - - def to_native(self, value): - if value is None or self.format is None: - return value - - if isinstance(value, datetime.datetime): - value = value.date() - - if self.format.lower() == ISO_8601: - return value.isoformat() - return value.strftime(self.format) - - -class DateTimeField(WritableField): - type_name = 'DateTimeField' - type_label = 'datetime' - widget = widgets.DateTimeInput - form_field_class = forms.DateTimeField - - default_error_messages = { - 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), - } - empty = None - input_formats = api_settings.DATETIME_INPUT_FORMATS - format = api_settings.DATETIME_FORMAT - - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(DateTimeField, self).__init__(*args, **kwargs) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - if isinstance(value, datetime.datetime): - return value - if isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day) - if settings.USE_TZ: - # For backwards compatibility, interpret naive datetimes in - # local time. This won't work during DST change, but we can't - # do much about it, so we let the exceptions percolate up the - # call stack. - warnings.warn("DateTimeField received a naive datetime (%s)" - " while time zone support is active." % value, - RuntimeWarning) - default_timezone = timezone.get_default_timezone() - value = timezone.make_aware(value, default_timezone) - return value - - for format in self.input_formats: - if format.lower() == ISO_8601: - try: - parsed = parse_datetime(value) - except (ValueError, TypeError): - pass - else: - if parsed is not None: - return parsed - else: - try: - parsed = datetime.datetime.strptime(value, format) - except (ValueError, TypeError): - pass - else: - return parsed - - msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) - raise ValidationError(msg) - - def to_native(self, value): - if value is None or self.format is None: - return value - - if self.format.lower() == ISO_8601: - ret = value.isoformat() - if ret.endswith('+00:00'): - ret = ret[:-6] + 'Z' - return ret - return value.strftime(self.format) - - -class TimeField(WritableField): - type_name = 'TimeField' - type_label = 'time' - widget = widgets.TimeInput - form_field_class = forms.TimeField - - default_error_messages = { - 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), - } - empty = None - input_formats = api_settings.TIME_INPUT_FORMATS - format = api_settings.TIME_FORMAT - - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(TimeField, self).__init__(*args, **kwargs) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - if isinstance(value, datetime.time): - return value - - for format in self.input_formats: - if format.lower() == ISO_8601: - try: - parsed = parse_time(value) - except (ValueError, TypeError): - pass - else: - if parsed is not None: - return parsed - else: - try: - parsed = datetime.datetime.strptime(value, format) - except (ValueError, TypeError): - pass - else: - return parsed.time() - - msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) - raise ValidationError(msg) - - def to_native(self, value): - if value is None or self.format is None: - return value - - if isinstance(value, datetime.datetime): - value = value.time() - - if self.format.lower() == ISO_8601: - return value.isoformat() - return value.strftime(self.format) - - -class IntegerField(WritableField): - type_name = 'IntegerField' - type_label = 'integer' - form_field_class = forms.IntegerField - empty = 0 - - default_error_messages = { - 'invalid': _('Enter a whole number.'), - 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), - 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), - } - - def __init__(self, max_value=None, min_value=None, *args, **kwargs): - self.max_value, self.min_value = max_value, min_value - super(IntegerField, self).__init__(*args, **kwargs) - - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None + def to_native(self, data): try: - value = int(str(value)) + return self.choice_strings_to_values[str(data)] + except KeyError: + self.fail('invalid_choice', input=data) + + +class MultipleChoiceField(ChoiceField): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_choice': '`{input}` is not a valid choice.', + 'not_a_list': 'Expected a list of items but got type `{input_type}`' + } + + def to_native(self, data): + if not hasattr(data, '__iter__'): + self.fail('not_a_list', input_type=type(data).__name__) + return set([ + super(MultipleChoiceField, self).to_native(item) + for item in data + ]) + + +class IntegerField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_integer': 'A valid integer is required.' + } + + def to_native(self, data): + try: + data = int(str(data)) except (ValueError, TypeError): - raise ValidationError(self.error_messages['invalid']) - return value - - -class FloatField(WritableField): - type_name = 'FloatField' - type_label = 'float' - form_field_class = forms.FloatField - empty = 0 - - default_error_messages = { - 'invalid': _("'%s' value must be a float."), - } - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - try: - return float(value) - except (TypeError, ValueError): - msg = self.error_messages['invalid'] % value - raise ValidationError(msg) - - -class DecimalField(WritableField): - type_name = 'DecimalField' - type_label = 'decimal' - form_field_class = forms.DecimalField - empty = Decimal('0') - - default_error_messages = { - 'invalid': _('Enter a number.'), - 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), - 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), - 'max_digits': _('Ensure that there are no more than %s digits in total.'), - 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), - 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') - } - - def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): - self.max_value, self.min_value = max_value, min_value - self.max_digits, self.decimal_places = max_digits, decimal_places - super(DecimalField, self).__init__(*args, **kwargs) - - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def from_native(self, value): - """ - Validates that the input is a decimal number. Returns a Decimal - instance. Returns None for empty values. Ensures that there are no more - than max_digits in the number, and no more than decimal_places digits - after the decimal point. - """ - if value in validators.EMPTY_VALUES: - return None - value = smart_text(value).strip() - try: - value = Decimal(value) - except DecimalException: - raise ValidationError(self.error_messages['invalid']) - return value - - def validate(self, value): - super(DecimalField, self).validate(value) - if value in validators.EMPTY_VALUES: - return - # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, - # since it is never equal to itself. However, NaN is the only value that - # isn't equal to itself, so we can use this to identify NaN - if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): - raise ValidationError(self.error_messages['invalid']) - sign, digittuple, exponent = value.as_tuple() - decimals = abs(exponent) - # digittuple doesn't include any leading zeros. - digits = len(digittuple) - if decimals > digits: - # We have leading zeros up to or past the decimal point. Count - # everything past the decimal point as a digit. We do not count - # 0 before the decimal point as a digit since that would mean - # we would not allow max_digits = decimal_places. - digits = decimals - whole_digits = digits - decimals - - if self.max_digits is not None and digits > self.max_digits: - raise ValidationError(self.error_messages['max_digits'] % self.max_digits) - if self.decimal_places is not None and decimals > self.decimal_places: - raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) - if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): - raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) - return value - - -class FileField(WritableField): - use_files = True - type_name = 'FileField' - type_label = 'file upload' - form_field_class = forms.FileField - widget = widgets.FileInput - - default_error_messages = { - 'invalid': _("No file was submitted. Check the encoding type on the form."), - 'missing': _("No file was submitted."), - 'empty': _("The submitted file is empty."), - 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'), - 'contradiction': _('Please either submit a file or check the clear checkbox, not both.') - } - - def __init__(self, *args, **kwargs): - self.max_length = kwargs.pop('max_length', None) - self.allow_empty_file = kwargs.pop('allow_empty_file', False) - super(FileField, self).__init__(*args, **kwargs) - - def from_native(self, data): - if data in validators.EMPTY_VALUES: - return None - - # UploadedFile objects should have name and size attributes. - try: - file_name = data.name - file_size = data.size - except AttributeError: - raise ValidationError(self.error_messages['invalid']) - - if self.max_length is not None and len(file_name) > self.max_length: - error_values = {'max': self.max_length, 'length': len(file_name)} - raise ValidationError(self.error_messages['max_length'] % error_values) - if not file_name: - raise ValidationError(self.error_messages['invalid']) - if not self.allow_empty_file and not file_size: - raise ValidationError(self.error_messages['empty']) - + self.fail('invalid_integer') return data - def to_native(self, value): - return value.name +class MethodField(Field): + def __init__(self, **kwargs): + kwargs['source'] = '*' + kwargs['read_only'] = True + super(MethodField, self).__init__(**kwargs) -class ImageField(FileField): - use_files = True - type_name = 'ImageField' - type_label = 'image upload' - form_field_class = forms.ImageField - - default_error_messages = { - 'invalid_image': _("Upload a valid image. The file you uploaded was " - "either not an image or a corrupted image."), - } - - def from_native(self, data): - """ - Checks that the file-upload field data contains a valid image (GIF, JPG, - PNG, possibly others -- whatever the Python Imaging Library supports). - """ - f = super(ImageField, self).from_native(data) - if f is None: - return None - - from rest_framework.compat import Image - assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.' - - # We need to get a file object for PIL. We might have a path or we might - # have to read the data into memory. - if hasattr(data, 'temporary_file_path'): - file = data.temporary_file_path() - else: - if hasattr(data, 'read'): - file = BytesIO(data.read()) - else: - file = BytesIO(data['content']) - - try: - # load() could spot a truncated JPEG, but it loads the entire - # image in memory, which is a DoS vector. See #3848 and #18520. - # verify() must be called immediately after the constructor. - Image.open(file).verify() - except ImportError: - # Under PyPy, it is possible to import PIL. However, the underlying - # _imaging C module isn't available, so an ImportError will be - # raised. Catch and re-raise. - raise - except Exception: # Python Imaging Library doesn't recognize it as an image - raise ValidationError(self.error_messages['invalid_image']) - if hasattr(f, 'seek') and callable(f.seek): - f.seek(0) - return f - - -class SerializerMethodField(Field): - """ - A field that gets its value by calling a method on the serializer it's attached to. - """ - - def __init__(self, method_name, *args, **kwargs): - self.method_name = method_name - super(SerializerMethodField, self).__init__(*args, **kwargs) - - def field_to_native(self, obj, field_name): - value = getattr(self.parent, self.method_name)(obj) - return self.to_native(value) + def to_primative(self, value): + attr = 'get_{field_name}'.format(field_name=self.field_name) + method = getattr(self.parent, attr) + return method(value) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index b3bd6ce92..6705cbb2f 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -79,18 +79,16 @@ class GenericAPIView(views.APIView): 'view': self } - def get_serializer(self, instance=None, data=None, files=None, many=False, - partial=False, allow_add_remove=False): + def get_serializer(self, instance=None, data=None, many=False, partial=False): """ Return the serializer instance that should be used for validating and deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() context = self.get_serializer_context() - return serializer_class(instance, data=data, files=files, - many=many, partial=partial, - allow_add_remove=allow_add_remove, - context=context) + return serializer_class( + instance, data=data, many=many, partial=partial, context=context + ) def get_pagination_serializer(self, page): """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ac59d9795..ee01cabc7 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -36,12 +36,10 @@ class CreateModelMixin(object): Create a model instance. """ def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA, files=request.FILES) + serializer = self.get_serializer(data=request.DATA) if serializer.is_valid(): - self.pre_save(serializer.object) - self.object = serializer.save(force_insert=True) - self.post_save(self.object, created=True) + self.object = serializer.save() headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) @@ -90,26 +88,20 @@ class UpdateModelMixin(object): partial = kwargs.pop('partial', False) self.object = self.get_object_or_none() - serializer = self.get_serializer(self.object, data=request.DATA, - files=request.FILES, partial=partial) + serializer = self.get_serializer(self.object, data=request.DATA, partial=partial) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - try: - self.pre_save(serializer.object) - except ValidationError as err: - # full_clean on model instance may be called in pre_save, - # so we have to handle eventual errors. - return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup_value = self.kwargs[lookup_url_kwarg] + extras = {self.lookup_field: lookup_value} if self.object is None: - self.object = serializer.save(force_insert=True) - self.post_save(self.object, created=True) + self.object = serializer.save(extras=extras) return Response(serializer.data, status=status.HTTP_201_CREATED) - self.object = serializer.save(force_update=True) - self.post_save(self.object, created=False) + self.object = serializer.save(extras=extras) return Response(serializer.data, status=status.HTTP_200_OK) def partial_update(self, request, *args, **kwargs): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d51ea929b..83ef97c5c 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -48,17 +48,17 @@ class DefaultObjectSerializer(serializers.Field): super(DefaultObjectSerializer, self).__init__(source=source) -class PaginationSerializerOptions(serializers.SerializerOptions): - """ - An object that stores the options that may be provided to a - pagination serializer by using the inner `Meta` class. +# class PaginationSerializerOptions(serializers.SerializerOptions): +# """ +# An object that stores the options that may be provided to a +# pagination serializer by using the inner `Meta` class. - Accessible on the instance as `serializer.opts`. - """ - def __init__(self, meta): - super(PaginationSerializerOptions, self).__init__(meta) - self.object_serializer_class = getattr(meta, 'object_serializer_class', - DefaultObjectSerializer) +# Accessible on the instance as `serializer.opts`. +# """ +# def __init__(self, meta): +# super(PaginationSerializerOptions, self).__init__(meta) +# self.object_serializer_class = getattr(meta, 'object_serializer_class', +# DefaultObjectSerializer) class BasePaginationSerializer(serializers.Serializer): @@ -66,7 +66,7 @@ class BasePaginationSerializer(serializers.Serializer): A base class for pagination serializers to inherit from, to make implementing custom serializers more easy. """ - _options_class = PaginationSerializerOptions + # _options_class = PaginationSerializerOptions results_field = 'results' def __init__(self, *args, **kwargs): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 56870b408..e69de29bb 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,486 +0,0 @@ -""" -Serializer fields that deal with relationships. - -These fields allow you to specify the style that should be used to represent -model relationships, including hyperlinks, primary keys, or slugs. -""" -from __future__ import unicode_literals -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch -from django import forms -from django.db.models.fields import BLANK_CHOICE_DASH -from django.forms import widgets -from django.forms.models import ModelChoiceIterator -from django.utils.translation import ugettext_lazy as _ -from rest_framework.fields import Field, WritableField, get_component, is_simple_callable -from rest_framework.reverse import reverse -from rest_framework.compat import urlparse -from rest_framework.compat import smart_text - - -# Relational fields - -# Not actually Writable, but subclasses may need to be. -class RelatedField(WritableField): - """ - Base class for related model fields. - - This represents a relationship using the unicode representation of the target. - """ - widget = widgets.Select - many_widget = widgets.SelectMultiple - form_field_class = forms.ChoiceField - many_form_field_class = forms.MultipleChoiceField - null_values = (None, '', 'None') - - cache_choices = False - empty_label = None - read_only = True - many = False - - def __init__(self, *args, **kwargs): - queryset = kwargs.pop('queryset', None) - self.many = kwargs.pop('many', self.many) - if self.many: - self.widget = self.many_widget - self.form_field_class = self.many_form_field_class - - kwargs['read_only'] = kwargs.pop('read_only', self.read_only) - super(RelatedField, self).__init__(*args, **kwargs) - - if not self.required: - # Accessed in ModelChoiceIterator django/forms/models.py:1034 - # If set adds empty choice. - self.empty_label = BLANK_CHOICE_DASH[0][1] - - self.queryset = queryset - - def initialize(self, parent, field_name): - super(RelatedField, self).initialize(parent, field_name) - if self.queryset is None and not self.read_only: - manager = getattr(self.parent.opts.model, self.source or field_name) - if hasattr(manager, 'related'): # Forward - self.queryset = manager.related.model._default_manager.all() - else: # Reverse - self.queryset = manager.field.rel.to._default_manager.all() - - # We need this stuff to make form choices work... - - def prepare_value(self, obj): - return self.to_native(obj) - - def label_from_instance(self, obj): - """ - Return a readable representation for use with eg. select widgets. - """ - desc = smart_text(obj) - ident = smart_text(self.to_native(obj)) - if desc == ident: - return desc - return "%s - %s" % (desc, ident) - - def _get_queryset(self): - return self._queryset - - def _set_queryset(self, queryset): - self._queryset = queryset - self.widget.choices = self.choices - - queryset = property(_get_queryset, _set_queryset) - - def _get_choices(self): - # If self._choices is set, then somebody must have manually set - # the property self.choices. In this case, just return self._choices. - if hasattr(self, '_choices'): - return self._choices - - # Otherwise, execute the QuerySet in self.queryset to determine the - # choices dynamically. Return a fresh ModelChoiceIterator that has not been - # consumed. Note that we're instantiating a new ModelChoiceIterator *each* - # time _get_choices() is called (and, thus, each time self.choices is - # accessed) so that we can ensure the QuerySet has not been consumed. This - # construct might look complicated but it allows for lazy evaluation of - # the queryset. - return ModelChoiceIterator(self) - - def _set_choices(self, value): - # Setting choices also sets the choices on the widget. - # choices can be any iterable, but we call list() on it because - # it will be consumed more than once. - self._choices = self.widget.choices = list(value) - - choices = property(_get_choices, _set_choices) - - # Default value handling - - def get_default_value(self): - default = super(RelatedField, self).get_default_value() - if self.many and default is None: - return [] - return default - - # Regular serializer stuff... - - def field_to_native(self, obj, field_name): - try: - if self.source == '*': - return self.to_native(obj) - - source = self.source or field_name - value = obj - - for component in source.split('.'): - if value is None: - break - value = get_component(value, component) - except ObjectDoesNotExist: - return None - - if value is None: - return None - - if self.many: - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] - else: - # Also support non-queryset iterables. - # This allows us to also support plain lists of related items. - return [self.to_native(item) for item in value] - return self.to_native(value) - - def field_from_native(self, data, files, field_name, into): - if self.read_only: - return - - try: - if self.many: - try: - # Form data - value = data.getlist(field_name) - if value == [''] or value == []: - raise KeyError - except AttributeError: - # Non-form data - value = data[field_name] - else: - value = data[field_name] - except KeyError: - if self.partial: - return - value = self.get_default_value() - - if value in self.null_values: - if self.required: - raise ValidationError(self.error_messages['required']) - into[(self.source or field_name)] = None - elif self.many: - into[(self.source or field_name)] = [self.from_native(item) for item in value] - else: - into[(self.source or field_name)] = self.from_native(value) - - -# PrimaryKey relationships - -class PrimaryKeyRelatedField(RelatedField): - """ - Represents a relationship as a pk value. - """ - read_only = False - - default_error_messages = { - 'does_not_exist': _("Invalid pk '%s' - object does not exist."), - 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), - } - - # TODO: Remove these field hacks... - def prepare_value(self, obj): - return self.to_native(obj.pk) - - def label_from_instance(self, obj): - """ - Return a readable representation for use with eg. select widgets. - """ - desc = smart_text(obj) - ident = smart_text(self.to_native(obj.pk)) - if desc == ident: - return desc - return "%s - %s" % (desc, ident) - - # TODO: Possibly change this to just take `obj`, through prob less performant - def to_native(self, pk): - return pk - - def from_native(self, data): - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - return self.queryset.get(pk=data) - except ObjectDoesNotExist: - msg = self.error_messages['does_not_exist'] % smart_text(data) - raise ValidationError(msg) - except (TypeError, ValueError): - received = type(data).__name__ - msg = self.error_messages['incorrect_type'] % received - raise ValidationError(msg) - - def field_to_native(self, obj, field_name): - if self.many: - # To-many relationship - - queryset = None - if not self.source: - # Prefer obj.serializable_value for performance reasons - try: - queryset = obj.serializable_value(field_name) - except AttributeError: - pass - if queryset is None: - # RelatedManager (reverse relationship) - source = self.source or field_name - queryset = obj - for component in source.split('.'): - if queryset is None: - return [] - queryset = get_component(queryset, component) - - # Forward relationship - if is_simple_callable(getattr(queryset, 'all', None)): - return [self.to_native(item.pk) for item in queryset.all()] - else: - # Also support non-queryset iterables. - # This allows us to also support plain lists of related items. - return [self.to_native(item.pk) for item in queryset] - - # To-one relationship - try: - # Prefer obj.serializable_value for performance reasons - pk = obj.serializable_value(self.source or field_name) - except AttributeError: - # RelatedObject (reverse relationship) - try: - pk = getattr(obj, self.source or field_name).pk - except (ObjectDoesNotExist, AttributeError): - return None - - # Forward relationship - return self.to_native(pk) - - -# Slug relationships - -class SlugRelatedField(RelatedField): - """ - Represents a relationship using a unique field on the target. - """ - read_only = False - - default_error_messages = { - 'does_not_exist': _("Object with %s=%s does not exist."), - 'invalid': _('Invalid value.'), - } - - def __init__(self, *args, **kwargs): - self.slug_field = kwargs.pop('slug_field', None) - assert self.slug_field, 'slug_field is required' - super(SlugRelatedField, self).__init__(*args, **kwargs) - - def to_native(self, obj): - return getattr(obj, self.slug_field) - - def from_native(self, data): - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - return self.queryset.get(**{self.slug_field: data}) - except ObjectDoesNotExist: - raise ValidationError(self.error_messages['does_not_exist'] % - (self.slug_field, smart_text(data))) - except (TypeError, ValueError): - msg = self.error_messages['invalid'] - raise ValidationError(msg) - - -# Hyperlinked relationships - -class HyperlinkedRelatedField(RelatedField): - """ - Represents a relationship using hyperlinking. - """ - read_only = False - lookup_field = 'pk' - - default_error_messages = { - 'no_match': _('Invalid hyperlink - No URL match'), - 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), - 'configuration_error': _('Invalid hyperlink due to configuration error'), - 'does_not_exist': _("Invalid hyperlink - object does not exist."), - 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), - } - - def __init__(self, *args, **kwargs): - try: - self.view_name = kwargs.pop('view_name') - except KeyError: - raise ValueError("Hyperlinked field requires 'view_name' kwarg") - - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) - self.format = kwargs.pop('format', None) - - super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - - def get_url(self, obj, view_name, request, format): - """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. - """ - lookup_field = getattr(obj, self.lookup_field) - kwargs = {self.lookup_field: lookup_field} - return reverse(view_name, kwargs=kwargs, request=request, format=format) - - def get_object(self, queryset, view_name, view_args, view_kwargs): - """ - Return the object corresponding to a matched URL. - - Takes the matched URL conf arguments, and the queryset, and should - return an object instance, or raise an `ObjectDoesNotExist` exception. - """ - lookup_value = view_kwargs[self.lookup_field] - filter_kwargs = {self.lookup_field: lookup_value} - return queryset.get(**filter_kwargs) - - def to_native(self, obj): - view_name = self.view_name - request = self.context.get('request', None) - format = self.format or self.context.get('format', None) - - assert request is not None, ( - "`HyperlinkedRelatedField` requires the request in the serializer " - "context. Add `context={'request': request}` when instantiating " - "the serializer." - ) - - # If the object has not yet been saved then we cannot hyperlink to it. - if getattr(obj, 'pk', None) is None: - return - - # Return the hyperlink, or error if incorrectly configured. - try: - return self.get_url(obj, view_name, request, format) - except NoReverseMatch: - msg = ( - 'Could not resolve URL for hyperlinked relationship using ' - 'view name "%s". You may have failed to include the related ' - 'model in your API, or incorrectly configured the ' - '`lookup_field` attribute on this field.' - ) - raise Exception(msg % view_name) - - def from_native(self, value): - # Convert URL -> model instance pk - # TODO: Use values_list - queryset = self.queryset - if queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - http_prefix = value.startswith(('http:', 'https:')) - except AttributeError: - msg = self.error_messages['incorrect_type'] - raise ValidationError(msg % type(value).__name__) - - if http_prefix: - # If needed convert absolute URLs to relative path - value = urlparse.urlparse(value).path - prefix = get_script_prefix() - if value.startswith(prefix): - value = '/' + value[len(prefix):] - - try: - match = resolve(value) - except Exception: - raise ValidationError(self.error_messages['no_match']) - - if match.view_name != self.view_name: - raise ValidationError(self.error_messages['incorrect_match']) - - try: - return self.get_object(queryset, match.view_name, - match.args, match.kwargs) - except (ObjectDoesNotExist, TypeError, ValueError): - raise ValidationError(self.error_messages['does_not_exist']) - - -class HyperlinkedIdentityField(Field): - """ - Represents the instance, or a property on the instance, using hyperlinking. - """ - lookup_field = 'pk' - read_only = True - - def __init__(self, *args, **kwargs): - try: - self.view_name = kwargs.pop('view_name') - except KeyError: - msg = "HyperlinkedIdentityField requires 'view_name' argument" - raise ValueError(msg) - - self.format = kwargs.pop('format', None) - lookup_field = kwargs.pop('lookup_field', None) - self.lookup_field = lookup_field or self.lookup_field - - super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) - - def field_to_native(self, obj, field_name): - request = self.context.get('request', None) - format = self.context.get('format', None) - view_name = self.view_name - - assert request is not None, ( - "`HyperlinkedIdentityField` requires the request in the serializer" - " context. Add `context={'request': request}` when instantiating " - "the serializer." - ) - - # By default use whatever format is given for the current context - # unless the target is a different type to the source. - # - # Eg. Consider a HyperlinkedIdentityField pointing from a json - # representation to an html property of that representation... - # - # '/snippets/1/' should link to '/snippets/1/highlight/' - # ...but... - # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' - if format and self.format and self.format != format: - format = self.format - - # Return the hyperlink, or error if incorrectly configured. - try: - return self.get_url(obj, view_name, request, format) - except NoReverseMatch: - msg = ( - 'Could not resolve URL for hyperlinked relationship using ' - 'view name "%s". You may have failed to include the related ' - 'model in your API, or incorrectly configured the ' - '`lookup_field` attribute on this field.' - ) - raise Exception(msg % view_name) - - def get_url(self, obj, view_name, request, format): - """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. - """ - lookup_field = getattr(obj, self.lookup_field, None) - kwargs = {self.lookup_field: lookup_field} - - # Handle unsaved object case - if lookup_field is None: - return None - - return reverse(view_name, kwargs=kwargs, request=request, format=format) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 748ebac94..e8935b012 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -458,7 +458,7 @@ class BrowsableAPIRenderer(BaseRenderer): ): return - serializer = view.get_serializer(instance=obj, data=data, files=files) + serializer = view.get_serializer(instance=obj, data=data) serializer.is_valid() data = serializer.data @@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer): 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], 'response_headers': response_headers, - 'put_form': self.get_rendered_html_form(view, 'PUT', request), - 'post_form': self.get_rendered_html_form(view, 'POST', request), - 'delete_form': self.get_rendered_html_form(view, 'DELETE', request), - 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), + #'put_form': self.get_rendered_html_form(view, 'PUT', request), + #'post_form': self.get_rendered_html_form(view, 'POST', request), + #'delete_form': self.get_rendered_html_form(view, 'DELETE', request), + #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), 'raw_data_put_form': raw_data_put_form, 'raw_data_post_form': raw_data_post_form, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f24..d121812d6 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,21 +10,14 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from __future__ import unicode_literals -import copy -import datetime -import inspect -import types -from decimal import Decimal -from django.contrib.contenttypes.generic import GenericForeignKey -from django.core.paginator import Page from django.db import models -from django.forms import widgets from django.utils import six -from django.utils.datastructures import SortedDict -from django.core.exceptions import ObjectDoesNotExist +from collections import namedtuple, OrderedDict +from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError from rest_framework.settings import api_settings - +from rest_framework.utils import html +import copy +import inspect # Note: We do the following so that users of the framework can use this style: # @@ -37,6 +30,253 @@ from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA +FieldResult = namedtuple('FieldResult', ['field', 'value', 'error']) + + +class BaseSerializer(Field): + def __init__(self, instance=None, data=None, **kwargs): + super(BaseSerializer, self).__init__(**kwargs) + self.instance = instance + self._initial_data = data + + def to_native(self, data): + raise NotImplementedError() + + def to_primative(self, instance): + raise NotImplementedError() + + def update(self, instance): + raise NotImplementedError() + + def create(self): + raise NotImplementedError() + + def save(self, extras=None): + if extras is not None: + self._validated_data.update(extras) + + if self.instance is not None: + self.update(self.instance) + else: + self.instance = self.create() + + return self.instance + + def is_valid(self): + try: + self._validated_data = self.to_native(self._initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.args[0] + return False + self._errors = {} + return True + + @property + def data(self): + if not hasattr(self, '_data'): + if self.instance is not None: + self._data = self.to_primative(self.instance) + elif self._initial_data is not None: + self._data = { + field_name: field.get_value(self._initial_data) + for field_name, field in self.fields.items() + } + else: + self._data = self.get_initial() + return self._data + + @property + def errors(self): + if not hasattr(self, '_errors'): + msg = 'You must call `.is_valid()` before accessing `.errors`.' + raise AssertionError(msg) + return self._errors + + @property + def validated_data(self): + if not hasattr(self, '_validated_data'): + msg = 'You must call `.is_valid()` before accessing `.validated_data`.' + raise AssertionError(msg) + return self._validated_data + + +class SerializerMetaclass(type): + """ + This metaclass sets a dictionary named `base_fields` on the class. + + Any fields included as attributes on either the class or it's superclasses + will be include in the `base_fields` dictionary. + """ + + @classmethod + def _get_fields(cls, bases, attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1]._creation_counter) + + # If this class is subclassing another Serializer, add that Serializer's + # fields. Note that we loop over the bases in *reverse*. This is necessary + # in order to maintain the correct order of fields. + for base in bases[::-1]: + if hasattr(base, 'base_fields'): + fields = list(base.base_fields.items()) + fields + + return OrderedDict(fields) + + def __new__(cls, name, bases, attrs): + attrs['base_fields'] = cls._get_fields(bases, attrs) + return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) + + +@six.add_metaclass(SerializerMetaclass) +class Serializer(BaseSerializer): + + def __new__(cls, *args, **kwargs): + many = kwargs.pop('many', False) + if many: + class DynamicListSerializer(ListSerializer): + child = cls() + return DynamicListSerializer(*args, **kwargs) + return super(Serializer, cls).__new__(cls) + + def __init__(self, *args, **kwargs): + kwargs.pop('context', None) + kwargs.pop('partial', None) + kwargs.pop('many', False) + + super(Serializer, self).__init__(*args, **kwargs) + + # Every new serializer is created with a clone of the field instances. + # This allows users to dynamically modify the fields on a serializer + # instance without affecting every other serializer class. + self.fields = self.get_fields() + + # Setup all the child fields, to provide them with the current context. + for field_name, field in self.fields.items(): + field.bind(field_name, self, self) + + def get_fields(self): + return copy.deepcopy(self.base_fields) + + def bind(self, field_name, parent, root): + # If the serializer is used as a field then when it becomes bound + # it also needs to bind all its child fields. + super(Serializer, self).bind(field_name, parent, root) + for field_name, field in self.fields.items(): + field.bind(field_name, self, root) + + def get_initial(self): + return { + field.field_name: field.get_initial() + for field in self.fields.values() + } + + def get_value(self, dictionary): + # We override the default field access in order to support + # nested HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_dict(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) + + def to_native(self, data): + """ + Dict of native values <- Dict of primitive datatypes. + """ + ret = {} + errors = {} + fields = [field for field in self.fields.values() if not field.read_only] + + for field in fields: + primitive_value = field.get_value(data) + try: + validated_value = field.validate(primitive_value) + except ValidationError as exc: + errors[field.field_name] = str(exc) + except SkipField: + pass + else: + set_value(ret, field.source_attrs, validated_value) + + if errors: + raise ValidationError(errors) + + return ret + + def to_primative(self, instance): + """ + Object instance -> Dict of primitive datatypes. + """ + ret = OrderedDict() + fields = [field for field in self.fields.values() if not field.write_only] + + for field in fields: + native_value = field.get_attribute(instance) + ret[field.field_name] = field.to_primative(native_value) + + return ret + + def __iter__(self): + errors = self.errors if hasattr(self, '_errors') else {} + for field in self.fields.values(): + value = self.data.get(field.field_name) if self.data else None + error = errors.get(field.field_name) + yield FieldResult(field, value, error) + + +class ListSerializer(BaseSerializer): + child = None + initial = [] + + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + assert self.child is not None, '`child` is a required argument.' + + kwargs.pop('context', None) + kwargs.pop('partial', None) + + super(ListSerializer, self).__init__(*args, **kwargs) + self.child.bind('', self, self) + + def bind(self, field_name, parent, root): + # If the list is used as a field then it needs to provide + # the current context to the child serializer. + super(ListSerializer, self).bind(field_name, parent, root) + self.child.bind(field_name, self, root) + + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if is_html_input(dictionary): + return html.parse_html_list(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) + + def to_native(self, data): + """ + List of dicts of native values <- List of dicts of primitive datatypes. + """ + if html.is_html_input(data): + data = html.parse_html_list(data) + + return [self.child.validate(item) for item in data] + + def to_primative(self, data): + """ + List of object instances -> List of dicts of primitive datatypes. + """ + return [self.child.to_primative(item) for item in data] + + def create(self, attrs_list): + return [self.child.create(attrs) for attrs in attrs_list] + + def save(self): + if self.instance is not None: + self.update(self.instance, self.validated_data) + self.instance = self.create(self.validated_data) + return self.instance + + def _resolve_model(obj): """ Resolve supplied `obj` to a Django model class. @@ -58,614 +298,71 @@ def _resolve_model(obj): raise ValueError("{0} is not a Django model".format(obj)) -def pretty_name(name): - """Converts 'first_name' to 'First name'""" - if not name: - return '' - return name.replace('_', ' ').capitalize() - - -class RelationsList(list): - _deleted = [] - - -class NestedValidationError(ValidationError): - """ - The default ValidationError behavior is to stringify each item in the list - if the messages are a list of error messages. - - In the case of nested serializers, where the parent has many children, - then the child's `serializer.errors` will be a list of dicts. In the case - of a single child, the `serializer.errors` will be a dict. - - We need to override the default behavior to get properly nested error dicts. - """ - - def __init__(self, message): - if isinstance(message, dict): - self._messages = [message] - else: - self._messages = message - - @property - def messages(self): - return self._messages - - -class DictWithMetadata(dict): - """ - A dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overridden to remove the metadata from the dict, since it shouldn't be - pickled and may in some instances be unpickleable. - """ - return dict(self) - - -class SortedDictWithMetadata(SortedDict): - """ - A sorted dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overriden to remove the metadata from the dict, since it shouldn't be - pickle and may in some instances be unpickleable. - """ - return SortedDict(self).__dict__ - - -def _is_protected_type(obj): - """ - True if the object is a native datatype that does not need to - be serialized further. - """ - return isinstance(obj, ( - types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) - ) - - -def _get_declared_fields(bases, attrs): - """ - Create a list of serializer field instances from the passed in 'attrs', - plus any fields on the base classes (in 'bases'). - - Note that all fields from the base classes are used. - """ - fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in list(six.iteritems(attrs)) - if isinstance(obj, Field)] - fields.sort(key=lambda x: x[1].creation_counter) - - # If this class is subclassing another Serializer, add that Serializer's - # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to maintain the correct order of fields. - for base in bases[::-1]: - if hasattr(base, 'base_fields'): - fields = list(base.base_fields.items()) + fields - - return SortedDict(fields) - - -class SerializerMetaclass(type): - def __new__(cls, name, bases, attrs): - attrs['base_fields'] = _get_declared_fields(bases, attrs) - return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) - - -class SerializerOptions(object): - """ - Meta class options for Serializer - """ - def __init__(self, meta): - self.depth = getattr(meta, 'depth', 0) - self.fields = getattr(meta, 'fields', ()) - self.exclude = getattr(meta, 'exclude', ()) - - -class BaseSerializer(WritableField): - """ - This is the Serializer implementation. - We need to implement it as `BaseSerializer` due to metaclass magicks. - """ - class Meta(object): - pass - - _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata - - def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, many=False, - allow_add_remove=False, **kwargs): - super(BaseSerializer, self).__init__(**kwargs) - self.opts = self._options_class(self.Meta) - self.parent = None - self.root = None - self.partial = partial - self.many = many - self.allow_add_remove = allow_add_remove - - self.context = context or {} - - self.init_data = data - self.init_files = files - self.object = instance - self.fields = self.get_fields() - - self._data = None - self._files = None - self._errors = None - - if many and instance is not None and not hasattr(instance, '__iter__'): - raise ValueError('instance should be a queryset or other iterable with many=True') - - if allow_add_remove and not many: - raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') - - ##### - # Methods to determine which fields to use when (de)serializing objects. - - def get_default_fields(self): - """ - Return the complete set of default fields for the object, as a dict. - """ - return {} - - def get_fields(self): - """ - Returns the complete set of fields for the object as a dict. - - This will be the set of any explicitly declared fields, - plus the set of fields returned by get_default_fields(). - """ - ret = SortedDict() - - # Get the explicitly declared fields - base_fields = copy.deepcopy(self.base_fields) - for key, field in base_fields.items(): - ret[key] = field - - # Add in the default fields - default_fields = self.get_default_fields() - for key, val in default_fields.items(): - if key not in ret: - ret[key] = val - - # If 'fields' is specified, use those fields, in that order. - if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' - new = SortedDict() - for key in self.opts.fields: - new[key] = ret[key] - ret = new - - # Remove anything in 'exclude' - if self.opts.exclude: - assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' - for key in self.opts.exclude: - ret.pop(key, None) - - for key, field in ret.items(): - field.initialize(parent=self, field_name=key) - - return ret - - ##### - # Methods to convert or revert from objects <--> primitive representations. - - def get_field_key(self, field_name): - """ - Return the key that should be used for a given field. - """ - return field_name - - def restore_fields(self, data, files): - """ - Core of deserialization, together with `restore_object`. - Converts a dictionary of data into a dictionary of deserialized fields. - """ - reverted_data = {} - - if data is not None and not isinstance(data, dict): - self._errors['non_field_errors'] = ['Invalid data'] - return None - - for field_name, field in self.fields.items(): - field.initialize(parent=self, field_name=field_name) - try: - field.field_from_native(data, files, field_name, reverted_data) - except ValidationError as err: - self._errors[field_name] = list(err.messages) - - return reverted_data - - def perform_validation(self, attrs): - """ - Run `validate_()` and `validate()` methods on the serializer - """ - for field_name, field in self.fields.items(): - if field_name in self._errors: - continue - - source = field.source or field_name - if self.partial and source not in attrs: - continue - try: - validate_method = getattr(self, 'validate_%s' % field_name, None) - if validate_method: - attrs = validate_method(attrs, source) - except ValidationError as err: - self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - - # If there are already errors, we don't run .validate() because - # field-validation failed and thus `attrs` may not be complete. - # which in turn can cause inconsistent validation errors. - if not self._errors: - try: - attrs = self.validate(attrs) - except ValidationError as err: - if hasattr(err, 'message_dict'): - for field_name, error_messages in err.message_dict.items(): - self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) - elif hasattr(err, 'messages'): - self._errors['non_field_errors'] = err.messages - - return attrs - - def validate(self, attrs): - """ - Stub method, to be overridden in Serializer subclasses - """ - return attrs - - def restore_object(self, attrs, instance=None): - """ - Deserialize a dictionary of attributes into an object instance. - You should override this method to control how deserialized objects - are instantiated. - """ - if instance is not None: - instance.update(attrs) - return instance - return attrs - - def to_native(self, obj): - """ - Serialize objects -> primitives. - """ - ret = self._dict_class() - ret.fields = self._dict_class() - - for field_name, field in self.fields.items(): - if field.read_only and obj is None: - continue - field.initialize(parent=self, field_name=field_name) - key = self.get_field_key(field_name) - value = field.field_to_native(obj, field_name) - method = getattr(self, 'transform_%s' % field_name, None) - if callable(method): - value = method(obj, value) - if not getattr(field, 'write_only', False): - ret[key] = value - ret.fields[key] = self.augment_field(field, field_name, key, value) - - return ret - - def from_native(self, data, files=None): - """ - Deserialize primitives -> objects. - """ - self._errors = {} - - if data is not None or files is not None: - attrs = self.restore_fields(data, files) - if attrs is not None: - attrs = self.perform_validation(attrs) - else: - self._errors['non_field_errors'] = ['No input provided'] - - if not self._errors: - return self.restore_object(attrs, instance=getattr(self, 'object', None)) - - def augment_field(self, field, field_name, key, value): - # This horrible stuff is to manage serializers rendering to HTML - field._errors = self._errors.get(key) if self._errors else None - field._name = field_name - field._value = self.init_data.get(key) if self._errors and self.init_data else value - if not field.label: - field.label = pretty_name(key) - return field - - def field_to_native(self, obj, field_name): - """ - Override default so that the serializer can be used as a nested field - across relationships. - """ - if self.write_only: - return None - - if self.source == '*': - return self.to_native(obj) - - # Get the raw field value - try: - source = self.source or field_name - value = obj - - for component in source.split('.'): - if value is None: - break - value = get_component(value, component) - except ObjectDoesNotExist: - return None - - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] - - if value is None: - return None - - if self.many: - return [self.to_native(item) for item in value] - return self.to_native(value) - - def field_from_native(self, data, files, field_name, into): - """ - Override default so that the serializer can be used as a writable - nested field across relationships. - """ - if self.read_only: - return - - try: - value = data[field_name] - except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - value = copy.deepcopy(self.default) - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return - - if self.source == '*': - if value: - reverted_data = self.restore_fields(value, {}) - if not self._errors: - into.update(reverted_data) - else: - if value in (None, ''): - into[(self.source or field_name)] = None - else: - # Set the serializer object if it exists - obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - - # If we have a model manager or similar object then we need - # to iterate through each instance. - if ( - self.many and - not hasattr(obj, '__iter__') and - is_simple_callable(getattr(obj, 'all', None)) - ): - obj = obj.all() - - kwargs = { - 'instance': obj, - 'data': value, - 'context': self.context, - 'partial': self.partial, - 'many': self.many, - 'allow_add_remove': self.allow_add_remove - } - serializer = self.__class__(**kwargs) - - if serializer.is_valid(): - into[self.source or field_name] = serializer.object - else: - # Propagate errors up to our parent - raise NestedValidationError(serializer.errors) - - def get_identity(self, data): - """ - This hook is required for bulk update. - It is used to determine the canonical identity of a given object. - - Note that the data has not been validated at this point, so we need - to make sure that we catch any cases of incorrect datatypes being - passed to this method. - """ - try: - return data.get('id', None) - except AttributeError: - return None - - @property - def errors(self): - """ - Run deserialization and return error data, - setting self.object if no errors occurred. - """ - if self._errors is None: - data, files = self.init_data, self.init_files - - if self.many is not None: - many = self.many - else: - many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=3) - - if many: - ret = RelationsList() - errors = [] - update = self.object is not None - - if update: - # If this is a bulk update we need to map all the objects - # to a canonical identity so we can determine which - # individual object is being updated for each item in the - # incoming data - objects = self.object - identities = [self.get_identity(self.to_native(obj)) for obj in objects] - identity_to_objects = dict(zip(identities, objects)) - - if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): - for item in data: - if update: - # Determine which object we're updating - identity = self.get_identity(item) - self.object = identity_to_objects.pop(identity, None) - if self.object is None and not self.allow_add_remove: - ret.append(None) - errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) - continue - - ret.append(self.from_native(item, None)) - errors.append(self._errors) - - if update and self.allow_add_remove: - ret._deleted = identity_to_objects.values() - - self._errors = any(errors) and errors or [] - else: - self._errors = {'non_field_errors': ['Expected a list of items.']} - else: - ret = self.from_native(data, files) - - if not self._errors: - self.object = ret - - return self._errors - - def is_valid(self): - return not self.errors - - @property - def data(self): - """ - Returns the serialized data on the serializer. - """ - if self._data is None: - obj = self.object - - if self.many is not None: - many = self.many - else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=2) - - if many: - self._data = [self.to_native(item) for item in obj] - else: - self._data = self.to_native(obj) - - return self._data - - def save_object(self, obj, **kwargs): - obj.save(**kwargs) - - def delete_object(self, obj): - obj.delete() - - def save(self, **kwargs): - """ - Save the deserialized object and return it. - """ - # Clear cached _data, which may be invalidated by `save()` - self._data = None - - if isinstance(self.object, list): - [self.save_object(item, **kwargs) for item in self.object] - - if self.object._deleted: - [self.delete_object(item) for item in self.object._deleted] - else: - self.save_object(self.object, **kwargs) - - return self.object - - def metadata(self): - """ - Return a dictionary of metadata about the fields on the serializer. - Useful for things like responding to OPTIONS requests, or generating - API schemas for auto-documentation. - """ - return SortedDict( - [ - (field_name, field.metadata()) - for field_name, field in six.iteritems(self.fields) - ] - ) - - -class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): - pass - - -class ModelSerializerOptions(SerializerOptions): +class ModelSerializerOptions(object): """ Meta class options for ModelSerializer """ def __init__(self, meta): - super(ModelSerializerOptions, self).__init__(meta) - self.model = getattr(meta, 'model', None) - self.read_only_fields = getattr(meta, 'read_only_fields', ()) - self.write_only_fields = getattr(meta, 'write_only_fields', ()) + self.model = getattr(meta, 'model') + self.fields = getattr(meta, 'fields', ()) + self.depth = getattr(meta, 'depth', 0) class ModelSerializer(Serializer): - """ - A serializer that deals with model instances and querysets. - """ - _options_class = ModelSerializerOptions - field_mapping = { models.AutoField: IntegerField, - models.FloatField: FloatField, + # models.FloatField: FloatField, models.IntegerField: IntegerField, models.PositiveIntegerField: IntegerField, models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, - models.TimeField: TimeField, - models.DecimalField: DecimalField, - models.EmailField: EmailField, + # models.DateTimeField: DateTimeField, + # models.DateField: DateField, + # models.TimeField: TimeField, + # models.DecimalField: DecimalField, + # models.EmailField: EmailField, models.CharField: CharField, - models.URLField: URLField, - models.SlugField: SlugField, + # models.URLField: URLField, + # models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, models.NullBooleanField: BooleanField, - models.FileField: FileField, - models.ImageField: ImageField, + # models.FileField: FileField, + # models.ImageField: ImageField, } + _options_class = ModelSerializerOptions + + def __init__(self, *args, **kwargs): + self.opts = self._options_class(self.Meta) + super(ModelSerializer, self).__init__(*args, **kwargs) + + def get_fields(self): + # Get the explicitly declared fields. + fields = copy.deepcopy(self.base_fields) + + # Add in the default fields. + for key, val in self.get_default_fields().items(): + if key not in fields: + fields[key] = val + + # If `fields` is set on the `Meta` class, + # then use only those fields, and in that order. + if self.opts.fields: + fields = OrderedDict([ + (key, fields[key]) for key in self.opts.fields + ]) + + return fields + def get_default_fields(self): """ Return all the fields that should be serialized for the model. """ - cls = self.opts.model - assert cls is not None, ( - "Serializer class '%s' is missing 'model' Meta option" % - self.__class__.__name__ - ) opts = cls._meta.concrete_model._meta - ret = SortedDict() + ret = OrderedDict() nested = bool(self.opts.depth) # Deal with adding the primary key field @@ -694,29 +391,9 @@ class ModelSerializer(Serializer): has_through_model = True if model_field.rel and nested: - if len(inspect.getargspec(self.get_nested_field).args) == 2: - warnings.warn( - 'The `get_nested_field(model_field)` call signature ' - 'is deprecated. ' - 'Use `get_nested_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_nested_field(model_field) - else: - field = self.get_nested_field(model_field, related_model, to_many) + field = self.get_nested_field(model_field, related_model, to_many) elif model_field.rel: - if len(inspect.getargspec(self.get_nested_field).args) == 3: - warnings.warn( - 'The `get_related_field(model_field, to_many)` call ' - 'signature is deprecated. ' - 'Use `get_related_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_related_field(model_field, to_many=to_many) - else: - field = self.get_related_field(model_field, related_model, to_many) + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) @@ -763,38 +440,6 @@ class ModelSerializer(Serializer): ret[accessor_name] = field - # Ensure that 'read_only_fields' is an iterable - assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - - # Add the `read_only` flag to any fields that have been specified - # in the `read_only_fields` option - for field_name in self.opts.read_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`read_only_fields`, but also added " - "as an explicit field. Remove it from `read_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `read_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].read_only = True - - # Ensure that 'write_only_fields' is an iterable - assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' - - for field_name in self.opts.write_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`write_only_fields`, but also added " - "as an explicit field. Remove it from `write_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `write_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].write_only = True - return ret def get_pk_field(self, model_field): @@ -825,28 +470,24 @@ class ModelSerializer(Serializer): # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'many': to_many - } + kwargs = {} + # 'queryset': related_model._default_manager, + # 'many': to_many + # } if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if not model_field.editable: kwargs['read_only'] = True - if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - - return PrimaryKeyRelatedField(**kwargs) + return IntegerField(**kwargs) + # TODO: return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ @@ -869,8 +510,8 @@ class ModelSerializer(Serializer): if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text # TODO: TypedChoiceField? if model_field.flatchoices: # This ModelField contains choices @@ -880,7 +521,7 @@ class ModelSerializer(Serializer): return ChoiceField(**kwargs) # put this below the ChoiceField because min_value isn't a valid initializer - if issubclass(model_field.__class__, models.PositiveIntegerField) or\ + if issubclass(model_field.__class__, models.PositiveIntegerField) or \ issubclass(model_field.__class__, models.PositiveSmallIntegerField): kwargs['min_value'] = 0 @@ -888,170 +529,27 @@ class ModelSerializer(Serializer): issubclass(model_field.__class__, (models.CharField, models.TextField)): kwargs['allow_none'] = True - attribute_dict = { - models.CharField: ['max_length'], - models.CommaSeparatedIntegerField: ['max_length'], - models.DecimalField: ['max_digits', 'decimal_places'], - models.EmailField: ['max_length'], - models.FileField: ['max_length'], - models.ImageField: ['max_length'], - models.SlugField: ['max_length'], - models.URLField: ['max_length'], - } + # attribute_dict = { + # models.CharField: ['max_length'], + # models.CommaSeparatedIntegerField: ['max_length'], + # models.DecimalField: ['max_digits', 'decimal_places'], + # models.EmailField: ['max_length'], + # models.FileField: ['max_length'], + # models.ImageField: ['max_length'], + # models.SlugField: ['max_length'], + # models.URLField: ['max_length'], + # } - if model_field.__class__ in attribute_dict: - attributes = attribute_dict[model_field.__class__] - for attribute in attributes: - kwargs.update({attribute: getattr(model_field, attribute)}) + # if model_field.__class__ in attribute_dict: + # attributes = attribute_dict[model_field.__class__] + # for attribute in attributes: + # kwargs.update({attribute: getattr(model_field, attribute)}) try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: - return ModelField(model_field=model_field, **kwargs) - - def get_validation_exclusions(self, instance=None): - """ - Return a list of field names to exclude from model validation. - """ - cls = self.opts.model - opts = cls._meta.concrete_model._meta - exclusions = [field.name for field in opts.fields + opts.many_to_many] - - for field_name, field in self.fields.items(): - field_name = field.source or field_name - if ( - field_name in exclusions - and not field.read_only - and (field.required or hasattr(instance, field_name)) - and not isinstance(field, Serializer) - ): - exclusions.remove(field_name) - return exclusions - - def full_clean(self, instance): - """ - Perform Django's full_clean, and populate the `errors` dictionary - if any validation errors occur. - - Note that we don't perform this inside the `.restore_object()` method, - so that subclasses can override `.restore_object()`, and still get - the full_clean validation checking. - """ - try: - instance.full_clean(exclude=self.get_validation_exclusions(instance)) - except ValidationError as err: - self._errors = err.message_dict - return None - return instance - - def restore_object(self, attrs, instance=None): - """ - Restore the model instance. - """ - m2m_data = {} - related_data = {} - nested_forward_relations = {} - meta = self.opts.model._meta - - # Reverse fk or one-to-one relations - for (obj, model) in meta.get_all_related_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - related_data[field_name] = attrs.pop(field_name) - - # Reverse m2m relations - for (obj, model) in meta.get_all_related_m2m_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - m2m_data[field_name] = attrs.pop(field_name) - - # Forward m2m relations - for field in meta.many_to_many + meta.virtual_fields: - if isinstance(field, GenericForeignKey): - continue - if field.name in attrs: - m2m_data[field.name] = attrs.pop(field.name) - - # Nested forward relations - These need to be marked so we can save - # them before saving the parent model instance. - for field_name in attrs.keys(): - if isinstance(self.fields.get(field_name, None), Serializer): - nested_forward_relations[field_name] = attrs[field_name] - - # Create an empty instance of the model - if instance is None: - instance = self.opts.model() - - for key, val in attrs.items(): - try: - setattr(instance, key, val) - except ValueError: - self._errors[key] = [self.error_messages['required']] - - # Any relations that cannot be set until we've - # saved the model get hidden away on these - # private attributes, so we can deal with them - # at the point of save. - instance._related_data = related_data - instance._m2m_data = m2m_data - instance._nested_forward_relations = nested_forward_relations - - return instance - - def from_native(self, data, files): - """ - Override the default method to also include model field validation. - """ - instance = super(ModelSerializer, self).from_native(data, files) - if not self._errors: - return self.full_clean(instance) - - def save_object(self, obj, **kwargs): - """ - Save the deserialized object. - """ - if getattr(obj, '_nested_forward_relations', None): - # Nested relationships need to be saved before we can save the - # parent instance. - for field_name, sub_object in obj._nested_forward_relations.items(): - if sub_object: - self.save_object(sub_object) - setattr(obj, field_name, sub_object) - - obj.save(**kwargs) - - if getattr(obj, '_m2m_data', None): - for accessor_name, object_list in obj._m2m_data.items(): - setattr(obj, accessor_name, object_list) - del(obj._m2m_data) - - if getattr(obj, '_related_data', None): - related_fields = dict([ - (field.get_accessor_name(), field) - for field, model - in obj._meta.get_all_related_objects_with_model() - ]) - for accessor_name, related in obj._related_data.items(): - if isinstance(related, RelationsList): - # Nested reverse fk relationship - for related_item in related: - fk_field = related_fields[accessor_name].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) - - # Delete any removed objects - if related._deleted: - [self.delete_object(item) for item in related._deleted] - - elif isinstance(related, models.Model): - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) - else: - # Reverse FK or reverse one-one - setattr(obj, accessor_name, related) - del(obj._related_data) + # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)` + return CharField(**kwargs) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializer(ModelSerializer): - """ - A subclass of ModelSerializer that uses hyperlinked relationships, - instead of primary key relationships. - """ _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' - _hyperlink_field_class = HyperlinkedRelatedField - _hyperlink_identify_field_class = HyperlinkedIdentityField + #_hyperlink_field_class = HyperlinkedRelatedField + #_hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - if self.opts.url_field_name not in fields: - url_field = self._hyperlink_identify_field_class( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field - ) - ret = self._dict_class() - ret[self.opts.url_field_name] = url_field - ret.update(fields) - fields = ret + # if self.opts.url_field_name not in fields: + # url_field = self._hyperlink_identify_field_class( + # view_name=self.opts.view_name, + # lookup_field=self.opts.lookup_field + # ) + # ret = self._dict_class() + # ret[self.opts.url_field_name] = url_field + # ret.update(fields) + # fields = ret return fields @@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer): """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'view_name': self._get_default_view_name(related_model), - 'many': to_many - } + # kwargs = { + # 'queryset': related_model._default_manager, + # 'view_name': self._get_default_view_name(related_model), + # 'many': to_many + # } + kwargs = {} if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if self.opts.lookup_field: - kwargs['lookup_field'] = self.opts.lookup_field + return IntegerField(**kwargs) + # if self.opts.lookup_field: + # kwargs['lookup_field'] = self.opts.lookup_field - return self._hyperlink_field_class(**kwargs) - - def get_identity(self, data): - """ - This hook is required for bulk update. - We need to override the default, to use the url as the identity. - """ - try: - return data.get(self.opts.url_field_name, None) - except AttributeError: - return None + # return self._hyperlink_field_class(**kwargs) def _get_default_view_name(self, model): """ diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 00ffdfbae..6a2f61266 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -7,7 +7,7 @@ from django.db.models.query import QuerySet from django.utils.datastructures import SortedDict from django.utils.functional import Promise from rest_framework.compat import force_text -from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata +# from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata import datetime import decimal import types @@ -106,14 +106,14 @@ else: SortedDict, yaml.representer.SafeRepresenter.represent_dict ) - SafeDumper.add_representer( - DictWithMetadata, - yaml.representer.SafeRepresenter.represent_dict - ) - SafeDumper.add_representer( - SortedDictWithMetadata, - yaml.representer.SafeRepresenter.represent_dict - ) + # SafeDumper.add_representer( + # DictWithMetadata, + # yaml.representer.SafeRepresenter.represent_dict + # ) + # SafeDumper.add_representer( + # SortedDictWithMetadata, + # yaml.representer.SafeRepresenter.represent_dict + # ) SafeDumper.add_representer( types.GeneratorType, yaml.representer.SafeRepresenter.represent_list diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py new file mode 100644 index 000000000..bf17050df --- /dev/null +++ b/rest_framework/utils/html.py @@ -0,0 +1,86 @@ +""" +Helpers for dealing with HTML input. +""" + +def is_html_input(dictionary): + # MultiDict type datastructures are used to represent HTML form input, + # which may have more than one value for each key. + return hasattr(dictionary, 'getlist') + + +def parse_html_list(dictionary, prefix=''): + """ + Used to suport list values in HTML forms. + Supports lists of primitives and/or dictionaries. + + * List of primitives. + + { + '[0]': 'abc', + '[1]': 'def', + '[2]': 'hij' + } + --> + [ + 'abc', + 'def', + 'hij' + ] + + * List of dictionaries. + + { + '[0]foo': 'abc', + '[0]bar': 'def', + '[1]foo': 'hij', + '[2]bar': 'klm', + } + --> + [ + {'foo': 'abc', 'bar': 'def'}, + {'foo': 'hij', 'bar': 'klm'} + ] + """ + Dict = type(dictionary) + ret = {} + regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) + for field, value in dictionary.items(): + match = regex.match(field) + if not match: + continue + index, key = match.groups() + index = int(index) + if not key: + ret[index] = value + elif isinstance(ret.get(index), dict): + ret[index][key] = value + else: + ret[index] = Dict({key: value}) + return [ret[item] for item in sorted(ret.keys())] + + +def parse_html_dict(dictionary, prefix): + """ + Used to support dictionary values in HTML forms. + + { + 'profile.username': 'example', + 'profile.email': 'example@example.com', + } + --> + { + 'profile': { + 'username': 'example, + 'email': 'example@example.com' + } + } + """ + ret = {} + regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix)) + for field, value in dictionary.items(): + match = regex.match(field) + if not match: + continue + key = match.groups()[0] + ret[key] = value + return ret