diff --git a/djangorestframework/__init__.py b/djangorestframework/__init__.py index b1ef6ddaa..ceb76b3b3 100644 --- a/djangorestframework/__init__.py +++ b/djangorestframework/__init__.py @@ -1,3 +1,64 @@ __version__ = '0.2.3' VERSION = __version__ # synonym + +from djangorestframework.builtins import DjangoRestFrameworkSite +from django.utils.importlib import import_module + +import imp + +__all__ = ('autodiscover','site', '__version__', 'VERSION') + +api = DjangoRestFrameworkSite() + +# A flag to tell us if autodiscover is running. autodiscover will set this to +# True while running, and False when it finishes. +LOADING = False + +def autodiscover(): + """ + Auto-discover INSTALLED_APPS api.py modules and fail silently when + not present. This forces an import on them to register any api bits they + may want. + """ + + # Bail out if autodiscover didn't finish loading from a previous call so + # that we avoid running autodiscover again when the URLconf is loaded by + # the exception handler to resolve the handler500 view. This prevents an + # admin.py module with errors from re-registering models and raising a + # spurious AlreadyRegistered exception. + global LOADING + if LOADING: + return + LOADING = True + + from django.conf import settings + + for app in settings.INSTALLED_APPS: + # For each app, we need to look for a api.py inside that + # app's package. We can't use os.path here -- recall that modules may be + # imported different ways (think zip files) -- so we need to get + # the app's __path__ and look for api.py on that path. + + # Step 1: find out the app's __path__ Import errors here will (and + # should) bubble up, but a missing __path__ (which is legal, but weird) + # fails silently -- apps that do weird things with __path__ might + # need to roll their own api registration. + try: + app_path = import_module(app).__path__ + except (AttributeError, ImportError): + continue + + # Step 2: use imp.find_module to find the app's gargoyle_conditions.py. + # For some # reason imp.find_module raises ImportError if the app can't + # be found # but doesn't actually try to import the module. So skip this + # app if its gargoyle.py doesn't exist + try: + imp.find_module('api', app_path) + except ImportError: + continue + + import_module("%s.api" % app) + + # autodiscover was successful, reset loading flag. + LOADING = False diff --git a/djangorestframework/authentication.py b/djangorestframework/authentication.py index be22103e6..a40c5e653 100644 --- a/djangorestframework/authentication.py +++ b/djangorestframework/authentication.py @@ -82,6 +82,7 @@ class UserLoggedInAuthentication(BaseAuthentication): """ Use Django's session framework for authentication. """ + check_csrf = True def authenticate(self, request): """ @@ -91,7 +92,7 @@ class UserLoggedInAuthentication(BaseAuthentication): # TODO: Switch this back to request.POST, and let FormParser/MultiPartParser deal with the consequences. if getattr(request, 'user', None) and request.user.is_active: # If this is a POST request we enforce CSRF validation. - if request.method.upper() == 'POST': + if request.method.upper() == 'POST' and self.check_csrf: # Temporarily replace request.POST with .DATA, # so that we use our more generic request parsing request._post = self.view.DATA diff --git a/djangorestframework/builtins.py b/djangorestframework/builtins.py new file mode 100644 index 000000000..55d257ca5 --- /dev/null +++ b/djangorestframework/builtins.py @@ -0,0 +1,107 @@ + +from django.conf.urls.defaults import patterns, url, include +from django.views.decorators.csrf import csrf_exempt + +class ApiEntry(object): + def __init__(self, resource, view, prefix, resource_name): + self.resource, self.view = resource, view + self.prefix, self.resource_name = prefix, resource_name + if self.prefix is None: + self.prefix = '' + + def get_urls(self): + from djangorestframework.mixins import ListModelMixin, InstanceMixin + from django.conf.urls.defaults import patterns, url + + if self.prefix == '': + url_prefix = '' + else: + url_prefix = self.prefix + '/' + + if issubclass(self.view, ListModelMixin): + urlpatterns = patterns('', + url(r'^%s%s/$' % (url_prefix, self.resource_name), + self.view.as_view(resource=self.resource), + name=self.resource_name, + ) + ) + elif issubclass(self.view, InstanceMixin): + urlpatterns = patterns('', + url(r'^%s%s/(?P[0-9a-zA-Z]+)/$' % (url_prefix, self.resource_name), + self.view.as_view(resource=self.resource), + name=self.resource_name + '_change', + ) + ) + return urlpatterns + + + def urls(self): + return self.get_urls(), 'api', self.prefix + urls = property(urls) + +class DjangoRestFrameworkSite(object): + app_name = 'api' + name = 'api' + + def __init__(self, *args, **kwargs): + self._registry = {} + super(DjangoRestFrameworkSite, self).__init__(*args, **kwargs) + + def register(self, view, resource, prefix=None, resource_name=None): + if resource_name is None: + if hasattr(resource, 'model'): + resource_name = resource.model.__name__.lower() + else: + resource_name = resource.__name__.lower() + + resource.resource_name = resource_name + + if prefix not in self._registry: + self._registry[prefix] = {} + + if resource_name not in self._registry[prefix]: + self._registry[prefix][resource_name] = [] + + api_entry = ApiEntry(resource, view, prefix, resource_name) + self._registry[prefix][resource_name].append(api_entry) + + +# def unregister(self, prefix=None, resource_name=None, resource=None): +# """ +# Unregisters a resource. +# """ +# if resource_name is None and resource is not None and \ +# hasattr(resource, 'model'): +# resource_name = resource.model.__name__.lower() +# +# if resource_name is None: +# # do nothing +# return +# +# prefix_registry = self._registry.get(prefix, {}) +# if resource_name in prefix_registry: +# del prefix_registry[resource_name] + + @property + def urls(self): + return self.get_urls(), self.app_name, self.name + + def get_urls(self): + + # Site-wide views. + urlpatterns = patterns('', + + + ) + + # Add in each resource's views. + for prefix in self._registry.keys(): + for resource_name in self._registry[prefix].keys(): + for api_entry in self._registry[prefix][resource_name]: + urlpatterns += patterns('', + url(r'^', include(api_entry.urls)) + ) + + + return urlpatterns + diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 9fed61221..11b62162e 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -185,7 +185,6 @@ class RequestMixin(object): return (None, None) parsers = as_tuple(self.parsers) - for parser_cls in parsers: parser = parser_cls(self) if parser.can_handle_request(content_type): @@ -387,6 +386,7 @@ class AuthMixin(object): user = self.user for permission_cls in self.permissions: permission = permission_cls(self) + permission.request = self.request permission.check_permission(user) @@ -570,7 +570,6 @@ class UpdateModelMixin(object): else: # Otherwise assume the kwargs uniquely identify the model self.model_instance = model.objects.get(**kwargs) - for (key, val) in self.CONTENT.items(): setattr(self.model_instance, key, val) except model.DoesNotExist: @@ -606,7 +605,6 @@ class ListModelMixin(object): """ Behavior to list a set of `model` instances on GET requests """ - # NB. Not obvious to me if it would be better to set this on the resource? # # Presumably it's more useful to have on the view, because that way you can diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py index 2ff64bd3e..235017fdb 100644 --- a/djangorestframework/parsers.py +++ b/djangorestframework/parsers.py @@ -19,6 +19,9 @@ from djangorestframework import status from djangorestframework.compat import yaml from djangorestframework.response import ErrorResponse from djangorestframework.utils.mediatypes import media_type_matches +from xml.etree import ElementTree as ET +import datetime +import decimal __all__ = ( @@ -28,6 +31,7 @@ __all__ = ( 'FormParser', 'MultiPartParser', 'YAMLParser', + 'XMLParser', ) @@ -167,10 +171,66 @@ class MultiPartParser(BaseParser): raise ErrorResponse(status.HTTP_400_BAD_REQUEST, {'detail': 'multipart parse error - %s' % unicode(exc)}) return django_parser.parse() + + + +class XMLParser(BaseParser): + """ + XML parser. + """ + + media_type = 'application/xml' + + def parse(self, stream): + """ + Returns a 2-tuple of `(data, files)`. + + `data` will simply be a string representing the body of the request. + `files` will always be `None`. + """ + data = {} + tree = ET.parse(stream) + for child in tree.getroot().getchildren(): + data[child.tag] = self._type_convert(child.text) + + return (data, None) + + def _type_convert(self, value): + """ + Converts the value returned by the XMl parse into the equivalent + Python type + """ + if value is None: + return value + + try: + return datetime.datetime.strptime(value,'%Y-%m-%d %H:%M:%S') + except ValueError: + pass + + try: + return int(value) + except ValueError: + pass + + + try: + return decimal.Decimal(value) + except decimal.InvalidOperation: + pass + + return value + + DEFAULT_PARSERS = ( JSONParser, FormParser, - MultiPartParser ) + MultiPartParser, + XMLParser + ) if YAMLParser: DEFAULT_PARSERS += ( YAMLParser, ) + + + diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index 0052a6094..d704665dc 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -26,6 +26,14 @@ _403_FORBIDDEN_RESPONSE = ErrorResponse( {'detail': 'You do not have permission to access this resource. ' + 'You may need to login or otherwise authenticate the request.'}) +_403_NOT_LOGGED_IN_RESPONSE = ErrorResponse( + status.HTTP_403_FORBIDDEN, + {'detail': 'You need to login before you can access this resource.'}) + +_403_PERMISSION_DENIED_RESPONSE = ErrorResponse( + status.HTTP_403_FORBIDDEN, + {'detail': 'You do not have permission to access this resource.'}) + _503_SERVICE_UNAVAILABLE = ErrorResponse( status.HTTP_503_SERVICE_UNAVAILABLE, {'detail': 'request was throttled'}) @@ -64,7 +72,7 @@ class IsAuthenticated(BasePermission): def check_permission(self, user): if not user.is_authenticated(): - raise _403_FORBIDDEN_RESPONSE + raise _403_NOT_LOGGED_IN_RESPONSE class IsAdminUser(BasePermission): diff --git a/djangorestframework/renderers.py b/djangorestframework/renderers.py index aae2cab25..9600ef571 100644 --- a/djangorestframework/renderers.py +++ b/djangorestframework/renderers.py @@ -109,7 +109,6 @@ class JSONRenderer(BaseRenderer): sort_keys = True except (ValueError, TypeError): indent = None - return json.dumps(obj, cls=DateTimeAwareJSONEncoder, indent=indent, sort_keys=sort_keys) diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index 5770d07f9..037004832 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -212,9 +212,14 @@ class FormResource(Resource): return None if data is not None or files is not None: - return form(data, files) - - return form() + form_ = form(data=data, files=files) + else: + form_ = form() + + if hasattr(self.view, 'request'): + form_.request = self.view.request + + return form_ @@ -279,13 +284,13 @@ class ModelResource(FormResource): The list of extra fields to include. This is only used if :attr:`fields` is not set. """ - def __init__(self, view=None, depth=None, stack=[], **kwargs): + def __init__(self, view): """ Allow :attr:`form` and :attr:`model` attributes set on the :class:`View` to override the :attr:`form` and :attr:`model` attributes set on the :class:`Resource`. """ - super(ModelResource, self).__init__(view, depth, stack, **kwargs) + super(ModelResource, self).__init__(view) self.model = getattr(view, 'model', None) or self.model @@ -333,11 +338,17 @@ class ModelResource(FormResource): if data is not None or files is not None: if issubclass(form, forms.ModelForm) and hasattr(self.view, 'model_instance'): # Bound to an existing model instance - return form(data, files, instance=self.view.model_instance) + form_ = form(data=data, files=files, instance=self.view.model_instance) else: - return form(data, files) + form_ = form(data=data, files=files) - return form() + else: + form_ = form() + + if hasattr(self.view, 'request'): + form_.request = self.view.request + + return form_ def url(self, instance): @@ -355,7 +366,7 @@ class ModelResource(FormResource): # dis does teh magicks... urlconf = get_urlconf() resolver = get_resolver(urlconf) - + possibilities = resolver.reverse_dict.getlist(self.view_callable[0]) for tuple_item in possibilities: possibility = tuple_item[0] @@ -379,6 +390,18 @@ class ModelResource(FormResource): return reverse(self.view_callable[0], kwargs=instance_attrs) except NoReverseMatch: pass + + try: + if hasattr(self, 'resource_name'): + resource_name = self.resource_name + else: + resource_name = instance.__class__.__name__.split('.')[0].lower() + return reverse( + '%s:%s_change' % ('api', resource_name), args=(instance.pk,) + ) + except NoReverseMatch: + pass + raise _SkipField diff --git a/djangorestframework/serializer.py b/djangorestframework/serializer.py index 55b84df16..c28663216 100644 --- a/djangorestframework/serializer.py +++ b/djangorestframework/serializer.py @@ -144,7 +144,10 @@ class Serializer(object): def get_related_serializer(self, key): - info = _fields_to_dict(self.fields).get(key, None) + fields = _fields_to_dict(self.fields) + fields.update(_fields_to_dict(self.include)) + info = fields.get(key, None) + # If an element in `fields` is a 2-tuple of (str, tuple) # then the second element of the tuple is the fields to diff --git a/djangorestframework/tests/__init__.py b/djangorestframework/tests/__init__.py index f664c5c12..b409f9f1d 100644 --- a/djangorestframework/tests/__init__.py +++ b/djangorestframework/tests/__init__.py @@ -1,13 +1,23 @@ """Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers.""" -import os +from django.conf import settings -modules = [filename.rsplit('.', 1)[0] - for filename in os.listdir(os.path.dirname(__file__)) - if filename.endswith('.py') and not filename.startswith('_')] -__test__ = dict() - -for module in modules: - exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module) - exec("from djangorestframework.tests.%s import *" % module) - __test__[module] = module_doc or "" +# Try importing all tests if asked for (then we can run 'em) +try: + skiptest = settings.SKIP_DJANGORESTFRAMEWORK_TESTS +except: + skiptest = False + +if not skiptest: + import os + + modules = [filename.rsplit('.', 1)[0] + for filename in os.listdir(os.path.dirname(__file__)) + if filename.endswith('.py') and not filename.startswith('_')] + __test__ = dict() + + for module in modules: + exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module) + exec("from djangorestframework.tests.%s import *" % module) + __test__[module] = module_doc or "" +print 'TestXMLParser' in locals().keys() diff --git a/djangorestframework/tests/parsers.py b/djangorestframework/tests/parsers.py index deba688e5..08b963304 100644 --- a/djangorestframework/tests/parsers.py +++ b/djangorestframework/tests/parsers.py @@ -136,6 +136,8 @@ from cgi import parse_qs from django import forms from django.test import TestCase from djangorestframework.parsers import FormParser +from djangorestframework.parsers import XMLParser +import datetime class Form(forms.Form): field1 = forms.CharField(max_length=3) @@ -153,3 +155,30 @@ class TestFormParser(TestCase): (data, files) = parser.parse(stream) self.assertEqual(Form(data).is_valid(), True) + + +class TestXMLParser(TestCase): + def setUp(self): + self.input = StringIO( + '' + '' + '121.0' + 'dasd' + '' + '2011-12-25 12:45:00' + '' + ) + self.data = { + 'field_a': 121, + 'field_b': 'dasd', + 'field_c': None, + 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) + + } + + def test_parse(self): + parser = XMLParser(None) + + (data, files) = parser.parse(self.input) + self.assertEqual(data, self.data) + diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py index d2046212f..9ae43cdda 100644 --- a/djangorestframework/tests/renderers.py +++ b/djangorestframework/tests/renderers.py @@ -4,13 +4,16 @@ from django.test import TestCase from djangorestframework import status from djangorestframework.compat import View as DjangoView -from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer -from djangorestframework.parsers import JSONParser, YAMLParser +from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer,\ + XMLRenderer +from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser from djangorestframework.mixins import ResponseMixin from djangorestframework.response import Response from djangorestframework.utils.mediatypes import add_media_type_param from StringIO import StringIO +import datetime +from decimal import Decimal DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' @@ -223,4 +226,66 @@ if YAMLRenderer: content = renderer.render(obj, 'application/yaml') (data, files) = parser.parse(StringIO(content)) - self.assertEquals(obj, data) \ No newline at end of file + self.assertEquals(obj, data) + + +class XMLRendererTestCase(TestCase): + """ + Tests specific to the JSON Renderer + """ + + def test_render_string(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 'astring'}, 'application/xml') + self.assertXMLContains(content, 'astring') + + def test_render_integer(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 111}, 'application/xml') + self.assertXMLContains(content, '111') + + def test_render_datetime(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({ + 'field': datetime.datetime(2011, 12, 25, 12, 45, 00) + }, 'application/xml') + self.assertXMLContains(content, '2011-12-25 12:45:00') + + def test_render_float(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 123.4}, 'application/xml') + self.assertXMLContains(content, '123.4') + + def test_render_decimal(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': Decimal('111.2')}, 'application/xml') + self.assertXMLContains(content, '111.2') + + def test_render_none(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': None}, 'application/xml') + self.assertXMLContains(content, '') + + + def assertXMLContains(self, xml, string): + self.assertTrue(xml.startswith('\n')) + self.assertTrue(xml.endswith('')) + self.assertTrue(string in xml, '%r not in %r' % (string, xml)) \ No newline at end of file diff --git a/djangorestframework/utils/__init__.py b/djangorestframework/utils/__init__.py index 99f9724ce..1f1a08661 100644 --- a/djangorestframework/utils/__init__.py +++ b/djangorestframework/utils/__init__.py @@ -150,7 +150,9 @@ class XMLRenderer(): xml.startElement(key, {}) self._to_xml(xml, value) xml.endElement(key) - + elif data is None: + # Don't output any value + pass else: xml.characters(smart_unicode(data))