From 10da18b20b7feaa7a615d6cddbdabda7968ff98c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 3 Sep 2015 16:24:13 +0100 Subject: [PATCH] Access settings lazily, not at module import --- rest_framework/fields.py | 114 ++++++++++++++++++++-------------- rest_framework/serializers.py | 5 +- rest_framework/settings.py | 12 +++- 3 files changed, 82 insertions(+), 49 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 159784ea3..8b42690d7 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -9,6 +9,7 @@ import re import uuid from django.conf import settings +from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ObjectDoesNotExist from django.core.validators import RegexValidator, ip_address_validators @@ -882,12 +883,11 @@ class DecimalField(Field): } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING - def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places - self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string + if coerce_to_string is not None: + self.coerce_to_string = coerce_to_string self.max_value = max_value self.min_value = min_value @@ -967,12 +967,14 @@ class DecimalField(Field): return value def to_representation(self, value): + coerce_to_string = getattr(self, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING) + if not isinstance(value, decimal.Decimal): value = decimal.Decimal(six.text_type(value).strip()) quantized = self.quantize(value) - if not self.coerce_to_string: + if not coerce_to_string: return quantized return '{0:f}'.format(quantized) @@ -994,15 +996,15 @@ class DateTimeField(Field): 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'), 'date': _('Expected a datetime but got a date.'), } - format = api_settings.DATETIME_FORMAT - input_formats = api_settings.DATETIME_INPUT_FORMATS - default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None datetime_parser = datetime.datetime.strptime def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): - self.format = format if format is not empty else self.format - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone + if format is not empty: + self.format = format + if input_formats is not None: + self.input_formats = input_formats + if default_timezone is not None: + self.timezone = default_timezone super(DateTimeField, self).__init__(*args, **kwargs) def enforce_timezone(self, value): @@ -1010,21 +1012,31 @@ class DateTimeField(Field): When `self.default_timezone` is `None`, always return naive datetimes. When `self.default_timezone` is not `None`, always return aware datetimes. """ - if (self.default_timezone is not None) and not timezone.is_aware(value): - return timezone.make_aware(value, self.default_timezone) - elif (self.default_timezone is None) and timezone.is_aware(value): + field_timezone = getattr(self, 'timezone', self.default_timezone()) + + if (field_timezone is not None) and not timezone.is_aware(value): + return timezone.make_aware(value, field_timezone) + elif (field_timezone is None) and timezone.is_aware(value): return timezone.make_naive(value, timezone.UTC()) return value + def default_timezone(self): + try: + return timezone.get_default_timezone() if settings.USE_TZ else None + except ImproperlyConfigured: + return None + def to_internal_value(self, value): + input_formats = getattr(self, 'input_formats', api_settings.DATETIME_INPUT_FORMATS) + if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): self.fail('date') if isinstance(value, datetime.datetime): return self.enforce_timezone(value) - for format in self.input_formats: - if format.lower() == ISO_8601: + for input_format in input_formats: + if input_format.lower() == ISO_8601: try: parsed = parse_datetime(value) except (ValueError, TypeError): @@ -1034,25 +1046,27 @@ class DateTimeField(Field): return self.enforce_timezone(parsed) else: try: - parsed = self.datetime_parser(value, format) + parsed = self.datetime_parser(value, input_format) except (ValueError, TypeError): pass else: return self.enforce_timezone(parsed) - humanized_format = humanize_datetime.datetime_formats(self.input_formats) + humanized_format = humanize_datetime.datetime_formats(input_formats) self.fail('invalid', format=humanized_format) def to_representation(self, value): - if self.format is None: + output_format = getattr(self, 'format', api_settings.DATETIME_FORMAT) + + if output_format is None: return value - if self.format.lower() == ISO_8601: + if output_format.lower() == ISO_8601: value = value.isoformat() if value.endswith('+00:00'): value = value[:-6] + 'Z' return value - return value.strftime(self.format) + return value.strftime(output_format) class DateField(Field): @@ -1060,24 +1074,26 @@ class DateField(Field): 'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'), 'datetime': _('Expected a date but got a datetime.'), } - format = api_settings.DATE_FORMAT - input_formats = api_settings.DATE_INPUT_FORMATS datetime_parser = datetime.datetime.strptime def __init__(self, format=empty, input_formats=None, *args, **kwargs): - self.format = format if format is not empty else self.format - self.input_formats = input_formats if input_formats is not None else self.input_formats + if format is not empty: + self.format = format + if input_formats is not None: + self.input_formats = input_formats super(DateField, self).__init__(*args, **kwargs) def to_internal_value(self, value): + input_formats = getattr(self, 'input_formats', api_settings.DATE_INPUT_FORMATS) + if isinstance(value, datetime.datetime): self.fail('datetime') if isinstance(value, datetime.date): return value - for format in self.input_formats: - if format.lower() == ISO_8601: + for input_format in input_formats: + if input_format.lower() == ISO_8601: try: parsed = parse_date(value) except (ValueError, TypeError): @@ -1087,20 +1103,22 @@ class DateField(Field): return parsed else: try: - parsed = self.datetime_parser(value, format) + parsed = self.datetime_parser(value, input_format) except (ValueError, TypeError): pass else: return parsed.date() - humanized_format = humanize_datetime.date_formats(self.input_formats) + humanized_format = humanize_datetime.date_formats(input_formats) self.fail('invalid', format=humanized_format) def to_representation(self, value): + output_format = getattr(self, 'format', api_settings.DATE_FORMAT) + if not value: return None - if self.format is None: + if output_format is None: return value # Applying a `DateField` to a datetime value is almost always @@ -1112,33 +1130,35 @@ class DateField(Field): 'read-only field and deal with timezone issues explicitly.' ) - if self.format.lower() == ISO_8601: + if output_format.lower() == ISO_8601: if (isinstance(value, str)): value = datetime.datetime.strptime(value, '%Y-%m-%d').date() return value.isoformat() - return value.strftime(self.format) + return value.strftime(output_format) class TimeField(Field): default_error_messages = { 'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'), } - format = api_settings.TIME_FORMAT - input_formats = api_settings.TIME_INPUT_FORMATS datetime_parser = datetime.datetime.strptime def __init__(self, format=empty, input_formats=None, *args, **kwargs): - self.format = format if format is not empty else self.format - self.input_formats = input_formats if input_formats is not None else self.input_formats + if format is not empty: + self.format = format + if input_formats is not None: + self.input_formats = input_formats super(TimeField, self).__init__(*args, **kwargs) def to_internal_value(self, value): + input_formats = getattr(self, 'input_formats', api_settings.TIME_INPUT_FORMATS) + if isinstance(value, datetime.time): return value - for format in self.input_formats: - if format.lower() == ISO_8601: + for input_format in input_formats: + if input_format.lower() == ISO_8601: try: parsed = parse_time(value) except (ValueError, TypeError): @@ -1148,17 +1168,19 @@ class TimeField(Field): return parsed else: try: - parsed = self.datetime_parser(value, format) + parsed = self.datetime_parser(value, input_format) except (ValueError, TypeError): pass else: return parsed.time() - humanized_format = humanize_datetime.time_formats(self.input_formats) + humanized_format = humanize_datetime.time_formats(input_formats) self.fail('invalid', format=humanized_format) def to_representation(self, value): - if self.format is None: + output_format = getattr(self, 'format', api_settings.TIME_FORMAT) + + if output_format is None: return value # Applying a `TimeField` to a datetime value is almost always @@ -1170,9 +1192,9 @@ class TimeField(Field): 'read-only field and deal with timezone issues explicitly.' ) - if self.format.lower() == ISO_8601: + if output_format.lower() == ISO_8601: return value.isoformat() - return value.strftime(self.format) + return value.strftime(output_format) class DurationField(Field): @@ -1316,12 +1338,12 @@ class FileField(Field): 'empty': _('The submitted file is empty.'), 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'), } - use_url = api_settings.UPLOADED_FILES_USE_URL def __init__(self, *args, **kwargs): self.max_length = kwargs.pop('max_length', None) self.allow_empty_file = kwargs.pop('allow_empty_file', False) - self.use_url = kwargs.pop('use_url', self.use_url) + if 'use_url' in kwargs: + self.use_url = kwargs.pop('use_url') super(FileField, self).__init__(*args, **kwargs) def to_internal_value(self, data): @@ -1342,10 +1364,12 @@ class FileField(Field): return data def to_representation(self, value): + use_url = getattr(self, 'use_url', api_settings.UPLOADED_FILES_USE_URL) + if not value: return None - if self.use_url: + if use_url: if not getattr(value, 'url', None): # If the file has not been saved it may not have a URL. return None diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index c7d4405c5..8ba1f3c40 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -786,7 +786,7 @@ class ModelSerializer(Serializer): # you'll also need to ensure you update the `create` method on any generic # views, to correctly handle the 'Location' response header for # "HTTP 201 Created" responses. - url_field_name = api_settings.URL_FIELD_NAME + url_field_name = None # Default `create` and `update` behavior... def create(self, validated_data): @@ -869,6 +869,9 @@ class ModelSerializer(Serializer): Return the dict of field names -> field instances that should be used for `self.fields` when instantiating the serializer. """ + if self.url_field_name is None: + self.url_field_name = api_settings.URL_FIELD_NAME + assert hasattr(self, 'Meta'), ( 'Class {serializer_class} missing "Meta" attribute'.format( serializer_class=self.__class__.__name__ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index e20e51287..bc3868715 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -26,7 +26,6 @@ from django.utils import six from rest_framework import ISO_8601 from rest_framework.compat import importlib -USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) DEFAULTS = { # Base API policies @@ -188,10 +187,17 @@ class APISettings(object): and return the class, rather than the string literal. """ def __init__(self, user_settings=None, defaults=None, import_strings=None): - self.user_settings = user_settings or {} + if user_settings: + self._user_settings = user_settings self.defaults = defaults or DEFAULTS self.import_strings = import_strings or IMPORT_STRINGS + @property + def user_settings(self): + if not hasattr(self, '_user_settings'): + self._user_settings = getattr(settings, 'REST_FRAMEWORK', {}) + return self._user_settings + def __getattr__(self, attr): if attr not in self.defaults.keys(): raise AttributeError("Invalid API setting: '%s'" % attr) @@ -212,7 +218,7 @@ class APISettings(object): return val -api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) +api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS) def reload_api_settings(*args, **kwargs):