mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-24 23:19:47 +03:00
Merge 6261d8c02b
into 59afd87cd4
This commit is contained in:
commit
8fe06939f2
|
@ -1,3 +1,64 @@
|
|||
__version__ = '0.2.3'
|
||||
|
||||
VERSION = __version__ # synonym
|
||||
|
||||
from djangorestframework.builtins import DjangoRestFrameworkSite
|
||||
from django.utils.importlib import import_module
|
||||
|
||||
import imp
|
||||
|
||||
__all__ = ('autodiscover','site', '__version__', 'VERSION')
|
||||
|
||||
api = DjangoRestFrameworkSite()
|
||||
|
||||
# A flag to tell us if autodiscover is running. autodiscover will set this to
|
||||
# True while running, and False when it finishes.
|
||||
LOADING = False
|
||||
|
||||
def autodiscover():
|
||||
"""
|
||||
Auto-discover INSTALLED_APPS api.py modules and fail silently when
|
||||
not present. This forces an import on them to register any api bits they
|
||||
may want.
|
||||
"""
|
||||
|
||||
# Bail out if autodiscover didn't finish loading from a previous call so
|
||||
# that we avoid running autodiscover again when the URLconf is loaded by
|
||||
# the exception handler to resolve the handler500 view. This prevents an
|
||||
# admin.py module with errors from re-registering models and raising a
|
||||
# spurious AlreadyRegistered exception.
|
||||
global LOADING
|
||||
if LOADING:
|
||||
return
|
||||
LOADING = True
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
for app in settings.INSTALLED_APPS:
|
||||
# For each app, we need to look for a api.py inside that
|
||||
# app's package. We can't use os.path here -- recall that modules may be
|
||||
# imported different ways (think zip files) -- so we need to get
|
||||
# the app's __path__ and look for api.py on that path.
|
||||
|
||||
# Step 1: find out the app's __path__ Import errors here will (and
|
||||
# should) bubble up, but a missing __path__ (which is legal, but weird)
|
||||
# fails silently -- apps that do weird things with __path__ might
|
||||
# need to roll their own api registration.
|
||||
try:
|
||||
app_path = import_module(app).__path__
|
||||
except (AttributeError, ImportError):
|
||||
continue
|
||||
|
||||
# Step 2: use imp.find_module to find the app's gargoyle_conditions.py.
|
||||
# For some # reason imp.find_module raises ImportError if the app can't
|
||||
# be found # but doesn't actually try to import the module. So skip this
|
||||
# app if its gargoyle.py doesn't exist
|
||||
try:
|
||||
imp.find_module('api', app_path)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
import_module("%s.api" % app)
|
||||
|
||||
# autodiscover was successful, reset loading flag.
|
||||
LOADING = False
|
||||
|
|
|
@ -82,6 +82,7 @@ class UserLoggedInAuthentication(BaseAuthentication):
|
|||
"""
|
||||
Use Django's session framework for authentication.
|
||||
"""
|
||||
check_csrf = True
|
||||
|
||||
def authenticate(self, request):
|
||||
"""
|
||||
|
@ -91,7 +92,7 @@ class UserLoggedInAuthentication(BaseAuthentication):
|
|||
# TODO: Switch this back to request.POST, and let FormParser/MultiPartParser deal with the consequences.
|
||||
if getattr(request, 'user', None) and request.user.is_active:
|
||||
# If this is a POST request we enforce CSRF validation.
|
||||
if request.method.upper() == 'POST':
|
||||
if request.method.upper() == 'POST' and self.check_csrf:
|
||||
# Temporarily replace request.POST with .DATA,
|
||||
# so that we use our more generic request parsing
|
||||
request._post = self.view.DATA
|
||||
|
|
107
djangorestframework/builtins.py
Normal file
107
djangorestframework/builtins.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
|
||||
from django.conf.urls.defaults import patterns, url, include
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
class ApiEntry(object):
|
||||
def __init__(self, resource, view, prefix, resource_name):
|
||||
self.resource, self.view = resource, view
|
||||
self.prefix, self.resource_name = prefix, resource_name
|
||||
if self.prefix is None:
|
||||
self.prefix = ''
|
||||
|
||||
def get_urls(self):
|
||||
from djangorestframework.mixins import ListModelMixin, InstanceMixin
|
||||
from django.conf.urls.defaults import patterns, url
|
||||
|
||||
if self.prefix == '':
|
||||
url_prefix = ''
|
||||
else:
|
||||
url_prefix = self.prefix + '/'
|
||||
|
||||
if issubclass(self.view, ListModelMixin):
|
||||
urlpatterns = patterns('',
|
||||
url(r'^%s%s/$' % (url_prefix, self.resource_name),
|
||||
self.view.as_view(resource=self.resource),
|
||||
name=self.resource_name,
|
||||
)
|
||||
)
|
||||
elif issubclass(self.view, InstanceMixin):
|
||||
urlpatterns = patterns('',
|
||||
url(r'^%s%s/(?P<pk>[0-9a-zA-Z]+)/$' % (url_prefix, self.resource_name),
|
||||
self.view.as_view(resource=self.resource),
|
||||
name=self.resource_name + '_change',
|
||||
)
|
||||
)
|
||||
return urlpatterns
|
||||
|
||||
|
||||
def urls(self):
|
||||
return self.get_urls(), 'api', self.prefix
|
||||
urls = property(urls)
|
||||
|
||||
class DjangoRestFrameworkSite(object):
|
||||
app_name = 'api'
|
||||
name = 'api'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._registry = {}
|
||||
super(DjangoRestFrameworkSite, self).__init__(*args, **kwargs)
|
||||
|
||||
def register(self, view, resource, prefix=None, resource_name=None):
|
||||
if resource_name is None:
|
||||
if hasattr(resource, 'model'):
|
||||
resource_name = resource.model.__name__.lower()
|
||||
else:
|
||||
resource_name = resource.__name__.lower()
|
||||
|
||||
resource.resource_name = resource_name
|
||||
|
||||
if prefix not in self._registry:
|
||||
self._registry[prefix] = {}
|
||||
|
||||
if resource_name not in self._registry[prefix]:
|
||||
self._registry[prefix][resource_name] = []
|
||||
|
||||
api_entry = ApiEntry(resource, view, prefix, resource_name)
|
||||
self._registry[prefix][resource_name].append(api_entry)
|
||||
|
||||
|
||||
# def unregister(self, prefix=None, resource_name=None, resource=None):
|
||||
# """
|
||||
# Unregisters a resource.
|
||||
# """
|
||||
# if resource_name is None and resource is not None and \
|
||||
# hasattr(resource, 'model'):
|
||||
# resource_name = resource.model.__name__.lower()
|
||||
#
|
||||
# if resource_name is None:
|
||||
# # do nothing
|
||||
# return
|
||||
#
|
||||
# prefix_registry = self._registry.get(prefix, {})
|
||||
# if resource_name in prefix_registry:
|
||||
# del prefix_registry[resource_name]
|
||||
|
||||
@property
|
||||
def urls(self):
|
||||
return self.get_urls(), self.app_name, self.name
|
||||
|
||||
def get_urls(self):
|
||||
|
||||
# Site-wide views.
|
||||
urlpatterns = patterns('',
|
||||
|
||||
|
||||
)
|
||||
|
||||
# Add in each resource's views.
|
||||
for prefix in self._registry.keys():
|
||||
for resource_name in self._registry[prefix].keys():
|
||||
for api_entry in self._registry[prefix][resource_name]:
|
||||
urlpatterns += patterns('',
|
||||
url(r'^', include(api_entry.urls))
|
||||
)
|
||||
|
||||
|
||||
return urlpatterns
|
||||
|
|
@ -185,7 +185,6 @@ class RequestMixin(object):
|
|||
return (None, None)
|
||||
|
||||
parsers = as_tuple(self.parsers)
|
||||
|
||||
for parser_cls in parsers:
|
||||
parser = parser_cls(self)
|
||||
if parser.can_handle_request(content_type):
|
||||
|
@ -387,6 +386,7 @@ class AuthMixin(object):
|
|||
user = self.user
|
||||
for permission_cls in self.permissions:
|
||||
permission = permission_cls(self)
|
||||
permission.request = self.request
|
||||
permission.check_permission(user)
|
||||
|
||||
|
||||
|
@ -570,7 +570,6 @@ class UpdateModelMixin(object):
|
|||
else:
|
||||
# Otherwise assume the kwargs uniquely identify the model
|
||||
self.model_instance = model.objects.get(**kwargs)
|
||||
|
||||
for (key, val) in self.CONTENT.items():
|
||||
setattr(self.model_instance, key, val)
|
||||
except model.DoesNotExist:
|
||||
|
@ -606,7 +605,6 @@ class ListModelMixin(object):
|
|||
"""
|
||||
Behavior to list a set of `model` instances on GET requests
|
||||
"""
|
||||
|
||||
# NB. Not obvious to me if it would be better to set this on the resource?
|
||||
#
|
||||
# Presumably it's more useful to have on the view, because that way you can
|
||||
|
|
|
@ -19,6 +19,9 @@ from djangorestframework import status
|
|||
from djangorestframework.compat import yaml
|
||||
from djangorestframework.response import ErrorResponse
|
||||
from djangorestframework.utils.mediatypes import media_type_matches
|
||||
from xml.etree import ElementTree as ET
|
||||
import datetime
|
||||
import decimal
|
||||
|
||||
|
||||
__all__ = (
|
||||
|
@ -28,6 +31,7 @@ __all__ = (
|
|||
'FormParser',
|
||||
'MultiPartParser',
|
||||
'YAMLParser',
|
||||
'XMLParser',
|
||||
)
|
||||
|
||||
|
||||
|
@ -167,10 +171,66 @@ class MultiPartParser(BaseParser):
|
|||
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
|
||||
{'detail': 'multipart parse error - %s' % unicode(exc)})
|
||||
return django_parser.parse()
|
||||
|
||||
|
||||
|
||||
class XMLParser(BaseParser):
|
||||
"""
|
||||
XML parser.
|
||||
"""
|
||||
|
||||
media_type = 'application/xml'
|
||||
|
||||
def parse(self, stream):
|
||||
"""
|
||||
Returns a 2-tuple of `(data, files)`.
|
||||
|
||||
`data` will simply be a string representing the body of the request.
|
||||
`files` will always be `None`.
|
||||
"""
|
||||
data = {}
|
||||
tree = ET.parse(stream)
|
||||
for child in tree.getroot().getchildren():
|
||||
data[child.tag] = self._type_convert(child.text)
|
||||
|
||||
return (data, None)
|
||||
|
||||
def _type_convert(self, value):
|
||||
"""
|
||||
Converts the value returned by the XMl parse into the equivalent
|
||||
Python type
|
||||
"""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
try:
|
||||
return datetime.datetime.strptime(value,'%Y-%m-%d %H:%M:%S')
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
return decimal.Decimal(value)
|
||||
except decimal.InvalidOperation:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
|
||||
DEFAULT_PARSERS = ( JSONParser,
|
||||
FormParser,
|
||||
MultiPartParser )
|
||||
MultiPartParser,
|
||||
XMLParser
|
||||
)
|
||||
|
||||
if YAMLParser:
|
||||
DEFAULT_PARSERS += ( YAMLParser, )
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,14 @@ _403_FORBIDDEN_RESPONSE = ErrorResponse(
|
|||
{'detail': 'You do not have permission to access this resource. ' +
|
||||
'You may need to login or otherwise authenticate the request.'})
|
||||
|
||||
_403_NOT_LOGGED_IN_RESPONSE = ErrorResponse(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
{'detail': 'You need to login before you can access this resource.'})
|
||||
|
||||
_403_PERMISSION_DENIED_RESPONSE = ErrorResponse(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
{'detail': 'You do not have permission to access this resource.'})
|
||||
|
||||
_503_SERVICE_UNAVAILABLE = ErrorResponse(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
{'detail': 'request was throttled'})
|
||||
|
@ -64,7 +72,7 @@ class IsAuthenticated(BasePermission):
|
|||
|
||||
def check_permission(self, user):
|
||||
if not user.is_authenticated():
|
||||
raise _403_FORBIDDEN_RESPONSE
|
||||
raise _403_NOT_LOGGED_IN_RESPONSE
|
||||
|
||||
|
||||
class IsAdminUser(BasePermission):
|
||||
|
|
|
@ -109,7 +109,6 @@ class JSONRenderer(BaseRenderer):
|
|||
sort_keys = True
|
||||
except (ValueError, TypeError):
|
||||
indent = None
|
||||
|
||||
return json.dumps(obj, cls=DateTimeAwareJSONEncoder, indent=indent, sort_keys=sort_keys)
|
||||
|
||||
|
||||
|
|
|
@ -212,9 +212,14 @@ class FormResource(Resource):
|
|||
return None
|
||||
|
||||
if data is not None or files is not None:
|
||||
return form(data, files)
|
||||
|
||||
return form()
|
||||
form_ = form(data=data, files=files)
|
||||
else:
|
||||
form_ = form()
|
||||
|
||||
if hasattr(self.view, 'request'):
|
||||
form_.request = self.view.request
|
||||
|
||||
return form_
|
||||
|
||||
|
||||
|
||||
|
@ -279,13 +284,13 @@ class ModelResource(FormResource):
|
|||
The list of extra fields to include. This is only used if :attr:`fields` is not set.
|
||||
"""
|
||||
|
||||
def __init__(self, view=None, depth=None, stack=[], **kwargs):
|
||||
def __init__(self, view):
|
||||
"""
|
||||
Allow :attr:`form` and :attr:`model` attributes set on the
|
||||
:class:`View` to override the :attr:`form` and :attr:`model`
|
||||
attributes set on the :class:`Resource`.
|
||||
"""
|
||||
super(ModelResource, self).__init__(view, depth, stack, **kwargs)
|
||||
super(ModelResource, self).__init__(view)
|
||||
|
||||
self.model = getattr(view, 'model', None) or self.model
|
||||
|
||||
|
@ -333,11 +338,17 @@ class ModelResource(FormResource):
|
|||
if data is not None or files is not None:
|
||||
if issubclass(form, forms.ModelForm) and hasattr(self.view, 'model_instance'):
|
||||
# Bound to an existing model instance
|
||||
return form(data, files, instance=self.view.model_instance)
|
||||
form_ = form(data=data, files=files, instance=self.view.model_instance)
|
||||
else:
|
||||
return form(data, files)
|
||||
form_ = form(data=data, files=files)
|
||||
|
||||
return form()
|
||||
else:
|
||||
form_ = form()
|
||||
|
||||
if hasattr(self.view, 'request'):
|
||||
form_.request = self.view.request
|
||||
|
||||
return form_
|
||||
|
||||
|
||||
def url(self, instance):
|
||||
|
@ -355,7 +366,7 @@ class ModelResource(FormResource):
|
|||
# dis does teh magicks...
|
||||
urlconf = get_urlconf()
|
||||
resolver = get_resolver(urlconf)
|
||||
|
||||
|
||||
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
|
||||
for tuple_item in possibilities:
|
||||
possibility = tuple_item[0]
|
||||
|
@ -379,6 +390,18 @@ class ModelResource(FormResource):
|
|||
return reverse(self.view_callable[0], kwargs=instance_attrs)
|
||||
except NoReverseMatch:
|
||||
pass
|
||||
|
||||
try:
|
||||
if hasattr(self, 'resource_name'):
|
||||
resource_name = self.resource_name
|
||||
else:
|
||||
resource_name = instance.__class__.__name__.split('.')[0].lower()
|
||||
return reverse(
|
||||
'%s:%s_change' % ('api', resource_name), args=(instance.pk,)
|
||||
)
|
||||
except NoReverseMatch:
|
||||
pass
|
||||
|
||||
raise _SkipField
|
||||
|
||||
|
||||
|
|
|
@ -144,7 +144,10 @@ class Serializer(object):
|
|||
|
||||
|
||||
def get_related_serializer(self, key):
|
||||
info = _fields_to_dict(self.fields).get(key, None)
|
||||
fields = _fields_to_dict(self.fields)
|
||||
fields.update(_fields_to_dict(self.include))
|
||||
info = fields.get(key, None)
|
||||
|
||||
|
||||
# If an element in `fields` is a 2-tuple of (str, tuple)
|
||||
# then the second element of the tuple is the fields to
|
||||
|
|
|
@ -1,13 +1,23 @@
|
|||
"""Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers."""
|
||||
import os
|
||||
from django.conf import settings
|
||||
|
||||
modules = [filename.rsplit('.', 1)[0]
|
||||
for filename in os.listdir(os.path.dirname(__file__))
|
||||
if filename.endswith('.py') and not filename.startswith('_')]
|
||||
__test__ = dict()
|
||||
|
||||
for module in modules:
|
||||
exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module)
|
||||
exec("from djangorestframework.tests.%s import *" % module)
|
||||
__test__[module] = module_doc or ""
|
||||
# Try importing all tests if asked for (then we can run 'em)
|
||||
try:
|
||||
skiptest = settings.SKIP_DJANGORESTFRAMEWORK_TESTS
|
||||
except:
|
||||
skiptest = False
|
||||
|
||||
if not skiptest:
|
||||
import os
|
||||
|
||||
modules = [filename.rsplit('.', 1)[0]
|
||||
for filename in os.listdir(os.path.dirname(__file__))
|
||||
if filename.endswith('.py') and not filename.startswith('_')]
|
||||
__test__ = dict()
|
||||
|
||||
for module in modules:
|
||||
exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module)
|
||||
exec("from djangorestframework.tests.%s import *" % module)
|
||||
__test__[module] = module_doc or ""
|
||||
|
||||
print 'TestXMLParser' in locals().keys()
|
||||
|
|
|
@ -136,6 +136,8 @@ from cgi import parse_qs
|
|||
from django import forms
|
||||
from django.test import TestCase
|
||||
from djangorestframework.parsers import FormParser
|
||||
from djangorestframework.parsers import XMLParser
|
||||
import datetime
|
||||
|
||||
class Form(forms.Form):
|
||||
field1 = forms.CharField(max_length=3)
|
||||
|
@ -153,3 +155,30 @@ class TestFormParser(TestCase):
|
|||
(data, files) = parser.parse(stream)
|
||||
|
||||
self.assertEqual(Form(data).is_valid(), True)
|
||||
|
||||
|
||||
class TestXMLParser(TestCase):
|
||||
def setUp(self):
|
||||
self.input = StringIO(
|
||||
'<?xml version="1.0" encoding="utf-8"?>'
|
||||
'<root>'
|
||||
'<field_a>121.0</field_a>'
|
||||
'<field_b>dasd</field_b>'
|
||||
'<field_c></field_c>'
|
||||
'<field_d>2011-12-25 12:45:00</field_d>'
|
||||
'</root>'
|
||||
)
|
||||
self.data = {
|
||||
'field_a': 121,
|
||||
'field_b': 'dasd',
|
||||
'field_c': None,
|
||||
'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
|
||||
|
||||
}
|
||||
|
||||
def test_parse(self):
|
||||
parser = XMLParser(None)
|
||||
|
||||
(data, files) = parser.parse(self.input)
|
||||
self.assertEqual(data, self.data)
|
||||
|
||||
|
|
|
@ -4,13 +4,16 @@ from django.test import TestCase
|
|||
|
||||
from djangorestframework import status
|
||||
from djangorestframework.compat import View as DjangoView
|
||||
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer
|
||||
from djangorestframework.parsers import JSONParser, YAMLParser
|
||||
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\
|
||||
XMLRenderer
|
||||
from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser
|
||||
from djangorestframework.mixins import ResponseMixin
|
||||
from djangorestframework.response import Response
|
||||
from djangorestframework.utils.mediatypes import add_media_type_param
|
||||
|
||||
from StringIO import StringIO
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
DUMMYSTATUS = status.HTTP_200_OK
|
||||
DUMMYCONTENT = 'dummycontent'
|
||||
|
@ -223,4 +226,66 @@ if YAMLRenderer:
|
|||
|
||||
content = renderer.render(obj, 'application/yaml')
|
||||
(data, files) = parser.parse(StringIO(content))
|
||||
self.assertEquals(obj, data)
|
||||
self.assertEquals(obj, data)
|
||||
|
||||
|
||||
class XMLRendererTestCase(TestCase):
|
||||
"""
|
||||
Tests specific to the JSON Renderer
|
||||
"""
|
||||
|
||||
def test_render_string(self):
|
||||
"""
|
||||
Test XML rendering.
|
||||
"""
|
||||
renderer = XMLRenderer(None)
|
||||
content = renderer.render({'field': 'astring'}, 'application/xml')
|
||||
self.assertXMLContains(content, '<field>astring</field>')
|
||||
|
||||
def test_render_integer(self):
|
||||
"""
|
||||
Test XML rendering.
|
||||
"""
|
||||
renderer = XMLRenderer(None)
|
||||
content = renderer.render({'field': 111}, 'application/xml')
|
||||
self.assertXMLContains(content, '<field>111</field>')
|
||||
|
||||
def test_render_datetime(self):
|
||||
"""
|
||||
Test XML rendering.
|
||||
"""
|
||||
renderer = XMLRenderer(None)
|
||||
content = renderer.render({
|
||||
'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
|
||||
}, 'application/xml')
|
||||
self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
|
||||
|
||||
def test_render_float(self):
|
||||
"""
|
||||
Test XML rendering.
|
||||
"""
|
||||
renderer = XMLRenderer(None)
|
||||
content = renderer.render({'field': 123.4}, 'application/xml')
|
||||
self.assertXMLContains(content, '<field>123.4</field>')
|
||||
|
||||
def test_render_decimal(self):
|
||||
"""
|
||||
Test XML rendering.
|
||||
"""
|
||||
renderer = XMLRenderer(None)
|
||||
content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
|
||||
self.assertXMLContains(content, '<field>111.2</field>')
|
||||
|
||||
def test_render_none(self):
|
||||
"""
|
||||
Test XML rendering.
|
||||
"""
|
||||
renderer = XMLRenderer(None)
|
||||
content = renderer.render({'field': None}, 'application/xml')
|
||||
self.assertXMLContains(content, '<field></field>')
|
||||
|
||||
|
||||
def assertXMLContains(self, xml, string):
|
||||
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
|
||||
self.assertTrue(xml.endswith('</root>'))
|
||||
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
|
|
@ -150,7 +150,9 @@ class XMLRenderer():
|
|||
xml.startElement(key, {})
|
||||
self._to_xml(xml, value)
|
||||
xml.endElement(key)
|
||||
|
||||
elif data is None:
|
||||
# Don't output any value
|
||||
pass
|
||||
else:
|
||||
xml.characters(smart_unicode(data))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user