Merge pull request #279 from tomchristie/hyperlinked-relationships

Hyperlinked relationships
This commit is contained in:
Tom Christie 2012-10-04 06:07:26 -07:00
commit 42b3fdbdc2
5 changed files with 276 additions and 149 deletions

View File

@ -9,6 +9,7 @@ from django.conf import settings
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
from django.utils.encoding import is_protected_type, smart_unicode from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone from rest_framework.compat import timezone
@ -25,21 +26,88 @@ def is_simple_callable(obj):
class Field(object): class Field(object):
creation_counter = 0 creation_counter = 0
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
empty = '' empty = ''
def __init__(self, source=None, readonly=False, required=None, def __init__(self, source=None):
validators=[], error_messages=None):
self.parent = None self.parent = None
self.creation_counter = Field.creation_counter self.creation_counter = Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
self.source = source self.source = source
def initialize(self, parent):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
def field_from_native(self, data, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
return
def field_to_native(self, obj, field_name):
"""
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
if is_protected_type(value):
return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
return smart_unicode(value)
def attributes(self):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
if getattr(self, 'type_name', None):
return {'type': self.type_name}
return {}
class WritableField(Field):
"""
Base for read/write fields.
"""
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
def __init__(self, source=None, readonly=False, required=None,
validators=[], error_messages=None):
super(WritableField, self).__init__(source=source)
self.readonly = readonly self.readonly = readonly
if required is None: if required is None:
self.required = not(readonly) self.required = not(readonly)
@ -55,19 +123,6 @@ class Field(object):
self.validators = self.default_validators + validators self.validators = self.default_validators + validators
def initialize(self, parent, model_field=None):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
if model_field:
self.model_field = model_field
def validate(self, value): def validate(self, value):
if value in validators.EMPTY_VALUES and self.required: if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required']) raise ValidationError(self.error_messages['required'])
@ -116,88 +171,75 @@ class Field(object):
""" """
Reverts a simple representation back to the field's value. Reverts a simple representation back to the field's value.
""" """
if hasattr(self, 'model_field'):
try:
return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value)
except:
return self.model_field.to_python(value)
return value return value
class ModelField(WritableField):
"""
A generic field that can be used against an arbirtrary model field.
"""
def __init__(self, *args, **kwargs):
try:
self.model_field = kwargs.pop('model_field')
except:
raise ValueError("ModelField requires 'model_field' kwarg")
super(ModelField, self).__init__(*args, **kwargs)
def from_native(self, value):
try:
rel = self.model_field.rel
except:
return self.model_field.to_python(value)
return rel.to._meta.get_field(rel.field_name).to_python(value)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
""" value = self.model_field._get_val_from_obj(obj)
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
self.obj = obj # Need to hang onto this in the case of model fields
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
if is_protected_type(value): if is_protected_type(value):
return value return value
elif hasattr(self, 'model_field'): return self.model_field.value_to_string(self.obj)
return self.model_field.value_to_string(self.obj)
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
return smart_unicode(value)
def attributes(self): def attributes(self):
""" return {
Returns a dictionary of attributes to be used when serializing to xml. "type": self.model_field.get_internal_type()
""" }
try:
return { ##### Relational fields #####
"type": self.model_field.get_internal_type()
}
except AttributeError:
return {}
class RelatedField(Field): class RelatedField(WritableField):
""" """
A base class for model related fields or related managers. Base class for related model fields.
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.queryset = kwargs.pop('queryset', None) self.queryset = kwargs.pop('queryset', None)
super(RelatedField, self).__init__(*args, **kwargs) super(RelatedField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name) value = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): return self.to_native(value)
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def attributes(self): def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[(self.source or field_name) + '_id'] = self.from_native(value)
class ManyRelatedField(RelatedField):
"""
Base class for related model managers.
"""
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, field_name, into):
try: try:
return { value = data.getlist(self.source or field_name)
"rel": self.model_field.rel.__class__.__name__, except:
"to": smart_unicode(self.model_field.rel.to._meta) value = data.get(self.source or field_name)
} else:
except AttributeError: if value == ['']:
return {} value = []
into[field_name] = [self.from_native(item) for item in value]
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
@ -206,20 +248,11 @@ class PrimaryKeyRelatedField(RelatedField):
""" """
def to_native(self, pk): def to_native(self, pk):
"""
You can subclass this method to provide different serialization
behavior based on the pk.
"""
return pk return pk
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
# This is only implemented for performance reasons
#
# We could leave the default `RelatedField.field_to_native()` in place,
# and inside just implement `to_native()` as `return obj.pk`
#
# That would involve an extra database lookup.
try: try:
# Prefer obj.serializable_value for performance reasons
pk = obj.serializable_value(self.source or field_name) pk = obj.serializable_value(self.source or field_name)
except AttributeError: except AttributeError:
# RelatedObject (reverse relationship) # RelatedObject (reverse relationship)
@ -228,18 +261,17 @@ class PrimaryKeyRelatedField(RelatedField):
# Forward relationship # Forward relationship
return self.to_native(pk) return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[field_name + '_id'] = self.from_native(value)
class ManyPrimaryKeyRelatedField(ManyRelatedField):
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
""" """
Serializes a to-many related field or related manager to a pk value. Serializes a to-many related field or related manager to a pk value.
""" """
def to_native(self, pk):
return pk
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
try: try:
# Prefer obj.serializable_value for performance reasons
queryset = obj.serializable_value(self.source or field_name) queryset = obj.serializable_value(self.source or field_name)
except AttributeError: except AttributeError:
# RelatedManager (reverse relationship) # RelatedManager (reverse relationship)
@ -248,40 +280,25 @@ class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
# Forward relationship # Forward relationship
return [self.to_native(item.pk) for item in queryset.all()] return [self.to_native(item.pk) for item in queryset.all()]
def field_from_native(self, data, field_name, into):
try:
value = data.getlist(field_name)
except:
value = data.get(field_name)
else:
if value == ['']:
value = []
into[field_name] = [self.from_native(item) for item in value]
class HyperlinkedIdentityField(Field):
class NaturalKeyRelatedField(RelatedField):
""" """
Serializes a model related field or related manager to a natural key value. A field that represents the model's identity using a hyperlink.
""" """
is_natural_key = True # XML renderer handles these differently def __init__(self, *args, **kwargs):
pass
def to_native(self, obj): def field_to_native(self, obj, field_name):
if hasattr(obj, 'natural_key'): request = self.context.get('request', None)
return obj.natural_key() view_name = self.parent.opts.view_name
return obj view_kwargs = {'pk': obj.pk}
return reverse(view_name, kwargs=view_kwargs, request=request)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[self.model_field.attname] = self.from_native(value)
def from_native(self, value):
# TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS)
manager = self.model_field.rel.to._default_manager
manager = manager.db_manager(DEFAULT_DB_ALIAS)
return manager.get_by_natural_key(*value).pk
class BooleanField(Field): ##### Typed Fields #####
class BooleanField(WritableField):
type_name = 'BooleanField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."), 'invalid': _(u"'%s' value must be either True or False."),
} }
@ -298,7 +315,9 @@ class BooleanField(Field):
raise ValidationError(self.error_messages['invalid'] % value) raise ValidationError(self.error_messages['invalid'] % value)
class CharField(Field): class CharField(WritableField):
type_name = 'CharField'
def __init__(self, max_length=None, min_length=None, *args, **kwargs): def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length self.max_length, self.min_length = max_length, min_length
super(CharField, self).__init__(*args, **kwargs) super(CharField, self).__init__(*args, **kwargs)
@ -314,6 +333,8 @@ class CharField(Field):
class EmailField(CharField): class EmailField(CharField):
type_name = 'EmailField'
default_error_messages = { default_error_messages = {
'invalid': _('Enter a valid e-mail address.'), 'invalid': _('Enter a valid e-mail address.'),
} }
@ -330,7 +351,9 @@ class EmailField(CharField):
return result return result
class DateField(Field): class DateField(WritableField):
type_name = 'DateField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be " 'invalid': _(u"'%s' value has an invalid date format. It must be "
u"in YYYY-MM-DD format."), u"in YYYY-MM-DD format."),
@ -364,7 +387,9 @@ class DateField(Field):
raise ValidationError(msg) raise ValidationError(msg)
class DateTimeField(Field): class DateTimeField(WritableField):
type_name = 'DateTimeField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in " 'invalid': _(u"'%s' value has an invalid format. It must be in "
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
@ -415,7 +440,9 @@ class DateTimeField(Field):
raise ValidationError(msg) raise ValidationError(msg)
class IntegerField(Field): class IntegerField(WritableField):
type_name = 'IntegerField'
default_error_messages = { default_error_messages = {
'invalid': _('Enter a whole number.'), 'invalid': _('Enter a whole number.'),
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
@ -441,7 +468,9 @@ class IntegerField(Field):
return value return value
class FloatField(Field): class FloatField(WritableField):
type_name = 'FloatField'
default_error_messages = { default_error_messages = {
'invalid': _("'%s' value must be a float."), 'invalid': _("'%s' value must be a float."),
} }

View File

@ -123,16 +123,8 @@ class BaseSerializer(Field):
# Get the explicitly declared fields # Get the explicitly declared fields
for key, field in self.fields.items(): for key, field in self.fields.items():
ret[key] = field ret[key] = field
# Determine if the declared field corrosponds to a model field.
try:
if key == 'pk':
model_field = obj._meta.pk
else:
model_field = obj._meta.get_field_by_name(key)[0]
except:
model_field = None
# Set up the field # Set up the field
field.initialize(parent=self, model_field=model_field) field.initialize(parent=self)
# Add in the default fields # Add in the default fields
fields = self.default_fields(serialize, obj, data, nested) fields = self.default_fields(serialize, obj, data, nested)
@ -157,12 +149,12 @@ class BaseSerializer(Field):
##### #####
# Field methods - used when the serializer class is itself used as a field. # Field methods - used when the serializer class is itself used as a field.
def initialize(self, parent, model_field=None): def initialize(self, parent):
""" """
Same behaviour as usual Field, except that we need to keep track Same behaviour as usual Field, except that we need to keep track
of state so that we can deal with handling maximum depth and recursion. of state so that we can deal with handling maximum depth and recursion.
""" """
super(BaseSerializer, self).initialize(parent, model_field) super(BaseSerializer, self).initialize(parent)
self.stack = parent.stack[:] self.stack = parent.stack[:]
if parent.opts.nested and not isinstance(parent.opts.nested, bool): if parent.opts.nested and not isinstance(parent.opts.nested, bool):
self.opts.nested = parent.opts.nested - 1 self.opts.nested = parent.opts.nested - 1
@ -296,12 +288,22 @@ class ModelSerializerOptions(SerializerOptions):
self.model = getattr(meta, 'model', None) self.model = getattr(meta, 'model', None)
class ModelSerializer(RelatedField, Serializer): class ModelSerializer(Serializer):
""" """
A serializer that deals with model instances and querysets. A serializer that deals with model instances and querysets.
""" """
_options_class = ModelSerializerOptions _options_class = ModelSerializerOptions
def field_to_native(self, obj, field_name):
"""
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def default_fields(self, serialize, obj=None, data=None, nested=False): def default_fields(self, serialize, obj=None, data=None, nested=False):
""" """
Return all the fields that should be serialized for the model. Return all the fields that should be serialized for the model.
@ -330,7 +332,7 @@ class ModelSerializer(RelatedField, Serializer):
field = self.get_field(model_field) field = self.get_field(model_field)
if field: if field:
field.initialize(parent=self, model_field=model_field) field.initialize(parent=self)
ret[model_field.name] = field ret[model_field.name] = field
return ret return ret
@ -339,7 +341,7 @@ class ModelSerializer(RelatedField, Serializer):
""" """
Returns a default instance of the pk field. Returns a default instance of the pk field.
""" """
return Field(readonly=True) return Field()
def get_nested_field(self, model_field): def get_nested_field(self, model_field):
""" """
@ -373,7 +375,7 @@ class ModelSerializer(RelatedField, Serializer):
try: try:
return field_mapping[model_field.__class__]() return field_mapping[model_field.__class__]()
except KeyError: except KeyError:
return Field() return ModelField(model_field=model_field)
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
""" """
@ -396,3 +398,40 @@ class ModelSerializer(RelatedField, Serializer):
""" """
self.object.save() self.object.save()
return self.object.object return self.object.object
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
Options for HyperlinkedModelSerializer
"""
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
class HyperlinkedModelSerializer(ModelSerializer):
"""
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
url = HyperlinkedIdentityField()
def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name()
def _get_default_view_name(self):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
"""
model_meta = self.opts.model._meta
format_kwargs = {
'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower()
}
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
return None

View File

@ -46,7 +46,7 @@ DEFAULTS = {
'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer', 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer', 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
'PAGINATE_BY': 20, 'PAGINATE_BY': None,
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None, 'UNAUTHENTICATED_TOKEN': None,

View File

@ -13,7 +13,6 @@ class RootView(generics.ListCreateAPIView):
Example description for OPTIONS. Example description for OPTIONS.
""" """
model = BasicModel model = BasicModel
paginate_by = None
class InstanceView(generics.RetrieveUpdateDestroyAPIView): class InstanceView(generics.RetrieveUpdateDestroyAPIView):

View File

@ -0,0 +1,60 @@
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
from rest_framework.tests.models import BasicModel
factory = RequestFactory()
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
)
class TestHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self):
"""
Create 3 BasicModel intances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.list_view = BasicList.as_view()
self.detail_view = BasicDetail.as_view()
def test_get_list_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
def test_get_detail_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/1')
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])