Merge pull request #277 from tomchristie/related-field-fixes

Related field fixes
This commit is contained in:
Tom Christie 2012-10-03 13:28:22 -07:00
commit 0a769f261e
5 changed files with 107 additions and 35 deletions

View File

@ -7,7 +7,6 @@ from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.conf import settings from django.conf import settings
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
from django.db.models.related import RelatedObject
from django.utils.encoding import is_protected_type, smart_unicode from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import parse_date, parse_datetime from rest_framework.compat import parse_date, parse_datetime
@ -181,6 +180,9 @@ class RelatedField(Field):
Subclass this and override `convert` to define custom behaviour when Subclass this and override `convert` to define custom behaviour when
serializing related objects. serializing related objects.
""" """
def __init__(self, *args, **kwargs):
self.queryset = kwargs.pop('queryset', None)
super(RelatedField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name) obj = getattr(obj, self.source or field_name)
@ -200,48 +202,61 @@ class RelatedField(Field):
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
""" """
Serializes a model related field or related manager to a pk value. Serializes a related field or related object to a pk value.
""" """
# Note the we use ModelRelatedField's implementation, as we want to get the
# raw database value directly, since that won't involve another
# database lookup.
#
# An alternative implementation would simply be this...
#
# class PrimaryKeyRelatedField(RelatedField):
# def to_native(self, obj):
# return obj.pk
def to_native(self, pk): def to_native(self, pk):
""" """
Simply returns the object's pk. You can subclass this method to You can subclass this method to provide different serialization
provide different serialization behavior of the pk. behavior based on the pk.
(For example returning a URL based on the model's pk.)
""" """
return pk return pk
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
# This is only implemented for performance reasons
#
# We could leave the default `RelatedField.field_to_native()` in place,
# and inside just implement `to_native()` as `return obj.pk`
#
# That would involve an extra database lookup.
try: try:
obj = obj.serializable_value(self.source or field_name) pk = obj.serializable_value(self.source or field_name)
except AttributeError: except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0] # RelatedObject (reverse relationship)
obj = getattr(obj, self.source or field_name) obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ == 'RelatedManager': return self.to_native(obj.pk)
return [self.to_native(item.pk) for item in obj.all()] # Forward relationship
elif isinstance(field, RelatedObject): return self.to_native(pk)
return self.to_native(obj.pk)
raise
if obj.__class__.__name__ == 'ManyRelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
return self.to_native(obj)
def field_from_native(self, data, field_name, into): def field_from_native(self, data, field_name, into):
value = data.get(field_name) value = data.get(field_name)
if hasattr(value, '__iter__'): into[field_name + '_id'] = self.from_native(value)
into[field_name] = [self.from_native(item) for item in value]
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
"""
Serializes a to-many related field or related manager to a pk value.
"""
def field_to_native(self, obj, field_name):
try:
queryset = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedManager (reverse relationship)
queryset = getattr(obj, self.source or field_name)
return [self.to_native(item.pk) for item in queryset.all()]
# Forward relationship
return [self.to_native(item.pk) for item in queryset.all()]
def field_from_native(self, data, field_name, into):
try:
value = data.getlist(field_name)
except:
value = data.get(field_name)
else: else:
into[field_name + '_id'] = self.from_native(value) if value == ['']:
value = []
into[field_name] = [self.from_native(item) for item in value]
class NaturalKeyRelatedField(RelatedField): class NaturalKeyRelatedField(RelatedField):

View File

@ -246,7 +246,9 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializers.DateField: forms.DateField, serializers.DateField: forms.DateField,
serializers.EmailField: forms.EmailField, serializers.EmailField: forms.EmailField,
serializers.CharField: forms.CharField, serializers.CharField: forms.CharField,
serializers.BooleanField: forms.BooleanField serializers.BooleanField: forms.BooleanField,
serializers.PrimaryKeyRelatedField: forms.ModelChoiceField,
serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField
} }
# Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python # Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
@ -257,12 +259,18 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializer = view.get_serializer(instance=obj) serializer = view.get_serializer(instance=obj)
for k, v in serializer.get_fields(True).items(): for k, v in serializer.get_fields(True).items():
print k, v
if v.readonly: if v.readonly:
continue continue
kwargs = {}
if getattr(v, 'queryset', None):
kwargs['queryset'] = getattr(v, 'queryset', None)
try: try:
fields[k] = field_mapping[v.__class__]() fields[k] = field_mapping[v.__class__](**kwargs)
except KeyError: except KeyError:
fields[k] = forms.CharField fields[k] = forms.CharField()
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted

View File

@ -351,7 +351,10 @@ class ModelSerializer(RelatedField, Serializer):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
""" """
return PrimaryKeyRelatedField() queryset = model_field.rel.to._default_manager # .using(db).complex_filter(self.rel.limit_choices_to)
if isinstance(model_field, models.fields.related.ManyToManyField):
return ManyPrimaryKeyRelatedField(queryset=queryset)
return PrimaryKeyRelatedField(queryset=queryset)
def get_field(self, model_field): def get_field(self, model_field):
""" """
@ -365,7 +368,7 @@ class ModelSerializer(RelatedField, Serializer):
models.EmailField: EmailField, models.EmailField: EmailField,
models.CharField: CharField, models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField, models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField models.BooleanField: BooleanField,
} }
try: try:
return field_mapping[model_field.__class__]() return field_mapping[model_field.__class__]()

View File

@ -122,6 +122,7 @@
{% if response.status_code != 403 %} {% if response.status_code != 403 %}
{% if post_form %} {% if post_form %}
<div class="well">
<form action="{{ request.get_full_path }}" method="POST" {% if post_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> <form action="{{ request.get_full_path }}" method="POST" {% if post_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
<fieldset> <fieldset>
<h2>POST: {{ name }}</h2> <h2>POST: {{ name }}</h2>
@ -131,7 +132,7 @@
<div class="control-group {% if field.errors %}error{% endif %}"> <div class="control-group {% if field.errors %}error{% endif %}">
{{ field.label_tag|add_class:"control-label" }} {{ field.label_tag|add_class:"control-label" }}
<div class="controls"> <div class="controls">
{{ field }} {{ field|add_class:"input-xlarge" }}
<span class="help-inline">{{ field.help_text }}</span> <span class="help-inline">{{ field.help_text }}</span>
{{ field.errors|add_class:"help-block" }} {{ field.errors|add_class:"help-block" }}
</div> </div>
@ -142,9 +143,11 @@
</div> </div>
</fieldset> </fieldset>
</form> </form>
</div>
{% endif %} {% endif %}
{% if put_form %} {% if put_form %}
<div class="well">
<form action="{{ request.get_full_path }}" method="POST" {% if put_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> <form action="{{ request.get_full_path }}" method="POST" {% if put_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
<fieldset> <fieldset>
<h2>PUT: {{ name }}</h2> <h2>PUT: {{ name }}</h2>
@ -155,7 +158,7 @@
<div class="control-group {% if field.errors %}error{% endif %}"> <div class="control-group {% if field.errors %}error{% endif %}">
{{ field.label_tag|add_class:"control-label" }} {{ field.label_tag|add_class:"control-label" }}
<div class="controls"> <div class="controls">
{{ field }} {{ field|add_class:"input-xlarge" }}
<span class='help-inline'>{{ field.help_text }}</span> <span class='help-inline'>{{ field.help_text }}</span>
{{ field.errors|add_class:"help-block" }} {{ field.errors|add_class:"help-block" }}
</div> </div>
@ -167,6 +170,7 @@
</fieldset> </fieldset>
</form> </form>
</div>
{% endif %} {% endif %}
{% endif %} {% endif %}

View File

@ -160,6 +160,48 @@ class ManyToManyTests(TestCase):
self.assertEquals(instance.pk, 1) self.assertEquals(instance.pk, 1)
self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor]) self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor])
def test_create_empty_relationship(self):
"""
Create an instance of a model with a ManyToMany relationship,
containing no items.
"""
data = {'rel': []}
serializer = self.serializer_class(data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
def test_update_empty_relationship(self):
"""
Update an instance of a model with a ManyToMany relationship,
containing no items.
"""
new_anchor = Anchor()
new_anchor.save()
data = {'rel': []}
serializer = self.serializer_class(data, instance=self.instance)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
self.assertEquals(instance.pk, 1)
self.assertEquals(list(instance.rel.all()), [])
def test_create_empty_relationship_flat_data(self):
"""
Create an instance of a model with a ManyToMany relationship,
containing no items, using a representation that does not support
lists (eg form data).
"""
data = {'rel': ''}
serializer = self.serializer_class(data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
# def test_deserialization_for_update(self): # def test_deserialization_for_update(self):
# serializer = self.serializer_class(self.data, instance=self.instance) # serializer = self.serializer_class(self.data, instance=self.instance)
# expected = self.instance # expected = self.instance