Fix rest_framework.filters.OrderingFilter doesn't pass context to ser… (#4543)

* Fix rest_framework.filters.OrderingFilter doesn't pass context to serializers #4541

* #4541 Additional fix for remove_invalid_fields()
This commit is contained in:
Camille Harang 2016-10-10 12:59:02 +02:00 committed by Tom Christie
parent 4ff9e96b4c
commit e99b30d28b

View File

@ -252,7 +252,7 @@ class OrderingFilter(BaseFilterBackend):
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view)
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
@ -265,7 +265,7 @@ class OrderingFilter(BaseFilterBackend):
return (ordering,)
return ordering
def get_default_valid_fields(self, queryset, view):
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
@ -288,16 +288,16 @@ class OrderingFilter(BaseFilterBackend):
return [
(field.source or field_name, field.label)
for field_name, field in serializer_class().fields.items()
for field_name, field in serializer_class(context=context).fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*'
]
def get_valid_fields(self, queryset, view):
def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view)
return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
@ -316,8 +316,8 @@ class OrderingFilter(BaseFilterBackend):
return valid_fields
def remove_invalid_fields(self, queryset, fields, view):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view)]
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
return [term for term in fields if term.lstrip('-') in valid_fields]
def filter_queryset(self, request, queryset, view):
@ -332,15 +332,16 @@ class OrderingFilter(BaseFilterBackend):
current = self.get_ordering(request, queryset, view)
current = None if current is None else current[0]
options = []
for key, label in self.get_valid_fields(queryset, view):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
return {
context = {
'request': request,
'current': current,
'param': self.ordering_param,
'options': options,
}
for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)