Refactor to use self.CONTENT to access request body. Get file upload working

This commit is contained in:
tom christie tom@tomchristie.com 2011-04-02 16:32:37 +01:00
parent 8845b281fe
commit 4687db680c
16 changed files with 540 additions and 372 deletions

View File

@ -21,3 +21,4 @@ MANIFEST
.cache .cache
.coverage .coverage
.tox .tox
.DS_Store

View File

@ -1,57 +0,0 @@
"""Mixin classes that provide a determine_content(request) method to return the content type and content of a request.
We use this more generic behaviour to allow for overloaded content in POST forms.
"""
class ContentMixin(object):
"""Base class for all ContentMixin classes, which simply defines the interface they provide."""
def determine_content(self, request):
"""If the request contains content return a tuple of (content_type, content) otherwise return None.
Note that content_type may be None if it is unset.
Must be overridden to be implemented."""
raise NotImplementedError()
class StandardContentMixin(ContentMixin):
"""Standard HTTP request content behaviour.
See RFC 2616 sec 4.3 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3"""
def determine_content(self, request):
"""If the request contains content return a tuple of (content_type, content) otherwise return None.
Note that content_type may be None if it is unset."""
if not request.META.get('CONTENT_LENGTH', None) and not request.META.get('TRANSFER_ENCODING', None):
return None
return (request.META.get('CONTENT_TYPE', None), request.raw_post_data)
class OverloadedContentMixin(ContentMixin):
"""HTTP request content behaviour that also allows arbitrary content to be tunneled in form data."""
"""The name to use for the content override field in the POST form.
Set this to *None* to desactivate content overloading."""
CONTENT_PARAM = '_content'
"""The name to use for the content-type override field in the POST form.
Taken into account only if content overloading is activated."""
CONTENTTYPE_PARAM = '_contenttype'
def determine_content(self, request):
"""If the request contains content, returns a tuple of (content_type, content) otherwise returns None.
Note that content_type may be None if it is unset."""
if not request.META.get('CONTENT_LENGTH', None) and not request.META.get('TRANSFER_ENCODING', None):
return None
content_type = request.META.get('CONTENT_TYPE', None)
if (request.method == 'POST' and self.CONTENT_PARAM and
request.POST.get(self.CONTENT_PARAM, None) is not None):
# Set content type if form contains a non-empty CONTENTTYPE_PARAM field
content_type = None
if self.CONTENTTYPE_PARAM and request.POST.get(self.CONTENTTYPE_PARAM, None):
content_type = request.POST.get(self.CONTENTTYPE_PARAM, None)
request.META['CONTENT_TYPE'] = content_type # TODO : VERY BAD, avoid modifying original request.
return (content_type, request.POST[self.CONTENT_PARAM])
else:
return (content_type, request.raw_post_data)

View File

@ -13,7 +13,6 @@ from djangorestframework.validators import FormValidatorMixin
from djangorestframework.utils import dict2xml, url_resolves from djangorestframework.utils import dict2xml, url_resolves
from djangorestframework.markdownwrapper import apply_markdown from djangorestframework.markdownwrapper import apply_markdown
from djangorestframework.breadcrumbs import get_breadcrumbs from djangorestframework.breadcrumbs import get_breadcrumbs
from djangorestframework.content import OverloadedContentMixin
from djangorestframework.description import get_name, get_description from djangorestframework.description import get_name, get_description
from djangorestframework import status from djangorestframework import status
@ -254,7 +253,7 @@ class DocumentingTemplateEmitter(BaseEmitter):
# If we're not using content overloading there's no point in supplying a generic form, # If we're not using content overloading there's no point in supplying a generic form,
# as the resource won't treat the form's value as the content of the request. # as the resource won't treat the form's value as the content of the request.
if not isinstance(resource, OverloadedContentMixin): if not getattr(resource, 'USE_FORM_OVERLOADING', False):
return None return None
# NB. http://jacobian.org/writing/dynamic-form-generation/ # NB. http://jacobian.org/writing/dynamic-form-generation/

View File

@ -0,0 +1,78 @@
"""
Handling of media types, as found in HTTP Content-Type and Accept headers.
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7
"""
from django.http.multipartparser import parse_header
class MediaType(object):
def __init__(self, media_type_str):
self.orig = media_type_str
self.media_type, self.params = parse_header(media_type_str)
self.main_type, sep, self.sub_type = self.media_type.partition('/')
def match(self, other):
"""Return true if this MediaType satisfies the constraint of the given MediaType."""
for key in other.params.keys():
if key != 'q' and other.params[key] != self.params.get(key, None):
return False
if other.sub_type != '*' and other.sub_type != self.sub_type:
return False
if other.main_type != '*' and other.main_type != self.main_type:
return False
return True
def precedence(self):
"""
Return a precedence level for the media type given how specific it is.
"""
if self.main_type == '*':
return 1
elif self.sub_type == '*':
return 2
elif not self.params or self.params.keys() == ['q']:
return 3
return 4
def quality(self):
"""
Return a quality level for the media type.
"""
try:
return Decimal(self.params.get('q', '1.0'))
except:
return Decimal(0)
def score(self):
"""
Return an overall score for a given media type given it's quality and precedence.
"""
# NB. quality values should only have up to 3 decimal points
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.9
return self.quality * 10000 + self.precedence
def is_form(self):
"""
Return True if the MediaType is a valid form media type as defined by the HTML4 spec.
(NB. HTML5 also adds text/plain to the list of valid form media types, but we don't support this here)
"""
return self.media_type == 'application/x-www-form-urlencoded' or \
self.media_type == 'multipart/form-data'
def as_tuple(self):
return (self.main_type, self.sub_type, self.params)
def __repr__(self):
return "<MediaType %s>" % (self.as_tuple(),)
def __str__(self):
return unicode(self).encode('utf-8')
def __unicode__(self):
return self.orig

View File

@ -1,35 +0,0 @@
"""Mixin classes that provide a determine_method(request) function to determine the HTTP
method that a given request should be treated as. We use this more generic behaviour to
allow for overloaded methods in POST forms.
See Richardson & Ruby's RESTful Web Services for justification.
"""
class MethodMixin(object):
"""Base class for all MethodMixin classes, which simply defines the interface they provide."""
def determine_method(self, request):
"""Simply return GET, POST etc... as appropriate."""
raise NotImplementedError()
class StandardMethodMixin(MethodMixin):
"""Provide for standard HTTP behaviour, with no overloaded POST."""
def determine_method(self, request):
"""Simply return GET, POST etc... as appropriate."""
return request.method.upper()
class OverloadedPOSTMethodMixin(MethodMixin):
"""Provide for overloaded POST behaviour."""
"""The name to use for the method override field in the POST form."""
METHOD_PARAM = '_method'
def determine_method(self, request):
"""Simply return GET, POST etc... as appropriate, allowing for POST overloading
by setting a form field with the requested method name."""
method = request.method.upper()
if method == 'POST' and self.METHOD_PARAM and request.POST.has_key(self.METHOD_PARAM):
method = request.POST[self.METHOD_PARAM].upper()
return method

View File

@ -1,9 +1,18 @@
from StringIO import StringIO """Django supports parsing the content of an HTTP request, but only for form POST requests.
That behaviour is sufficient for dealing with standard HTML forms, but it doesn't map well
to general HTTP requests.
We need a method to be able to:
1) Determine the parsed content on a request for methods other than POST (eg typically also PUT)
2) Determine the parsed content on a request for media types other than application/x-www-form-urlencoded
and multipart/form-data. (eg also handle multipart/json)
"""
from django.http.multipartparser import MultiPartParser as DjangoMPParser from django.http.multipartparser import MultiPartParser as DjangoMPParser
from djangorestframework.response import ResponseException from djangorestframework.response import ResponseException
from djangorestframework import status from djangorestframework import status
from djangorestframework.utils import as_tuple
from djangorestframework.mediatypes import MediaType
try: try:
import json import json
@ -18,22 +27,27 @@ except ImportError:
class ParserMixin(object): class ParserMixin(object):
parsers = () parsers = ()
def parse(self, content_type, content): def parse(self, stream, content_type):
# See RFC 2616 sec 3 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7 """
split = content_type.split(';', 1) Parse the request content.
if len(split) > 1:
content_type = split[0]
content_type = content_type.strip()
media_type_to_parser = dict([(parser.media_type, parser) for parser in self.parsers]) May raise a 415 ResponseException (Unsupported Media Type),
or a 400 ResponseException (Bad Request).
"""
parsers = as_tuple(self.parsers)
try: parser = None
parser = media_type_to_parser[content_type] for parser_cls in parsers:
except KeyError: if parser_cls.handles(content_type):
parser = parser_cls(self)
break
if parser is None:
raise ResponseException(status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, raise ResponseException(status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
{'error': 'Unsupported media type in request \'%s\'.' % content_type}) {'error': 'Unsupported media type in request \'%s\'.' %
content_type.media_type})
return parser(self).parse(content)
return parser.parse(stream)
@property @property
def parsed_media_types(self): def parsed_media_types(self):
@ -48,36 +62,41 @@ class ParserMixin(object):
class BaseParser(object): class BaseParser(object):
"""All parsers should extend BaseParser, specifing a media_type attribute, """All parsers should extend BaseParser, specifying a media_type attribute,
and overriding the parse() method.""" and overriding the parse() method."""
media_type = None media_type = None
def __init__(self, resource): def __init__(self, view):
"""Initialise the parser with the Resource instance as state, """
in case the parser needs to access any metadata on the Resource object.""" Initialise the parser with the View instance as state,
self.resource = resource in case the parser needs to access any metadata on the View object.
"""
self.view = view
def parse(self, input): @classmethod
"""Given some serialized input, return the deserialized output. def handles(self, media_type):
The input will be the raw request content body. The return value may be of """
any type, but for many parsers/inputs it might typically be a dict.""" Returns `True` if this parser is able to deal with the given MediaType.
return input """
return media_type.match(self.media_type)
def parse(self, stream):
"""Given a stream to read from, return the deserialized output.
The return value may be of any type, but for many parsers it might typically be a dict-like object."""
raise NotImplementedError("BaseParser.parse() Must be overridden to be implemented.")
class JSONParser(BaseParser): class JSONParser(BaseParser):
media_type = 'application/json' media_type = MediaType('application/json')
def parse(self, input): def parse(self, stream):
try: try:
return json.loads(input) return json.load(stream)
except ValueError, exc: except ValueError, exc:
raise ResponseException(status.HTTP_400_BAD_REQUEST, {'detail': 'JSON parse error - %s' % str(exc)}) raise ResponseException(status.HTTP_400_BAD_REQUEST, {'detail': 'JSON parse error - %s' % str(exc)})
class XMLParser(BaseParser):
media_type = 'application/xml'
class DataFlatener(object): class DataFlatener(object):
"""Utility object for flatening dictionaries of lists. Useful for "urlencoded" decoded data.""" """Utility object for flatening dictionaries of lists. Useful for "urlencoded" decoded data."""
@ -102,6 +121,7 @@ class DataFlatener(object):
*val_list* which is the received value for parameter *key* can be used to guess the answer.""" *val_list* which is the received value for parameter *key* can be used to guess the answer."""
return False return False
class FormParser(BaseParser, DataFlatener): class FormParser(BaseParser, DataFlatener):
"""The default parser for form data. """The default parser for form data.
Return a dict containing a single value for each non-reserved parameter. Return a dict containing a single value for each non-reserved parameter.
@ -109,16 +129,17 @@ class FormParser(BaseParser, DataFlatener):
In order to handle select multiple (and having possibly more than a single value for each parameter), In order to handle select multiple (and having possibly more than a single value for each parameter),
you can customize the output by subclassing the method 'is_a_list'.""" you can customize the output by subclassing the method 'is_a_list'."""
media_type = 'application/x-www-form-urlencoded' media_type = MediaType('application/x-www-form-urlencoded')
"""The value of the parameter when the select multiple is empty. """The value of the parameter when the select multiple is empty.
Browsers are usually stripping the select multiple that have no option selected from the parameters sent. Browsers are usually stripping the select multiple that have no option selected from the parameters sent.
A common hack to avoid this is to send the parameter with a value specifying that the list is empty. A common hack to avoid this is to send the parameter with a value specifying that the list is empty.
This value will always be stripped before the data is returned.""" This value will always be stripped before the data is returned."""
EMPTY_VALUE = '_empty' EMPTY_VALUE = '_empty'
RESERVED_FORM_PARAMS = ('csrfmiddlewaretoken',)
def parse(self, input): def parse(self, stream):
data = parse_qs(input, keep_blank_values=True) data = parse_qs(stream.read(), keep_blank_values=True)
# removing EMPTY_VALUEs from the lists and flatening the data # removing EMPTY_VALUEs from the lists and flatening the data
for key, val_list in data.items(): for key, val_list in data.items():
@ -127,8 +148,9 @@ class FormParser(BaseParser, DataFlatener):
# Strip any parameters that we are treating as reserved # Strip any parameters that we are treating as reserved
for key in data.keys(): for key in data.keys():
if key in self.resource.RESERVED_FORM_PARAMS: if key in self.RESERVED_FORM_PARAMS:
data.pop(key) data.pop(key)
return data return data
def remove_empty_val(self, val_list): def remove_empty_val(self, val_list):
@ -141,27 +163,28 @@ class FormParser(BaseParser, DataFlatener):
else: else:
val_list.pop(ind) val_list.pop(ind)
# TODO: Allow parsers to specify multiple media_types
class MultipartData(dict):
def __init__(self, data, files):
dict.__init__(self, data)
self.FILES = files
class MultipartParser(BaseParser, DataFlatener): class MultipartParser(BaseParser, DataFlatener):
media_type = 'multipart/form-data' media_type = MediaType('multipart/form-data')
RESERVED_FORM_PARAMS = ('csrfmiddlewaretoken',)
def parse(self, input): def parse(self, stream):
upload_handlers = self.view.request._get_upload_handlers()
request = self.resource.request django_mpp = DjangoMPParser(self.view.request.META, stream, upload_handlers)
#TODO : that's pretty dumb : files are loaded with
#upload_handlers, but as we read the request body completely (input),
#then it kind of misses the point. Why not input as a stream ?
upload_handlers = request._get_upload_handlers()
django_mpp = DjangoMPParser(request.META, StringIO(input), upload_handlers)
data, files = django_mpp.parse() data, files = django_mpp.parse()
# Flatening data, files and combining them # Flatening data, files and combining them
data = self.flatten_data(dict(data.iterlists())) data = self.flatten_data(dict(data.iterlists()))
files = self.flatten_data(dict(files.iterlists())) files = self.flatten_data(dict(files.iterlists()))
data.update(files)
# Strip any parameters that we are treating as reserved # Strip any parameters that we are treating as reserved
for key in data.keys(): for key in data.keys():
if key in self.resource.RESERVED_FORM_PARAMS: if key in self.RESERVED_FORM_PARAMS:
data.pop(key) data.pop(key)
return data
return MultipartData(data, files)

View File

@ -0,0 +1,128 @@
from djangorestframework.mediatypes import MediaType
#from djangorestframework.requestparsing import parse, load_parser
from StringIO import StringIO
class RequestMixin(object):
"""Delegate class that supplements an HttpRequest object with additional behaviour."""
USE_FORM_OVERLOADING = True
METHOD_PARAM = "_method"
CONTENTTYPE_PARAM = "_content_type"
CONTENT_PARAM = "_content"
def _get_method(self):
"""
Returns the HTTP method for the current view.
"""
if not hasattr(self, '_method'):
self._method = self.request.method
return self._method
def _set_method(self, method):
"""
Set the method for the current view.
"""
self._method = method
def _get_content_type(self):
"""
Returns a MediaType object, representing the request's content type header.
"""
if not hasattr(self, '_content_type'):
content_type = self.request.META.get('HTTP_CONTENT_TYPE', self.request.META.get('CONTENT_TYPE', ''))
self._content_type = MediaType(content_type)
return self._content_type
def _set_content_type(self, content_type):
"""
Set the content type. Should be a MediaType object.
"""
self._content_type = content_type
def _get_accept(self):
"""
Returns a list of MediaType objects, representing the request's accept header.
"""
if not hasattr(self, '_accept'):
accept = self.request.META.get('HTTP_ACCEPT', '*/*')
self._accept = [MediaType(elem) for elem in accept.split(',')]
return self._accept
def _set_accept(self):
"""
Set the acceptable media types. Should be a list of MediaType objects.
"""
self._accept = accept
def _get_stream(self):
"""
Returns an object that may be used to stream the request content.
"""
if not hasattr(self, '_stream'):
if hasattr(self.request, 'read'):
self._stream = self.request
else:
self._stream = StringIO(self.request.raw_post_data)
return self._stream
def _set_stream(self, stream):
"""
Set the stream representing the request body.
"""
self._stream = stream
def _get_raw_content(self):
"""
Returns the parsed content of the request
"""
if not hasattr(self, '_raw_content'):
self._raw_content = self.parse(self.stream, self.content_type)
return self._raw_content
def _get_content(self):
"""
Returns the parsed and validated content of the request
"""
if not hasattr(self, '_content'):
self._content = self.validate(self.RAW_CONTENT)
return self._content
def perform_form_overloading(self):
"""
Check the request to see if it is using form POST '_method'/'_content'/'_content_type' overrides.
If it is then alter self.method, self.content_type, self.CONTENT to reflect that rather than simply
delegating them to the original request.
"""
if not self.USE_FORM_OVERLOADING or self.method != 'POST' or not self.content_type.is_form():
return
content = self.RAW_CONTENT
if self.METHOD_PARAM in content:
self.method = content[self.METHOD_PARAM].upper()
del self._raw_content[self.METHOD_PARAM]
if self.CONTENT_PARAM in content and self.CONTENTTYPE_PARAM in content:
self._content_type = MediaType(content[self.CONTENTTYPE_PARAM])
self._stream = StringIO(content[self.CONTENT_PARAM])
del(self._raw_content)
method = property(_get_method, _set_method)
content_type = property(_get_content_type, _set_content_type)
accept = property(_get_accept, _set_accept)
stream = property(_get_stream, _set_stream)
RAW_CONTENT = property(_get_raw_content)
CONTENT = property(_get_content)

View File

@ -6,47 +6,42 @@ from djangorestframework.emitters import EmitterMixin
from djangorestframework.parsers import ParserMixin from djangorestframework.parsers import ParserMixin
from djangorestframework.authenticators import AuthenticatorMixin from djangorestframework.authenticators import AuthenticatorMixin
from djangorestframework.validators import FormValidatorMixin from djangorestframework.validators import FormValidatorMixin
from djangorestframework.content import OverloadedContentMixin
from djangorestframework.methods import OverloadedPOSTMethodMixin
from djangorestframework.response import Response, ResponseException from djangorestframework.response import Response, ResponseException
from djangorestframework.request import RequestMixin
from djangorestframework import emitters, parsers, authenticators, status from djangorestframework import emitters, parsers, authenticators, status
import re
# TODO: Figure how out references and named urls need to work nicely # TODO: Figure how out references and named urls need to work nicely
# TODO: POST on existing 404 URL, PUT on existing 404 URL # TODO: POST on existing 404 URL, PUT on existing 404 URL
# #
# NEXT: Exceptions on func() -> 500, tracebacks emitted if settings.DEBUG # NEXT: Exceptions on func() -> 500, tracebacks emitted if settings.DEBUG
#
__all__ = ['Resource'] __all__ = ['Resource']
class Resource(EmitterMixin, ParserMixin, AuthenticatorMixin, FormValidatorMixin, RequestMixin, View):
class Resource(EmitterMixin, ParserMixin, AuthenticatorMixin, FormValidatorMixin,
OverloadedContentMixin, OverloadedPOSTMethodMixin, View):
"""Handles incoming requests and maps them to REST operations, """Handles incoming requests and maps them to REST operations,
performing authentication, input deserialization, input validation, output serialization.""" performing authentication, input deserialization, input validation, output serialization."""
# List of RESTful operations which may be performed on this resource. # List of RESTful operations which may be performed on this resource.
# These are going to get dropped at some point, the allowable methods will be defined simply by
# which methods are present on the request (in the same way as Django's generic View)
allowed_methods = ('GET',) allowed_methods = ('GET',)
anon_allowed_methods = () anon_allowed_methods = ()
# List of emitters the resource can serialize the response with, ordered by preference # List of emitters the resource can serialize the response with, ordered by preference.
emitters = ( emitters.JSONEmitter, emitters = ( emitters.JSONEmitter,
emitters.DocumentingHTMLEmitter, emitters.DocumentingHTMLEmitter,
emitters.DocumentingXHTMLEmitter, emitters.DocumentingXHTMLEmitter,
emitters.DocumentingPlainTextEmitter, emitters.DocumentingPlainTextEmitter,
emitters.XMLEmitter ) emitters.XMLEmitter )
# List of content-types the resource can read from # List of parsers the resource can parse the request with.
parsers = ( parsers.JSONParser, parsers = ( parsers.JSONParser,
parsers.XMLParser,
parsers.FormParser, parsers.FormParser,
parsers.MultipartParser ) parsers.MultipartParser )
# List of all authenticating methods to attempt # List of all authenticating methods to attempt.
authenticators = ( authenticators.UserLoggedInAuthenticator, authenticators = ( authenticators.UserLoggedInAuthenticator,
authenticators.BasicAuthenticator ) authenticators.BasicAuthenticator )
@ -63,12 +58,6 @@ class Resource(EmitterMixin, ParserMixin, AuthenticatorMixin, FormValidatorMixin
callmap = { 'GET': 'get', 'POST': 'post', callmap = { 'GET': 'get', 'POST': 'post',
'PUT': 'put', 'DELETE': 'delete' } 'PUT': 'put', 'DELETE': 'delete' }
# Some reserved parameters to allow us to use standard HTML forms with our resource
# Override any/all of these with None to disable them, or override them with another value to rename them.
CSRF_PARAM = 'csrfmiddlewaretoken' # Django's CSRF token used in form params
def get(self, request, auth, *args, **kwargs): def get(self, request, auth, *args, **kwargs):
"""Must be subclassed to be implemented.""" """Must be subclassed to be implemented."""
self.not_implemented('GET') self.not_implemented('GET')
@ -137,24 +126,14 @@ class Resource(EmitterMixin, ParserMixin, AuthenticatorMixin, FormValidatorMixin
4. cleanup the response data 4. cleanup the response data
5. serialize response data into response content, using standard HTTP content negotiation 5. serialize response data into response content, using standard HTTP content negotiation
""" """
self.request = request self.request = request
# Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here. # Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.
prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host()) prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host())
set_script_prefix(prefix) set_script_prefix(prefix)
# These sets are determined now so that overridding classes can modify the various parameter names,
# or set them to None to disable them.
self.RESERVED_FORM_PARAMS = set((self.METHOD_PARAM, self.CONTENTTYPE_PARAM, self.CONTENT_PARAM, self.CSRF_PARAM))
self.RESERVED_QUERY_PARAMS = set((self.ACCEPT_QUERY_PARAM))
self.RESERVED_FORM_PARAMS.discard(None)
self.RESERVED_QUERY_PARAMS.discard(None)
method = self.determine_method(request)
try: try:
# Authenticate the request, and store any context so that the resource operations can # Authenticate the request, and store any context so that the resource operations can
# do more fine grained authentication if required. # do more fine grained authentication if required.
# #
@ -163,19 +142,21 @@ class Resource(EmitterMixin, ParserMixin, AuthenticatorMixin, FormValidatorMixin
# has been signed against a particular set of permissions) # has been signed against a particular set of permissions)
auth_context = self.authenticate(request) auth_context = self.authenticate(request)
# If using a form POST with '_method'/'_content'/'_content_type' overrides, then alter
# self.method, self.content_type, self.CONTENT appropriately.
self.perform_form_overloading()
# Ensure the requested operation is permitted on this resource # Ensure the requested operation is permitted on this resource
self.check_method_allowed(method, auth_context) self.check_method_allowed(self.method, auth_context)
# Get the appropriate create/read/update/delete function # Get the appropriate create/read/update/delete function
func = getattr(self, self.callmap.get(method, None)) func = getattr(self, self.callmap.get(self.method, None))
# Either generate the response data, deserializing and validating any request data # Either generate the response data, deserializing and validating any request data
# TODO: Add support for message bodys on other HTTP methods, as it is valid (although non-conventional). # TODO: This is going to change to: func(request, *args, **kwargs)
if method in ('PUT', 'POST'): # That'll work out now that we have the lazily evaluated self.CONTENT property.
(content_type, content) = self.determine_content(request) if self.method in ('PUT', 'POST'):
parser_content = self.parse(content_type, content) response_obj = func(request, auth_context, self.CONTENT, *args, **kwargs)
cleaned_content = self.validate(parser_content)
response_obj = func(request, auth_context, cleaned_content, *args, **kwargs)
else: else:
response_obj = func(request, auth_context, *args, **kwargs) response_obj = func(request, auth_context, *args, **kwargs)
@ -191,11 +172,13 @@ class Resource(EmitterMixin, ParserMixin, AuthenticatorMixin, FormValidatorMixin
# Pre-serialize filtering (eg filter complex objects into natively serializable types) # Pre-serialize filtering (eg filter complex objects into natively serializable types)
response.cleaned_content = self.cleanup_response(response.raw_content) response.cleaned_content = self.cleanup_response(response.raw_content)
except ResponseException, exc: except ResponseException, exc:
response = exc.response response = exc.response
# Always add these headers # Always add these headers.
#
# TODO - this isn't actually the correct way to set the vary header,
# also it's currently sub-obtimal for HTTP caching - need to sort that out.
response.headers['Allow'] = ', '.join(self.allowed_methods) response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept' response.headers['Vary'] = 'Authenticate, Accept'

View File

@ -65,7 +65,7 @@
{% if resource.METHOD_PARAM and form %} {% if resource.METHOD_PARAM and form %}
{% if 'POST' in resource.allowed_methods %} {% if 'POST' in resource.allowed_methods %}
<form action="{{ request.path }}" method="post"> <form action="{{ request.path }}" method="post" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %}>
<fieldset class='module aligned'> <fieldset class='module aligned'>
<h2>POST {{ name }}</h2> <h2>POST {{ name }}</h2>
{% csrf_token %} {% csrf_token %}
@ -86,7 +86,7 @@
{% endif %} {% endif %}
{% if 'PUT' in resource.allowed_methods %} {% if 'PUT' in resource.allowed_methods %}
<form action="{{ request.path }}" method="post"> <form action="{{ request.path }}" method="post" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %}>
<fieldset class='module aligned'> <fieldset class='module aligned'>
<h2>PUT {{ name }}</h2> <h2>PUT {{ name }}</h2>
<input type="hidden" name="{{ resource.METHOD_PARAM }}" value="PUT" /> <input type="hidden" name="{{ resource.METHOD_PARAM }}" value="PUT" />

View File

@ -4,7 +4,6 @@ import os
modules = [filename.rsplit('.', 1)[0] modules = [filename.rsplit('.', 1)[0]
for filename in os.listdir(os.path.dirname(__file__)) for filename in os.listdir(os.path.dirname(__file__))
if filename.endswith('.py') and not filename.startswith('_')] if filename.endswith('.py') and not filename.startswith('_')]
__test__ = dict() __test__ = dict()
for module in modules: for module in modules:

View File

@ -1,121 +1,122 @@
from django.test import TestCase # TODO: refactor these tests
from djangorestframework.compat import RequestFactory #from django.test import TestCase
from djangorestframework.content import ContentMixin, StandardContentMixin, OverloadedContentMixin #from djangorestframework.compat import RequestFactory
#from djangorestframework.content import ContentMixin, StandardContentMixin, OverloadedContentMixin
#
class TestContentMixins(TestCase): #
def setUp(self): #class TestContentMixins(TestCase):
self.req = RequestFactory() # def setUp(self):
# self.req = RequestFactory()
# Interface tests #
# # Interface tests
def test_content_mixin_interface(self): #
"""Ensure the ContentMixin interface is as expected.""" # def test_content_mixin_interface(self):
self.assertRaises(NotImplementedError, ContentMixin().determine_content, None) # """Ensure the ContentMixin interface is as expected."""
# self.assertRaises(NotImplementedError, ContentMixin().determine_content, None)
def test_standard_content_mixin_interface(self): #
"""Ensure the OverloadedContentMixin interface is as expected.""" # def test_standard_content_mixin_interface(self):
self.assertTrue(issubclass(StandardContentMixin, ContentMixin)) # """Ensure the OverloadedContentMixin interface is as expected."""
getattr(StandardContentMixin, 'determine_content') # self.assertTrue(issubclass(StandardContentMixin, ContentMixin))
# getattr(StandardContentMixin, 'determine_content')
def test_overloaded_content_mixin_interface(self): #
"""Ensure the OverloadedContentMixin interface is as expected.""" # def test_overloaded_content_mixin_interface(self):
self.assertTrue(issubclass(OverloadedContentMixin, ContentMixin)) # """Ensure the OverloadedContentMixin interface is as expected."""
getattr(OverloadedContentMixin, 'CONTENT_PARAM') # self.assertTrue(issubclass(OverloadedContentMixin, ContentMixin))
getattr(OverloadedContentMixin, 'CONTENTTYPE_PARAM') # getattr(OverloadedContentMixin, 'CONTENT_PARAM')
getattr(OverloadedContentMixin, 'determine_content') # getattr(OverloadedContentMixin, 'CONTENTTYPE_PARAM')
# getattr(OverloadedContentMixin, 'determine_content')
#
# Common functionality to test with both StandardContentMixin and OverloadedContentMixin #
# # Common functionality to test with both StandardContentMixin and OverloadedContentMixin
def ensure_determines_no_content_GET(self, mixin): #
"""Ensure determine_content(request) returns None for GET request with no content.""" # def ensure_determines_no_content_GET(self, mixin):
request = self.req.get('/') # """Ensure determine_content(request) returns None for GET request with no content."""
self.assertEqual(mixin.determine_content(request), None) # request = self.req.get('/')
# self.assertEqual(mixin.determine_content(request), None)
def ensure_determines_form_content_POST(self, mixin): #
"""Ensure determine_content(request) returns content for POST request with content.""" # def ensure_determines_form_content_POST(self, mixin):
form_data = {'qwerty': 'uiop'} # """Ensure determine_content(request) returns content for POST request with content."""
request = self.req.post('/', data=form_data) # form_data = {'qwerty': 'uiop'}
self.assertEqual(mixin.determine_content(request), (request.META['CONTENT_TYPE'], request.raw_post_data)) # request = self.req.post('/', data=form_data)
# self.assertEqual(mixin.determine_content(request), (request.META['CONTENT_TYPE'], request.raw_post_data))
def ensure_determines_non_form_content_POST(self, mixin): #
"""Ensure determine_content(request) returns (content type, content) for POST request with content.""" # def ensure_determines_non_form_content_POST(self, mixin):
content = 'qwerty' # """Ensure determine_content(request) returns (content type, content) for POST request with content."""
content_type = 'text/plain' # content = 'qwerty'
request = self.req.post('/', content, content_type=content_type) # content_type = 'text/plain'
self.assertEqual(mixin.determine_content(request), (content_type, content)) # request = self.req.post('/', content, content_type=content_type)
# self.assertEqual(mixin.determine_content(request), (content_type, content))
def ensure_determines_form_content_PUT(self, mixin): #
"""Ensure determine_content(request) returns content for PUT request with content.""" # def ensure_determines_form_content_PUT(self, mixin):
form_data = {'qwerty': 'uiop'} # """Ensure determine_content(request) returns content for PUT request with content."""
request = self.req.put('/', data=form_data) # form_data = {'qwerty': 'uiop'}
self.assertEqual(mixin.determine_content(request), (request.META['CONTENT_TYPE'], request.raw_post_data)) # request = self.req.put('/', data=form_data)
# self.assertEqual(mixin.determine_content(request), (request.META['CONTENT_TYPE'], request.raw_post_data))
def ensure_determines_non_form_content_PUT(self, mixin): #
"""Ensure determine_content(request) returns (content type, content) for PUT request with content.""" # def ensure_determines_non_form_content_PUT(self, mixin):
content = 'qwerty' # """Ensure determine_content(request) returns (content type, content) for PUT request with content."""
content_type = 'text/plain' # content = 'qwerty'
request = self.req.put('/', content, content_type=content_type) # content_type = 'text/plain'
self.assertEqual(mixin.determine_content(request), (content_type, content)) # request = self.req.put('/', content, content_type=content_type)
# self.assertEqual(mixin.determine_content(request), (content_type, content))
# StandardContentMixin behavioural tests #
# # StandardContentMixin behavioural tests
def test_standard_behaviour_determines_no_content_GET(self): #
"""Ensure StandardContentMixin.determine_content(request) returns None for GET request with no content.""" # def test_standard_behaviour_determines_no_content_GET(self):
self.ensure_determines_no_content_GET(StandardContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns None for GET request with no content."""
# self.ensure_determines_no_content_GET(StandardContentMixin())
def test_standard_behaviour_determines_form_content_POST(self): #
"""Ensure StandardContentMixin.determine_content(request) returns content for POST request with content.""" # def test_standard_behaviour_determines_form_content_POST(self):
self.ensure_determines_form_content_POST(StandardContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns content for POST request with content."""
# self.ensure_determines_form_content_POST(StandardContentMixin())
def test_standard_behaviour_determines_non_form_content_POST(self): #
"""Ensure StandardContentMixin.determine_content(request) returns (content type, content) for POST request with content.""" # def test_standard_behaviour_determines_non_form_content_POST(self):
self.ensure_determines_non_form_content_POST(StandardContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns (content type, content) for POST request with content."""
# self.ensure_determines_non_form_content_POST(StandardContentMixin())
def test_standard_behaviour_determines_form_content_PUT(self): #
"""Ensure StandardContentMixin.determine_content(request) returns content for PUT request with content.""" # def test_standard_behaviour_determines_form_content_PUT(self):
self.ensure_determines_form_content_PUT(StandardContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns content for PUT request with content."""
# self.ensure_determines_form_content_PUT(StandardContentMixin())
def test_standard_behaviour_determines_non_form_content_PUT(self): #
"""Ensure StandardContentMixin.determine_content(request) returns (content type, content) for PUT request with content.""" # def test_standard_behaviour_determines_non_form_content_PUT(self):
self.ensure_determines_non_form_content_PUT(StandardContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns (content type, content) for PUT request with content."""
# self.ensure_determines_non_form_content_PUT(StandardContentMixin())
# OverloadedContentMixin behavioural tests #
# # OverloadedContentMixin behavioural tests
def test_overloaded_behaviour_determines_no_content_GET(self): #
"""Ensure StandardContentMixin.determine_content(request) returns None for GET request with no content.""" # def test_overloaded_behaviour_determines_no_content_GET(self):
self.ensure_determines_no_content_GET(OverloadedContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns None for GET request with no content."""
# self.ensure_determines_no_content_GET(OverloadedContentMixin())
def test_overloaded_behaviour_determines_form_content_POST(self): #
"""Ensure StandardContentMixin.determine_content(request) returns content for POST request with content.""" # def test_overloaded_behaviour_determines_form_content_POST(self):
self.ensure_determines_form_content_POST(OverloadedContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns content for POST request with content."""
# self.ensure_determines_form_content_POST(OverloadedContentMixin())
def test_overloaded_behaviour_determines_non_form_content_POST(self): #
"""Ensure StandardContentMixin.determine_content(request) returns (content type, content) for POST request with content.""" # def test_overloaded_behaviour_determines_non_form_content_POST(self):
self.ensure_determines_non_form_content_POST(OverloadedContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns (content type, content) for POST request with content."""
# self.ensure_determines_non_form_content_POST(OverloadedContentMixin())
def test_overloaded_behaviour_determines_form_content_PUT(self): #
"""Ensure StandardContentMixin.determine_content(request) returns content for PUT request with content.""" # def test_overloaded_behaviour_determines_form_content_PUT(self):
self.ensure_determines_form_content_PUT(OverloadedContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns content for PUT request with content."""
# self.ensure_determines_form_content_PUT(OverloadedContentMixin())
def test_overloaded_behaviour_determines_non_form_content_PUT(self): #
"""Ensure StandardContentMixin.determine_content(request) returns (content type, content) for PUT request with content.""" # def test_overloaded_behaviour_determines_non_form_content_PUT(self):
self.ensure_determines_non_form_content_PUT(OverloadedContentMixin()) # """Ensure StandardContentMixin.determine_content(request) returns (content type, content) for PUT request with content."""
# self.ensure_determines_non_form_content_PUT(OverloadedContentMixin())
def test_overloaded_behaviour_allows_content_tunnelling(self): #
"""Ensure determine_content(request) returns (content type, content) for overloaded POST request""" # def test_overloaded_behaviour_allows_content_tunnelling(self):
content = 'qwerty' # """Ensure determine_content(request) returns (content type, content) for overloaded POST request"""
content_type = 'text/plain' # content = 'qwerty'
form_data = {OverloadedContentMixin.CONTENT_PARAM: content, # content_type = 'text/plain'
OverloadedContentMixin.CONTENTTYPE_PARAM: content_type} # form_data = {OverloadedContentMixin.CONTENT_PARAM: content,
request = self.req.post('/', form_data) # OverloadedContentMixin.CONTENTTYPE_PARAM: content_type}
self.assertEqual(OverloadedContentMixin().determine_content(request), (content_type, content)) # request = self.req.post('/', form_data)
self.assertEqual(request.META['CONTENT_TYPE'], content_type) # self.assertEqual(OverloadedContentMixin().determine_content(request), (content_type, content))
# self.assertEqual(request.META['CONTENT_TYPE'], content_type)
def test_overloaded_behaviour_allows_content_tunnelling_content_type_not_set(self): #
"""Ensure determine_content(request) returns (None, content) for overloaded POST request with content type not set""" # def test_overloaded_behaviour_allows_content_tunnelling_content_type_not_set(self):
content = 'qwerty' # """Ensure determine_content(request) returns (None, content) for overloaded POST request with content type not set"""
request = self.req.post('/', {OverloadedContentMixin.CONTENT_PARAM: content}) # content = 'qwerty'
self.assertEqual(OverloadedContentMixin().determine_content(request), (None, content)) # request = self.req.post('/', {OverloadedContentMixin.CONTENT_PARAM: content})
# self.assertEqual(OverloadedContentMixin().determine_content(request), (None, content))

View File

@ -0,0 +1,37 @@
from django.test import TestCase
from django import forms
from djangorestframework.compat import RequestFactory
from djangorestframework.resource import Resource
import StringIO
class UploadFilesTests(TestCase):
"""Check uploading of files"""
def setUp(self):
self.factory = RequestFactory()
def test_upload_file(self):
class FileForm(forms.Form):
file = forms.FileField
class MockResource(Resource):
allowed_methods = anon_allowed_methods = ('POST',)
form = FileForm
def post(self, request, auth, content, *args, **kwargs):
#self.uploaded = content.file
return {'FILE_NAME': content['file'].name,
'FILE_CONTENT': content['file'].read()}
file = StringIO.StringIO('stuff')
file.name = 'stuff.txt'
request = self.factory.post('/', {'file': file})
view = MockResource.as_view()
response = view(request)
self.assertEquals(response.content, '{"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"}')

View File

@ -1,52 +1,53 @@
from django.test import TestCase # TODO: Refactor these tests
from djangorestframework.compat import RequestFactory #from django.test import TestCase
from djangorestframework.methods import MethodMixin, StandardMethodMixin, OverloadedPOSTMethodMixin #from djangorestframework.compat import RequestFactory
#from djangorestframework.methods import MethodMixin, StandardMethodMixin, OverloadedPOSTMethodMixin
#
class TestMethodMixins(TestCase): #
def setUp(self): #class TestMethodMixins(TestCase):
self.req = RequestFactory() # def setUp(self):
# self.req = RequestFactory()
# Interface tests #
# # Interface tests
def test_method_mixin_interface(self): #
"""Ensure the base ContentMixin interface is as expected.""" # def test_method_mixin_interface(self):
self.assertRaises(NotImplementedError, MethodMixin().determine_method, None) # """Ensure the base ContentMixin interface is as expected."""
# self.assertRaises(NotImplementedError, MethodMixin().determine_method, None)
def test_standard_method_mixin_interface(self): #
"""Ensure the StandardMethodMixin interface is as expected.""" # def test_standard_method_mixin_interface(self):
self.assertTrue(issubclass(StandardMethodMixin, MethodMixin)) # """Ensure the StandardMethodMixin interface is as expected."""
getattr(StandardMethodMixin, 'determine_method') # self.assertTrue(issubclass(StandardMethodMixin, MethodMixin))
# getattr(StandardMethodMixin, 'determine_method')
def test_overloaded_method_mixin_interface(self): #
"""Ensure the OverloadedPOSTMethodMixin interface is as expected.""" # def test_overloaded_method_mixin_interface(self):
self.assertTrue(issubclass(OverloadedPOSTMethodMixin, MethodMixin)) # """Ensure the OverloadedPOSTMethodMixin interface is as expected."""
getattr(OverloadedPOSTMethodMixin, 'METHOD_PARAM') # self.assertTrue(issubclass(OverloadedPOSTMethodMixin, MethodMixin))
getattr(OverloadedPOSTMethodMixin, 'determine_method') # getattr(OverloadedPOSTMethodMixin, 'METHOD_PARAM')
# getattr(OverloadedPOSTMethodMixin, 'determine_method')
# Behavioural tests #
# # Behavioural tests
def test_standard_behaviour_determines_GET(self): #
"""GET requests identified as GET method with StandardMethodMixin""" # def test_standard_behaviour_determines_GET(self):
request = self.req.get('/') # """GET requests identified as GET method with StandardMethodMixin"""
self.assertEqual(StandardMethodMixin().determine_method(request), 'GET') # request = self.req.get('/')
# self.assertEqual(StandardMethodMixin().determine_method(request), 'GET')
def test_standard_behaviour_determines_POST(self): #
"""POST requests identified as POST method with StandardMethodMixin""" # def test_standard_behaviour_determines_POST(self):
request = self.req.post('/') # """POST requests identified as POST method with StandardMethodMixin"""
self.assertEqual(StandardMethodMixin().determine_method(request), 'POST') # request = self.req.post('/')
# self.assertEqual(StandardMethodMixin().determine_method(request), 'POST')
def test_overloaded_POST_behaviour_determines_GET(self): #
"""GET requests identified as GET method with OverloadedPOSTMethodMixin""" # def test_overloaded_POST_behaviour_determines_GET(self):
request = self.req.get('/') # """GET requests identified as GET method with OverloadedPOSTMethodMixin"""
self.assertEqual(OverloadedPOSTMethodMixin().determine_method(request), 'GET') # request = self.req.get('/')
# self.assertEqual(OverloadedPOSTMethodMixin().determine_method(request), 'GET')
def test_overloaded_POST_behaviour_determines_POST(self): #
"""POST requests identified as POST method with OverloadedPOSTMethodMixin""" # def test_overloaded_POST_behaviour_determines_POST(self):
request = self.req.post('/') # """POST requests identified as POST method with OverloadedPOSTMethodMixin"""
self.assertEqual(OverloadedPOSTMethodMixin().determine_method(request), 'POST') # request = self.req.post('/')
# self.assertEqual(OverloadedPOSTMethodMixin().determine_method(request), 'POST')
def test_overloaded_POST_behaviour_determines_overloaded_method(self): #
"""POST requests can be overloaded to another method by setting a reserved form field with OverloadedPOSTMethodMixin""" # def test_overloaded_POST_behaviour_determines_overloaded_method(self):
request = self.req.post('/', {OverloadedPOSTMethodMixin.METHOD_PARAM: 'DELETE'}) # """POST requests can be overloaded to another method by setting a reserved form field with OverloadedPOSTMethodMixin"""
self.assertEqual(OverloadedPOSTMethodMixin().determine_method(request), 'DELETE') # request = self.req.post('/', {OverloadedPOSTMethodMixin.METHOD_PARAM: 'DELETE'})
# self.assertEqual(OverloadedPOSTMethodMixin().determine_method(request), 'DELETE')

View File

@ -1,12 +1,13 @@
""" """
.. ..
>>> from djangorestframework.parsers import FormParser >>> from djangorestframework.parsers import FormParser
>>> from djangorestframework.resource import Resource
>>> from djangorestframework.compat import RequestFactory >>> from djangorestframework.compat import RequestFactory
>>> from djangorestframework.resource import Resource
>>> from StringIO import StringIO
>>> from urllib import urlencode >>> from urllib import urlencode
>>> req = RequestFactory().get('/') >>> req = RequestFactory().get('/')
>>> some_resource = Resource() >>> some_resource = Resource()
>>> trash = some_resource.dispatch(req)# Some variables are set only when calling dispatch >>> some_resource.request = req # Make as if this request had been dispatched
FormParser FormParser
============ ============
@ -23,7 +24,7 @@ Here is some example data, which would eventually be sent along with a post requ
Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter : Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter :
>>> FormParser(some_resource).parse(inpt) == {'key1': 'bla1', 'key2': 'blo1'} >>> FormParser(some_resource).parse(StringIO(inpt)) == {'key1': 'bla1', 'key2': 'blo1'}
True True
However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` : However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` :
@ -35,7 +36,7 @@ However, you can customize this behaviour by subclassing :class:`parsers.FormPar
This new parser only flattens the lists of parameters that contain a single value. This new parser only flattens the lists of parameters that contain a single value.
>>> MyFormParser(some_resource).parse(inpt) == {'key1': 'bla1', 'key2': ['blo1', 'blo2']} >>> MyFormParser(some_resource).parse(StringIO(inpt)) == {'key1': 'bla1', 'key2': ['blo1', 'blo2']}
True True
.. note:: The same functionality is available for :class:`parsers.MultipartParser`. .. note:: The same functionality is available for :class:`parsers.MultipartParser`.
@ -60,7 +61,7 @@ The browsers usually strip the parameter completely. A hack to avoid this, and t
:class:`parsers.FormParser` strips the values ``_empty`` from all the lists. :class:`parsers.FormParser` strips the values ``_empty`` from all the lists.
>>> MyFormParser(some_resource).parse(inpt) == {'key1': 'blo1'} >>> MyFormParser(some_resource).parse(StringIO(inpt)) == {'key1': 'blo1'}
True True
Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it. Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it.
@ -70,7 +71,7 @@ Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a lis
... def is_a_list(self, key, val_list): ... def is_a_list(self, key, val_list):
... return key == 'key2' ... return key == 'key2'
... ...
>>> MyFormParser(some_resource).parse(inpt) == {'key1': 'blo1', 'key2': []} >>> MyFormParser(some_resource).parse(StringIO(inpt)) == {'key1': 'blo1', 'key2': []}
True True
Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`. Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`.
@ -81,6 +82,8 @@ from django.test import TestCase
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework.parsers import MultipartParser from djangorestframework.parsers import MultipartParser
from djangorestframework.resource import Resource from djangorestframework.resource import Resource
from djangorestframework.mediatypes import MediaType
from StringIO import StringIO
def encode_multipart_formdata(fields, files): def encode_multipart_formdata(fields, files):
"""For testing multipart parser. """For testing multipart parser.
@ -119,9 +122,9 @@ class TestMultipartParser(TestCase):
def test_multipartparser(self): def test_multipartparser(self):
"""Ensure that MultipartParser can parse multipart/form-data that contains a mix of several files and parameters.""" """Ensure that MultipartParser can parse multipart/form-data that contains a mix of several files and parameters."""
post_req = RequestFactory().post('/', self.body, content_type=self.content_type) post_req = RequestFactory().post('/', self.body, content_type=self.content_type)
some_resource = Resource() resource = Resource()
some_resource.dispatch(post_req) resource.request = post_req
parsed = MultipartParser(some_resource).parse(self.body) parsed = MultipartParser(resource).parse(StringIO(self.body))
self.assertEqual(parsed['key1'], 'val1') self.assertEqual(parsed['key1'], 'val1')
self.assertEqual(parsed['file1'].read(), 'blablabla') self.assertEqual(parsed.FILES['file1'].read(), 'blablabla')

View File

@ -143,7 +143,7 @@ class TestFormValidation(TestCase):
try: try:
validator.validate(content) validator.validate(content)
except ResponseException, exc: except ResponseException, exc:
self.assertEqual(exc.response.raw_content, {'errors': ['No content was supplied.']}) self.assertEqual(exc.response.raw_content, {'field-errors': {'qwerty': ['This field is required.']}})
else: else:
self.fail('ResourceException was not raised') #pragma: no cover self.fail('ResourceException was not raised') #pragma: no cover

View File

@ -58,6 +58,8 @@ class FormValidatorMixin(ValidatorMixin):
# Validation succeeded... # Validation succeeded...
cleaned_data = bound_form.cleaned_data cleaned_data = bound_form.cleaned_data
cleaned_data.update(bound_form.files)
# Add in any extra fields to the cleaned content... # Add in any extra fields to the cleaned content...
for key in (allowed_extra_fields_set & seen_fields_set) - set(cleaned_data.keys()): for key in (allowed_extra_fields_set & seen_fields_set) - set(cleaned_data.keys()):
cleaned_data[key] = content[key] cleaned_data[key] = content[key]
@ -95,7 +97,9 @@ class FormValidatorMixin(ValidatorMixin):
if not self.form: if not self.form:
return None return None
if content: if not content is None:
if hasattr(content, 'FILES'):
return self.form(content, content.FILES)
return self.form(content) return self.form(content)
return self.form() return self.form()
@ -157,8 +161,11 @@ class ModelFormValidatorMixin(FormValidatorMixin):
# Instantiate the ModelForm as appropriate # Instantiate the ModelForm as appropriate
if content and isinstance(content, models.Model): if content and isinstance(content, models.Model):
# Bound to an existing model instance
return OnTheFlyModelForm(instance=content) return OnTheFlyModelForm(instance=content)
elif content: elif not content is None:
if hasattr(content, 'FILES'):
return OnTheFlyModelForm(content, content.FILES)
return OnTheFlyModelForm(content) return OnTheFlyModelForm(content)
return OnTheFlyModelForm() return OnTheFlyModelForm()
@ -189,4 +196,4 @@ class ModelFormValidatorMixin(FormValidatorMixin):
return property_fields - set(as_tuple(self.exclude_fields)) return property_fields - set(as_tuple(self.exclude_fields))