diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index e557abedf..6dc804983 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -34,27 +34,33 @@ class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ + www_authenticate_realm = 'api' def authenticate(self, request): """ Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ - if 'HTTP_AUTHORIZATION' in request.META: - auth = request.META['HTTP_AUTHORIZATION'].split() - if len(auth) == 2 and auth[0].lower() == "basic": - try: - auth_parts = base64.b64decode(auth[1]).partition(':') - except TypeError: - return None + auth = request.META.get('HTTP_AUTHORIZATION', '').split() - try: - userid = smart_unicode(auth_parts[0]) - password = smart_unicode(auth_parts[2]) - except DjangoUnicodeDecodeError: - return None + if not auth or auth[0].lower() != "basic": + return None - return self.authenticate_credentials(userid, password) + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid basic header') + + try: + auth_parts = base64.b64decode(auth[1]).partition(':') + except TypeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + try: + userid = smart_unicode(auth_parts[0]) + password = smart_unicode(auth_parts[2]) + except DjangoUnicodeDecodeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + return self.authenticate_credentials(userid, password) def authenticate_credentials(self, userid, password): """ @@ -63,6 +69,10 @@ class BasicAuthentication(BaseAuthentication): user = authenticate(username=userid, password=password) if user is not None and user.is_active: return (user, None) + raise exceptions.AuthenticationFailed('Invalid username/password') + + def authenticate_header(self): + return 'Basic realm="%s"' % self.www_authenticate_realm class SessionAuthentication(BaseAuthentication): @@ -82,7 +92,7 @@ class SessionAuthentication(BaseAuthentication): # Unauthenticated, CSRF validation not required if not user or not user.is_active: - return + return None # Enforce CSRF validation for session based authentication. class CSRFCheck(CsrfViewMiddleware): @@ -93,7 +103,7 @@ class SessionAuthentication(BaseAuthentication): reason = CSRFCheck().process_view(http_request, None, (), {}) if reason: # CSRF failed, bail with explicit error message - raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) # CSRF passed with authenticated user return (user, None) @@ -120,14 +130,26 @@ class TokenAuthentication(BaseAuthentication): def authenticate(self, request): auth = request.META.get('HTTP_AUTHORIZATION', '').split() - if len(auth) == 2 and auth[0].lower() == "token": - key = auth[1] - try: - token = self.model.objects.get(key=key) - except self.model.DoesNotExist: - return None + if not auth or auth[0].lower() != "token": + return None + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid token header') + + return self.authenticate_credentials(auth[1]) + + def authenticate_credentials(self, key): + try: + token = self.model.objects.get(key=key) + except self.model.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') + + if token.user.is_active: + return (token.user, token) + raise exceptions.AuthenticationFailed('User inactive or deleted') + + def authenticate_header(self): + return 'Token' - if token.user.is_active: - return (token.user, token) # TODO: OAuthAuthentication diff --git a/rest_framework/request.py b/rest_framework/request.py index a1827ba48..38ee36dd7 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -86,6 +86,7 @@ class Request(object): self._method = Empty self._content_type = Empty self._stream = Empty + self._authenticator = None if self.parser_context is None: self.parser_context = {} @@ -166,7 +167,7 @@ class Request(object): by the authentication classes provided to the request. """ if not hasattr(self, '_user'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._user @property @@ -176,9 +177,17 @@ class Request(object): request, such as an authentication token. """ if not hasattr(self, '_auth'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._auth + @property + def successful_authenticator(self): + """ + Return the instance of the authentication instance class that was used + to authenticate the request, or `None`. + """ + return self._authenticator + def _load_data_and_files(self): """ Parses the request content into self.DATA and self.FILES. @@ -282,21 +291,23 @@ class Request(object): def _authenticate(self): """ - Attempt to authenticate the request using each authentication instance in turn. - Returns a two-tuple of (user, authtoken). + Attempt to authenticate the request using each authentication instance + in turn. + Returns a three-tuple of (authenticator, user, authtoken). """ for authenticator in self.authenticators: user_auth_tuple = authenticator.authenticate(self) if not user_auth_tuple is None: - return user_auth_tuple + user, auth = user_auth_tuple + return (authenticator, user, auth) return self._not_authenticated() def _not_authenticated(self): """ - Return a two-tuple of (user, authtoken), representing an - unauthenticated request. + Return a three-tuple of (authenticator, user, authtoken), representing + an unauthenticated request. - By default this will be (AnonymousUser, None). + By default this will be (None, AnonymousUser, None). """ if api_settings.UNAUTHENTICATED_USER: user = api_settings.UNAUTHENTICATED_USER() @@ -308,7 +319,7 @@ class Request(object): else: auth = None - return (user, auth) + return (None, user, auth) def __getattr__(self, attr): """ diff --git a/rest_framework/views.py b/rest_framework/views.py index 1afbd6974..c470817a4 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -148,6 +148,8 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ + if self.request.successful_authenticator: + raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() def throttled(self, request, wait):