diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md index ea1e7d23e..ca901b039 100644 --- a/docs/api-guide/filtering.md +++ b/docs/api-guide/filtering.md @@ -82,24 +82,30 @@ We can override `.get_queryset()` to deal with URLs such as `http://example.com/ As well as being able to override the default queryset, REST framework also includes support for generic filtering backends that allow you to easily construct complex filters that can be specified by the client using query parameters. -REST framework supports pluggable backends to implement filtering, and includes a default implementation which uses the [django-filter] package. +REST framework supports pluggable backends to implement filtering, and provides an implementation which uses the [django-filter] package. To use REST framework's default filtering backend, first install `django-filter`. pip install -e git+https://github.com/alex/django-filter.git#egg=django-filter +You must also set the filter backend to `DjangoFilterBackend` in your settings: + + REST_FRAMEWORK = { + 'FILTER_BACKEND': 'rest_framework.filters.DjangoFilterBackend' + } + **Note**: The currently supported version of `django-filter` is the `master` branch. A PyPI release is expected to be coming soon. -## Specifying filter fields - -**TODO**: Document setting `.filter_fields` on the view. - ## Specifying a FilterSet **TODO**: Document setting `.filter_class` on the view. **TODO**: Note support for `lookup_type`, double underscore relationship spanning, and ordering. +## Specifying filter fields + +**TODO**: Document setting `.filter_fields` on the view. + **TODO**: Note that overiding `get_queryset()` can be used together with generic filtering --- diff --git a/rest_framework/filters.py b/rest_framework/filters.py index b972e82a1..14902a69b 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -17,6 +17,10 @@ class DjangoFilterBackend(BaseFilterBackend): """ A filter backend that uses django-filter. """ + default_filter_set = django_filters.FilterSet + + def __init__(self): + assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' def get_filter_class(self, view): """ @@ -24,20 +28,21 @@ class DjangoFilterBackend(BaseFilterBackend): """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - filter_model = getattr(view, 'model', None) - - if filter_class or filter_fields: - assert django_filters, 'django-filter is not installed' + view_model = getattr(view, 'model', None) if filter_class: - assert issubclass(filter_class.Meta.model, filter_model), \ - '%s is not a subclass of %s' % (filter_class.Meta.model, filter_model) + filter_model = filter_class.Meta.model + + assert issubclass(filter_model, view_model), \ + 'FilterSet model %s does not match view model %s' % \ + (filter_model, view_model) + return filter_class if filter_fields: - class AutoFilterSet(django_filters.FilterSet): + class AutoFilterSet(self.default_filter_set): class Meta: - model = filter_model + model = view_model fields = filter_fields return AutoFilterSet diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index b48f85e4e..dd5d9dc3c 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -107,6 +107,7 @@ import django if django.VERSION < (1, 3): INSTALLED_APPS += ('staticfiles',) + # If we're running on the Jenkins server we want to archive the coverage reports as XML. import os if os.environ.get('HUDSON_URL', None): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index da647658e..906a7cf6c 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -55,7 +55,7 @@ DEFAULTS = { 'anon': None, }, 'PAGINATE_BY': None, - 'FILTER_BACKEND': 'rest_framework.filters.DjangoFilterBackend', + 'FILTER_BACKEND': None, 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -144,8 +144,15 @@ class APISettings(object): if val and attr in self.import_strings: val = perform_import(val, attr) + self.validate_setting(attr, val) + # Cache the result setattr(self, attr, val) return val + def validate_setting(self, attr, val): + if attr == 'FILTER_BACKEND' and val is not None: + # Make sure we can initilize the class + val() + api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 6cdea32fe..af2e6c2e7 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -3,7 +3,7 @@ from decimal import Decimal from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest -from rest_framework import generics, status +from rest_framework import generics, status, filters from rest_framework.compat import django_filters from rest_framework.tests.models import FilterableItem, BasicModel @@ -15,6 +15,7 @@ if django_filters: class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend # These class are used to test a filter class. class SeveralFieldsFilter(django_filters.FilterSet): @@ -29,6 +30,7 @@ if django_filters: class FilterClassRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend # These classes are used to test a misconfigured filter class. class MisconfiguredFilter(django_filters.FilterSet): @@ -41,6 +43,7 @@ if django_filters: class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = MisconfiguredFilter + filter_backend = filters.DjangoFilterBackend class IntegrationTestFiltering(TestCase): diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 7f8cd5247..713a7255b 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -4,7 +4,7 @@ from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest -from rest_framework import generics, status, pagination +from rest_framework import generics, status, pagination, filters from rest_framework.compat import django_filters from rest_framework.tests.models import BasicModel, FilterableItem @@ -31,6 +31,7 @@ if django_filters: model = FilterableItem paginate_by = 10 filter_class = DecimalFilter + filter_backend = filters.DjangoFilterBackend class IntegrationTestPagination(TestCase):