Merge pull request #125 from michelelazzeri-nextage/master

application/xml parser fully compatible with application/xml render
This commit is contained in:
Marko Tibold 2012-01-13 12:36:02 -08:00
commit c204101563
3 changed files with 98 additions and 101 deletions

View File

@ -187,13 +187,33 @@ class XMLParser(BaseParser):
`data` will simply be a string representing the body of the request. `data` will simply be a string representing the body of the request.
`files` will always be `None`. `files` will always be `None`.
""" """
data = {}
tree = ET.parse(stream) tree = ET.parse(stream)
for child in tree.getroot().getchildren(): data = self._xml_convert(tree.getroot())
data[child.tag] = self._type_convert(child.text)
return (data, None) return (data, None)
def _xml_convert(self, element):
"""
convert the xml `element` into the corresponding python object
"""
children = element.getchildren()
if len(children) == 0:
return self._type_convert(element.text)
else:
# if the fist child tag is list-item means all children are list-item
if children[0].tag == "list-item":
data = []
for child in children:
data.append(self._xml_convert(child))
else:
data = {}
for child in children:
data[child.tag] = self._xml_convert(child)
return data
def _type_convert(self, value): def _type_convert(self, value):
""" """
Converts the value returned by the XMl parse into the equivalent Converts the value returned by the XMl parse into the equivalent

View File

@ -156,10 +156,9 @@ class TestFormParser(TestCase):
self.assertEqual(Form(data).is_valid(), True) self.assertEqual(Form(data).is_valid(), True)
class TestXMLParser(TestCase): class TestXMLParser(TestCase):
def setUp(self): def setUp(self):
self.input = StringIO( self._input = StringIO(
'<?xml version="1.0" encoding="utf-8"?>' '<?xml version="1.0" encoding="utf-8"?>'
'<root>' '<root>'
'<field_a>121.0</field_a>' '<field_a>121.0</field_a>'
@ -168,15 +167,45 @@ class TestXMLParser(TestCase):
'<field_d>2011-12-25 12:45:00</field_d>' '<field_d>2011-12-25 12:45:00</field_d>'
'</root>' '</root>'
) )
self.data = { self._data = {
'field_a': 121, 'field_a': 121,
'field_b': 'dasd', 'field_b': 'dasd',
'field_c': None, 'field_c': None,
'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
} }
self._complex_data_input = StringIO(
'<?xml version="1.0" encoding="utf-8"?>'
'<root>'
'<creation_date>2011-12-25 12:45:00</creation_date>'
'<sub_data_list>'
'<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>'
'<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>'
'</sub_data_list>'
'<name>name</name>'
'</root>'
)
self._complex_data = {
"creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
"name": "name",
"sub_data_list": [
{
"sub_id": 1,
"sub_name": "first"
},
{
"sub_id": 2,
"sub_name": "second"
}
]
}
def test_parse(self): def test_parse(self):
parser = XMLParser(None) parser = XMLParser(None)
(data, files) = parser.parse(self.input) (data, files) = parser.parse(self._input)
self.assertEqual(data, self.data) self.assertEqual(data, self._data)
def test_complex_data_parse(self):
parser = XMLParser(None)
(data, files) = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)

View File

@ -6,7 +6,7 @@ from djangorestframework.views import View
from djangorestframework.compat import View as DjangoView from djangorestframework.compat import View as DjangoView
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, DocumentingHTMLRenderer XMLRenderer, JSONPRenderer, DocumentingHTMLRenderer
from djangorestframework.parsers import JSONParser, YAMLParser from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser
from djangorestframework.mixins import ResponseMixin from djangorestframework.mixins import ResponseMixin
from djangorestframework.response import Response from djangorestframework.response import Response
@ -283,72 +283,6 @@ if YAMLRenderer:
self.assertEquals(obj, data) self.assertEquals(obj, data)
class XMLRendererTestCase(TestCase):
"""
Tests specific to the XML Renderer
"""
def test_render_string(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 'astring'}, 'application/xml')
self.assertXMLContains(content, '<field>astring</field>')
def test_render_integer(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 111}, 'application/xml')
self.assertXMLContains(content, '<field>111</field>')
def test_render_datetime(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({
'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
}, 'application/xml')
self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
def test_render_float(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 123.4}, 'application/xml')
self.assertXMLContains(content, '<field>123.4</field>')
def test_render_decimal(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
self.assertXMLContains(content, '<field>111.2</field>')
def test_render_none(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': None}, 'application/xml')
self.assertXMLContains(content, '<field></field>')
def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
class HTMLView(View):
renderers = (DocumentingHTMLRenderer)
def get(self, request, **kwargs):
return 'text'
urlpatterns += patterns('', urlpatterns += patterns('',
url(r'^/html$', HTMLView.as_view()), url(r'^/html$', HTMLView.as_view()),
) )
@ -429,6 +363,21 @@ class XMLRendererTestCase(TestCase):
Tests specific to the XML Renderer Tests specific to the XML Renderer
""" """
_complex_data = {
"creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
"name": "name",
"sub_data_list": [
{
"sub_id": 1,
"sub_name": "first"
},
{
"sub_id": 2,
"sub_name": "second"
}
]
}
def test_render_string(self): def test_render_string(self):
""" """
Test XML rendering. Test XML rendering.
@ -479,29 +428,28 @@ class XMLRendererTestCase(TestCase):
content = renderer.render({'field': None}, 'application/xml') content = renderer.render({'field': None}, 'application/xml')
self.assertXMLContains(content, '<field></field>') self.assertXMLContains(content, '<field></field>')
def test_render_complex_data(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render(self._complex_data, 'application/xml')
self.assertXMLContains(content, '<sub_name>first</sub_name>')
self.assertXMLContains(content, '<sub_name>second</sub_name>')
def test_render_and_parse_complex_data(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = StringIO(renderer.render(self._complex_data, 'application/xml'))
parser = XMLParser(None)
complex_data_out, dummy = parser.parse(content)
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
self.assertDictEqual(self._complex_data, complex_data_out, error_msg)
def assertXMLContains(self, xml, string): def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>')) self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml)) self.assertTrue(string in xml, '%r not in %r' % (string, xml))
class Issue122Tests(TestCase):
"""
Tests that covers #122.
"""
urls = 'djangorestframework.tests.renderers'
def test_only_html_renderer(self):
"""
Test if no recursion occurs.
"""
resp = self.client.get('/html')
def test_html_renderer_is_first(self):
"""
Test if no recursion occurs.
"""
resp = self.client.get('/html1')