From 5d247a65c89594a7ab5ce2333612f23eadc6828d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 Oct 2014 15:11:19 +0100 Subject: [PATCH] First pass on nested serializers in HTML --- docs/tutorial/quickstart.md | 8 ++- rest_framework/compat.py | 16 ++++- rest_framework/fields.py | 28 ++++++-- rest_framework/relations.py | 20 +++++- rest_framework/renderers.py | 10 ++- rest_framework/serializers.py | 37 +++++++--- .../fields/horizontal/fieldset.html | 5 +- .../fields/horizontal/list_fieldset.html | 13 ++++ .../fields/inline/fieldset.html | 5 +- .../fields/vertical/fieldset.html | 5 +- .../fields/vertical/list_fieldset.html | 7 ++ tests/test_bound_fields.py | 69 +++++++++++++++++++ 12 files changed, 195 insertions(+), 28 deletions(-) create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html create mode 100644 tests/test_bound_fields.py diff --git a/docs/tutorial/quickstart.md b/docs/tutorial/quickstart.md index 813e9872c..c2dc4bea9 100644 --- a/docs/tutorial/quickstart.md +++ b/docs/tutorial/quickstart.md @@ -26,11 +26,13 @@ Create a new Django project named `tutorial`, then start a new app called `quick Now sync your database for the first time: - python manage.py syncdb + python manage.py migrate -Make sure to create an initial user named `admin` with a password of `password`. We'll authenticate as that user later in our example. +We'll also create an initial user named `admin` with a password of `password`. We'll authenticate as that user later in our example. -Once you've set up a database and got everything synced and ready to go, open up the app's directory and we'll get coding... + python manage.py createsuperuser + +Once you've set up a database and initial user created and ready to go, open up the app's directory and we'll get coding... ## Serializers diff --git a/rest_framework/compat.py b/rest_framework/compat.py index e4e69580f..4ab23a4da 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -114,12 +114,15 @@ else: -# MinValueValidator and MaxValueValidator only accept `message` in 1.8+ +# MinValueValidator, MaxValueValidator et al. only accept `message` in 1.8+ if django.VERSION >= (1, 8): from django.core.validators import MinValueValidator, MaxValueValidator + from django.core.validators import MinLengthValidator, MaxLengthValidator else: from django.core.validators import MinValueValidator as DjangoMinValueValidator from django.core.validators import MaxValueValidator as DjangoMaxValueValidator + from django.core.validators import MinLengthValidator as DjangoMinLengthValidator + from django.core.validators import MaxLengthValidator as DjangoMaxLengthValidator class MinValueValidator(DjangoMinValueValidator): def __init__(self, *args, **kwargs): @@ -131,6 +134,17 @@ else: self.message = kwargs.pop('message', self.message) super(MaxValueValidator, self).__init__(*args, **kwargs) + class MinLengthValidator(DjangoMinLengthValidator): + def __init__(self, *args, **kwargs): + self.message = kwargs.pop('message', self.message) + super(MinLengthValidator, self).__init__(*args, **kwargs) + + class MaxLengthValidator(DjangoMaxLengthValidator): + def __init__(self, *args, **kwargs): + self.message = kwargs.pop('message', self.message) + super(MaxLengthValidator, self).__init__(*args, **kwargs) + + # URLValidator only accepts `message` in 1.6+ if django.VERSION >= (1, 6): from django.core.validators import URLValidator diff --git a/rest_framework/fields.py b/rest_framework/fields.py index b371c7d0a..7053acee0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,7 +8,10 @@ from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 -from rest_framework.compat import smart_text, EmailValidator, MinValueValidator, MaxValueValidator, URLValidator +from rest_framework.compat import ( + smart_text, EmailValidator, MinValueValidator, MaxValueValidator, + MinLengthValidator, MaxLengthValidator, URLValidator +) from rest_framework.settings import api_settings from rest_framework.utils import html, representation, humanize_datetime import copy @@ -138,7 +141,7 @@ class Field(object): self.label = label self.help_text = help_text self.style = {} if style is None else style - self.validators = validators or self.default_validators[:] + self.validators = validators[:] or self.default_validators[:] self.allow_null = allow_null # These are set up by `.bind()` when the field is added to a serializer. @@ -412,16 +415,24 @@ class NullBooleanField(Field): class CharField(Field): default_error_messages = { - 'blank': _('This field may not be blank.') + 'blank': _('This field may not be blank.'), + 'max_length': _('Ensure this field has no more than {max_length} characters.'), + 'min_length': _('Ensure this field has no more than {min_length} characters.') } initial = '' coerce_blank_to_null = False def __init__(self, **kwargs): self.allow_blank = kwargs.pop('allow_blank', False) - self.max_length = kwargs.pop('max_length', None) - self.min_length = kwargs.pop('min_length', None) + max_length = kwargs.pop('max_length', None) + min_length = kwargs.pop('min_length', None) super(CharField, self).__init__(**kwargs) + if max_length is not None: + message = self.error_messages['max_length'].format(max_length=max_length) + self.validators.append(MaxLengthValidator(max_length, message=message)) + if min_length is not None: + message = self.error_messages['min_length'].format(min_length=min_length) + self.validators.append(MinLengthValidator(min_length, message=message)) def run_validation(self, data=empty): # Test for the empty string here so that it does not get validated, @@ -857,6 +868,13 @@ class MultipleChoiceField(ChoiceField): } default_empty_html = [] + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return dictionary.getlist(self.field_name) + return dictionary.get(self.field_name, empty) + def to_internal_value(self, data): if isinstance(data, type('')) or not hasattr(data, '__iter__'): self.fail('not_a_list', input_type=type(data).__name__) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index c1e5aa187..268b95cf1 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,6 +1,7 @@ from rest_framework.compat import smart_text, urlparse from rest_framework.fields import empty, Field from rest_framework.reverse import reverse +from rest_framework.utils import html from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404 from django.db.models.query import QuerySet @@ -263,6 +264,13 @@ class ManyRelation(Field): super(ManyRelation, self).__init__(*args, **kwargs) self.child_relation.bind(field_name='', parent=self) + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return dictionary.getlist(self.field_name) + return dictionary.get(self.field_name, empty) + def to_internal_value(self, data): return [ self.child_relation.to_internal_value(item) @@ -278,10 +286,16 @@ class ManyRelation(Field): @property def choices(self): + queryset = self.child_relation.queryset + iterable = queryset.all() if (hasattr(queryset, 'all')) else queryset + items_and_representations = [ + (item, self.child_relation.to_representation(item)) + for item in iterable + ] return dict([ ( - str(self.child_relation.to_representation(item)), - str(item) + str(item_representation), + str(item) + ' - ' + str(item_representation) ) - for item in self.child_relation.queryset.all() + for item, item_representation in items_and_representations ]) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 931dd434f..4fb360609 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -364,6 +364,12 @@ class HTMLFormRenderer(BaseRenderer): serializers.ManyRelation: { 'default': 'select_multiple.html', 'checkbox': 'select_checkbox.html' + }, + serializers.Serializer: { + 'default': 'fieldset.html' + }, + serializers.ListSerializer: { + 'default': 'list_fieldset.html' } }) @@ -392,7 +398,9 @@ class HTMLFormRenderer(BaseRenderer): template = loader.get_template(template_name) context = Context({ 'field': field, - 'input_type': input_type + 'input_type': input_type, + 'renderer': self, + 'layout': layout }) return template.render(context) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 9fcbcba76..1c006990b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -166,14 +166,25 @@ class BoundField(object): Returned when iterating over a serializer instance, providing an API similar to Django forms and form fields. """ - def __init__(self, field, value, errors): + def __init__(self, field, value, errors, prefix=''): self._field = field self.value = value self.errors = errors + self.name = prefix + self.field_name def __getattr__(self, attr_name): return getattr(self._field, attr_name) + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] + + def __getitem__(self, key): + field = self.fields[key] + value = self.value.get(key) if self.value else None + error = self.errors.get(key) if self.errors else None + return BoundField(field, value, error, prefix=self.name + '.') + @property def _proxy_class(self): return self._field.__class__ @@ -355,16 +366,23 @@ class Serializer(BaseSerializer): def validate(self, attrs): return attrs - def __iter__(self): - errors = self.errors if hasattr(self, '_errors') else {} - for field in self.fields.values(): - value = self.data.get(field.field_name) if self.data else None - error = errors.get(field.field_name) - yield BoundField(field, value, error) - def __repr__(self): return representation.serializer_repr(self, indent=1) + # The following are used for accessing `BoundField` instances on the + # serializer, for the purposes of presenting a form-like API onto the + # field values and field errors. + + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] + + def __getitem__(self, key): + field = self.fields[key] + value = self.data.get(key) + error = self.errors.get(key) if hasattr(self, '_errors') else None + return BoundField(field, value, error) + # There's some replication of `ListField` here, # but that's probably better than obfuscating the call hierarchy. @@ -404,8 +422,9 @@ class ListSerializer(BaseSerializer): """ List of object instances -> List of dicts of primitive datatypes. """ + iterable = data.all() if (hasattr(data, 'all')) else data return ReturnList( - [self.child.to_representation(item) for item in data], + [self.child.to_representation(item) for item in iterable], serializer=self ) diff --git a/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html index 843a56b29..ff93c6baa 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html @@ -1,10 +1,11 @@ +{% load rest_framework %}
{% if field.label %}
{{ field.label }}
{% endif %} - {% for field_item in field.value.field_items.values() %} - {{ renderer.render_field(field_item, layout=layout) }} + {% for nested_field in field %} + {% render_field nested_field layout=layout renderer=renderer %} {% endfor %}
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html b/rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html new file mode 100644 index 000000000..68c75d4f8 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html @@ -0,0 +1,13 @@ +{% load rest_framework %} +
+ {% if field.label %} +
+ {{ field.label }} +
+ {% endif %} + +
diff --git a/rest_framework/templates/rest_framework/fields/inline/fieldset.html b/rest_framework/templates/rest_framework/fields/inline/fieldset.html index 380d46272..ba9f1835a 100644 --- a/rest_framework/templates/rest_framework/fields/inline/fieldset.html +++ b/rest_framework/templates/rest_framework/fields/inline/fieldset.html @@ -1,3 +1,4 @@ -{% for field_item in field.value.field_items.values() %} - {{ renderer.render_field(field_item, layout=layout) }} +{% load rest_framework %} +{% for nested_field in field %} + {% render_field nested_field layout=layout renderer=renderer %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html index 8708916bd..248fe9044 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html +++ b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html @@ -1,6 +1,7 @@ +{% load rest_framework %}
{% if field.label %}{{ field.label }}{% endif %} - {% for field_item in field.value.field_items.values() %} - {{ renderer.render_field(field_item, layout=layout) }} + {% for nested_field in field %} + {% render_field nested_field layout=layout renderer=renderer %} {% endfor %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html b/rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html new file mode 100644 index 000000000..6b99a8349 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html @@ -0,0 +1,7 @@ +
+ {% if field.label %}{{ field.label }}{% endif %} + +
diff --git a/tests/test_bound_fields.py b/tests/test_bound_fields.py new file mode 100644 index 000000000..469437e4b --- /dev/null +++ b/tests/test_bound_fields.py @@ -0,0 +1,69 @@ +from rest_framework import serializers + + +class TestSimpleBoundField: + def test_empty_bound_field(self): + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + serializer = ExampleSerializer() + + assert serializer['text'].value == '' + assert serializer['text'].errors is None + assert serializer['text'].name == 'text' + assert serializer['amount'].value is None + assert serializer['amount'].errors is None + assert serializer['amount'].name == 'amount' + + def test_populated_bound_field(self): + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + serializer = ExampleSerializer(data={'text': 'abc', 'amount': 123}) + + assert serializer['text'].value == 'abc' + assert serializer['text'].errors is None + assert serializer['text'].name == 'text' + assert serializer['amount'].value is 123 + assert serializer['amount'].errors is None + assert serializer['amount'].name == 'amount' + + def test_error_bound_field(self): + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + serializer = ExampleSerializer(data={'text': 'x' * 1000, 'amount': 123}) + serializer.is_valid() + + assert serializer['text'].value == 'x' * 1000 + assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.'] + assert serializer['text'].name == 'text' + assert serializer['amount'].value is 123 + assert serializer['amount'].errors is None + assert serializer['amount'].name == 'amount' + + +class TestNestedBoundField: + def test_nested_empty_bound_field(self): + class Nested(serializers.Serializer): + more_text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + nested = Nested() + + serializer = ExampleSerializer() + + assert serializer['text'].value == '' + assert serializer['text'].errors is None + assert serializer['text'].name == 'text' + assert serializer['nested']['more_text'].value == '' + assert serializer['nested']['more_text'].errors is None + assert serializer['nested']['more_text'].name == 'nested.more_text' + assert serializer['nested']['amount'].value is None + assert serializer['nested']['amount'].errors is None + assert serializer['nested']['amount'].name == 'nested.amount'