Merge pull request #277 from tomchristie/related-field-fixes

Related field fixes
This commit is contained in:
Tom Christie 2012-10-03 13:28:22 -07:00
commit 0a769f261e
5 changed files with 107 additions and 35 deletions

View File

@ -7,7 +7,6 @@ from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
from django.db.models.related import RelatedObject
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import parse_date, parse_datetime
@ -181,6 +180,9 @@ class RelatedField(Field):
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
"""
def __init__(self, *args, **kwargs):
self.queryset = kwargs.pop('queryset', None)
super(RelatedField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name)
@ -200,48 +202,61 @@ class RelatedField(Field):
class PrimaryKeyRelatedField(RelatedField):
"""
Serializes a model related field or related manager to a pk value.
Serializes a related field or related object to a pk value.
"""
# Note the we use ModelRelatedField's implementation, as we want to get the
# raw database value directly, since that won't involve another
# database lookup.
#
# An alternative implementation would simply be this...
#
# class PrimaryKeyRelatedField(RelatedField):
# def to_native(self, obj):
# return obj.pk
def to_native(self, pk):
"""
Simply returns the object's pk. You can subclass this method to
provide different serialization behavior of the pk.
(For example returning a URL based on the model's pk.)
You can subclass this method to provide different serialization
behavior based on the pk.
"""
return pk
def field_to_native(self, obj, field_name):
# This is only implemented for performance reasons
#
# We could leave the default `RelatedField.field_to_native()` in place,
# and inside just implement `to_native()` as `return obj.pk`
#
# That would involve an extra database lookup.
try:
obj = obj.serializable_value(self.source or field_name)
pk = obj.serializable_value(self.source or field_name)
except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0]
# RelatedObject (reverse relationship)
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ == 'RelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
elif isinstance(field, RelatedObject):
return self.to_native(obj.pk)
raise
if obj.__class__.__name__ == 'ManyRelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
return self.to_native(obj)
return self.to_native(obj.pk)
# Forward relationship
return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
if hasattr(value, '__iter__'):
into[field_name] = [self.from_native(item) for item in value]
into[field_name + '_id'] = self.from_native(value)
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
"""
Serializes a to-many related field or related manager to a pk value.
"""
def field_to_native(self, obj, field_name):
try:
queryset = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedManager (reverse relationship)
queryset = getattr(obj, self.source or field_name)
return [self.to_native(item.pk) for item in queryset.all()]
# Forward relationship
return [self.to_native(item.pk) for item in queryset.all()]
def field_from_native(self, data, field_name, into):
try:
value = data.getlist(field_name)
except:
value = data.get(field_name)
else:
into[field_name + '_id'] = self.from_native(value)
if value == ['']:
value = []
into[field_name] = [self.from_native(item) for item in value]
class NaturalKeyRelatedField(RelatedField):

View File

@ -246,7 +246,9 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializers.DateField: forms.DateField,
serializers.EmailField: forms.EmailField,
serializers.CharField: forms.CharField,
serializers.BooleanField: forms.BooleanField
serializers.BooleanField: forms.BooleanField,
serializers.PrimaryKeyRelatedField: forms.ModelChoiceField,
serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField
}
# Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
@ -257,12 +259,18 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializer = view.get_serializer(instance=obj)
for k, v in serializer.get_fields(True).items():
print k, v
if v.readonly:
continue
kwargs = {}
if getattr(v, 'queryset', None):
kwargs['queryset'] = getattr(v, 'queryset', None)
try:
fields[k] = field_mapping[v.__class__]()
fields[k] = field_mapping[v.__class__](**kwargs)
except KeyError:
fields[k] = forms.CharField
fields[k] = forms.CharField()
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted

View File

@ -351,7 +351,10 @@ class ModelSerializer(RelatedField, Serializer):
"""
Creates a default instance of a flat relational field.
"""
return PrimaryKeyRelatedField()
queryset = model_field.rel.to._default_manager # .using(db).complex_filter(self.rel.limit_choices_to)
if isinstance(model_field, models.fields.related.ManyToManyField):
return ManyPrimaryKeyRelatedField(queryset=queryset)
return PrimaryKeyRelatedField(queryset=queryset)
def get_field(self, model_field):
"""
@ -365,7 +368,7 @@ class ModelSerializer(RelatedField, Serializer):
models.EmailField: EmailField,
models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField
models.BooleanField: BooleanField,
}
try:
return field_mapping[model_field.__class__]()

View File

@ -122,6 +122,7 @@
{% if response.status_code != 403 %}
{% if post_form %}
<div class="well">
<form action="{{ request.get_full_path }}" method="POST" {% if post_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
<fieldset>
<h2>POST: {{ name }}</h2>
@ -131,7 +132,7 @@
<div class="control-group {% if field.errors %}error{% endif %}">
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
{{ field }}
{{ field|add_class:"input-xlarge" }}
<span class="help-inline">{{ field.help_text }}</span>
{{ field.errors|add_class:"help-block" }}
</div>
@ -142,9 +143,11 @@
</div>
</fieldset>
</form>
</div>
{% endif %}
{% if put_form %}
<div class="well">
<form action="{{ request.get_full_path }}" method="POST" {% if put_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
<fieldset>
<h2>PUT: {{ name }}</h2>
@ -155,7 +158,7 @@
<div class="control-group {% if field.errors %}error{% endif %}">
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
{{ field }}
{{ field|add_class:"input-xlarge" }}
<span class='help-inline'>{{ field.help_text }}</span>
{{ field.errors|add_class:"help-block" }}
</div>
@ -167,6 +170,7 @@
</fieldset>
</form>
</div>
{% endif %}
{% endif %}

View File

@ -160,6 +160,48 @@ class ManyToManyTests(TestCase):
self.assertEquals(instance.pk, 1)
self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor])
def test_create_empty_relationship(self):
"""
Create an instance of a model with a ManyToMany relationship,
containing no items.
"""
data = {'rel': []}
serializer = self.serializer_class(data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
def test_update_empty_relationship(self):
"""
Update an instance of a model with a ManyToMany relationship,
containing no items.
"""
new_anchor = Anchor()
new_anchor.save()
data = {'rel': []}
serializer = self.serializer_class(data, instance=self.instance)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
self.assertEquals(instance.pk, 1)
self.assertEquals(list(instance.rel.all()), [])
def test_create_empty_relationship_flat_data(self):
"""
Create an instance of a model with a ManyToMany relationship,
containing no items, using a representation that does not support
lists (eg form data).
"""
data = {'rel': ''}
serializer = self.serializer_class(data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
# def test_deserialization_for_update(self):
# serializer = self.serializer_class(self.data, instance=self.instance)
# expected = self.instance