mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-13 05:06:53 +03:00
Fix rest_framework.filters.OrderingFilter doesn't pass context to ser… (#4543)
* Fix rest_framework.filters.OrderingFilter doesn't pass context to serializers #4541 * #4541 Additional fix for remove_invalid_fields()
This commit is contained in:
parent
4ff9e96b4c
commit
e99b30d28b
|
@ -252,7 +252,7 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
params = request.query_params.get(self.ordering_param)
|
params = request.query_params.get(self.ordering_param)
|
||||||
if params:
|
if params:
|
||||||
fields = [param.strip() for param in params.split(',')]
|
fields = [param.strip() for param in params.split(',')]
|
||||||
ordering = self.remove_invalid_fields(queryset, fields, view)
|
ordering = self.remove_invalid_fields(queryset, fields, view, request)
|
||||||
if ordering:
|
if ordering:
|
||||||
return ordering
|
return ordering
|
||||||
|
|
||||||
|
@ -265,7 +265,7 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
return (ordering,)
|
return (ordering,)
|
||||||
return ordering
|
return ordering
|
||||||
|
|
||||||
def get_default_valid_fields(self, queryset, view):
|
def get_default_valid_fields(self, queryset, view, context={}):
|
||||||
# If `ordering_fields` is not specified, then we determine a default
|
# If `ordering_fields` is not specified, then we determine a default
|
||||||
# based on the serializer class, if one exists on the view.
|
# based on the serializer class, if one exists on the view.
|
||||||
if hasattr(view, 'get_serializer_class'):
|
if hasattr(view, 'get_serializer_class'):
|
||||||
|
@ -288,16 +288,16 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(field.source or field_name, field.label)
|
(field.source or field_name, field.label)
|
||||||
for field_name, field in serializer_class().fields.items()
|
for field_name, field in serializer_class(context=context).fields.items()
|
||||||
if not getattr(field, 'write_only', False) and not field.source == '*'
|
if not getattr(field, 'write_only', False) and not field.source == '*'
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_valid_fields(self, queryset, view):
|
def get_valid_fields(self, queryset, view, context={}):
|
||||||
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
||||||
|
|
||||||
if valid_fields is None:
|
if valid_fields is None:
|
||||||
# Default to allowing filtering on serializer fields
|
# Default to allowing filtering on serializer fields
|
||||||
return self.get_default_valid_fields(queryset, view)
|
return self.get_default_valid_fields(queryset, view, context)
|
||||||
|
|
||||||
elif valid_fields == '__all__':
|
elif valid_fields == '__all__':
|
||||||
# View explicitly allows filtering on any model field
|
# View explicitly allows filtering on any model field
|
||||||
|
@ -316,8 +316,8 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
|
|
||||||
return valid_fields
|
return valid_fields
|
||||||
|
|
||||||
def remove_invalid_fields(self, queryset, fields, view):
|
def remove_invalid_fields(self, queryset, fields, view, request):
|
||||||
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view)]
|
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
|
||||||
return [term for term in fields if term.lstrip('-') in valid_fields]
|
return [term for term in fields if term.lstrip('-') in valid_fields]
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
|
@ -332,15 +332,16 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
current = self.get_ordering(request, queryset, view)
|
current = self.get_ordering(request, queryset, view)
|
||||||
current = None if current is None else current[0]
|
current = None if current is None else current[0]
|
||||||
options = []
|
options = []
|
||||||
for key, label in self.get_valid_fields(queryset, view):
|
context = {
|
||||||
options.append((key, '%s - %s' % (label, _('ascending'))))
|
|
||||||
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
|
|
||||||
return {
|
|
||||||
'request': request,
|
'request': request,
|
||||||
'current': current,
|
'current': current,
|
||||||
'param': self.ordering_param,
|
'param': self.ordering_param,
|
||||||
'options': options,
|
|
||||||
}
|
}
|
||||||
|
for key, label in self.get_valid_fields(queryset, view, context):
|
||||||
|
options.append((key, '%s - %s' % (label, _('ascending'))))
|
||||||
|
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
|
||||||
|
context['options'] = options
|
||||||
|
return context
|
||||||
|
|
||||||
def to_html(self, request, queryset, view):
|
def to_html(self, request, queryset, view):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user