From 9d8bce8f5b0915223f57d9fe3d4b63029cfc64c2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 5 Oct 2012 14:48:33 +0100 Subject: [PATCH] Remove Parser.can_handle_request() --- rest_framework/negotiation.py | 10 ++++++ rest_framework/parsers.py | 17 +--------- rest_framework/request.py | 52 +++++++++++++----------------- rest_framework/tests/decorators.py | 8 +++-- rest_framework/tests/request.py | 12 +++---- rest_framework/views.py | 38 +++++++++++++++++++--- 6 files changed, 79 insertions(+), 58 deletions(-) diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 73ae78997..8b22f6690 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -11,6 +11,16 @@ class BaseContentNegotiation(object): class DefaultContentNegotiation(object): settings = api_settings + def select_parser(self, parsers, media_type): + """ + Given a list of parsers and a media type, return the appropriate + parser to handle the incoming request. + """ + for parser in parsers: + if media_type_matches(parser.media_type, media_type): + return parser + return None + def negotiate(self, request, renderers, format=None, force=False): """ Given a request and a list of renderers, return a two-tuple of: diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 5151b2525..5325a64b0 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -15,11 +15,9 @@ from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError from django.utils import simplejson as json -from rest_framework.compat import yaml +from rest_framework.compat import yaml, ETParseError from rest_framework.exceptions import ParseError -from rest_framework.utils.mediatypes import media_type_matches from xml.etree import ElementTree as ET -from rest_framework.compat import ETParseError from xml.parsers.expat import ExpatError import datetime import decimal @@ -40,19 +38,6 @@ class BaseParser(object): media_type = None - def can_handle_request(self, content_type): - """ - Returns :const:`True` if this parser is able to deal with the given *content_type*. - - The default implementation for this function is to check the *content_type* - argument against the :attr:`media_type` attribute set on the class to see if - they match. - - This may be overridden to provide for other behavior, but typically you'll - instead want to just set the :attr:`media_type` attribute on the class. - """ - return media_type_matches(self.media_type, content_type) - def parse(self, string_or_stream, **opts): """ The main entry point to parsers. This is a light wrapper around diff --git a/rest_framework/request.py b/rest_framework/request.py index e254cf8e7..ac15defc2 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -34,8 +34,8 @@ def clone_request(request, method): HTTP method. Used for checking permissions against other methods. """ ret = Request(request._request, - request.parser_classes, - request.authentication_classes) + request.parsers, + request.authenticators) ret._data = request._data ret._files = request._files ret._content_type = request._content_type @@ -60,27 +60,20 @@ class Request(object): _CONTENT_PARAM = api_settings.FORM_CONTENT_OVERRIDE _CONTENTTYPE_PARAM = api_settings.FORM_CONTENTTYPE_OVERRIDE - def __init__(self, request, parser_classes=None, authentication_classes=None): + def __init__(self, request, parsers=None, authenticators=None, + negotiator=None): self._request = request - self.parser_classes = parser_classes or () - self.authentication_classes = authentication_classes or () + self.parsers = parsers or () + self.authenticators = authenticators or () + self.negotiator = negotiator or self._default_negotiator() self._data = Empty self._files = Empty self._method = Empty self._content_type = Empty self._stream = Empty - def get_parsers(self): - """ - Instantiates and returns the list of parsers the request will use. - """ - return [parser() for parser in self.parser_classes] - - def get_authentications(self): - """ - Instantiates and returns the list of parsers the request will use. - """ - return [authentication() for authentication in self.authentication_classes] + def _default_negotiator(self): + return api_settings.DEFAULT_CONTENT_NEGOTIATION() @property def method(self): @@ -254,26 +247,27 @@ class Request(object): if self.stream is None or self.content_type is None: return (None, None) - for parser in self.get_parsers(): - if parser.can_handle_request(self.content_type): - parsed = parser.parse(self.stream, meta=self.META, - upload_handlers=self.upload_handlers) - # Parser classes may return the raw data, or a - # DataAndFiles object. Unpack the result as required. - try: - return (parsed.data, parsed.files) - except AttributeError: - return (parsed, None) + parser = self.negotiator.select_parser(self.parsers, self.content_type) - raise exceptions.UnsupportedMediaType(self._content_type) + if not parser: + raise exceptions.UnsupportedMediaType(self._content_type) + + parsed = parser.parse(self.stream, meta=self.META, + upload_handlers=self.upload_handlers) + # Parser classes may return the raw data, or a + # DataAndFiles object. Unpack the result as required. + try: + return (parsed.data, parsed.files) + except AttributeError: + return (parsed, None) def _authenticate(self): """ Attempt to authenticate the request using each authentication instance in turn. Returns a two-tuple of (user, authtoken). """ - for authentication in self.get_authentications(): - user_auth_tuple = authentication.authenticate(self) + for authenticator in self.authenticators: + user_auth_tuple = authenticator.authenticate(self) if not user_auth_tuple is None: return user_auth_tuple return self._not_authenticated() diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index e943d8fef..a3217bd68 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -65,7 +65,9 @@ class DecoratorTestCase(TestCase): @api_view(['GET']) @parser_classes([JSONParser]) def view(request): - self.assertEqual(request.parser_classes, [JSONParser]) + self.assertEqual(len(request.parsers), 1) + self.assertTrue(isinstance(request.parsers[0], + JSONParser)) return Response({}) request = self.factory.get('/') @@ -76,7 +78,9 @@ class DecoratorTestCase(TestCase): @api_view(['GET']) @authentication_classes([BasicAuthentication]) def view(request): - self.assertEqual(request.authentication_classes, [BasicAuthentication]) + self.assertEqual(len(request.authenticators), 1) + self.assertTrue(isinstance(request.authenticators[0], + BasicAuthentication)) return Response({}) request = self.factory.get('/') diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 42274fcd2..f5c63f110 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -61,7 +61,7 @@ class TestContentParsing(TestCase): """ data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) - request.parser_classes = (FormParser, MultiPartParser) + request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(request.DATA.items(), data.items()) def test_request_DATA_with_text_content(self): @@ -72,7 +72,7 @@ class TestContentParsing(TestCase): content = 'qwerty' content_type = 'text/plain' request = Request(factory.post('/', content, content_type=content_type)) - request.parser_classes = (PlainTextParser,) + request.parsers = (PlainTextParser(),) self.assertEqual(request.DATA, content) def test_request_POST_with_form_content(self): @@ -81,7 +81,7 @@ class TestContentParsing(TestCase): """ data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) - request.parser_classes = (FormParser, MultiPartParser) + request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(request.POST.items(), data.items()) def test_standard_behaviour_determines_form_content_PUT(self): @@ -99,7 +99,7 @@ class TestContentParsing(TestCase): else: request = Request(factory.put('/', data)) - request.parser_classes = (FormParser, MultiPartParser) + request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(request.DATA.items(), data.items()) def test_standard_behaviour_determines_non_form_content_PUT(self): @@ -110,7 +110,7 @@ class TestContentParsing(TestCase): content = 'qwerty' content_type = 'text/plain' request = Request(factory.put('/', content, content_type=content_type)) - request.parser_classes = (PlainTextParser, ) + request.parsers = (PlainTextParser(), ) self.assertEqual(request.DATA, content) def test_overloaded_behaviour_allows_content_tunnelling(self): @@ -124,7 +124,7 @@ class TestContentParsing(TestCase): Request._CONTENTTYPE_PARAM: content_type } request = Request(factory.post('/', data)) - request.parser_classes = (PlainTextParser, ) + request.parsers = (PlainTextParser(), ) self.assertEqual(request.DATA, content) # def test_accessing_post_after_data_form(self): diff --git a/rest_framework/views.py b/rest_framework/views.py index 166bf0b16..0aa1dd0d8 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -70,6 +70,7 @@ class APIView(View): as an attribute on the callable function. This allows us to discover information about the view when we do URL reverse lookups. """ + # TODO: deprecate? view = super(APIView, cls).as_view(**initkwargs) view.cls_instance = cls(**initkwargs) return view @@ -84,6 +85,7 @@ class APIView(View): @property def default_response_headers(self): + # TODO: Only vary by accept if multiple renderers return { 'Allow': ', '.join(self.allowed_methods), 'Vary': 'Accept' @@ -94,6 +96,7 @@ class APIView(View): Return the resource or view class name for use as this view's name. Override to customize. """ + # TODO: deprecate? name = self.__class__.__name__ name = _remove_trailing_string(name, 'View') return _camelcase_to_spaces(name) @@ -103,6 +106,7 @@ class APIView(View): Return the resource or view docstring for use as this view's description. Override to customize. """ + # TODO: deprecate? description = self.__doc__ or '' description = _remove_leading_indent(description) if html: @@ -113,6 +117,7 @@ class APIView(View): """ Apply HTML markup to the description of this view. """ + # TODO: deprecate? if apply_markdown: description = apply_markdown(description) else: @@ -137,6 +142,8 @@ class APIView(View): """ raise exceptions.Throttled(wait) + # API policy instantiation methods + def get_format_suffix(self, **kwargs): """ Determine if the request includes a '.json' style format suffix @@ -144,12 +151,24 @@ class APIView(View): if self.settings.FORMAT_SUFFIX_KWARG: return kwargs.get(self.settings.FORMAT_SUFFIX_KWARG) - def get_renderers(self, format=None): + def get_renderers(self): """ Instantiates and returns the list of renderers that this view can use. """ return [renderer(self) for renderer in self.renderer_classes] + def get_parsers(self): + """ + Instantiates and returns the list of renderers that this view can use. + """ + return [parser() for parser in self.parser_classes] + + def get_authenticators(self): + """ + Instantiates and returns the list of renderers that this view can use. + """ + return [auth() for auth in self.authentication_classes] + def get_permissions(self): """ Instantiates and returns the list of permissions that this view requires. @@ -166,7 +185,11 @@ class APIView(View): """ Instantiate and return the content negotiation class to use. """ - return self.content_negotiation_class() + if not getattr(self, '_negotiator', None): + self._negotiator = self.content_negotiation_class() + return self._negotiator + + # API policy implementation methods def perform_content_negotiation(self, request, force=False): """ @@ -193,19 +216,24 @@ class APIView(View): if not throttle.allow_request(request): self.throttled(request, throttle.wait()) + # Dispatch methods + def initialize_request(self, request, *args, **kargs): """ Returns the initial request object. """ - return Request(request, parser_classes=self.parser_classes, - authentication_classes=self.authentication_classes) + return Request(request, + parsers=self.get_parsers(), + authenticators=self.get_authenticators(), + negotiator=self.get_content_negotiator()) def initial(self, request, *args, **kwargs): """ - Runs anything that needs to occur prior to calling the method handlers. + Runs anything that needs to occur prior to calling the method handler. """ self.format_kwarg = self.get_format_suffix(**kwargs) + # Ensure that the incoming request is permitted if not self.has_permission(request): self.permission_denied(request) self.check_throttles(request)