diff --git a/.travis.yml b/.travis.yml index c9d9a1648..100a7cd8b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,6 @@ env: - TOX_ENV=py35-django18 - TOX_ENV=py34-django18 - TOX_ENV=py33-django18 - - TOX_ENV=py32-django18 - TOX_ENV=py27-django18 - TOX_ENV=py27-django110 - TOX_ENV=py35-django110 diff --git a/docs/api-guide/relations.md b/docs/api-guide/relations.md index 8695b2c1e..75643a83e 100644 --- a/docs/api-guide/relations.md +++ b/docs/api-guide/relations.md @@ -457,6 +457,8 @@ There are two keyword arguments you can use to control this behavior: - `html_cutoff` - If set this will be the maximum number of choices that will be displayed by a HTML select drop down. Set to `None` to disable any limiting. Defaults to `1000`. - `html_cutoff_text` - If set this will display a textual indicator if the maximum number of items have been cutoff in an HTML select drop down. Defaults to `"More than {count} items…"` +You can also control these globally using the settings `HTML_SELECT_CUTOFF` and `HTML_SELECT_CUTOFF_TEXT`. + In cases where the cutoff is being enforced you may want to instead use a plain input field in the HTML form. You can do so using the `style` keyword argument. For example: assigned_to = serializers.SlugRelatedField( diff --git a/docs/api-guide/reverse.md b/docs/api-guide/reverse.md index 71fb83f9e..35d88e2db 100644 --- a/docs/api-guide/reverse.md +++ b/docs/api-guide/reverse.md @@ -23,7 +23,7 @@ There's no requirement for you to use them, but if you do then the self-describi **Signature:** `reverse(viewname, *args, **kwargs)` -Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port. +Has the same behavior as [`django.urls.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port. You should **include the request as a keyword argument** to the function, for example: @@ -44,7 +44,7 @@ You should **include the request as a keyword argument** to the function, for ex **Signature:** `reverse_lazy(viewname, *args, **kwargs)` -Has the same behavior as [`django.core.urlresolvers.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port. +Has the same behavior as [`django.urls.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port. As with the `reverse` function, you should **include the request as a keyword argument** to the function, for example: diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index ea018053f..67d317839 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -382,6 +382,22 @@ This should be a function with the following signature: Default: `'rest_framework.views.get_view_description'` +## HTML Select Field cutoffs + +Global settings for [select field cutoffs for rendering relational fields](relations.md#select-field-cutoffs) in the browsable API. + +#### HTML_SELECT_CUTOFF + +Global setting for the `html_cutoff` value. Must be an integer. + +Default: 1000 + +#### HTML_SELECT_CUTOFF_TEXT + +A string representing a global setting for `html_cutoff_text`. + +Default: `"More than {count} items..."` + --- ## Miscellaneous settings diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index 69da7d105..18f9e19e9 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -197,7 +197,7 @@ REST framework includes the following test case classes, that mirror the existin You can use any of REST framework's test case classes as you would for the regular Django test case classes. The `self.client` attribute will be an `APIClient` instance. - from django.core.urlresolvers import reverse + from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase from myproject.apps.core.models import Account diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 20436e6b4..31f24f4b7 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -1,5 +1,5 @@ # Optional packages which may be used with REST framework. markdown==2.6.4 -django-guardian==1.4.3 -django-filter==0.13.0 -coreapi==1.32.0 +django-guardian==1.4.6 +django-filter==0.14.0 +coreapi==2.0.8 diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 457af6c88..68e96703f 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,7 +8,7 @@ ______ _____ _____ _____ __ """ __title__ = 'Django REST framework' -__version__ = '3.4.7' +__version__ = '3.5.0' __author__ = 'Tom Christie' __license__ = 'BSD 2-Clause' __copyright__ = 'Copyright 2011-2016 Tom Christie' diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index df0c48b86..90d3bd96e 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -16,6 +16,9 @@ class AuthTokenSerializer(serializers.Serializer): user = authenticate(username=username, password=password) if user: + # From Django 1.10 onwards the `authenticate` call simply + # returns `None` for is_active=False users. + # (Assuming the default `ModelBackend` authentication backend.) if not user.is_active: msg = _('User account is disabled.') raise serializers.ValidationError(msg) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84..7ec39ba63 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -23,6 +23,16 @@ except ImportError: from django.utils import importlib # Will be removed in Django 1.9 +try: + from django.urls import ( + NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve + ) +except ImportError: + from django.core.urlresolvers import ( # Will be removed in Django 2.0 + NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve + ) + + try: import urlparse # Python 2.x except ImportError: @@ -128,6 +138,12 @@ def is_authenticated(user): return user.is_authenticated +def is_anonymous(user): + if django.VERSION < (1, 10): + return user.is_anonymous() + return user.is_anonymous + + def get_related_model(field): if django.VERSION < (1, 9): return _resolve_model(field.rel.to) @@ -178,6 +194,13 @@ except (ImportError, SyntaxError): uritemplate = None +# requests is optional +try: + import requests +except ImportError: + requests = None + + # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Fixes (#1712). We keep the try/except for the test suite. guardian = None @@ -200,8 +223,13 @@ try: if markdown.version <= '2.2': HEADERID_EXT_PATH = 'headerid' - else: + LEVEL_PARAM = 'level' + elif markdown.version < '2.6': HEADERID_EXT_PATH = 'markdown.extensions.headerid' + LEVEL_PARAM = 'level' + else: + HEADERID_EXT_PATH = 'markdown.extensions.toc' + LEVEL_PARAM = 'baselevel' def apply_markdown(text): """ @@ -211,7 +239,7 @@ try: extensions = [HEADERID_EXT_PATH] extension_configs = { HEADERID_EXT_PATH: { - 'level': '2' + LEVEL_PARAM: '2' } } md = markdown.Markdown( @@ -277,3 +305,11 @@ def template_render(template, context=None, request=None): # backends template, e.g. django.template.backends.django.Template else: return template.render(context, request=request) + + +def set_many(instance, field, value): + if django.VERSION < (1, 10): + setattr(instance, field, value) + else: + field = getattr(instance, field) + field.set(value) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 917a151e5..7f8391b8a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -49,20 +49,34 @@ class empty: pass -def is_simple_callable(obj): - """ - True if the object is a callable that takes no arguments. - """ - function = inspect.isfunction(obj) - method = inspect.ismethod(obj) +if six.PY3: + def is_simple_callable(obj): + """ + True if the object is a callable that takes no arguments. + """ + if not callable(obj): + return False - if not (function or method): - return False + sig = inspect.signature(obj) + params = sig.parameters.values() + return all(param.default != param.empty for param in params) - args, _, _, defaults = inspect.getargspec(obj) - len_args = len(args) if function else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults +else: + def is_simple_callable(obj): + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + + if method: + is_unbound = obj.im_self is None + + args, _, _, defaults = inspect.getargspec(obj) + + len_args = len(args) if function or is_unbound else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults def get_attribute(instance, attrs): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 65c4c0318..c6eb92a24 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,9 +4,6 @@ from __future__ import unicode_literals from collections import OrderedDict from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.core.urlresolvers import ( - NoReverseMatch, Resolver404, get_script_prefix, resolve -) from django.db.models import Manager from django.db.models.query import QuerySet from django.utils import six @@ -14,10 +11,14 @@ from django.utils.encoding import python_2_unicode_compatible, smart_text from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ +from rest_framework.compat import ( + NoReverseMatch, Resolver404, get_script_prefix, resolve +) from rest_framework.fields import ( Field, empty, get_attribute, is_simple_callable, iter_options ) from rest_framework.reverse import reverse +from rest_framework.settings import api_settings from rest_framework.utils import html @@ -71,14 +72,19 @@ MANY_RELATION_KWARGS = ( class RelatedField(Field): queryset = None - html_cutoff = 1000 - html_cutoff_text = _('More than {count} items...') + html_cutoff = None + html_cutoff_text = None def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', self.queryset) - self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) - self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) - + self.html_cutoff = kwargs.pop( + 'html_cutoff', + self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF) + ) + self.html_cutoff_text = kwargs.pop( + 'html_cutoff_text', + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + ) if not method_overridden('get_queryset', RelatedField, self): assert self.queryset is not None or kwargs.get('read_only', None), ( 'Relational field must provide a `queryset` argument, ' @@ -447,15 +453,20 @@ class ManyRelatedField(Field): 'not_a_list': _('Expected a list of items but got type "{input_type}".'), 'empty': _('This list may not be empty.') } - html_cutoff = 1000 - html_cutoff_text = _('More than {count} items...') + html_cutoff = None + html_cutoff_text = None def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation self.allow_empty = kwargs.pop('allow_empty', True) - self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) - self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) - + self.html_cutoff = kwargs.pop( + 'html_cutoff', + self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF) + ) + self.html_cutoff_text = kwargs.pop( + 'html_cutoff_text', + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + ) assert child_relation is not None, '`child_relation` is a required argument.' super(ManyRelatedField, self).__init__(*args, **kwargs) self.child_relation.bind(field_name='', parent=self) diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index 5a7ba09a8..fd418dcca 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -3,11 +3,11 @@ Provide urlresolver functions that return fully qualified URLs or view names """ from __future__ import unicode_literals -from django.core.urlresolvers import reverse as django_reverse -from django.core.urlresolvers import NoReverseMatch from django.utils import six from django.utils.functional import lazy +from rest_framework.compat import reverse as django_reverse +from rest_framework.compat import NoReverseMatch from rest_framework.settings import api_settings from rest_framework.utils.urls import replace_query_param @@ -54,7 +54,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ - Same as `django.core.urlresolvers.reverse`, but optionally takes a request + Same as `django.urls.reverse`, but optionally takes a request and returns a fully qualified URL, using the request to get the base URL. """ if format is not None: diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..a02a0f1bf 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -20,9 +20,9 @@ from collections import OrderedDict, namedtuple from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured -from django.core.urlresolvers import NoReverseMatch from rest_framework import exceptions, renderers, views +from rest_framework.compat import NoReverseMatch from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator @@ -83,6 +83,7 @@ class BaseRouter(object): class SimpleRouter(BaseRouter): + routes = [ # List route. Route( @@ -258,6 +259,13 @@ class SimpleRouter(BaseRouter): trailing_slash=self.trailing_slash ) + # If there is no prefix, the first part of the url is probably + # controlled by project's urls.py and the router is in an app, + # so a slash in the beginning will (A) cause Django to give + # warnings and (B) generate URLS that will require using '//'. + if not prefix and regex[:2] == '^/': + regex = '^' + regex[2:] + view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) @@ -289,42 +297,42 @@ class DefaultRouter(SimpleRouter): self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) super(DefaultRouter, self).__init__(*args, **kwargs) + def get_schema_root_view(self, api_urls=None): + """ + Return a schema root view. + """ + schema_renderers = self.schema_renderers + schema_generator = SchemaGenerator( + title=self.schema_title, + url=self.schema_url, + patterns=api_urls + ) + + class APISchemaView(views.APIView): + _ignore_model_permissions = True + renderer_classes = schema_renderers + + def get(self, request, *args, **kwargs): + schema = schema_generator.get_schema(request) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) + + return APISchemaView.as_view() + def get_api_root_view(self, api_urls=None): """ - Return a view to use as the API root. + Return a basic root view. """ api_root_dict = OrderedDict() list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - view_renderers = list(self.root_renderers) - schema_media_types = [] - - if api_urls and self.schema_title: - view_renderers += list(self.schema_renderers) - schema_generator = SchemaGenerator( - title=self.schema_title, - url=self.schema_url, - patterns=api_urls - ) - schema_media_types = [ - renderer.media_type - for renderer in self.schema_renderers - ] - - class APIRoot(views.APIView): + class APIRootView(views.APIView): _ignore_model_permissions = True - renderer_classes = view_renderers def get(self, request, *args, **kwargs): - if request.accepted_renderer.media_type in schema_media_types: - # Return a schema response. - schema = schema_generator.get_schema(request) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - # Return a plain {"name": "hyperlink"} response. ret = OrderedDict() namespace = request.resolver_match.namespace @@ -345,7 +353,7 @@ class DefaultRouter(SimpleRouter): return Response(ret) - return APIRoot.as_view() + return APIRootView.as_view() def get_urls(self): """ @@ -355,7 +363,10 @@ class DefaultRouter(SimpleRouter): urls = super(DefaultRouter, self).get_urls() if self.include_root_view: - view = self.get_api_root_view(api_urls=urls) + if self.schema_title: + view = self.get_schema_root_view(api_urls=urls) + else: + view = self.get_api_root_view(api_urls=urls) root_url = url(r'^$', view, name=self.root_view_name) urls.append(root_url) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index fdd8a1b1e..0c3ba530f 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -2,12 +2,13 @@ from importlib import import_module from django.conf import settings from django.contrib.admindocs.views import simplify_regex -from django.core.urlresolvers import RegexURLPattern, RegexURLResolver from django.utils import six from django.utils.encoding import force_text from rest_framework import exceptions, serializers -from rest_framework.compat import coreapi, uritemplate, urlparse +from rest_framework.compat import ( + RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse +) from rest_framework.request import clone_request from rest_framework.views import APIView @@ -329,7 +330,7 @@ class SchemaGenerator(object): fields += as_query_fields(filter_backend().get_fields(view)) return fields - # Methods for generating the keys which are used to layout each link. + # Methods for generating the link layout.... default_mapping = { 'get': 'read', diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4d1ed63ae..28f70bd40 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -23,7 +23,7 @@ from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import JSONField as ModelJSONField -from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.compat import postgres_fields, set_many, unicode_to_repr from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -892,19 +892,23 @@ class ModelSerializer(Serializer): # Save many-to-many relationships after the instance is created. if many_to_many: for field_name, value in many_to_many.items(): - setattr(instance, field_name, value) + set_many(instance, field_name, value) return instance def update(self, instance, validated_data): raise_errors_on_nested_writes('update', self, validated_data) + info = model_meta.get_field_info(instance) # Simply set each attribute on the instance, and then save it. # Note that unlike `.create()` we don't need to treat many-to-many # relationships as being a special case. During updates we already # have an instance pk for the relationships to be associated with. for attr, value in validated_data.items(): - setattr(instance, attr, value) + if attr in info.relations and info.relations[attr].to_many: + set_many(instance, attr, value) + else: + setattr(instance, attr, value) instance.save() return instance diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 68c7709e8..89e27e743 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -111,6 +111,10 @@ DEFAULTS = { 'COMPACT_JSON': True, 'COERCE_DECIMAL_TO_STRING': True, 'UPLOADED_FILES_USE_URL': True, + + # Browseable API + 'HTML_SELECT_CUTOFF': 1000, + 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", } diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 3bb85e472..c1c8a5396 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -3,14 +3,13 @@ from __future__ import absolute_import, unicode_literals import re from django import template -from django.core.urlresolvers import NoReverseMatch, reverse from django.template import loader from django.utils import six from django.utils.encoding import force_text, iri_to_uri from django.utils.html import escape, format_html, smart_urlquote from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import template_render +from rest_framework.compat import NoReverseMatch, reverse, template_render from rest_framework.renderers import HTMLFormRenderer from rest_framework.utils.urls import replace_query_param diff --git a/rest_framework/test.py b/rest_framework/test.py index fd9f6ab13..1b3ad80c2 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -4,7 +4,10 @@ # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals +import io + from django.conf import settings +from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory @@ -13,6 +16,7 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode +from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings @@ -21,6 +25,118 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token +if requests is not None: + class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key, default): + return self.getheaders(key) + + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = HeaderDict(headers) + self.closed = False + + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + + class DjangoTestAdapter(requests.adapters.HTTPAdapter): + """ + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. + """ + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + if hasattr(request.body, 'read'): + kwargs['data'] = request.body.read() + else: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key.replace('-', '_')] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + raw_kwargs = {} + + def start_response(wsgi_status, wsgi_headers): + status, _, reason = wsgi_status.partition(' ') + raw_kwargs['status'] = int(status) + raw_kwargs['reason'] = reason + raw_kwargs['headers'] = wsgi_headers + raw_kwargs['version'] = 11 + raw_kwargs['preload_content'] = False + raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) + + # Make the outgoing request via WSGI. + environ = self.get_environ(request) + wsgi_response = self.app(environ, start_response) + + # Build the underlying urllib3.HTTPResponse + raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) + + # Build the requests.Response + return self.build_response(request, raw) + + def close(self): + pass + + class DjangoTestSession(requests.Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + + +def get_requests_client(): + assert requests is not None, 'requests must be installed' + return DjangoTestSession() + + +def get_api_client(): + assert coreapi is not None, 'coreapi must be installed' + session = get_requests_client() + return coreapi.Client(transports=[ + coreapi.transports.HTTPTransport(session=session) + ]) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 7a02bb0f0..4ea55300e 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,8 +1,8 @@ from __future__ import unicode_literals from django.conf.urls import include, url -from django.core.urlresolvers import RegexURLResolver +from rest_framework.compat import RegexURLResolver from rest_framework.settings import api_settings diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 2e3ab9084..74f4f7840 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals -from django.core.urlresolvers import get_script_prefix, resolve +from rest_framework.compat import get_script_prefix, resolve def get_breadcrumbs(url, request=None): diff --git a/tests/browsable_api/test_form_rendering.py b/tests/browsable_api/test_form_rendering.py index 5a31ae0dd..8b79ab6ff 100644 --- a/tests/browsable_api/test_form_rendering.py +++ b/tests/browsable_api/test_form_rendering.py @@ -11,6 +11,7 @@ factory = APIRequestFactory() class BasicSerializer(serializers.ModelSerializer): class Meta: model = BasicModel + fields = '__all__' class ManyPostView(generics.GenericAPIView): diff --git a/tests/conftest.py b/tests/conftest.py index a5123b9d8..256678226 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,13 @@ def pytest_configure(): from django.conf import settings + MIDDLEWARE = ( + 'django.middleware.common.CommonMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + ) + settings.configure( DEBUG_PROPAGATE_EXCEPTIONS=True, DATABASES={ @@ -21,12 +28,8 @@ def pytest_configure(): 'APP_DIRS': True, }, ], - MIDDLEWARE_CLASSES=( - 'django.middleware.common.CommonMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - ), + MIDDLEWARE=MIDDLEWARE, + MIDDLEWARE_CLASSES=MIDDLEWARE, INSTALLED_APPS=( 'django.contrib.auth', 'django.contrib.contenttypes', diff --git a/tests/test_api_client.py b/tests/test_api_client.py new file mode 100644 index 000000000..9daf3f3fe --- /dev/null +++ b/tests/test_api_client.py @@ -0,0 +1,452 @@ +from __future__ import unicode_literals + +import os +import tempfile +import unittest + +from django.conf.urls import url +from django.http import HttpResponse +from django.test import override_settings + +from rest_framework.compat import coreapi +from rest_framework.parsers import FileUploadParser +from rest_framework.renderers import CoreJSONRenderer +from rest_framework.response import Response +from rest_framework.test import APITestCase, get_api_client +from rest_framework.views import APIView + + +def get_schema(): + return coreapi.Document( + url='https://api.example.com/', + title='Example API', + content={ + 'simple_link': coreapi.Link('/example/', description='example link'), + 'location': { + 'query': coreapi.Link('/example/', fields=[ + coreapi.Field(name='example', description='example field') + ]), + 'form': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example'), + ]), + 'body': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'path': coreapi.Link('/example/{id}', fields=[ + coreapi.Field(name='id', location='path') + ]) + }, + 'encoding': { + 'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example') + ]), + 'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example') + ]), + 'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[ + coreapi.Field(name='example', location='body') + ]), + }, + 'response': { + 'download': coreapi.Link('/download/'), + 'text': coreapi.Link('/text/') + } + } + ) + + +def _iterlists(querydict): + if hasattr(querydict, 'iterlists'): + return querydict.iterlists() + return querydict.lists() + + +def _get_query_params(request): + # Return query params in a plain dict, using a list value if more + # than one item is present for a given key. + return { + key: (value[0] if len(value) == 1 else value) + for key, value in + _iterlists(request.query_params) + } + + +def _get_data(request): + if not isinstance(request.data, dict): + return request.data + # Coerce multidict into regular dict, and remove files to + # make assertions simpler. + if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'): + # Use a list value if a QueryDict contains multiple items for a key. + return { + key: value[0] if len(value) == 1 else value + for key, value in _iterlists(request.data) + if key not in request.FILES + } + return { + key: value + for key, value in request.data.items() + if key not in request.FILES + } + + +def _get_files(request): + if not request.FILES: + return {} + return { + key: {'name': value.name, 'content': value.read()} + for key, value in request.FILES.items() + } + + +class SchemaView(APIView): + renderer_classes = [CoreJSONRenderer] + + def get(self, request): + schema = get_schema() + return Response(schema) + + +class ListView(APIView): + def get(self, request): + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + def post(self, request): + if request.content_type: + content_type = request.content_type.split(';')[0] + else: + content_type = None + + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request), + 'data': _get_data(request), + 'files': _get_files(request), + 'content_type': content_type + }) + + +class DetailView(APIView): + def get(self, request, id): + return Response({ + 'id': id, + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + +class UploadView(APIView): + parser_classes = [FileUploadParser] + + def post(self, request): + return Response({ + 'method': request.method, + 'files': _get_files(request), + 'content_type': request.content_type + }) + + +class DownloadView(APIView): + def get(self, request): + return HttpResponse('some file content', content_type='image/png') + + +class TextView(APIView): + def get(self, request): + return HttpResponse('123', content_type='text/plain') + + +urlpatterns = [ + url(r'^$', SchemaView.as_view()), + url(r'^example/$', ListView.as_view()), + url(r'^example/(?P[0-9]+)/$', DetailView.as_view()), + url(r'^upload/$', UploadView.as_view()), + url(r'^download/$', DownloadView.as_view()), + url(r'^text/$', TextView.as_view()), +] + + +@unittest.skipUnless(coreapi, 'coreapi not installed') +@override_settings(ROOT_URLCONF='tests.test_api_client') +class APIClientTests(APITestCase): + def test_api_client(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + assert schema.title == 'Example API' + assert schema.url == 'https://api.example.com/' + assert schema['simple_link'].description == 'example link' + assert schema['location']['query'].fields[0].description == 'example field' + data = client.action(schema, ['simple_link']) + expected = { + 'method': 'GET', + 'query_params': {} + } + assert data == expected + + def test_query_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': 123}) + expected = { + 'method': 'GET', + 'query_params': {'example': '123'} + } + assert data == expected + + def test_query_params_with_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'GET', + 'query_params': {'example': ['1', '2', '3']} + } + assert data == expected + + def test_form_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'form'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': {'example': 123}, + 'files': {} + } + assert data == expected + + def test_body_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'body'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': 123, + 'files': {} + } + assert data == expected + + def test_path_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'path'], params={'id': 123}) + expected = { + 'method': 'GET', + 'query_params': {}, + 'id': '123' + } + assert data == expected + + def test_multipart_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'multipart'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': name, 'content': 'example file content'}} + } + assert data == expected + + def test_multipart_encoding_no_file(self): + # When no file is included, multipart encoding should still be used. + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': 123}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_string_file_content(self): + # Test for `coreapi.utils.File` support. + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File(name='example.txt', content='123') + data = client.action(schema, ['encoding', 'multipart'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + def test_multipart_encoding_in_body(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'} + data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'bar': 'abc'}, + 'files': {'foo': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + # URLencoded + + def test_urlencoded_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_in_body(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'foo': '123', 'bar': 'true'}, + 'files': {} + } + assert data == expected + + # Raw uploads + + def test_raw_upload(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': name, 'content': 'example file content'}}, + 'content_type': 'application/octet-stream' + } + assert data == expected + + def test_raw_upload_string_file_content(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/plain' + } + assert data == expected + + def test_raw_upload_explicit_content_type(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123', 'text/html') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/html' + } + assert data == expected + + # Responses + + def test_text_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'text']) + + expected = '123' + assert data == expected + + def test_download_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'download']) + assert data.basename == 'download.png' + assert data.read() == b'some file content' diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 8342ad3af..09d7f2fb1 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -5,7 +5,7 @@ import unittest from django.conf.urls import url from django.db import connection, connections, transaction from django.http import Http404 -from django.test import TestCase, TransactionTestCase +from django.test import TestCase, TransactionTestCase, override_settings from django.utils.decorators import method_decorator from rest_framework import status @@ -36,6 +36,20 @@ class APIExceptionView(APIView): raise APIException +class NonAtomicAPIExceptionView(APIView): + @method_decorator(transaction.non_atomic_requests) + def dispatch(self, *args, **kwargs): + return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs) + + def get(self, request, *args, **kwargs): + BasicModel.objects.all() + raise Http404 + +urlpatterns = ( + url(r'^$', NonAtomicAPIExceptionView.as_view()), +) + + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." @@ -124,22 +138,8 @@ class DBTransactionAPIExceptionTests(TestCase): connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) +@override_settings(ROOT_URLCONF='tests.test_atomic_requests') class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): - @property - def urls(self): - class NonAtomicAPIExceptionView(APIView): - @method_decorator(transaction.non_atomic_requests) - def dispatch(self, *args, **kwargs): - return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs) - - def get(self, request, *args, **kwargs): - BasicModel.objects.all() - raise Http404 - - return ( - url(r'^$', NonAtomicAPIExceptionView.as_view()), - ) - def setUp(self): connections.databases['default']['ATOMIC_REQUESTS'] = True diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 5ef620abe..6f17ea14f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -20,6 +20,7 @@ from rest_framework.authentication import ( ) from rest_framework.authtoken.models import Token from rest_framework.authtoken.views import obtain_auth_token +from rest_framework.compat import is_authenticated from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView @@ -408,7 +409,7 @@ class FailingAuthAccessedInRenderer(TestCase): def render(self, data, media_type=None, renderer_context=None): request = renderer_context['request'] - if request.user.is_authenticated(): + if is_authenticated(request.user): return b'authenticated' return b'not authenticated' diff --git a/tests/test_fields.py b/tests/test_fields.py index 4a4b741c5..c271afa9e 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,6 +1,7 @@ import datetime import os import re +import unittest import uuid from decimal import Decimal @@ -11,6 +12,67 @@ from django.utils import six, timezone import rest_framework from rest_framework import serializers +from rest_framework.fields import is_simple_callable + +try: + import typings +except ImportError: + typings = False + + +# Tests for helper functions. +# --------------------------- + +class TestIsSimpleCallable: + + def test_method(self): + class Foo: + @classmethod + def classmethod(cls): + pass + + def valid(self): + pass + + def valid_kwargs(self, param='value'): + pass + + def invalid(self, param): + pass + + assert is_simple_callable(Foo.classmethod) + + # unbound methods + assert not is_simple_callable(Foo.valid) + assert not is_simple_callable(Foo.valid_kwargs) + assert not is_simple_callable(Foo.invalid) + + # bound methods + assert is_simple_callable(Foo().valid) + assert is_simple_callable(Foo().valid_kwargs) + assert not is_simple_callable(Foo().invalid) + + def test_function(self): + def simple(): + pass + + def valid(param='value', param2='value'): + pass + + def invalid(param, param2='value'): + pass + + assert is_simple_callable(simple) + assert is_simple_callable(valid) + assert not is_simple_callable(invalid) + + @unittest.skipUnless(typings, 'requires python 3.5') + def test_type_annotation(self): + # The annotation will otherwise raise a syntax error in python < 3.5 + exec("def valid(param: str='value'): pass", locals()) + valid = locals()['valid'] + + assert is_simple_callable(valid) # Tests for field keyword arguments and core functionality. diff --git a/tests/test_filters.py b/tests/test_filters.py index 03d61fc37..c67412dd7 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -6,7 +6,6 @@ from decimal import Decimal from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured -from django.core.urlresolvers import reverse from django.db import models from django.test import TestCase from django.test.utils import override_settings @@ -14,7 +13,7 @@ from django.utils.dateparse import parse_date from django.utils.six.moves import reload_module from rest_framework import filters, generics, serializers, status -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, reverse from rest_framework.test import APIRequestFactory from .models import BaseFilterableItem, BasicModel, FilterableItem @@ -77,6 +76,7 @@ if django_filters: class Meta: model = BaseFilterableItem + fields = '__all__' class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): queryset = FilterableItem.objects.all() @@ -456,7 +456,7 @@ class AttributeModel(models.Model): class SearchFilterModelFk(models.Model): title = models.CharField(max_length=20) - attribute = models.ForeignKey(AttributeModel) + attribute = models.ForeignKey(AttributeModel, on_delete=models.CASCADE) class SearchFilterFkSerializer(serializers.ModelSerializer): diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 01243ff6e..cd9b2dfc3 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -20,7 +20,7 @@ from django.test import TestCase from django.utils import six from rest_framework import serializers -from rest_framework.compat import unicode_repr +from rest_framework.compat import set_many, unicode_repr def dedent(blocktext): @@ -651,7 +651,7 @@ class TestIntegration(TestCase): foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, ) - self.instance.many_to_many = self.many_to_many_targets + set_many(self.instance, 'many_to_many', self.many_to_many_targets) self.instance.save() def test_pk_retrival(self): @@ -962,7 +962,7 @@ class OneToOneTargetTestModel(models.Model): class OneToOneSourceTestModel(models.Model): - target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True) + target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE) class TestModelFieldValues(TestCase): @@ -990,6 +990,7 @@ class TestUniquenessOverride(TestCase): class TestSerializer(serializers.ModelSerializer): class Meta: model = TestModel + fields = '__all__' extra_kwargs = {'field_1': {'required': False}} fields = TestSerializer().fields diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 5cef22628..f8561e61d 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -4,7 +4,6 @@ import base64 import unittest from django.contrib.auth.models import Group, Permission, User -from django.core.urlresolvers import ResolverMatch from django.db import models from django.test import TestCase @@ -12,7 +11,7 @@ from rest_framework import ( HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status ) -from rest_framework.compat import guardian +from rest_framework.compat import ResolverMatch, guardian, set_many from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory @@ -74,15 +73,15 @@ class ModelPermissionsIntegrationTests(TestCase): def setUp(self): User.objects.create_user('disallowed', 'disallowed@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password') - user.user_permissions = [ + set_many(user, 'user_permissions', [ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') - ] + ]) user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') - user.user_permissions = [ + set_many(user, 'user_permissions', [ Permission.objects.get(codename='change_basicmodel'), - ] + ]) self.permitted_credentials = basic_auth_header('permitted', 'password') self.disallowed_credentials = basic_auth_header('disallowed', 'password') diff --git a/tests/test_request.py b/tests/test_request.py index dbfa695fd..32fbbc50b 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -13,6 +13,7 @@ from django.utils import six from rest_framework import status from rest_framework.authentication import SessionAuthentication +from rest_framework.compat import is_anonymous from rest_framework.parsers import BaseParser, FormParser, MultiPartParser from rest_framework.request import Request from rest_framework.response import Response @@ -169,9 +170,9 @@ class TestUserSetter(TestCase): def test_user_can_logout(self): self.request.user = self.user - self.assertFalse(self.request.user.is_anonymous()) + self.assertFalse(is_anonymous(self.request.user)) logout(self.request) - self.assertTrue(self.request.user.is_anonymous()) + self.assertTrue(is_anonymous(self.request.user)) def test_logged_in_user_is_set_on_wrapped_request(self): login(self.request, self.user) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py new file mode 100644 index 000000000..37bde1092 --- /dev/null +++ b/tests/test_requests_client.py @@ -0,0 +1,247 @@ +from __future__ import unicode_literals + +import unittest + +from django.conf.urls import url +from django.contrib.auth import authenticate, login +from django.contrib.auth.models import User +from django.shortcuts import redirect +from django.test import override_settings +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie + +from rest_framework.compat import is_authenticated, requests +from rest_framework.response import Response +from rest_framework.test import APITestCase, get_requests_client +from rest_framework.views import APIView + + +class Root(APIView): + def get(self, request): + return Response({ + 'method': request.method, + 'query_params': request.query_params, + }) + + def post(self, request): + files = { + key: (value.name, value.read()) + for key, value in request.FILES.items() + } + post = request.POST + json = None + if request.META.get('CONTENT_TYPE') == 'application/json': + json = request.data + + return Response({ + 'method': request.method, + 'query_params': request.query_params, + 'POST': post, + 'FILES': files, + 'JSON': json + }) + + +class HeadersView(APIView): + def get(self, request): + headers = { + key[5:].replace('_', '-'): value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) + + +class SessionView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.session.items() + }) + + def post(self, request): + for key, value in request.data.items(): + request.session[key] = value + return Response({ + key: value for key, value in request.session.items() + }) + + +class AuthView(APIView): + @method_decorator(ensure_csrf_cookie) + def get(self, request): + if is_authenticated(request.user): + username = request.user.username + else: + username = None + return Response({ + 'username': username + }) + + @method_decorator(csrf_protect) + def post(self, request): + username = request.data['username'] + password = request.data['password'] + user = authenticate(username=username, password=password) + if user is None: + return Response({'error': 'incorrect credentials'}) + login(request, user) + return redirect('/auth/') + + +urlpatterns = [ + url(r'^$', Root.as_view()), + url(r'^headers/$', HeadersView.as_view()), + url(r'^session/$', SessionView.as_view()), + url(r'^auth/$', AuthView.as_view()), +] + + +@unittest.skipUnless(requests, 'requests not installed') +@override_settings(ROOT_URLCONF='tests.test_requests_client') +class RequestsClientTests(APITestCase): + def test_get_request(self): + client = get_requests_client() + response = client.get('/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {} + } + assert response.json() == expected + + def test_get_request_query_params_in_url(self): + client = get_requests_client() + response = client.get('/?key=value') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_request_query_params_by_kwarg(self): + client = get_requests_client() + response = client.get('/', params={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_with_headers(self): + client = get_requests_client() + response = client.get('/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_post_form_request(self): + client = get_requests_client() + response = client.post('/', data={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {'key': 'value'}, + 'FILES': {}, + 'JSON': None + } + assert response.json() == expected + + def test_post_json_request(self): + client = get_requests_client() + response = client.post('/', json={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {}, + 'FILES': {}, + 'JSON': {'key': 'value'} + } + assert response.json() == expected + + def test_post_multipart_request(self): + client = get_requests_client() + files = { + 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') + } + response = client.post('/', files=files) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']}, + 'POST': {}, + 'JSON': None + } + assert response.json() == expected + + def test_session(self): + client = get_requests_client() + response = client.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {} + assert response.json() == expected + + response = client.post('/session/', json={'example': 'abc'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + response = client.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + def test_auth(self): + # Confirm session is not authenticated + client = get_requests_client() + response = client.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': None + } + assert response.json() == expected + assert 'csrftoken' in response.cookies + csrftoken = response.cookies['csrftoken'] + + user = User.objects.create(username='tom') + user.set_password('password') + user.save() + + # Perform a login + response = client.post('/auth/', json={ + 'username': 'tom', + 'password': 'password' + }, headers={'X-CSRFToken': csrftoken}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } + assert response.json() == expected + + # Confirm session is authenticated + response = client.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } + assert response.json() == expected diff --git a/tests/test_reverse.py b/tests/test_reverse.py index 03d31f1f9..f30a8bf9a 100644 --- a/tests/test_reverse.py +++ b/tests/test_reverse.py @@ -1,9 +1,9 @@ from __future__ import unicode_literals from django.conf.urls import url -from django.core.urlresolvers import NoReverseMatch from django.test import TestCase, override_settings +from rest_framework.compat import NoReverseMatch from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory diff --git a/tests/test_routers.py b/tests/test_routers.py index f45039f80..d28e301a0 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import json from collections import namedtuple from django.conf.urls import include, url @@ -47,6 +48,21 @@ class MockViewSet(viewsets.ModelViewSet): serializer_class = None +class EmptyPrefixSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + fields = ('uuid', 'text') + + +class EmptyPrefixViewSet(viewsets.ModelViewSet): + queryset = [RouterTestModel(id=1, uuid='111', text='First'), RouterTestModel(id=2, uuid='222', text='Second')] + serializer_class = EmptyPrefixSerializer + + def get_object(self, *args, **kwargs): + index = int(self.kwargs['pk']) - 1 + return self.queryset[index] + + notes_router = SimpleRouter() notes_router.register(r'notes', NoteViewSet) @@ -56,11 +72,19 @@ kwarged_notes_router.register(r'notes', KWargedNoteViewSet) namespaced_router = DefaultRouter() namespaced_router.register(r'example', MockViewSet, base_name='example') +empty_prefix_router = SimpleRouter() +empty_prefix_router.register(r'', EmptyPrefixViewSet, base_name='empty_prefix') +empty_prefix_urls = [ + url(r'^', include(empty_prefix_router.urls)), +] + urlpatterns = [ url(r'^non-namespaced/', include(namespaced_router.urls)), url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), url(r'^example/', include(notes_router.urls)), url(r'^example2/', include(kwarged_notes_router.urls)), + + url(r'^empty-prefix/', include(empty_prefix_urls)), ] @@ -384,3 +408,28 @@ class TestDynamicListAndDetailRouter(TestCase): def test_inherited_list_and_detail_route_decorators(self): self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet) + + +@override_settings(ROOT_URLCONF='tests.test_routers') +class TestEmptyPrefix(TestCase): + def test_empty_prefix_list(self): + response = self.client.get('/empty-prefix/') + self.assertEqual(200, response.status_code) + self.assertEqual( + json.loads(response.content.decode('utf-8')), + [ + {'uuid': '111', 'text': 'First'}, + {'uuid': '222', 'text': 'Second'} + ] + ) + + def test_empty_prefix_detail(self): + response = self.client.get('/empty-prefix/1/') + self.assertEqual(200, response.status_code) + self.assertEqual( + json.loads(response.content.decode('utf-8')), + { + 'uuid': '111', + 'text': 'First' + } + ) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 5d2b58a86..51e8ae441 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -215,7 +215,7 @@ class TestSchemaGenerator(TestCase): } } ) - self.assertEquals(schema, expected) + self.assertEqual(schema, expected) class SnippetListView(APIView): diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py index 78d37c1a8..33d367e1d 100644 --- a/tests/test_urlpatterns.py +++ b/tests/test_urlpatterns.py @@ -3,9 +3,9 @@ from __future__ import unicode_literals from collections import namedtuple from django.conf.urls import include, url -from django.core import urlresolvers from django.test import TestCase +from rest_framework.compat import RegexURLResolver, Resolver404 from rest_framework.test import APIRequestFactory from rest_framework.urlpatterns import format_suffix_patterns @@ -28,7 +28,7 @@ class FormatSuffixTests(TestCase): urlpatterns = format_suffix_patterns(urlpatterns) except Exception: self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") - resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + resolver = RegexURLResolver(r'^/', urlpatterns) for test_path in test_paths: request = factory.get(test_path.path) try: @@ -43,7 +43,7 @@ class FormatSuffixTests(TestCase): urlpatterns = format_suffix_patterns([ url(r'^test/$', dummy_view), ]) - resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + resolver = RegexURLResolver(r'^/', urlpatterns) test_paths = [ (URLTestPath('/test.api', (), {'format': 'api'}), True), @@ -55,7 +55,7 @@ class FormatSuffixTests(TestCase): request = factory.get(test_path.path) try: callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) - except urlresolvers.Resolver404: + except Resolver404: callback, callback_args, callback_kwargs = (None, None, None) if not expected_resolved: assert callback is None diff --git a/tests/utils.py b/tests/utils.py index 5b2d75864..52582f093 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,5 @@ from django.core.exceptions import ObjectDoesNotExist -from django.core.urlresolvers import NoReverseMatch +from rest_framework.compat import NoReverseMatch class MockObject(object): diff --git a/tox.ini b/tox.ini index 1e8a7e5c4..c20021e4b 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ addopts=--tb=short [tox] envlist = py27-{lint,docs}, - {py27,py32,py33,py34,py35}-django18, + {py27,py33,py34,py35}-django18, {py27,py34,py35}-django19, {py27,py34,py35}-django110, {py27,py34,py35}-django{master} @@ -25,7 +25,6 @@ basepython = py35: python3.5 py34: python3.4 py33: python3.3 - py32: python3.2 py27: python2.7 [testenv:py27-lint]