diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1d12b1d92..8d339e798 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -155,6 +155,39 @@ def flatten_choices_dict(choices): return ret +def iter_options(grouped_choices): + """ + Helper function for options and option groups in templates. + """ + class StartOptionGroup(object): + start_option_group = True + end_option_group = False + + def __init__(self, label): + self.label = label + + class EndOptionGroup(object): + start_option_group = False + end_option_group = True + + class Option(object): + start_option_group = False + end_option_group = False + + def __init__(self, value, display_text): + self.value = value + self.display_text = display_text + + for key, value in grouped_choices.items(): + if isinstance(value, dict): + yield StartOptionGroup(label=key) + for sub_key, sub_value in value.items(): + yield Option(value=sub_key, display_text=sub_value) + yield EndOptionGroup() + else: + yield Option(value=key, display_text=value) + + class CreateOnlyDefault(object): """ This class may be used to provide default values that are only used @@ -1190,33 +1223,7 @@ class ChoiceField(Field): """ Helper method for use with templates rendering select widgets. """ - class StartOptionGroup(object): - start_option_group = True - end_option_group = False - - def __init__(self, label): - self.label = label - - class EndOptionGroup(object): - start_option_group = False - end_option_group = True - - class Option(object): - start_option_group = False - end_option_group = False - - def __init__(self, value, display_text): - self.value = value - self.display_text = display_text - - for key, value in self.grouped_choices.items(): - if isinstance(value, dict): - yield StartOptionGroup(label=key) - for sub_key, sub_value in value.items(): - yield Option(value=sub_key, display_text=sub_value) - yield EndOptionGroup() - else: - yield Option(value=key, display_text=value) + return iter_options(self.grouped_choices) class MultipleChoiceField(ChoiceField): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 076248541..874406696 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -14,7 +14,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import OrderedDict from rest_framework.fields import ( - Field, empty, get_attribute, is_simple_callable + Field, empty, get_attribute, is_simple_callable, iter_options ) from rest_framework.reverse import reverse from rest_framework.utils import html @@ -153,6 +153,13 @@ class RelatedField(Field): for item in queryset ]) + @property + def grouped_choices(self): + return self.choices + + def iter_options(self): + return iter_options(self.grouped_choices) + class StringRelatedField(RelatedField): """ @@ -453,3 +460,10 @@ class ManyRelatedField(Field): @property def choices(self): return self.child_relation.choices + + @property + def grouped_choices(self): + return self.choices + + def iter_options(self): + return iter_options(self.grouped_choices)