Remove Parser.can_handle_request()

This commit is contained in:
Tom Christie 2012-10-05 14:48:33 +01:00
parent 3e862c7737
commit 9d8bce8f5b
6 changed files with 79 additions and 58 deletions

View File

@ -11,6 +11,16 @@ class BaseContentNegotiation(object):
class DefaultContentNegotiation(object): class DefaultContentNegotiation(object):
settings = api_settings settings = api_settings
def select_parser(self, parsers, media_type):
"""
Given a list of parsers and a media type, return the appropriate
parser to handle the incoming request.
"""
for parser in parsers:
if media_type_matches(parser.media_type, media_type):
return parser
return None
def negotiate(self, request, renderers, format=None, force=False): def negotiate(self, request, renderers, format=None, force=False):
""" """
Given a request and a list of renderers, return a two-tuple of: Given a request and a list of renderers, return a two-tuple of:

View File

@ -15,11 +15,9 @@ from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError from django.http.multipartparser import MultiPartParserError
from django.utils import simplejson as json from django.utils import simplejson as json
from rest_framework.compat import yaml from rest_framework.compat import yaml, ETParseError
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.utils.mediatypes import media_type_matches
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from rest_framework.compat import ETParseError
from xml.parsers.expat import ExpatError from xml.parsers.expat import ExpatError
import datetime import datetime
import decimal import decimal
@ -40,19 +38,6 @@ class BaseParser(object):
media_type = None media_type = None
def can_handle_request(self, content_type):
"""
Returns :const:`True` if this parser is able to deal with the given *content_type*.
The default implementation for this function is to check the *content_type*
argument against the :attr:`media_type` attribute set on the class to see if
they match.
This may be overridden to provide for other behavior, but typically you'll
instead want to just set the :attr:`media_type` attribute on the class.
"""
return media_type_matches(self.media_type, content_type)
def parse(self, string_or_stream, **opts): def parse(self, string_or_stream, **opts):
""" """
The main entry point to parsers. This is a light wrapper around The main entry point to parsers. This is a light wrapper around

View File

@ -34,8 +34,8 @@ def clone_request(request, method):
HTTP method. Used for checking permissions against other methods. HTTP method. Used for checking permissions against other methods.
""" """
ret = Request(request._request, ret = Request(request._request,
request.parser_classes, request.parsers,
request.authentication_classes) request.authenticators)
ret._data = request._data ret._data = request._data
ret._files = request._files ret._files = request._files
ret._content_type = request._content_type ret._content_type = request._content_type
@ -60,27 +60,20 @@ class Request(object):
_CONTENT_PARAM = api_settings.FORM_CONTENT_OVERRIDE _CONTENT_PARAM = api_settings.FORM_CONTENT_OVERRIDE
_CONTENTTYPE_PARAM = api_settings.FORM_CONTENTTYPE_OVERRIDE _CONTENTTYPE_PARAM = api_settings.FORM_CONTENTTYPE_OVERRIDE
def __init__(self, request, parser_classes=None, authentication_classes=None): def __init__(self, request, parsers=None, authenticators=None,
negotiator=None):
self._request = request self._request = request
self.parser_classes = parser_classes or () self.parsers = parsers or ()
self.authentication_classes = authentication_classes or () self.authenticators = authenticators or ()
self.negotiator = negotiator or self._default_negotiator()
self._data = Empty self._data = Empty
self._files = Empty self._files = Empty
self._method = Empty self._method = Empty
self._content_type = Empty self._content_type = Empty
self._stream = Empty self._stream = Empty
def get_parsers(self): def _default_negotiator(self):
""" return api_settings.DEFAULT_CONTENT_NEGOTIATION()
Instantiates and returns the list of parsers the request will use.
"""
return [parser() for parser in self.parser_classes]
def get_authentications(self):
"""
Instantiates and returns the list of parsers the request will use.
"""
return [authentication() for authentication in self.authentication_classes]
@property @property
def method(self): def method(self):
@ -254,26 +247,27 @@ class Request(object):
if self.stream is None or self.content_type is None: if self.stream is None or self.content_type is None:
return (None, None) return (None, None)
for parser in self.get_parsers(): parser = self.negotiator.select_parser(self.parsers, self.content_type)
if parser.can_handle_request(self.content_type):
parsed = parser.parse(self.stream, meta=self.META,
upload_handlers=self.upload_handlers)
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
return (parsed.data, parsed.files)
except AttributeError:
return (parsed, None)
raise exceptions.UnsupportedMediaType(self._content_type) if not parser:
raise exceptions.UnsupportedMediaType(self._content_type)
parsed = parser.parse(self.stream, meta=self.META,
upload_handlers=self.upload_handlers)
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
return (parsed.data, parsed.files)
except AttributeError:
return (parsed, None)
def _authenticate(self): def _authenticate(self):
""" """
Attempt to authenticate the request using each authentication instance in turn. Attempt to authenticate the request using each authentication instance in turn.
Returns a two-tuple of (user, authtoken). Returns a two-tuple of (user, authtoken).
""" """
for authentication in self.get_authentications(): for authenticator in self.authenticators:
user_auth_tuple = authentication.authenticate(self) user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None: if not user_auth_tuple is None:
return user_auth_tuple return user_auth_tuple
return self._not_authenticated() return self._not_authenticated()

View File

@ -65,7 +65,9 @@ class DecoratorTestCase(TestCase):
@api_view(['GET']) @api_view(['GET'])
@parser_classes([JSONParser]) @parser_classes([JSONParser])
def view(request): def view(request):
self.assertEqual(request.parser_classes, [JSONParser]) self.assertEqual(len(request.parsers), 1)
self.assertTrue(isinstance(request.parsers[0],
JSONParser))
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
@ -76,7 +78,9 @@ class DecoratorTestCase(TestCase):
@api_view(['GET']) @api_view(['GET'])
@authentication_classes([BasicAuthentication]) @authentication_classes([BasicAuthentication])
def view(request): def view(request):
self.assertEqual(request.authentication_classes, [BasicAuthentication]) self.assertEqual(len(request.authenticators), 1)
self.assertTrue(isinstance(request.authenticators[0],
BasicAuthentication))
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')

View File

@ -61,7 +61,7 @@ class TestContentParsing(TestCase):
""" """
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data)) request = Request(factory.post('/', data))
request.parser_classes = (FormParser, MultiPartParser) request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.DATA.items(), data.items()) self.assertEqual(request.DATA.items(), data.items())
def test_request_DATA_with_text_content(self): def test_request_DATA_with_text_content(self):
@ -72,7 +72,7 @@ class TestContentParsing(TestCase):
content = 'qwerty' content = 'qwerty'
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type)) request = Request(factory.post('/', content, content_type=content_type))
request.parser_classes = (PlainTextParser,) request.parsers = (PlainTextParser(),)
self.assertEqual(request.DATA, content) self.assertEqual(request.DATA, content)
def test_request_POST_with_form_content(self): def test_request_POST_with_form_content(self):
@ -81,7 +81,7 @@ class TestContentParsing(TestCase):
""" """
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data)) request = Request(factory.post('/', data))
request.parser_classes = (FormParser, MultiPartParser) request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.POST.items(), data.items()) self.assertEqual(request.POST.items(), data.items())
def test_standard_behaviour_determines_form_content_PUT(self): def test_standard_behaviour_determines_form_content_PUT(self):
@ -99,7 +99,7 @@ class TestContentParsing(TestCase):
else: else:
request = Request(factory.put('/', data)) request = Request(factory.put('/', data))
request.parser_classes = (FormParser, MultiPartParser) request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.DATA.items(), data.items()) self.assertEqual(request.DATA.items(), data.items())
def test_standard_behaviour_determines_non_form_content_PUT(self): def test_standard_behaviour_determines_non_form_content_PUT(self):
@ -110,7 +110,7 @@ class TestContentParsing(TestCase):
content = 'qwerty' content = 'qwerty'
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type)) request = Request(factory.put('/', content, content_type=content_type))
request.parser_classes = (PlainTextParser, ) request.parsers = (PlainTextParser(), )
self.assertEqual(request.DATA, content) self.assertEqual(request.DATA, content)
def test_overloaded_behaviour_allows_content_tunnelling(self): def test_overloaded_behaviour_allows_content_tunnelling(self):
@ -124,7 +124,7 @@ class TestContentParsing(TestCase):
Request._CONTENTTYPE_PARAM: content_type Request._CONTENTTYPE_PARAM: content_type
} }
request = Request(factory.post('/', data)) request = Request(factory.post('/', data))
request.parser_classes = (PlainTextParser, ) request.parsers = (PlainTextParser(), )
self.assertEqual(request.DATA, content) self.assertEqual(request.DATA, content)
# def test_accessing_post_after_data_form(self): # def test_accessing_post_after_data_form(self):

View File

@ -70,6 +70,7 @@ class APIView(View):
as an attribute on the callable function. This allows us to discover as an attribute on the callable function. This allows us to discover
information about the view when we do URL reverse lookups. information about the view when we do URL reverse lookups.
""" """
# TODO: deprecate?
view = super(APIView, cls).as_view(**initkwargs) view = super(APIView, cls).as_view(**initkwargs)
view.cls_instance = cls(**initkwargs) view.cls_instance = cls(**initkwargs)
return view return view
@ -84,6 +85,7 @@ class APIView(View):
@property @property
def default_response_headers(self): def default_response_headers(self):
# TODO: Only vary by accept if multiple renderers
return { return {
'Allow': ', '.join(self.allowed_methods), 'Allow': ', '.join(self.allowed_methods),
'Vary': 'Accept' 'Vary': 'Accept'
@ -94,6 +96,7 @@ class APIView(View):
Return the resource or view class name for use as this view's name. Return the resource or view class name for use as this view's name.
Override to customize. Override to customize.
""" """
# TODO: deprecate?
name = self.__class__.__name__ name = self.__class__.__name__
name = _remove_trailing_string(name, 'View') name = _remove_trailing_string(name, 'View')
return _camelcase_to_spaces(name) return _camelcase_to_spaces(name)
@ -103,6 +106,7 @@ class APIView(View):
Return the resource or view docstring for use as this view's description. Return the resource or view docstring for use as this view's description.
Override to customize. Override to customize.
""" """
# TODO: deprecate?
description = self.__doc__ or '' description = self.__doc__ or ''
description = _remove_leading_indent(description) description = _remove_leading_indent(description)
if html: if html:
@ -113,6 +117,7 @@ class APIView(View):
""" """
Apply HTML markup to the description of this view. Apply HTML markup to the description of this view.
""" """
# TODO: deprecate?
if apply_markdown: if apply_markdown:
description = apply_markdown(description) description = apply_markdown(description)
else: else:
@ -137,6 +142,8 @@ class APIView(View):
""" """
raise exceptions.Throttled(wait) raise exceptions.Throttled(wait)
# API policy instantiation methods
def get_format_suffix(self, **kwargs): def get_format_suffix(self, **kwargs):
""" """
Determine if the request includes a '.json' style format suffix Determine if the request includes a '.json' style format suffix
@ -144,12 +151,24 @@ class APIView(View):
if self.settings.FORMAT_SUFFIX_KWARG: if self.settings.FORMAT_SUFFIX_KWARG:
return kwargs.get(self.settings.FORMAT_SUFFIX_KWARG) return kwargs.get(self.settings.FORMAT_SUFFIX_KWARG)
def get_renderers(self, format=None): def get_renderers(self):
""" """
Instantiates and returns the list of renderers that this view can use. Instantiates and returns the list of renderers that this view can use.
""" """
return [renderer(self) for renderer in self.renderer_classes] return [renderer(self) for renderer in self.renderer_classes]
def get_parsers(self):
"""
Instantiates and returns the list of renderers that this view can use.
"""
return [parser() for parser in self.parser_classes]
def get_authenticators(self):
"""
Instantiates and returns the list of renderers that this view can use.
"""
return [auth() for auth in self.authentication_classes]
def get_permissions(self): def get_permissions(self):
""" """
Instantiates and returns the list of permissions that this view requires. Instantiates and returns the list of permissions that this view requires.
@ -166,7 +185,11 @@ class APIView(View):
""" """
Instantiate and return the content negotiation class to use. Instantiate and return the content negotiation class to use.
""" """
return self.content_negotiation_class() if not getattr(self, '_negotiator', None):
self._negotiator = self.content_negotiation_class()
return self._negotiator
# API policy implementation methods
def perform_content_negotiation(self, request, force=False): def perform_content_negotiation(self, request, force=False):
""" """
@ -193,19 +216,24 @@ class APIView(View):
if not throttle.allow_request(request): if not throttle.allow_request(request):
self.throttled(request, throttle.wait()) self.throttled(request, throttle.wait())
# Dispatch methods
def initialize_request(self, request, *args, **kargs): def initialize_request(self, request, *args, **kargs):
""" """
Returns the initial request object. Returns the initial request object.
""" """
return Request(request, parser_classes=self.parser_classes, return Request(request,
authentication_classes=self.authentication_classes) parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator())
def initial(self, request, *args, **kwargs): def initial(self, request, *args, **kwargs):
""" """
Runs anything that needs to occur prior to calling the method handlers. Runs anything that needs to occur prior to calling the method handler.
""" """
self.format_kwarg = self.get_format_suffix(**kwargs) self.format_kwarg = self.get_format_suffix(**kwargs)
# Ensure that the incoming request is permitted
if not self.has_permission(request): if not self.has_permission(request):
self.permission_denied(request) self.permission_denied(request)
self.check_throttles(request) self.check_throttles(request)