From 250755def707e1397876614fa0c08130d9fcc449 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 10:59:51 +0100 Subject: [PATCH] Clean up relational fields queryset usage --- rest_framework/fields.py | 15 +++----- rest_framework/generics.py | 2 +- rest_framework/relations.py | 73 ++++++++++++++++++++----------------- tests/test_generics.py | 4 +- 4 files changed, 49 insertions(+), 45 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a96f9ba89..4f06d1868 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -508,7 +508,7 @@ class DecimalField(Field): class DateField(Field): default_error_messages = { - 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), } input_formats = api_settings.DATE_INPUT_FORMATS format = api_settings.DATE_FORMAT @@ -551,8 +551,7 @@ class DateField(Field): return parsed.date() humanized_format = humanize_datetime.date_formats(self.input_formats) - msg = self.error_messages['invalid'] % humanized_format - raise ValidationError(msg) + self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: @@ -568,7 +567,7 @@ class DateField(Field): class DateTimeField(Field): default_error_messages = { - 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), } input_formats = api_settings.DATETIME_INPUT_FORMATS format = api_settings.DATETIME_FORMAT @@ -617,8 +616,7 @@ class DateTimeField(Field): return parsed humanized_format = humanize_datetime.datetime_formats(self.input_formats) - msg = self.error_messages['invalid'] % humanized_format - raise ValidationError(msg) + self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: @@ -634,7 +632,7 @@ class DateTimeField(Field): class TimeField(Field): default_error_messages = { - 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'), } input_formats = api_settings.TIME_INPUT_FORMATS format = api_settings.TIME_FORMAT @@ -669,8 +667,7 @@ class TimeField(Field): return parsed.time() humanized_format = humanize_datetime.time_formats(self.input_formats) - msg = self.error_messages['invalid'] % humanized_format - raise ValidationError(msg) + self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 338d56a6a..eb6b64efa 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -216,7 +216,7 @@ class GenericAPIView(views.APIView): ) queryset = self.queryset - if isinstance(self.queryset, QuerySet): + if isinstance(queryset, QuerySet): # Ensure queryset is re-evaluated on each request. queryset = queryset.all() return queryset diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 30a252db9..e23a41526 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -2,28 +2,35 @@ from rest_framework.fields import Field from rest_framework.reverse import reverse from django.core.exceptions import ObjectDoesNotExist from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch +from django.db.models.query import QuerySet from rest_framework.compat import urlparse -def get_default_queryset(serializer_class, field_name): - manager = getattr(serializer_class.opts.model, field_name) - if hasattr(manager, 'related'): - # Forward relationships - return manager.related.model._default_manager.all() - # Reverse relationships - return manager.field.rel.to._default_manager.all() - - class RelatedField(Field): def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', None) self.many = kwargs.pop('many', False) + assert self.queryset is not None or kwargs.get('read_only', False), ( + 'Relational field must provide a `queryset` argument, ' + 'or set read_only=`True`.' + ) super(RelatedField, self).__init__(**kwargs) - def bind(self, field_name, parent, root): - super(RelatedField, self).bind(field_name, parent, root) - if self.queryset is None and not self.read_only: - self.queryset = get_default_queryset(parent, self.source) + def get_queryset(self): + queryset = self.queryset + if isinstance(queryset, QuerySet): + # Ensure queryset is re-evaluated whenever used. + queryset = queryset.all() + return queryset + + +class StringRelatedField(Field): + def __init__(self, **kwargs): + kwargs['read_only'] = True + super(StringRelatedField, self).__init__(**kwargs) + + def to_representation(self, value): + return str(value) class PrimaryKeyRelatedField(RelatedField): @@ -33,9 +40,9 @@ class PrimaryKeyRelatedField(RelatedField): 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.', } - def from_native(self, data): + def to_internal_value(self, data): try: - return self.queryset.get(pk=data) + return self.get_queryset().get(pk=data) except ObjectDoesNotExist: self.fail('does_not_exist', pk_value=data) except (TypeError, ValueError): @@ -68,9 +75,9 @@ class HyperlinkedRelatedField(RelatedField): """ lookup_value = view_kwargs[self.lookup_url_kwarg] lookup_kwargs = {self.lookup_field: lookup_value} - return self.queryset.get(**lookup_kwargs) + return self.get_queryset().get(**lookup_kwargs) - def from_native(self, value): + def to_internal_value(self, value): try: http_prefix = value.startswith(('http:', 'https:')) except AttributeError: @@ -102,13 +109,26 @@ class HyperlinkedIdentityField(RelatedField): def __init__(self, **kwargs): kwargs['read_only'] = True + kwargs['source'] = '*' self.view_name = kwargs.pop('view_name') self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) super(HyperlinkedIdentityField, self).__init__(**kwargs) - def get_attribute(self, instance): - return instance + def get_url(self, obj, view_name, request, format): + """ + Given an object, return the URL that hyperlinks to the object. + + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + # Unsaved objects will not yet have a valid URL. + if obj.pk is None: + return None + + lookup_value = getattr(obj, self.lookup_field) + kwargs = {self.lookup_url_kwarg: lookup_value} + return reverse(view_name, kwargs=kwargs, request=request, format=format) def to_representation(self, value): request = self.context.get('request', None) @@ -144,21 +164,6 @@ class HyperlinkedIdentityField(RelatedField): ) raise Exception(msg % self.view_name) - def get_url(self, obj, view_name, request, format): - """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. - """ - # Unsaved objects will not yet have a valid URL. - if obj.pk is None: - return None - - lookup_value = getattr(obj, self.lookup_field) - kwargs = {self.lookup_url_kwarg: lookup_value} - return reverse(view_name, kwargs=kwargs, request=request, format=format) - class SlugRelatedField(RelatedField): def __init__(self, **kwargs): diff --git a/tests/test_generics.py b/tests/test_generics.py index 17bfca2ff..51004edfe 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -547,7 +547,9 @@ class ClassA(models.Model): class ClassASerializer(serializers.ModelSerializer): - childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') + childs = serializers.PrimaryKeyRelatedField( + many=True, queryset=ClassB.objects.all() + ) class Meta: model = ClassA