diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 42abf3ca7..4ecf795cb 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -493,8 +493,9 @@ class HyperlinkedIdentityField(Field): self.view_name = kwargs.pop('view_name', None) # Optionally the format of the target hyperlink may be specified self.format = kwargs.pop('format', None) - - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + lookup_field = kwargs.pop('lookup_field', None) + if lookup_field is not None: + self.lookup_field = lookup_field # These are pending deprecation if 'pk_url_kwarg' in kwargs: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3e5c366ec..4dde0d7c1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -907,6 +907,8 @@ class HyperlinkedModelSerializer(ModelSerializer): def __init__(self, *args, **kwargs): super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) + lookup_field = self.opts.lookup_field + self.fields['url'] = HyperlinkedIdentityField(lookup_field=lookup_field) if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index c73f5e726..fc3a87e99 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -1,13 +1,17 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from django.test.client import RequestFactory -from rest_framework.response import Response -from rest_framework import viewsets +from rest_framework import serializers, viewsets +from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action +from rest_framework.response import Response from rest_framework.routers import SimpleRouter factory = RequestFactory() +urlpatterns = patterns('',) + class BasicViewSet(viewsets.ViewSet): def list(self, request, *args, **kwargs): @@ -49,3 +53,62 @@ class TestSimpleRouter(TestCase): else: method_map = 'get' self.assertEqual(route.mapping[method_map], endpoint) + + +class RouterTestModel(models.Model): + uuid = models.CharField(max_length=20) + text = models.CharField(max_length=200) + + +class TestCustomLookupFields(TestCase): + """ + Ensure that custom lookup fields are correctly routed. + """ + urls = 'rest_framework.tests.test_routers' + + def setUp(self): + class NoteSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + lookup_field = 'uuid' + fields = ('url', 'uuid', 'text') + + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + serializer_class = NoteSerializer + lookup_field = 'uuid' + + RouterTestModel.objects.create(uuid='123', text='foo bar') + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + + from rest_framework.tests import test_routers + urls = getattr(test_routers, 'urlpatterns') + urls += patterns('', + url(r'^', include(self.router.urls)), + ) + + def test_custom_lookup_field_route(self): + detail_route = self.router.urls[-1] + detail_url_pattern = detail_route.regex.pattern + self.assertIn('', detail_url_pattern) + + def test_retrieve_lookup_field_list_view(self): + response = self.client.get('/notes/') + self.assertEquals(response.data, + [{ + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + }] + ) + + def test_retrieve_lookup_field_detail_view(self): + response = self.client.get('/notes/123/') + self.assertEquals(response.data, + { + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + } + ) +