From aac6b6cf4e61814e392829b1101ace4789bb0871 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Oct 2012 15:00:23 +0100 Subject: [PATCH 1/5] Tweak comment --- rest_framework/serializers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ae0b3cdf6..4ffff65d5 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -353,7 +353,9 @@ class ModelSerializer(Serializer): """ Creates a default instance of a flat relational field. """ - queryset = model_field.rel.to._default_manager # .using(db).complex_filter(self.rel.limit_choices_to) + # TODO: filter queryset using: + # .using(db).complex_filter(self.rel.limit_choices_to) + queryset = model_field.rel.to._default_manager if isinstance(model_field, models.fields.related.ManyToManyField): return ManyPrimaryKeyRelatedField(queryset=queryset) return PrimaryKeyRelatedField(queryset=queryset) From 55e9cbecac1456f0e1521a4bcceb1ef4f44e5e0b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Oct 2012 15:01:44 +0100 Subject: [PATCH 2/5] Tweaks --- rest_framework/fields.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 32f2d1225..ad2ca5899 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -6,7 +6,6 @@ import warnings from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings -from django.db import DEFAULT_DB_ALIAS from django.utils.encoding import is_protected_type, smart_unicode from django.utils.translation import ugettext_lazy as _ from rest_framework.reverse import reverse @@ -27,6 +26,7 @@ def is_simple_callable(obj): class Field(object): creation_counter = 0 empty = '' + type_name = None def __init__(self, source=None): self.parent = None @@ -90,7 +90,7 @@ class Field(object): """ Returns a dictionary of attributes to be used when serializing to xml. """ - if getattr(self, 'type_name', None): + if self.type_name: return {'type': self.type_name} return {} @@ -233,8 +233,10 @@ class ManyRelatedField(RelatedField): def field_from_native(self, data, field_name, into): try: + # Form data value = data.getlist(self.source or field_name) except: + # Non-form data value = data.get(self.source or field_name) else: if value == ['']: From c91d926b0664981de0fd239a4398dd71367a5911 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Oct 2012 16:58:18 +0100 Subject: [PATCH 3/5] Initial tests for hyperlinked relationships --- rest_framework/fields.py | 93 ++++++++++++++++++- rest_framework/serializers.py | 22 ++++- .../tests/hyperlinkedserializers.py | 74 ++++++++++++++- 3 files changed, 179 insertions(+), 10 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index ad2ca5899..9dbc11944 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -4,7 +4,8 @@ import inspect import warnings from django.core import validators -from django.core.exceptions import ValidationError +from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.urlresolvers import resolve from django.conf import settings from django.utils.encoding import is_protected_type, smart_unicode from django.utils.translation import ugettext_lazy as _ @@ -223,9 +224,9 @@ class RelatedField(WritableField): into[(self.source or field_name) + '_id'] = self.from_native(value) -class ManyRelatedField(RelatedField): +class ManyRelatedMixin(object): """ - Base class for related model managers. + Mixin to convert a related field to a many related field. """ def field_to_native(self, obj, field_name): value = getattr(obj, self.source or field_name) @@ -244,6 +245,15 @@ class ManyRelatedField(RelatedField): into[field_name] = [self.from_native(item) for item in value] +class ManyRelatedField(ManyRelatedMixin, RelatedField): + """ + Base class for related model managers. + """ + pass + + +### PrimaryKey relationships + class PrimaryKeyRelatedField(RelatedField): """ Serializes a related field or related object to a pk value. @@ -283,6 +293,83 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): return [self.to_native(item.pk) for item in queryset.all()] +### Hyperlinked relationships + +class HyperlinkedRelatedField(RelatedField): + pk_url_kwarg = 'pk' + slug_url_kwarg = 'slug' + slug_field = 'slug' + + def __init__(self, *args, **kwargs): + try: + self.view_name = kwargs.pop('view_name') + except: + raise ValueError("Hyperlinked field requires 'view_name' kwarg") + super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) + + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + kwargs = {self.pk_url_kwarg: obj.pk} + try: + return reverse(view_name, kwargs=kwargs, request=request) + except: + pass + + slug = getattr(obj, self.slug_field, None) + + if not slug: + raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request) + except: + pass + + kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request) + except: + pass + + raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + + def from_native(self, value): + # Convert URL -> model instance pk + try: + match = resolve(value) + except: + raise ValidationError('Invalid hyperlink - No URL match') + + if match.url_name != self.view_name: + raise ValidationError('Invalid hyperlink - Incorrect URL match') + + pk = match.kwargs.get(self.pk_url_kwarg, None) + slug = match.kwargs.get(self.slug_url_kwarg, None) + + # Try explicit primary key. + if pk is not None: + return pk + # Next, try looking up by slug. + elif slug is not None: + slug_field = self.get_slug_field() + queryset = self.queryset.filter(**{slug_field: slug}) + # If none of those are defined, it's an error. + else: + raise ValidationError('Invalid hyperlink') + + try: + obj = queryset.get() + except ObjectDoesNotExist: + raise ValidationError('Invalid hyperlink - object does not exist.') + return obj.pk + + +class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): + pass + + class HyperlinkedIdentityField(Field): """ A field that represents the model's identity using a hyperlink. diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4ffff65d5..ba8bf8ad2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -422,13 +422,13 @@ class HyperlinkedModelSerializer(ModelSerializer): def __init__(self, *args, **kwargs): super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) if self.opts.view_name is None: - self.opts.view_name = self._get_default_view_name() + self.opts.view_name = self._get_default_view_name(self.opts.model) - def _get_default_view_name(self): + def _get_default_view_name(self, model): """ Return the view name to use if 'view_name' is not specified in 'Meta' """ - model_meta = self.opts.model._meta + model_meta = model._meta format_kwargs = { 'app_label': model_meta.app_label, 'model_name': model_meta.object_name.lower() @@ -437,3 +437,19 @@ class HyperlinkedModelSerializer(ModelSerializer): def get_pk_field(self, model_field): return None + + def get_related_field(self, model_field): + """ + Creates a default instance of a flat relational field. + """ + # TODO: filter queryset using: + # .using(db).complex_filter(self.rel.limit_choices_to) + rel = model_field.rel.to + queryset = rel._default_manager + kwargs = { + 'queryset': queryset, + 'view_name': self._get_default_view_name(rel) + } + if isinstance(model_field, models.fields.related.ManyToManyField): + return ManyHyperlinkedRelatedField(**kwargs) + return HyperlinkedRelatedField(**kwargs) diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 4f9393aa9..5532a8ee0 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -2,7 +2,7 @@ from django.conf.urls.defaults import patterns, url from django.test import TestCase from django.test.client import RequestFactory from rest_framework import generics, status, serializers -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel factory = RequestFactory() @@ -17,13 +17,31 @@ class BasicDetail(generics.RetrieveUpdateDestroyAPIView): model_serializer_class = serializers.HyperlinkedModelSerializer +class AnchorDetail(generics.RetrieveAPIView): + model = Anchor + model_serializer_class = serializers.HyperlinkedModelSerializer + + +class ManyToManyList(generics.ListAPIView): + model = ManyToManyModel + model_serializer_class = serializers.HyperlinkedModelSerializer + + +class ManyToManyDetail(generics.RetrieveAPIView): + model = ManyToManyModel + model_serializer_class = serializers.HyperlinkedModelSerializer + + urlpatterns = patterns('', url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/(?P\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), + url(r'^anchor/(?P\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), + url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), + url(r'^manytomany/(?P\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), ) -class TestHyperlinkedView(TestCase): +class TestBasicHyperlinkedView(TestCase): urls = 'rest_framework.tests.hyperlinkedserializers' def setUp(self): @@ -45,7 +63,7 @@ class TestHyperlinkedView(TestCase): """ GET requests to ListCreateAPIView should return list of objects. """ - request = factory.get('/') + request = factory.get('/basic/') response = self.list_view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data) @@ -54,7 +72,55 @@ class TestHyperlinkedView(TestCase): """ GET requests to ListCreateAPIView should return list of objects. """ - request = factory.get('/1') + request = factory.get('/basic/1') + response = self.detail_view(request, pk=1).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data[0]) + + +class TestManyToManyHyperlinkedView(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create 3 BasicModel intances. + """ + items = ['foo', 'bar', 'baz'] + anchors = [] + for item in items: + anchor = Anchor(text=item) + anchor.save() + anchors.append(anchor) + + manytomany = ManyToManyModel() + manytomany.save() + manytomany.rel.add(*anchors) + + self.data = [{ + 'url': 'http://testserver/manytomany/1/', + 'rel': [ + 'http://testserver/anchor/1/', + 'http://testserver/anchor/2/', + 'http://testserver/anchor/3/', + ] + }] + self.list_view = ManyToManyList.as_view() + self.detail_view = ManyToManyDetail.as_view() + + def test_get_list_view(self): + """ + GET requests to ListCreateAPIView should return list of objects. + """ + request = factory.get('/manytomany/') + response = self.list_view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + def test_get_detail_view(self): + """ + GET requests to ListCreateAPIView should return list of objects. + """ + request = factory.get('/manytomany/1/') response = self.detail_view(request, pk=1).render() self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data[0]) From cc0d2601b8dfdf3f5fcee8591540b9cb4b2f3e44 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Oct 2012 21:36:40 +0100 Subject: [PATCH 4/5] Minor fixes --- rest_framework/fields.py | 5 +---- rest_framework/renderers.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9dbc11944..be9182357 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -319,7 +319,7 @@ class HyperlinkedRelatedField(RelatedField): slug = getattr(obj, self.slug_field, None) if not slug: - raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) kwargs = {self.slug_url_kwarg: slug} try: @@ -374,9 +374,6 @@ class HyperlinkedIdentityField(Field): """ A field that represents the model's identity using a hyperlink. """ - def __init__(self, *args, **kwargs): - pass - def field_to_native(self, obj, field_name): request = self.context.get('request', None) view_name = self.parent.opts.view_name diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 5bc5d5f8e..e33fa30e9 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -260,7 +260,7 @@ class DocumentingHTMLRenderer(BaseRenderer): serializer = view.get_serializer(instance=obj) for k, v in serializer.get_fields(True).items(): print k, v - if v.readonly: + if getattr(v, 'readonly', True): continue kwargs = {} From 693892ed0104b8ce8cd801e7bec6107feeb88782 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Oct 2012 22:07:24 +0100 Subject: [PATCH 5/5] Fix for field to make it easier to access field relationships --- rest_framework/fields.py | 6 ++++- rest_framework/tests/genericrelations.py | 33 ++++++++++++++++++++++++ rest_framework/tests/models.py | 24 +++++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 rest_framework/tests/genericrelations.py diff --git a/rest_framework/fields.py b/rest_framework/fields.py index be9182357..b9ac37768 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -83,6 +83,10 @@ class Field(object): if is_protected_type(value): return value + + all_callable = getattr(value, 'all', None) + if is_simple_callable(all_callable): + return [self.to_native(item) for item in value.all()] elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): return [self.to_native(item) for item in value] return smart_unicode(value) @@ -197,7 +201,7 @@ class ModelField(WritableField): value = self.model_field._get_val_from_obj(obj) if is_protected_type(value): return value - return self.model_field.value_to_string(self.obj) + return self.model_field.value_to_string(obj) def attributes(self): return { diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py new file mode 100644 index 000000000..d88a6c06e --- /dev/null +++ b/rest_framework/tests/genericrelations.py @@ -0,0 +1,33 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import * + + +class TestGenericRelations(TestCase): + def setUp(self): + bookmark = Bookmark(url='https://www.djangoproject.com/') + bookmark.save() + django = Tag(tag_name='django') + django.save() + python = Tag(tag_name='python') + python.save() + t1 = TaggedItem(content_object=bookmark, tag=django) + t1.save() + t2 = TaggedItem(content_object=bookmark, tag=python) + t2.save() + self.bookmark = bookmark + + def test_reverse_generic_relation(self): + class BookmarkSerializer(serializers.ModelSerializer): + tags = serializers.Field(source='tags') + + class Meta: + model = Bookmark + exclude = ('id',) + + serializer = BookmarkSerializer(instance=self.bookmark) + expected = { + 'tags': [u'django', u'python'], + 'url': u'https://www.djangoproject.com/' + } + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 969c82978..7c7f485b1 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -1,4 +1,7 @@ from django.db import models +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation + # from django.contrib.auth.models import Group @@ -59,3 +62,24 @@ class CallableDefaultValueModel(RESTFrameworkModel): class ManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) + +# Models to test generic relations + + +class Tag(RESTFrameworkModel): + tag_name = models.SlugField() + + +class TaggedItem(RESTFrameworkModel): + tag = models.ForeignKey(Tag, related_name='items') + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey('content_type', 'object_id') + + def __unicode__(self): + return self.tag.tag_name + + +class Bookmark(RESTFrameworkModel): + url = models.URLField() + tags = GenericRelation(TaggedItem)