Raise PEP 3134 chained exceptions

Use PEP 3134 (https://peps.python.org/pep-3134/) exception chaining to
provide enhanced reporting and extra context when errors are encountered.
This commit is contained in:
Tyson 2022-04-12 20:48:25 +10:00
parent df92e57ad6
commit f0ff449232
11 changed files with 50 additions and 49 deletions

View File

@ -79,9 +79,9 @@ class BasicAuthentication(BaseAuthentication):
except UnicodeDecodeError: except UnicodeDecodeError:
auth_decoded = base64.b64decode(auth[1]).decode('latin-1') auth_decoded = base64.b64decode(auth[1]).decode('latin-1')
auth_parts = auth_decoded.partition(':') auth_parts = auth_decoded.partition(':')
except (TypeError, UnicodeDecodeError, binascii.Error): except (TypeError, UnicodeDecodeError, binascii.Error) as exc:
msg = _('Invalid basic header. Credentials not correctly base64 encoded.') msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg) from exc
userid, password = auth_parts[0], auth_parts[2] userid, password = auth_parts[0], auth_parts[2]
return self.authenticate_credentials(userid, password, request) return self.authenticate_credentials(userid, password, request)
@ -189,9 +189,9 @@ class TokenAuthentication(BaseAuthentication):
try: try:
token = auth[1].decode() token = auth[1].decode()
except UnicodeError: except UnicodeError as exc:
msg = _('Invalid token header. Token string should not contain invalid characters.') msg = _('Invalid token header. Token string should not contain invalid characters.')
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg) from exc
return self.authenticate_credentials(token) return self.authenticate_credentials(token)

View File

@ -4,6 +4,7 @@ import decimal
import functools import functools
import inspect import inspect
import re import re
import sys
import uuid import uuid
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
@ -104,7 +105,7 @@ def get_attribute(instance, attrs):
# If we raised an Attribute or KeyError here it'd get treated # If we raised an Attribute or KeyError here it'd get treated
# as an omitted field in `Field.get_attribute()`. Instead we # as an omitted field in `Field.get_attribute()`. Instead we
# raise a ValueError to ensure the exception is not masked. # raise a ValueError to ensure the exception is not masked.
raise ValueError('Exception raised in callable attribute "{}"; original exception was: {}'.format(attr, exc)) raise ValueError('Exception raised in callable attribute "{}"; original exception was: {}'.format(attr, exc)) from exc
return instance return instance
@ -466,14 +467,14 @@ class Field:
instance=instance.__class__.__name__, instance=instance.__class__.__name__,
) )
) )
raise type(exc)(msg) raise type(exc)(msg) from exc
except (KeyError, AttributeError) as exc: except (KeyError, AttributeError) as exc:
if self.default is not empty: if self.default is not empty:
return self.get_default() return self.get_default()
if self.allow_null: if self.allow_null:
return None return None
if not self.required: if not self.required:
raise SkipField() raise SkipField() from exc
msg = ( msg = (
'Got {exc_type} when attempting to get a value for field ' 'Got {exc_type} when attempting to get a value for field '
'`{field}` on serializer `{serializer}`.\nThe serializer ' '`{field}` on serializer `{serializer}`.\nThe serializer '
@ -487,7 +488,7 @@ class Field:
exc=exc exc=exc
) )
) )
raise type(exc)(msg) raise type(exc)(msg) from exc
def get_default(self): def get_default(self):
""" """
@ -633,12 +634,17 @@ class Field:
""" """
try: try:
msg = self.error_messages[key] msg = self.error_messages[key]
except KeyError: except KeyError as exc:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg) from exc
message_string = msg.format(**kwargs) message_string = msg.format(**kwargs)
raise ValidationError(message_string, code=key) err = ValidationError(message_string, code=key)
(_, exc, _) = sys.exc_info()
if exc:
raise err from exc
else:
raise err
@property @property
def root(self): def root(self):

View File

@ -17,8 +17,8 @@ def get_object_or_404(queryset, *filter_args, **filter_kwargs):
""" """
try: try:
return _get_object_or_404(queryset, *filter_args, **filter_kwargs) return _get_object_or_404(queryset, *filter_args, **filter_kwargs)
except (TypeError, ValueError, ValidationError): except (TypeError, ValueError, ValidationError) as exc:
raise Http404 raise Http404 from exc
class GenericAPIView(views.APIView): class GenericAPIView(views.APIView):

View File

@ -206,7 +206,7 @@ class PageNumberPagination(BasePagination):
msg = self.invalid_page_message.format( msg = self.invalid_page_message.format(
page_number=page_number, message=str(exc) page_number=page_number, message=str(exc)
) )
raise NotFound(msg) raise NotFound(msg) from exc
if paginator.num_pages > 1 and self.template is not None: if paginator.num_pages > 1 and self.template is not None:
# The browsable API should display pagination controls. # The browsable API should display pagination controls.
@ -862,8 +862,8 @@ class CursorPagination(BasePagination):
reverse = bool(int(reverse)) reverse = bool(int(reverse))
position = tokens.get('p', [None])[0] position = tokens.get('p', [None])[0]
except (TypeError, ValueError): except (TypeError, ValueError) as exc:
raise NotFound(self.invalid_cursor_message) raise NotFound(self.invalid_cursor_message) from exc
return Cursor(offset=offset, reverse=reverse, position=position) return Cursor(offset=offset, reverse=reverse, position=position)

View File

@ -64,7 +64,7 @@ class JSONParser(BaseParser):
parse_constant = json.strict_constant if self.strict else None parse_constant = json.strict_constant if self.strict else None
return json.load(decoded_stream, parse_constant=parse_constant) return json.load(decoded_stream, parse_constant=parse_constant)
except ValueError as exc: except ValueError as exc:
raise ParseError('JSON parse error - %s' % str(exc)) raise ParseError('JSON parse error - %s' % str(exc)) from exc
class FormParser(BaseParser): class FormParser(BaseParser):
@ -109,7 +109,7 @@ class MultiPartParser(BaseParser):
data, files = parser.parse() data, files = parser.parse()
return DataAndFiles(data, files) return DataAndFiles(data, files)
except MultiPartParserError as exc: except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % str(exc)) raise ParseError('Multipart form parse error - %s' % str(exc)) from exc
class FileUploadParser(BaseParser): class FileUploadParser(BaseParser):

View File

@ -1,4 +1,3 @@
import sys
from collections import OrderedDict from collections import OrderedDict
from urllib import parse from urllib import parse
@ -316,12 +315,10 @@ class HyperlinkedRelatedField(RelatedField):
try: try:
return queryset.get(**lookup_kwargs) return queryset.get(**lookup_kwargs)
except ValueError: except ValueError as exc:
exc = ObjectValueError(str(sys.exc_info()[1])) raise ObjectValueError(str(exc)) from exc
raise exc.with_traceback(sys.exc_info()[2]) except TypeError as exc:
except TypeError: raise ObjectTypeError(str(exc)) from exc
exc = ObjectTypeError(str(sys.exc_info()[1]))
raise exc.with_traceback(sys.exc_info()[2])
def get_url(self, obj, view_name, request, format): def get_url(self, obj, view_name, request, format):
""" """
@ -399,7 +396,7 @@ class HyperlinkedRelatedField(RelatedField):
# Return the hyperlink, or error if incorrectly configured. # Return the hyperlink, or error if incorrectly configured.
try: try:
url = self.get_url(value, self.view_name, request, format) url = self.get_url(value, self.view_name, request, format)
except NoReverseMatch: except NoReverseMatch as exc:
msg = ( msg = (
'Could not resolve URL for hyperlinked relationship using ' 'Could not resolve URL for hyperlinked relationship using '
'view name "%s". You may have failed to include the related ' 'view name "%s". You may have failed to include the related '
@ -413,7 +410,7 @@ class HyperlinkedRelatedField(RelatedField):
"was %s, which may be why it didn't match any " "was %s, which may be why it didn't match any "
"entries in your URL conf." % value_string "entries in your URL conf." % value_string
) )
raise ImproperlyConfigured(msg % self.view_name) raise ImproperlyConfigured(msg % self.view_name) from exc
if url is None: if url is None:
return None return None

View File

@ -9,7 +9,6 @@ The wrapped request then offers a richer API, in particular :
- form overloading of HTTP method, content type and content - form overloading of HTTP method, content type and content
""" """
import io import io
import sys
from contextlib import contextmanager from contextlib import contextmanager
from django.conf import settings from django.conf import settings
@ -72,10 +71,8 @@ def wrap_attributeerrors():
""" """
try: try:
yield yield
except AttributeError: except AttributeError as exc:
info = sys.exc_info() raise WrappedAttributeError(str(exc)) from exc
exc = WrappedAttributeError(str(info[1]))
raise exc.with_traceback(info[2])
class Empty: class Empty:

View File

@ -89,13 +89,13 @@ def insert_into(target, keys, value):
try: try:
target.links.append((keys[-1], value)) target.links.append((keys[-1], value))
except TypeError: except TypeError as exc:
msg = INSERT_INTO_COLLISION_FMT.format( msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url, value_url=value.url,
target_url=target.url, target_url=target.url,
keys=keys keys=keys
) )
raise ValueError(msg) raise ValueError(msg) from exc
class SchemaGenerator(BaseSchemaGenerator): class SchemaGenerator(BaseSchemaGenerator):

View File

@ -12,7 +12,6 @@ response content is handled by parsers and renderers.
""" """
import copy import copy
import inspect import inspect
import traceback
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from collections.abc import Mapping from collections.abc import Mapping
@ -227,12 +226,14 @@ class BaseSerializer(Field):
self._validated_data = self.run_validation(self.initial_data) self._validated_data = self.run_validation(self.initial_data)
except ValidationError as exc: except ValidationError as exc:
self._validated_data = {} self._validated_data = {}
self._exc = exc
self._errors = exc.detail self._errors = exc.detail
else: else:
self._exc = None
self._errors = {} self._errors = {}
if self._errors and raise_exception: if raise_exception and self._exc is not None:
raise ValidationError(self.errors) raise ValidationError(self.errors) from self._exc
return not bool(self._errors) return not bool(self._errors)
@ -429,7 +430,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the validated data'
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc)) raise ValidationError(detail=as_serializer_error(exc)) from exc
return value return value
@ -621,7 +622,7 @@ class ListSerializer(BaseSerializer):
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the validated data'
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc)) raise ValidationError(detail=as_serializer_error(exc)) from exc
return value return value
@ -748,12 +749,14 @@ class ListSerializer(BaseSerializer):
self._validated_data = self.run_validation(self.initial_data) self._validated_data = self.run_validation(self.initial_data)
except ValidationError as exc: except ValidationError as exc:
self._validated_data = [] self._validated_data = []
self._exc = exc
self._errors = exc.detail self._errors = exc.detail
else: else:
self._exc = None
self._errors = [] self._errors = []
if self._errors and raise_exception: if raise_exception and self._exc is not None:
raise ValidationError(self.errors) raise ValidationError(self.errors) from self._exc
return not bool(self._errors) return not bool(self._errors)
@ -960,25 +963,23 @@ class ModelSerializer(Serializer):
try: try:
instance = ModelClass._default_manager.create(**validated_data) instance = ModelClass._default_manager.create(**validated_data)
except TypeError: except TypeError as exc:
tb = traceback.format_exc()
msg = ( msg = (
'Got a `TypeError` when calling `%s.%s.create()`. ' 'Got a `TypeError` when calling `%s.%s.create()`. '
'This may be because you have a writable field on the ' 'This may be because you have a writable field on the '
'serializer class that is not a valid argument to ' 'serializer class that is not a valid argument to '
'`%s.%s.create()`. You may need to make the field ' '`%s.%s.create()`. You may need to make the field '
'read-only, or override the %s.create() method to handle ' 'read-only, or override the %s.create() method to handle '
'this correctly.\nOriginal exception was:\n %s' % 'this correctly.' %
( (
ModelClass.__name__, ModelClass.__name__,
ModelClass._default_manager.name, ModelClass._default_manager.name,
ModelClass.__name__, ModelClass.__name__,
ModelClass._default_manager.name, ModelClass._default_manager.name,
self.__class__.__name__, self.__class__.__name__,
tb
) )
) )
raise TypeError(msg) raise TypeError(msg) from exc
# Save many-to-many relationships after the instance is created. # Save many-to-many relationships after the instance is created.
if many_to_many: if many_to_many:

View File

@ -177,7 +177,7 @@ def import_from_string(val, setting_name):
return import_string(val) return import_string(val)
except ImportError as e: except ImportError as e:
msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
raise ImportError(msg) raise ImportError(msg) from e
class APISettings: class APISettings:

View File

@ -90,9 +90,9 @@ class SimpleRateThrottle(BaseThrottle):
try: try:
return self.THROTTLE_RATES[self.scope] return self.THROTTLE_RATES[self.scope]
except KeyError: except KeyError as exc:
msg = "No default throttle rate set for '%s' scope" % self.scope msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg) raise ImproperlyConfigured(msg) from exc
def parse_rate(self, rate): def parse_rate(self, rate):
""" """