mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-25 23:50:01 +03:00
Merge 6261d8c02b
into 59afd87cd4
This commit is contained in:
commit
8fe06939f2
|
@ -1,3 +1,64 @@
|
||||||
__version__ = '0.2.3'
|
__version__ = '0.2.3'
|
||||||
|
|
||||||
VERSION = __version__ # synonym
|
VERSION = __version__ # synonym
|
||||||
|
|
||||||
|
from djangorestframework.builtins import DjangoRestFrameworkSite
|
||||||
|
from django.utils.importlib import import_module
|
||||||
|
|
||||||
|
import imp
|
||||||
|
|
||||||
|
__all__ = ('autodiscover','site', '__version__', 'VERSION')
|
||||||
|
|
||||||
|
api = DjangoRestFrameworkSite()
|
||||||
|
|
||||||
|
# A flag to tell us if autodiscover is running. autodiscover will set this to
|
||||||
|
# True while running, and False when it finishes.
|
||||||
|
LOADING = False
|
||||||
|
|
||||||
|
def autodiscover():
|
||||||
|
"""
|
||||||
|
Auto-discover INSTALLED_APPS api.py modules and fail silently when
|
||||||
|
not present. This forces an import on them to register any api bits they
|
||||||
|
may want.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Bail out if autodiscover didn't finish loading from a previous call so
|
||||||
|
# that we avoid running autodiscover again when the URLconf is loaded by
|
||||||
|
# the exception handler to resolve the handler500 view. This prevents an
|
||||||
|
# admin.py module with errors from re-registering models and raising a
|
||||||
|
# spurious AlreadyRegistered exception.
|
||||||
|
global LOADING
|
||||||
|
if LOADING:
|
||||||
|
return
|
||||||
|
LOADING = True
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
for app in settings.INSTALLED_APPS:
|
||||||
|
# For each app, we need to look for a api.py inside that
|
||||||
|
# app's package. We can't use os.path here -- recall that modules may be
|
||||||
|
# imported different ways (think zip files) -- so we need to get
|
||||||
|
# the app's __path__ and look for api.py on that path.
|
||||||
|
|
||||||
|
# Step 1: find out the app's __path__ Import errors here will (and
|
||||||
|
# should) bubble up, but a missing __path__ (which is legal, but weird)
|
||||||
|
# fails silently -- apps that do weird things with __path__ might
|
||||||
|
# need to roll their own api registration.
|
||||||
|
try:
|
||||||
|
app_path = import_module(app).__path__
|
||||||
|
except (AttributeError, ImportError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Step 2: use imp.find_module to find the app's gargoyle_conditions.py.
|
||||||
|
# For some # reason imp.find_module raises ImportError if the app can't
|
||||||
|
# be found # but doesn't actually try to import the module. So skip this
|
||||||
|
# app if its gargoyle.py doesn't exist
|
||||||
|
try:
|
||||||
|
imp.find_module('api', app_path)
|
||||||
|
except ImportError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
import_module("%s.api" % app)
|
||||||
|
|
||||||
|
# autodiscover was successful, reset loading flag.
|
||||||
|
LOADING = False
|
||||||
|
|
|
@ -82,6 +82,7 @@ class UserLoggedInAuthentication(BaseAuthentication):
|
||||||
"""
|
"""
|
||||||
Use Django's session framework for authentication.
|
Use Django's session framework for authentication.
|
||||||
"""
|
"""
|
||||||
|
check_csrf = True
|
||||||
|
|
||||||
def authenticate(self, request):
|
def authenticate(self, request):
|
||||||
"""
|
"""
|
||||||
|
@ -91,7 +92,7 @@ class UserLoggedInAuthentication(BaseAuthentication):
|
||||||
# TODO: Switch this back to request.POST, and let FormParser/MultiPartParser deal with the consequences.
|
# TODO: Switch this back to request.POST, and let FormParser/MultiPartParser deal with the consequences.
|
||||||
if getattr(request, 'user', None) and request.user.is_active:
|
if getattr(request, 'user', None) and request.user.is_active:
|
||||||
# If this is a POST request we enforce CSRF validation.
|
# If this is a POST request we enforce CSRF validation.
|
||||||
if request.method.upper() == 'POST':
|
if request.method.upper() == 'POST' and self.check_csrf:
|
||||||
# Temporarily replace request.POST with .DATA,
|
# Temporarily replace request.POST with .DATA,
|
||||||
# so that we use our more generic request parsing
|
# so that we use our more generic request parsing
|
||||||
request._post = self.view.DATA
|
request._post = self.view.DATA
|
||||||
|
|
107
djangorestframework/builtins.py
Normal file
107
djangorestframework/builtins.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
|
||||||
|
from django.conf.urls.defaults import patterns, url, include
|
||||||
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
|
|
||||||
|
class ApiEntry(object):
|
||||||
|
def __init__(self, resource, view, prefix, resource_name):
|
||||||
|
self.resource, self.view = resource, view
|
||||||
|
self.prefix, self.resource_name = prefix, resource_name
|
||||||
|
if self.prefix is None:
|
||||||
|
self.prefix = ''
|
||||||
|
|
||||||
|
def get_urls(self):
|
||||||
|
from djangorestframework.mixins import ListModelMixin, InstanceMixin
|
||||||
|
from django.conf.urls.defaults import patterns, url
|
||||||
|
|
||||||
|
if self.prefix == '':
|
||||||
|
url_prefix = ''
|
||||||
|
else:
|
||||||
|
url_prefix = self.prefix + '/'
|
||||||
|
|
||||||
|
if issubclass(self.view, ListModelMixin):
|
||||||
|
urlpatterns = patterns('',
|
||||||
|
url(r'^%s%s/$' % (url_prefix, self.resource_name),
|
||||||
|
self.view.as_view(resource=self.resource),
|
||||||
|
name=self.resource_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif issubclass(self.view, InstanceMixin):
|
||||||
|
urlpatterns = patterns('',
|
||||||
|
url(r'^%s%s/(?P<pk>[0-9a-zA-Z]+)/$' % (url_prefix, self.resource_name),
|
||||||
|
self.view.as_view(resource=self.resource),
|
||||||
|
name=self.resource_name + '_change',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return urlpatterns
|
||||||
|
|
||||||
|
|
||||||
|
def urls(self):
|
||||||
|
return self.get_urls(), 'api', self.prefix
|
||||||
|
urls = property(urls)
|
||||||
|
|
||||||
|
class DjangoRestFrameworkSite(object):
|
||||||
|
app_name = 'api'
|
||||||
|
name = 'api'
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._registry = {}
|
||||||
|
super(DjangoRestFrameworkSite, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def register(self, view, resource, prefix=None, resource_name=None):
|
||||||
|
if resource_name is None:
|
||||||
|
if hasattr(resource, 'model'):
|
||||||
|
resource_name = resource.model.__name__.lower()
|
||||||
|
else:
|
||||||
|
resource_name = resource.__name__.lower()
|
||||||
|
|
||||||
|
resource.resource_name = resource_name
|
||||||
|
|
||||||
|
if prefix not in self._registry:
|
||||||
|
self._registry[prefix] = {}
|
||||||
|
|
||||||
|
if resource_name not in self._registry[prefix]:
|
||||||
|
self._registry[prefix][resource_name] = []
|
||||||
|
|
||||||
|
api_entry = ApiEntry(resource, view, prefix, resource_name)
|
||||||
|
self._registry[prefix][resource_name].append(api_entry)
|
||||||
|
|
||||||
|
|
||||||
|
# def unregister(self, prefix=None, resource_name=None, resource=None):
|
||||||
|
# """
|
||||||
|
# Unregisters a resource.
|
||||||
|
# """
|
||||||
|
# if resource_name is None and resource is not None and \
|
||||||
|
# hasattr(resource, 'model'):
|
||||||
|
# resource_name = resource.model.__name__.lower()
|
||||||
|
#
|
||||||
|
# if resource_name is None:
|
||||||
|
# # do nothing
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
# prefix_registry = self._registry.get(prefix, {})
|
||||||
|
# if resource_name in prefix_registry:
|
||||||
|
# del prefix_registry[resource_name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def urls(self):
|
||||||
|
return self.get_urls(), self.app_name, self.name
|
||||||
|
|
||||||
|
def get_urls(self):
|
||||||
|
|
||||||
|
# Site-wide views.
|
||||||
|
urlpatterns = patterns('',
|
||||||
|
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add in each resource's views.
|
||||||
|
for prefix in self._registry.keys():
|
||||||
|
for resource_name in self._registry[prefix].keys():
|
||||||
|
for api_entry in self._registry[prefix][resource_name]:
|
||||||
|
urlpatterns += patterns('',
|
||||||
|
url(r'^', include(api_entry.urls))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return urlpatterns
|
||||||
|
|
|
@ -185,7 +185,6 @@ class RequestMixin(object):
|
||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
parsers = as_tuple(self.parsers)
|
parsers = as_tuple(self.parsers)
|
||||||
|
|
||||||
for parser_cls in parsers:
|
for parser_cls in parsers:
|
||||||
parser = parser_cls(self)
|
parser = parser_cls(self)
|
||||||
if parser.can_handle_request(content_type):
|
if parser.can_handle_request(content_type):
|
||||||
|
@ -387,6 +386,7 @@ class AuthMixin(object):
|
||||||
user = self.user
|
user = self.user
|
||||||
for permission_cls in self.permissions:
|
for permission_cls in self.permissions:
|
||||||
permission = permission_cls(self)
|
permission = permission_cls(self)
|
||||||
|
permission.request = self.request
|
||||||
permission.check_permission(user)
|
permission.check_permission(user)
|
||||||
|
|
||||||
|
|
||||||
|
@ -570,7 +570,6 @@ class UpdateModelMixin(object):
|
||||||
else:
|
else:
|
||||||
# Otherwise assume the kwargs uniquely identify the model
|
# Otherwise assume the kwargs uniquely identify the model
|
||||||
self.model_instance = model.objects.get(**kwargs)
|
self.model_instance = model.objects.get(**kwargs)
|
||||||
|
|
||||||
for (key, val) in self.CONTENT.items():
|
for (key, val) in self.CONTENT.items():
|
||||||
setattr(self.model_instance, key, val)
|
setattr(self.model_instance, key, val)
|
||||||
except model.DoesNotExist:
|
except model.DoesNotExist:
|
||||||
|
@ -606,7 +605,6 @@ class ListModelMixin(object):
|
||||||
"""
|
"""
|
||||||
Behavior to list a set of `model` instances on GET requests
|
Behavior to list a set of `model` instances on GET requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# NB. Not obvious to me if it would be better to set this on the resource?
|
# NB. Not obvious to me if it would be better to set this on the resource?
|
||||||
#
|
#
|
||||||
# Presumably it's more useful to have on the view, because that way you can
|
# Presumably it's more useful to have on the view, because that way you can
|
||||||
|
|
|
@ -19,6 +19,9 @@ from djangorestframework import status
|
||||||
from djangorestframework.compat import yaml
|
from djangorestframework.compat import yaml
|
||||||
from djangorestframework.response import ErrorResponse
|
from djangorestframework.response import ErrorResponse
|
||||||
from djangorestframework.utils.mediatypes import media_type_matches
|
from djangorestframework.utils.mediatypes import media_type_matches
|
||||||
|
from xml.etree import ElementTree as ET
|
||||||
|
import datetime
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
@ -28,6 +31,7 @@ __all__ = (
|
||||||
'FormParser',
|
'FormParser',
|
||||||
'MultiPartParser',
|
'MultiPartParser',
|
||||||
'YAMLParser',
|
'YAMLParser',
|
||||||
|
'XMLParser',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,10 +171,66 @@ class MultiPartParser(BaseParser):
|
||||||
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
|
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
|
||||||
{'detail': 'multipart parse error - %s' % unicode(exc)})
|
{'detail': 'multipart parse error - %s' % unicode(exc)})
|
||||||
return django_parser.parse()
|
return django_parser.parse()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class XMLParser(BaseParser):
|
||||||
|
"""
|
||||||
|
XML parser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
media_type = 'application/xml'
|
||||||
|
|
||||||
|
def parse(self, stream):
|
||||||
|
"""
|
||||||
|
Returns a 2-tuple of `(data, files)`.
|
||||||
|
|
||||||
|
`data` will simply be a string representing the body of the request.
|
||||||
|
`files` will always be `None`.
|
||||||
|
"""
|
||||||
|
data = {}
|
||||||
|
tree = ET.parse(stream)
|
||||||
|
for child in tree.getroot().getchildren():
|
||||||
|
data[child.tag] = self._type_convert(child.text)
|
||||||
|
|
||||||
|
return (data, None)
|
||||||
|
|
||||||
|
def _type_convert(self, value):
|
||||||
|
"""
|
||||||
|
Converts the value returned by the XMl parse into the equivalent
|
||||||
|
Python type
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
return datetime.datetime.strptime(value,'%Y-%m-%d %H:%M:%S')
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
return decimal.Decimal(value)
|
||||||
|
except decimal.InvalidOperation:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PARSERS = ( JSONParser,
|
DEFAULT_PARSERS = ( JSONParser,
|
||||||
FormParser,
|
FormParser,
|
||||||
MultiPartParser )
|
MultiPartParser,
|
||||||
|
XMLParser
|
||||||
|
)
|
||||||
|
|
||||||
if YAMLParser:
|
if YAMLParser:
|
||||||
DEFAULT_PARSERS += ( YAMLParser, )
|
DEFAULT_PARSERS += ( YAMLParser, )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,14 @@ _403_FORBIDDEN_RESPONSE = ErrorResponse(
|
||||||
{'detail': 'You do not have permission to access this resource. ' +
|
{'detail': 'You do not have permission to access this resource. ' +
|
||||||
'You may need to login or otherwise authenticate the request.'})
|
'You may need to login or otherwise authenticate the request.'})
|
||||||
|
|
||||||
|
_403_NOT_LOGGED_IN_RESPONSE = ErrorResponse(
|
||||||
|
status.HTTP_403_FORBIDDEN,
|
||||||
|
{'detail': 'You need to login before you can access this resource.'})
|
||||||
|
|
||||||
|
_403_PERMISSION_DENIED_RESPONSE = ErrorResponse(
|
||||||
|
status.HTTP_403_FORBIDDEN,
|
||||||
|
{'detail': 'You do not have permission to access this resource.'})
|
||||||
|
|
||||||
_503_SERVICE_UNAVAILABLE = ErrorResponse(
|
_503_SERVICE_UNAVAILABLE = ErrorResponse(
|
||||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
{'detail': 'request was throttled'})
|
{'detail': 'request was throttled'})
|
||||||
|
@ -64,7 +72,7 @@ class IsAuthenticated(BasePermission):
|
||||||
|
|
||||||
def check_permission(self, user):
|
def check_permission(self, user):
|
||||||
if not user.is_authenticated():
|
if not user.is_authenticated():
|
||||||
raise _403_FORBIDDEN_RESPONSE
|
raise _403_NOT_LOGGED_IN_RESPONSE
|
||||||
|
|
||||||
|
|
||||||
class IsAdminUser(BasePermission):
|
class IsAdminUser(BasePermission):
|
||||||
|
|
|
@ -109,7 +109,6 @@ class JSONRenderer(BaseRenderer):
|
||||||
sort_keys = True
|
sort_keys = True
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
indent = None
|
indent = None
|
||||||
|
|
||||||
return json.dumps(obj, cls=DateTimeAwareJSONEncoder, indent=indent, sort_keys=sort_keys)
|
return json.dumps(obj, cls=DateTimeAwareJSONEncoder, indent=indent, sort_keys=sort_keys)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -212,9 +212,14 @@ class FormResource(Resource):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if data is not None or files is not None:
|
if data is not None or files is not None:
|
||||||
return form(data, files)
|
form_ = form(data=data, files=files)
|
||||||
|
else:
|
||||||
return form()
|
form_ = form()
|
||||||
|
|
||||||
|
if hasattr(self.view, 'request'):
|
||||||
|
form_.request = self.view.request
|
||||||
|
|
||||||
|
return form_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,13 +284,13 @@ class ModelResource(FormResource):
|
||||||
The list of extra fields to include. This is only used if :attr:`fields` is not set.
|
The list of extra fields to include. This is only used if :attr:`fields` is not set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, view=None, depth=None, stack=[], **kwargs):
|
def __init__(self, view):
|
||||||
"""
|
"""
|
||||||
Allow :attr:`form` and :attr:`model` attributes set on the
|
Allow :attr:`form` and :attr:`model` attributes set on the
|
||||||
:class:`View` to override the :attr:`form` and :attr:`model`
|
:class:`View` to override the :attr:`form` and :attr:`model`
|
||||||
attributes set on the :class:`Resource`.
|
attributes set on the :class:`Resource`.
|
||||||
"""
|
"""
|
||||||
super(ModelResource, self).__init__(view, depth, stack, **kwargs)
|
super(ModelResource, self).__init__(view)
|
||||||
|
|
||||||
self.model = getattr(view, 'model', None) or self.model
|
self.model = getattr(view, 'model', None) or self.model
|
||||||
|
|
||||||
|
@ -333,11 +338,17 @@ class ModelResource(FormResource):
|
||||||
if data is not None or files is not None:
|
if data is not None or files is not None:
|
||||||
if issubclass(form, forms.ModelForm) and hasattr(self.view, 'model_instance'):
|
if issubclass(form, forms.ModelForm) and hasattr(self.view, 'model_instance'):
|
||||||
# Bound to an existing model instance
|
# Bound to an existing model instance
|
||||||
return form(data, files, instance=self.view.model_instance)
|
form_ = form(data=data, files=files, instance=self.view.model_instance)
|
||||||
else:
|
else:
|
||||||
return form(data, files)
|
form_ = form(data=data, files=files)
|
||||||
|
|
||||||
return form()
|
else:
|
||||||
|
form_ = form()
|
||||||
|
|
||||||
|
if hasattr(self.view, 'request'):
|
||||||
|
form_.request = self.view.request
|
||||||
|
|
||||||
|
return form_
|
||||||
|
|
||||||
|
|
||||||
def url(self, instance):
|
def url(self, instance):
|
||||||
|
@ -355,7 +366,7 @@ class ModelResource(FormResource):
|
||||||
# dis does teh magicks...
|
# dis does teh magicks...
|
||||||
urlconf = get_urlconf()
|
urlconf = get_urlconf()
|
||||||
resolver = get_resolver(urlconf)
|
resolver = get_resolver(urlconf)
|
||||||
|
|
||||||
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
|
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
|
||||||
for tuple_item in possibilities:
|
for tuple_item in possibilities:
|
||||||
possibility = tuple_item[0]
|
possibility = tuple_item[0]
|
||||||
|
@ -379,6 +390,18 @@ class ModelResource(FormResource):
|
||||||
return reverse(self.view_callable[0], kwargs=instance_attrs)
|
return reverse(self.view_callable[0], kwargs=instance_attrs)
|
||||||
except NoReverseMatch:
|
except NoReverseMatch:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(self, 'resource_name'):
|
||||||
|
resource_name = self.resource_name
|
||||||
|
else:
|
||||||
|
resource_name = instance.__class__.__name__.split('.')[0].lower()
|
||||||
|
return reverse(
|
||||||
|
'%s:%s_change' % ('api', resource_name), args=(instance.pk,)
|
||||||
|
)
|
||||||
|
except NoReverseMatch:
|
||||||
|
pass
|
||||||
|
|
||||||
raise _SkipField
|
raise _SkipField
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -144,7 +144,10 @@ class Serializer(object):
|
||||||
|
|
||||||
|
|
||||||
def get_related_serializer(self, key):
|
def get_related_serializer(self, key):
|
||||||
info = _fields_to_dict(self.fields).get(key, None)
|
fields = _fields_to_dict(self.fields)
|
||||||
|
fields.update(_fields_to_dict(self.include))
|
||||||
|
info = fields.get(key, None)
|
||||||
|
|
||||||
|
|
||||||
# If an element in `fields` is a 2-tuple of (str, tuple)
|
# If an element in `fields` is a 2-tuple of (str, tuple)
|
||||||
# then the second element of the tuple is the fields to
|
# then the second element of the tuple is the fields to
|
||||||
|
|
|
@ -1,13 +1,23 @@
|
||||||
"""Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers."""
|
"""Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers."""
|
||||||
import os
|
from django.conf import settings
|
||||||
|
|
||||||
modules = [filename.rsplit('.', 1)[0]
|
# Try importing all tests if asked for (then we can run 'em)
|
||||||
for filename in os.listdir(os.path.dirname(__file__))
|
try:
|
||||||
if filename.endswith('.py') and not filename.startswith('_')]
|
skiptest = settings.SKIP_DJANGORESTFRAMEWORK_TESTS
|
||||||
__test__ = dict()
|
except:
|
||||||
|
skiptest = False
|
||||||
for module in modules:
|
|
||||||
exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module)
|
if not skiptest:
|
||||||
exec("from djangorestframework.tests.%s import *" % module)
|
import os
|
||||||
__test__[module] = module_doc or ""
|
|
||||||
|
modules = [filename.rsplit('.', 1)[0]
|
||||||
|
for filename in os.listdir(os.path.dirname(__file__))
|
||||||
|
if filename.endswith('.py') and not filename.startswith('_')]
|
||||||
|
__test__ = dict()
|
||||||
|
|
||||||
|
for module in modules:
|
||||||
|
exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module)
|
||||||
|
exec("from djangorestframework.tests.%s import *" % module)
|
||||||
|
__test__[module] = module_doc or ""
|
||||||
|
|
||||||
|
print 'TestXMLParser' in locals().keys()
|
||||||
|
|
|
@ -136,6 +136,8 @@ from cgi import parse_qs
|
||||||
from django import forms
|
from django import forms
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from djangorestframework.parsers import FormParser
|
from djangorestframework.parsers import FormParser
|
||||||
|
from djangorestframework.parsers import XMLParser
|
||||||
|
import datetime
|
||||||
|
|
||||||
class Form(forms.Form):
|
class Form(forms.Form):
|
||||||
field1 = forms.CharField(max_length=3)
|
field1 = forms.CharField(max_length=3)
|
||||||
|
@ -153,3 +155,30 @@ class TestFormParser(TestCase):
|
||||||
(data, files) = parser.parse(stream)
|
(data, files) = parser.parse(stream)
|
||||||
|
|
||||||
self.assertEqual(Form(data).is_valid(), True)
|
self.assertEqual(Form(data).is_valid(), True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestXMLParser(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.input = StringIO(
|
||||||
|
'<?xml version="1.0" encoding="utf-8"?>'
|
||||||
|
'<root>'
|
||||||
|
'<field_a>121.0</field_a>'
|
||||||
|
'<field_b>dasd</field_b>'
|
||||||
|
'<field_c></field_c>'
|
||||||
|
'<field_d>2011-12-25 12:45:00</field_d>'
|
||||||
|
'</root>'
|
||||||
|
)
|
||||||
|
self.data = {
|
||||||
|
'field_a': 121,
|
||||||
|
'field_b': 'dasd',
|
||||||
|
'field_c': None,
|
||||||
|
'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_parse(self):
|
||||||
|
parser = XMLParser(None)
|
||||||
|
|
||||||
|
(data, files) = parser.parse(self.input)
|
||||||
|
self.assertEqual(data, self.data)
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,16 @@ from django.test import TestCase
|
||||||
|
|
||||||
from djangorestframework import status
|
from djangorestframework import status
|
||||||
from djangorestframework.compat import View as DjangoView
|
from djangorestframework.compat import View as DjangoView
|
||||||
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer
|
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\
|
||||||
from djangorestframework.parsers import JSONParser, YAMLParser
|
XMLRenderer
|
||||||
|
from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser
|
||||||
from djangorestframework.mixins import ResponseMixin
|
from djangorestframework.mixins import ResponseMixin
|
||||||
from djangorestframework.response import Response
|
from djangorestframework.response import Response
|
||||||
from djangorestframework.utils.mediatypes import add_media_type_param
|
from djangorestframework.utils.mediatypes import add_media_type_param
|
||||||
|
|
||||||
from StringIO import StringIO
|
from StringIO import StringIO
|
||||||
|
import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
DUMMYSTATUS = status.HTTP_200_OK
|
DUMMYSTATUS = status.HTTP_200_OK
|
||||||
DUMMYCONTENT = 'dummycontent'
|
DUMMYCONTENT = 'dummycontent'
|
||||||
|
@ -223,4 +226,66 @@ if YAMLRenderer:
|
||||||
|
|
||||||
content = renderer.render(obj, 'application/yaml')
|
content = renderer.render(obj, 'application/yaml')
|
||||||
(data, files) = parser.parse(StringIO(content))
|
(data, files) = parser.parse(StringIO(content))
|
||||||
self.assertEquals(obj, data)
|
self.assertEquals(obj, data)
|
||||||
|
|
||||||
|
|
||||||
|
class XMLRendererTestCase(TestCase):
|
||||||
|
"""
|
||||||
|
Tests specific to the JSON Renderer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_render_string(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': 'astring'}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>astring</field>')
|
||||||
|
|
||||||
|
def test_render_integer(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': 111}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>111</field>')
|
||||||
|
|
||||||
|
def test_render_datetime(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({
|
||||||
|
'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
|
||||||
|
}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
|
||||||
|
|
||||||
|
def test_render_float(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': 123.4}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>123.4</field>')
|
||||||
|
|
||||||
|
def test_render_decimal(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field>111.2</field>')
|
||||||
|
|
||||||
|
def test_render_none(self):
|
||||||
|
"""
|
||||||
|
Test XML rendering.
|
||||||
|
"""
|
||||||
|
renderer = XMLRenderer(None)
|
||||||
|
content = renderer.render({'field': None}, 'application/xml')
|
||||||
|
self.assertXMLContains(content, '<field></field>')
|
||||||
|
|
||||||
|
|
||||||
|
def assertXMLContains(self, xml, string):
|
||||||
|
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
|
||||||
|
self.assertTrue(xml.endswith('</root>'))
|
||||||
|
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
|
|
@ -150,7 +150,9 @@ class XMLRenderer():
|
||||||
xml.startElement(key, {})
|
xml.startElement(key, {})
|
||||||
self._to_xml(xml, value)
|
self._to_xml(xml, value)
|
||||||
xml.endElement(key)
|
xml.endElement(key)
|
||||||
|
elif data is None:
|
||||||
|
# Don't output any value
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
xml.characters(smart_unicode(data))
|
xml.characters(smart_unicode(data))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user