diff --git a/.travis.yml b/.travis.yml
index 100a7cd8b..cd87dd339 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,44 +1,51 @@
language: python
python:
+ - "2.7"
+ - "3.4"
- "3.5"
sudo: false
env:
- - TOX_ENV=py27-lint
- - TOX_ENV=py27-docs
- - TOX_ENV=py35-django19
- - TOX_ENV=py34-django19
- - TOX_ENV=py27-django19
- - TOX_ENV=py35-django18
- - TOX_ENV=py34-django18
- - TOX_ENV=py33-django18
- - TOX_ENV=py27-django18
- - TOX_ENV=py27-django110
- - TOX_ENV=py35-django110
- - TOX_ENV=py34-django110
- - TOX_ENV=py27-djangomaster
- - TOX_ENV=py34-djangomaster
- - TOX_ENV=py35-djangomaster
+ - DJANGO=1.8
+ - DJANGO=1.9
+ - DJANGO=1.10
+ - DJANGO=1.11
+ - DJANGO=master
matrix:
fast_finish: true
+ include:
+ - python: "3.6"
+ env: DJANGO=master
+ - python: "3.6"
+ env: DJANGO=1.11
+ - python: "3.3"
+ env: DJANGO=1.8
+ - python: "2.7"
+ env: TOXENV="lint"
+ - python: "2.7"
+ env: TOXENV="docs"
+ exclude:
+ - python: "2.7"
+ env: DJANGO=master
+ - python: "3.4"
+ env: DJANGO=master
+
allow_failures:
- - env: TOX_ENV=py27-djangomaster
- - env: TOX_ENV=py34-djangomaster
- - env: TOX_ENV=py35-djangomaster
+ - env: DJANGO=master
+ - env: DJANGO=1.11
install:
- # Virtualenv < 14 is required to keep the Python 3.2 builds running.
- - pip install tox "virtualenv<14"
+ - pip install tox tox-travis
script:
- - tox -e $TOX_ENV
+ - tox
after_success:
- pip install codecov
- - codecov -e TOX_ENV
+ - codecov -e TOXENV,DJANGO
notifications:
email: false
diff --git a/README.md b/README.md
index 7bf90ba6a..609f99184 100644
--- a/README.md
+++ b/README.md
@@ -24,10 +24,11 @@ The initial aim is to provide a single full-time position on REST framework.
-
+
+
diff --git a/rest_framework/test.py b/rest_framework/test.py index 241f94c91..87255bca0 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -114,7 +114,7 @@ if requests is not None: self.mount('https://', adapter) def request(self, method, url, *args, **kwargs): - if ':' not in url: + if not url.startswith('http'): raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url) return super(RequestsClient, self).request(method, url, *args, **kwargs) diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 4ea55300e..2ce4ba52d 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,8 +1,8 @@ from __future__ import unicode_literals -from django.conf.urls import include, url +from django.conf.urls import url -from rest_framework.compat import RegexURLResolver +from rest_framework.compat import RegexURLResolver, include from rest_framework.settings import api_settings diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 8203a7bc8..8896e4f2c 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -55,6 +55,11 @@ class JSONEncoder(json.JSONEncoder): elif hasattr(obj, 'tolist'): # Numpy arrays and array scalars. return obj.tolist() + elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)): + raise RuntimeError( + 'Cannot return a coreapi object from a JSON view. ' + 'You should be using a schema renderer instead for this view.' + ) elif hasattr(obj, '__getitem__'): try: return dict(obj) @@ -62,9 +67,4 @@ class JSONEncoder(json.JSONEncoder): pass elif hasattr(obj, '__iter__'): return tuple(item for item in obj) - elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)): - raise RuntimeError( - 'Cannot return a coreapi object from a JSON view. ' - 'You should be using a schema renderer instead for this view.' - ) return super(JSONEncoder, self).default(obj) diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index ca5b33c5e..78cb37e56 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -32,23 +32,18 @@ def dedent(content): unindented text on the initial line. """ content = force_text(content) - whitespace_counts = [ - len(line) - len(line.lstrip(' ')) - for line in content.splitlines()[1:] if line.lstrip() - ] - tab_counts = [ - len(line) - len(line.lstrip('\t')) - for line in content.splitlines()[1:] if line.lstrip() - ] + lines = [line for line in content.splitlines()[1:] if line.lstrip()] # unindent the content if needed - if whitespace_counts: - whitespace_pattern = '^' + (' ' * min(whitespace_counts)) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - elif tab_counts: - whitespace_pattern = '^' + ('\t' * min(whitespace_counts)) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - + if lines: + whitespace_counts = min([len(line) - len(line.lstrip(' ')) for line in lines]) + tab_counts = min([len(line) - len(line.lstrip('\t')) for line in lines]) + if whitespace_counts: + whitespace_pattern = '^' + (' ' * whitespace_counts) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) + elif tab_counts: + whitespace_pattern = '^' + ('\t' * tab_counts) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) return content.strip() diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index def4b9641..865c283cc 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -49,10 +49,8 @@ def order_by_precedence(media_type_lst): @python_2_unicode_compatible class _MediaType(object): def __init__(self, media_type_str): - if media_type_str is None: - media_type_str = '' - self.orig = media_type_str - self.full_type, self.params = parse_header(media_type_str.encode(HTTP_HEADER_ENCODING)) + self.orig = '' if (media_type_str is None) else media_type_str + self.full_type, self.params = parse_header(self.orig.encode(HTTP_HEADER_ENCODING)) self.main_type, sep, self.sub_type = self.full_type.partition('/') def match(self, other): diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 3e3e434e6..f8200c98f 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -76,7 +76,12 @@ def _get_forward_relationships(opts): Returns an `OrderedDict` of field names to `RelationInfo`. """ forward_relations = OrderedDict() - for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]: + for field in [ + field for field in opts.fields + if field.serialize and get_remote_field(field) and not (field.primary_key and field.one_to_one) + # If the field is a OneToOneField and it's been marked as PK, then this + # is a multi-table inheritance auto created PK ('%_ptr'). + ]: forward_relations[field.name] = RelationInfo( model_field=field, related_model=get_related_model(field), diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py index e5524afe8..5c9a7ade1 100644 --- a/rest_framework/versioning.py +++ b/rest_framework/versioning.py @@ -117,7 +117,7 @@ class NamespaceVersioning(BaseVersioning): def determine_version(self, request, *args, **kwargs): resolver_match = getattr(request, 'resolver_match', None) - if (resolver_match is None or not resolver_match.namespace): + if resolver_match is None or not resolver_match.namespace: return self.default_version # Allow for possibly nested namespaces. diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 151fabad5..631649369 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals import base64 +import pytest from django.conf.urls import include, url from django.contrib.auth.models import User from django.db import models @@ -151,6 +152,18 @@ class BasicAuthTests(TestCase): assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response['WWW-Authenticate'] == 'Basic realm="api"' + def test_fail_post_if_credentials_are_missing(self): + response = self.csrf_client.post( + '/basic/', {'example': 'example'}, HTTP_AUTHORIZATION='Basic ') + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_fail_post_if_credentials_contain_spaces(self): + response = self.csrf_client.post( + '/basic/', {'example': 'example'}, + HTTP_AUTHORIZATION='Basic foo bar' + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + @override_settings(ROOT_URLCONF='tests.test_authentication') class SessionAuthTests(TestCase): @@ -249,6 +262,17 @@ class BaseTokenAuthTests(object): ) assert response.status_code == status.HTTP_200_OK + def test_fail_authentication_if_user_is_not_active(self): + user = User.objects.create_user('foo', 'bar', 'baz') + user.is_active = False + user.save() + self.model.objects.create(key='foobar_token', user=user) + response = self.csrf_client.post( + self.path, {'example': 'example'}, + HTTP_AUTHORIZATION=self.header_prefix + 'foobar_token' + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_fail_post_form_passing_nonexistent_token_auth(self): # use a nonexistent token key auth = self.header_prefix + 'wxyz6789' @@ -257,6 +281,19 @@ class BaseTokenAuthTests(object): ) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_fail_post_if_token_is_missing(self): + response = self.csrf_client.post( + self.path, {'example': 'example'}, + HTTP_AUTHORIZATION=self.header_prefix) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_fail_post_if_token_contains_spaces(self): + response = self.csrf_client.post( + self.path, {'example': 'example'}, + HTTP_AUTHORIZATION=self.header_prefix + 'foo bar' + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_fail_post_form_passing_invalid_token_auth(self): # add an 'invalid' unicode character auth = self.header_prefix + self.key + "ΒΈ" @@ -461,3 +498,28 @@ class NoAuthenticationClassesTests(TestCase): response = view(request) assert response.status_code == status.HTTP_403_FORBIDDEN assert response.data == {'detail': 'Dummy permission message'} + + +class BasicAuthenticationUnitTests(TestCase): + + def test_base_authentication_abstract_method(self): + with pytest.raises(NotImplementedError): + BaseAuthentication().authenticate({}) + + def test_basic_authentication_raises_error_if_user_not_found(self): + auth = BasicAuthentication() + with pytest.raises(exceptions.AuthenticationFailed): + auth.authenticate_credentials('invalid id', 'invalid password') + + def test_basic_authentication_raises_error_if_user_not_active(self): + from rest_framework import authentication + + class MockUser(object): + is_active = False + old_authenticate = authentication.authenticate + authentication.authenticate = lambda **kwargs: MockUser() + auth = authentication.BasicAuthentication() + with pytest.raises(exceptions.AuthenticationFailed) as error: + auth.authenticate_credentials('foo', 'bar') + assert 'User inactive or deleted.' in str(error) + authentication.authenticate = old_authenticate diff --git a/tests/test_authtoken.py b/tests/test_authtoken.py new file mode 100644 index 000000000..04eeb2f63 --- /dev/null +++ b/tests/test_authtoken.py @@ -0,0 +1,29 @@ +import pytest +from django.contrib.admin import site +from django.contrib.auth.models import User +from django.test import TestCase + +from rest_framework.authtoken.admin import TokenAdmin +from rest_framework.authtoken.models import Token +from rest_framework.authtoken.serializers import AuthTokenSerializer +from rest_framework.exceptions import ValidationError + + +class AuthTokenTests(TestCase): + + def setUp(self): + self.site = site + self.user = User.objects.create_user(username='test_user') + self.token = Token.objects.create(key='test token', user=self.user) + + def test_model_admin_displayed_fields(self): + mock_request = object() + token_admin = TokenAdmin(self.token, self.site) + assert token_admin.get_fields(mock_request) == ('user',) + + def test_token_string_representation(self): + assert str(self.token) == 'test token' + + def test_validate_raise_error_if_no_credentials_provided(self): + with pytest.raises(ValidationError): + AuthTokenSerializer().validate({}) diff --git a/tests/test_bound_fields.py b/tests/test_bound_fields.py index 814ddf9a5..ade739d41 100644 --- a/tests/test_bound_fields.py +++ b/tests/test_bound_fields.py @@ -45,6 +45,15 @@ class TestSimpleBoundField: assert serializer['amount'].errors is None assert serializer['amount'].name == 'amount' + def test_delete_field(self): + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + serializer = ExampleSerializer() + del serializer.fields['text'] + assert 'text' not in serializer.fields.keys() + def test_as_form_fields(self): class ExampleSerializer(serializers.Serializer): bool_field = serializers.BooleanField() diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 000000000..4c1a5e94d --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,67 @@ +from django.test import TestCase + +from rest_framework import compat + + +class CompatTests(TestCase): + + def setUp(self): + self.original_django_version = compat.django.VERSION + self.original_transaction = compat.transaction + + def tearDown(self): + compat.django.VERSION = self.original_django_version + compat.transaction = self.original_transaction + + def test_total_seconds(self): + class MockTimedelta(object): + days = 1 + seconds = 1 + microseconds = 100 + timedelta = MockTimedelta() + expected = (timedelta.days * 86400.0) + float(timedelta.seconds) + (timedelta.microseconds / 1000000.0) + assert compat.total_seconds(timedelta) == expected + + def test_get_remote_field_with_old_django_version(self): + class MockField(object): + rel = 'example_rel' + compat.django.VERSION = (1, 8) + assert compat.get_remote_field(MockField(), default='default_value') == 'example_rel' + assert compat.get_remote_field(object(), default='default_value') == 'default_value' + + def test_get_remote_field_with_new_django_version(self): + class MockField(object): + remote_field = 'example_remote_field' + compat.django.VERSION = (1, 10) + assert compat.get_remote_field(MockField(), default='default_value') == 'example_remote_field' + assert compat.get_remote_field(object(), default='default_value') == 'default_value' + + def test_set_rollback_for_transaction_in_managed_mode(self): + class MockTransaction(object): + called_rollback = False + called_leave_transaction_management = False + + def is_managed(self): + return True + + def is_dirty(self): + return True + + def rollback(self): + self.called_rollback = True + + def leave_transaction_management(self): + self.called_leave_transaction_management = True + + dirty_mock_transaction = MockTransaction() + compat.transaction = dirty_mock_transaction + compat.set_rollback() + assert dirty_mock_transaction.called_rollback is True + assert dirty_mock_transaction.called_leave_transaction_management is True + + clean_mock_transaction = MockTransaction() + clean_mock_transaction.is_dirty = lambda: False + compat.transaction = clean_mock_transaction + compat.set_rollback() + assert clean_mock_transaction.called_rollback is False + assert clean_mock_transaction.called_leave_transaction_management is True diff --git a/tests/test_description.py b/tests/test_description.py index 08d8bddec..001a3ea21 100644 --- a/tests/test_description.py +++ b/tests/test_description.py @@ -124,4 +124,8 @@ class TestViewNamesAndDescriptions(TestCase): def test_dedent_tabs(): - assert dedent("\tfirst string\n\n\tsecond string") == 'first string\n\n\tsecond string' + result = 'first string\n\nsecond string' + assert dedent(" first string\n\n second string") == result + assert dedent("first string\n\n second string") == result + assert dedent("\tfirst string\n\n\tsecond string") == result + assert dedent("first string\n\n\tsecond string") == result diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 687141476..5c6915d0c 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -2,8 +2,10 @@ from datetime import date, datetime, timedelta, tzinfo from decimal import Decimal from uuid import uuid4 +import pytest from django.test import TestCase +from rest_framework.compat import coreapi from rest_framework.utils.encoders import JSONEncoder @@ -56,7 +58,7 @@ class JSONEncoderTests(TestCase): current_time = datetime.now().time() current_time = current_time.replace(tzinfo=UTC()) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.encoder.default(current_time) def test_encode_date(self): @@ -79,3 +81,13 @@ class JSONEncoderTests(TestCase): """ unique_id = uuid4() assert self.encoder.default(unique_id) == str(unique_id) + + def test_encode_coreapi_raises_error(self): + """ + Tests encoding a coreapi objects raises proper error + """ + with pytest.raises(RuntimeError): + self.encoder.default(coreapi.Document()) + + with pytest.raises(RuntimeError): + self.encoder.default(coreapi.Error()) diff --git a/tests/test_fields.py b/tests/test_fields.py index 069ba879d..16221d4cc 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -8,7 +8,8 @@ from decimal import Decimal import pytest from django.http import QueryDict from django.test import TestCase, override_settings -from django.utils import six, timezone +from django.utils import six +from django.utils.timezone import utc import rest_framework from rest_framework import serializers @@ -1129,13 +1130,13 @@ class TestDateTimeField(FieldValues): Valid and invalid values for `DateTimeField`. """ valid_inputs = { - '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), - '2001-01-01T13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), - '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), - datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), - datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + '2001-01-01T13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), # Django 1.4 does not support timezone string parsing. - '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()) + '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc) } invalid_inputs = { 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], @@ -1144,13 +1145,13 @@ class TestDateTimeField(FieldValues): } outputs = { datetime.datetime(2001, 1, 1, 13, 00): '2001-01-01T13:00:00', - datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): '2001-01-01T13:00:00Z', + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): '2001-01-01T13:00:00Z', '2001-01-01T00:00:00': '2001-01-01T00:00:00', six.text_type('2016-01-10T00:00:00'): '2016-01-10T00:00:00', None: None, '': None, } - field = serializers.DateTimeField(default_timezone=timezone.UTC()) + field = serializers.DateTimeField(default_timezone=utc) class TestCustomInputFormatDateTimeField(FieldValues): @@ -1158,13 +1159,13 @@ class TestCustomInputFormatDateTimeField(FieldValues): Valid and invalid values for `DateTimeField` with a custom input format. """ valid_inputs = { - '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()), + '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=utc), } invalid_inputs = { '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.'] } outputs = {} - field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) + field = serializers.DateTimeField(default_timezone=utc, input_formats=['%I:%M%p, %d %b %Y']) class TestCustomOutputFormatDateTimeField(FieldValues): @@ -1196,7 +1197,7 @@ class TestNaiveDateTimeField(FieldValues): Valid and invalid values for `DateTimeField` with naive datetimes. """ valid_inputs = { - datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00), + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): datetime.datetime(2001, 1, 1, 13, 00), '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00), } invalid_inputs = {} @@ -1667,6 +1668,16 @@ class TestEmptyListField(FieldValues): field = serializers.ListField(child=serializers.IntegerField(), allow_empty=False) +class TestListFieldLengthLimit(FieldValues): + valid_inputs = () + invalid_inputs = [ + ((0, 1), ['Ensure this field has at least 3 elements.']), + ((0, 1, 2, 3, 4, 5), ['Ensure this field has no more than 4 elements.']), + ] + outputs = () + field = serializers.ListField(child=serializers.IntegerField(), min_length=3, max_length=4) + + class TestUnvalidatedListField(FieldValues): """ Values for `ListField` with no `child` argument. diff --git a/tests/test_filters.py b/tests/test_filters.py index 0cc326239..7db5da63f 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -5,6 +5,7 @@ import unittest import warnings from decimal import Decimal +import pytest from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured from django.db import models @@ -119,6 +120,27 @@ if django_filters: ] +class BaseFilterTests(TestCase): + def setUp(self): + self.original_coreapi = filters.coreapi + filters.coreapi = True # mock it, because not None value needed + self.filter_backend = filters.BaseFilterBackend() + + def tearDown(self): + filters.coreapi = self.original_coreapi + + def test_filter_queryset_raises_error(self): + with pytest.raises(NotImplementedError): + self.filter_backend.filter_queryset(None, None, None) + + def test_get_schema_fields_checks_for_coreapi(self): + filters.coreapi = None + with pytest.raises(AssertionError): + self.filter_backend.get_schema_fields({}) + filters.coreapi = True + assert self.filter_backend.get_schema_fields({}) == [] + + class CommonFilteringTestCase(TestCase): def _serialize_object(self, obj): return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} @@ -429,6 +451,19 @@ class SearchFilterTests(TestCase): {'id': 2, 'title': 'zz', 'text': 'bcd'} ] + def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + + view = SearchListView.as_view() + request = factory.get('/') + response = view(request) + expected = SearchFilterSerializer(SearchFilterModel.objects.all(), + many=True).data + assert response.data == expected + def test_exact_search(self): class SearchListView(generics.ListAPIView): queryset = SearchFilterModel.objects.all() diff --git a/tests/test_generics.py b/tests/test_generics.py index c24cda006..59278572e 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -547,3 +547,94 @@ class TestGuardedQueryset(TestCase): request = factory.get('/') with pytest.raises(RuntimeError): view(request).render() + + +class ApiViewsTests(TestCase): + + def test_create_api_view_post(self): + class MockCreateApiView(generics.CreateAPIView): + def create(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockCreateApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.post('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data + + def test_destroy_api_view_delete(self): + class MockDestroyApiView(generics.DestroyAPIView): + def destroy(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockDestroyApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.delete('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data + + def test_update_api_view_partial_update(self): + class MockUpdateApiView(generics.UpdateAPIView): + def partial_update(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockUpdateApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.patch('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data + + def test_retrieve_update_api_view_get(self): + class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): + def retrieve(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockRetrieveUpdateApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.get('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data + + def test_retrieve_update_api_view_put(self): + class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): + def update(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockRetrieveUpdateApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.put('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data + + def test_retrieve_update_api_view_patch(self): + class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): + def partial_update(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockRetrieveUpdateApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.patch('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data + + def test_retrieve_destroy_api_view_get(self): + class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): + def retrieve(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockRetrieveDestroyUApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.get('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data + + def test_retrieve_destroy_api_view_delete(self): + class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): + def destroy(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) + view = MockRetrieveDestroyUApiView() + data = ('test request', ('test arg',), {'test_kwarg': 'test'}) + view.delete('test request', 'test arg', test_kwarg='test') + assert view.called is True + assert view.call_args == data diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py index 3dfe6ef74..c49fc96d4 100644 --- a/tests/test_htmlrenderer.py +++ b/tests/test_htmlrenderer.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import django.template.loader +import pytest from django.conf.urls import url -from django.core.exceptions import PermissionDenied +from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.http import Http404 from django.template import Template, TemplateDoesNotExist from django.test import TestCase, override_settings @@ -46,6 +47,12 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_htmlrenderer') class TemplateHTMLRendererTests(TestCase): def setUp(self): + class MockResponse(object): + template_name = None + self.mock_response = MockResponse() + self._monkey_patch_get_template() + + def _monkey_patch_get_template(self): """ Monkeypatch get_template """ @@ -87,6 +94,40 @@ class TemplateHTMLRendererTests(TestCase): self.assertEqual(response.content, six.b("403 Forbidden")) self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + # 2 tests below are based on order of if statements in corresponding method + # of TemplateHTMLRenderer + def test_get_template_names_returns_own_template_name(self): + renderer = TemplateHTMLRenderer() + renderer.template_name = 'test_template' + template_name = renderer.get_template_names(self.mock_response, view={}) + assert template_name == ['test_template'] + + def test_get_template_names_returns_view_template_name(self): + renderer = TemplateHTMLRenderer() + + class MockResponse(object): + template_name = None + + class MockView(object): + def get_template_names(self): + return ['template from get_template_names method'] + + class MockView2(object): + template_name = 'template from template_name attribute' + + template_name = renderer.get_template_names(self.mock_response, + MockView()) + assert template_name == ['template from get_template_names method'] + + template_name = renderer.get_template_names(self.mock_response, + MockView2()) + assert template_name == ['template from template_name attribute'] + + def test_get_template_names_raises_error_if_no_template_found(self): + renderer = TemplateHTMLRenderer() + with pytest.raises(ImproperlyConfigured): + renderer.get_template_names(self.mock_response, view=object()) + @override_settings(ROOT_URLCONF='tests.test_htmlrenderer') class TemplateHTMLRendererExceptionTests(TestCase): diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 7a02c2a3d..a9d2dc0c9 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import pytest from django.core.validators import MaxValueValidator, MinValueValidator from django.db import models from django.test import TestCase @@ -17,6 +18,11 @@ request = Request(APIRequestFactory().options('/')) class TestMetadata: + + def test_determine_metadata_abstract_method_raises_proper_error(self): + with pytest.raises(NotImplementedError): + metadata.BaseMetadata().determine_metadata(None, None) + def test_metadata(self): """ OPTIONS requests to views should return a valid 200 response. @@ -263,12 +269,25 @@ class TestMetadata: view = ExampleView.as_view(versioning_class=scheme) view(request=request) + def test_list_serializer_metadata_returns_info_about_fields_of_child_serializer(self): + class ExampleSerializer(serializers.Serializer): + integer_field = serializers.IntegerField(max_value=10) + char_field = serializers.CharField(required=False) + + class ExampleListSerializer(serializers.ListSerializer): + pass + + options = metadata.SimpleMetadata() + child_serializer = ExampleSerializer() + list_serializer = ExampleListSerializer(child=child_serializer) + assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer) + class TestSimpleMetadataFieldInfo(TestCase): def test_null_boolean_field_info_type(self): options = metadata.SimpleMetadata() field_info = options.get_field_info(serializers.NullBooleanField()) - self.assertEqual(field_info['type'], 'boolean') + assert field_info['type'] == 'boolean' def test_related_field_choices(self): options = metadata.SimpleMetadata() @@ -277,7 +296,7 @@ class TestSimpleMetadataFieldInfo(TestCase): field_info = options.get_field_info( serializers.RelatedField(queryset=BasicModel.objects.all()) ) - self.assertNotIn('choices', field_info) + assert 'choices' not in field_info class TestModelSerializerMetadata(TestCase): diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py index 6b8fd5060..b435b876a 100644 --- a/tests/test_negotiation.py +++ b/tests/test_negotiation.py @@ -1,11 +1,16 @@ from __future__ import unicode_literals +import pytest +from django.http import Http404 from django.test import TestCase -from rest_framework.negotiation import DefaultContentNegotiation +from rest_framework.negotiation import ( + BaseContentNegotiation, DefaultContentNegotiation +) from rest_framework.renderers import BaseRenderer from rest_framework.request import Request from rest_framework.test import APIRequestFactory +from rest_framework.utils.mediatypes import _MediaType factory = APIRequestFactory() @@ -55,3 +60,46 @@ class TestAcceptedMediaType(TestCase): accepted_renderer, accepted_media_type = self.select_renderer(request) assert accepted_media_type == 'application/openapi+json;version=2.0' assert accepted_renderer.format == 'swagger' + + def test_match_is_false_if_main_types_not_match(self): + mediatype = _MediaType('test_1') + anoter_mediatype = _MediaType('test_2') + assert mediatype.match(anoter_mediatype) is False + + def test_mediatype_match_is_false_if_keys_not_match(self): + mediatype = _MediaType(';test_param=foo') + another_mediatype = _MediaType(';test_param=bar') + assert mediatype.match(another_mediatype) is False + + def test_mediatype_precedence_with_wildcard_subtype(self): + mediatype = _MediaType('test/*') + assert mediatype.precedence == 1 + + def test_mediatype_string_representation(self): + mediatype = _MediaType('test/*; foo=bar') + params_str = '' + for key, val in mediatype.params.items(): + params_str += '; %s=%s' % (key, val) + expected = 'test/*' + params_str + assert str(mediatype) == expected + + def test_raise_error_if_no_suitable_renderers_found(self): + class MockRenderer(object): + format = 'xml' + renderers = [MockRenderer()] + with pytest.raises(Http404): + self.negotiator.filter_renderers(renderers, format='json') + + +class BaseContentNegotiationTests(TestCase): + + def setUp(self): + self.negotiator = BaseContentNegotiation() + + def test_raise_error_for_abstract_select_parser_method(self): + with pytest.raises(NotImplementedError): + self.negotiator.select_parser(None, None) + + def test_raise_error_for_abstract_select_renderer_method(self): + with pytest.raises(NotImplementedError): + self.negotiator.select_renderer(None, None) diff --git a/tests/test_one_to_one_with_inheritance.py b/tests/test_one_to_one_with_inheritance.py index 06e1cd8b8..9c489c1df 100644 --- a/tests/test_one_to_one_with_inheritance.py +++ b/tests/test_one_to_one_with_inheritance.py @@ -14,7 +14,7 @@ from tests.test_multitable_inheritance import ChildModel # Regression test for #4290 class ChildAssociatedModel(RESTFrameworkModel): - child_model = models.OneToOneField(ChildModel) + child_model = models.OneToOneField(ChildModel, on_delete=models.CASCADE) child_name = models.CharField(max_length=100) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 9f2e1c57c..dd7f70330 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -370,6 +370,13 @@ class TestLimitOffset: assert self.pagination.display_page_controls assert isinstance(self.pagination.to_html(), type('')) + def test_pagination_not_applied_if_limit_or_default_limit_not_set(self): + class MockPagination(pagination.LimitOffsetPagination): + default_limit = None + request = Request(factory.get('/')) + queryset = MockPagination().paginate_queryset(self.queryset, request) + assert queryset is None + def test_single_offset(self): """ When the offset is not a multiple of the limit we get some edge cases: diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 5052e2e53..2499bfa3a 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -35,7 +35,7 @@ class TestFormParser(TestCase): stream = StringIO(self.string) data = parser.parse(stream) - self.assertEqual(Form(data).is_valid(), True) + assert Form(data).is_valid() is True class TestFileUploadParser(TestCase): @@ -62,7 +62,7 @@ class TestFileUploadParser(TestCase): self.stream.seek(0) data_and_files = parser.parse(self.stream, None, self.parser_context) file_obj = data_and_files.files['file'] - self.assertEqual(file_obj._size, 14) + assert file_obj._size == 14 def test_parse_missing_filename(self): """ @@ -108,22 +108,22 @@ class TestFileUploadParser(TestCase): def test_get_filename(self): parser = FileUploadParser() filename = parser.get_filename(self.stream, None, self.parser_context) - self.assertEqual(filename, 'file.txt') + assert filename == 'file.txt' def test_get_encoded_filename(self): parser = FileUploadParser() self.__replace_content_disposition('inline; filename*=utf-8\'\'ΓΔ₯Ζ¦.txt') filename = parser.get_filename(self.stream, None, self.parser_context) - self.assertEqual(filename, 'ΓΔ₯Ζ¦.txt') + assert filename == 'ΓΔ₯Ζ¦.txt' self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ΓΔ₯Ζ¦.txt') filename = parser.get_filename(self.stream, None, self.parser_context) - self.assertEqual(filename, 'ΓΔ₯Ζ¦.txt') + assert filename == 'ΓΔ₯Ζ¦.txt' self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ΓΔ₯Ζ¦.txt') filename = parser.get_filename(self.stream, None, self.parser_context) - self.assertEqual(filename, 'ΓΔ₯Ζ¦.txt') + assert filename == 'ΓΔ₯Ζ¦.txt' def __replace_content_disposition(self, disposition): self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index cc74debb3..887a6f423 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -91,7 +91,7 @@ class HyperlinkedManyToManyTests(TestCase): {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} ] with self.assertNumQueries(4): - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_many_to_many_retrieve(self): queryset = ManyToManySource.objects.all() @@ -102,7 +102,7 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] with self.assertNumQueries(4): - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_many_to_many_retrieve_prefetch_related(self): queryset = ManyToManySource.objects.all().prefetch_related('targets') @@ -119,15 +119,15 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] with self.assertNumQueries(4): - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_many_to_many_update(self): data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} instance = ManyToManySource.objects.get(pk=1) serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, data) + assert serializer.data == data # Ensure source 1 is updated, and everything else is as expected queryset = ManyToManySource.objects.all() @@ -137,16 +137,15 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_reverse_many_to_many_update(self): data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} instance = ManyToManyTarget.objects.get(pk=1) serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, data) - + assert serializer.data == data # Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all() serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) @@ -156,15 +155,15 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_many_to_many_create(self): data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} serializer = ManyToManySourceSerializer(data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-4') + assert serializer.data == data + assert obj.name == 'source-4' # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() @@ -175,15 +174,15 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_reverse_many_to_many_create(self): data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'target-4') + assert serializer.data == data + assert obj.name == 'target-4' # Ensure target 4 is added, and everything else is as expected queryset = ManyToManyTarget.objects.all() @@ -194,7 +193,7 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @@ -217,7 +216,7 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] with self.assertNumQueries(1): - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() @@ -227,15 +226,15 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ] with self.assertNumQueries(3): - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update(self): data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, data) + assert serializer.data == data # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() @@ -245,20 +244,20 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_incorrect_type(self): data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected URL string, received int.']}) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']} def test_reverse_foreign_key_update(self): data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} instance = ForeignKeyTarget.objects.get(pk=2) serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() # We shouldn't have saved anything to the db yet since save # hasn't been called. queryset = ForeignKeyTarget.objects.all() @@ -267,10 +266,10 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ] - self.assertEqual(new_serializer.data, expected) + assert new_serializer.data == expected serializer.save() - self.assertEqual(serializer.data, data) + assert serializer.data == data # Ensure target 2 is update, and everything else is as expected queryset = ForeignKeyTarget.objects.all() @@ -279,15 +278,15 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_create(self): data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-4') + assert serializer.data == data + assert obj.name == 'source-4' # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() @@ -298,15 +297,15 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_reverse_foreign_key_create(self): data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'target-3') + assert serializer.data == data + assert obj.name == 'target-3' # Ensure target 4 is added, and everything else is as expected queryset = ForeignKeyTarget.objects.all() @@ -316,14 +315,14 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_with_invalid_null(self): data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['This field may not be null.']} @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @@ -345,15 +344,15 @@ class HyperlinkedNullableForeignKeyTests(TestCase): {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_create_with_valid_null(self): data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-4') + assert serializer.data == data + assert obj.name == 'source-4' # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -364,7 +363,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase): {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_create_with_valid_emptystring(self): """ @@ -374,10 +373,10 @@ class HyperlinkedNullableForeignKeyTests(TestCase): data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, expected_data) - self.assertEqual(obj.name, 'source-4') + assert serializer.data == expected_data + assert obj.name == 'source-4' # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -388,15 +387,15 @@ class HyperlinkedNullableForeignKeyTests(TestCase): {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_with_valid_null(self): data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, data) + assert serializer.data == data # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -406,7 +405,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase): {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_with_valid_emptystring(self): """ @@ -417,9 +416,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase): expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, expected_data) + assert serializer.data == expected_data # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -429,7 +428,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase): {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @@ -449,4 +448,4 @@ class HyperlinkedNullableOneToOneTests(TestCase): {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py index 1e1bdaa62..0b9ca79d3 100644 --- a/tests/test_relations_slug.py +++ b/tests/test_relations_slug.py @@ -61,7 +61,7 @@ class SlugForeignKeyTests(TestCase): {'id': 3, 'name': 'source-3', 'target': 'target-1'} ] with self.assertNumQueries(4): - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_retrieve_select_related(self): queryset = ForeignKeySource.objects.all().select_related('target') @@ -76,7 +76,7 @@ class SlugForeignKeyTests(TestCase): {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_reverse_foreign_key_retrieve_prefetch_related(self): queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') @@ -88,9 +88,9 @@ class SlugForeignKeyTests(TestCase): data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, data) + assert serializer.data == data # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() @@ -100,20 +100,20 @@ class SlugForeignKeyTests(TestCase): {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_incorrect_type(self): data = {'id': 1, 'name': 'source-1', 'target': 123} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['Object with name=123 does not exist.']} def test_reverse_foreign_key_update(self): data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} instance = ForeignKeyTarget.objects.get(pk=2) serializer = ForeignKeyTargetSerializer(instance, data=data) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() # We shouldn't have saved anything to the db yet since save # hasn't been called. queryset = ForeignKeyTarget.objects.all() @@ -122,10 +122,10 @@ class SlugForeignKeyTests(TestCase): {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ] - self.assertEqual(new_serializer.data, expected) + assert new_serializer.data == expected serializer.save() - self.assertEqual(serializer.data, data) + assert serializer.data == data # Ensure target 2 is update, and everything else is as expected queryset = ForeignKeyTarget.objects.all() @@ -134,16 +134,16 @@ class SlugForeignKeyTests(TestCase): {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_create(self): data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} serializer = ForeignKeySourceSerializer(data=data) serializer.is_valid() - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-4') + assert serializer.data == data + assert obj.name == 'source-4' # Ensure source 4 is added, and everything else is as expected queryset = ForeignKeySource.objects.all() @@ -154,15 +154,15 @@ class SlugForeignKeyTests(TestCase): {'id': 3, 'name': 'source-3', 'target': 'target-1'}, {'id': 4, 'name': 'source-4', 'target': 'target-2'}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_reverse_foreign_key_create(self): data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} serializer = ForeignKeyTargetSerializer(data=data) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'target-3') + assert serializer.data == data + assert obj.name == 'target-3' # Ensure target 3 is added, and everything else is as expected queryset = ForeignKeyTarget.objects.all() @@ -172,14 +172,14 @@ class SlugForeignKeyTests(TestCase): {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_with_invalid_null(self): data = {'id': 1, 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['This field may not be null.']} class SlugNullableForeignKeyTests(TestCase): @@ -200,15 +200,15 @@ class SlugNullableForeignKeyTests(TestCase): {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_create_with_valid_null(self): data = {'id': 4, 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-4') + assert serializer.data == data + assert obj.name == 'source-4' # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -219,7 +219,7 @@ class SlugNullableForeignKeyTests(TestCase): {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_create_with_valid_emptystring(self): """ @@ -229,10 +229,10 @@ class SlugNullableForeignKeyTests(TestCase): data = {'id': 4, 'name': 'source-4', 'target': ''} expected_data = {'id': 4, 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() obj = serializer.save() - self.assertEqual(serializer.data, expected_data) - self.assertEqual(obj.name, 'source-4') + assert serializer.data == expected_data + assert obj.name == 'source-4' # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -243,15 +243,15 @@ class SlugNullableForeignKeyTests(TestCase): {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_with_valid_null(self): data = {'id': 1, 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, data) + assert serializer.data == data # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -261,7 +261,7 @@ class SlugNullableForeignKeyTests(TestCase): {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected def test_foreign_key_update_with_valid_emptystring(self): """ @@ -272,9 +272,9 @@ class SlugNullableForeignKeyTests(TestCase): expected_data = {'id': 1, 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) - self.assertTrue(serializer.is_valid()) + assert serializer.is_valid() serializer.save() - self.assertEqual(serializer.data, expected_data) + assert serializer.data == expected_data # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -284,4 +284,4 @@ class SlugNullableForeignKeyTests(TestCase): {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ] - self.assertEqual(serializer.data, expected) + assert serializer.data == expected diff --git a/tests/test_renderers.py b/tests/test_renderers.py index a2620e93c..eba5b8104 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -5,9 +5,11 @@ import json import re from collections import MutableMapping, OrderedDict +import pytest from django.conf.urls import include, url from django.core.cache import cache from django.db import models +from django.http.request import HttpRequest from django.test import TestCase, override_settings from django.utils import six from django.utils.safestring import SafeText @@ -15,8 +17,10 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import permissions, serializers, status from rest_framework.renderers import ( - BaseRenderer, BrowsableAPIRenderer, HTMLFormRenderer, JSONRenderer + AdminRenderer, BaseRenderer, BrowsableAPIRenderer, + HTMLFormRenderer, JSONRenderer, StaticHTMLRenderer ) +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory @@ -269,6 +273,18 @@ def strip_trailing_whitespace(content): return re.sub(' +\n', '\n', content) +class BaseRendererTests(TestCase): + """ + Tests BaseRenderer + """ + def test_render_raise_error(self): + """ + BaseRenderer.render should raise NotImplementedError + """ + with pytest.raises(NotImplementedError): + BaseRenderer().render('test') + + class JSONRendererTests(TestCase): """ Tests specific to the JSON Renderer @@ -568,3 +584,67 @@ class TestMultipleChoiceFieldHTMLFormRenderer(TestCase): result) self.assertInHTML('', result) self.assertInHTML('', result) + + +class StaticHTMLRendererTests(TestCase): + """ + Tests specific for Static HTML Renderer + """ + def setUp(self): + self.renderer = StaticHTMLRenderer() + + def test_static_renderer(self): + data = '
text' + result = self.renderer.render(data) + assert result == data + + def test_static_renderer_with_exception(self): + context = { + 'response': Response(status=500, exception=True), + 'request': Request(HttpRequest()) + } + result = self.renderer.render({}, renderer_context=context) + assert result == '500 Internal Server Error' + + +class BrowsableAPIRendererTests(TestCase): + + def setUp(self): + self.renderer = BrowsableAPIRenderer() + + def test_get_description_returns_empty_string_for_401_and_403_statuses(self): + assert self.renderer.get_description({}, status_code=401) == '' + assert self.renderer.get_description({}, status_code=403) == '' + + def test_get_filter_form_returns_none_if_data_is_not_list_instance(self): + class DummyView(object): + get_queryset = None + filter_backends = None + + result = self.renderer.get_filter_form(data='not list', + view=DummyView(), request={}) + assert result is None + + +class AdminRendererTests(TestCase): + + def setUp(self): + self.renderer = AdminRenderer() + + def test_render_when_resource_created(self): + class DummyView(APIView): + renderer_classes = (AdminRenderer, ) + request = Request(HttpRequest()) + request.build_absolute_uri = lambda: 'http://example.com' + response = Response(status=201, headers={'Location': '/test'}) + context = { + 'view': DummyView(), + 'request': request, + 'response': response + } + + result = self.renderer.render(data={'test': 'test'}, + renderer_context=context) + assert result == '' + assert response.status_code == status.HTTP_303_SEE_OTHER + assert response['Location'] == 'http://example.com' diff --git a/tests/test_request.py b/tests/test_request.py index 32fbbc50b..428b969f5 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -42,14 +42,14 @@ class TestContentParsing(TestCase): Ensure request.data returns empty QueryDict for GET request. """ request = Request(factory.get('/')) - self.assertEqual(request.data, {}) + assert request.data == {} def test_standard_behaviour_determines_no_content_HEAD(self): """ Ensure request.data returns empty QueryDict for HEAD request. """ request = Request(factory.head('/')) - self.assertEqual(request.data, {}) + assert request.data == {} def test_request_DATA_with_form_content(self): """ @@ -58,7 +58,7 @@ class TestContentParsing(TestCase): data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(list(request.data.items()), list(data.items())) + assert list(request.data.items()) == list(data.items()) def test_request_DATA_with_text_content(self): """ @@ -69,7 +69,7 @@ class TestContentParsing(TestCase): content_type = 'text/plain' request = Request(factory.post('/', content, content_type=content_type)) request.parsers = (PlainTextParser(),) - self.assertEqual(request.data, content) + assert request.data == content def test_request_POST_with_form_content(self): """ @@ -78,7 +78,7 @@ class TestContentParsing(TestCase): data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(list(request.POST.items()), list(data.items())) + assert list(request.POST.items()) == list(data.items()) def test_request_POST_with_files(self): """ @@ -87,8 +87,8 @@ class TestContentParsing(TestCase): upload = SimpleUploadedFile("file.txt", b"file_content") request = Request(factory.post('/', {'upload': upload})) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(list(request.POST.keys()), []) - self.assertEqual(list(request.FILES.keys()), ['upload']) + assert list(request.POST.keys()) == [] + assert list(request.FILES.keys()) == ['upload'] def test_standard_behaviour_determines_form_content_PUT(self): """ @@ -97,7 +97,7 @@ class TestContentParsing(TestCase): data = {'qwerty': 'uiop'} request = Request(factory.put('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(list(request.data.items()), list(data.items())) + assert list(request.data.items()) == list(data.items()) def test_standard_behaviour_determines_non_form_content_PUT(self): """ @@ -108,7 +108,7 @@ class TestContentParsing(TestCase): content_type = 'text/plain' request = Request(factory.put('/', content, content_type=content_type)) request.parsers = (PlainTextParser(), ) - self.assertEqual(request.data, content) + assert request.data == content class MockView(APIView): @@ -142,10 +142,10 @@ class TestContentParsingWithAuthentication(TestCase): content = {'example': 'example'} response = self.client.post('/', content) - self.assertEqual(status.HTTP_200_OK, response.status_code) + assert status.HTTP_200_OK == response.status_code response = self.csrf_client.post('/', content) - self.assertEqual(status.HTTP_200_OK, response.status_code) + assert status.HTTP_200_OK == response.status_code class TestUserSetter(TestCase): @@ -162,11 +162,11 @@ class TestUserSetter(TestCase): def test_user_can_be_set(self): self.request.user = self.user - self.assertEqual(self.request.user, self.user) + assert self.request.user == self.user def test_user_can_login(self): login(self.request, self.user) - self.assertEqual(self.request.user, self.user) + assert self.request.user == self.user def test_user_can_logout(self): self.request.user = self.user @@ -176,7 +176,7 @@ class TestUserSetter(TestCase): def test_logged_in_user_is_set_on_wrapped_request(self): login(self.request, self.user) - self.assertEqual(self.wrapped_request.user, self.user) + assert self.wrapped_request.user == self.user def test_calling_user_fails_when_attribute_error_is_raised(self): """ @@ -207,15 +207,15 @@ class TestAuthSetter(TestCase): def test_auth_can_be_set(self): request = Request(factory.get('/')) request.auth = 'DUMMY' - self.assertEqual(request.auth, 'DUMMY') + assert request.auth == 'DUMMY' class TestSecure(TestCase): def test_default_secure_false(self): request = Request(factory.get('/', secure=False)) - self.assertEqual(request.scheme, 'http') + assert request.scheme == 'http' def test_default_secure_true(self): request = Request(factory.get('/', secure=True)) - self.assertEqual(request.scheme, 'https') + assert request.scheme == 'https' diff --git a/tests/test_reverse.py b/tests/test_reverse.py index f30a8bf9a..2ca44ab77 100644 --- a/tests/test_reverse.py +++ b/tests/test_reverse.py @@ -38,18 +38,18 @@ class ReverseTests(TestCase): def test_reversed_urls_are_fully_qualified(self): request = factory.get('/view') url = reverse('view', request=request) - self.assertEqual(url, 'http://testserver/view') + assert url == 'http://testserver/view' def test_reverse_with_versioning_scheme(self): request = factory.get('/view') request.versioning_scheme = MockVersioningScheme() url = reverse('view', request=request) - self.assertEqual(url, 'http://scheme-reversed/view') + assert url == 'http://scheme-reversed/view' def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self): request = factory.get('/view') request.versioning_scheme = MockVersioningScheme(raise_error=True) url = reverse('view', request=request) - self.assertEqual(url, 'http://testserver/view') + assert url == 'http://testserver/view' diff --git a/tests/test_routers.py b/tests/test_routers.py index d28e301a0..dc3df2e7b 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -3,12 +3,14 @@ from __future__ import unicode_literals import json from collections import namedtuple -from django.conf.urls import include, url +import pytest +from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured from django.db import models from django.test import TestCase, override_settings from rest_framework import permissions, serializers, viewsets +from rest_framework.compat import include from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import DefaultRouter, SimpleRouter @@ -80,7 +82,7 @@ empty_prefix_urls = [ urlpatterns = [ url(r'^non-namespaced/', include(namespaced_router.urls)), - url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), + url(r'^namespaced/', include(namespaced_router.urls, namespace='example', app_name='example')), url(r'^example/', include(notes_router.urls)), url(r'^example2/', include(kwarged_notes_router.urls)), @@ -124,8 +126,7 @@ class TestSimpleRouter(TestCase): for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']): route = decorator_routes[i] # check url listing - self.assertEqual(route.url, - '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) + assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint) # check method to function mapping if endpoint == 'action3': methods_map = ['post', 'delete'] @@ -134,28 +135,18 @@ class TestSimpleRouter(TestCase): else: methods_map = ['get'] for method in methods_map: - self.assertEqual(route.mapping[method], endpoint) + assert route.mapping[method] == endpoint @override_settings(ROOT_URLCONF='tests.test_routers') class TestRootView(TestCase): def test_retrieve_namespaced_root(self): response = self.client.get('/namespaced/') - self.assertEqual( - response.data, - { - "example": "http://testserver/namespaced/example/", - } - ) + assert response.data == {"example": "http://testserver/namespaced/example/"} def test_retrieve_non_namespaced_root(self): response = self.client.get('/non-namespaced/') - self.assertEqual( - response.data, - { - "example": "http://testserver/non-namespaced/example/", - } - ) + assert response.data == {"example": "http://testserver/non-namespaced/example/"} @override_settings(ROOT_URLCONF='tests.test_routers') @@ -169,27 +160,15 @@ class TestCustomLookupFields(TestCase): def test_custom_lookup_field_route(self): detail_route = notes_router.urls[-1] detail_url_pattern = detail_route.regex.pattern - self.assertIn('