From 1c78bf53dbc4f75cfdc240c72f4db9d2376cb9cb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 6 Sep 2012 13:49:15 +0100 Subject: [PATCH] Refactoring some basics --- djangorestframework/authentication.py | 30 +++++++--- djangorestframework/exceptions.py | 2 +- djangorestframework/permissions.py | 4 +- djangorestframework/response.py | 4 +- djangorestframework/tests/renderers.py | 12 ++-- djangorestframework/views.py | 78 ++++++++++++++------------ 6 files changed, 76 insertions(+), 54 deletions(-) diff --git a/djangorestframework/authentication.py b/djangorestframework/authentication.py index 197aa424c..4ebe72592 100644 --- a/djangorestframework/authentication.py +++ b/djangorestframework/authentication.py @@ -39,13 +39,14 @@ class BaseAuthentication(object): class BasicAuthentication(BaseAuthentication): """ - Use HTTP Basic authentication. + Base class for HTTP Basic authentication. + Subclasses should implement `.authenticate_credentials()`. """ def authenticate(self, request): """ - Returns a :obj:`User` if a correct username and password have been supplied - using HTTP Basic authentication. Otherwise returns :const:`None`. + Returns a `User` if a correct username and password have been supplied + using HTTP Basic authentication. Otherwise returns `None`. """ from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError @@ -58,15 +59,30 @@ class BasicAuthentication(BaseAuthentication): return None try: - uname, passwd = smart_unicode(auth_parts[0]), smart_unicode(auth_parts[2]) + userid, password = smart_unicode(auth_parts[0]), smart_unicode(auth_parts[2]) except DjangoUnicodeDecodeError: return None - user = authenticate(username=uname, password=passwd) - if user is not None and user.is_active: - return user + return self.authenticate_credentials(userid, password) return None + def authenticate_credentials(self, userid, password): + """ + Given the Basic authentication userid and password, authenticate + and return a user instance. + """ + raise NotImplementedError('.authenticate_credentials() must be overridden') + + +class UserBasicAuthentication(BasicAuthentication): + def authenticate_credentials(self, userid, password): + """ + Authenticate the userid and password against username and password. + """ + user = authenticate(username=userid, password=password) + if user is not None and user.is_active: + return user + class SessionAuthentication(BaseAuthentication): """ diff --git a/djangorestframework/exceptions.py b/djangorestframework/exceptions.py index 0b4dacf73..3f5b23f67 100644 --- a/djangorestframework/exceptions.py +++ b/djangorestframework/exceptions.py @@ -25,7 +25,7 @@ class ParseError(APIException): class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN - default_detail = 'You do not have permission to access this resource.' + default_detail = 'You do not have permission to perform this action.' def __init__(self, detail=None): self.detail = detail or self.default_detail diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index eff2ed2b8..64e455f57 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -52,7 +52,7 @@ class IsAdminUser(BasePermission): """ def check_permission(self, request, obj=None): - if request.user and request.user.is_staff(): + if request.user and request.user.is_staff: return True return False @@ -82,7 +82,7 @@ class DjangoModelPermissions(BasePermission): """ # Map methods into required permission codes. - # Override this if you need to also provide 'read' permissions, + # Override this if you need to also provide 'view' permissions, # or if you want to provide custom permission codes. perms_map = { 'GET': [], diff --git a/djangorestframework/response.py b/djangorestframework/response.py index 65173200d..f8b3504e2 100644 --- a/djangorestframework/response.py +++ b/djangorestframework/response.py @@ -144,9 +144,9 @@ class Response(SimpleTemplateResponse): # attempting more specific media types first # NB. The inner loop here isn't as bad as it first looks :) # Worst case is we're looping over len(accept_list) * len(self.renderers) - for media_type_list in order_by_precedence(accepts): + for media_type_set in order_by_precedence(accepts): for renderer in renderers: - for media_type in media_type_list: + for media_type in media_type_set: if renderer.can_handle_response(media_type): return renderer, media_type diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py index 8b14038df..d8581540f 100644 --- a/djangorestframework/tests/renderers.py +++ b/djangorestframework/tests/renderers.py @@ -246,9 +246,9 @@ class JSONPRendererTests(TestCase): Test JSONP rendering with View JSON Renderer. """ resp = self.client.get('/jsonp/jsonrenderer', - HTTP_ACCEPT='application/json-p') + HTTP_ACCEPT='application/javascript') self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/json-p') + self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) def test_without_callback_without_json_renderer(self): @@ -256,9 +256,9 @@ class JSONPRendererTests(TestCase): Test JSONP rendering without View JSON Renderer. """ resp = self.client.get('/jsonp/nojsonrenderer', - HTTP_ACCEPT='application/json-p') + HTTP_ACCEPT='application/javascript') self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/json-p') + self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) def test_with_callback(self): @@ -267,9 +267,9 @@ class JSONPRendererTests(TestCase): """ callback_func = 'myjsonpcallback' resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, - HTTP_ACCEPT='application/json-p') + HTTP_ACCEPT='application/javascript') self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/json-p') + self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr)) diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 5f9677823..e4d47a31e 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -21,15 +21,6 @@ from djangorestframework.settings import api_settings from djangorestframework import parsers, authentication, status, exceptions, mixins -__all__ = ( - 'View', - 'ModelView', - 'InstanceModelView', - 'ListModelView', - 'ListOrCreateModelView' -) - - def _remove_trailing_string(content, trailing): """ Strip trailing component `trailing` from `content` if it exists. @@ -65,11 +56,6 @@ def _camelcase_to_spaces(content): class APIView(_View): - """ - Handles incoming requests and maps them to REST operations. - Performs request deserialization, response serialization, authentication and input validation. - """ - renderers = api_settings.DEFAULT_RENDERERS """ List of renderer classes the view can serialize the response with, ordered by preference. @@ -81,7 +67,7 @@ class APIView(_View): """ authentication = (authentication.SessionAuthentication, - authentication.BasicAuthentication) + authentication.UserBasicAuthentication) """ List of all authenticating methods to attempt. """ @@ -155,10 +141,21 @@ class APIView(_View): def http_method_not_allowed(self, request, *args, **kwargs): """ Called if `request.method` does not corrospond to a handler method. - We raise an exception, which is handled by `.handle_exception()`. """ raise exceptions.MethodNotAllowed(request.method) + def permission_denied(self, request): + """ + If request is not permitted, determine what kind of exception to raise. + """ + raise exceptions.PermissionDenied() + + def throttled(self, request, wait): + """ + If request is throttled, determine what kind of exception to raise. + """ + raise exceptions.Throttled(wait) + @property def _parsed_media_types(self): """ @@ -208,35 +205,29 @@ class APIView(_View): def check_permissions(self, request, obj=None): """ - Check user permissions and either raise an ``PermissionDenied`` or return. + Check if request should be permitted. """ for permission in self.get_permissions(): if not permission.check_permission(request, obj): - raise exceptions.PermissionDenied() + self.permission_denied(request) def check_throttles(self, request): """ - Check throttles and either raise a `Throttled` exception or return. + Check if request should be throttled. """ for throttle in self.get_throttles(): if not throttle.check_throttle(request): - raise exceptions.Throttled(throttle.wait()) + self.throttled(request, throttle.wait()) - def initial(self, request, *args, **kargs): + def initialize_request(self, request, *args, **kargs): """ - This method runs prior to anything else in the view. - It should return the initial request object. - - You may need to override this if you want to do things like set - `request.upload_handlers` before the authentication and dispatch - handling is run. + Returns the initial request object. """ return Request(request, parsers=self.parsers, authentication=self.authentication) - def final(self, request, response, *args, **kargs): + def finalize_response(self, request, response, *args, **kargs): """ - This method runs after everything else in the view. - It should return the final response object. + Returns the final response object. """ if isinstance(response, Response): response.view = self @@ -248,6 +239,13 @@ class APIView(_View): return response + def initial(self, request, *args, **kwargs): + """ + Runs anything that needs to occur prior to calling the method handlers. + """ + self.check_permissions(request) + self.check_throttles(request) + def handle_exception(self, exc): """ Handle any exception that occurs, by returning an appropriate response, @@ -270,16 +268,24 @@ class APIView(_View): # all other authentication is CSRF exempt. @csrf_exempt def dispatch(self, request, *args, **kwargs): + """ + `APIView.dispatch()` is pretty much the same as Django's regular + `View.dispatch()`, except that it includes hooks to: + + * Initialize the request object. + * Finalize the response object. + * Handle exceptions that occur in the handler method. + * An initial hook for code such as permission checking that should + occur prior to running the method handlers. + """ + request = self.initialize_request(request, *args, **kwargs) + self.request = request self.args = args self.kwargs = kwargs self.headers = self.default_response_headers try: - self.request = self.initial(request, *args, **kwargs) - - # Check that the request is allowed - self.check_permissions(request) - self.check_throttles(request) + self.initial(request, *args, **kwargs) # Get the appropriate handler method if request.method.lower() in self.http_method_names: @@ -292,7 +298,7 @@ class APIView(_View): except Exception as exc: response = self.handle_exception(exc) - self.response = self.final(request, response, *args, **kwargs) + self.response = self.finalize_response(request, response, *args, **kwargs) return self.response