mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-27 03:54:01 +03:00
Merge pull request #280 from tomchristie/hyperlinked-relationships
Hyperlinked relationships
This commit is contained in:
commit
ad5e6eb16f
|
@ -4,9 +4,9 @@ import inspect
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ObjectDoesNotExist, ValidationError
|
||||||
|
from django.core.urlresolvers import resolve
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import DEFAULT_DB_ALIAS
|
|
||||||
from django.utils.encoding import is_protected_type, smart_unicode
|
from django.utils.encoding import is_protected_type, smart_unicode
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
from rest_framework.reverse import reverse
|
from rest_framework.reverse import reverse
|
||||||
|
@ -27,6 +27,7 @@ def is_simple_callable(obj):
|
||||||
class Field(object):
|
class Field(object):
|
||||||
creation_counter = 0
|
creation_counter = 0
|
||||||
empty = ''
|
empty = ''
|
||||||
|
type_name = None
|
||||||
|
|
||||||
def __init__(self, source=None):
|
def __init__(self, source=None):
|
||||||
self.parent = None
|
self.parent = None
|
||||||
|
@ -82,6 +83,10 @@ class Field(object):
|
||||||
|
|
||||||
if is_protected_type(value):
|
if is_protected_type(value):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
all_callable = getattr(value, 'all', None)
|
||||||
|
if is_simple_callable(all_callable):
|
||||||
|
return [self.to_native(item) for item in value.all()]
|
||||||
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
|
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
|
||||||
return [self.to_native(item) for item in value]
|
return [self.to_native(item) for item in value]
|
||||||
return smart_unicode(value)
|
return smart_unicode(value)
|
||||||
|
@ -90,7 +95,7 @@ class Field(object):
|
||||||
"""
|
"""
|
||||||
Returns a dictionary of attributes to be used when serializing to xml.
|
Returns a dictionary of attributes to be used when serializing to xml.
|
||||||
"""
|
"""
|
||||||
if getattr(self, 'type_name', None):
|
if self.type_name:
|
||||||
return {'type': self.type_name}
|
return {'type': self.type_name}
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -196,7 +201,7 @@ class ModelField(WritableField):
|
||||||
value = self.model_field._get_val_from_obj(obj)
|
value = self.model_field._get_val_from_obj(obj)
|
||||||
if is_protected_type(value):
|
if is_protected_type(value):
|
||||||
return value
|
return value
|
||||||
return self.model_field.value_to_string(self.obj)
|
return self.model_field.value_to_string(obj)
|
||||||
|
|
||||||
def attributes(self):
|
def attributes(self):
|
||||||
return {
|
return {
|
||||||
|
@ -223,9 +228,9 @@ class RelatedField(WritableField):
|
||||||
into[(self.source or field_name) + '_id'] = self.from_native(value)
|
into[(self.source or field_name) + '_id'] = self.from_native(value)
|
||||||
|
|
||||||
|
|
||||||
class ManyRelatedField(RelatedField):
|
class ManyRelatedMixin(object):
|
||||||
"""
|
"""
|
||||||
Base class for related model managers.
|
Mixin to convert a related field to a many related field.
|
||||||
"""
|
"""
|
||||||
def field_to_native(self, obj, field_name):
|
def field_to_native(self, obj, field_name):
|
||||||
value = getattr(obj, self.source or field_name)
|
value = getattr(obj, self.source or field_name)
|
||||||
|
@ -233,8 +238,10 @@ class ManyRelatedField(RelatedField):
|
||||||
|
|
||||||
def field_from_native(self, data, field_name, into):
|
def field_from_native(self, data, field_name, into):
|
||||||
try:
|
try:
|
||||||
|
# Form data
|
||||||
value = data.getlist(self.source or field_name)
|
value = data.getlist(self.source or field_name)
|
||||||
except:
|
except:
|
||||||
|
# Non-form data
|
||||||
value = data.get(self.source or field_name)
|
value = data.get(self.source or field_name)
|
||||||
else:
|
else:
|
||||||
if value == ['']:
|
if value == ['']:
|
||||||
|
@ -242,6 +249,15 @@ class ManyRelatedField(RelatedField):
|
||||||
into[field_name] = [self.from_native(item) for item in value]
|
into[field_name] = [self.from_native(item) for item in value]
|
||||||
|
|
||||||
|
|
||||||
|
class ManyRelatedField(ManyRelatedMixin, RelatedField):
|
||||||
|
"""
|
||||||
|
Base class for related model managers.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
### PrimaryKey relationships
|
||||||
|
|
||||||
class PrimaryKeyRelatedField(RelatedField):
|
class PrimaryKeyRelatedField(RelatedField):
|
||||||
"""
|
"""
|
||||||
Serializes a related field or related object to a pk value.
|
Serializes a related field or related object to a pk value.
|
||||||
|
@ -281,13 +297,87 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
|
||||||
return [self.to_native(item.pk) for item in queryset.all()]
|
return [self.to_native(item.pk) for item in queryset.all()]
|
||||||
|
|
||||||
|
|
||||||
|
### Hyperlinked relationships
|
||||||
|
|
||||||
|
class HyperlinkedRelatedField(RelatedField):
|
||||||
|
pk_url_kwarg = 'pk'
|
||||||
|
slug_url_kwarg = 'slug'
|
||||||
|
slug_field = 'slug'
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
self.view_name = kwargs.pop('view_name')
|
||||||
|
except:
|
||||||
|
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
|
||||||
|
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def to_native(self, obj):
|
||||||
|
view_name = self.view_name
|
||||||
|
request = self.context.get('request', None)
|
||||||
|
kwargs = {self.pk_url_kwarg: obj.pk}
|
||||||
|
try:
|
||||||
|
return reverse(view_name, kwargs=kwargs, request=request)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
slug = getattr(obj, self.slug_field, None)
|
||||||
|
|
||||||
|
if not slug:
|
||||||
|
raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
|
||||||
|
|
||||||
|
kwargs = {self.slug_url_kwarg: slug}
|
||||||
|
try:
|
||||||
|
return reverse(self.view_name, kwargs=kwargs, request=request)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
|
||||||
|
try:
|
||||||
|
return reverse(self.view_name, kwargs=kwargs, request=request)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
|
||||||
|
|
||||||
|
def from_native(self, value):
|
||||||
|
# Convert URL -> model instance pk
|
||||||
|
try:
|
||||||
|
match = resolve(value)
|
||||||
|
except:
|
||||||
|
raise ValidationError('Invalid hyperlink - No URL match')
|
||||||
|
|
||||||
|
if match.url_name != self.view_name:
|
||||||
|
raise ValidationError('Invalid hyperlink - Incorrect URL match')
|
||||||
|
|
||||||
|
pk = match.kwargs.get(self.pk_url_kwarg, None)
|
||||||
|
slug = match.kwargs.get(self.slug_url_kwarg, None)
|
||||||
|
|
||||||
|
# Try explicit primary key.
|
||||||
|
if pk is not None:
|
||||||
|
return pk
|
||||||
|
# Next, try looking up by slug.
|
||||||
|
elif slug is not None:
|
||||||
|
slug_field = self.get_slug_field()
|
||||||
|
queryset = self.queryset.filter(**{slug_field: slug})
|
||||||
|
# If none of those are defined, it's an error.
|
||||||
|
else:
|
||||||
|
raise ValidationError('Invalid hyperlink')
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj = queryset.get()
|
||||||
|
except ObjectDoesNotExist:
|
||||||
|
raise ValidationError('Invalid hyperlink - object does not exist.')
|
||||||
|
return obj.pk
|
||||||
|
|
||||||
|
|
||||||
|
class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HyperlinkedIdentityField(Field):
|
class HyperlinkedIdentityField(Field):
|
||||||
"""
|
"""
|
||||||
A field that represents the model's identity using a hyperlink.
|
A field that represents the model's identity using a hyperlink.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def field_to_native(self, obj, field_name):
|
def field_to_native(self, obj, field_name):
|
||||||
request = self.context.get('request', None)
|
request = self.context.get('request', None)
|
||||||
view_name = self.parent.opts.view_name
|
view_name = self.parent.opts.view_name
|
||||||
|
|
|
@ -260,7 +260,7 @@ class DocumentingHTMLRenderer(BaseRenderer):
|
||||||
serializer = view.get_serializer(instance=obj)
|
serializer = view.get_serializer(instance=obj)
|
||||||
for k, v in serializer.get_fields(True).items():
|
for k, v in serializer.get_fields(True).items():
|
||||||
print k, v
|
print k, v
|
||||||
if v.readonly:
|
if getattr(v, 'readonly', True):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
|
@ -353,7 +353,9 @@ class ModelSerializer(Serializer):
|
||||||
"""
|
"""
|
||||||
Creates a default instance of a flat relational field.
|
Creates a default instance of a flat relational field.
|
||||||
"""
|
"""
|
||||||
queryset = model_field.rel.to._default_manager # .using(db).complex_filter(self.rel.limit_choices_to)
|
# TODO: filter queryset using:
|
||||||
|
# .using(db).complex_filter(self.rel.limit_choices_to)
|
||||||
|
queryset = model_field.rel.to._default_manager
|
||||||
if isinstance(model_field, models.fields.related.ManyToManyField):
|
if isinstance(model_field, models.fields.related.ManyToManyField):
|
||||||
return ManyPrimaryKeyRelatedField(queryset=queryset)
|
return ManyPrimaryKeyRelatedField(queryset=queryset)
|
||||||
return PrimaryKeyRelatedField(queryset=queryset)
|
return PrimaryKeyRelatedField(queryset=queryset)
|
||||||
|
@ -420,13 +422,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
|
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
|
||||||
if self.opts.view_name is None:
|
if self.opts.view_name is None:
|
||||||
self.opts.view_name = self._get_default_view_name()
|
self.opts.view_name = self._get_default_view_name(self.opts.model)
|
||||||
|
|
||||||
def _get_default_view_name(self):
|
def _get_default_view_name(self, model):
|
||||||
"""
|
"""
|
||||||
Return the view name to use if 'view_name' is not specified in 'Meta'
|
Return the view name to use if 'view_name' is not specified in 'Meta'
|
||||||
"""
|
"""
|
||||||
model_meta = self.opts.model._meta
|
model_meta = model._meta
|
||||||
format_kwargs = {
|
format_kwargs = {
|
||||||
'app_label': model_meta.app_label,
|
'app_label': model_meta.app_label,
|
||||||
'model_name': model_meta.object_name.lower()
|
'model_name': model_meta.object_name.lower()
|
||||||
|
@ -435,3 +437,19 @@ class HyperlinkedModelSerializer(ModelSerializer):
|
||||||
|
|
||||||
def get_pk_field(self, model_field):
|
def get_pk_field(self, model_field):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_related_field(self, model_field):
|
||||||
|
"""
|
||||||
|
Creates a default instance of a flat relational field.
|
||||||
|
"""
|
||||||
|
# TODO: filter queryset using:
|
||||||
|
# .using(db).complex_filter(self.rel.limit_choices_to)
|
||||||
|
rel = model_field.rel.to
|
||||||
|
queryset = rel._default_manager
|
||||||
|
kwargs = {
|
||||||
|
'queryset': queryset,
|
||||||
|
'view_name': self._get_default_view_name(rel)
|
||||||
|
}
|
||||||
|
if isinstance(model_field, models.fields.related.ManyToManyField):
|
||||||
|
return ManyHyperlinkedRelatedField(**kwargs)
|
||||||
|
return HyperlinkedRelatedField(**kwargs)
|
||||||
|
|
33
rest_framework/tests/genericrelations.py
Normal file
33
rest_framework/tests/genericrelations.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
from rest_framework import serializers
|
||||||
|
from rest_framework.tests.models import *
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericRelations(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
bookmark = Bookmark(url='https://www.djangoproject.com/')
|
||||||
|
bookmark.save()
|
||||||
|
django = Tag(tag_name='django')
|
||||||
|
django.save()
|
||||||
|
python = Tag(tag_name='python')
|
||||||
|
python.save()
|
||||||
|
t1 = TaggedItem(content_object=bookmark, tag=django)
|
||||||
|
t1.save()
|
||||||
|
t2 = TaggedItem(content_object=bookmark, tag=python)
|
||||||
|
t2.save()
|
||||||
|
self.bookmark = bookmark
|
||||||
|
|
||||||
|
def test_reverse_generic_relation(self):
|
||||||
|
class BookmarkSerializer(serializers.ModelSerializer):
|
||||||
|
tags = serializers.Field(source='tags')
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Bookmark
|
||||||
|
exclude = ('id',)
|
||||||
|
|
||||||
|
serializer = BookmarkSerializer(instance=self.bookmark)
|
||||||
|
expected = {
|
||||||
|
'tags': [u'django', u'python'],
|
||||||
|
'url': u'https://www.djangoproject.com/'
|
||||||
|
}
|
||||||
|
self.assertEquals(serializer.data, expected)
|
|
@ -2,7 +2,7 @@ from django.conf.urls.defaults import patterns, url
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.client import RequestFactory
|
from django.test.client import RequestFactory
|
||||||
from rest_framework import generics, status, serializers
|
from rest_framework import generics, status, serializers
|
||||||
from rest_framework.tests.models import BasicModel
|
from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
|
||||||
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
|
|
||||||
|
@ -17,13 +17,31 @@ class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
|
||||||
model_serializer_class = serializers.HyperlinkedModelSerializer
|
model_serializer_class = serializers.HyperlinkedModelSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class AnchorDetail(generics.RetrieveAPIView):
|
||||||
|
model = Anchor
|
||||||
|
model_serializer_class = serializers.HyperlinkedModelSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class ManyToManyList(generics.ListAPIView):
|
||||||
|
model = ManyToManyModel
|
||||||
|
model_serializer_class = serializers.HyperlinkedModelSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class ManyToManyDetail(generics.RetrieveAPIView):
|
||||||
|
model = ManyToManyModel
|
||||||
|
model_serializer_class = serializers.HyperlinkedModelSerializer
|
||||||
|
|
||||||
|
|
||||||
urlpatterns = patterns('',
|
urlpatterns = patterns('',
|
||||||
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
|
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
|
||||||
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
|
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
|
||||||
|
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
|
||||||
|
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
|
||||||
|
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestHyperlinkedView(TestCase):
|
class TestBasicHyperlinkedView(TestCase):
|
||||||
urls = 'rest_framework.tests.hyperlinkedserializers'
|
urls = 'rest_framework.tests.hyperlinkedserializers'
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -45,7 +63,7 @@ class TestHyperlinkedView(TestCase):
|
||||||
"""
|
"""
|
||||||
GET requests to ListCreateAPIView should return list of objects.
|
GET requests to ListCreateAPIView should return list of objects.
|
||||||
"""
|
"""
|
||||||
request = factory.get('/')
|
request = factory.get('/basic/')
|
||||||
response = self.list_view(request).render()
|
response = self.list_view(request).render()
|
||||||
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEquals(response.data, self.data)
|
self.assertEquals(response.data, self.data)
|
||||||
|
@ -54,7 +72,55 @@ class TestHyperlinkedView(TestCase):
|
||||||
"""
|
"""
|
||||||
GET requests to ListCreateAPIView should return list of objects.
|
GET requests to ListCreateAPIView should return list of objects.
|
||||||
"""
|
"""
|
||||||
request = factory.get('/1')
|
request = factory.get('/basic/1')
|
||||||
|
response = self.detail_view(request, pk=1).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEquals(response.data, self.data[0])
|
||||||
|
|
||||||
|
|
||||||
|
class TestManyToManyHyperlinkedView(TestCase):
|
||||||
|
urls = 'rest_framework.tests.hyperlinkedserializers'
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Create 3 BasicModel intances.
|
||||||
|
"""
|
||||||
|
items = ['foo', 'bar', 'baz']
|
||||||
|
anchors = []
|
||||||
|
for item in items:
|
||||||
|
anchor = Anchor(text=item)
|
||||||
|
anchor.save()
|
||||||
|
anchors.append(anchor)
|
||||||
|
|
||||||
|
manytomany = ManyToManyModel()
|
||||||
|
manytomany.save()
|
||||||
|
manytomany.rel.add(*anchors)
|
||||||
|
|
||||||
|
self.data = [{
|
||||||
|
'url': 'http://testserver/manytomany/1/',
|
||||||
|
'rel': [
|
||||||
|
'http://testserver/anchor/1/',
|
||||||
|
'http://testserver/anchor/2/',
|
||||||
|
'http://testserver/anchor/3/',
|
||||||
|
]
|
||||||
|
}]
|
||||||
|
self.list_view = ManyToManyList.as_view()
|
||||||
|
self.detail_view = ManyToManyDetail.as_view()
|
||||||
|
|
||||||
|
def test_get_list_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to ListCreateAPIView should return list of objects.
|
||||||
|
"""
|
||||||
|
request = factory.get('/manytomany/')
|
||||||
|
response = self.list_view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEquals(response.data, self.data)
|
||||||
|
|
||||||
|
def test_get_detail_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to ListCreateAPIView should return list of objects.
|
||||||
|
"""
|
||||||
|
request = factory.get('/manytomany/1/')
|
||||||
response = self.detail_view(request, pk=1).render()
|
response = self.detail_view(request, pk=1).render()
|
||||||
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEquals(response.data, self.data[0])
|
self.assertEquals(response.data, self.data[0])
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.contrib.contenttypes.models import ContentType
|
||||||
|
from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
|
||||||
|
|
||||||
# from django.contrib.auth.models import Group
|
# from django.contrib.auth.models import Group
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,3 +62,24 @@ class CallableDefaultValueModel(RESTFrameworkModel):
|
||||||
|
|
||||||
class ManyToManyModel(RESTFrameworkModel):
|
class ManyToManyModel(RESTFrameworkModel):
|
||||||
rel = models.ManyToManyField(Anchor)
|
rel = models.ManyToManyField(Anchor)
|
||||||
|
|
||||||
|
# Models to test generic relations
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(RESTFrameworkModel):
|
||||||
|
tag_name = models.SlugField()
|
||||||
|
|
||||||
|
|
||||||
|
class TaggedItem(RESTFrameworkModel):
|
||||||
|
tag = models.ForeignKey(Tag, related_name='items')
|
||||||
|
content_type = models.ForeignKey(ContentType)
|
||||||
|
object_id = models.PositiveIntegerField()
|
||||||
|
content_object = GenericForeignKey('content_type', 'object_id')
|
||||||
|
|
||||||
|
def __unicode__(self):
|
||||||
|
return self.tag.tag_name
|
||||||
|
|
||||||
|
|
||||||
|
class Bookmark(RESTFrameworkModel):
|
||||||
|
url = models.URLField()
|
||||||
|
tags = GenericRelation(TaggedItem)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user