Remove Parser.can_handle_request()

This commit is contained in:
Tom Christie 2012-10-05 14:48:33 +01:00
parent 3e862c7737
commit 9d8bce8f5b
6 changed files with 79 additions and 58 deletions

View File

@ -11,6 +11,16 @@ class BaseContentNegotiation(object):
class DefaultContentNegotiation(object):
settings = api_settings
def select_parser(self, parsers, media_type):
"""
Given a list of parsers and a media type, return the appropriate
parser to handle the incoming request.
"""
for parser in parsers:
if media_type_matches(parser.media_type, media_type):
return parser
return None
def negotiate(self, request, renderers, format=None, force=False):
"""
Given a request and a list of renderers, return a two-tuple of:

View File

@ -15,11 +15,9 @@ from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError
from django.utils import simplejson as json
from rest_framework.compat import yaml
from rest_framework.compat import yaml, ETParseError
from rest_framework.exceptions import ParseError
from rest_framework.utils.mediatypes import media_type_matches
from xml.etree import ElementTree as ET
from rest_framework.compat import ETParseError
from xml.parsers.expat import ExpatError
import datetime
import decimal
@ -40,19 +38,6 @@ class BaseParser(object):
media_type = None
def can_handle_request(self, content_type):
"""
Returns :const:`True` if this parser is able to deal with the given *content_type*.
The default implementation for this function is to check the *content_type*
argument against the :attr:`media_type` attribute set on the class to see if
they match.
This may be overridden to provide for other behavior, but typically you'll
instead want to just set the :attr:`media_type` attribute on the class.
"""
return media_type_matches(self.media_type, content_type)
def parse(self, string_or_stream, **opts):
"""
The main entry point to parsers. This is a light wrapper around

View File

@ -34,8 +34,8 @@ def clone_request(request, method):
HTTP method. Used for checking permissions against other methods.
"""
ret = Request(request._request,
request.parser_classes,
request.authentication_classes)
request.parsers,
request.authenticators)
ret._data = request._data
ret._files = request._files
ret._content_type = request._content_type
@ -60,27 +60,20 @@ class Request(object):
_CONTENT_PARAM = api_settings.FORM_CONTENT_OVERRIDE
_CONTENTTYPE_PARAM = api_settings.FORM_CONTENTTYPE_OVERRIDE
def __init__(self, request, parser_classes=None, authentication_classes=None):
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None):
self._request = request
self.parser_classes = parser_classes or ()
self.authentication_classes = authentication_classes or ()
self.parsers = parsers or ()
self.authenticators = authenticators or ()
self.negotiator = negotiator or self._default_negotiator()
self._data = Empty
self._files = Empty
self._method = Empty
self._content_type = Empty
self._stream = Empty
def get_parsers(self):
"""
Instantiates and returns the list of parsers the request will use.
"""
return [parser() for parser in self.parser_classes]
def get_authentications(self):
"""
Instantiates and returns the list of parsers the request will use.
"""
return [authentication() for authentication in self.authentication_classes]
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION()
@property
def method(self):
@ -254,26 +247,27 @@ class Request(object):
if self.stream is None or self.content_type is None:
return (None, None)
for parser in self.get_parsers():
if parser.can_handle_request(self.content_type):
parsed = parser.parse(self.stream, meta=self.META,
upload_handlers=self.upload_handlers)
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
return (parsed.data, parsed.files)
except AttributeError:
return (parsed, None)
parser = self.negotiator.select_parser(self.parsers, self.content_type)
raise exceptions.UnsupportedMediaType(self._content_type)
if not parser:
raise exceptions.UnsupportedMediaType(self._content_type)
parsed = parser.parse(self.stream, meta=self.META,
upload_handlers=self.upload_handlers)
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
return (parsed.data, parsed.files)
except AttributeError:
return (parsed, None)
def _authenticate(self):
"""
Attempt to authenticate the request using each authentication instance in turn.
Returns a two-tuple of (user, authtoken).
"""
for authentication in self.get_authentications():
user_auth_tuple = authentication.authenticate(self)
for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None:
return user_auth_tuple
return self._not_authenticated()

View File

@ -65,7 +65,9 @@ class DecoratorTestCase(TestCase):
@api_view(['GET'])
@parser_classes([JSONParser])
def view(request):
self.assertEqual(request.parser_classes, [JSONParser])
self.assertEqual(len(request.parsers), 1)
self.assertTrue(isinstance(request.parsers[0],
JSONParser))
return Response({})
request = self.factory.get('/')
@ -76,7 +78,9 @@ class DecoratorTestCase(TestCase):
@api_view(['GET'])
@authentication_classes([BasicAuthentication])
def view(request):
self.assertEqual(request.authentication_classes, [BasicAuthentication])
self.assertEqual(len(request.authenticators), 1)
self.assertTrue(isinstance(request.authenticators[0],
BasicAuthentication))
return Response({})
request = self.factory.get('/')

View File

@ -61,7 +61,7 @@ class TestContentParsing(TestCase):
"""
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parser_classes = (FormParser, MultiPartParser)
request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.DATA.items(), data.items())
def test_request_DATA_with_text_content(self):
@ -72,7 +72,7 @@ class TestContentParsing(TestCase):
content = 'qwerty'
content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type))
request.parser_classes = (PlainTextParser,)
request.parsers = (PlainTextParser(),)
self.assertEqual(request.DATA, content)
def test_request_POST_with_form_content(self):
@ -81,7 +81,7 @@ class TestContentParsing(TestCase):
"""
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parser_classes = (FormParser, MultiPartParser)
request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.POST.items(), data.items())
def test_standard_behaviour_determines_form_content_PUT(self):
@ -99,7 +99,7 @@ class TestContentParsing(TestCase):
else:
request = Request(factory.put('/', data))
request.parser_classes = (FormParser, MultiPartParser)
request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.DATA.items(), data.items())
def test_standard_behaviour_determines_non_form_content_PUT(self):
@ -110,7 +110,7 @@ class TestContentParsing(TestCase):
content = 'qwerty'
content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type))
request.parser_classes = (PlainTextParser, )
request.parsers = (PlainTextParser(), )
self.assertEqual(request.DATA, content)
def test_overloaded_behaviour_allows_content_tunnelling(self):
@ -124,7 +124,7 @@ class TestContentParsing(TestCase):
Request._CONTENTTYPE_PARAM: content_type
}
request = Request(factory.post('/', data))
request.parser_classes = (PlainTextParser, )
request.parsers = (PlainTextParser(), )
self.assertEqual(request.DATA, content)
# def test_accessing_post_after_data_form(self):

View File

@ -70,6 +70,7 @@ class APIView(View):
as an attribute on the callable function. This allows us to discover
information about the view when we do URL reverse lookups.
"""
# TODO: deprecate?
view = super(APIView, cls).as_view(**initkwargs)
view.cls_instance = cls(**initkwargs)
return view
@ -84,6 +85,7 @@ class APIView(View):
@property
def default_response_headers(self):
# TODO: Only vary by accept if multiple renderers
return {
'Allow': ', '.join(self.allowed_methods),
'Vary': 'Accept'
@ -94,6 +96,7 @@ class APIView(View):
Return the resource or view class name for use as this view's name.
Override to customize.
"""
# TODO: deprecate?
name = self.__class__.__name__
name = _remove_trailing_string(name, 'View')
return _camelcase_to_spaces(name)
@ -103,6 +106,7 @@ class APIView(View):
Return the resource or view docstring for use as this view's description.
Override to customize.
"""
# TODO: deprecate?
description = self.__doc__ or ''
description = _remove_leading_indent(description)
if html:
@ -113,6 +117,7 @@ class APIView(View):
"""
Apply HTML markup to the description of this view.
"""
# TODO: deprecate?
if apply_markdown:
description = apply_markdown(description)
else:
@ -137,6 +142,8 @@ class APIView(View):
"""
raise exceptions.Throttled(wait)
# API policy instantiation methods
def get_format_suffix(self, **kwargs):
"""
Determine if the request includes a '.json' style format suffix
@ -144,12 +151,24 @@ class APIView(View):
if self.settings.FORMAT_SUFFIX_KWARG:
return kwargs.get(self.settings.FORMAT_SUFFIX_KWARG)
def get_renderers(self, format=None):
def get_renderers(self):
"""
Instantiates and returns the list of renderers that this view can use.
"""
return [renderer(self) for renderer in self.renderer_classes]
def get_parsers(self):
"""
Instantiates and returns the list of renderers that this view can use.
"""
return [parser() for parser in self.parser_classes]
def get_authenticators(self):
"""
Instantiates and returns the list of renderers that this view can use.
"""
return [auth() for auth in self.authentication_classes]
def get_permissions(self):
"""
Instantiates and returns the list of permissions that this view requires.
@ -166,7 +185,11 @@ class APIView(View):
"""
Instantiate and return the content negotiation class to use.
"""
return self.content_negotiation_class()
if not getattr(self, '_negotiator', None):
self._negotiator = self.content_negotiation_class()
return self._negotiator
# API policy implementation methods
def perform_content_negotiation(self, request, force=False):
"""
@ -193,19 +216,24 @@ class APIView(View):
if not throttle.allow_request(request):
self.throttled(request, throttle.wait())
# Dispatch methods
def initialize_request(self, request, *args, **kargs):
"""
Returns the initial request object.
"""
return Request(request, parser_classes=self.parser_classes,
authentication_classes=self.authentication_classes)
return Request(request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator())
def initial(self, request, *args, **kwargs):
"""
Runs anything that needs to occur prior to calling the method handlers.
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(**kwargs)
# Ensure that the incoming request is permitted
if not self.has_permission(request):
self.permission_denied(request)
self.check_throttles(request)