Merge pull request #279 from tomchristie/hyperlinked-relationships

Hyperlinked relationships
This commit is contained in:
Tom Christie 2012-10-04 06:07:26 -07:00
commit 42b3fdbdc2
5 changed files with 276 additions and 149 deletions

View File

@ -9,6 +9,7 @@ from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone
@ -25,21 +26,88 @@ def is_simple_callable(obj):
class Field(object):
creation_counter = 0
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
empty = ''
def __init__(self, source=None, readonly=False, required=None,
validators=[], error_messages=None):
def __init__(self, source=None):
self.parent = None
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
self.source = source
def initialize(self, parent):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
def field_from_native(self, data, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
return
def field_to_native(self, obj, field_name):
"""
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
if is_protected_type(value):
return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
return smart_unicode(value)
def attributes(self):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
if getattr(self, 'type_name', None):
return {'type': self.type_name}
return {}
class WritableField(Field):
"""
Base for read/write fields.
"""
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
def __init__(self, source=None, readonly=False, required=None,
validators=[], error_messages=None):
super(WritableField, self).__init__(source=source)
self.readonly = readonly
if required is None:
self.required = not(readonly)
@ -55,19 +123,6 @@ class Field(object):
self.validators = self.default_validators + validators
def initialize(self, parent, model_field=None):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
if model_field:
self.model_field = model_field
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
@ -116,88 +171,75 @@ class Field(object):
"""
Reverts a simple representation back to the field's value.
"""
if hasattr(self, 'model_field'):
return value
class ModelField(WritableField):
"""
A generic field that can be used against an arbirtrary model field.
"""
def __init__(self, *args, **kwargs):
try:
return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value)
self.model_field = kwargs.pop('model_field')
except:
raise ValueError("ModelField requires 'model_field' kwarg")
super(ModelField, self).__init__(*args, **kwargs)
def from_native(self, value):
try:
rel = self.model_field.rel
except:
return self.model_field.to_python(value)
return value
return rel.to._meta.get_field(rel.field_name).to_python(value)
def field_to_native(self, obj, field_name):
"""
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
self.obj = obj # Need to hang onto this in the case of model fields
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):
return value
elif hasattr(self, 'model_field'):
return self.model_field.value_to_string(self.obj)
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
return smart_unicode(value)
def attributes(self):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
try:
return {
"type": self.model_field.get_internal_type()
}
except AttributeError:
return {}
##### Relational fields #####
class RelatedField(Field):
class RelatedField(WritableField):
"""
A base class for model related fields or related managers.
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
Base class for related model fields.
"""
def __init__(self, *args, **kwargs):
self.queryset = kwargs.pop('queryset', None)
super(RelatedField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
value = getattr(obj, self.source or field_name)
return self.to_native(value)
def attributes(self):
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[(self.source or field_name) + '_id'] = self.from_native(value)
class ManyRelatedField(RelatedField):
"""
Base class for related model managers.
"""
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, field_name, into):
try:
return {
"rel": self.model_field.rel.__class__.__name__,
"to": smart_unicode(self.model_field.rel.to._meta)
}
except AttributeError:
return {}
value = data.getlist(self.source or field_name)
except:
value = data.get(self.source or field_name)
else:
if value == ['']:
value = []
into[field_name] = [self.from_native(item) for item in value]
class PrimaryKeyRelatedField(RelatedField):
@ -206,20 +248,11 @@ class PrimaryKeyRelatedField(RelatedField):
"""
def to_native(self, pk):
"""
You can subclass this method to provide different serialization
behavior based on the pk.
"""
return pk
def field_to_native(self, obj, field_name):
# This is only implemented for performance reasons
#
# We could leave the default `RelatedField.field_to_native()` in place,
# and inside just implement `to_native()` as `return obj.pk`
#
# That would involve an extra database lookup.
try:
# Prefer obj.serializable_value for performance reasons
pk = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedObject (reverse relationship)
@ -228,18 +261,17 @@ class PrimaryKeyRelatedField(RelatedField):
# Forward relationship
return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[field_name + '_id'] = self.from_native(value)
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
class ManyPrimaryKeyRelatedField(ManyRelatedField):
"""
Serializes a to-many related field or related manager to a pk value.
"""
def to_native(self, pk):
return pk
def field_to_native(self, obj, field_name):
try:
# Prefer obj.serializable_value for performance reasons
queryset = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedManager (reverse relationship)
@ -248,40 +280,25 @@ class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
# Forward relationship
return [self.to_native(item.pk) for item in queryset.all()]
def field_from_native(self, data, field_name, into):
try:
value = data.getlist(field_name)
except:
value = data.get(field_name)
else:
if value == ['']:
value = []
into[field_name] = [self.from_native(item) for item in value]
class NaturalKeyRelatedField(RelatedField):
class HyperlinkedIdentityField(Field):
"""
Serializes a model related field or related manager to a natural key value.
A field that represents the model's identity using a hyperlink.
"""
is_natural_key = True # XML renderer handles these differently
def __init__(self, *args, **kwargs):
pass
def to_native(self, obj):
if hasattr(obj, 'natural_key'):
return obj.natural_key()
return obj
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[self.model_field.attname] = self.from_native(value)
def from_native(self, value):
# TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS)
manager = self.model_field.rel.to._default_manager
manager = manager.db_manager(DEFAULT_DB_ALIAS)
return manager.get_by_natural_key(*value).pk
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
view_name = self.parent.opts.view_name
view_kwargs = {'pk': obj.pk}
return reverse(view_name, kwargs=view_kwargs, request=request)
class BooleanField(Field):
##### Typed Fields #####
class BooleanField(WritableField):
type_name = 'BooleanField'
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
}
@ -298,7 +315,9 @@ class BooleanField(Field):
raise ValidationError(self.error_messages['invalid'] % value)
class CharField(Field):
class CharField(WritableField):
type_name = 'CharField'
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
super(CharField, self).__init__(*args, **kwargs)
@ -314,6 +333,8 @@ class CharField(Field):
class EmailField(CharField):
type_name = 'EmailField'
default_error_messages = {
'invalid': _('Enter a valid e-mail address.'),
}
@ -330,7 +351,9 @@ class EmailField(CharField):
return result
class DateField(Field):
class DateField(WritableField):
type_name = 'DateField'
default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be "
u"in YYYY-MM-DD format."),
@ -364,7 +387,9 @@ class DateField(Field):
raise ValidationError(msg)
class DateTimeField(Field):
class DateTimeField(WritableField):
type_name = 'DateTimeField'
default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in "
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
@ -415,7 +440,9 @@ class DateTimeField(Field):
raise ValidationError(msg)
class IntegerField(Field):
class IntegerField(WritableField):
type_name = 'IntegerField'
default_error_messages = {
'invalid': _('Enter a whole number.'),
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
@ -441,7 +468,9 @@ class IntegerField(Field):
return value
class FloatField(Field):
class FloatField(WritableField):
type_name = 'FloatField'
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}

View File

@ -123,16 +123,8 @@ class BaseSerializer(Field):
# Get the explicitly declared fields
for key, field in self.fields.items():
ret[key] = field
# Determine if the declared field corrosponds to a model field.
try:
if key == 'pk':
model_field = obj._meta.pk
else:
model_field = obj._meta.get_field_by_name(key)[0]
except:
model_field = None
# Set up the field
field.initialize(parent=self, model_field=model_field)
field.initialize(parent=self)
# Add in the default fields
fields = self.default_fields(serialize, obj, data, nested)
@ -157,12 +149,12 @@ class BaseSerializer(Field):
#####
# Field methods - used when the serializer class is itself used as a field.
def initialize(self, parent, model_field=None):
def initialize(self, parent):
"""
Same behaviour as usual Field, except that we need to keep track
of state so that we can deal with handling maximum depth and recursion.
"""
super(BaseSerializer, self).initialize(parent, model_field)
super(BaseSerializer, self).initialize(parent)
self.stack = parent.stack[:]
if parent.opts.nested and not isinstance(parent.opts.nested, bool):
self.opts.nested = parent.opts.nested - 1
@ -296,12 +288,22 @@ class ModelSerializerOptions(SerializerOptions):
self.model = getattr(meta, 'model', None)
class ModelSerializer(RelatedField, Serializer):
class ModelSerializer(Serializer):
"""
A serializer that deals with model instances and querysets.
"""
_options_class = ModelSerializerOptions
def field_to_native(self, obj, field_name):
"""
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def default_fields(self, serialize, obj=None, data=None, nested=False):
"""
Return all the fields that should be serialized for the model.
@ -330,7 +332,7 @@ class ModelSerializer(RelatedField, Serializer):
field = self.get_field(model_field)
if field:
field.initialize(parent=self, model_field=model_field)
field.initialize(parent=self)
ret[model_field.name] = field
return ret
@ -339,7 +341,7 @@ class ModelSerializer(RelatedField, Serializer):
"""
Returns a default instance of the pk field.
"""
return Field(readonly=True)
return Field()
def get_nested_field(self, model_field):
"""
@ -373,7 +375,7 @@ class ModelSerializer(RelatedField, Serializer):
try:
return field_mapping[model_field.__class__]()
except KeyError:
return Field()
return ModelField(model_field=model_field)
def restore_object(self, attrs, instance=None):
"""
@ -396,3 +398,40 @@ class ModelSerializer(RelatedField, Serializer):
"""
self.object.save()
return self.object.object
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
Options for HyperlinkedModelSerializer
"""
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
class HyperlinkedModelSerializer(ModelSerializer):
"""
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
url = HyperlinkedIdentityField()
def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name()
def _get_default_view_name(self):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
"""
model_meta = self.opts.model._meta
format_kwargs = {
'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower()
}
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
return None

View File

@ -46,7 +46,7 @@ DEFAULTS = {
'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
'PAGINATE_BY': 20,
'PAGINATE_BY': None,
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,

View File

@ -13,7 +13,6 @@ class RootView(generics.ListCreateAPIView):
Example description for OPTIONS.
"""
model = BasicModel
paginate_by = None
class InstanceView(generics.RetrieveUpdateDestroyAPIView):

View File

@ -0,0 +1,60 @@
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
from rest_framework.tests.models import BasicModel
factory = RequestFactory()
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
)
class TestHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self):
"""
Create 3 BasicModel intances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.list_view = BasicList.as_view()
self.detail_view = BasicDetail.as_view()
def test_get_list_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
def test_get_detail_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/1')
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])