Add 'view' argument to 'get_fields()'

This commit is contained in:
Tom Christie 2016-06-21 21:16:45 +01:00
parent 8fb260214b
commit b438281f34
3 changed files with 10 additions and 10 deletions

View File

@ -71,7 +71,7 @@ class BaseFilterBackend(object):
"""
raise NotImplementedError(".filter_queryset() must be overridden.")
def get_fields(self):
def get_fields(self, view):
return []
@ -130,7 +130,7 @@ class DjangoFilterBackend(BaseFilterBackend):
template = loader.get_template(self.template)
return template_render(template, context)
def get_fields(self):
def get_fields(self, view):
filter_class = getattr(view, 'filter_class', None)
if filter_class:
return list(filter_class().filters.keys())
@ -205,7 +205,7 @@ class SearchFilter(BaseFilterBackend):
template = loader.get_template(self.template)
return template_render(template, context)
def get_fields(self):
def get_fields(self, view):
return [self.search_param]
@ -321,7 +321,7 @@ class OrderingFilter(BaseFilterBackend):
context = self.get_template_context(request, queryset, view)
return template_render(template, context)
def get_fields(self):
def get_fields(self, view):
return [self.ordering_param]

View File

@ -157,7 +157,7 @@ class BasePagination(object):
def get_results(self, data):
return data['results']
def get_fields(self):
def get_fields(self, view):
return []
@ -283,7 +283,7 @@ class PageNumberPagination(BasePagination):
context = self.get_html_context()
return template_render(template, context)
def get_fields(self):
def get_fields(self, view):
if self.page_size_query_param is None:
return [self.page_query_param]
return [self.page_query_param, self.page_size_query_param]
@ -412,7 +412,7 @@ class LimitOffsetPagination(BasePagination):
context = self.get_html_context()
return template_render(template, context)
def get_fields(self):
def get_fields(self, view):
return [self.limit_query_param, self.offset_query_param]
@ -719,5 +719,5 @@ class CursorPagination(BasePagination):
context = self.get_html_context()
return template_render(template, context)
def get_fields(self):
def get_fields(self, view):
return [self.cursor_query_param]

View File

@ -231,7 +231,7 @@ class SchemaGenerator(object):
return []
paginator = view.pagination_class()
return paginator.get_fields()
return paginator.get_fields(view)
def get_filter_fields(self, path, method, callback, view):
if method != 'GET':
@ -245,5 +245,5 @@ class SchemaGenerator(object):
fields = []
for filter_backend in view.filter_backends:
fields += filter_backend().get_fields()
fields += filter_backend().get_fields(view)
return fields