Access settings lazily, not at module import

This commit is contained in:
Tom Christie 2015-09-03 16:24:13 +01:00
parent 39ec564ae9
commit 10da18b20b
3 changed files with 82 additions and 49 deletions

View File

@ -9,6 +9,7 @@ import re
import uuid
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.exceptions import ObjectDoesNotExist
from django.core.validators import RegexValidator, ip_address_validators
@ -882,12 +883,11 @@ class DecimalField(Field):
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs):
self.max_digits = max_digits
self.decimal_places = decimal_places
self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string
if coerce_to_string is not None:
self.coerce_to_string = coerce_to_string
self.max_value = max_value
self.min_value = min_value
@ -967,12 +967,14 @@ class DecimalField(Field):
return value
def to_representation(self, value):
coerce_to_string = getattr(self, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING)
if not isinstance(value, decimal.Decimal):
value = decimal.Decimal(six.text_type(value).strip())
quantized = self.quantize(value)
if not self.coerce_to_string:
if not coerce_to_string:
return quantized
return '{0:f}'.format(quantized)
@ -994,15 +996,15 @@ class DateTimeField(Field):
'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'),
'date': _('Expected a datetime but got a date.'),
}
format = api_settings.DATETIME_FORMAT
input_formats = api_settings.DATETIME_INPUT_FORMATS
default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None
datetime_parser = datetime.datetime.strptime
def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs):
self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone
if format is not empty:
self.format = format
if input_formats is not None:
self.input_formats = input_formats
if default_timezone is not None:
self.timezone = default_timezone
super(DateTimeField, self).__init__(*args, **kwargs)
def enforce_timezone(self, value):
@ -1010,21 +1012,31 @@ class DateTimeField(Field):
When `self.default_timezone` is `None`, always return naive datetimes.
When `self.default_timezone` is not `None`, always return aware datetimes.
"""
if (self.default_timezone is not None) and not timezone.is_aware(value):
return timezone.make_aware(value, self.default_timezone)
elif (self.default_timezone is None) and timezone.is_aware(value):
field_timezone = getattr(self, 'timezone', self.default_timezone())
if (field_timezone is not None) and not timezone.is_aware(value):
return timezone.make_aware(value, field_timezone)
elif (field_timezone is None) and timezone.is_aware(value):
return timezone.make_naive(value, timezone.UTC())
return value
def default_timezone(self):
try:
return timezone.get_default_timezone() if settings.USE_TZ else None
except ImproperlyConfigured:
return None
def to_internal_value(self, value):
input_formats = getattr(self, 'input_formats', api_settings.DATETIME_INPUT_FORMATS)
if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
self.fail('date')
if isinstance(value, datetime.datetime):
return self.enforce_timezone(value)
for format in self.input_formats:
if format.lower() == ISO_8601:
for input_format in input_formats:
if input_format.lower() == ISO_8601:
try:
parsed = parse_datetime(value)
except (ValueError, TypeError):
@ -1034,25 +1046,27 @@ class DateTimeField(Field):
return self.enforce_timezone(parsed)
else:
try:
parsed = self.datetime_parser(value, format)
parsed = self.datetime_parser(value, input_format)
except (ValueError, TypeError):
pass
else:
return self.enforce_timezone(parsed)
humanized_format = humanize_datetime.datetime_formats(self.input_formats)
humanized_format = humanize_datetime.datetime_formats(input_formats)
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
if self.format is None:
output_format = getattr(self, 'format', api_settings.DATETIME_FORMAT)
if output_format is None:
return value
if self.format.lower() == ISO_8601:
if output_format.lower() == ISO_8601:
value = value.isoformat()
if value.endswith('+00:00'):
value = value[:-6] + 'Z'
return value
return value.strftime(self.format)
return value.strftime(output_format)
class DateField(Field):
@ -1060,24 +1074,26 @@ class DateField(Field):
'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'),
'datetime': _('Expected a date but got a datetime.'),
}
format = api_settings.DATE_FORMAT
input_formats = api_settings.DATE_INPUT_FORMATS
datetime_parser = datetime.datetime.strptime
def __init__(self, format=empty, input_formats=None, *args, **kwargs):
self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
if format is not empty:
self.format = format
if input_formats is not None:
self.input_formats = input_formats
super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
input_formats = getattr(self, 'input_formats', api_settings.DATE_INPUT_FORMATS)
if isinstance(value, datetime.datetime):
self.fail('datetime')
if isinstance(value, datetime.date):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
for input_format in input_formats:
if input_format.lower() == ISO_8601:
try:
parsed = parse_date(value)
except (ValueError, TypeError):
@ -1087,20 +1103,22 @@ class DateField(Field):
return parsed
else:
try:
parsed = self.datetime_parser(value, format)
parsed = self.datetime_parser(value, input_format)
except (ValueError, TypeError):
pass
else:
return parsed.date()
humanized_format = humanize_datetime.date_formats(self.input_formats)
humanized_format = humanize_datetime.date_formats(input_formats)
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
output_format = getattr(self, 'format', api_settings.DATE_FORMAT)
if not value:
return None
if self.format is None:
if output_format is None:
return value
# Applying a `DateField` to a datetime value is almost always
@ -1112,33 +1130,35 @@ class DateField(Field):
'read-only field and deal with timezone issues explicitly.'
)
if self.format.lower() == ISO_8601:
if output_format.lower() == ISO_8601:
if (isinstance(value, str)):
value = datetime.datetime.strptime(value, '%Y-%m-%d').date()
return value.isoformat()
return value.strftime(self.format)
return value.strftime(output_format)
class TimeField(Field):
default_error_messages = {
'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'),
}
format = api_settings.TIME_FORMAT
input_formats = api_settings.TIME_INPUT_FORMATS
datetime_parser = datetime.datetime.strptime
def __init__(self, format=empty, input_formats=None, *args, **kwargs):
self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
if format is not empty:
self.format = format
if input_formats is not None:
self.input_formats = input_formats
super(TimeField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
input_formats = getattr(self, 'input_formats', api_settings.TIME_INPUT_FORMATS)
if isinstance(value, datetime.time):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
for input_format in input_formats:
if input_format.lower() == ISO_8601:
try:
parsed = parse_time(value)
except (ValueError, TypeError):
@ -1148,17 +1168,19 @@ class TimeField(Field):
return parsed
else:
try:
parsed = self.datetime_parser(value, format)
parsed = self.datetime_parser(value, input_format)
except (ValueError, TypeError):
pass
else:
return parsed.time()
humanized_format = humanize_datetime.time_formats(self.input_formats)
humanized_format = humanize_datetime.time_formats(input_formats)
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
if self.format is None:
output_format = getattr(self, 'format', api_settings.TIME_FORMAT)
if output_format is None:
return value
# Applying a `TimeField` to a datetime value is almost always
@ -1170,9 +1192,9 @@ class TimeField(Field):
'read-only field and deal with timezone issues explicitly.'
)
if self.format.lower() == ISO_8601:
if output_format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
return value.strftime(output_format)
class DurationField(Field):
@ -1316,12 +1338,12 @@ class FileField(Field):
'empty': _('The submitted file is empty.'),
'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'),
}
use_url = api_settings.UPLOADED_FILES_USE_URL
def __init__(self, *args, **kwargs):
self.max_length = kwargs.pop('max_length', None)
self.allow_empty_file = kwargs.pop('allow_empty_file', False)
self.use_url = kwargs.pop('use_url', self.use_url)
if 'use_url' in kwargs:
self.use_url = kwargs.pop('use_url')
super(FileField, self).__init__(*args, **kwargs)
def to_internal_value(self, data):
@ -1342,10 +1364,12 @@ class FileField(Field):
return data
def to_representation(self, value):
use_url = getattr(self, 'use_url', api_settings.UPLOADED_FILES_USE_URL)
if not value:
return None
if self.use_url:
if use_url:
if not getattr(value, 'url', None):
# If the file has not been saved it may not have a URL.
return None

View File

@ -786,7 +786,7 @@ class ModelSerializer(Serializer):
# you'll also need to ensure you update the `create` method on any generic
# views, to correctly handle the 'Location' response header for
# "HTTP 201 Created" responses.
url_field_name = api_settings.URL_FIELD_NAME
url_field_name = None
# Default `create` and `update` behavior...
def create(self, validated_data):
@ -869,6 +869,9 @@ class ModelSerializer(Serializer):
Return the dict of field names -> field instances that should be
used for `self.fields` when instantiating the serializer.
"""
if self.url_field_name is None:
self.url_field_name = api_settings.URL_FIELD_NAME
assert hasattr(self, 'Meta'), (
'Class {serializer_class} missing "Meta" attribute'.format(
serializer_class=self.__class__.__name__

View File

@ -26,7 +26,6 @@ from django.utils import six
from rest_framework import ISO_8601
from rest_framework.compat import importlib
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
DEFAULTS = {
# Base API policies
@ -188,10 +187,17 @@ class APISettings(object):
and return the class, rather than the string literal.
"""
def __init__(self, user_settings=None, defaults=None, import_strings=None):
self.user_settings = user_settings or {}
if user_settings:
self._user_settings = user_settings
self.defaults = defaults or DEFAULTS
self.import_strings = import_strings or IMPORT_STRINGS
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
return self._user_settings
def __getattr__(self, attr):
if attr not in self.defaults.keys():
raise AttributeError("Invalid API setting: '%s'" % attr)
@ -212,7 +218,7 @@ class APISettings(object):
return val
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_api_settings(*args, **kwargs):