Merge pull request #408 from markotibold/file_and_image_fields

Added a FileField and an ImageField
This commit is contained in:
Tom Christie 2012-11-16 14:48:42 -08:00
commit e40000c834
7 changed files with 178 additions and 42 deletions

View File

@ -165,6 +165,33 @@ A floating point representation.
Corresponds to `django.db.models.fields.FloatField`. Corresponds to `django.db.models.fields.FloatField`.
## FileField
A file representation. Performs Django's standard FileField validation.
Corresponds to `django.forms.fields.FileField`.
**Signature:** `FileField(max_length=None, allow_empty_file=False)`
- `max_length` designates the maximum length for the file name.
- `allow_empty_file` designates if empty files are allowed.
## ImageField
An image representation.
Corresponds to `django.forms.fields.ImageField`.
Requires the `PIL` package.
Signature and validation is the same as with `FileField`.
---
**Note:** `FileFields` and `ImageFields` are only suitable for use with MultiPartParser, since eg json doesn't support file uploads.
Django's regular [FILE_UPLOAD_HANDLERS] are used for handling uploaded files.
--- ---
# Relational Fields # Relational Fields
@ -286,3 +313,4 @@ This field is always read-only.
* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`. * `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`.
[cite]: http://www.python.org/dev/peps/pep-0020/ [cite]: http://www.python.org/dev/peps/pep-0020/
[FILE_UPLOAD_HANDLERS]: https://docs.djangoproject.com/en/dev/ref/settings/#std:setting-FILE_UPLOAD_HANDLERS

View File

@ -3,6 +3,8 @@ import datetime
import inspect import inspect
import warnings import warnings
from io import BytesIO
from django.core import validators from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix from django.core.urlresolvers import resolve, get_script_prefix
@ -31,6 +33,7 @@ class Field(object):
creation_counter = 0 creation_counter = 0
empty = '' empty = ''
type_name = None type_name = None
_use_files = None
def __init__(self, source=None): def __init__(self, source=None):
self.parent = None self.parent = None
@ -51,7 +54,7 @@ class Field(object):
self.root = parent.root or parent self.root = parent.root or parent
self.context = self.root.context self.context = self.root.context
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
""" """
Given a dictionary and a field name, updates the dictionary `into`, Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value. with the field and it's deserialized value.
@ -166,7 +169,7 @@ class WritableField(Field):
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
""" """
Given a dictionary and a field name, updates the dictionary `into`, Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value. with the field and it's deserialized value.
@ -175,7 +178,10 @@ class WritableField(Field):
return return
try: try:
native = data[field_name] if self._use_files:
native = files[field_name]
else:
native = data[field_name]
except KeyError: except KeyError:
if self.default is not None: if self.default is not None:
native = self.default native = self.default
@ -323,7 +329,7 @@ class RelatedField(WritableField):
value = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
return self.to_native(value) return self.to_native(value)
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
if self.read_only: if self.read_only:
return return
@ -341,7 +347,7 @@ class ManyRelatedMixin(object):
value = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()] return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, field_name, into): def field_from_native(self, data, files, field_name, into):
if self.read_only: if self.read_only:
return return
@ -904,3 +910,95 @@ class FloatField(WritableField):
except (TypeError, ValueError): except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value msg = self.error_messages['invalid'] % value
raise ValidationError(msg) raise ValidationError(msg)
class FileField(WritableField):
_use_files = True
type_name = 'FileField'
widget = widgets.FileInput
default_error_messages = {
'invalid': _("No file was submitted. Check the encoding type on the form."),
'missing': _("No file was submitted."),
'empty': _("The submitted file is empty."),
'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
}
def __init__(self, *args, **kwargs):
self.max_length = kwargs.pop('max_length', None)
self.allow_empty_file = kwargs.pop('allow_empty_file', False)
super(FileField, self).__init__(*args, **kwargs)
def from_native(self, data):
if data in validators.EMPTY_VALUES:
return None
# UploadedFile objects should have name and size attributes.
try:
file_name = data.name
file_size = data.size
except AttributeError:
raise ValidationError(self.error_messages['invalid'])
if self.max_length is not None and len(file_name) > self.max_length:
error_values = {'max': self.max_length, 'length': len(file_name)}
raise ValidationError(self.error_messages['max_length'] % error_values)
if not file_name:
raise ValidationError(self.error_messages['invalid'])
if not self.allow_empty_file and not file_size:
raise ValidationError(self.error_messages['empty'])
return data
def to_native(self, value):
return value.name
class ImageField(FileField):
_use_files = True
default_error_messages = {
'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
}
def from_native(self, data):
"""
Checks that the file-upload field data contains a valid image (GIF, JPG,
PNG, possibly others -- whatever the Python Imaging Library supports).
"""
f = super(ImageField, self).from_native(data)
if f is None:
return None
# Try to import PIL in either of the two ways it can end up installed.
try:
from PIL import Image
except ImportError:
import Image
# We need to get a file object for PIL. We might have a path or we might
# have to read the data into memory.
if hasattr(data, 'temporary_file_path'):
file = data.temporary_file_path()
else:
if hasattr(data, 'read'):
file = BytesIO(data.read())
else:
file = BytesIO(data['content'])
try:
# load() could spot a truncated JPEG, but it loads the entire
# image in memory, which is a DoS vector. See #3848 and #18520.
# verify() must be called immediately after the constructor.
Image.open(file).verify()
except ImportError:
# Under PyPy, it is possible to import PIL. However, the underlying
# _imaging C module isn't available, so an ImportError will be
# raised. Catch and re-raise.
raise
except Exception: # Python Imaging Library doesn't recognize it as an image
raise ValidationError(self.error_messages['invalid_image'])
if hasattr(f, 'seek') and callable(f.seek):
f.seek(0)
return f

View File

@ -47,11 +47,10 @@ class GenericAPIView(views.APIView):
return serializer_class return serializer_class
def get_serializer(self, instance=None, data=None, files=None): def get_serializer(self, instance=None, data=None, files=None):
# TODO: add support for files
# TODO: add support for seperate serializer/deserializer # TODO: add support for seperate serializer/deserializer
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
context = self.get_serializer_context() context = self.get_serializer_context()
return serializer_class(instance, data=data, context=context) return serializer_class(instance, data=data, files=files, context=context)
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):

View File

@ -15,7 +15,7 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`. Should be mixed in with any `BaseView`.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA) serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)
self.object = serializer.save() self.object = serializer.save()
@ -89,7 +89,7 @@ class UpdateModelMixin(object):
self.object = None self.object = None
created = True created = True
serializer = self.get_serializer(self.object, data=request.DATA) serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)

View File

@ -320,7 +320,9 @@ class BrowsableAPIRenderer(BaseRenderer):
serializers.SlugRelatedField: forms.ChoiceField, serializers.SlugRelatedField: forms.ChoiceField,
serializers.ManySlugRelatedField: forms.MultipleChoiceField, serializers.ManySlugRelatedField: forms.MultipleChoiceField,
serializers.HyperlinkedRelatedField: forms.ChoiceField, serializers.HyperlinkedRelatedField: forms.ChoiceField,
serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField,
serializers.FileField: forms.FileField,
serializers.ImageField: forms.ImageField,
} }
fields = {} fields = {}

View File

@ -91,7 +91,7 @@ class BaseSerializer(Field):
_options_class = SerializerOptions _options_class = SerializerOptions
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations. _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations.
def __init__(self, instance=None, data=None, context=None, **kwargs): def __init__(self, instance=None, data=None, files=None, context=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs) super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
self.fields = copy.deepcopy(self.base_fields) self.fields = copy.deepcopy(self.base_fields)
@ -101,9 +101,11 @@ class BaseSerializer(Field):
self.context = context or {} self.context = context or {}
self.init_data = data self.init_data = data
self.init_files = files
self.object = instance self.object = instance
self._data = None self._data = None
self._files = None
self._errors = None self._errors = None
##### #####
@ -187,7 +189,7 @@ class BaseSerializer(Field):
ret.fields[key] = field ret.fields[key] = field
return ret return ret
def restore_fields(self, data): def restore_fields(self, data, files):
""" """
Core of deserialization, together with `restore_object`. Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields. Converts a dictionary of data into a dictionary of deserialized fields.
@ -196,7 +198,7 @@ class BaseSerializer(Field):
reverted_data = {} reverted_data = {}
for field_name, field in fields.items(): for field_name, field in fields.items():
try: try:
field.field_from_native(data, field_name, reverted_data) field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err: except ValidationError as err:
self._errors[field_name] = list(err.messages) self._errors[field_name] = list(err.messages)
@ -250,7 +252,7 @@ class BaseSerializer(Field):
return [self.convert_object(item) for item in obj] return [self.convert_object(item) for item in obj]
return self.convert_object(obj) return self.convert_object(obj)
def from_native(self, data): def from_native(self, data, files):
""" """
Deserialize primitives -> objects. Deserialize primitives -> objects.
""" """
@ -259,8 +261,8 @@ class BaseSerializer(Field):
return (self.from_native(item) for item in data) return (self.from_native(item) for item in data)
self._errors = {} self._errors = {}
if data is not None: if data is not None or files is not None:
attrs = self.restore_fields(data) attrs = self.restore_fields(data, files)
attrs = self.perform_validation(attrs) attrs = self.perform_validation(attrs)
else: else:
self._errors['non_field_errors'] = ['No input provided'] self._errors['non_field_errors'] = ['No input provided']
@ -288,7 +290,7 @@ class BaseSerializer(Field):
setting self.object if no errors occurred. setting self.object if no errors occurred.
""" """
if self._errors is None: if self._errors is None:
obj = self.from_native(self.init_data) obj = self.from_native(self.init_data, self.init_files)
if not self._errors: if not self._errors:
self.object = obj self.object = obj
return self._errors return self._errors
@ -440,6 +442,8 @@ class ModelSerializer(Serializer):
models.TextField: CharField, models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField, models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField, models.BooleanField: BooleanField,
models.FileField: FileField,
models.ImageField: ImageField,
} }
try: try:
return field_mapping[model_field.__class__](**kwargs) return field_mapping[model_field.__class__](**kwargs)

View File

@ -1,34 +1,39 @@
# from django.test import TestCase import StringIO
# from django import forms import datetime
# from django.test.client import RequestFactory from django.test import TestCase
# from rest_framework.views import View
# from rest_framework.response import Response
# import StringIO from rest_framework import serializers
# class UploadFilesTests(TestCase): class UploadedFile(object):
# """Check uploading of files""" def __init__(self, file, created=None):
# def setUp(self): self.file = file
# self.factory = RequestFactory() self.created = created or datetime.datetime.now()
# def test_upload_file(self):
# class FileForm(forms.Form): class UploadedFileSerializer(serializers.Serializer):
# file = forms.FileField() file = serializers.FileField()
created = serializers.DateTimeField()
# class MockView(View): def restore_object(self, attrs, instance=None):
# permissions = () if instance:
# form = FileForm instance.file = attrs['file']
instance.created = attrs['created']
return instance
return UploadedFile(**attrs)
# def post(self, request, *args, **kwargs):
# return Response({'FILE_NAME': self.CONTENT['file'].name,
# 'FILE_CONTENT': self.CONTENT['file'].read()})
# file = StringIO.StringIO('stuff') class FileSerializerTests(TestCase):
# file.name = 'stuff.txt'
# request = self.factory.post('/', {'file': file}) def test_create(self):
# view = MockView.as_view() now = datetime.datetime.now()
# response = view(request) file = StringIO.StringIO('stuff')
# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"}) file.name = 'stuff.txt'
file.size = file.len
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.object.created, uploaded_file.created)
self.assertEquals(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file)