From 06ac5209acff6b5e7768534d1da983e8cf8d6e94 Mon Sep 17 00:00:00 2001 From: Craig Blaszczyk Date: Sun, 28 Aug 2011 22:54:02 +0100 Subject: [PATCH] initial XML parsing checkin --- djangorestframework/mixins.py | 1 - djangorestframework/parsers.py | 62 +++++++++++++++++++++- djangorestframework/tests/__init__.py | 1 + djangorestframework/tests/parsers.py | 29 +++++++++++ djangorestframework/tests/renderers.py | 71 ++++++++++++++++++++++++-- djangorestframework/utils/__init__.py | 4 +- 6 files changed, 162 insertions(+), 6 deletions(-) diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 9fed61221..9b238d306 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -185,7 +185,6 @@ class RequestMixin(object): return (None, None) parsers = as_tuple(self.parsers) - for parser_cls in parsers: parser = parser_cls(self) if parser.can_handle_request(content_type): diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py index 2ff64bd3e..235017fdb 100644 --- a/djangorestframework/parsers.py +++ b/djangorestframework/parsers.py @@ -19,6 +19,9 @@ from djangorestframework import status from djangorestframework.compat import yaml from djangorestframework.response import ErrorResponse from djangorestframework.utils.mediatypes import media_type_matches +from xml.etree import ElementTree as ET +import datetime +import decimal __all__ = ( @@ -28,6 +31,7 @@ __all__ = ( 'FormParser', 'MultiPartParser', 'YAMLParser', + 'XMLParser', ) @@ -167,10 +171,66 @@ class MultiPartParser(BaseParser): raise ErrorResponse(status.HTTP_400_BAD_REQUEST, {'detail': 'multipart parse error - %s' % unicode(exc)}) return django_parser.parse() + + + +class XMLParser(BaseParser): + """ + XML parser. + """ + + media_type = 'application/xml' + + def parse(self, stream): + """ + Returns a 2-tuple of `(data, files)`. + + `data` will simply be a string representing the body of the request. + `files` will always be `None`. + """ + data = {} + tree = ET.parse(stream) + for child in tree.getroot().getchildren(): + data[child.tag] = self._type_convert(child.text) + + return (data, None) + + def _type_convert(self, value): + """ + Converts the value returned by the XMl parse into the equivalent + Python type + """ + if value is None: + return value + + try: + return datetime.datetime.strptime(value,'%Y-%m-%d %H:%M:%S') + except ValueError: + pass + + try: + return int(value) + except ValueError: + pass + + + try: + return decimal.Decimal(value) + except decimal.InvalidOperation: + pass + + return value + + DEFAULT_PARSERS = ( JSONParser, FormParser, - MultiPartParser ) + MultiPartParser, + XMLParser + ) if YAMLParser: DEFAULT_PARSERS += ( YAMLParser, ) + + + diff --git a/djangorestframework/tests/__init__.py b/djangorestframework/tests/__init__.py index f664c5c12..65d8ec43c 100644 --- a/djangorestframework/tests/__init__.py +++ b/djangorestframework/tests/__init__.py @@ -11,3 +11,4 @@ for module in modules: exec("from djangorestframework.tests.%s import *" % module) __test__[module] = module_doc or "" +print 'TestXMLParser' in locals().keys() diff --git a/djangorestframework/tests/parsers.py b/djangorestframework/tests/parsers.py index deba688e5..08b963304 100644 --- a/djangorestframework/tests/parsers.py +++ b/djangorestframework/tests/parsers.py @@ -136,6 +136,8 @@ from cgi import parse_qs from django import forms from django.test import TestCase from djangorestframework.parsers import FormParser +from djangorestframework.parsers import XMLParser +import datetime class Form(forms.Form): field1 = forms.CharField(max_length=3) @@ -153,3 +155,30 @@ class TestFormParser(TestCase): (data, files) = parser.parse(stream) self.assertEqual(Form(data).is_valid(), True) + + +class TestXMLParser(TestCase): + def setUp(self): + self.input = StringIO( + '' + '' + '121.0' + 'dasd' + '' + '2011-12-25 12:45:00' + '' + ) + self.data = { + 'field_a': 121, + 'field_b': 'dasd', + 'field_c': None, + 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) + + } + + def test_parse(self): + parser = XMLParser(None) + + (data, files) = parser.parse(self.input) + self.assertEqual(data, self.data) + diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py index d2046212f..9ae43cdda 100644 --- a/djangorestframework/tests/renderers.py +++ b/djangorestframework/tests/renderers.py @@ -4,13 +4,16 @@ from django.test import TestCase from djangorestframework import status from djangorestframework.compat import View as DjangoView -from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer -from djangorestframework.parsers import JSONParser, YAMLParser +from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\ + XMLRenderer +from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser from djangorestframework.mixins import ResponseMixin from djangorestframework.response import Response from djangorestframework.utils.mediatypes import add_media_type_param from StringIO import StringIO +import datetime +from decimal import Decimal DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' @@ -223,4 +226,66 @@ if YAMLRenderer: content = renderer.render(obj, 'application/yaml') (data, files) = parser.parse(StringIO(content)) - self.assertEquals(obj, data) \ No newline at end of file + self.assertEquals(obj, data) + + +class XMLRendererTestCase(TestCase): + """ + Tests specific to the JSON Renderer + """ + + def test_render_string(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 'astring'}, 'application/xml') + self.assertXMLContains(content, 'astring') + + def test_render_integer(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 111}, 'application/xml') + self.assertXMLContains(content, '111') + + def test_render_datetime(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({ + 'field': datetime.datetime(2011, 12, 25, 12, 45, 00) + }, 'application/xml') + self.assertXMLContains(content, '2011-12-25 12:45:00') + + def test_render_float(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 123.4}, 'application/xml') + self.assertXMLContains(content, '123.4') + + def test_render_decimal(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': Decimal('111.2')}, 'application/xml') + self.assertXMLContains(content, '111.2') + + def test_render_none(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': None}, 'application/xml') + self.assertXMLContains(content, '') + + + def assertXMLContains(self, xml, string): + self.assertTrue(xml.startswith('\n')) + self.assertTrue(xml.endswith('')) + self.assertTrue(string in xml, '%r not in %r' % (string, xml)) \ No newline at end of file diff --git a/djangorestframework/utils/__init__.py b/djangorestframework/utils/__init__.py index 99f9724ce..1f1a08661 100644 --- a/djangorestframework/utils/__init__.py +++ b/djangorestframework/utils/__init__.py @@ -150,7 +150,9 @@ class XMLRenderer(): xml.startElement(key, {}) self._to_xml(xml, value) xml.endElement(key) - + elif data is None: + # Don't output any value + pass else: xml.characters(smart_unicode(data))