Merge pull request #90 from jakul/xmlparser6

Lovely stuff!
This commit is contained in:
Tom Christie 2011-12-11 10:42:37 -08:00
commit e2f3153b13
4 changed files with 156 additions and 5 deletions

View File

@ -19,6 +19,9 @@ from djangorestframework import status
from djangorestframework.compat import yaml
from djangorestframework.response import ErrorResponse
from djangorestframework.utils.mediatypes import media_type_matches
from xml.etree import ElementTree as ET
import datetime
import decimal
__all__ = (
@ -28,6 +31,7 @@ __all__ = (
'FormParser',
'MultiPartParser',
'YAMLParser',
'XMLParser'
)
@ -167,10 +171,60 @@ class MultiPartParser(BaseParser):
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
{'detail': 'multipart parse error - %s' % unicode(exc)})
return django_parser.parse()
class XMLParser(BaseParser):
"""
XML parser.
"""
media_type = 'application/xml'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will simply be a string representing the body of the request.
`files` will always be `None`.
"""
data = {}
tree = ET.parse(stream)
for child in tree.getroot().getchildren():
data[child.tag] = self._type_convert(child.text)
return (data, None)
def _type_convert(self, value):
"""
Converts the value returned by the XMl parse into the equivalent
Python type
"""
if value is None:
return value
try:
return datetime.datetime.strptime(value,'%Y-%m-%d %H:%M:%S')
except ValueError:
pass
try:
return int(value)
except ValueError:
pass
try:
return decimal.Decimal(value)
except decimal.InvalidOperation:
pass
return value
DEFAULT_PARSERS = ( JSONParser,
FormParser,
MultiPartParser )
MultiPartParser,
XMLParser
)
if YAMLParser:
DEFAULT_PARSERS += ( YAMLParser, )
DEFAULT_PARSERS += ( YAMLParser, )

View File

@ -136,6 +136,8 @@ from cgi import parse_qs
from django import forms
from django.test import TestCase
from djangorestframework.parsers import FormParser
from djangorestframework.parsers import XMLParser
import datetime
class Form(forms.Form):
field1 = forms.CharField(max_length=3)
@ -153,3 +155,28 @@ class TestFormParser(TestCase):
(data, files) = parser.parse(stream)
self.assertEqual(Form(data).is_valid(), True)
class TestXMLParser(TestCase):
def setUp(self):
self.input = StringIO(
'<?xml version="1.0" encoding="utf-8"?>'
'<root>'
'<field_a>121.0</field_a>'
'<field_b>dasd</field_b>'
'<field_c></field_c>'
'<field_d>2011-12-25 12:45:00</field_d>'
'</root>'
)
self.data = {
'field_a': 121,
'field_b': 'dasd',
'field_c': None,
'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
}
def test_parse(self):
parser = XMLParser(None)
(data, files) = parser.parse(self.input)
self.assertEqual(data, self.data)

View File

@ -4,13 +4,16 @@ from django.test import TestCase
from djangorestframework import status
from djangorestframework.compat import View as DjangoView
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer
from djangorestframework.parsers import JSONParser, YAMLParser
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\
XMLRenderer
from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser
from djangorestframework.mixins import ResponseMixin
from djangorestframework.response import Response
from djangorestframework.utils.mediatypes import add_media_type_param
from StringIO import StringIO
import datetime
from decimal import Decimal
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
@ -224,3 +227,66 @@ if YAMLRenderer:
content = renderer.render(obj, 'application/yaml')
(data, files) = parser.parse(StringIO(content))
self.assertEquals(obj, data)
class XMLRendererTestCase(TestCase):
"""
Tests specific to the JSON Renderer
"""
def test_render_string(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 'astring'}, 'application/xml')
self.assertXMLContains(content, '<field>astring</field>')
def test_render_integer(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 111}, 'application/xml')
self.assertXMLContains(content, '<field>111</field>')
def test_render_datetime(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({
'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
}, 'application/xml')
self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
def test_render_float(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 123.4}, 'application/xml')
self.assertXMLContains(content, '<field>123.4</field>')
def test_render_decimal(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
self.assertXMLContains(content, '<field>111.2</field>')
def test_render_none(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': None}, 'application/xml')
self.assertXMLContains(content, '<field></field>')
def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))

View File

@ -150,6 +150,10 @@ class XMLRenderer():
xml.startElement(key, {})
self._to_xml(xml, value)
xml.endElement(key)
elif data is None:
# Don't output any value
pass
else:
xml.characters(smart_unicode(data))
@ -168,4 +172,4 @@ class XMLRenderer():
return stream.getvalue()
def dict2xml(input):
return XMLRenderer().dict2xml(input)
return XMLRenderer().dict2xml(input)