From ecd3733c5e229505baca5a870963f2dd492d6dd7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 28 Aug 2012 15:46:38 +0100 Subject: [PATCH] Added serializers and fields --- djangorestframework/fields.py | 446 +++++++++++++++++++++++++ djangorestframework/parsers.py | 16 +- djangorestframework/request.py | 3 +- djangorestframework/serializers.py | 348 +++++++++++++++++++ djangorestframework/tests/parsers.py | 6 +- djangorestframework/tests/renderers.py | 2 +- 6 files changed, 809 insertions(+), 12 deletions(-) create mode 100644 djangorestframework/fields.py create mode 100644 djangorestframework/serializers.py diff --git a/djangorestframework/fields.py b/djangorestframework/fields.py new file mode 100644 index 000000000..a44eb417b --- /dev/null +++ b/djangorestframework/fields.py @@ -0,0 +1,446 @@ +import copy +import datetime +import inspect +import warnings + +from django.core import validators +from django.core.exceptions import ValidationError +from django.conf import settings +from django.db import DEFAULT_DB_ALIAS +from django.db.models.related import RelatedObject +from django.utils import timezone +from django.utils.dateparse import parse_date, parse_datetime +from django.utils.encoding import is_protected_type, smart_unicode +from django.utils.translation import ugettext_lazy as _ + + +def is_simple_callable(obj): + """ + True if the object is a callable that takes no arguments. + """ + return ( + (inspect.isfunction(obj) and not inspect.getargspec(obj)[0]) or + (inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1) + ) + + +class Field(object): + creation_counter = 0 + default_validators = [] + default_error_messages = { + 'required': _('This field is required.'), + 'invalid': _('Invalid value.'), + } + empty = '' + + def __init__(self, source=None, readonly=False, required=None, + validators=[], error_messages=None): + self.parent = None + + self.creation_counter = Field.creation_counter + Field.creation_counter += 1 + + self.source = source + self.readonly = readonly + self.required = not(readonly) + + messages = {} + for c in reversed(self.__class__.__mro__): + messages.update(getattr(c, 'default_error_messages', {})) + messages.update(error_messages or {}) + self.error_messages = messages + + self.validators = self.default_validators + validators + + def initialize(self, parent, model_field=None): + """ + Called to set up a field prior to field_to_native or field_from_native. + + parent - The parent serializer. + model_field - The model field this field corrosponds to, if one exists. + """ + self.parent = parent + self.root = parent.root or parent + self.context = self.root.context + if model_field: + self.model_field = model_field + + def validate(self, value): + pass + # if value in validators.EMPTY_VALUES and self.required: + # raise ValidationError(self.error_messages['required']) + + def run_validators(self, value): + if value in validators.EMPTY_VALUES: + return + errors = [] + for v in self.validators: + try: + v(value) + except ValidationError as e: + if hasattr(e, 'code') and e.code in self.error_messages: + message = self.error_messages[e.code] + if e.params: + message = message % e.params + errors.append(message) + else: + errors.extend(e.messages) + if errors: + raise ValidationError(errors) + + def field_from_native(self, data, field_name, into): + """ + Given a dictionary and a field name, updates the dictionary `into`, + with the field and it's deserialized value. + """ + if self.readonly: + return + + try: + native = data[field_name] + except KeyError: + return # TODO Consider validation behaviour, 'required' opt etc... + + value = self.from_native(native) + if self.source == '*': + if value: + into.update(value) + else: + self.validate(value) + self.run_validators(value) + into[self.source or field_name] = value + + def from_native(self, value): + """ + Reverts a simple representation back to the field's value. + """ + if hasattr(self, 'model_field'): + try: + return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value) + except: + return self.model_field.to_python(value) + return value + + def field_to_native(self, obj, field_name): + """ + Given and object and a field name, returns the value that should be + serialized for that field. + """ + if obj is None: + return self.empty + + if self.source == '*': + return self.to_native(obj) + + self.obj = obj # Need to hang onto this in the case of model fields + if hasattr(self, 'model_field'): + return self.to_native(self.model_field._get_val_from_obj(obj)) + + return self.to_native(getattr(obj, self.source or field_name)) + + def to_native(self, value): + """ + Converts the field's value into it's simple representation. + """ + if is_simple_callable(value): + value = value() + + if is_protected_type(value): + return value + elif hasattr(self, 'model_field'): + return self.model_field.value_to_string(self.obj) + return smart_unicode(value) + + def attributes(self): + """ + Returns a dictionary of attributes to be used when serializing to xml. + """ + try: + return { + "type": self.model_field.get_internal_type() + } + except AttributeError: + return {} + + +class RelatedField(Field): + """ + A base class for model related fields or related managers. + + Subclass this and override `convert` to define custom behaviour when + serializing related objects. + """ + + def field_to_native(self, obj, field_name): + obj = getattr(obj, field_name) + if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): + return [self.to_native(item) for item in obj.all()] + return self.to_native(obj) + + def attributes(self): + try: + return { + "rel": self.model_field.rel.__class__.__name__, + "to": smart_unicode(self.model_field.rel.to._meta) + } + except AttributeError: + return {} + + +class PrimaryKeyRelatedField(RelatedField): + """ + Serializes a model related field or related manager to a pk value. + """ + + # Note the we use ModelRelatedField's implementation, as we want to get the + # raw database value directly, since that won't involve another + # database lookup. + # + # An alternative implementation would simply be this... + # + # class PrimaryKeyRelatedField(RelatedField): + # def to_native(self, obj): + # return obj.pk + + def to_native(self, pk): + """ + Simply returns the object's pk. You can subclass this method to + provide different serialization behavior of the pk. + (For example returning a URL based on the model's pk.) + """ + return pk + + def field_to_native(self, obj, field_name): + try: + obj = obj.serializable_value(field_name) + except AttributeError: + field = obj._meta.get_field_by_name(field_name)[0] + obj = getattr(obj, field_name) + if obj.__class__.__name__ == 'RelatedManager': + return [self.to_native(item.pk) for item in obj.all()] + elif isinstance(field, RelatedObject): + return self.to_native(obj.pk) + raise + if obj.__class__.__name__ == 'ManyRelatedManager': + return [self.to_native(item.pk) for item in obj.all()] + return self.to_native(obj) + + def field_from_native(self, data, field_name, into): + value = data.get(field_name) + if hasattr(value, '__iter__'): + into[field_name] = [self.from_native(item) for item in value] + else: + into[field_name + '_id'] = self.from_native(value) + + +class NaturalKeyRelatedField(RelatedField): + """ + Serializes a model related field or related manager to a natural key value. + """ + is_natural_key = True # XML renderer handles these differently + + def to_native(self, obj): + if hasattr(obj, 'natural_key'): + return obj.natural_key() + return obj + + def field_from_native(self, data, field_name, into): + value = data.get(field_name) + into[self.model_field.attname] = self.from_native(value) + + def from_native(self, value): + # TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS) + manager = self.model_field.rel.to._default_manager + manager = manager.db_manager(DEFAULT_DB_ALIAS) + return manager.get_by_natural_key(*value).pk + + +class BooleanField(Field): + default_error_messages = { + 'invalid': _(u"'%s' value must be either True or False."), + } + + def from_native(self, value): + if value in (True, False): + # if value is 1 or 0 than it's equal to True or False, but we want + # to return a true bool for semantic reasons. + return bool(value) + if value in ('t', 'True', '1'): + return True + if value in ('f', 'False', '0'): + return False + raise ValidationError(self.error_messages['invalid'] % value) + + +class CharField(Field): + def __init__(self, max_length=None, min_length=None, *args, **kwargs): + self.max_length, self.min_length = max_length, min_length + super(CharField, self).__init__(*args, **kwargs) + if min_length is not None: + self.validators.append(validators.MinLengthValidator(min_length)) + if max_length is not None: + self.validators.append(validators.MaxLengthValidator(max_length)) + + def from_native(self, value): + if isinstance(value, basestring) or value is None: + return value + return smart_unicode(value) + + +class EmailField(CharField): + default_error_messages = { + 'invalid': _('Enter a valid e-mail address.'), + } + default_validators = [validators.validate_email] + + def from_native(self, value): + return super(EmailField, self).from_native(value).strip() + + def __deepcopy__(self, memo): + result = copy.copy(self) + memo[id(self)] = result + #result.widget = copy.deepcopy(self.widget, memo) + result.validators = self.validators[:] + return result + + +class DateField(Field): + default_error_messages = { + 'invalid': _(u"'%s' value has an invalid date format. It must be " + u"in YYYY-MM-DD format."), + 'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) " + u"but it is an invalid date."), + } + empty = None + + def from_native(self, value): + if value is None: + return value + if isinstance(value, datetime.datetime): + if settings.USE_TZ and timezone.is_aware(value): + # Convert aware datetimes to the default time zone + # before casting them to dates (#17742). + default_timezone = timezone.get_default_timezone() + value = timezone.make_naive(value, default_timezone) + return value.date() + if isinstance(value, datetime.date): + return value + + try: + parsed = parse_date(value) + if parsed is not None: + return parsed + except ValueError: + msg = self.error_messages['invalid_date'] % value + raise ValidationError(msg) + + msg = self.error_messages['invalid'] % value + raise ValidationError(msg) + + +class DateTimeField(Field): + default_error_messages = { + 'invalid': _(u"'%s' value has an invalid format. It must be in " + u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), + 'invalid_date': _(u"'%s' value has the correct format " + u"(YYYY-MM-DD) but it is an invalid date."), + 'invalid_datetime': _(u"'%s' value has the correct format " + u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " + u"but it is an invalid date/time."), + } + empty = None + + def from_native(self, value): + if value is None: + return value + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + if settings.USE_TZ: + # For backwards compatibility, interpret naive datetimes in + # local time. This won't work during DST change, but we can't + # do much about it, so we let the exceptions percolate up the + # call stack. + warnings.warn(u"DateTimeField received a naive datetime (%s)" + u" while time zone support is active." % value, + RuntimeWarning) + default_timezone = timezone.get_default_timezone() + value = timezone.make_aware(value, default_timezone) + return value + + try: + parsed = parse_datetime(value) + if parsed is not None: + return parsed + except ValueError: + msg = self.error_messages['invalid_datetime'] % value + raise ValidationError(msg) + + try: + parsed = parse_date(value) + if parsed is not None: + return datetime.datetime(parsed.year, parsed.month, parsed.day) + except ValueError: + msg = self.error_messages['invalid_date'] % value + raise ValidationError(msg) + + msg = self.error_messages['invalid'] % value + raise ValidationError(msg) + + +class IntegerField(Field): + default_error_messages = { + 'invalid': _('Enter a whole number.'), + 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), + 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), + } + + def __init__(self, max_value=None, min_value=None, *args, **kwargs): + self.max_value, self.min_value = max_value, min_value + super(IntegerField, self).__init__(*args, **kwargs) + + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + try: + value = int(str(value)) + except (ValueError, TypeError): + raise ValidationError(self.error_messages['invalid']) + return value + + +class FloatField(Field): + default_error_messages = { + 'invalid': _("'%s' value must be a float."), + } + + def from_native(self, value): + if value is None: + return value + try: + return float(value) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] % value + raise ValidationError(msg) + +# field_mapping = { +# models.AutoField: IntegerField, +# models.BooleanField: BooleanField, +# models.CharField: CharField, +# models.DateTimeField: DateTimeField, +# models.DateField: DateField, +# models.BigIntegerField: IntegerField, +# models.IntegerField: IntegerField, +# models.PositiveIntegerField: IntegerField, +# models.FloatField: FloatField +# } + + +# def modelfield_to_serializerfield(field): +# return field_mapping.get(type(field), Field) diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py index 1fff64f79..43ea0c4dc 100644 --- a/djangorestframework/parsers.py +++ b/djangorestframework/parsers.py @@ -57,7 +57,7 @@ class BaseParser(object): """ return media_type_matches(self.media_type, content_type) - def parse(self, stream, meta, upload_handlers): + def parse(self, stream, **opts): """ Given a *stream* to read from, return the deserialized output. Should return a 2-tuple of (data, files). @@ -72,7 +72,7 @@ class JSONParser(BaseParser): media_type = 'application/json' - def parse(self, stream, meta, upload_handlers): + def parse(self, stream, **opts): """ Returns a 2-tuple of `(data, files)`. @@ -92,7 +92,7 @@ class YAMLParser(BaseParser): media_type = 'application/yaml' - def parse(self, stream, meta, upload_handlers): + def parse(self, stream, **opts): """ Returns a 2-tuple of `(data, files)`. @@ -112,7 +112,7 @@ class PlainTextParser(BaseParser): media_type = 'text/plain' - def parse(self, stream, meta, upload_handlers): + def parse(self, stream, **opts): """ Returns a 2-tuple of `(data, files)`. @@ -129,7 +129,7 @@ class FormParser(BaseParser): media_type = 'application/x-www-form-urlencoded' - def parse(self, stream, meta, upload_handlers): + def parse(self, stream, **opts): """ Returns a 2-tuple of `(data, files)`. @@ -147,13 +147,15 @@ class MultiPartParser(BaseParser): media_type = 'multipart/form-data' - def parse(self, stream, meta, upload_handlers): + def parse(self, stream, **opts): """ Returns a 2-tuple of `(data, files)`. `data` will be a :class:`QueryDict` containing all the form parameters. `files` will be a :class:`QueryDict` containing all the form files. """ + meta = opts['meta'] + upload_handlers = opts['upload_handlers'] try: parser = DjangoMultiPartParser(meta, stream, upload_handlers) return parser.parse() @@ -168,7 +170,7 @@ class XMLParser(BaseParser): media_type = 'application/xml' - def parse(self, stream, meta, upload_handlers): + def parse(self, stream, **opts): """ Returns a 2-tuple of `(data, files)`. diff --git a/djangorestframework/request.py b/djangorestframework/request.py index 684f65914..84ca05753 100644 --- a/djangorestframework/request.py +++ b/djangorestframework/request.py @@ -214,7 +214,8 @@ class Request(object): for parser in self.get_parsers(): if parser.can_handle_request(self.content_type): - return parser.parse(self.stream, self.META, self.upload_handlers) + return parser.parse(self.stream, meta=self.META, + upload_handlers=self.upload_handlers) raise UnsupportedMediaType(self._content_type) diff --git a/djangorestframework/serializers.py b/djangorestframework/serializers.py new file mode 100644 index 000000000..46980ee6b --- /dev/null +++ b/djangorestframework/serializers.py @@ -0,0 +1,348 @@ +from decimal import Decimal +from django.core.serializers.base import DeserializedObject +from django.utils.datastructures import SortedDict +import copy +import datetime +import types +from djangorestframework.fields import * + + +class DictWithMetadata(dict): + """ + A dict-like object, that can have additional properties attached. + """ + pass + + +class SortedDictWithMetadata(SortedDict, DictWithMetadata): + """ + A sorted dict-like object, that can have additional properties attached. + """ + pass + + +class RecursionOccured(BaseException): + pass + + +def _is_protected_type(obj): + """ + True if the object is a native datatype that does not need to + be serialized further. + """ + return isinstance(obj, ( + types.NoneType, + int, long, + datetime.datetime, datetime.date, datetime.time, + float, Decimal, + basestring) + ) + + +def _get_declared_fields(bases, attrs): + """ + Create a list of serializer field instances from the passed in 'attrs', + plus any fields on the base classes (in 'bases'). + + Note that all fields from the base classes are used. + """ + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in attrs.items() + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1].creation_counter) + + # If this class is subclassing another Serializer, add that Serializer's + # fields. Note that we loop over the bases in *reverse*. This is necessary + # in order to the correct order of fields. + for base in bases[::-1]: + if hasattr(base, 'base_fields'): + fields = base.base_fields.items() + fields + + return SortedDict(fields) + + +class SerializerMetaclass(type): + def __new__(cls, name, bases, attrs): + attrs['base_fields'] = _get_declared_fields(bases, attrs) + return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) + + +class SerializerOptions(object): + """ + Meta class options for ModelSerializer + """ + def __init__(self, meta): + self.nested = getattr(meta, 'nested', False) + self.fields = getattr(meta, 'fields', ()) + self.exclude = getattr(meta, 'exclude', ()) + + +class BaseSerializer(Field): + class Meta(object): + pass + + _options_class = SerializerOptions + _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations. + + def __init__(self, data=None, instance=None, context=None, **kwargs): + super(BaseSerializer, self).__init__(**kwargs) + self.fields = copy.deepcopy(self.base_fields) + self.opts = self._options_class(self.Meta) + self.parent = None + self.root = None + + self.stack = [] + self.context = context or {} + + self.init_data = data + self.instance = instance + + self._data = None + self._errors = None + + ##### + # Methods to determine which fields to use when (de)serializing objects. + + def default_fields(self, serialize, obj=None, data=None, nested=False): + """ + Return the complete set of default fields for the object, as a dict. + """ + return {} + + def get_fields(self, serialize, obj=None, data=None, nested=False): + """ + Returns the complete set of fields for the object as a dict. + + This will be the set of any explicitly declared fields, + plus the set of fields returned by get_default_fields(). + """ + ret = SortedDict() + + # Get the explicitly declared fields + for key, field in self.fields.items(): + ret[key] = field + # Determine if the declared field corrosponds to a model field. + try: + if key == 'pk': + model_field = obj._meta.pk + else: + model_field = obj._meta.get_field_by_name(key)[0] + except: + model_field = None + # Set up the field + field.initialize(parent=self, model_field=model_field) + + # Add in the default fields + fields = self.default_fields(serialize, obj, data, nested) + for key, val in fields.items(): + if key not in ret: + ret[key] = val + + # If 'fields' is specified, use those fields, in that order. + if self.opts.fields: + new = SortedDict() + for key in self.opts.fields: + new[key] = ret[key] + ret = new + + # Remove anything in 'exclude' + if self.opts.exclude: + for key in self.opts.exclude: + ret.pop(key, None) + + return ret + + ##### + # Field methods - used when the serializer class is itself used as a field. + + def initialize(self, parent, model_field=None): + """ + Same behaviour as usual Field, except that we need to keep track + of state so that we can deal with handling maximum depth and recursion. + """ + super(BaseSerializer, self).initialize(parent, model_field) + self.stack = parent.stack[:] + if parent.opts.nested and not isinstance(parent.opts.nested, bool): + self.opts.nested = parent.opts.nested - 1 + else: + self.opts.nested = parent.opts.nested + + ##### + # Methods to convert or revert from objects <--> primative representations. + + def get_field_key(self, field_name): + """ + Return the key that should be used for a given field. + """ + return field_name + + def convert_object(self, obj): + """ + Core of serialization. + Convert an object into a dictionary of serialized field values. + """ + if obj in self.stack and not self.source == '*': + raise RecursionOccured() + self.stack.append(obj) + + ret = self._dict_class() + ret.fields = {} + + fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested) + for field_name, field in fields.items(): + key = self.get_field_key(field_name) + try: + value = field.field_to_native(obj, field_name) + except RecursionOccured: + field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name] + value = field.field_to_native(obj, field_name) + ret[key] = value + ret.fields[key] = field + return ret + + def restore_fields(self, data): + """ + Core of deserialization, together with `restore_object`. + Converts a dictionary of data into a dictionary of deserialized fields. + """ + fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested) + reverted_data = {} + for field_name, field in fields.items(): + try: + field.field_from_native(data, field_name, reverted_data) + except ValidationError as err: + self._errors[field_name] = list(err.messages) + + return reverted_data + + def restore_object(self, attrs, instance=None): + """ + Deserialize a dictionary of attributes into an object instance. + You should override this method to control how deserialized objects + are instantiated. + """ + if instance is not None: + instance.update(attrs) + return instance + return attrs + + def to_native(self, obj): + """ + Serialize objects -> primatives. + """ + if isinstance(obj, dict): + return dict([(key, self.to_native(val)) + for (key, val) in obj.items()]) + elif hasattr(obj, '__iter__'): + return (self.to_native(item) for item in obj) + return self.convert_object(obj) + + def from_native(self, data): + """ + Deserialize primatives -> objects. + """ + if hasattr(data, '__iter__') and not isinstance(data, dict): + # TODO: error data when deserializing lists + return (self.from_native(item) for item in data) + self._errors = {} + attrs = self.restore_fields(data) + if not self._errors: + return self.restore_object(attrs, instance=getattr(self, 'instance', None)) + + @property + def errors(self): + """ + Run deserialization and return error data, + setting self.object if no errors occured. + """ + if self._errors is None: + obj = self.from_native(self.init_data) + if not self._errors: + self.object = obj + return self._errors + + def is_valid(self): + return not self.errors + + @property + def data(self): + if self._data is None: + self._data = self.to_native(self.instance) + return self._data + + +class Serializer(BaseSerializer): + __metaclass__ = SerializerMetaclass + + +class ModelSerializerOptions(SerializerOptions): + """ + Meta class options for ModelSerializer + """ + def __init__(self, meta): + super(ModelSerializerOptions, self).__init__(meta) + self.model = getattr(meta, 'model', None) + + +class ModelSerializer(RelatedField, Serializer): + """ + A serializer that deals with model instances and querysets. + """ + _options_class = ModelSerializerOptions + + def default_fields(self, serialize, obj=None, data=None, nested=False): + """ + Return all the fields that should be serialized for the model. + """ + if serialize: + cls = obj.__class__ + else: + cls = self.opts.model + + opts = cls._meta.concrete_model._meta + pk_field = opts.pk + while pk_field.rel: + pk_field = pk_field.rel.to._meta.pk + fields = [pk_field] + fields += [field for field in opts.fields if field.serialize] + fields += [field for field in opts.many_to_many if field.serialize] + + ret = SortedDict() + for model_field in fields: + if model_field.rel and nested: + field = self.get_nested_field(model_field) + elif model_field.rel: + field = self.get_related_field(model_field) + else: + field = self.get_field(model_field) + field.initialize(parent=self, model_field=model_field) + ret[model_field.name] = field + return ret + + def get_nested_field(self, model_field): + """ + Creates a default instance of a nested relational field. + """ + return ModelSerializer() + + def get_related_field(self, model_field): + """ + Creates a default instance of a flat relational field. + """ + return PrimaryKeyRelatedField() + + def get_field(self, model_field): + """ + Creates a default instance of a basic field. + """ + return Field() + + def restore_object(self, attrs, instance=None): + """ + Restore the model instance. + """ + m2m_data = {} + for field in self.opts.model._meta.many_to_many: + if field.name in attrs: + m2m_data[field.name] = attrs.pop(field.name) + return DeserializedObject(self.opts.model(**attrs), m2m_data) diff --git a/djangorestframework/tests/parsers.py b/djangorestframework/tests/parsers.py index c733d9d09..a85409dc0 100644 --- a/djangorestframework/tests/parsers.py +++ b/djangorestframework/tests/parsers.py @@ -153,7 +153,7 @@ class TestFormParser(TestCase): parser = FormParser() stream = StringIO(self.string) - (data, files) = parser.parse(stream, {}, []) + (data, files) = parser.parse(stream) self.assertEqual(Form(data).is_valid(), True) @@ -203,10 +203,10 @@ class TestXMLParser(TestCase): def test_parse(self): parser = XMLParser() - (data, files) = parser.parse(self._input, {}, []) + (data, files) = parser.parse(self._input) self.assertEqual(data, self._data) def test_complex_data_parse(self): parser = XMLParser() - (data, files) = parser.parse(self._complex_data_input, {}, []) + (data, files) = parser.parse(self._complex_data_input) self.assertEqual(data, self._complex_data) diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py index 610457c71..1943d012b 100644 --- a/djangorestframework/tests/renderers.py +++ b/djangorestframework/tests/renderers.py @@ -380,7 +380,7 @@ class XMLRendererTestCase(TestCase): content = StringIO(renderer.render(self._complex_data, 'application/xml')) parser = XMLParser() - complex_data_out, dummy = parser.parse(content, {}, []) + complex_data_out, dummy = parser.parse(content) error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) self.assertEqual(self._complex_data, complex_data_out, error_msg)