From 15024f3f07b9311887ad6f5afb366f6c032c0486 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Tue, 14 Nov 2017 03:55:59 -0500 Subject: [PATCH 1/8] Remove set_rollback() from compat (#5591) * Remove Django 1.6 transaction compat * Move set_rollback from compat => views --- rest_framework/compat.py | 21 ++----------------- rest_framework/views.py | 9 ++++++-- tests/test_compat.py | 44 ---------------------------------------- 3 files changed, 9 insertions(+), 65 deletions(-) delete mode 100644 tests/test_compat.py diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 75a840ad5..5009ffee1 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -12,7 +12,7 @@ from django.apps import apps from django.conf import settings from django.core import validators from django.core.exceptions import ImproperlyConfigured -from django.db import connection, models, transaction +from django.db import models from django.utils import six from django.views.generic import View @@ -250,7 +250,7 @@ else: # pytz is required from Django 1.11. Remove when dropping Django 1.10 support. try: - import pytz # noqa + import pytz # noqa from pytz.exceptions import InvalidTimeError except ImportError: InvalidTimeError = Exception @@ -297,23 +297,6 @@ class MaxLengthValidator(CustomValidatorMessage, validators.MaxLengthValidator): pass -def set_rollback(): - if hasattr(transaction, 'set_rollback'): - if connection.settings_dict.get('ATOMIC_REQUESTS', False): - # If running in >=1.6 then mark a rollback as required, - # and allow it to be handled by Django. - if connection.in_atomic_block: - transaction.set_rollback(True) - elif transaction.is_managed(): - # Otherwise handle it explicitly if in managed mode. - if transaction.is_dirty(): - transaction.rollback() - transaction.leave_transaction_management() - else: - # transaction not managed - pass - - def authenticate(request=None, **credentials): from django.contrib.auth import authenticate if django.VERSION < (1, 11): diff --git a/rest_framework/views.py b/rest_framework/views.py index dfed15888..3140bb9a3 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.conf import settings from django.core.exceptions import PermissionDenied -from django.db import models +from django.db import connection, models, transaction from django.http import Http404 from django.http.response import HttpResponseBase from django.utils import six @@ -16,7 +16,6 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View from rest_framework import exceptions, status -from rest_framework.compat import set_rollback from rest_framework.request import Request from rest_framework.response import Response from rest_framework.schemas import AutoSchema @@ -55,6 +54,12 @@ def get_view_description(view_cls, html=False): return description +def set_rollback(): + atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) + if atomic_requests and connection.in_atomic_block: + transaction.set_rollback(True) + + def exception_handler(exc, context): """ Returns the response that should be used for any given exception. diff --git a/tests/test_compat.py b/tests/test_compat.py deleted file mode 100644 index 842cb8ef8..000000000 --- a/tests/test_compat.py +++ /dev/null @@ -1,44 +0,0 @@ -from django.test import TestCase - -from rest_framework import compat - - -class CompatTests(TestCase): - - def setUp(self): - self.original_django_version = compat.django.VERSION - self.original_transaction = compat.transaction - - def tearDown(self): - compat.django.VERSION = self.original_django_version - compat.transaction = self.original_transaction - - def test_set_rollback_for_transaction_in_managed_mode(self): - class MockTransaction(object): - called_rollback = False - called_leave_transaction_management = False - - def is_managed(self): - return True - - def is_dirty(self): - return True - - def rollback(self): - self.called_rollback = True - - def leave_transaction_management(self): - self.called_leave_transaction_management = True - - dirty_mock_transaction = MockTransaction() - compat.transaction = dirty_mock_transaction - compat.set_rollback() - assert dirty_mock_transaction.called_rollback is True - assert dirty_mock_transaction.called_leave_transaction_management is True - - clean_mock_transaction = MockTransaction() - clean_mock_transaction.is_dirty = lambda: False - compat.transaction = clean_mock_transaction - compat.set_rollback() - assert clean_mock_transaction.called_rollback is False - assert clean_mock_transaction.called_leave_transaction_management is True From 9f66e8baddd9d8106a121b159356422086e7d90c Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 15 Nov 2017 14:58:37 -0500 Subject: [PATCH 2/8] Fix request body/POST access (#5590) * Modernize middleware tests * Added a failing test for #5582 * Set data ref on underlying django request --- rest_framework/request.py | 5 ++-- tests/test_middleware.py | 60 +++++++++++++++++++++++++++++++++------ 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/rest_framework/request.py b/rest_framework/request.py index 4f413e03f..f9503cd59 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -250,9 +250,10 @@ class Request(object): else: self._full_data = self._data - # copy files refs to the underlying request so that closable + # copy data & files refs to the underlying request so that closable # objects are handled appropriately. - self._request._files = self._files + self._request._post = self.POST + self._request._files = self.FILES def _load_stream(self): """ diff --git a/tests/test_middleware.py b/tests/test_middleware.py index a9f620c0e..9df7d8e3e 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,34 +1,76 @@ from django.conf.urls import url from django.contrib.auth.models import User +from django.http import HttpRequest from django.test import override_settings from rest_framework.authentication import TokenAuthentication from rest_framework.authtoken.models import Token +from rest_framework.request import is_form_media_type +from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView + +class PostView(APIView): + def post(self, request): + return Response(data=request.data, status=200) + + urlpatterns = [ - url(r'^$', APIView.as_view(authentication_classes=(TokenAuthentication,))), + url(r'^auth$', APIView.as_view(authentication_classes=(TokenAuthentication,))), + url(r'^post$', PostView.as_view()), ] -class MyMiddleware(object): +class RequestUserMiddleware(object): + def __init__(self, get_response): + self.get_response = get_response - def process_response(self, request, response): + def __call__(self, request): + response = self.get_response(request) assert hasattr(request, 'user'), '`user` is not set on request' - assert request.user.is_authenticated(), '`user` is not authenticated' + assert request.user.is_authenticated, '`user` is not authenticated' + + return response + + +class RequestPOSTMiddleware(object): + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + assert isinstance(request, HttpRequest) + + # Parse body with underlying Django request + request.body + + # Process request with DRF view + response = self.get_response(request) + + # Ensure request.POST is set as appropriate + if is_form_media_type(request.content_type): + assert request.POST == {'foo': ['bar']} + else: + assert request.POST == {} + return response @override_settings(ROOT_URLCONF='tests.test_middleware') class TestMiddleware(APITestCase): + + @override_settings(MIDDLEWARE=('tests.test_middleware.RequestUserMiddleware',)) def test_middleware_can_access_user_when_processing_response(self): user = User.objects.create_user('john', 'john@example.com', 'password') key = 'abcd1234' Token.objects.create(key=key, user=user) - with self.settings( - MIDDLEWARE_CLASSES=('tests.test_middleware.MyMiddleware',) - ): - auth = 'Token ' + key - self.client.get('/', HTTP_AUTHORIZATION=auth) + self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key) + + @override_settings(MIDDLEWARE=('tests.test_middleware.RequestPOSTMiddleware',)) + def test_middleware_can_access_request_post_when_processing_response(self): + response = self.client.post('/post', {'foo': 'bar'}) + assert response.status_code == 200 + + response = self.client.post('/post', {'foo': 'bar'}, format='json') + assert response.status_code == 200 From 25319984274e799ab557dc96a2f6563d4d2c07a7 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Mon, 20 Nov 2017 02:58:29 -0500 Subject: [PATCH 3/8] Rename test to reference correct issue (#5610) --- tests/test_serializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index df8839356..23c6ec2c1 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -515,7 +515,7 @@ class TestSerializerValidationWithCompiledRegexField: assert serializer.errors == {} -class Test2505Regression: +class Test2555Regression: def test_serializer_context(self): class NestedSerializer(serializers.Serializer): def __init__(self, *args, **kwargs): From 20954469b2938f2e701f11e0398815c557bafa67 Mon Sep 17 00:00:00 2001 From: Alexei Znamensky Date: Mon, 20 Nov 2017 21:07:36 +1300 Subject: [PATCH 4/8] Fix in documentation (#5611) - model serializers now must provide either "fields" or "exclude" as attribute --- docs/api-guide/serializers.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md index 0e235c88d..021ef1c38 100644 --- a/docs/api-guide/serializers.md +++ b/docs/api-guide/serializers.md @@ -493,6 +493,8 @@ The names in the `fields` and `exclude` attributes will normally map to model fi Alternatively names in the `fields` options can map to properties or methods which take no arguments that exist on the model class. +Since version 3.3.0, it is **mandatory** to provide one of the attributes `fields` or `exclude`. + ## Specifying nested serialization The default `ModelSerializer` uses primary keys for relationships, but you can also easily generate nested representations using the `depth` option: From 9c11077cf63283fcebb939d58202d4841c942eb8 Mon Sep 17 00:00:00 2001 From: bartkim0426 Date: Mon, 20 Nov 2017 17:08:16 +0900 Subject: [PATCH 5/8] Fix in documentation (#5612) - typo in serialization document: 'intead' => 'instead' --- docs/api-guide/serializers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md index 021ef1c38..ee6e41607 100644 --- a/docs/api-guide/serializers.md +++ b/docs/api-guide/serializers.md @@ -1011,7 +1011,7 @@ Takes the object instance that requires serialization, and should return a primi Takes the unvalidated incoming data as input and should return the validated data that will be made available as `serializer.validated_data`. The return value will also be passed to the `.create()` or `.update()` methods if `.save()` is called on the serializer class. -If any of the validation fails, then the method should raise a `serializers.ValidationError(errors)`. The `errors` argument should be a dictionary mapping field names (or `settings.NON_FIELD_ERRORS_KEY`) to a list of error messages. If you don't need to alter deserialization behavior and instead want to provide object-level validation, it's recommended that you intead override the [`.validate()`](#object-level-validation) method. +If any of the validation fails, then the method should raise a `serializers.ValidationError(errors)`. The `errors` argument should be a dictionary mapping field names (or `settings.NON_FIELD_ERRORS_KEY`) to a list of error messages. If you don't need to alter deserialization behavior and instead want to provide object-level validation, it's recommended that you instead override the [`.validate()`](#object-level-validation) method. The `data` argument passed to this method will normally be the value of `request.data`, so the datatype it provides will depend on the parser classes you have configured for your API. From ff556a91fdd62b389b9b0b6c353d2161263346bc Mon Sep 17 00:00:00 2001 From: Jon Dufresne Date: Mon, 20 Nov 2017 00:35:54 -0800 Subject: [PATCH 6/8] Remove references to unsupported Django versions in docs and code (#5602) Per the trove classifiers, DRF only supports Django versions 1.10+. Can drop documentation, code comments, and workarounds for older Django versions. --- docs/api-guide/fields.md | 4 +--- docs/index.md | 4 ++-- docs/tutorial/1-serialization.md | 2 -- docs/tutorial/4-authentication-and-permissions.md | 5 ++--- rest_framework/authtoken/serializers.py | 12 ++++-------- rest_framework/renderers.py | 2 +- rest_framework/urls.py | 5 ++--- rest_framework/utils/model_meta.py | 10 ++-------- tests/test_atomic_requests.py | 5 ++--- tests/test_fields.py | 6 ------ tests/test_filters.py | 3 --- 11 files changed, 16 insertions(+), 42 deletions(-) diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index 64014b56e..d209a945b 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -356,8 +356,6 @@ Corresponds to `django.db.models.fields.DurationField` The `validated_data` for these fields will contain a `datetime.timedelta` instance. The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`. -**Note:** This field is only available with Django versions >= 1.8. - **Signature:** `DurationField()` --- @@ -681,4 +679,4 @@ The [django-rest-framework-hstore][django-rest-framework-hstore] package provide [django-rest-framework-gis]: https://github.com/djangonauts/django-rest-framework-gis [django-rest-framework-hstore]: https://github.com/djangonauts/django-rest-framework-hstore [django-hstore]: https://github.com/djangonauts/django-hstore -[python-decimal-rounding-modes]: https://docs.python.org/3/library/decimal.html#rounding-modes \ No newline at end of file +[python-decimal-rounding-modes]: https://docs.python.org/3/library/decimal.html#rounding-modes diff --git a/docs/index.md b/docs/index.md index a902ed3af..0e747463b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -120,10 +120,10 @@ If you're intending to use the browsable API you'll probably also want to add RE urlpatterns = [ ... - url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')) + url(r'^api-auth/', include('rest_framework.urls')) ] -Note that the URL path can be whatever you want, but you must include `'rest_framework.urls'` with the `'rest_framework'` namespace. You may leave out the namespace in Django 1.9+, and REST framework will set it for you. +Note that the URL path can be whatever you want. ## Example diff --git a/docs/tutorial/1-serialization.md b/docs/tutorial/1-serialization.md index 558797816..a834c8dbb 100644 --- a/docs/tutorial/1-serialization.md +++ b/docs/tutorial/1-serialization.md @@ -48,8 +48,6 @@ We'll need to add our new `snippets` app and the `rest_framework` app to `INSTAL 'snippets.apps.SnippetsConfig', ) -Please note that if you're using Django <1.9, you need to replace `snippets.apps.SnippetsConfig` with `snippets`. - Okay, we're ready to roll. ## Creating a model to work with diff --git a/docs/tutorial/4-authentication-and-permissions.md b/docs/tutorial/4-authentication-and-permissions.md index b43fabfac..72cf64e37 100644 --- a/docs/tutorial/4-authentication-and-permissions.md +++ b/docs/tutorial/4-authentication-and-permissions.md @@ -142,11 +142,10 @@ Add the following import at the top of the file: And, at the end of the file, add a pattern to include the login and logout views for the browsable API. urlpatterns += [ - url(r'^api-auth/', include('rest_framework.urls', - namespace='rest_framework')), + url(r'^api-auth/', include('rest_framework.urls'), ] -The `r'^api-auth/'` part of pattern can actually be whatever URL you want to use. The only restriction is that the included urls must use the `'rest_framework'` namespace. In Django 1.9+, REST framework will set the namespace, so you may leave it out. +The `r'^api-auth/'` part of pattern can actually be whatever URL you want to use. Now if you open up the browser again and refresh the page you'll see a 'Login' link in the top right of the page. If you log in as one of the users you created earlier, you'll be able to create code snippets again. diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 301b6a0cb..01d2d40b9 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -20,14 +20,10 @@ class AuthTokenSerializer(serializers.Serializer): user = authenticate(request=self.context.get('request'), username=username, password=password) - if user: - # From Django 1.10 onwards the `authenticate` call simply - # returns `None` for is_active=False users. - # (Assuming the default `ModelBackend` authentication backend.) - if not user.is_active: - msg = _('User account is disabled.') - raise serializers.ValidationError(msg, code='authorization') - else: + # The authenticate call simply returns None for is_active=False + # users. (Assuming the default ModelBackend authentication + # backend.) + if not user: msg = _('Unable to log in with provided credentials.') raise serializers.ValidationError(msg, code='authorization') else: diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 3298294ce..bbefb4624 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -666,7 +666,7 @@ class BrowsableAPIRenderer(BaseRenderer): paginator = None csrf_cookie_name = settings.CSRF_COOKIE_NAME - csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME', 'HTTP_X_CSRFToken') # Fallback for Django 1.8 + csrf_header_name = settings.CSRF_HEADER_NAME if csrf_header_name.startswith('HTTP_'): csrf_header_name = csrf_header_name[5:] csrf_header_name = csrf_header_name.replace('_', '-') diff --git a/rest_framework/urls.py b/rest_framework/urls.py index 10cc5def0..80fce5dc4 100644 --- a/rest_framework/urls.py +++ b/rest_framework/urls.py @@ -6,11 +6,10 @@ your API requires authentication: urlpatterns = [ ... - url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')) + url(r'^auth/', include('rest_framework.urls')) ] -In Django versions older than 1.9, the urls must be namespaced as 'rest_framework', -and you should make sure your authentication settings include `SessionAuthentication`. +You should make sure your authentication settings include `SessionAuthentication`. """ from __future__ import unicode_literals diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index f0ae02bb2..4cc93b8ef 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -105,18 +105,13 @@ def _get_reverse_relationships(opts): """ Returns an `OrderedDict` of field names to `RelationInfo`. """ - # Note that we have a hack here to handle internal API differences for - # this internal API across Django 1.7 -> Django 1.8. - # See: https://code.djangoproject.com/ticket/24208 - reverse_relations = OrderedDict() all_related_objects = [r for r in opts.related_objects if not r.field.many_to_many] for relation in all_related_objects: accessor_name = relation.get_accessor_name() - related = getattr(relation, 'related_model', relation.model) reverse_relations[accessor_name] = RelationInfo( model_field=None, - related_model=related, + related_model=relation.related_model, to_many=relation.field.remote_field.multiple, to_field=_get_to_field(relation.field), has_through_model=False, @@ -127,10 +122,9 @@ def _get_reverse_relationships(opts): all_related_many_to_many_objects = [r for r in opts.related_objects if r.field.many_to_many] for relation in all_related_many_to_many_objects: accessor_name = relation.get_accessor_name() - related = getattr(relation, 'related_model', relation.model) reverse_relations[accessor_name] = RelationInfo( model_field=None, - related_model=related, + related_model=relation.related_model, to_many=True, # manytomany do not have to_fields to_field=None, diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index f925ce3d3..697c549de 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -120,13 +120,12 @@ class DBTransactionAPIExceptionTests(TestCase): Transaction is rollbacked by our transaction atomic block. """ request = factory.post('/') - num_queries = (4 if getattr(connection.features, - 'can_release_savepoints', False) else 3) + num_queries = 4 if connection.features.can_release_savepoints else 3 with self.assertNumQueries(num_queries): # 1 - begin savepoint # 2 - insert # 3 - rollback savepoint - # 4 - release savepoint (django>=1.8 only) + # 4 - release savepoint with transaction.atomic(): response = self.view(request) assert transaction.get_rollback() diff --git a/tests/test_fields.py b/tests/test_fields.py index 101d3b26d..fc9ce192a 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -5,7 +5,6 @@ import unittest import uuid from decimal import ROUND_DOWN, ROUND_UP, Decimal -import django import pytest from django.http import QueryDict from django.test import TestCase, override_settings @@ -1197,11 +1196,6 @@ class TestDateTimeField(FieldValues): field = serializers.DateTimeField(default_timezone=utc) -if django.VERSION[:2] <= (1, 8): - # Doesn't raise an error on earlier versions of Django - TestDateTimeField.invalid_inputs.pop('2018-08-16 22:00-24:00') - - class TestCustomInputFormatDateTimeField(FieldValues): """ Valid and invalid values for `DateTimeField` with a custom input format. diff --git a/tests/test_filters.py b/tests/test_filters.py index 970f6bdfc..f9e068fec 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,9 +1,7 @@ from __future__ import unicode_literals import datetime -import unittest -import django import pytest from django.core.exceptions import ImproperlyConfigured from django.db import models @@ -291,7 +289,6 @@ class SearchFilterToManyTests(TestCase): Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1)) - @unittest.skipIf(django.VERSION < (1, 9), "Django 1.8 does not support transforms") def test_multiple_filter_conditions(self): class SearchListView(generics.ListAPIView): queryset = Blog.objects.all() From a3df1c119967e04fd57495ebbb4645b02883fc4e Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Mon, 20 Nov 2017 03:51:16 -0500 Subject: [PATCH 7/8] Test Serializer exclude for declared fields (#5599) * Test current behavior of exclude+declared field * Assert declared fields are not present in exclude --- rest_framework/serializers.py | 11 +++++++++++ tests/test_model_serializer.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 994d0273f..0952e190c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1102,6 +1102,17 @@ class ModelSerializer(Serializer): if exclude is not None: # If `Meta.exclude` is included, then remove those fields. for field_name in exclude: + assert field_name not in self._declared_fields, ( + "Cannot both declare the field '{field_name}' and include " + "it in the {serializer_class} 'exclude' option. Remove the " + "field or, if inherited from a parent serializer, disable " + "with `{field_name} = None`." + .format( + field_name=field_name, + serializer_class=self.__class__.__name__ + ) + ) + assert field_name in fields, ( "The field '{field_name}' was included on serializer " "{serializer_class} in the 'exclude' option, but does " diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 203e1fe7f..98586b971 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -900,6 +900,22 @@ class TestSerializerMetaClass(TestCase): "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer." ) + def test_declared_fields_with_exclude_option(self): + class ExampleSerializer(serializers.ModelSerializer): + text = serializers.CharField() + + class Meta: + model = MetaClassTestModel + exclude = ('text',) + + expected = ( + "Cannot both declare the field 'text' and include it in the " + "ExampleSerializer 'exclude' option. Remove the field or, if " + "inherited from a parent serializer, disable with `text = None`." + ) + with self.assertRaisesMessage(AssertionError, expected): + ExampleSerializer().fields + class Issue2704TestCase(TestCase): def test_queryset_all(self): From 134a6f66f92dc4e6ae390f166c6cead8a8d8c13b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Bielawski?= Date: Wed, 22 Nov 2017 05:11:59 +0000 Subject: [PATCH 8/8] Fixed schema generation for filter backends (#5613) --- rest_framework/schemas/inspectors.py | 2 +- tests/test_schemas.py | 48 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index bae4d38ed..80dc49268 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -368,7 +368,7 @@ class AutoSchema(ViewInspector): if hasattr(self.view, 'action'): return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] - return method.lower in ["get", "put", "patch", "delete"] + return method.lower() in ["get", "put", "patch", "delete"] def get_filter_fields(self, path, method): if not self._allows_filters(path, method): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 1a84dfc89..56692d4f5 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -951,3 +951,51 @@ def test_head_and_options_methods_are_excluded(): assert inspector.should_include_endpoint(path, callback) assert inspector.get_allowed_methods(callback) == ["GET"] + + +class TestAutoSchemaAllowsFilters(object): + class MockAPIView(APIView): + filter_backends = [filters.OrderingFilter] + + def _test(self, method): + view = self.MockAPIView() + fields = view.schema.get_filter_fields('', method) + field_names = [f.name for f in fields] + + return 'ordering' in field_names + + def test_get(self): + assert self._test('get') + + def test_GET(self): + assert self._test('GET') + + def test_put(self): + assert self._test('put') + + def test_PUT(self): + assert self._test('PUT') + + def test_patch(self): + assert self._test('patch') + + def test_PATCH(self): + assert self._test('PATCH') + + def test_delete(self): + assert self._test('delete') + + def test_DELETE(self): + assert self._test('DELETE') + + def test_post(self): + assert not self._test('post') + + def test_POST(self): + assert not self._test('POST') + + def test_foo(self): + assert not self._test('foo') + + def test_FOO(self): + assert not self._test('FOO')