mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-25 00:34:21 +03:00
commit
e2f3153b13
|
@ -19,6 +19,9 @@ from djangorestframework import status
|
||||||
from djangorestframework.compat import yaml
|
from djangorestframework.compat import yaml
|
||||||
from djangorestframework.response import ErrorResponse
|
from djangorestframework.response import ErrorResponse
|
||||||
from djangorestframework.utils.mediatypes import media_type_matches
|
from djangorestframework.utils.mediatypes import media_type_matches
|
||||||
|
from xml.etree import ElementTree as ET
|
||||||
|
import datetime
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
@ -28,6 +31,7 @@ __all__ = (
|
||||||
'FormParser',
|
'FormParser',
|
||||||
'MultiPartParser',
|
'MultiPartParser',
|
||||||
'YAMLParser',
|
'YAMLParser',
|
||||||
|
'XMLParser'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,10 +171,60 @@ class MultiPartParser(BaseParser):
|
||||||
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
|
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
|
||||||
{'detail': 'multipart parse error - %s' % unicode(exc)})
|
{'detail': 'multipart parse error - %s' % unicode(exc)})
|
||||||
return django_parser.parse()
|
return django_parser.parse()
|
||||||
|
|
||||||
|
|
||||||
|
class XMLParser(BaseParser):
|
||||||
|
"""
|
||||||
|
XML parser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
media_type = 'application/xml'
|
||||||
|
|
||||||
|
def parse(self, stream):
|
||||||
|
"""
|
||||||
|
Returns a 2-tuple of `(data, files)`.
|
||||||
|
|
||||||
|
`data` will simply be a string representing the body of the request.
|
||||||
|
`files` will always be `None`.
|
||||||
|
"""
|
||||||
|
data = {}
|
||||||
|
tree = ET.parse(stream)
|
||||||
|
for child in tree.getroot().getchildren():
|
||||||
|
data[child.tag] = self._type_convert(child.text)
|
||||||
|
|
||||||
|
return (data, None)
|
||||||
|
|
||||||
|
def _type_convert(self, value):
|
||||||
|
"""
|
||||||
|
Converts the value returned by the XMl parse into the equivalent
|
||||||
|
Python type
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
return datetime.datetime.strptime(value,'%Y-%m-%d %H:%M:%S')
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return decimal.Decimal(value)
|
||||||
|
except decimal.InvalidOperation:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PARSERS = ( JSONParser,
|
DEFAULT_PARSERS = ( JSONParser,
|
||||||
FormParser,
|
FormParser,
|
||||||
MultiPartParser )
|
MultiPartParser,
|
||||||
|
XMLParser
|
||||||
|
)
|
||||||
|
|
||||||
if YAMLParser:
|
if YAMLParser:
|
||||||
DEFAULT_PARSERS += ( YAMLParser, )
|
DEFAULT_PARSERS += ( YAMLParser, )
|
|
@ -136,6 +136,8 @@ from cgi import parse_qs
|
||||||
from django import forms
|
from django import forms
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from djangorestframework.parsers import FormParser
|
from djangorestframework.parsers import FormParser
|
||||||
|
from djangorestframework.parsers import XMLParser
|
||||||
|
import datetime
|
||||||
|
|
||||||
class Form(forms.Form):
|
class Form(forms.Form):
|
||||||
field1 = forms.CharField(max_length=3)
|
field1 = forms.CharField(max_length=3)
|
||||||
|
@ -153,3 +155,28 @@ class TestFormParser(TestCase):
|
||||||
(data, files) = parser.parse(stream)
|
(data, files) = parser.parse(stream)
|
||||||
|
|
||||||
self.assertEqual(Form(data).is_valid(), True)
|
self.assertEqual(Form(data).is_valid(), True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestXMLParser(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.input = StringIO(
|
||||||
|
'<?xml version="1.0" encoding="utf-8"?>'
|
||||||
|
'<root>'
|
||||||
|
'<field_a>121.0</field_a>'
|
||||||
|
'<field_b>dasd</field_b>'
|
||||||
|
'<field_c></field_c>'
|
||||||
|
'<field_d>2011-12-25 12:45:00</field_d>'
|
||||||
|
'</root>'
|
||||||
|
)
|
||||||
|
self.data = {
|
||||||
|
'field_a': 121,
|
||||||
|
'field_b': 'dasd',
|
||||||
|
'field_c': None,
|
||||||
|
'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_parse(self):
|
||||||
|
parser = XMLParser(None)
|
||||||
|
(data, files) = parser.parse(self.input)
|
||||||
|
self.assertEqual(data, self.data)
|
|
@ -4,13 +4,16 @@ from django.test import TestCase
|
||||||
|
|
||||||
from djangorestframework import status
|
from djangorestframework import status
|
||||||
from djangorestframework.compat import View as DjangoView
|
from djangorestframework.compat import View as DjangoView
|
||||||
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer
|
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\
|
||||||
from djangorestframework.parsers import JSONParser, YAMLParser
|
XMLRenderer
|
||||||
|
from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser
|
||||||
from djangorestframework.mixins import ResponseMixin
|
from djangorestframework.mixins import ResponseMixin
|
||||||
from djangorestframework.response import Response
|
from djangorestframework.response import Response
|
||||||
from djangorestframework.utils.mediatypes import add_media_type_param
|
from djangorestframework.utils.mediatypes import add_media_type_param
|
||||||
|
|
||||||
from StringIO import StringIO
|
from StringIO import StringIO
|
||||||
|
import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
DUMMYSTATUS = status.HTTP_200_OK
|
DUMMYSTATUS = status.HTTP_200_OK
|
||||||
DUMMYCONTENT = 'dummycontent'
|
DUMMYCONTENT = 'dummycontent'
|
||||||
|
@ -224,3 +227,66 @@ if YAMLRenderer:
|
||||||
content = renderer.render(obj, 'application/yaml')
|
content = renderer.render(obj, 'application/yaml')
|
||||||
(data, files) = parser.parse(StringIO(content))
|
(data, files) = parser.parse(StringIO(content))
|
||||||
self.assertEquals(obj, data)
|
self.assertEquals(obj, data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class XMLRendererTestCase(TestCase):
|
||||||
|
"""
|
||||||
|
Tests specific to the JSON Renderer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_render_string(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': 'astring'}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>astring</field>')
|
||||||
|
|
||||||
|
def test_render_integer(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': 111}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>111</field>')
|
||||||
|
|
||||||
|
def test_render_datetime(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({
|
||||||
|
'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
|
||||||
|
}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
|
||||||
|
|
||||||
|
def test_render_float(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': 123.4}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>123.4</field>')
|
||||||
|
|
||||||
|
def test_render_decimal(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>111.2</field>')
|
||||||
|
|
||||||
|
def test_render_none(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': None}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field></field>')
|
||||||
|
|
||||||
|
|
||||||
|
def assertXMLContains(self, xml, string):
|
||||||
|
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
|
||||||
|
self.assertTrue(xml.endswith('</root>'))
|
||||||
|
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
|
||||||
|
|
|
@ -150,6 +150,10 @@ class XMLRenderer():
|
||||||
xml.startElement(key, {})
|
xml.startElement(key, {})
|
||||||
self._to_xml(xml, value)
|
self._to_xml(xml, value)
|
||||||
xml.endElement(key)
|
xml.endElement(key)
|
||||||
|
|
||||||
|
elif data is None:
|
||||||
|
# Don't output any value
|
||||||
|
pass
|
||||||
|
|
||||||
else:
|
else:
|
||||||
xml.characters(smart_unicode(data))
|
xml.characters(smart_unicode(data))
|
||||||
|
@ -168,4 +172,4 @@ class XMLRenderer():
|
||||||
return stream.getvalue()
|
return stream.getvalue()
|
||||||
|
|
||||||
def dict2xml(input):
|
def dict2xml(input):
|
||||||
return XMLRenderer().dict2xml(input)
|
return XMLRenderer().dict2xml(input)
|
Loading…
Reference in New Issue
Block a user