From 47b534a13e42d498629bf9522225633122c563d5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 7 Nov 2012 21:07:24 +0000 Subject: [PATCH] Make filtering optional, and pluggable. --- docs/api-guide/filtering.md | 114 ++++++++++++++++++ rest_framework/compat.py | 34 ++---- rest_framework/filters.py | 52 ++++++++ rest_framework/generics.py | 33 +---- rest_framework/pagination.py | 19 +-- rest_framework/settings.py | 2 + rest_framework/templatetags/rest_framework.py | 27 ++--- rest_framework/tests/filterset.py | 71 ++++++----- rest_framework/tests/pagination.py | 25 ++-- rest_framework/tests/response.py | 6 - rest_framework/utils/__init__.py | 1 - 11 files changed, 251 insertions(+), 133 deletions(-) create mode 100644 docs/api-guide/filtering.md create mode 100644 rest_framework/filters.py diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md new file mode 100644 index 000000000..7f6a9c970 --- /dev/null +++ b/docs/api-guide/filtering.md @@ -0,0 +1,114 @@ +# Filtering + +> The root QuerySet provided by the Manager describes all objects in the database table. Usually, though, you'll need to select only a subset of the complete set of objects. +> +> — [Django documentation][cite] + +The default behavior of REST framework's generic list views is to return the entire queryset for a model manager. Often you will want your API to restrict the items that are returned by the queryset. + +The simplest way to filter the queryset of any view that subclasses `MultipleObjectAPIView` is to override the `.get_queryset()` method. + +Overriding this method allows you to customize the queryset returned by the view in a number of different ways. + +## Filtering against the current user + +You might want to filter the queryset to ensure that only results relevant to the currently authenticated user making the request are returned. + +You can do so by filtering based on the value of `request.user`. + +For example: + + class PurchaseList(generics.ListAPIView) + model = Purchase + serializer_class = PurchaseSerializer + + def get_queryset(self): + """ + This view should return a list of all the purchases + for the currently authenticated user. + """ + user = self.request.user + return Purchase.objects.filter(purchaser=user) + + +## Filtering against the URL + +Another style of filtering might involve restricting the queryset based on some part of the URL. + +For example if your URL config contained an entry like this: + + url('^purchases/(?P.+)/$', PurchaseList.as_view()), + +You could then write a view that returned a purchase queryset filtered by the username portion of the URL: + + class PurchaseList(generics.ListAPIView) + model = Purchase + serializer_class = PurchaseSerializer + + def get_queryset(self): + """ + This view should return a list of all the purchases for + the user as determined by the username portion of the URL. + """ + username = self.kwargs['username'] + return Purchase.objects.filter(purchaser__username=username) + +## Filtering against query parameters + +A final example of filtering the initial queryset would be to determine the initial queryset based on query parameters in the url. + +We can override `.get_queryset()` to deal with URLs such as `http://example.com/api/purchases?username=denvercoder9`, and filter the queryset only if the `username` parameter is included in the URL: + + class PurchaseList(generics.ListAPIView) + model = Purchase + serializer_class = PurchaseSerializer + + def get_queryset(self): + """ + Optionally restricts the returned purchases to a given user, + by filtering against a `username` query parameter in the URL. + """ + queryset = Purchase.objects.all() + username = self.request.QUERY_PARAMS.get('username', None): + if username is not None: + queryset = queryset.filter(purchaser__username=username) + return queryset + +# Generic Filtering + +As well as being able to override the default queryset, REST framework also includes support for generic filtering backends that allow you to easily construct complex filters that can be specified by the client using query parameters. + +REST framework supports pluggable backends to implement filtering, and includes a default implementation which uses the [django-filter] package. + +To use REST framework's default filtering backend, first install `django-filter`. + + pip install -e git+https://github.com/alex/django-filter.git#egg=django-filter + +**Note**: The currently supported version of `django-filter` is the `master` branch. A PyPI release is expected to be coming soon. + +## Specifying filter fields + +**TODO**: Document setting `.filter_fields` on the view. + +## Specifying a FilterSet + +**TODO**: Document setting `.filter_class` on the view. + +**TODO**: Note support for `lookup_type`, double underscore relationship spanning, and ordering. + +# Custom generic filtering + +You can also provide your own generic filtering backend, or write an installable app for other developers to use. + +To do so overide `BaseFilterBackend`, and override the `.filter_queryset(self, request, queryset, view)` method. + +To install the filter, set the `'FILTER_BACKEND'` key in your `'REST_FRAMEWORK'` setting, using the dotted import path of the filter backend class. + +For example: + + REST_FRAMEWORK = { + 'FILTER_BACKEND': 'custom_filters.CustomFilterBackend' + } + +[cite]: https://docs.djangoproject.com/en/dev/topics/db/queries/#retrieving-specific-objects-with-filters +[django-filter]: https://github.com/alex/django-filter \ No newline at end of file diff --git a/rest_framework/compat.py b/rest_framework/compat.py index b0367a32c..02e50604e 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,6 +5,13 @@ versions of django/python, and compatbility wrappers around optional packages. # flake8: noqa import django +# django-filter is optional +try: + import django_filters +except: + django_filters = None + + # cStringIO only if it's available, otherwise StringIO try: import cStringIO as StringIO @@ -348,33 +355,6 @@ except ImportError: yaml = None -import unittest -try: - import unittest.skip -except ImportError: # python < 2.7 - from unittest import TestCase - import functools - - def skip(reason): - # Pasted from py27/lib/unittest/case.py - """ - Unconditionally skip a test. - """ - def decorator(test_item): - if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): - @functools.wraps(test_item) - def skip_wrapper(*args, **kwargs): - pass - test_item = skip_wrapper - - test_item.__unittest_skip__ = True - test_item.__unittest_skip_why__ = reason - return test_item - return decorator - - unittest.skip = skip - - # xml.etree.parse only throws ParseError for python >= 2.7 try: from xml.etree import ParseError as ETParseError diff --git a/rest_framework/filters.py b/rest_framework/filters.py new file mode 100644 index 000000000..b972e82a1 --- /dev/null +++ b/rest_framework/filters.py @@ -0,0 +1,52 @@ +from rest_framework.compat import django_filters + + +class BaseFilterBackend(object): + """ + A base class from which all filter backend classes should inherit. + """ + + def filter_queryset(self, request, queryset, view): + """ + Return a filtered queryset. + """ + raise NotImplementedError(".filter_queryset() must be overridden.") + + +class DjangoFilterBackend(BaseFilterBackend): + """ + A filter backend that uses django-filter. + """ + + def get_filter_class(self, view): + """ + Return the django-filters `FilterSet` used to filter the queryset. + """ + filter_class = getattr(view, 'filter_class', None) + filter_fields = getattr(view, 'filter_fields', None) + filter_model = getattr(view, 'model', None) + + if filter_class or filter_fields: + assert django_filters, 'django-filter is not installed' + + if filter_class: + assert issubclass(filter_class.Meta.model, filter_model), \ + '%s is not a subclass of %s' % (filter_class.Meta.model, filter_model) + return filter_class + + if filter_fields: + class AutoFilterSet(django_filters.FilterSet): + class Meta: + model = filter_model + fields = filter_fields + return AutoFilterSet + + return None + + def filter_queryset(self, request, queryset, view): + filter_class = self.get_filter_class(view) + + if filter_class: + return filter_class(request.GET, queryset=queryset) + + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ac02d3da4..ebd06e452 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -6,7 +6,7 @@ from rest_framework import views, mixins from rest_framework.settings import api_settings from django.views.generic.detail import SingleObjectMixin from django.views.generic.list import MultipleObjectMixin -import django_filters + ### Base classes for the generic views ### @@ -58,34 +58,13 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS paginate_by = api_settings.PAGINATE_BY - filter_class = None - filter_fields = None - - def get_filter_class(self): - """ - Return the django-filters `FilterSet` used to filter the queryset. - """ - if self.filter_class: - return self.filter_class - - if self.filter_fields: - class AutoFilterSet(django_filters.FilterSet): - class Meta: - model = self.model - fields = self.filter_fields - return AutoFilterSet - - return None + filter_backend = api_settings.FILTER_BACKEND def filter_queryset(self, queryset): - filter_class = self.get_filter_class() - - if filter_class: - assert issubclass(filter_class.Meta.model, self.model), \ - "%s is not a subclass of %s" % (filter_class.Meta.model, self.model) - return filter_class(self.request.GET, queryset=queryset) - - return queryset + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) def get_filtered_queryset(self): return self.filter_queryset(self.get_queryset()) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index c77a10051..aa54d154a 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,4 +1,5 @@ from rest_framework import serializers +from rest_framework.templatetags.rest_framework import replace_query_param # TODO: Support URLconf kwarg-style paging @@ -16,13 +17,8 @@ class NextPageField(PageField): return None page = value.next_page_number() request = self.context.get('request') - relative_url = '?%s=%d' % (self.page_field, page) - if request: - for field, value in request.QUERY_PARAMS.iteritems(): - if field != self.page_field: - relative_url += '&%s=%s' % (field, value) - return request.build_absolute_uri(relative_url) - return relative_url + url = request and request.get_full_path() or '' + return replace_query_param(url, self.page_field, page) class PreviousPageField(PageField): @@ -34,13 +30,8 @@ class PreviousPageField(PageField): return None page = value.previous_page_number() request = self.context.get('request') - relative_url = '?%s=%d' % (self.page_field, page) - if request: - for field, value in request.QUERY_PARAMS.iteritems(): - if field != self.page_field: - relative_url += '&%s=%s' % (field, value) - return request.build_absolute_uri(relative_url) - return relative_url + url = request and request.get_full_path() or '' + return replace_query_param(url, self.page_field, page) class PaginationSerializerOptions(serializers.SerializerOptions): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 9c40a2144..da647658e 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -55,6 +55,7 @@ DEFAULTS = { 'anon': None, }, 'PAGINATE_BY': None, + 'FILTER_BACKEND': 'rest_framework.filters.DjangoFilterBackend', 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -79,6 +80,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c9b6eb10d..0672ee4f6 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -1,9 +1,9 @@ from django import template from django.core.urlresolvers import reverse -from django.http import QueryDict from django.utils.encoding import force_unicode from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe +from django.http import QueryDict from urlparse import urlsplit, urlunsplit import re import string @@ -11,6 +11,18 @@ import string register = template.Library() +def replace_query_param(url, key, val): + """ + Given a URL and a key/val pair, set or replace an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = urlsplit(url) + query_dict = QueryDict(query).copy() + query_dict[key] = val + query = query_dict.urlencode() + return urlunsplit((scheme, netloc, path, query, fragment)) + + # Regex for adding classes to html snippets class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') @@ -31,19 +43,6 @@ hard_coded_bullets_re = re.compile(r'((?:

(?:%s).*?[a-zA-Z].*?

\s*)+)' % '| trailing_empty_content_re = re.compile(r'(?:

(?: |\s|
)*?

\s*)+\Z') -# Helper function for 'add_query_param' -def replace_query_param(url, key, val): - """ - Given a URL and a key/val pair, set or replace an item in the query - parameters of the URL, and return the new URL. - """ - (scheme, netloc, path, query, fragment) = urlsplit(url) - query_dict = QueryDict(query).copy() - query_dict[key] = val - query = query_dict.urlencode() - return urlunsplit((scheme, netloc, path, query, fragment)) - - # And the template tags themselves... @register.simple_tag diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 5374eefc8..6cdea32fe 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -2,44 +2,45 @@ import datetime from decimal import Decimal from django.test import TestCase from django.test.client import RequestFactory +from django.utils import unittest from rest_framework import generics, status +from rest_framework.compat import django_filters from rest_framework.tests.models import FilterableItem, BasicModel -import django_filters factory = RequestFactory() -# Basic filter on a list view. -class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_fields = ['decimal', 'date'] - -# These class are used to test a filter class. -class SeveralFieldsFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - decimal = django_filters.NumberFilter(lookup_type='lt') - date = django_filters.DateFilter(lookup_type='gt') - class Meta: +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem - fields = ['text', 'decimal', 'date'] + filter_fields = ['decimal', 'date'] + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') -class FilterClassRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = SeveralFieldsFilter + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter -# These classes are used to test a misconfigured filter class. -class MisconfiguredFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - class Meta: - model = BasicModel - fields = ['text'] + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + class Meta: + model = BasicModel + fields = ['text'] -class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = MisconfiguredFilter + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter class IntegrationTestFiltering(TestCase): @@ -64,6 +65,7 @@ class IntegrationTestFiltering(TestCase): for obj in self.objects.all() ] + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_fields_root_view(self): """ GET requests to paginated ListCreateAPIView should return paginated results. @@ -81,7 +83,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?decimal=%s' % search_decimal) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['decimal'] == search_decimal ] + expected_data = [f for f in self.data if f['decimal'] == search_decimal] self.assertEquals(response.data, expected_data) # Tests that the date filter works. @@ -89,9 +91,10 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['date'] == search_date ] + expected_data = [f for f in self.data if f['date'] == search_date] self.assertEquals(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ GET requests to filtered ListCreateAPIView that have a filter_class set @@ -110,7 +113,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?decimal=%s' % search_decimal) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['decimal'] < search_decimal ] + expected_data = [f for f in self.data if f['decimal'] < search_decimal] self.assertEquals(response.data, expected_data) # Tests that the date filter set with 'gt' in the filter class works. @@ -118,7 +121,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['date'] > search_date ] + expected_data = [f for f in self.data if f['date'] > search_date] self.assertEquals(response.data, expected_data) # Tests that the text filter set with 'icontains' in the filter class works. @@ -126,7 +129,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?text=%s' % search_text) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if search_text in f['text'].lower() ] + expected_data = [f for f in self.data if search_text in f['text'].lower()] self.assertEquals(response.data, expected_data) # Tests that multiple filters works. @@ -135,10 +138,11 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal ] + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] self.assertEquals(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_incorrectly_configured_filter(self): """ An error should be displayed when the filter class is misconfigured. @@ -148,6 +152,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/') self.assertRaises(AssertionError, view, request) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_unknown_filter(self): """ GET requests with filters that aren't configured should return 200. @@ -157,4 +162,4 @@ class IntegrationTestFiltering(TestCase): search_integer = 10 request = factory.get('/?integer=%s' % search_integer) response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) \ No newline at end of file + self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 7a2134e01..7f8cd5247 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -3,9 +3,10 @@ from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory +from django.utils import unittest from rest_framework import generics, status, pagination +from rest_framework.compat import django_filters from rest_framework.tests.models import BasicModel, FilterableItem -import django_filters factory = RequestFactory() @@ -18,17 +19,18 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 -class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - class Meta: +if django_filters: + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem - fields = ['text', 'decimal', 'date'] - - -class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - paginate_by = 10 - filter_class = DecimalFilter + paginate_by = 10 + filter_class = DecimalFilter class IntegrationTestPagination(TestCase): @@ -98,6 +100,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ] self.view = FilterFieldsRootView.as_view() + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_paginated_filtered_root_view(self): """ GET requests to paginated filtered ListCreateAPIView should return diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index 18b6af394..d7b75450c 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -131,12 +131,6 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) - @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse') - def test_unsatisfiable_accept_header_on_request_returns_406_status(self): - """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" - resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) - def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py index a59fff453..84fcb5dbb 100644 --- a/rest_framework/utils/__init__.py +++ b/rest_framework/utils/__init__.py @@ -1,7 +1,6 @@ from django.utils.encoding import smart_unicode from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO - import re import xml.etree.ElementTree as ET