Merge pull request #280 from tomchristie/hyperlinked-relationships

Hyperlinked relationships
This commit is contained in:
Tom Christie 2012-10-04 14:14:56 -07:00
commit ad5e6eb16f
6 changed files with 249 additions and 18 deletions

View File

@ -4,9 +4,9 @@ import inspect
import warnings import warnings
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve
from django.conf import settings from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
from django.utils.encoding import is_protected_type, smart_unicode from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
@ -27,6 +27,7 @@ def is_simple_callable(obj):
class Field(object): class Field(object):
creation_counter = 0 creation_counter = 0
empty = '' empty = ''
type_name = None
def __init__(self, source=None): def __init__(self, source=None):
self.parent = None self.parent = None
@ -82,6 +83,10 @@ class Field(object):
if is_protected_type(value): if is_protected_type(value):
return value return value
all_callable = getattr(value, 'all', None)
if is_simple_callable(all_callable):
return [self.to_native(item) for item in value.all()]
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value] return [self.to_native(item) for item in value]
return smart_unicode(value) return smart_unicode(value)
@ -90,7 +95,7 @@ class Field(object):
""" """
Returns a dictionary of attributes to be used when serializing to xml. Returns a dictionary of attributes to be used when serializing to xml.
""" """
if getattr(self, 'type_name', None): if self.type_name:
return {'type': self.type_name} return {'type': self.type_name}
return {} return {}
@ -196,7 +201,7 @@ class ModelField(WritableField):
value = self.model_field._get_val_from_obj(obj) value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value): if is_protected_type(value):
return value return value
return self.model_field.value_to_string(self.obj) return self.model_field.value_to_string(obj)
def attributes(self): def attributes(self):
return { return {
@ -223,9 +228,9 @@ class RelatedField(WritableField):
into[(self.source or field_name) + '_id'] = self.from_native(value) into[(self.source or field_name) + '_id'] = self.from_native(value)
class ManyRelatedField(RelatedField): class ManyRelatedMixin(object):
""" """
Base class for related model managers. Mixin to convert a related field to a many related field.
""" """
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
@ -233,8 +238,10 @@ class ManyRelatedField(RelatedField):
def field_from_native(self, data, field_name, into): def field_from_native(self, data, field_name, into):
try: try:
# Form data
value = data.getlist(self.source or field_name) value = data.getlist(self.source or field_name)
except: except:
# Non-form data
value = data.get(self.source or field_name) value = data.get(self.source or field_name)
else: else:
if value == ['']: if value == ['']:
@ -242,6 +249,15 @@ class ManyRelatedField(RelatedField):
into[field_name] = [self.from_native(item) for item in value] into[field_name] = [self.from_native(item) for item in value]
class ManyRelatedField(ManyRelatedMixin, RelatedField):
"""
Base class for related model managers.
"""
pass
### PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
""" """
Serializes a related field or related object to a pk value. Serializes a related field or related object to a pk value.
@ -281,13 +297,87 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
return [self.to_native(item.pk) for item in queryset.all()] return [self.to_native(item.pk) for item in queryset.all()]
### Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField):
pk_url_kwarg = 'pk'
slug_url_kwarg = 'slug'
slug_field = 'slug'
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
def to_native(self, obj):
view_name = self.view_name
request = self.context.get('request', None)
kwargs = {self.pk_url_kwarg: obj.pk}
try:
return reverse(view_name, kwargs=kwargs, request=request)
except:
pass
slug = getattr(obj, self.slug_field, None)
if not slug:
raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
kwargs = {self.slug_url_kwarg: slug}
try:
return reverse(self.view_name, kwargs=kwargs, request=request)
except:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
return reverse(self.view_name, kwargs=kwargs, request=request)
except:
pass
raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
def from_native(self, value):
# Convert URL -> model instance pk
try:
match = resolve(value)
except:
raise ValidationError('Invalid hyperlink - No URL match')
if match.url_name != self.view_name:
raise ValidationError('Invalid hyperlink - Incorrect URL match')
pk = match.kwargs.get(self.pk_url_kwarg, None)
slug = match.kwargs.get(self.slug_url_kwarg, None)
# Try explicit primary key.
if pk is not None:
return pk
# Next, try looking up by slug.
elif slug is not None:
slug_field = self.get_slug_field()
queryset = self.queryset.filter(**{slug_field: slug})
# If none of those are defined, it's an error.
else:
raise ValidationError('Invalid hyperlink')
try:
obj = queryset.get()
except ObjectDoesNotExist:
raise ValidationError('Invalid hyperlink - object does not exist.')
return obj.pk
class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
pass
class HyperlinkedIdentityField(Field): class HyperlinkedIdentityField(Field):
""" """
A field that represents the model's identity using a hyperlink. A field that represents the model's identity using a hyperlink.
""" """
def __init__(self, *args, **kwargs):
pass
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
request = self.context.get('request', None) request = self.context.get('request', None)
view_name = self.parent.opts.view_name view_name = self.parent.opts.view_name

View File

@ -260,7 +260,7 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializer = view.get_serializer(instance=obj) serializer = view.get_serializer(instance=obj)
for k, v in serializer.get_fields(True).items(): for k, v in serializer.get_fields(True).items():
print k, v print k, v
if v.readonly: if getattr(v, 'readonly', True):
continue continue
kwargs = {} kwargs = {}

View File

@ -353,7 +353,9 @@ class ModelSerializer(Serializer):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
""" """
queryset = model_field.rel.to._default_manager # .using(db).complex_filter(self.rel.limit_choices_to) # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
queryset = model_field.rel.to._default_manager
if isinstance(model_field, models.fields.related.ManyToManyField): if isinstance(model_field, models.fields.related.ManyToManyField):
return ManyPrimaryKeyRelatedField(queryset=queryset) return ManyPrimaryKeyRelatedField(queryset=queryset)
return PrimaryKeyRelatedField(queryset=queryset) return PrimaryKeyRelatedField(queryset=queryset)
@ -420,13 +422,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
if self.opts.view_name is None: if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name() self.opts.view_name = self._get_default_view_name(self.opts.model)
def _get_default_view_name(self): def _get_default_view_name(self, model):
""" """
Return the view name to use if 'view_name' is not specified in 'Meta' Return the view name to use if 'view_name' is not specified in 'Meta'
""" """
model_meta = self.opts.model._meta model_meta = model._meta
format_kwargs = { format_kwargs = {
'app_label': model_meta.app_label, 'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower() 'model_name': model_meta.object_name.lower()
@ -435,3 +437,19 @@ class HyperlinkedModelSerializer(ModelSerializer):
def get_pk_field(self, model_field): def get_pk_field(self, model_field):
return None return None
def get_related_field(self, model_field):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
queryset = rel._default_manager
kwargs = {
'queryset': queryset,
'view_name': self._get_default_view_name(rel)
}
if isinstance(model_field, models.fields.related.ManyToManyField):
return ManyHyperlinkedRelatedField(**kwargs)
return HyperlinkedRelatedField(**kwargs)

View File

@ -0,0 +1,33 @@
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import *
class TestGenericRelations(TestCase):
def setUp(self):
bookmark = Bookmark(url='https://www.djangoproject.com/')
bookmark.save()
django = Tag(tag_name='django')
django.save()
python = Tag(tag_name='python')
python.save()
t1 = TaggedItem(content_object=bookmark, tag=django)
t1.save()
t2 = TaggedItem(content_object=bookmark, tag=python)
t2.save()
self.bookmark = bookmark
def test_reverse_generic_relation(self):
class BookmarkSerializer(serializers.ModelSerializer):
tags = serializers.Field(source='tags')
class Meta:
model = Bookmark
exclude = ('id',)
serializer = BookmarkSerializer(instance=self.bookmark)
expected = {
'tags': [u'django', u'python'],
'url': u'https://www.djangoproject.com/'
}
self.assertEquals(serializer.data, expected)

View File

@ -2,7 +2,7 @@ from django.conf.urls.defaults import patterns, url
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework import generics, status, serializers from rest_framework import generics, status, serializers
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
factory = RequestFactory() factory = RequestFactory()
@ -17,13 +17,31 @@ class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer model_serializer_class = serializers.HyperlinkedModelSerializer
class AnchorDetail(generics.RetrieveAPIView):
model = Anchor
model_serializer_class = serializers.HyperlinkedModelSerializer
class ManyToManyList(generics.ListAPIView):
model = ManyToManyModel
model_serializer_class = serializers.HyperlinkedModelSerializer
class ManyToManyDetail(generics.RetrieveAPIView):
model = ManyToManyModel
model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
) )
class TestHyperlinkedView(TestCase): class TestBasicHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers' urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self): def setUp(self):
@ -45,7 +63,7 @@ class TestHyperlinkedView(TestCase):
""" """
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/') request = factory.get('/basic/')
response = self.list_view(request).render() response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
@ -54,7 +72,55 @@ class TestHyperlinkedView(TestCase):
""" """
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/1') request = factory.get('/basic/1')
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
class TestManyToManyHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self):
"""
Create 3 BasicModel intances.
"""
items = ['foo', 'bar', 'baz']
anchors = []
for item in items:
anchor = Anchor(text=item)
anchor.save()
anchors.append(anchor)
manytomany = ManyToManyModel()
manytomany.save()
manytomany.rel.add(*anchors)
self.data = [{
'url': 'http://testserver/manytomany/1/',
'rel': [
'http://testserver/anchor/1/',
'http://testserver/anchor/2/',
'http://testserver/anchor/3/',
]
}]
self.list_view = ManyToManyList.as_view()
self.detail_view = ManyToManyDetail.as_view()
def test_get_list_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/')
response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
def test_get_detail_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1).render() response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0]) self.assertEquals(response.data, self.data[0])

View File

@ -1,4 +1,7 @@
from django.db import models from django.db import models
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
# from django.contrib.auth.models import Group # from django.contrib.auth.models import Group
@ -59,3 +62,24 @@ class CallableDefaultValueModel(RESTFrameworkModel):
class ManyToManyModel(RESTFrameworkModel): class ManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor) rel = models.ManyToManyField(Anchor)
# Models to test generic relations
class Tag(RESTFrameworkModel):
tag_name = models.SlugField()
class TaggedItem(RESTFrameworkModel):
tag = models.ForeignKey(Tag, related_name='items')
content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField()
content_object = GenericForeignKey('content_type', 'object_id')
def __unicode__(self):
return self.tag.tag_name
class Bookmark(RESTFrameworkModel):
url = models.URLField()
tags = GenericRelation(TaggedItem)