Access settings lazily, not at module import

This commit is contained in:
Tom Christie 2015-09-03 16:24:13 +01:00
parent 39ec564ae9
commit 10da18b20b
3 changed files with 82 additions and 49 deletions

View File

@ -9,6 +9,7 @@ import re
import uuid import uuid
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.core.validators import RegexValidator, ip_address_validators from django.core.validators import RegexValidator, ip_address_validators
@ -882,12 +883,11 @@ class DecimalField(Field):
} }
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs): def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs):
self.max_digits = max_digits self.max_digits = max_digits
self.decimal_places = decimal_places self.decimal_places = decimal_places
self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string if coerce_to_string is not None:
self.coerce_to_string = coerce_to_string
self.max_value = max_value self.max_value = max_value
self.min_value = min_value self.min_value = min_value
@ -967,12 +967,14 @@ class DecimalField(Field):
return value return value
def to_representation(self, value): def to_representation(self, value):
coerce_to_string = getattr(self, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING)
if not isinstance(value, decimal.Decimal): if not isinstance(value, decimal.Decimal):
value = decimal.Decimal(six.text_type(value).strip()) value = decimal.Decimal(six.text_type(value).strip())
quantized = self.quantize(value) quantized = self.quantize(value)
if not self.coerce_to_string: if not coerce_to_string:
return quantized return quantized
return '{0:f}'.format(quantized) return '{0:f}'.format(quantized)
@ -994,15 +996,15 @@ class DateTimeField(Field):
'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'), 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'),
'date': _('Expected a datetime but got a date.'), 'date': _('Expected a datetime but got a date.'),
} }
format = api_settings.DATETIME_FORMAT
input_formats = api_settings.DATETIME_INPUT_FORMATS
default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None
datetime_parser = datetime.datetime.strptime datetime_parser = datetime.datetime.strptime
def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs):
self.format = format if format is not empty else self.format if format is not empty:
self.input_formats = input_formats if input_formats is not None else self.input_formats self.format = format
self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone if input_formats is not None:
self.input_formats = input_formats
if default_timezone is not None:
self.timezone = default_timezone
super(DateTimeField, self).__init__(*args, **kwargs) super(DateTimeField, self).__init__(*args, **kwargs)
def enforce_timezone(self, value): def enforce_timezone(self, value):
@ -1010,21 +1012,31 @@ class DateTimeField(Field):
When `self.default_timezone` is `None`, always return naive datetimes. When `self.default_timezone` is `None`, always return naive datetimes.
When `self.default_timezone` is not `None`, always return aware datetimes. When `self.default_timezone` is not `None`, always return aware datetimes.
""" """
if (self.default_timezone is not None) and not timezone.is_aware(value): field_timezone = getattr(self, 'timezone', self.default_timezone())
return timezone.make_aware(value, self.default_timezone)
elif (self.default_timezone is None) and timezone.is_aware(value): if (field_timezone is not None) and not timezone.is_aware(value):
return timezone.make_aware(value, field_timezone)
elif (field_timezone is None) and timezone.is_aware(value):
return timezone.make_naive(value, timezone.UTC()) return timezone.make_naive(value, timezone.UTC())
return value return value
def default_timezone(self):
try:
return timezone.get_default_timezone() if settings.USE_TZ else None
except ImproperlyConfigured:
return None
def to_internal_value(self, value): def to_internal_value(self, value):
input_formats = getattr(self, 'input_formats', api_settings.DATETIME_INPUT_FORMATS)
if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
self.fail('date') self.fail('date')
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
return self.enforce_timezone(value) return self.enforce_timezone(value)
for format in self.input_formats: for input_format in input_formats:
if format.lower() == ISO_8601: if input_format.lower() == ISO_8601:
try: try:
parsed = parse_datetime(value) parsed = parse_datetime(value)
except (ValueError, TypeError): except (ValueError, TypeError):
@ -1034,25 +1046,27 @@ class DateTimeField(Field):
return self.enforce_timezone(parsed) return self.enforce_timezone(parsed)
else: else:
try: try:
parsed = self.datetime_parser(value, format) parsed = self.datetime_parser(value, input_format)
except (ValueError, TypeError): except (ValueError, TypeError):
pass pass
else: else:
return self.enforce_timezone(parsed) return self.enforce_timezone(parsed)
humanized_format = humanize_datetime.datetime_formats(self.input_formats) humanized_format = humanize_datetime.datetime_formats(input_formats)
self.fail('invalid', format=humanized_format) self.fail('invalid', format=humanized_format)
def to_representation(self, value): def to_representation(self, value):
if self.format is None: output_format = getattr(self, 'format', api_settings.DATETIME_FORMAT)
if output_format is None:
return value return value
if self.format.lower() == ISO_8601: if output_format.lower() == ISO_8601:
value = value.isoformat() value = value.isoformat()
if value.endswith('+00:00'): if value.endswith('+00:00'):
value = value[:-6] + 'Z' value = value[:-6] + 'Z'
return value return value
return value.strftime(self.format) return value.strftime(output_format)
class DateField(Field): class DateField(Field):
@ -1060,24 +1074,26 @@ class DateField(Field):
'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'), 'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'),
'datetime': _('Expected a date but got a datetime.'), 'datetime': _('Expected a date but got a datetime.'),
} }
format = api_settings.DATE_FORMAT
input_formats = api_settings.DATE_INPUT_FORMATS
datetime_parser = datetime.datetime.strptime datetime_parser = datetime.datetime.strptime
def __init__(self, format=empty, input_formats=None, *args, **kwargs): def __init__(self, format=empty, input_formats=None, *args, **kwargs):
self.format = format if format is not empty else self.format if format is not empty:
self.input_formats = input_formats if input_formats is not None else self.input_formats self.format = format
if input_formats is not None:
self.input_formats = input_formats
super(DateField, self).__init__(*args, **kwargs) super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value): def to_internal_value(self, value):
input_formats = getattr(self, 'input_formats', api_settings.DATE_INPUT_FORMATS)
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
self.fail('datetime') self.fail('datetime')
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
return value return value
for format in self.input_formats: for input_format in input_formats:
if format.lower() == ISO_8601: if input_format.lower() == ISO_8601:
try: try:
parsed = parse_date(value) parsed = parse_date(value)
except (ValueError, TypeError): except (ValueError, TypeError):
@ -1087,20 +1103,22 @@ class DateField(Field):
return parsed return parsed
else: else:
try: try:
parsed = self.datetime_parser(value, format) parsed = self.datetime_parser(value, input_format)
except (ValueError, TypeError): except (ValueError, TypeError):
pass pass
else: else:
return parsed.date() return parsed.date()
humanized_format = humanize_datetime.date_formats(self.input_formats) humanized_format = humanize_datetime.date_formats(input_formats)
self.fail('invalid', format=humanized_format) self.fail('invalid', format=humanized_format)
def to_representation(self, value): def to_representation(self, value):
output_format = getattr(self, 'format', api_settings.DATE_FORMAT)
if not value: if not value:
return None return None
if self.format is None: if output_format is None:
return value return value
# Applying a `DateField` to a datetime value is almost always # Applying a `DateField` to a datetime value is almost always
@ -1112,33 +1130,35 @@ class DateField(Field):
'read-only field and deal with timezone issues explicitly.' 'read-only field and deal with timezone issues explicitly.'
) )
if self.format.lower() == ISO_8601: if output_format.lower() == ISO_8601:
if (isinstance(value, str)): if (isinstance(value, str)):
value = datetime.datetime.strptime(value, '%Y-%m-%d').date() value = datetime.datetime.strptime(value, '%Y-%m-%d').date()
return value.isoformat() return value.isoformat()
return value.strftime(self.format) return value.strftime(output_format)
class TimeField(Field): class TimeField(Field):
default_error_messages = { default_error_messages = {
'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'), 'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'),
} }
format = api_settings.TIME_FORMAT
input_formats = api_settings.TIME_INPUT_FORMATS
datetime_parser = datetime.datetime.strptime datetime_parser = datetime.datetime.strptime
def __init__(self, format=empty, input_formats=None, *args, **kwargs): def __init__(self, format=empty, input_formats=None, *args, **kwargs):
self.format = format if format is not empty else self.format if format is not empty:
self.input_formats = input_formats if input_formats is not None else self.input_formats self.format = format
if input_formats is not None:
self.input_formats = input_formats
super(TimeField, self).__init__(*args, **kwargs) super(TimeField, self).__init__(*args, **kwargs)
def to_internal_value(self, value): def to_internal_value(self, value):
input_formats = getattr(self, 'input_formats', api_settings.TIME_INPUT_FORMATS)
if isinstance(value, datetime.time): if isinstance(value, datetime.time):
return value return value
for format in self.input_formats: for input_format in input_formats:
if format.lower() == ISO_8601: if input_format.lower() == ISO_8601:
try: try:
parsed = parse_time(value) parsed = parse_time(value)
except (ValueError, TypeError): except (ValueError, TypeError):
@ -1148,17 +1168,19 @@ class TimeField(Field):
return parsed return parsed
else: else:
try: try:
parsed = self.datetime_parser(value, format) parsed = self.datetime_parser(value, input_format)
except (ValueError, TypeError): except (ValueError, TypeError):
pass pass
else: else:
return parsed.time() return parsed.time()
humanized_format = humanize_datetime.time_formats(self.input_formats) humanized_format = humanize_datetime.time_formats(input_formats)
self.fail('invalid', format=humanized_format) self.fail('invalid', format=humanized_format)
def to_representation(self, value): def to_representation(self, value):
if self.format is None: output_format = getattr(self, 'format', api_settings.TIME_FORMAT)
if output_format is None:
return value return value
# Applying a `TimeField` to a datetime value is almost always # Applying a `TimeField` to a datetime value is almost always
@ -1170,9 +1192,9 @@ class TimeField(Field):
'read-only field and deal with timezone issues explicitly.' 'read-only field and deal with timezone issues explicitly.'
) )
if self.format.lower() == ISO_8601: if output_format.lower() == ISO_8601:
return value.isoformat() return value.isoformat()
return value.strftime(self.format) return value.strftime(output_format)
class DurationField(Field): class DurationField(Field):
@ -1316,12 +1338,12 @@ class FileField(Field):
'empty': _('The submitted file is empty.'), 'empty': _('The submitted file is empty.'),
'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'), 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'),
} }
use_url = api_settings.UPLOADED_FILES_USE_URL
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.max_length = kwargs.pop('max_length', None) self.max_length = kwargs.pop('max_length', None)
self.allow_empty_file = kwargs.pop('allow_empty_file', False) self.allow_empty_file = kwargs.pop('allow_empty_file', False)
self.use_url = kwargs.pop('use_url', self.use_url) if 'use_url' in kwargs:
self.use_url = kwargs.pop('use_url')
super(FileField, self).__init__(*args, **kwargs) super(FileField, self).__init__(*args, **kwargs)
def to_internal_value(self, data): def to_internal_value(self, data):
@ -1342,10 +1364,12 @@ class FileField(Field):
return data return data
def to_representation(self, value): def to_representation(self, value):
use_url = getattr(self, 'use_url', api_settings.UPLOADED_FILES_USE_URL)
if not value: if not value:
return None return None
if self.use_url: if use_url:
if not getattr(value, 'url', None): if not getattr(value, 'url', None):
# If the file has not been saved it may not have a URL. # If the file has not been saved it may not have a URL.
return None return None

View File

@ -786,7 +786,7 @@ class ModelSerializer(Serializer):
# you'll also need to ensure you update the `create` method on any generic # you'll also need to ensure you update the `create` method on any generic
# views, to correctly handle the 'Location' response header for # views, to correctly handle the 'Location' response header for
# "HTTP 201 Created" responses. # "HTTP 201 Created" responses.
url_field_name = api_settings.URL_FIELD_NAME url_field_name = None
# Default `create` and `update` behavior... # Default `create` and `update` behavior...
def create(self, validated_data): def create(self, validated_data):
@ -869,6 +869,9 @@ class ModelSerializer(Serializer):
Return the dict of field names -> field instances that should be Return the dict of field names -> field instances that should be
used for `self.fields` when instantiating the serializer. used for `self.fields` when instantiating the serializer.
""" """
if self.url_field_name is None:
self.url_field_name = api_settings.URL_FIELD_NAME
assert hasattr(self, 'Meta'), ( assert hasattr(self, 'Meta'), (
'Class {serializer_class} missing "Meta" attribute'.format( 'Class {serializer_class} missing "Meta" attribute'.format(
serializer_class=self.__class__.__name__ serializer_class=self.__class__.__name__

View File

@ -26,7 +26,6 @@ from django.utils import six
from rest_framework import ISO_8601 from rest_framework import ISO_8601
from rest_framework.compat import importlib from rest_framework.compat import importlib
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
DEFAULTS = { DEFAULTS = {
# Base API policies # Base API policies
@ -188,10 +187,17 @@ class APISettings(object):
and return the class, rather than the string literal. and return the class, rather than the string literal.
""" """
def __init__(self, user_settings=None, defaults=None, import_strings=None): def __init__(self, user_settings=None, defaults=None, import_strings=None):
self.user_settings = user_settings or {} if user_settings:
self._user_settings = user_settings
self.defaults = defaults or DEFAULTS self.defaults = defaults or DEFAULTS
self.import_strings = import_strings or IMPORT_STRINGS self.import_strings = import_strings or IMPORT_STRINGS
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
return self._user_settings
def __getattr__(self, attr): def __getattr__(self, attr):
if attr not in self.defaults.keys(): if attr not in self.defaults.keys():
raise AttributeError("Invalid API setting: '%s'" % attr) raise AttributeError("Invalid API setting: '%s'" % attr)
@ -212,7 +218,7 @@ class APISettings(object):
return val return val
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_api_settings(*args, **kwargs): def reload_api_settings(*args, **kwargs):