From e8126c3a9165cdc904660f4dc0ace4d2cba86ddf Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 14 Dec 2011 10:48:19 +0000 Subject: [PATCH 1/5] Getting more resourceful - read, create, update and delete methods on ModelMixin --- djangorestframework/mixins.py | 208 +++++++++++++--------------- djangorestframework/resources.py | 187 +++++++++++++------------ djangorestframework/tests/mixins.py | 10 +- djangorestframework/views.py | 57 ++++++-- 4 files changed, 247 insertions(+), 215 deletions(-) diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index b1a634a07..93a094e6f 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -9,9 +9,8 @@ from django.db.models.fields.related import ForeignKey from django.http import HttpResponse from djangorestframework import status -from djangorestframework.renderers import BaseRenderer from djangorestframework.resources import Resource, FormResource, ModelResource -from djangorestframework.response import Response, ErrorResponse +from djangorestframework.response import ErrorResponse from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence @@ -27,11 +26,7 @@ __all__ = ( # Reverse URL lookup behavior 'InstanceMixin', # Model behavior mixins - 'ReadModelMixin', - 'CreateModelMixin', - 'UpdateModelMixin', - 'DeleteModelMixin', - 'ListModelMixin' + 'ModelMixin', ) @@ -267,25 +262,32 @@ class ResponseMixin(object): def _determine_renderer(self, request): """ - Determines the appropriate renderer for the output, given the client's 'Accept' header, - and the :attr:`renderers` set on this class. + Determines the appropriate renderer for the output, given the client's + 'Accept' header, and the :attr:`renderers` set on this class. Returns a 2-tuple of `(renderer, media_type)` - See: RFC 2616, Section 14 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html + See: RFC 2616, Section 14 + http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html """ - if self._ACCEPT_QUERY_PARAM and request.GET.get(self._ACCEPT_QUERY_PARAM, None): + if (self._ACCEPT_QUERY_PARAM and + request.GET.get(self._ACCEPT_QUERY_PARAM, None)): # Use _accept parameter override accept_list = [request.GET.get(self._ACCEPT_QUERY_PARAM)] + elif (self._IGNORE_IE_ACCEPT_HEADER and - request.META.has_key('HTTP_USER_AGENT') and + 'HTTP_USER_AGENT' in request.META and MSIE_USER_AGENT_REGEX.match(request.META['HTTP_USER_AGENT'])): - # Ignore MSIE's broken accept behavior and do something sensible instead + # Ignore MSIE's broken accept behavior and do something sensible + # instead. accept_list = ['text/html', '*/*'] - elif request.META.has_key('HTTP_ACCEPT'): + + elif 'HTTP_USER_AGENT' in request.META: # Use standard HTTP Accept negotiation - accept_list = [token.strip() for token in request.META["HTTP_ACCEPT"].split(',')] + accept_list = [token.strip() for token in + request.META["HTTP_ACCEPT"].split(',')] + else: # No accept header specified accept_list = ['*/*'] @@ -481,48 +483,86 @@ class InstanceMixin(object): ########## Model Mixins ########## -class ReadModelMixin(object): - """ - Behavior to read a `model` instance on GET requests - """ - def get(self, request, *args, **kwargs): - model = self.resource.model +class ModelMixin(object): + def get_model(self): + """ + Return the model class for this view. + """ + return getattr(self, 'model', self.resource.model) + + def get_queryset(self): + """ + Return the queryset that should be used when retrieving or listing + instances. + """ + return getattr(self, 'queryset', + getattr(self.resource, 'queryset', + self._get_model().objects.all())) + + def get_ordering(self): + """ + Return the ordering that should be used when listing instances. + """ + return getattr(self, 'ordering', + getattr(self.resource, 'ordering', + None)) + + def get_instance(self, *args, **kwargs): + """ + Return a model instance or None. + """ + model = self.get_model() + queryset = self.get_queryset() + kwargs = self._filter_kwargs(kwargs) try: + # If we have any positional args then assume the last + # represents the primary key. Otherwise assume the named kwargs + # uniquely identify the instance. if args: - # If we have any none kwargs then assume the last represents the primrary key - self.model_instance = model.objects.get(pk=args[-1], **kwargs) + return queryset.get(pk=args[-1], **kwargs) else: - # Otherwise assume the kwargs uniquely identify the model - filtered_keywords = kwargs.copy() - if BaseRenderer._FORMAT_QUERY_PARAM in filtered_keywords: - del filtered_keywords[BaseRenderer._FORMAT_QUERY_PARAM] - self.model_instance = model.objects.get(**filtered_keywords) + return queryset.get(**kwargs) except model.DoesNotExist: - raise ErrorResponse(status.HTTP_404_NOT_FOUND) + return None - return self.model_instance + def read(self, request, *args, **kwargs): + instance = self.get_instance(*args, **kwargs) + return instance + def update(self, request, *args, **kwargs): + """ + Return a model instance. + """ + instance = self.get_instance(*args, **kwargs) -class CreateModelMixin(object): - """ - Behavior to create a `model` instance on POST requests - """ - def post(self, request, *args, **kwargs): - model = self.resource.model + if instance: + for (key, val) in self.CONTENT.items(): + setattr(instance, key, val) + else: + instance = self.get_model()(**self.CONTENT) + + instance.save() + return instance + + def create(self, request, *args, **kwargs): + """ + Return a model instance. + """ + model = self._get_model() # Copy the dict to keep self.CONTENT intact content = dict(self.CONTENT) m2m_data = {} for field in model._meta.fields: - if isinstance(field, ForeignKey) and kwargs.has_key(field.name): + if isinstance(field, ForeignKey) and field.name in kwargs: # translate 'related_field' kwargs into 'related_field_id' kwargs[field.name + '_id'] = kwargs[field.name] del kwargs[field.name] for field in model._meta.many_to_many: - if content.has_key(field.name): + if field.name in content: m2m_data[field.name] = ( field.m2m_reverse_field_name(), content[field.name] ) @@ -549,90 +589,30 @@ class CreateModelMixin(object): data[m2m_data[fieldname][0]] = related_item manager.through(**data).save() - headers = {} - if hasattr(instance, 'get_absolute_url'): - headers['Location'] = self.resource(self).url(instance) - return Response(status.HTTP_201_CREATED, instance, headers) + return instance -class UpdateModelMixin(object): - """ - Behavior to update a `model` instance on PUT requests - """ - def put(self, request, *args, **kwargs): - model = self.resource.model + def destroy(self, request, *args, **kwargs): + """ + Return a model instance or None. + """ + instance = self.get_instance(*args, **kwargs) - # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url - try: - if args: - # If we have any none kwargs then assume the last represents the primary key - self.model_instance = model.objects.get(pk=args[-1], **kwargs) - else: - # Otherwise assume the kwargs uniquely identify the model - self.model_instance = model.objects.get(**kwargs) + if instance: + instance.delete() - for (key, val) in self.CONTENT.items(): - setattr(self.model_instance, key, val) - except model.DoesNotExist: - self.model_instance = model(**self.CONTENT) - self.model_instance.save() - - self.model_instance.save() - return self.model_instance + return instance -class DeleteModelMixin(object): - """ - Behavior to delete a `model` instance on DELETE requests - """ - def delete(self, request, *args, **kwargs): - model = self.resource.model - - try: - if args: - # If we have any none kwargs then assume the last represents the primrary key - instance = model.objects.get(pk=args[-1], **kwargs) - else: - # Otherwise assume the kwargs uniquely identify the model - instance = model.objects.get(**kwargs) - except model.DoesNotExist: - raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {}) - - instance.delete() - return - - -class ListModelMixin(object): - """ - Behavior to list a set of `model` instances on GET requests - """ - - # NB. Not obvious to me if it would be better to set this on the resource? - # - # Presumably it's more useful to have on the view, because that way you can - # have multiple views across different querysets mapping to the same resource. - # - # Perhaps it ought to be: - # - # 1) View.queryset - # 2) if None fall back to Resource.queryset - # 3) if None fall back to Resource.model.objects.all() - # - # Any feedback welcomed. - queryset = None - - def get(self, request, *args, **kwargs): - model = self.resource.model - - queryset = self.queryset if self.queryset is not None else model.objects.all() - - if hasattr(self, 'resource'): - ordering = getattr(self.resource, 'ordering', None) - else: - ordering = None + def list(self, request, *args, **kwargs): + """ + Return a list of instances. + """ + queryset = self.get_queryset() + ordering = self.get_ordering() if ordering: - args = as_tuple(ordering) + assert(hasattr(ordering, '__iter__')) queryset = queryset.order_by(*args) return queryset.filter(**kwargs) diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index 5770d07f9..dc11c31f8 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -1,24 +1,16 @@ from django import forms from django.core.urlresolvers import reverse, get_urlconf, get_resolver, NoReverseMatch from django.db import models -from django.db.models.query import QuerySet -from django.db.models.fields.related import RelatedField -from django.utils.encoding import smart_unicode from djangorestframework.response import ErrorResponse from djangorestframework.serializer import Serializer, _SkipField from djangorestframework.utils import as_tuple -import decimal -import inspect -import re - - - class BaseResource(Serializer): """ - Base class for all Resource classes, which simply defines the interface they provide. + Base class for all Resource classes, which simply defines the interface + they provide. """ fields = None include = None @@ -31,10 +23,11 @@ class BaseResource(Serializer): def validate_request(self, data, files=None): """ Given the request content return the cleaned, validated content. - Typically raises a :exc:`response.ErrorResponse` with status code 400 (Bad Request) on failure. + Typically raises a :exc:`response.ErrorResponse` with status code 400 + (Bad Request) on failure. """ return data - + def filter_response(self, obj): """ Given the response content, filter it into a serializable object. @@ -45,18 +38,20 @@ class BaseResource(Serializer): class Resource(BaseResource): """ A Resource determines how a python object maps to some serializable data. - Objects that a resource can act on include plain Python object instances, Django Models, and Django QuerySets. + Objects that a resource can act on include plain Python object instances, + Django Models, and Django QuerySets. """ - - # The model attribute refers to the Django Model which this Resource maps to. - # (The Model's class, rather than an instance of the Model) + + # The model attribute refers to the Django Model which this Resource maps + # to. (The Model's class, rather than an instance of the Model) model = None - + # By default the set of returned fields will be the set of: # # 0. All the fields on the model, excluding 'id'. # 1. All the properties on the model. - # 2. The absolute_url of the model, if a get_absolute_url method exists for the model. + # 2. The absolute_url of the model, if a get_absolute_url method exists for + # the model. # # If you wish to override this behaviour, # you should explicitly set the fields attribute on your class. @@ -66,60 +61,68 @@ class Resource(BaseResource): class FormResource(Resource): """ Resource class that uses forms for validation. - Also provides a :meth:`get_bound_form` method which may be used by some renderers. + Also provides a :meth:`get_bound_form` method which may be used by some + renderers. - On calling :meth:`validate_request` this validator may set a :attr:`bound_form_instance` attribute on the - view, which may be used by some renderers. + On calling :meth:`validate_request` this validator may set a + :attr:`bound_form_instance` attribute on the view, which may be used by + some renderers. """ form = None """ The :class:`Form` class that should be used for request validation. - This can be overridden by a :attr:`form` attribute on the :class:`views.View`. + This can be overridden by a :attr:`form` attribute on the + :class:`views.View`. """ - def validate_request(self, data, files=None): """ Given some content as input return some cleaned, validated content. - Raises a :exc:`response.ErrorResponse` with status code 400 (Bad Request) on failure. - - Validation is standard form validation, with an additional constraint that *no extra unknown fields* may be supplied. + Raises a :exc:`response.ErrorResponse` with status code 400 + # (Bad Request) on failure. - On failure the :exc:`response.ErrorResponse` content is a dict which may contain :obj:`'errors'` and :obj:`'field-errors'` keys. - If the :obj:`'errors'` key exists it is a list of strings of non-field errors. - If the :obj:`'field-errors'` key exists it is a dict of ``{'field name as string': ['errors as strings', ...]}``. + Validation is standard form validation, with an additional constraint + that *no extra unknown fields* may be supplied. + + On failure the :exc:`response.ErrorResponse` content is a dict which + may contain :obj:`'errors'` and :obj:`'field-errors'` keys. + If the :obj:`'errors'` key exists it is a list of strings of non-field + errors. + If the :obj:`'field-errors'` key exists it is a dict of + ``{'field name as string': ['errors as strings', ...]}``. """ return self._validate(data, files) - def _validate(self, data, files, allowed_extra_fields=(), fake_data=None): """ - Wrapped by validate to hide the extra flags that are used in the implementation. + Wrapped by validate to hide the extra flags that are used in the + implementation. - allowed_extra_fields is a list of fields which are not defined by the form, but which we still - expect to see on the input. - - fake_data is a string that should be used as an extra key, as a kludge to force .errors - to be populated when an empty dict is supplied in `data` + allowed_extra_fields is a list of fields which are not defined by the + form, but which we still expect to see on the input. + + fake_data is a string that should be used as an extra key, as a kludge + to force `.errors` to be populated when an empty dict is supplied in + `data` """ - + # We'd like nice error messages even if no content is supplied. # Typically if an empty dict is given to a form Django will # return .is_valid() == False, but .errors == {} # - # To get around this case we revalidate with some fake data. + # To get around this case we revalidate with some fake data. if fake_data: data[fake_data] = '_fake_data' allowed_extra_fields = tuple(allowed_extra_fields) + ('_fake_data',) - + bound_form = self.get_bound_form(data, files) if bound_form is None: return data - + self.view.bound_form_instance = bound_form - + data = data and data or {} files = files and files or {} @@ -127,10 +130,11 @@ class FormResource(Resource): form_fields_set = set(bound_form.fields.keys()) allowed_extra_fields_set = set(allowed_extra_fields) - # In addition to regular validation we also ensure no additional fields are being passed in... + # In addition to regular validation we also ensure no additional fields + # are being passed in... unknown_fields = seen_fields_set - (form_fields_set | allowed_extra_fields_set) unknown_fields = unknown_fields - set(('csrfmiddlewaretoken', '_accept', '_method')) # TODO: Ugh. - + # Check using both regular validation, and our stricter no additional fields rule if bound_form.is_valid() and not unknown_fields: # Validation succeeded... @@ -155,7 +159,7 @@ class FormResource(Resource): # If we've already set fake_dict and we're still here, fallback gracefully. detail = {u'errors': [u'No content was supplied.']} - else: + else: # Add any non-field errors if bound_form.non_field_errors(): detail[u'errors'] = bound_form.non_field_errors() @@ -171,14 +175,13 @@ class FormResource(Resource): # Add any unknown field errors for key in unknown_fields: field_errors[key] = [u'This field does not exist.'] - + if field_errors: detail[u'field-errors'] = field_errors # Return HTTP 400 response (BAD REQUEST) raise ErrorResponse(400, detail) - def get_form_class(self, method=None): """ Returns the form class used to validate this resource. @@ -199,7 +202,6 @@ class FormResource(Resource): form = getattr(self.view, '%s_form' % method.lower(), form) return form - def get_bound_form(self, data=None, files=None, method=None): """ @@ -217,7 +219,6 @@ class FormResource(Resource): return form() - #class _RegisterModelResource(type): # """ # Auto register new ModelResource classes into ``_model_to_resource`` @@ -230,14 +231,15 @@ class FormResource(Resource): # return resource_cls - class ModelResource(FormResource): """ - Resource class that uses forms for validation and otherwise falls back to a model form if no form is set. - Also provides a :meth:`get_bound_form` method which may be used by some renderers. + Resource class that uses forms for validation and otherwise falls back to a + model form if no form is set. + Also provides a :meth:`get_bound_form` method which may be used by some + renderers. """ - # Auto-register new ModelResource classes into _model_to_resource + # Auto-register new ModelResource classes into _model_to_resource #__metaclass__ = _RegisterModelResource form = None @@ -245,38 +247,45 @@ class ModelResource(FormResource): The form class that should be used for request validation. If set to :const:`None` then the default model form validation will be used. - This can be overridden by a :attr:`form` attribute on the :class:`views.View`. + This can be overridden by a :attr:`form` attribute on the + :class:`views.View`. """ model = None """ The model class which this resource maps to. - This can be overridden by a :attr:`model` attribute on the :class:`views.View`. + This can be overridden by a :attr:`model` attribute on the + :class:`views.View`. """ fields = None """ The list of fields to use on the output. - + May be any of: - - The name of a model field. To view nested resources, give the field as a tuple of ("fieldName", resource) where `resource` may be any of ModelResource reference, the name of a ModelResourc reference as a string or a tuple of strings representing fields on the nested model. + + The name of a model field. To view nested resources, give the field as a + tuple of ("fieldName", resource) where `resource` may be any of + ModelResource reference, the name of a ModelResourc reference as a string + or a tuple of strings representing fields on the nested model. The name of an attribute on the model. The name of an attribute on the resource. The name of a method on the model, with a signature like ``func(self)``. - The name of a method on the resource, with a signature like ``func(self, instance)``. + The name of a method on the resource, with a signature like + ``func(self, instance)``. """ - + exclude = ('id', 'pk') """ - The list of fields to exclude. This is only used if :attr:`fields` is not set. + The list of fields to exclude. This is only used if :attr:`fields` is not + set. """ - include = ('url',) """ - The list of extra fields to include. This is only used if :attr:`fields` is not set. + The list of extra fields to include. This is only used if :attr:`fields` + is not set. """ def __init__(self, view=None, depth=None, stack=[], **kwargs): @@ -289,30 +298,35 @@ class ModelResource(FormResource): self.model = getattr(view, 'model', None) or self.model - def validate_request(self, data, files=None): """ Given some content as input return some cleaned, validated content. - Raises a :exc:`response.ErrorResponse` with status code 400 (Bad Request) on failure. - + Raises a :exc:`response.ErrorResponse` with status code 400 + (Bad Request) on failure. + Validation is standard form or model form validation, - with an additional constraint that no extra unknown fields may be supplied, - and that all fields specified by the fields class attribute must be supplied, - even if they are not validated by the form/model form. + with an additional constraint that no extra unknown fields may be + supplied, and that all fields specified by the fields class attribute + must be supplied, even if they are not validated by the Form/ModelForm. - On failure the ErrorResponse content is a dict which may contain :obj:`'errors'` and :obj:`'field-errors'` keys. - If the :obj:`'errors'` key exists it is a list of strings of non-field errors. - If the ''field-errors'` key exists it is a dict of {field name as string: list of errors as strings}. + On failure the ErrorResponse content is a dict which may contain + :obj:`'errors'` and :obj:`'field-errors'` keys. + If the :obj:`'errors'` key exists it is a list of strings of non-field + errors. + If the ''field-errors'` key exists it is a dict of + `{field name as string: list of errors as strings}`. """ - return self._validate(data, files, allowed_extra_fields=self._property_fields_set) - + return self._validate(data, files, + allowed_extra_fields=self._property_fields_set) def get_bound_form(self, data=None, files=None, method=None): """ Given some content return a ``Form`` instance bound to that content. - If the :attr:`form` class attribute has been explicitly set then that class will be used - to create the Form, otherwise the model will be used to create a ModelForm. + If the :attr:`form` class attribute has been explicitly set then that + class will be used + to create the Form, otherwise the model will be used to create a + ModelForm. """ form = self.get_form_class(method) @@ -339,18 +353,20 @@ class ModelResource(FormResource): return form() - def url(self, instance): """ - Attempts to reverse resolve the url of the given model *instance* for this resource. + Attempts to reverse resolve the url of the given model *instance* for + this resource. - Requires a ``View`` with :class:`mixins.InstanceMixin` to have been created for this resource. - - This method can be overridden if you need to set the resource url reversing explicitly. + Requires a ``View`` with :class:`mixins.InstanceMixin` to have been + created for this resource. + + This method can be overridden if you need to set the resource url + reversing explicitly. """ if not hasattr(self, 'view_callable'): - raise _SkipField + raise _SkipField # dis does teh magicks... urlconf = get_urlconf() @@ -363,7 +379,9 @@ class ModelResource(FormResource): # Note: defaults = tuple_item[2] for django >= 1.3 for result, params in possibility: - #instance_attrs = dict([ (param, getattr(instance, param)) for param in params if hasattr(instance, param) ]) + # instance_attrs = dict([ (param, getattr(instance, param)) + # for param in params + # if hasattr(instance, param) ]) instance_attrs = {} for param in params: @@ -381,7 +399,6 @@ class ModelResource(FormResource): pass raise _SkipField - @property def _model_fields_set(self): """ @@ -389,11 +406,11 @@ class ModelResource(FormResource): """ model_fields = set(field.name for field in self.model._meta.fields) - if fields: + if self.fields: return model_fields & set(as_tuple(self.fields)) return model_fields - set(as_tuple(self.exclude)) - + @property def _property_fields_set(self): """ diff --git a/djangorestframework/tests/mixins.py b/djangorestframework/tests/mixins.py index 65cf4a45a..0ccef5d3a 100644 --- a/djangorestframework/tests/mixins.py +++ b/djangorestframework/tests/mixins.py @@ -4,7 +4,7 @@ from django.utils import simplejson as json from djangorestframework import status from djangorestframework.compat import RequestFactory from django.contrib.auth.models import Group, User -from djangorestframework.mixins import CreateModelMixin, PaginatorMixin +from djangorestframework.mixins import PaginatorMixin from djangorestframework.resources import ModelResource from djangorestframework.response import Response from djangorestframework.tests.models import CustomUser @@ -25,7 +25,7 @@ class TestModelCreation(TestCase): form_data = {'name': 'foo'} request = self.req.post('/groups', data=form_data) - mixin = CreateModelMixin() + mixin = ModelMixin() mixin.resource = GroupResource mixin.CONTENT = form_data @@ -51,7 +51,7 @@ class TestModelCreation(TestCase): request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [group] - mixin = CreateModelMixin() + mixin = ModelMixin() mixin.resource = UserResource mixin.CONTENT = cleaned_data @@ -74,7 +74,7 @@ class TestModelCreation(TestCase): request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [] - mixin = CreateModelMixin() + mixin = ModelMixin() mixin.resource = UserResource mixin.CONTENT = cleaned_data @@ -105,7 +105,7 @@ class TestModelCreation(TestCase): request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [group, group2] - mixin = CreateModelMixin() + mixin = ModelMixin() mixin.resource = UserResource mixin.CONTENT = cleaned_data diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 0a3594047..9045c1061 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -25,7 +25,6 @@ __all__ = ( ) - class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ Handles incoming requests and maps them to REST operations. @@ -59,7 +58,6 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ permissions = ( permissions.FullAnonAccess, ) - @classmethod def as_view(cls, **initkwargs): """ @@ -71,7 +69,6 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): view.cls_instance = cls(**initkwargs) return view - @property def allowed_methods(self): """ @@ -79,7 +76,6 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ return [method.upper() for method in self.http_method_names if hasattr(self, method)] - def http_method_not_allowed(self, request, *args, **kwargs): """ Return an HTTP 405 error if an operation is called which does not have a handler method. @@ -87,7 +83,6 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): raise ErrorResponse(status.HTTP_405_METHOD_NOT_ALLOWED, {'detail': 'Method \'%s\' not allowed on this resource.' % self.method}) - def initial(self, request, *args, **kargs): """ Hook for any code that needs to run prior to anything else. @@ -96,14 +91,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ pass - def add_header(self, field, value): """ Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class. """ self.headers[field] = value - # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. @csrf_exempt @@ -183,26 +176,68 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): return response_obj -class ModelView(View): +class ModelView(ModelMixin, View): """ A RESTful view that maps to a model in the database. """ resource = resources.ModelResource -class InstanceModelView(InstanceMixin, ReadModelMixin, UpdateModelMixin, DeleteModelMixin, ModelView): + def _filter_kwargs(self, kwargs): + kwargs = kwargs.copy() + if BaseRenderer._FORMAT_QUERY_PARAM in kwargs: + del kwargs[BaseRenderer._FORMAT_QUERY_PARAM] + return kwargs + + +class InstanceModelView(ModelView): """ A view which provides default operations for read/update/delete against a model instance. """ _suffix = 'Instance' -class ListModelView(ListModelMixin, ModelView): + def get(self, request, *args, **kwargs): + instance = self.read(request, *args, **self._filter_kwargs(kwargs)) + + if not instance: + raise ErrorResponse(status.HTTP_404_NOT_FOUND) + + return instance + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **self._filter_kwargs(kwargs)) + + def delete(self, request, *args, **kwargs): + instance = self.destroy(request, *args, **self._filter_kwargs(kwargs)) + + if not instance: + raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {}) + + return None + + +class ListModelView(ModelView): """ A view which provides default operations for list, against a model in the database. """ _suffix = 'List' -class ListOrCreateModelView(ListModelMixin, CreateModelMixin, ModelView): + def get(self, request, *args, **kwargs): + return self.list(request, *args, **self._filter_kwargs(kwargs)) + + +class ListOrCreateModelView(ModelView): """ A view which provides default operations for list and create, against a model in the database. """ _suffix = 'List' + + def get(self, request, *args, **kwargs): + return self.list(request, *args, **self._filter_kwargs(kwargs)) + + def post(self, request, *args, **kwargs): + instance = self.create(request, *args, **self._filter_kwargs(kwargs)) + + headers = {} + if hasattr(instance, 'get_absolute_url'): + headers['Location'] = self.resource(self).url(instance) + return Response(status.HTTP_201_CREATED, instance, headers) From db6df5ce61248d6bcf5e83d1c7fa5d495e3584ec Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 14 Dec 2011 14:45:59 +0000 Subject: [PATCH 2/5] Drop .as_tuple() --- djangorestframework/mixins.py | 4 ++-- djangorestframework/resources.py | 29 ++++++++++++++------------- djangorestframework/utils/__init__.py | 18 ----------------- 3 files changed, 17 insertions(+), 34 deletions(-) diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 93a094e6f..350af5cb1 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -12,6 +12,7 @@ from djangorestframework import status from djangorestframework.resources import Resource, FormResource, ModelResource from djangorestframework.response import ErrorResponse from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX +from djangorestframework.utils import MSIE_USER_AGENT_REGEX from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence from StringIO import StringIO @@ -177,8 +178,7 @@ class RequestMixin(object): return (None, None) parsers = as_tuple(self.parsers) - - for parser_cls in parsers: + for parser_cls in self.parsers: parser = parser_cls(self) if parser.can_handle_request(content_type): return parser.parse(stream) diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index dc11c31f8..42c294a5a 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -4,7 +4,6 @@ from django.db import models from djangorestframework.response import ErrorResponse from djangorestframework.serializer import Serializer, _SkipField -from djangorestframework.utils import as_tuple class BaseResource(Serializer): @@ -12,9 +11,9 @@ class BaseResource(Serializer): Base class for all Resource classes, which simply defines the interface they provide. """ - fields = None - include = None - exclude = None + fields = () + include = () + exclude = () def __init__(self, view=None, depth=None, stack=[], **kwargs): super(BaseResource, self).__init__(depth, stack, **kwargs) @@ -126,14 +125,16 @@ class FormResource(Resource): data = data and data or {} files = files and files or {} - seen_fields_set = set(data.keys()) - form_fields_set = set(bound_form.fields.keys()) - allowed_extra_fields_set = set(allowed_extra_fields) + seen_fields = set(data.keys()) + form_fields = set(bound_form.fields.keys()) + allowed_extra_fields = set(allowed_extra_fields) # In addition to regular validation we also ensure no additional fields # are being passed in... - unknown_fields = seen_fields_set - (form_fields_set | allowed_extra_fields_set) - unknown_fields = unknown_fields - set(('csrfmiddlewaretoken', '_accept', '_method')) # TODO: Ugh. + # TODO: Hardcoded ignore_fields here is pretty icky. + ignore_fields = set(('csrfmiddlewaretoken', '_accept', '_method')) + allowed_fields = form_fields | allowed_extra_fields | ignore_fields + unknown_fields = seen_fields - allowed_fields # Check using both regular validation, and our stricter no additional fields rule if bound_form.is_valid() and not unknown_fields: @@ -141,7 +142,7 @@ class FormResource(Resource): cleaned_data = bound_form.cleaned_data # Add in any extra fields to the cleaned content... - for key in (allowed_extra_fields_set & seen_fields_set) - set(cleaned_data.keys()): + for key in (allowed_extra_fields & seen_fields) - set(cleaned_data.keys()): cleaned_data[key] = data[key] return cleaned_data @@ -407,9 +408,9 @@ class ModelResource(FormResource): model_fields = set(field.name for field in self.model._meta.fields) if self.fields: - return model_fields & set(as_tuple(self.fields)) + return model_fields & set(self.fields) - return model_fields - set(as_tuple(self.exclude)) + return model_fields - set(self.exclude) @property def _property_fields_set(self): @@ -421,6 +422,6 @@ class ModelResource(FormResource): and not attr.startswith('_')) if self.fields: - return property_fields & set(as_tuple(self.fields)) + return property_fields & set(self.fields) - return property_fields.union(set(as_tuple(self.include))) - set(as_tuple(self.exclude)) + return property_fields.union(set(self.include)) - set(self.exclude) diff --git a/djangorestframework/utils/__init__.py b/djangorestframework/utils/__init__.py index 04baea78f..7d693cc44 100644 --- a/djangorestframework/utils/__init__.py +++ b/djangorestframework/utils/__init__.py @@ -18,25 +18,7 @@ from mediatypes import add_media_type_param, get_media_type_params, order_by_pre MSIE_USER_AGENT_REGEX = re.compile(r'^Mozilla/[0-9]+\.[0-9]+ \([^)]*; MSIE [0-9]+\.[0-9]+[a-z]?;[^)]*\)(?!.* Opera )') -def as_tuple(obj): - """ - Given an object which may be a list/tuple, another object, or None, - return that object in list form. - IE: - If the object is already a list/tuple just return it. - If the object is not None, return it in a list with a single element. - If the object is None return an empty list. - """ - if obj is None: - return () - elif isinstance(obj, list): - return tuple(obj) - elif isinstance(obj, tuple): - return obj - return (obj,) - - def url_resolves(url): """ Return True if the given URL is mapped to a view in the urlconf, False otherwise. From f84fd47825a7921c9d0dbdf7aba05f49df428c9b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 14 Dec 2011 14:48:22 +0000 Subject: [PATCH 3/5] Cleanups - line lengths, whitespace etc. --- djangorestframework/authentication.py | 71 +++--- djangorestframework/mixins.py | 302 +++++++++++++------------ djangorestframework/tests/renderers.py | 59 +++-- djangorestframework/utils/__init__.py | 18 +- djangorestframework/views.py | 30 +-- 5 files changed, 256 insertions(+), 224 deletions(-) diff --git a/djangorestframework/authentication.py b/djangorestframework/authentication.py index be22103e6..d342e7a0f 100644 --- a/djangorestframework/authentication.py +++ b/djangorestframework/authentication.py @@ -1,15 +1,17 @@ """ -The :mod:`authentication` module provides a set of pluggable authentication classes. +The :mod:`authentication` module provides a set of pluggable authentication +classes. -Authentication behavior is provided by mixing the :class:`mixins.AuthMixin` class into a :class:`View` class. +Authentication behavior is provided by mixing the :class:`mixins.AuthMixin` +class into a :class:`View` class. -The set of authentication methods which are used is then specified by setting the -:attr:`authentication` attribute on the :class:`View` class, and listing a set of :class:`authentication` classes. +The set of authentication methods which are used is then specified by setting +the :attr:`authentication` attribute on the :class:`View` class, and listing a +set of :class:`authentication` classes. """ from django.contrib.auth import authenticate from django.middleware.csrf import CsrfViewMiddleware -from djangorestframework.utils import as_tuple import base64 __all__ = ( @@ -26,23 +28,25 @@ class BaseAuthentication(object): def __init__(self, view): """ - :class:`Authentication` classes are always passed the current view on creation. + :class:`Authentication` classes are always passed the current view on + creation. """ self.view = view def authenticate(self, request): """ - Authenticate the :obj:`request` and return a :obj:`User` or :const:`None`. [*]_ - + Authenticate the :obj:`request` and return a :obj:`User` or + :const:`None`. [*]_ + .. [*] The authentication context *will* typically be a :obj:`User`, but it need not be. It can be any user-like object so long as the - permissions classes (see the :mod:`permissions` module) on the view can - handle the object and use it to determine if the request has the required - permissions or not. - - This can be an important distinction if you're implementing some token - based authentication mechanism, where the authentication context - may be more involved than simply mapping to a :obj:`User`. + permissions classes (see the :mod:`permissions` module) on the view + can handle the object and use it to determine if the request has + the required permissions or not. + + This can be an important distinction if you're implementing some + token based authentication mechanism, where the authentication + context may be more involved than simply mapping to a :obj:`User`. """ return None @@ -51,14 +55,20 @@ class BasicAuthentication(BaseAuthentication): """ Use HTTP Basic authentication. """ + def _authenticate_user(self, username, password): + user = authenticate(username=username, password=password) + if user and user.is_active: + return user + return None def authenticate(self, request): """ - Returns a :obj:`User` if a correct username and password have been supplied - using HTTP Basic authentication. Otherwise returns :const:`None`. + Returns a :obj:`User` if a correct username and password have been + supplied using HTTP Basic authentication. + Otherwise returns :const:`None`. """ - from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError - + from django.utils import encoding + if 'HTTP_AUTHORIZATION' in request.META: auth = request.META['HTTP_AUTHORIZATION'].split() if len(auth) == 2 and auth[0].lower() == "basic": @@ -66,17 +76,19 @@ class BasicAuthentication(BaseAuthentication): auth_parts = base64.b64decode(auth[1]).partition(':') except TypeError: return None - + try: - uname, passwd = smart_unicode(auth_parts[0]), smart_unicode(auth_parts[2]) - except DjangoUnicodeDecodeError: + username = encoding.smart_unicode(auth_parts[0]) + password = encoding.smart_unicode(auth_parts[2]) + except encoding.DjangoUnicodeDecodeError: return None - - user = authenticate(username=uname, password=passwd) - if user is not None and user.is_active: + + user = self._authenticate_user(username, password) + if user: return user + return None - + class UserLoggedInAuthentication(BaseAuthentication): """ @@ -85,10 +97,11 @@ class UserLoggedInAuthentication(BaseAuthentication): def authenticate(self, request): """ - Returns a :obj:`User` if the request session currently has a logged in user. - Otherwise returns :const:`None`. + Returns a :obj:`User` if the request session currently has a logged in + user. Otherwise returns :const:`None`. """ - # TODO: Switch this back to request.POST, and let FormParser/MultiPartParser deal with the consequences. + # TODO: Switch this back to request.POST, and let + # FormParser/MultiPartParser deal with the consequences. if getattr(request, 'user', None) and request.user.is_active: # If this is a POST request we enforce CSRF validation. if request.method.upper() == 'POST': diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 350af5cb1..69e03f696 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -5,13 +5,12 @@ classes that can be added to a `View`. from django.contrib.auth.models import AnonymousUser from django.core.paginator import Paginator -from django.db.models.fields.related import ForeignKey from django.http import HttpResponse from djangorestframework import status +from djangorestframework.renderers import BaseRenderer from djangorestframework.resources import Resource, FormResource, ModelResource -from djangorestframework.response import ErrorResponse -from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX +from djangorestframework.response import Response, ErrorResponse from djangorestframework.utils import MSIE_USER_AGENT_REGEX from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence @@ -55,14 +54,14 @@ class RequestMixin(object): """ Returns the HTTP method. - This should be used instead of just reading :const:`request.method`, as it allows the `method` - to be overridden by using a hidden `form` field on a form POST request. + This should be used instead of just reading :const:`request.method`, as + it allows the `method` to be overridden by using a hidden `form` field + on a form POST request. """ if not hasattr(self, '_method'): self._load_method_and_content_type() return self._method - @property def content_type(self): """ @@ -76,7 +75,6 @@ class RequestMixin(object): self._load_method_and_content_type() return self._content_type - @property def DATA(self): """ @@ -89,7 +87,6 @@ class RequestMixin(object): self._load_data_and_files() return self._data - @property def FILES(self): """ @@ -101,7 +98,6 @@ class RequestMixin(object): self._load_data_and_files() return self._files - def _load_data_and_files(self): """ Parse the request content into self.DATA and self.FILES. @@ -110,18 +106,19 @@ class RequestMixin(object): self._load_method_and_content_type() if not hasattr(self, '_data'): - (self._data, self._files) = self._parse(self._get_stream(), self._content_type) - + (self._data, self._files) = self._parse(self._get_stream(), + self._content_type) def _load_method_and_content_type(self): """ - Set the method and content_type, and then check if they've been overridden. + Set the method and content_type, and then check if they've been + overridden. """ self._method = self.request.method - self._content_type = self.request.META.get('HTTP_CONTENT_TYPE', self.request.META.get('CONTENT_TYPE', '')) + self._content_type = self.request.META.get('HTTP_CONTENT_TYPE', + self.request.META.get('CONTENT_TYPE', '')) self._perform_form_overloading() - def _get_stream(self): """ Returns an object that may be used to stream the request content. @@ -129,7 +126,8 @@ class RequestMixin(object): request = self.request try: - content_length = int(request.META.get('CONTENT_LENGTH', request.META.get('HTTP_CONTENT_LENGTH'))) + content_length = int(request.META.get('CONTENT_LENGTH', + request.META.get('HTTP_CONTENT_LENGTH'))) except (ValueError, TypeError): content_length = 0 @@ -138,18 +136,20 @@ class RequestMixin(object): if content_length == 0: return None elif hasattr(request, 'read'): - return request + return request return StringIO(request.raw_post_data) - def _perform_form_overloading(self): """ - If this is a form POST request, then we need to check if the method and content/content_type have been - overridden by setting them in hidden form fields or not. + If this is a form POST request, then we need to check if the method and + content/content_type have been overridden by setting them in hidden + form fields or not. """ # We only need to use form overloading on form POST requests. - if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type): + if (not self._USE_FORM_OVERLOADING + or self._method != 'POST' + or not is_form_media_type(self._content_type)): return # At this point we're committed to parsing the request as form data. @@ -167,26 +167,24 @@ class RequestMixin(object): stream = StringIO(self._data.pop(self._CONTENT_PARAM)[0]) (self._data, self._files) = self._parse(stream, self._content_type) - def _parse(self, stream, content_type): """ Parse the request content. - May raise a 415 ErrorResponse (Unsupported Media Type), or a 400 ErrorResponse (Bad Request). + May raise a 415 ErrorResponse (Unsupported Media Type), or a 400 + ErrorResponse (Bad Request). """ if stream is None or content_type is None: return (None, None) - parsers = as_tuple(self.parsers) for parser_cls in self.parsers: parser = parser_cls(self) if parser.can_handle_request(content_type): return parser.parse(stream) - raise ErrorResponse(status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, - {'error': 'Unsupported media type in request \'%s\'.' % - content_type}) - + error = {'error': + "Unsupported media type in request '%s'." % content_type} + raise ErrorResponse(status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, error) @property def _parsed_media_types(self): @@ -195,7 +193,6 @@ class RequestMixin(object): """ return [parser.media_type for parser in self.parsers] - @property def _default_parser(self): """ @@ -204,29 +201,34 @@ class RequestMixin(object): return self.parsers[0] - ########## ResponseMixin ########## + class ResponseMixin(object): """ Adds behavior for pluggable `Renderers` to a :class:`views.View` class. Default behavior is to use standard HTTP Accept header content negotiation. - Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL. - Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead. + + Also supports overriding the content type by specifying an ``_accept=`` + parameter in the URL. + + Ignores Accept headers from Internet Explorer user agents and uses a + sensible browser Accept header instead. """ - _ACCEPT_QUERY_PARAM = '_accept' # Allow override of Accept header in URL query params + # Allow override of Accept header in URL query params + _ACCEPT_QUERY_PARAM = '_accept' _IGNORE_IE_ACCEPT_HEADER = True """ The set of response renderers that the view can handle. - Should be a tuple/list of classes as described in the :mod:`renderers` module. + Should be a tuple/list of classes as described in the :mod:`renderers` + module. """ renderers = () - # TODO: wrap this behavior around dispatch(), ensuring it works # out of the box with existing Django classes that use render_to_response. def render(self, response): @@ -253,13 +255,13 @@ class ResponseMixin(object): content = renderer.render() # Build the HTTP Response - resp = HttpResponse(content, mimetype=response.media_type, status=response.status) + resp = HttpResponse(content, mimetype=response.media_type, + status=response.status) for (key, val) in response.headers.items(): resp[key] = val return resp - def _determine_renderer(self, request): """ Determines the appropriate renderer for the output, given the client's @@ -283,10 +285,10 @@ class ResponseMixin(object): # instead. accept_list = ['text/html', '*/*'] - elif 'HTTP_USER_AGENT' in request.META: + elif 'HTTP_ACCEPT' in request.META: # Use standard HTTP Accept negotiation accept_list = [token.strip() for token in - request.META["HTTP_ACCEPT"].split(',')] + request.META['HTTP_ACCEPT'].split(',')] else: # No accept header specified @@ -295,7 +297,7 @@ class ResponseMixin(object): # Check the acceptable media types against each renderer, # attempting more specific media types first # NB. The inner loop here isn't as bad as it first looks :) - # Worst case is we're looping over len(accept_list) * len(self.renderers) + # Worst case is: len(accept_list) * len(self.renderers) renderers = [renderer_cls(self) for renderer_cls in self.renderers] for accepted_media_type_lst in order_by_precedence(accept_list): @@ -305,10 +307,9 @@ class ResponseMixin(object): return renderer, accepted_media_type # No acceptable renderers were found - raise ErrorResponse(status.HTTP_406_NOT_ACCEPTABLE, - {'detail': 'Could not satisfy the client\'s Accept header', - 'available_types': self._rendered_media_types}) - + error = {'detail': "Could not satisfy the client's Accept header", + 'available_types': self._rendered_media_types} + raise ErrorResponse(status.HTTP_406_NOT_ACCEPTABLE, error) @property def _rendered_media_types(self): @@ -336,39 +337,40 @@ class ResponseMixin(object): class AuthMixin(object): """ - Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class. + Simple :class:`mixin` class to add authentication and permission checking + to a :class:`View` class. """ """ The set of authentication types that this view can handle. - Should be a tuple/list of classes as described in the :mod:`authentication` module. + Should be a tuple/list of classes as described in the :mod:`authentication` + module. """ authentication = () """ The set of permissions that will be enforced on this view. - Should be a tuple/list of classes as described in the :mod:`permissions` module. + Should be a tuple/list of classes as described in the :mod:`permissions` + module. """ permissions = () - @property def user(self): """ - Returns the :obj:`user` for the current request, as determined by the set of - :class:`authentication` classes applied to the :class:`View`. + Returns the :obj:`user` for the current request, as determined by the + set of :class:`authentication` classes applied to the :class:`View`. """ if not hasattr(self, '_user'): self._user = self._authenticate() return self._user - def _authenticate(self): """ - Attempt to authenticate the request using each authentication class in turn. - Returns a ``User`` object, which may be ``AnonymousUser``. + Attempt to authenticate the request using each authentication class in + turn. Returns a ``User`` object, which may be ``AnonymousUser``. """ for authentication_cls in self.authentication: authentication = authentication_cls(self) @@ -377,7 +379,6 @@ class AuthMixin(object): return user return AnonymousUser() - # TODO: wrap this behavior around dispatch() def _check_permissions(self): """ @@ -397,10 +398,12 @@ class ResourceMixin(object): Should be a class as described in the :mod:`resources` module. - The :obj:`resource` is an object that maps a view onto it's representation on the server. + The :obj:`resource` is an object that maps a view onto it's representation + on the server. It provides validation on the content of incoming requests, - and filters the object representation into a serializable object for the response. + and filters the object representation into a serializable object for the + response. """ resource = None @@ -409,7 +412,8 @@ class ResourceMixin(object): """ Returns the cleaned, validated request content. - May raise an :class:`response.ErrorResponse` with status code 400 (Bad Request). + May raise an :class:`response.ErrorResponse` with status code 400 + (Bad Request). """ if not hasattr(self, '_content'): self._content = self.validate_request(self.DATA, self.FILES) @@ -420,7 +424,8 @@ class ResourceMixin(object): """ Returns the cleaned, validated query parameters. - May raise an :class:`response.ErrorResponse` with status code 400 (Bad Request). + May raise an :class:`response.ErrorResponse` with status code 400 + (Bad Request). """ return self.validate_request(self.request.GET) @@ -438,8 +443,10 @@ class ResourceMixin(object): def validate_request(self, data, files=None): """ - Given the request *data* and optional *files*, return the cleaned, validated content. - May raise an :class:`response.ErrorResponse` with status code 400 (Bad Request) on failure. + Given the request *data* and optional *files*, return the cleaned, + validated content. + May raise an :class:`response.ErrorResponse` with status code 400 + (Bad Request) on failure. """ return self._resource.validate_request(data, files) @@ -456,26 +463,28 @@ class ResourceMixin(object): return None - ########## + class InstanceMixin(object): """ - `Mixin` class that is used to identify a `View` class as being the canonical identifier - for the resources it is mapped to. + `Mixin` class that is used to identify a `View` class as being the + canonical identifier for the resources it is mapped to. """ @classmethod def as_view(cls, **initkwargs): """ - Store the callable object on the resource class that has been associated with this view. + Store the callable object on the resource class that has been + associated with this view. """ view = super(InstanceMixin, cls).as_view(**initkwargs) resource = getattr(cls(**initkwargs), 'resource', None) if resource: # We do a little dance when we store the view callable... - # we need to store it wrapped in a 1-tuple, so that inspect will treat it - # as a function when we later look it up (rather than turning it into a method). + # we need to store it wrapped in a 1-tuple, so that inspect will + # treat it as a function when we later look it up (rather than + # turning it into a method). # This makes sure our URL reversing works ok. resource.view_callable = (view,) return view @@ -483,6 +492,7 @@ class InstanceMixin(object): ########## Model Mixins ########## + class ModelMixin(object): def get_model(self): """ @@ -497,7 +507,7 @@ class ModelMixin(object): """ return getattr(self, 'queryset', getattr(self.resource, 'queryset', - self._get_model().objects.all())) + self.get_model().objects.all())) def get_ordering(self): """ @@ -507,73 +517,32 @@ class ModelMixin(object): getattr(self.resource, 'ordering', None)) + # Underlying instance API... + def get_instance(self, *args, **kwargs): """ Return a model instance or None. """ model = self.get_model() queryset = self.get_queryset() - kwargs = self._filter_kwargs(kwargs) try: - # If we have any positional args then assume the last - # represents the primary key. Otherwise assume the named kwargs - # uniquely identify the instance. - if args: - return queryset.get(pk=args[-1], **kwargs) - else: - return queryset.get(**kwargs) + return queryset.get(**kwargs) except model.DoesNotExist: return None - def read(self, request, *args, **kwargs): - instance = self.get_instance(*args, **kwargs) - return instance + def create_instance(self, *args, **kwargs): + model = self.get_model() - def update(self, request, *args, **kwargs): - """ - Return a model instance. - """ - instance = self.get_instance(*args, **kwargs) - - if instance: - for (key, val) in self.CONTENT.items(): - setattr(instance, key, val) - else: - instance = self.get_model()(**self.CONTENT) - - instance.save() - return instance - - def create(self, request, *args, **kwargs): - """ - Return a model instance. - """ - model = self._get_model() - - # Copy the dict to keep self.CONTENT intact - content = dict(self.CONTENT) m2m_data = {} - - for field in model._meta.fields: - if isinstance(field, ForeignKey) and field.name in kwargs: - # translate 'related_field' kwargs into 'related_field_id' - kwargs[field.name + '_id'] = kwargs[field.name] + for field in model._meta.many_to_many: + if field.name in kwargs: + m2m_data[field.name] = ( + field.m2m_reverse_field_name(), kwargs[field.name] + ) del kwargs[field.name] - for field in model._meta.many_to_many: - if field.name in content: - m2m_data[field.name] = ( - field.m2m_reverse_field_name(), content[field.name] - ) - del content[field.name] - - all_kw_args = dict(content.items() + kwargs.items()) - - if args: - instance = model(pk=args[-1], **all_kw_args) - else: - instance = model(**all_kw_args) + instance = model(**kwargs) instance.save() for fieldname in m2m_data: @@ -591,31 +560,81 @@ class ModelMixin(object): return instance - - def destroy(self, request, *args, **kwargs): - """ - Return a model instance or None. - """ - instance = self.get_instance(*args, **kwargs) - - if instance: - instance.delete() - + def update_instance(self, instance, *args, **kwargs): + for (key, val) in kwargs.items(): + setattr(instance, key, val) + instance.save() return instance + def delete_instance(self, instance, *args, **kwargs): + instance.delete() + return instance - def list(self, request, *args, **kwargs): - """ - Return a list of instances. - """ + def list_instances(self, *args, **kwargs): queryset = self.get_queryset() ordering = self.get_ordering() if ordering: - assert(hasattr(ordering, '__iter__')) - queryset = queryset.order_by(*args) + queryset = queryset.order_by(ordering) return queryset.filter(**kwargs) + # Request/Response layer... + + def _get_url_kwargs(self, kwargs): + format_arg = BaseRenderer._FORMAT_QUERY_PARAM + if format_arg in kwargs: + kwargs = kwargs.copy() + del kwargs[format_arg] + return kwargs + + def _get_content_kwargs(self, kwargs): + return dict(self._get_url_kwargs(kwargs).items() + + self.CONTENT.items()) + + def read(self, request, *args, **kwargs): + kwargs = self._get_url_kwargs(kwargs) + instance = self.get_instance(**kwargs) + + if instance is None: + raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {}) + + return instance + + def update(self, request, *args, **kwargs): + kwargs = self._get_url_kwargs(kwargs) + instance = self.get_instance(**kwargs) + + kwargs = self._get_content_kwargs(kwargs) + if instance: + instance = self.update_instance(instance, **kwargs) + else: + instance = self.create_instance(**kwargs) + + return instance + + def create(self, request, *args, **kwargs): + kwargs = self._get_content_kwargs(kwargs) + instance = self.create_instance(**kwargs) + + headers = {} + try: + headers['Location'] = self.resource(self).url(instance) + except: # TODO: _SkipField should not really happen. + pass + + return Response(status.HTTP_201_CREATED, instance, headers) + + def destroy(self, request, *args, **kwargs): + kwargs = self._get_url_kwargs(kwargs) + instance = self.delete_instance(**kwargs) + if not instance: + raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {}) + + return instance + + def list(self, request, *args, **kwargs): + return self.list_instances(**kwargs) + ########## Pagination Mixins ########## @@ -638,7 +657,7 @@ class PaginatorMixin(object): return self.limit def url_with_page_number(self, page_number): - """ Constructs a url used for getting the next/previous urls """ + """Constructs a url used for getting the next/previous urls.""" url = "%s?page=%d" % (self.request.path, page_number) limit = self.get_limit() @@ -648,21 +667,21 @@ class PaginatorMixin(object): return url def next(self, page): - """ Returns a url to the next page of results (if any) """ + """Returns a url to the next page of results. (If any exists.)""" if not page.has_next(): return None return self.url_with_page_number(page.next_page_number()) def previous(self, page): - """ Returns a url to the previous page of results (if any) """ + """Returns a url to the previous page of results. (If any exists.)""" if not page.has_previous(): return None return self.url_with_page_number(page.previous_page_number()) def serialize_page_info(self, page): - """ This is some useful information that is added to the response """ + """This is some useful information that is added to the response.""" return { 'next': self.next(page), 'page': page.number, @@ -676,14 +695,15 @@ class PaginatorMixin(object): """ Given the response content, paginate and then serialize. - The response is modified to include to useful data relating to the number - of objects, number of pages, next/previous urls etc. etc. + The response is modified to include to useful data relating to the + number of objects, number of pages, next/previous urls etc. etc. The serialised objects are put into `results` on this new, modified response """ - # We don't want to paginate responses for anything other than GET requests + # We don't want to paginate responses for anything other than GET + # requests if self.method.upper() != 'GET': return self._resource.filter_response(obj) diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py index 997fd5103..023de54b9 100644 --- a/djangorestframework/tests/renderers.py +++ b/djangorestframework/tests/renderers.py @@ -1,15 +1,13 @@ from django.conf.urls.defaults import patterns, url -from django import http from django.test import TestCase from djangorestframework import status from djangorestframework.compat import View as DjangoView -from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\ - XMLRenderer +from djangorestframework.renderers import BaseRenderer, JSONRenderer, \ + YAMLRenderer, XMLRenderer from djangorestframework.parsers import JSONParser, YAMLParser from djangorestframework.mixins import ResponseMixin from djangorestframework.response import Response -from djangorestframework.utils.mediatypes import add_media_type_param from StringIO import StringIO import datetime @@ -21,27 +19,30 @@ DUMMYCONTENT = 'dummycontent' RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x + class RendererA(BaseRenderer): media_type = 'mock/renderera' - format="formata" + format = "formata" def render(self, obj=None, media_type=None): return RENDERER_A_SERIALIZER(obj) + class RendererB(BaseRenderer): media_type = 'mock/rendererb' - format="formatb" + format = "formatb" def render(self, obj=None, media_type=None): return RENDERER_B_SERIALIZER(obj) + class MockView(ResponseMixin, DjangoView): renderers = (RendererA, RendererB) def get(self, request, **kwargs): response = Response(DUMMYSTATUS, DUMMYCONTENT) return self.render(response) - + urlpatterns = patterns('', url(r'^.*\.(?P.+)$', MockView.as_view(renderers=[RendererA, RendererB])), @@ -92,7 +93,7 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp['Content-Type'], RendererB.media_type) self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) - + def test_specified_renderer_serializes_content_on_accept_query(self): """The '_accept' query string should behave in the same way as the Accept header.""" resp = self.client.get('/?_accept=%s' % RendererB.media_type) @@ -148,12 +149,7 @@ class RendererIntegrationTests(TestCase): _flat_repr = '{"foo": ["bar", "baz"]}' -_indented_repr = """{ - "foo": [ - "bar", - "baz" - ] -}""" +_indented_repr = '{\n "foo": [\n "bar", \n "baz"\n ]\n}' class JSONRendererTests(TestCase): @@ -165,45 +161,44 @@ class JSONRendererTests(TestCase): """ Test basic JSON rendering. """ - obj = {'foo':['bar','baz']} + obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer(None) content = renderer.render(obj, 'application/json') self.assertEquals(content, _flat_repr) def test_with_content_type_args(self): """ - Test JSON rendering with additional content type arguments supplied. + Test JSON rendering with additional content type arguments supplied. """ - obj = {'foo':['bar','baz']} + obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer(None) - content = renderer.render(obj, 'application/json; indent=2') + content = renderer.render(obj, 'application/json; indent=4') self.assertEquals(content, _indented_repr) - + def test_render_and_parse(self): """ Test rendering and then parsing returns the original object. IE obj -> render -> parse -> obj. """ - obj = {'foo':['bar','baz']} + obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer(None) parser = JSONParser(None) content = renderer.render(obj, 'application/json') (data, files) = parser.parse(StringIO(content)) - self.assertEquals(obj, data) - + self.assertEquals(obj, data) if YAMLRenderer: _yaml_repr = 'foo: [bar, baz]\n' - - + + class YAMLRendererTests(TestCase): """ Tests specific to the JSON Renderer """ - + def test_render(self): """ Test basic YAML rendering. @@ -212,24 +207,24 @@ if YAMLRenderer: renderer = YAMLRenderer(None) content = renderer.render(obj, 'application/yaml') self.assertEquals(content, _yaml_repr) - - + + def test_render_and_parse(self): """ Test rendering and then parsing returns the original object. IE obj -> render -> parse -> obj. """ obj = {'foo':['bar','baz']} - + renderer = YAMLRenderer(None) parser = YAMLParser(None) - + content = renderer.render(obj, 'application/yaml') (data, files) = parser.parse(StringIO(content)) - self.assertEquals(obj, data) + self.assertEquals(obj, data) + - class XMLRendererTestCase(TestCase): """ Tests specific to the XML Renderer @@ -289,4 +284,4 @@ class XMLRendererTestCase(TestCase): def assertXMLContains(self, xml, string): self.assertTrue(xml.startswith('\n')) self.assertTrue(xml.endswith('')) - self.assertTrue(string in xml, '%r not in %r' % (string, xml)) + self.assertTrue(string in xml, '%r not in %r' % (string, xml)) diff --git a/djangorestframework/utils/__init__.py b/djangorestframework/utils/__init__.py index 7d693cc44..86fa4a295 100644 --- a/djangorestframework/utils/__init__.py +++ b/djangorestframework/utils/__init__.py @@ -32,7 +32,7 @@ def url_resolves(url): # From http://www.koders.com/python/fidB6E125C586A6F49EAC38992CF3AFDAAE35651975.aspx?s=mdef:xml #class object_dict(dict): -# """object view of dict, you can +# """object view of dict, you can # >>> a = object_dict() # >>> a.fish = 'fish' # >>> a['fish'] @@ -85,8 +85,8 @@ class XML2Dict(object): old = node_tree[tag] if not isinstance(old, list): node_tree.pop(tag) - node_tree[tag] = [old] # multi times, so change old dict to a list - node_tree[tag].append(tree) # add the new one + node_tree[tag] = [old] # multi times, so change old dict to a list + node_tree[tag].append(tree) # add the new one return node_tree @@ -99,13 +99,13 @@ class XML2Dict(object): """ result = re.compile("\{(.*)\}(.*)").search(tag) if result: - value.namespace, tag = result.groups() + value.namespace, tag = result.groups() return (tag, value) def parse(self, file): """parse a xml file to a dict""" f = open(file, 'r') - return self.fromstring(f.read()) + return self.fromstring(f.read()) def fromstring(self, s): """parse a string""" @@ -132,16 +132,16 @@ class XMLRenderer(): xml.startElement(key, {}) self._to_xml(xml, value) xml.endElement(key) - + elif data is None: # Don't output any value - pass + pass else: xml.characters(smart_unicode(data)) def dict2xml(self, data): - stream = StringIO.StringIO() + stream = StringIO.StringIO() xml = SimplerXMLGenerator(stream, "utf-8") xml.startDocument() @@ -154,4 +154,4 @@ class XMLRenderer(): return stream.getvalue() def dict2xml(input): - return XMLRenderer().dict2xml(input) \ No newline at end of file + return XMLRenderer().dict2xml(input) diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 4c472e2dc..d449636db 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -50,13 +50,13 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ List of all authenticating methods to attempt. """ - authentication = ( authentication.UserLoggedInAuthentication, - authentication.BasicAuthentication ) + authentication = (authentication.UserLoggedInAuthentication, + authentication.BasicAuthentication) """ List of all permissions that must be checked. """ - permissions = ( permissions.FullAnonAccess, ) + permissions = (permissions.FullAnonAccess,) @classmethod def as_view(cls, **initkwargs): @@ -86,8 +86,8 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): def initial(self, request, *args, **kargs): """ Hook for any code that needs to run prior to anything else. - Required if you want to do things like set `request.upload_handlers` before - the authentication and dispatch handling is run. + Required if you want to do things like set `request.upload_handlers` + before the authentication and dispatch handling is run. """ pass @@ -136,11 +136,13 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): response = Response(status.HTTP_204_NO_CONTENT) if request.method == 'OPTIONS': - # do not filter the response for HTTP OPTIONS, else the response fields are lost, + # do not filter the response for HTTP OPTIONS, + # else the response fields are lost, # as they do not correspond with model fields response.cleaned_content = response.raw_content else: - # Pre-serialize filtering (eg filter complex objects into natively serializable types) + # Pre-serialize filtering (eg filter complex objects into + # natively serializable types) response.cleaned_content = self.filter_response(response.raw_content) except ErrorResponse, exc: @@ -148,8 +150,8 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): # Always add these headers. # - # TODO - this isn't actually the correct way to set the vary header, - # also it's currently sub-optimal for HTTP caching - need to sort that out. + # TODO - this isn't really the correct way to set the Vary header, + # also it's currently sub-optimal for HTTP caching. response.headers['Allow'] = ', '.join(self.allowed_methods) response.headers['Vary'] = 'Authenticate, Accept' @@ -161,7 +163,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): return self.render(response) def options(self, request, *args, **kwargs): - response_obj = { + ret = { 'name': get_name(self), 'description': get_description(self), 'renders': self._rendered_media_types, @@ -172,8 +174,8 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): field_name_types = {} for name, field in form.fields.iteritems(): field_name_types[name] = field.__class__.__name__ - response_obj['fields'] = field_name_types - return response_obj + ret['fields'] = field_name_types + return ret class ModelView(ModelMixin, View): @@ -191,7 +193,9 @@ class ModelView(ModelMixin, View): class InstanceModelView(ModelView): """ - A view which provides default operations for read/update/delete against a model instance. + A view which provides default operations for read/update/delete against a + model instance. This view is also treated as the Canonical identifier + of the instances. """ _suffix = 'Instance' From a32eb50bae8b16803eec108ec3a9c48ecf304d7a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 14 Dec 2011 14:48:51 +0000 Subject: [PATCH 4/5] De-couple request/response layer CRUD, from underlying resource CRUD. --- djangorestframework/tests/mixins.py | 14 ++++---- djangorestframework/views.py | 50 +++++++---------------------- 2 files changed, 18 insertions(+), 46 deletions(-) diff --git a/djangorestframework/tests/mixins.py b/djangorestframework/tests/mixins.py index 0ccef5d3a..3b6512185 100644 --- a/djangorestframework/tests/mixins.py +++ b/djangorestframework/tests/mixins.py @@ -4,7 +4,7 @@ from django.utils import simplejson as json from djangorestframework import status from djangorestframework.compat import RequestFactory from django.contrib.auth.models import Group, User -from djangorestframework.mixins import PaginatorMixin +from djangorestframework.mixins import PaginatorMixin, ModelMixin from djangorestframework.resources import ModelResource from djangorestframework.response import Response from djangorestframework.tests.models import CustomUser @@ -29,7 +29,7 @@ class TestModelCreation(TestCase): mixin.resource = GroupResource mixin.CONTENT = form_data - response = mixin.post(request) + response = mixin.create(request) self.assertEquals(1, Group.objects.count()) self.assertEquals('foo', response.cleaned_content.name) @@ -55,7 +55,7 @@ class TestModelCreation(TestCase): mixin.resource = UserResource mixin.CONTENT = cleaned_data - response = mixin.post(request) + response = mixin.create(request) self.assertEquals(1, User.objects.count()) self.assertEquals(1, response.cleaned_content.groups.count()) self.assertEquals('foo', response.cleaned_content.groups.all()[0].name) @@ -78,7 +78,7 @@ class TestModelCreation(TestCase): mixin.resource = UserResource mixin.CONTENT = cleaned_data - response = mixin.post(request) + response = mixin.create(request) self.assertEquals(1, CustomUser.objects.count()) self.assertEquals(0, response.cleaned_content.groups.count()) @@ -89,11 +89,11 @@ class TestModelCreation(TestCase): request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [group] - mixin = CreateModelMixin() + mixin = ModelMixin() mixin.resource = UserResource mixin.CONTENT = cleaned_data - response = mixin.post(request) + response = mixin.create(request) self.assertEquals(2, CustomUser.objects.count()) self.assertEquals(1, response.cleaned_content.groups.count()) self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) @@ -109,7 +109,7 @@ class TestModelCreation(TestCase): mixin.resource = UserResource mixin.CONTENT = cleaned_data - response = mixin.post(request) + response = mixin.create(request) self.assertEquals(3, CustomUser.objects.count()) self.assertEquals(2, response.cleaned_content.groups.count()) self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) diff --git a/djangorestframework/views.py b/djangorestframework/views.py index d449636db..18911a52c 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -184,14 +184,8 @@ class ModelView(ModelMixin, View): """ resource = resources.ModelResource - def _filter_kwargs(self, kwargs): - kwargs = kwargs.copy() - if BaseRenderer._FORMAT_QUERY_PARAM in kwargs: - del kwargs[BaseRenderer._FORMAT_QUERY_PARAM] - return kwargs - -class InstanceModelView(ModelView): +class InstanceModelView(InstanceMixin, ModelView): """ A view which provides default operations for read/update/delete against a model instance. This view is also treated as the Canonical identifier @@ -199,49 +193,27 @@ class InstanceModelView(ModelView): """ _suffix = 'Instance' - def get(self, request, *args, **kwargs): - instance = self.read(request, *args, **self._filter_kwargs(kwargs)) - - if not instance: - raise ErrorResponse(status.HTTP_404_NOT_FOUND) - - return instance - - def put(self, request, *args, **kwargs): - return self.update(request, *args, **self._filter_kwargs(kwargs)) - - def delete(self, request, *args, **kwargs): - instance = self.destroy(request, *args, **self._filter_kwargs(kwargs)) - - if not instance: - raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {}) - - return None + get = ModelMixin.read + put = ModelMixin.update + delete = ModelMixin.destroy class ListModelView(ModelView): """ - A view which provides default operations for list, against a model in the database. + A view which provides default operations for list, against a model in the + database. """ _suffix = 'List' - def get(self, request, *args, **kwargs): - return self.list(request, *args, **self._filter_kwargs(kwargs)) + get = ModelMixin.list class ListOrCreateModelView(ModelView): """ - A view which provides default operations for list and create, against a model in the database. + A view which provides default operations for list and create, against a + model in the database. """ _suffix = 'List' - def get(self, request, *args, **kwargs): - return self.list(request, *args, **self._filter_kwargs(kwargs)) - - def post(self, request, *args, **kwargs): - instance = self.create(request, *args, **self._filter_kwargs(kwargs)) - - headers = {} - if hasattr(instance, 'get_absolute_url'): - headers['Location'] = self.resource(self).url(instance) - return Response(status.HTTP_201_CREATED, instance, headers) + get = ModelMixin.list + post = ModelMixin.create From 9dbe8b646eb8f599d94eff9a00f6f06ddf1799f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Piquemal?= Date: Tue, 3 Jan 2012 11:10:14 +0200 Subject: [PATCH 5/5] ran the tests and made some corrections to the merging --- djangorestframework/authentication.py | 6 +++--- djangorestframework/resources.py | 4 +--- djangorestframework/tests/renderers.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/djangorestframework/authentication.py b/djangorestframework/authentication.py index 656c58e4b..48f898dd5 100644 --- a/djangorestframework/authentication.py +++ b/djangorestframework/authentication.py @@ -79,9 +79,9 @@ class BasicAuthentication(BaseAuthentication): return None try: - username = encoding.smart_unicode(auth_parts[0]) - password = encoding.smart_unicode(auth_parts[2]) - except encoding.DjangoUnicodeDecodeError: + username = smart_unicode(auth_parts[0]) + password = smart_unicode(auth_parts[2]) + except DjangoUnicodeDecodeError: return None user = authenticate(username=username, password=password) diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index 30d75eef9..66ab0a855 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -126,8 +126,6 @@ class FormResource(Resource): data = data and data or {} files = files and files or {} - # In addition to regular validation we also ensure no additional fields - # are being passed in... seen_fields_set = set(data.keys()) form_fields_set = set(bound_form.fields.keys()) allowed_extra_fields_set = set(allowed_extra_fields) @@ -142,7 +140,7 @@ class FormResource(Resource): cleaned_data = bound_form.cleaned_data # Add in any extra fields to the cleaned content... - for key in (allowed_extra_fields & seen_fields) - set(cleaned_data.keys()): + for key in (allowed_extra_fields_set & seen_fields_set) - set(cleaned_data.keys()): cleaned_data[key] = data[key] return cleaned_data diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py index b683e27d3..c3dfb98b6 100644 --- a/djangorestframework/tests/renderers.py +++ b/djangorestframework/tests/renderers.py @@ -157,7 +157,7 @@ class RendererIntegrationTests(TestCase): _flat_repr = '{"foo": ["bar", "baz"]}' -_indented_repr = '{\n "foo": [\n "bar", \n "baz"\n ]\n}' +_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}' class JSONRendererTests(TestCase):