This commit is contained in:
GitHub Merge Button 2011-12-09 03:28:59 -08:00
commit 8fe06939f2
13 changed files with 397 additions and 31 deletions

View File

@ -1,3 +1,64 @@
__version__ = '0.2.3'
VERSION = __version__ # synonym
from djangorestframework.builtins import DjangoRestFrameworkSite
from django.utils.importlib import import_module
import imp
__all__ = ('autodiscover','site', '__version__', 'VERSION')
api = DjangoRestFrameworkSite()
# A flag to tell us if autodiscover is running. autodiscover will set this to
# True while running, and False when it finishes.
LOADING = False
def autodiscover():
"""
Auto-discover INSTALLED_APPS api.py modules and fail silently when
not present. This forces an import on them to register any api bits they
may want.
"""
# Bail out if autodiscover didn't finish loading from a previous call so
# that we avoid running autodiscover again when the URLconf is loaded by
# the exception handler to resolve the handler500 view. This prevents an
# admin.py module with errors from re-registering models and raising a
# spurious AlreadyRegistered exception.
global LOADING
if LOADING:
return
LOADING = True
from django.conf import settings
for app in settings.INSTALLED_APPS:
# For each app, we need to look for a api.py inside that
# app's package. We can't use os.path here -- recall that modules may be
# imported different ways (think zip files) -- so we need to get
# the app's __path__ and look for api.py on that path.
# Step 1: find out the app's __path__ Import errors here will (and
# should) bubble up, but a missing __path__ (which is legal, but weird)
# fails silently -- apps that do weird things with __path__ might
# need to roll their own api registration.
try:
app_path = import_module(app).__path__
except (AttributeError, ImportError):
continue
# Step 2: use imp.find_module to find the app's gargoyle_conditions.py.
# For some # reason imp.find_module raises ImportError if the app can't
# be found # but doesn't actually try to import the module. So skip this
# app if its gargoyle.py doesn't exist
try:
imp.find_module('api', app_path)
except ImportError:
continue
import_module("%s.api" % app)
# autodiscover was successful, reset loading flag.
LOADING = False

View File

@ -82,6 +82,7 @@ class UserLoggedInAuthentication(BaseAuthentication):
"""
Use Django's session framework for authentication.
"""
check_csrf = True
def authenticate(self, request):
"""
@ -91,7 +92,7 @@ class UserLoggedInAuthentication(BaseAuthentication):
# TODO: Switch this back to request.POST, and let FormParser/MultiPartParser deal with the consequences.
if getattr(request, 'user', None) and request.user.is_active:
# If this is a POST request we enforce CSRF validation.
if request.method.upper() == 'POST':
if request.method.upper() == 'POST' and self.check_csrf:
# Temporarily replace request.POST with .DATA,
# so that we use our more generic request parsing
request._post = self.view.DATA

View File

@ -0,0 +1,107 @@
from django.conf.urls.defaults import patterns, url, include
from django.views.decorators.csrf import csrf_exempt
class ApiEntry(object):
def __init__(self, resource, view, prefix, resource_name):
self.resource, self.view = resource, view
self.prefix, self.resource_name = prefix, resource_name
if self.prefix is None:
self.prefix = ''
def get_urls(self):
from djangorestframework.mixins import ListModelMixin, InstanceMixin
from django.conf.urls.defaults import patterns, url
if self.prefix == '':
url_prefix = ''
else:
url_prefix = self.prefix + '/'
if issubclass(self.view, ListModelMixin):
urlpatterns = patterns('',
url(r'^%s%s/$' % (url_prefix, self.resource_name),
self.view.as_view(resource=self.resource),
name=self.resource_name,
)
)
elif issubclass(self.view, InstanceMixin):
urlpatterns = patterns('',
url(r'^%s%s/(?P<pk>[0-9a-zA-Z]+)/$' % (url_prefix, self.resource_name),
self.view.as_view(resource=self.resource),
name=self.resource_name + '_change',
)
)
return urlpatterns
def urls(self):
return self.get_urls(), 'api', self.prefix
urls = property(urls)
class DjangoRestFrameworkSite(object):
app_name = 'api'
name = 'api'
def __init__(self, *args, **kwargs):
self._registry = {}
super(DjangoRestFrameworkSite, self).__init__(*args, **kwargs)
def register(self, view, resource, prefix=None, resource_name=None):
if resource_name is None:
if hasattr(resource, 'model'):
resource_name = resource.model.__name__.lower()
else:
resource_name = resource.__name__.lower()
resource.resource_name = resource_name
if prefix not in self._registry:
self._registry[prefix] = {}
if resource_name not in self._registry[prefix]:
self._registry[prefix][resource_name] = []
api_entry = ApiEntry(resource, view, prefix, resource_name)
self._registry[prefix][resource_name].append(api_entry)
# def unregister(self, prefix=None, resource_name=None, resource=None):
# """
# Unregisters a resource.
# """
# if resource_name is None and resource is not None and \
# hasattr(resource, 'model'):
# resource_name = resource.model.__name__.lower()
#
# if resource_name is None:
# # do nothing
# return
#
# prefix_registry = self._registry.get(prefix, {})
# if resource_name in prefix_registry:
# del prefix_registry[resource_name]
@property
def urls(self):
return self.get_urls(), self.app_name, self.name
def get_urls(self):
# Site-wide views.
urlpatterns = patterns('',
)
# Add in each resource's views.
for prefix in self._registry.keys():
for resource_name in self._registry[prefix].keys():
for api_entry in self._registry[prefix][resource_name]:
urlpatterns += patterns('',
url(r'^', include(api_entry.urls))
)
return urlpatterns

View File

@ -185,7 +185,6 @@ class RequestMixin(object):
return (None, None)
parsers = as_tuple(self.parsers)
for parser_cls in parsers:
parser = parser_cls(self)
if parser.can_handle_request(content_type):
@ -387,6 +386,7 @@ class AuthMixin(object):
user = self.user
for permission_cls in self.permissions:
permission = permission_cls(self)
permission.request = self.request
permission.check_permission(user)
@ -570,7 +570,6 @@ class UpdateModelMixin(object):
else:
# Otherwise assume the kwargs uniquely identify the model
self.model_instance = model.objects.get(**kwargs)
for (key, val) in self.CONTENT.items():
setattr(self.model_instance, key, val)
except model.DoesNotExist:
@ -606,7 +605,6 @@ class ListModelMixin(object):
"""
Behavior to list a set of `model` instances on GET requests
"""
# NB. Not obvious to me if it would be better to set this on the resource?
#
# Presumably it's more useful to have on the view, because that way you can

View File

@ -19,6 +19,9 @@ from djangorestframework import status
from djangorestframework.compat import yaml
from djangorestframework.response import ErrorResponse
from djangorestframework.utils.mediatypes import media_type_matches
from xml.etree import ElementTree as ET
import datetime
import decimal
__all__ = (
@ -28,6 +31,7 @@ __all__ = (
'FormParser',
'MultiPartParser',
'YAMLParser',
'XMLParser',
)
@ -167,10 +171,66 @@ class MultiPartParser(BaseParser):
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
{'detail': 'multipart parse error - %s' % unicode(exc)})
return django_parser.parse()
class XMLParser(BaseParser):
"""
XML parser.
"""
media_type = 'application/xml'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will simply be a string representing the body of the request.
`files` will always be `None`.
"""
data = {}
tree = ET.parse(stream)
for child in tree.getroot().getchildren():
data[child.tag] = self._type_convert(child.text)
return (data, None)
def _type_convert(self, value):
"""
Converts the value returned by the XMl parse into the equivalent
Python type
"""
if value is None:
return value
try:
return datetime.datetime.strptime(value,'%Y-%m-%d %H:%M:%S')
except ValueError:
pass
try:
return int(value)
except ValueError:
pass
try:
return decimal.Decimal(value)
except decimal.InvalidOperation:
pass
return value
DEFAULT_PARSERS = ( JSONParser,
FormParser,
MultiPartParser )
MultiPartParser,
XMLParser
)
if YAMLParser:
DEFAULT_PARSERS += ( YAMLParser, )

View File

@ -26,6 +26,14 @@ _403_FORBIDDEN_RESPONSE = ErrorResponse(
{'detail': 'You do not have permission to access this resource. ' +
'You may need to login or otherwise authenticate the request.'})
_403_NOT_LOGGED_IN_RESPONSE = ErrorResponse(
status.HTTP_403_FORBIDDEN,
{'detail': 'You need to login before you can access this resource.'})
_403_PERMISSION_DENIED_RESPONSE = ErrorResponse(
status.HTTP_403_FORBIDDEN,
{'detail': 'You do not have permission to access this resource.'})
_503_SERVICE_UNAVAILABLE = ErrorResponse(
status.HTTP_503_SERVICE_UNAVAILABLE,
{'detail': 'request was throttled'})
@ -64,7 +72,7 @@ class IsAuthenticated(BasePermission):
def check_permission(self, user):
if not user.is_authenticated():
raise _403_FORBIDDEN_RESPONSE
raise _403_NOT_LOGGED_IN_RESPONSE
class IsAdminUser(BasePermission):

View File

@ -109,7 +109,6 @@ class JSONRenderer(BaseRenderer):
sort_keys = True
except (ValueError, TypeError):
indent = None
return json.dumps(obj, cls=DateTimeAwareJSONEncoder, indent=indent, sort_keys=sort_keys)

View File

@ -212,9 +212,14 @@ class FormResource(Resource):
return None
if data is not None or files is not None:
return form(data, files)
return form()
form_ = form(data=data, files=files)
else:
form_ = form()
if hasattr(self.view, 'request'):
form_.request = self.view.request
return form_
@ -279,13 +284,13 @@ class ModelResource(FormResource):
The list of extra fields to include. This is only used if :attr:`fields` is not set.
"""
def __init__(self, view=None, depth=None, stack=[], **kwargs):
def __init__(self, view):
"""
Allow :attr:`form` and :attr:`model` attributes set on the
:class:`View` to override the :attr:`form` and :attr:`model`
attributes set on the :class:`Resource`.
"""
super(ModelResource, self).__init__(view, depth, stack, **kwargs)
super(ModelResource, self).__init__(view)
self.model = getattr(view, 'model', None) or self.model
@ -333,11 +338,17 @@ class ModelResource(FormResource):
if data is not None or files is not None:
if issubclass(form, forms.ModelForm) and hasattr(self.view, 'model_instance'):
# Bound to an existing model instance
return form(data, files, instance=self.view.model_instance)
form_ = form(data=data, files=files, instance=self.view.model_instance)
else:
return form(data, files)
form_ = form(data=data, files=files)
return form()
else:
form_ = form()
if hasattr(self.view, 'request'):
form_.request = self.view.request
return form_
def url(self, instance):
@ -355,7 +366,7 @@ class ModelResource(FormResource):
# dis does teh magicks...
urlconf = get_urlconf()
resolver = get_resolver(urlconf)
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
for tuple_item in possibilities:
possibility = tuple_item[0]
@ -379,6 +390,18 @@ class ModelResource(FormResource):
return reverse(self.view_callable[0], kwargs=instance_attrs)
except NoReverseMatch:
pass
try:
if hasattr(self, 'resource_name'):
resource_name = self.resource_name
else:
resource_name = instance.__class__.__name__.split('.')[0].lower()
return reverse(
'%s:%s_change' % ('api', resource_name), args=(instance.pk,)
)
except NoReverseMatch:
pass
raise _SkipField

View File

@ -144,7 +144,10 @@ class Serializer(object):
def get_related_serializer(self, key):
info = _fields_to_dict(self.fields).get(key, None)
fields = _fields_to_dict(self.fields)
fields.update(_fields_to_dict(self.include))
info = fields.get(key, None)
# If an element in `fields` is a 2-tuple of (str, tuple)
# then the second element of the tuple is the fields to

View File

@ -1,13 +1,23 @@
"""Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers."""
import os
from django.conf import settings
modules = [filename.rsplit('.', 1)[0]
for filename in os.listdir(os.path.dirname(__file__))
if filename.endswith('.py') and not filename.startswith('_')]
__test__ = dict()
for module in modules:
exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module)
exec("from djangorestframework.tests.%s import *" % module)
__test__[module] = module_doc or ""
# Try importing all tests if asked for (then we can run 'em)
try:
skiptest = settings.SKIP_DJANGORESTFRAMEWORK_TESTS
except:
skiptest = False
if not skiptest:
import os
modules = [filename.rsplit('.', 1)[0]
for filename in os.listdir(os.path.dirname(__file__))
if filename.endswith('.py') and not filename.startswith('_')]
__test__ = dict()
for module in modules:
exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module)
exec("from djangorestframework.tests.%s import *" % module)
__test__[module] = module_doc or ""
print 'TestXMLParser' in locals().keys()

View File

@ -136,6 +136,8 @@ from cgi import parse_qs
from django import forms
from django.test import TestCase
from djangorestframework.parsers import FormParser
from djangorestframework.parsers import XMLParser
import datetime
class Form(forms.Form):
field1 = forms.CharField(max_length=3)
@ -153,3 +155,30 @@ class TestFormParser(TestCase):
(data, files) = parser.parse(stream)
self.assertEqual(Form(data).is_valid(), True)
class TestXMLParser(TestCase):
def setUp(self):
self.input = StringIO(
'<?xml version="1.0" encoding="utf-8"?>'
'<root>'
'<field_a>121.0</field_a>'
'<field_b>dasd</field_b>'
'<field_c></field_c>'
'<field_d>2011-12-25 12:45:00</field_d>'
'</root>'
)
self.data = {
'field_a': 121,
'field_b': 'dasd',
'field_c': None,
'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
}
def test_parse(self):
parser = XMLParser(None)
(data, files) = parser.parse(self.input)
self.assertEqual(data, self.data)

View File

@ -4,13 +4,16 @@ from django.test import TestCase
from djangorestframework import status
from djangorestframework.compat import View as DjangoView
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer
from djangorestframework.parsers import JSONParser, YAMLParser
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\
XMLRenderer
from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser
from djangorestframework.mixins import ResponseMixin
from djangorestframework.response import Response
from djangorestframework.utils.mediatypes import add_media_type_param
from StringIO import StringIO
import datetime
from decimal import Decimal
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
@ -223,4 +226,66 @@ if YAMLRenderer:
content = renderer.render(obj, 'application/yaml')
(data, files) = parser.parse(StringIO(content))
self.assertEquals(obj, data)
self.assertEquals(obj, data)
class XMLRendererTestCase(TestCase):
"""
Tests specific to the JSON Renderer
"""
def test_render_string(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 'astring'}, 'application/xml')
self.assertXMLContains(content, '<field>astring</field>')
def test_render_integer(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 111}, 'application/xml')
self.assertXMLContains(content, '<field>111</field>')
def test_render_datetime(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({
'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
}, 'application/xml')
self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
def test_render_float(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 123.4}, 'application/xml')
self.assertXMLContains(content, '<field>123.4</field>')
def test_render_decimal(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
self.assertXMLContains(content, '<field>111.2</field>')
def test_render_none(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': None}, 'application/xml')
self.assertXMLContains(content, '<field></field>')
def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))

View File

@ -150,7 +150,9 @@ class XMLRenderer():
xml.startElement(key, {})
self._to_xml(xml, value)
xml.endElement(key)
elif data is None:
# Don't output any value
pass
else:
xml.characters(smart_unicode(data))