diff --git a/rest_framework/fields.py b/rest_framework/fields.py index ad2ca5899..9dbc11944 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -4,7 +4,8 @@ import inspect import warnings from django.core import validators -from django.core.exceptions import ValidationError +from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.urlresolvers import resolve from django.conf import settings from django.utils.encoding import is_protected_type, smart_unicode from django.utils.translation import ugettext_lazy as _ @@ -223,9 +224,9 @@ class RelatedField(WritableField): into[(self.source or field_name) + '_id'] = self.from_native(value) -class ManyRelatedField(RelatedField): +class ManyRelatedMixin(object): """ - Base class for related model managers. + Mixin to convert a related field to a many related field. """ def field_to_native(self, obj, field_name): value = getattr(obj, self.source or field_name) @@ -244,6 +245,15 @@ class ManyRelatedField(RelatedField): into[field_name] = [self.from_native(item) for item in value] +class ManyRelatedField(ManyRelatedMixin, RelatedField): + """ + Base class for related model managers. + """ + pass + + +### PrimaryKey relationships + class PrimaryKeyRelatedField(RelatedField): """ Serializes a related field or related object to a pk value. @@ -283,6 +293,83 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): return [self.to_native(item.pk) for item in queryset.all()] +### Hyperlinked relationships + +class HyperlinkedRelatedField(RelatedField): + pk_url_kwarg = 'pk' + slug_url_kwarg = 'slug' + slug_field = 'slug' + + def __init__(self, *args, **kwargs): + try: + self.view_name = kwargs.pop('view_name') + except: + raise ValueError("Hyperlinked field requires 'view_name' kwarg") + super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) + + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + kwargs = {self.pk_url_kwarg: obj.pk} + try: + return reverse(view_name, kwargs=kwargs, request=request) + except: + pass + + slug = getattr(obj, self.slug_field, None) + + if not slug: + raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request) + except: + pass + + kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request) + except: + pass + + raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + + def from_native(self, value): + # Convert URL -> model instance pk + try: + match = resolve(value) + except: + raise ValidationError('Invalid hyperlink - No URL match') + + if match.url_name != self.view_name: + raise ValidationError('Invalid hyperlink - Incorrect URL match') + + pk = match.kwargs.get(self.pk_url_kwarg, None) + slug = match.kwargs.get(self.slug_url_kwarg, None) + + # Try explicit primary key. + if pk is not None: + return pk + # Next, try looking up by slug. + elif slug is not None: + slug_field = self.get_slug_field() + queryset = self.queryset.filter(**{slug_field: slug}) + # If none of those are defined, it's an error. + else: + raise ValidationError('Invalid hyperlink') + + try: + obj = queryset.get() + except ObjectDoesNotExist: + raise ValidationError('Invalid hyperlink - object does not exist.') + return obj.pk + + +class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): + pass + + class HyperlinkedIdentityField(Field): """ A field that represents the model's identity using a hyperlink. diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4ffff65d5..ba8bf8ad2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -422,13 +422,13 @@ class HyperlinkedModelSerializer(ModelSerializer): def __init__(self, *args, **kwargs): super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) if self.opts.view_name is None: - self.opts.view_name = self._get_default_view_name() + self.opts.view_name = self._get_default_view_name(self.opts.model) - def _get_default_view_name(self): + def _get_default_view_name(self, model): """ Return the view name to use if 'view_name' is not specified in 'Meta' """ - model_meta = self.opts.model._meta + model_meta = model._meta format_kwargs = { 'app_label': model_meta.app_label, 'model_name': model_meta.object_name.lower() @@ -437,3 +437,19 @@ class HyperlinkedModelSerializer(ModelSerializer): def get_pk_field(self, model_field): return None + + def get_related_field(self, model_field): + """ + Creates a default instance of a flat relational field. + """ + # TODO: filter queryset using: + # .using(db).complex_filter(self.rel.limit_choices_to) + rel = model_field.rel.to + queryset = rel._default_manager + kwargs = { + 'queryset': queryset, + 'view_name': self._get_default_view_name(rel) + } + if isinstance(model_field, models.fields.related.ManyToManyField): + return ManyHyperlinkedRelatedField(**kwargs) + return HyperlinkedRelatedField(**kwargs) diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 4f9393aa9..5532a8ee0 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -2,7 +2,7 @@ from django.conf.urls.defaults import patterns, url from django.test import TestCase from django.test.client import RequestFactory from rest_framework import generics, status, serializers -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel factory = RequestFactory() @@ -17,13 +17,31 @@ class BasicDetail(generics.RetrieveUpdateDestroyAPIView): model_serializer_class = serializers.HyperlinkedModelSerializer +class AnchorDetail(generics.RetrieveAPIView): + model = Anchor + model_serializer_class = serializers.HyperlinkedModelSerializer + + +class ManyToManyList(generics.ListAPIView): + model = ManyToManyModel + model_serializer_class = serializers.HyperlinkedModelSerializer + + +class ManyToManyDetail(generics.RetrieveAPIView): + model = ManyToManyModel + model_serializer_class = serializers.HyperlinkedModelSerializer + + urlpatterns = patterns('', url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/(?P\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), + url(r'^anchor/(?P\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), + url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), + url(r'^manytomany/(?P\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), ) -class TestHyperlinkedView(TestCase): +class TestBasicHyperlinkedView(TestCase): urls = 'rest_framework.tests.hyperlinkedserializers' def setUp(self): @@ -45,7 +63,7 @@ class TestHyperlinkedView(TestCase): """ GET requests to ListCreateAPIView should return list of objects. """ - request = factory.get('/') + request = factory.get('/basic/') response = self.list_view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data) @@ -54,7 +72,55 @@ class TestHyperlinkedView(TestCase): """ GET requests to ListCreateAPIView should return list of objects. """ - request = factory.get('/1') + request = factory.get('/basic/1') + response = self.detail_view(request, pk=1).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data[0]) + + +class TestManyToManyHyperlinkedView(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create 3 BasicModel intances. + """ + items = ['foo', 'bar', 'baz'] + anchors = [] + for item in items: + anchor = Anchor(text=item) + anchor.save() + anchors.append(anchor) + + manytomany = ManyToManyModel() + manytomany.save() + manytomany.rel.add(*anchors) + + self.data = [{ + 'url': 'http://testserver/manytomany/1/', + 'rel': [ + 'http://testserver/anchor/1/', + 'http://testserver/anchor/2/', + 'http://testserver/anchor/3/', + ] + }] + self.list_view = ManyToManyList.as_view() + self.detail_view = ManyToManyDetail.as_view() + + def test_get_list_view(self): + """ + GET requests to ListCreateAPIView should return list of objects. + """ + request = factory.get('/manytomany/') + response = self.list_view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + def test_get_detail_view(self): + """ + GET requests to ListCreateAPIView should return list of objects. + """ + request = factory.get('/manytomany/1/') response = self.detail_view(request, pk=1).render() self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data[0])