Initial tests for hyperlinked relationships

This commit is contained in:
Tom Christie 2012-10-04 16:58:18 +01:00
parent 55e9cbecac
commit c91d926b06
3 changed files with 179 additions and 10 deletions

View File

@ -4,7 +4,8 @@ import inspect
import warnings import warnings
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve
from django.conf import settings from django.conf import settings
from django.utils.encoding import is_protected_type, smart_unicode from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -223,9 +224,9 @@ class RelatedField(WritableField):
into[(self.source or field_name) + '_id'] = self.from_native(value) into[(self.source or field_name) + '_id'] = self.from_native(value)
class ManyRelatedField(RelatedField): class ManyRelatedMixin(object):
""" """
Base class for related model managers. Mixin to convert a related field to a many related field.
""" """
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
@ -244,6 +245,15 @@ class ManyRelatedField(RelatedField):
into[field_name] = [self.from_native(item) for item in value] into[field_name] = [self.from_native(item) for item in value]
class ManyRelatedField(ManyRelatedMixin, RelatedField):
"""
Base class for related model managers.
"""
pass
### PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
""" """
Serializes a related field or related object to a pk value. Serializes a related field or related object to a pk value.
@ -283,6 +293,83 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
return [self.to_native(item.pk) for item in queryset.all()] return [self.to_native(item.pk) for item in queryset.all()]
### Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField):
pk_url_kwarg = 'pk'
slug_url_kwarg = 'slug'
slug_field = 'slug'
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
def to_native(self, obj):
view_name = self.view_name
request = self.context.get('request', None)
kwargs = {self.pk_url_kwarg: obj.pk}
try:
return reverse(view_name, kwargs=kwargs, request=request)
except:
pass
slug = getattr(obj, self.slug_field, None)
if not slug:
raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
kwargs = {self.slug_url_kwarg: slug}
try:
return reverse(self.view_name, kwargs=kwargs, request=request)
except:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
return reverse(self.view_name, kwargs=kwargs, request=request)
except:
pass
raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
def from_native(self, value):
# Convert URL -> model instance pk
try:
match = resolve(value)
except:
raise ValidationError('Invalid hyperlink - No URL match')
if match.url_name != self.view_name:
raise ValidationError('Invalid hyperlink - Incorrect URL match')
pk = match.kwargs.get(self.pk_url_kwarg, None)
slug = match.kwargs.get(self.slug_url_kwarg, None)
# Try explicit primary key.
if pk is not None:
return pk
# Next, try looking up by slug.
elif slug is not None:
slug_field = self.get_slug_field()
queryset = self.queryset.filter(**{slug_field: slug})
# If none of those are defined, it's an error.
else:
raise ValidationError('Invalid hyperlink')
try:
obj = queryset.get()
except ObjectDoesNotExist:
raise ValidationError('Invalid hyperlink - object does not exist.')
return obj.pk
class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
pass
class HyperlinkedIdentityField(Field): class HyperlinkedIdentityField(Field):
""" """
A field that represents the model's identity using a hyperlink. A field that represents the model's identity using a hyperlink.

View File

@ -422,13 +422,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
if self.opts.view_name is None: if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name() self.opts.view_name = self._get_default_view_name(self.opts.model)
def _get_default_view_name(self): def _get_default_view_name(self, model):
""" """
Return the view name to use if 'view_name' is not specified in 'Meta' Return the view name to use if 'view_name' is not specified in 'Meta'
""" """
model_meta = self.opts.model._meta model_meta = model._meta
format_kwargs = { format_kwargs = {
'app_label': model_meta.app_label, 'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower() 'model_name': model_meta.object_name.lower()
@ -437,3 +437,19 @@ class HyperlinkedModelSerializer(ModelSerializer):
def get_pk_field(self, model_field): def get_pk_field(self, model_field):
return None return None
def get_related_field(self, model_field):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
queryset = rel._default_manager
kwargs = {
'queryset': queryset,
'view_name': self._get_default_view_name(rel)
}
if isinstance(model_field, models.fields.related.ManyToManyField):
return ManyHyperlinkedRelatedField(**kwargs)
return HyperlinkedRelatedField(**kwargs)

View File

@ -2,7 +2,7 @@ from django.conf.urls.defaults import patterns, url
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework import generics, status, serializers from rest_framework import generics, status, serializers
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
factory = RequestFactory() factory = RequestFactory()
@ -17,13 +17,31 @@ class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer model_serializer_class = serializers.HyperlinkedModelSerializer
class AnchorDetail(generics.RetrieveAPIView):
model = Anchor
model_serializer_class = serializers.HyperlinkedModelSerializer
class ManyToManyList(generics.ListAPIView):
model = ManyToManyModel
model_serializer_class = serializers.HyperlinkedModelSerializer
class ManyToManyDetail(generics.RetrieveAPIView):
model = ManyToManyModel
model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
) )
class TestHyperlinkedView(TestCase): class TestBasicHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers' urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self): def setUp(self):
@ -45,7 +63,7 @@ class TestHyperlinkedView(TestCase):
""" """
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/') request = factory.get('/basic/')
response = self.list_view(request).render() response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
@ -54,7 +72,55 @@ class TestHyperlinkedView(TestCase):
""" """
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/1') request = factory.get('/basic/1')
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
class TestManyToManyHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self):
"""
Create 3 BasicModel intances.
"""
items = ['foo', 'bar', 'baz']
anchors = []
for item in items:
anchor = Anchor(text=item)
anchor.save()
anchors.append(anchor)
manytomany = ManyToManyModel()
manytomany.save()
manytomany.rel.add(*anchors)
self.data = [{
'url': 'http://testserver/manytomany/1/',
'rel': [
'http://testserver/anchor/1/',
'http://testserver/anchor/2/',
'http://testserver/anchor/3/',
]
}]
self.list_view = ManyToManyList.as_view()
self.detail_view = ManyToManyDetail.as_view()
def test_get_list_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/')
response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
def test_get_detail_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1).render() response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0]) self.assertEquals(response.data, self.data[0])