Merge branch 'version-3-5' into rejig-schema-generation

This commit is contained in:
Tom Christie 2016-09-30 11:28:24 +01:00
commit 657a7c5e00
37 changed files with 1153 additions and 122 deletions

View File

@ -14,7 +14,6 @@ env:
- TOX_ENV=py35-django18
- TOX_ENV=py34-django18
- TOX_ENV=py33-django18
- TOX_ENV=py32-django18
- TOX_ENV=py27-django18
- TOX_ENV=py27-django110
- TOX_ENV=py35-django110

View File

@ -457,6 +457,8 @@ There are two keyword arguments you can use to control this behavior:
- `html_cutoff` - If set this will be the maximum number of choices that will be displayed by a HTML select drop down. Set to `None` to disable any limiting. Defaults to `1000`.
- `html_cutoff_text` - If set this will display a textual indicator if the maximum number of items have been cutoff in an HTML select drop down. Defaults to `"More than {count} items…"`
You can also control these globally using the settings `HTML_SELECT_CUTOFF` and `HTML_SELECT_CUTOFF_TEXT`.
In cases where the cutoff is being enforced you may want to instead use a plain input field in the HTML form. You can do so using the `style` keyword argument. For example:
assigned_to = serializers.SlugRelatedField(

View File

@ -23,7 +23,7 @@ There's no requirement for you to use them, but if you do then the self-describi
**Signature:** `reverse(viewname, *args, **kwargs)`
Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port.
Has the same behavior as [`django.urls.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port.
You should **include the request as a keyword argument** to the function, for example:
@ -44,7 +44,7 @@ You should **include the request as a keyword argument** to the function, for ex
**Signature:** `reverse_lazy(viewname, *args, **kwargs)`
Has the same behavior as [`django.core.urlresolvers.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port.
Has the same behavior as [`django.urls.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port.
As with the `reverse` function, you should **include the request as a keyword argument** to the function, for example:

View File

@ -382,6 +382,22 @@ This should be a function with the following signature:
Default: `'rest_framework.views.get_view_description'`
## HTML Select Field cutoffs
Global settings for [select field cutoffs for rendering relational fields](relations.md#select-field-cutoffs) in the browsable API.
#### HTML_SELECT_CUTOFF
Global setting for the `html_cutoff` value. Must be an integer.
Default: 1000
#### HTML_SELECT_CUTOFF_TEXT
A string representing a global setting for `html_cutoff_text`.
Default: `"More than {count} items..."`
---
## Miscellaneous settings

View File

@ -197,7 +197,7 @@ REST framework includes the following test case classes, that mirror the existin
You can use any of REST framework's test case classes as you would for the regular Django test case classes. The `self.client` attribute will be an `APIClient` instance.
from django.core.urlresolvers import reverse
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APITestCase
from myproject.apps.core.models import Account

View File

@ -1,5 +1,5 @@
# Optional packages which may be used with REST framework.
markdown==2.6.4
django-guardian==1.4.3
django-filter==0.13.0
coreapi==1.32.0
django-guardian==1.4.6
django-filter==0.14.0
coreapi==2.0.8

View File

@ -8,7 +8,7 @@ ______ _____ _____ _____ __
"""
__title__ = 'Django REST framework'
__version__ = '3.4.7'
__version__ = '3.5.0'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
__copyright__ = 'Copyright 2011-2016 Tom Christie'

View File

@ -16,6 +16,9 @@ class AuthTokenSerializer(serializers.Serializer):
user = authenticate(username=username, password=password)
if user:
# From Django 1.10 onwards the `authenticate` call simply
# returns `None` for is_active=False users.
# (Assuming the default `ModelBackend` authentication backend.)
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg)

View File

@ -23,6 +23,16 @@ except ImportError:
from django.utils import importlib # Will be removed in Django 1.9
try:
from django.urls import (
NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve
)
except ImportError:
from django.core.urlresolvers import ( # Will be removed in Django 2.0
NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve
)
try:
import urlparse # Python 2.x
except ImportError:
@ -128,6 +138,12 @@ def is_authenticated(user):
return user.is_authenticated
def is_anonymous(user):
if django.VERSION < (1, 10):
return user.is_anonymous()
return user.is_anonymous
def get_related_model(field):
if django.VERSION < (1, 9):
return _resolve_model(field.rel.to)
@ -178,6 +194,13 @@ except (ImportError, SyntaxError):
uritemplate = None
# requests is optional
try:
import requests
except ImportError:
requests = None
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
# Fixes (#1712). We keep the try/except for the test suite.
guardian = None
@ -200,8 +223,13 @@ try:
if markdown.version <= '2.2':
HEADERID_EXT_PATH = 'headerid'
else:
LEVEL_PARAM = 'level'
elif markdown.version < '2.6':
HEADERID_EXT_PATH = 'markdown.extensions.headerid'
LEVEL_PARAM = 'level'
else:
HEADERID_EXT_PATH = 'markdown.extensions.toc'
LEVEL_PARAM = 'baselevel'
def apply_markdown(text):
"""
@ -211,7 +239,7 @@ try:
extensions = [HEADERID_EXT_PATH]
extension_configs = {
HEADERID_EXT_PATH: {
'level': '2'
LEVEL_PARAM: '2'
}
}
md = markdown.Markdown(
@ -277,3 +305,11 @@ def template_render(template, context=None, request=None):
# backends template, e.g. django.template.backends.django.Template
else:
return template.render(context, request=request)
def set_many(instance, field, value):
if django.VERSION < (1, 10):
setattr(instance, field, value)
else:
field = getattr(instance, field)
field.set(value)

View File

@ -49,18 +49,32 @@ class empty:
pass
if six.PY3:
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
if not callable(obj):
return False
sig = inspect.signature(obj)
params = sig.parameters.values()
return all(param.default != param.empty for param in params)
else:
def is_simple_callable(obj):
function = inspect.isfunction(obj)
method = inspect.ismethod(obj)
if not (function or method):
return False
if method:
is_unbound = obj.im_self is None
args, _, _, defaults = inspect.getargspec(obj)
len_args = len(args) if function else len(args) - 1
len_args = len(args) if function or is_unbound else len(args) - 1
len_defaults = len(defaults) if defaults else 0
return len_args <= len_defaults

View File

@ -4,9 +4,6 @@ from __future__ import unicode_literals
from collections import OrderedDict
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.core.urlresolvers import (
NoReverseMatch, Resolver404, get_script_prefix, resolve
)
from django.db.models import Manager
from django.db.models.query import QuerySet
from django.utils import six
@ -14,10 +11,14 @@ from django.utils.encoding import python_2_unicode_compatible, smart_text
from django.utils.six.moves.urllib import parse as urlparse
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import (
NoReverseMatch, Resolver404, get_script_prefix, resolve
)
from rest_framework.fields import (
Field, empty, get_attribute, is_simple_callable, iter_options
)
from rest_framework.reverse import reverse
from rest_framework.settings import api_settings
from rest_framework.utils import html
@ -71,14 +72,19 @@ MANY_RELATION_KWARGS = (
class RelatedField(Field):
queryset = None
html_cutoff = 1000
html_cutoff_text = _('More than {count} items...')
html_cutoff = None
html_cutoff_text = None
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', self.queryset)
self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff)
self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text)
self.html_cutoff = kwargs.pop(
'html_cutoff',
self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF)
)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
)
if not method_overridden('get_queryset', RelatedField, self):
assert self.queryset is not None or kwargs.get('read_only', None), (
'Relational field must provide a `queryset` argument, '
@ -447,15 +453,20 @@ class ManyRelatedField(Field):
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
'empty': _('This list may not be empty.')
}
html_cutoff = 1000
html_cutoff_text = _('More than {count} items...')
html_cutoff = None
html_cutoff_text = None
def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
self.allow_empty = kwargs.pop('allow_empty', True)
self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff)
self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text)
self.html_cutoff = kwargs.pop(
'html_cutoff',
self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF)
)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
)
assert child_relation is not None, '`child_relation` is a required argument.'
super(ManyRelatedField, self).__init__(*args, **kwargs)
self.child_relation.bind(field_name='', parent=self)

View File

@ -3,11 +3,11 @@ Provide urlresolver functions that return fully qualified URLs or view names
"""
from __future__ import unicode_literals
from django.core.urlresolvers import reverse as django_reverse
from django.core.urlresolvers import NoReverseMatch
from django.utils import six
from django.utils.functional import lazy
from rest_framework.compat import reverse as django_reverse
from rest_framework.compat import NoReverseMatch
from rest_framework.settings import api_settings
from rest_framework.utils.urls import replace_query_param
@ -54,7 +54,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra
def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
Same as `django.urls.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
if format is not None:

View File

@ -20,9 +20,9 @@ from collections import OrderedDict, namedtuple
from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import NoReverseMatch
from rest_framework import exceptions, renderers, views
from rest_framework.compat import NoReverseMatch
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
@ -83,6 +83,7 @@ class BaseRouter(object):
class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
@ -258,6 +259,13 @@ class SimpleRouter(BaseRouter):
trailing_slash=self.trailing_slash
)
# If there is no prefix, the first part of the url is probably
# controlled by project's urls.py and the router is in an app,
# so a slash in the beginning will (A) cause Django to give
# warnings and (B) generate URLS that will require using '//'.
if not prefix and regex[:2] == '^/':
regex = '^' + regex[2:]
view = viewset.as_view(mapping, **route.initkwargs)
name = route.name.format(basename=basename)
ret.append(url(regex, view, name=name))
@ -289,42 +297,42 @@ class DefaultRouter(SimpleRouter):
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
super(DefaultRouter, self).__init__(*args, **kwargs)
def get_schema_root_view(self, api_urls=None):
"""
Return a schema root view.
"""
schema_renderers = self.schema_renderers
schema_generator = SchemaGenerator(
title=self.schema_title,
url=self.schema_url,
patterns=api_urls
)
class APISchemaView(views.APIView):
_ignore_model_permissions = True
renderer_classes = schema_renderers
def get(self, request, *args, **kwargs):
schema = schema_generator.get_schema(request)
if schema is None:
raise exceptions.PermissionDenied()
return Response(schema)
return APISchemaView.as_view()
def get_api_root_view(self, api_urls=None):
"""
Return a view to use as the API root.
Return a basic root view.
"""
api_root_dict = OrderedDict()
list_name = self.routes[0].name
for prefix, viewset, basename in self.registry:
api_root_dict[prefix] = list_name.format(basename=basename)
view_renderers = list(self.root_renderers)
schema_media_types = []
if api_urls and self.schema_title:
view_renderers += list(self.schema_renderers)
schema_generator = SchemaGenerator(
title=self.schema_title,
url=self.schema_url,
patterns=api_urls
)
schema_media_types = [
renderer.media_type
for renderer in self.schema_renderers
]
class APIRoot(views.APIView):
class APIRootView(views.APIView):
_ignore_model_permissions = True
renderer_classes = view_renderers
def get(self, request, *args, **kwargs):
if request.accepted_renderer.media_type in schema_media_types:
# Return a schema response.
schema = schema_generator.get_schema(request)
if schema is None:
raise exceptions.PermissionDenied()
return Response(schema)
# Return a plain {"name": "hyperlink"} response.
ret = OrderedDict()
namespace = request.resolver_match.namespace
@ -345,7 +353,7 @@ class DefaultRouter(SimpleRouter):
return Response(ret)
return APIRoot.as_view()
return APIRootView.as_view()
def get_urls(self):
"""
@ -355,6 +363,9 @@ class DefaultRouter(SimpleRouter):
urls = super(DefaultRouter, self).get_urls()
if self.include_root_view:
if self.schema_title:
view = self.get_schema_root_view(api_urls=urls)
else:
view = self.get_api_root_view(api_urls=urls)
root_url = url(r'^$', view, name=self.root_view_name)
urls.append(root_url)

View File

@ -2,12 +2,13 @@ from importlib import import_module
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
from django.utils import six
from django.utils.encoding import force_text
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, uritemplate, urlparse
from rest_framework.compat import (
RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse
)
from rest_framework.request import clone_request
from rest_framework.views import APIView
@ -329,7 +330,7 @@ class SchemaGenerator(object):
fields += as_query_fields(filter_backend().get_fields(view))
return fields
# Methods for generating the keys which are used to layout each link.
# Methods for generating the link layout....
default_mapping = {
'get': 'read',

View File

@ -23,7 +23,7 @@ from django.utils.functional import cached_property
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
@ -892,18 +892,22 @@ class ModelSerializer(Serializer):
# Save many-to-many relationships after the instance is created.
if many_to_many:
for field_name, value in many_to_many.items():
setattr(instance, field_name, value)
set_many(instance, field_name, value)
return instance
def update(self, instance, validated_data):
raise_errors_on_nested_writes('update', self, validated_data)
info = model_meta.get_field_info(instance)
# Simply set each attribute on the instance, and then save it.
# Note that unlike `.create()` we don't need to treat many-to-many
# relationships as being a special case. During updates we already
# have an instance pk for the relationships to be associated with.
for attr, value in validated_data.items():
if attr in info.relations and info.relations[attr].to_many:
set_many(instance, attr, value)
else:
setattr(instance, attr, value)
instance.save()

View File

@ -111,6 +111,10 @@ DEFAULTS = {
'COMPACT_JSON': True,
'COERCE_DECIMAL_TO_STRING': True,
'UPLOADED_FILES_USE_URL': True,
# Browseable API
'HTML_SELECT_CUTOFF': 1000,
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
}

View File

@ -3,14 +3,13 @@ from __future__ import absolute_import, unicode_literals
import re
from django import template
from django.core.urlresolvers import NoReverseMatch, reverse
from django.template import loader
from django.utils import six
from django.utils.encoding import force_text, iri_to_uri
from django.utils.html import escape, format_html, smart_urlquote
from django.utils.safestring import SafeData, mark_safe
from rest_framework.compat import template_render
from rest_framework.compat import NoReverseMatch, reverse, template_render
from rest_framework.renderers import HTMLFormRenderer
from rest_framework.utils.urls import replace_query_param

View File

@ -4,7 +4,10 @@
# to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals
import io
from django.conf import settings
from django.core.handlers.wsgi import WSGIHandler
from django.test import testcases
from django.test.client import Client as DjangoClient
from django.test.client import RequestFactory as DjangoRequestFactory
@ -13,6 +16,7 @@ from django.utils import six
from django.utils.encoding import force_bytes
from django.utils.http import urlencode
from rest_framework.compat import coreapi, requests
from rest_framework.settings import api_settings
@ -21,6 +25,118 @@ def force_authenticate(request, user=None, token=None):
request._force_auth_token = token
if requests is not None:
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
return self.getheaders(key)
class MockOriginalResponse(object):
def __init__(self, headers):
self.msg = HeaderDict(headers)
self.closed = False
def isclosed(self):
return self.closed
def close(self):
self.closed = True
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
"""
A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network.
"""
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()
def get_environ(self, request):
"""
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
"""
method = request.method
url = request.url
kwargs = {}
# Set request content, if any exists.
if request.body is not None:
if hasattr(request.body, 'read'):
kwargs['data'] = request.body.read()
else:
kwargs['data'] = request.body
if 'content-type' in request.headers:
kwargs['content_type'] = request.headers['content-type']
# Set request headers.
for key, value in request.headers.items():
key = key.upper()
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
continue
kwargs['HTTP_%s' % key.replace('-', '_')] = value
return self.factory.generic(method, url, **kwargs).environ
def send(self, request, *args, **kwargs):
"""
Make an outgoing request to the Django WSGI application.
"""
raw_kwargs = {}
def start_response(wsgi_status, wsgi_headers):
status, _, reason = wsgi_status.partition(' ')
raw_kwargs['status'] = int(status)
raw_kwargs['reason'] = reason
raw_kwargs['headers'] = wsgi_headers
raw_kwargs['version'] = 11
raw_kwargs['preload_content'] = False
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
# Make the outgoing request via WSGI.
environ = self.get_environ(request)
wsgi_response = self.app(environ, start_response)
# Build the underlying urllib3.HTTPResponse
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
# Build the requests.Response
return self.build_response(request, raw)
def close(self):
pass
class DjangoTestSession(requests.Session):
def __init__(self, *args, **kwargs):
super(DjangoTestSession, self).__init__(*args, **kwargs)
adapter = DjangoTestAdapter()
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
for hostname in hostnames:
if hostname == '*':
hostname = ''
self.mount('http://%s' % hostname, adapter)
self.mount('https://%s' % hostname, adapter)
def request(self, method, url, *args, **kwargs):
if ':' not in url:
url = 'http://testserver/' + url.lstrip('/')
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
def get_requests_client():
assert requests is not None, 'requests must be installed'
return DjangoTestSession()
def get_api_client():
assert coreapi is not None, 'coreapi must be installed'
session = get_requests_client()
return coreapi.Client(transports=[
coreapi.transports.HTTPTransport(session=session)
])
class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT

View File

@ -1,8 +1,8 @@
from __future__ import unicode_literals
from django.conf.urls import include, url
from django.core.urlresolvers import RegexURLResolver
from rest_framework.compat import RegexURLResolver
from rest_framework.settings import api_settings

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals
from django.core.urlresolvers import get_script_prefix, resolve
from rest_framework.compat import get_script_prefix, resolve
def get_breadcrumbs(url, request=None):

View File

@ -11,6 +11,7 @@ factory = APIRequestFactory()
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
fields = '__all__'
class ManyPostView(generics.GenericAPIView):

View File

@ -1,6 +1,13 @@
def pytest_configure():
from django.conf import settings
MIDDLEWARE = (
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
)
settings.configure(
DEBUG_PROPAGATE_EXCEPTIONS=True,
DATABASES={
@ -21,12 +28,8 @@ def pytest_configure():
'APP_DIRS': True,
},
],
MIDDLEWARE_CLASSES=(
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
),
MIDDLEWARE=MIDDLEWARE,
MIDDLEWARE_CLASSES=MIDDLEWARE,
INSTALLED_APPS=(
'django.contrib.auth',
'django.contrib.contenttypes',

452
tests/test_api_client.py Normal file
View File

@ -0,0 +1,452 @@
from __future__ import unicode_literals
import os
import tempfile
import unittest
from django.conf.urls import url
from django.http import HttpResponse
from django.test import override_settings
from rest_framework.compat import coreapi
from rest_framework.parsers import FileUploadParser
from rest_framework.renderers import CoreJSONRenderer
from rest_framework.response import Response
from rest_framework.test import APITestCase, get_api_client
from rest_framework.views import APIView
def get_schema():
return coreapi.Document(
url='https://api.example.com/',
title='Example API',
content={
'simple_link': coreapi.Link('/example/', description='example link'),
'location': {
'query': coreapi.Link('/example/', fields=[
coreapi.Field(name='example', description='example field')
]),
'form': coreapi.Link('/example/', action='post', fields=[
coreapi.Field(name='example'),
]),
'body': coreapi.Link('/example/', action='post', fields=[
coreapi.Field(name='example', location='body')
]),
'path': coreapi.Link('/example/{id}', fields=[
coreapi.Field(name='id', location='path')
])
},
'encoding': {
'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
coreapi.Field(name='example')
]),
'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
coreapi.Field(name='example', location='body')
]),
'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
coreapi.Field(name='example')
]),
'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
coreapi.Field(name='example', location='body')
]),
'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[
coreapi.Field(name='example', location='body')
]),
},
'response': {
'download': coreapi.Link('/download/'),
'text': coreapi.Link('/text/')
}
}
)
def _iterlists(querydict):
if hasattr(querydict, 'iterlists'):
return querydict.iterlists()
return querydict.lists()
def _get_query_params(request):
# Return query params in a plain dict, using a list value if more
# than one item is present for a given key.
return {
key: (value[0] if len(value) == 1 else value)
for key, value in
_iterlists(request.query_params)
}
def _get_data(request):
if not isinstance(request.data, dict):
return request.data
# Coerce multidict into regular dict, and remove files to
# make assertions simpler.
if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'):
# Use a list value if a QueryDict contains multiple items for a key.
return {
key: value[0] if len(value) == 1 else value
for key, value in _iterlists(request.data)
if key not in request.FILES
}
return {
key: value
for key, value in request.data.items()
if key not in request.FILES
}
def _get_files(request):
if not request.FILES:
return {}
return {
key: {'name': value.name, 'content': value.read()}
for key, value in request.FILES.items()
}
class SchemaView(APIView):
renderer_classes = [CoreJSONRenderer]
def get(self, request):
schema = get_schema()
return Response(schema)
class ListView(APIView):
def get(self, request):
return Response({
'method': request.method,
'query_params': _get_query_params(request)
})
def post(self, request):
if request.content_type:
content_type = request.content_type.split(';')[0]
else:
content_type = None
return Response({
'method': request.method,
'query_params': _get_query_params(request),
'data': _get_data(request),
'files': _get_files(request),
'content_type': content_type
})
class DetailView(APIView):
def get(self, request, id):
return Response({
'id': id,
'method': request.method,
'query_params': _get_query_params(request)
})
class UploadView(APIView):
parser_classes = [FileUploadParser]
def post(self, request):
return Response({
'method': request.method,
'files': _get_files(request),
'content_type': request.content_type
})
class DownloadView(APIView):
def get(self, request):
return HttpResponse('some file content', content_type='image/png')
class TextView(APIView):
def get(self, request):
return HttpResponse('123', content_type='text/plain')
urlpatterns = [
url(r'^$', SchemaView.as_view()),
url(r'^example/$', ListView.as_view()),
url(r'^example/(?P<id>[0-9]+)/$', DetailView.as_view()),
url(r'^upload/$', UploadView.as_view()),
url(r'^download/$', DownloadView.as_view()),
url(r'^text/$', TextView.as_view()),
]
@unittest.skipUnless(coreapi, 'coreapi not installed')
@override_settings(ROOT_URLCONF='tests.test_api_client')
class APIClientTests(APITestCase):
def test_api_client(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
assert schema.title == 'Example API'
assert schema.url == 'https://api.example.com/'
assert schema['simple_link'].description == 'example link'
assert schema['location']['query'].fields[0].description == 'example field'
data = client.action(schema, ['simple_link'])
expected = {
'method': 'GET',
'query_params': {}
}
assert data == expected
def test_query_params(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'query'], params={'example': 123})
expected = {
'method': 'GET',
'query_params': {'example': '123'}
}
assert data == expected
def test_query_params_with_multiple_values(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]})
expected = {
'method': 'GET',
'query_params': {'example': ['1', '2', '3']}
}
assert data == expected
def test_form_params(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'form'], params={'example': 123})
expected = {
'method': 'POST',
'content_type': 'application/json',
'query_params': {},
'data': {'example': 123},
'files': {}
}
assert data == expected
def test_body_params(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'body'], params={'example': 123})
expected = {
'method': 'POST',
'content_type': 'application/json',
'query_params': {},
'data': 123,
'files': {}
}
assert data == expected
def test_path_params(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'path'], params={'id': 123})
expected = {
'method': 'GET',
'query_params': {},
'id': '123'
}
assert data == expected
def test_multipart_encoding(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
temp = tempfile.NamedTemporaryFile()
temp.write(b'example file content')
temp.flush()
with open(temp.name, 'rb') as upload:
name = os.path.basename(upload.name)
data = client.action(schema, ['encoding', 'multipart'], params={'example': upload})
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {},
'files': {'example': {'name': name, 'content': 'example file content'}}
}
assert data == expected
def test_multipart_encoding_no_file(self):
# When no file is included, multipart encoding should still be used.
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'multipart'], params={'example': 123})
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {'example': '123'},
'files': {}
}
assert data == expected
def test_multipart_encoding_multiple_values(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]})
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {'example': ['1', '2', '3']},
'files': {}
}
assert data == expected
def test_multipart_encoding_string_file_content(self):
# Test for `coreapi.utils.File` support.
from coreapi.utils import File
client = get_api_client()
schema = client.get('http://api.example.com/')
example = File(name='example.txt', content='123')
data = client.action(schema, ['encoding', 'multipart'], params={'example': example})
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {},
'files': {'example': {'name': 'example.txt', 'content': '123'}}
}
assert data == expected
def test_multipart_encoding_in_body(self):
from coreapi.utils import File
client = get_api_client()
schema = client.get('http://api.example.com/')
example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'}
data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example})
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {'bar': 'abc'},
'files': {'foo': {'name': 'example.txt', 'content': '123'}}
}
assert data == expected
# URLencoded
def test_urlencoded_encoding(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123})
expected = {
'method': 'POST',
'content_type': 'application/x-www-form-urlencoded',
'query_params': {},
'data': {'example': '123'},
'files': {}
}
assert data == expected
def test_urlencoded_encoding_multiple_values(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]})
expected = {
'method': 'POST',
'content_type': 'application/x-www-form-urlencoded',
'query_params': {},
'data': {'example': ['1', '2', '3']},
'files': {}
}
assert data == expected
def test_urlencoded_encoding_in_body(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}})
expected = {
'method': 'POST',
'content_type': 'application/x-www-form-urlencoded',
'query_params': {},
'data': {'foo': '123', 'bar': 'true'},
'files': {}
}
assert data == expected
# Raw uploads
def test_raw_upload(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
temp = tempfile.NamedTemporaryFile()
temp.write(b'example file content')
temp.flush()
with open(temp.name, 'rb') as upload:
name = os.path.basename(upload.name)
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload})
expected = {
'method': 'POST',
'files': {'file': {'name': name, 'content': 'example file content'}},
'content_type': 'application/octet-stream'
}
assert data == expected
def test_raw_upload_string_file_content(self):
from coreapi.utils import File
client = get_api_client()
schema = client.get('http://api.example.com/')
example = File('example.txt', '123')
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
expected = {
'method': 'POST',
'files': {'file': {'name': 'example.txt', 'content': '123'}},
'content_type': 'text/plain'
}
assert data == expected
def test_raw_upload_explicit_content_type(self):
from coreapi.utils import File
client = get_api_client()
schema = client.get('http://api.example.com/')
example = File('example.txt', '123', 'text/html')
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
expected = {
'method': 'POST',
'files': {'file': {'name': 'example.txt', 'content': '123'}},
'content_type': 'text/html'
}
assert data == expected
# Responses
def test_text_response(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['response', 'text'])
expected = '123'
assert data == expected
def test_download_response(self):
client = get_api_client()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['response', 'download'])
assert data.basename == 'download.png'
assert data.read() == b'some file content'

View File

@ -5,7 +5,7 @@ import unittest
from django.conf.urls import url
from django.db import connection, connections, transaction
from django.http import Http404
from django.test import TestCase, TransactionTestCase
from django.test import TestCase, TransactionTestCase, override_settings
from django.utils.decorators import method_decorator
from rest_framework import status
@ -36,6 +36,20 @@ class APIExceptionView(APIView):
raise APIException
class NonAtomicAPIExceptionView(APIView):
@method_decorator(transaction.non_atomic_requests)
def dispatch(self, *args, **kwargs):
return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs)
def get(self, request, *args, **kwargs):
BasicModel.objects.all()
raise Http404
urlpatterns = (
url(r'^$', NonAtomicAPIExceptionView.as_view()),
)
@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
@ -124,22 +138,8 @@ class DBTransactionAPIExceptionTests(TestCase):
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
@property
def urls(self):
class NonAtomicAPIExceptionView(APIView):
@method_decorator(transaction.non_atomic_requests)
def dispatch(self, *args, **kwargs):
return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs)
def get(self, request, *args, **kwargs):
BasicModel.objects.all()
raise Http404
return (
url(r'^$', NonAtomicAPIExceptionView.as_view()),
)
def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True

View File

@ -20,6 +20,7 @@ from rest_framework.authentication import (
)
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import obtain_auth_token
from rest_framework.compat import is_authenticated
from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView
@ -408,7 +409,7 @@ class FailingAuthAccessedInRenderer(TestCase):
def render(self, data, media_type=None, renderer_context=None):
request = renderer_context['request']
if request.user.is_authenticated():
if is_authenticated(request.user):
return b'authenticated'
return b'not authenticated'

View File

@ -1,6 +1,7 @@
import datetime
import os
import re
import unittest
import uuid
from decimal import Decimal
@ -11,6 +12,67 @@ from django.utils import six, timezone
import rest_framework
from rest_framework import serializers
from rest_framework.fields import is_simple_callable
try:
import typings
except ImportError:
typings = False
# Tests for helper functions.
# ---------------------------
class TestIsSimpleCallable:
def test_method(self):
class Foo:
@classmethod
def classmethod(cls):
pass
def valid(self):
pass
def valid_kwargs(self, param='value'):
pass
def invalid(self, param):
pass
assert is_simple_callable(Foo.classmethod)
# unbound methods
assert not is_simple_callable(Foo.valid)
assert not is_simple_callable(Foo.valid_kwargs)
assert not is_simple_callable(Foo.invalid)
# bound methods
assert is_simple_callable(Foo().valid)
assert is_simple_callable(Foo().valid_kwargs)
assert not is_simple_callable(Foo().invalid)
def test_function(self):
def simple():
pass
def valid(param='value', param2='value'):
pass
def invalid(param, param2='value'):
pass
assert is_simple_callable(simple)
assert is_simple_callable(valid)
assert not is_simple_callable(invalid)
@unittest.skipUnless(typings, 'requires python 3.5')
def test_type_annotation(self):
# The annotation will otherwise raise a syntax error in python < 3.5
exec("def valid(param: str='value'): pass", locals())
valid = locals()['valid']
assert is_simple_callable(valid)
# Tests for field keyword arguments and core functionality.

View File

@ -6,7 +6,6 @@ from decimal import Decimal
from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import reverse
from django.db import models
from django.test import TestCase
from django.test.utils import override_settings
@ -14,7 +13,7 @@ from django.utils.dateparse import parse_date
from django.utils.six.moves import reload_module
from rest_framework import filters, generics, serializers, status
from rest_framework.compat import django_filters
from rest_framework.compat import django_filters, reverse
from rest_framework.test import APIRequestFactory
from .models import BaseFilterableItem, BasicModel, FilterableItem
@ -77,6 +76,7 @@ if django_filters:
class Meta:
model = BaseFilterableItem
fields = '__all__'
class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
queryset = FilterableItem.objects.all()
@ -456,7 +456,7 @@ class AttributeModel(models.Model):
class SearchFilterModelFk(models.Model):
title = models.CharField(max_length=20)
attribute = models.ForeignKey(AttributeModel)
attribute = models.ForeignKey(AttributeModel, on_delete=models.CASCADE)
class SearchFilterFkSerializer(serializers.ModelSerializer):

View File

@ -20,7 +20,7 @@ from django.test import TestCase
from django.utils import six
from rest_framework import serializers
from rest_framework.compat import unicode_repr
from rest_framework.compat import set_many, unicode_repr
def dedent(blocktext):
@ -651,7 +651,7 @@ class TestIntegration(TestCase):
foreign_key=self.foreign_key_target,
one_to_one=self.one_to_one_target,
)
self.instance.many_to_many = self.many_to_many_targets
set_many(self.instance, 'many_to_many', self.many_to_many_targets)
self.instance.save()
def test_pk_retrival(self):
@ -962,7 +962,7 @@ class OneToOneTargetTestModel(models.Model):
class OneToOneSourceTestModel(models.Model):
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True)
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE)
class TestModelFieldValues(TestCase):
@ -990,6 +990,7 @@ class TestUniquenessOverride(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = TestModel
fields = '__all__'
extra_kwargs = {'field_1': {'required': False}}
fields = TestSerializer().fields

View File

@ -4,7 +4,6 @@ import base64
import unittest
from django.contrib.auth.models import Group, Permission, User
from django.core.urlresolvers import ResolverMatch
from django.db import models
from django.test import TestCase
@ -12,7 +11,7 @@ from rest_framework import (
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
status
)
from rest_framework.compat import guardian
from rest_framework.compat import ResolverMatch, guardian, set_many
from rest_framework.filters import DjangoObjectPermissionsFilter
from rest_framework.routers import DefaultRouter
from rest_framework.test import APIRequestFactory
@ -74,15 +73,15 @@ class ModelPermissionsIntegrationTests(TestCase):
def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
user.user_permissions = [
set_many(user, 'user_permissions', [
Permission.objects.get(codename='add_basicmodel'),
Permission.objects.get(codename='change_basicmodel'),
Permission.objects.get(codename='delete_basicmodel')
]
])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
user.user_permissions = [
set_many(user, 'user_permissions', [
Permission.objects.get(codename='change_basicmodel'),
]
])
self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password')

View File

@ -13,6 +13,7 @@ from django.utils import six
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import is_anonymous
from rest_framework.parsers import BaseParser, FormParser, MultiPartParser
from rest_framework.request import Request
from rest_framework.response import Response
@ -169,9 +170,9 @@ class TestUserSetter(TestCase):
def test_user_can_logout(self):
self.request.user = self.user
self.assertFalse(self.request.user.is_anonymous())
self.assertFalse(is_anonymous(self.request.user))
logout(self.request)
self.assertTrue(self.request.user.is_anonymous())
self.assertTrue(is_anonymous(self.request.user))
def test_logged_in_user_is_set_on_wrapped_request(self):
login(self.request, self.user)

View File

@ -0,0 +1,247 @@
from __future__ import unicode_literals
import unittest
from django.conf.urls import url
from django.contrib.auth import authenticate, login
from django.contrib.auth.models import User
from django.shortcuts import redirect
from django.test import override_settings
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
from rest_framework.compat import is_authenticated, requests
from rest_framework.response import Response
from rest_framework.test import APITestCase, get_requests_client
from rest_framework.views import APIView
class Root(APIView):
def get(self, request):
return Response({
'method': request.method,
'query_params': request.query_params,
})
def post(self, request):
files = {
key: (value.name, value.read())
for key, value in request.FILES.items()
}
post = request.POST
json = None
if request.META.get('CONTENT_TYPE') == 'application/json':
json = request.data
return Response({
'method': request.method,
'query_params': request.query_params,
'POST': post,
'FILES': files,
'JSON': json
})
class HeadersView(APIView):
def get(self, request):
headers = {
key[5:].replace('_', '-'): value
for key, value in request.META.items()
if key.startswith('HTTP_')
}
return Response({
'method': request.method,
'headers': headers
})
class SessionView(APIView):
def get(self, request):
return Response({
key: value for key, value in request.session.items()
})
def post(self, request):
for key, value in request.data.items():
request.session[key] = value
return Response({
key: value for key, value in request.session.items()
})
class AuthView(APIView):
@method_decorator(ensure_csrf_cookie)
def get(self, request):
if is_authenticated(request.user):
username = request.user.username
else:
username = None
return Response({
'username': username
})
@method_decorator(csrf_protect)
def post(self, request):
username = request.data['username']
password = request.data['password']
user = authenticate(username=username, password=password)
if user is None:
return Response({'error': 'incorrect credentials'})
login(request, user)
return redirect('/auth/')
urlpatterns = [
url(r'^$', Root.as_view()),
url(r'^headers/$', HeadersView.as_view()),
url(r'^session/$', SessionView.as_view()),
url(r'^auth/$', AuthView.as_view()),
]
@unittest.skipUnless(requests, 'requests not installed')
@override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase):
def test_get_request(self):
client = get_requests_client()
response = client.get('/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {}
}
assert response.json() == expected
def test_get_request_query_params_in_url(self):
client = get_requests_client()
response = client.get('/?key=value')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected
def test_get_request_query_params_by_kwarg(self):
client = get_requests_client()
response = client.get('/', params={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected
def test_get_with_headers(self):
client = get_requests_client()
response = client.get('/headers/', headers={'User-Agent': 'example'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example'
def test_post_form_request(self):
client = get_requests_client()
response = client.post('/', data={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {'key': 'value'},
'FILES': {},
'JSON': None
}
assert response.json() == expected
def test_post_json_request(self):
client = get_requests_client()
response = client.post('/', json={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {},
'FILES': {},
'JSON': {'key': 'value'}
}
assert response.json() == expected
def test_post_multipart_request(self):
client = get_requests_client()
files = {
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
}
response = client.post('/', files=files)
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
'POST': {},
'JSON': None
}
assert response.json() == expected
def test_session(self):
client = get_requests_client()
response = client.get('/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {}
assert response.json() == expected
response = client.post('/session/', json={'example': 'abc'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
assert response.json() == expected
response = client.get('/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
assert response.json() == expected
def test_auth(self):
# Confirm session is not authenticated
client = get_requests_client()
response = client.get('/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'username': None
}
assert response.json() == expected
assert 'csrftoken' in response.cookies
csrftoken = response.cookies['csrftoken']
user = User.objects.create(username='tom')
user.set_password('password')
user.save()
# Perform a login
response = client.post('/auth/', json={
'username': 'tom',
'password': 'password'
}, headers={'X-CSRFToken': csrftoken})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'username': 'tom'
}
assert response.json() == expected
# Confirm session is authenticated
response = client.get('/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'username': 'tom'
}
assert response.json() == expected

View File

@ -1,9 +1,9 @@
from __future__ import unicode_literals
from django.conf.urls import url
from django.core.urlresolvers import NoReverseMatch
from django.test import TestCase, override_settings
from rest_framework.compat import NoReverseMatch
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals
import json
from collections import namedtuple
from django.conf.urls import include, url
@ -47,6 +48,21 @@ class MockViewSet(viewsets.ModelViewSet):
serializer_class = None
class EmptyPrefixSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RouterTestModel
fields = ('uuid', 'text')
class EmptyPrefixViewSet(viewsets.ModelViewSet):
queryset = [RouterTestModel(id=1, uuid='111', text='First'), RouterTestModel(id=2, uuid='222', text='Second')]
serializer_class = EmptyPrefixSerializer
def get_object(self, *args, **kwargs):
index = int(self.kwargs['pk']) - 1
return self.queryset[index]
notes_router = SimpleRouter()
notes_router.register(r'notes', NoteViewSet)
@ -56,11 +72,19 @@ kwarged_notes_router.register(r'notes', KWargedNoteViewSet)
namespaced_router = DefaultRouter()
namespaced_router.register(r'example', MockViewSet, base_name='example')
empty_prefix_router = SimpleRouter()
empty_prefix_router.register(r'', EmptyPrefixViewSet, base_name='empty_prefix')
empty_prefix_urls = [
url(r'^', include(empty_prefix_router.urls)),
]
urlpatterns = [
url(r'^non-namespaced/', include(namespaced_router.urls)),
url(r'^namespaced/', include(namespaced_router.urls, namespace='example')),
url(r'^example/', include(notes_router.urls)),
url(r'^example2/', include(kwarged_notes_router.urls)),
url(r'^empty-prefix/', include(empty_prefix_urls)),
]
@ -384,3 +408,28 @@ class TestDynamicListAndDetailRouter(TestCase):
def test_inherited_list_and_detail_route_decorators(self):
self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
@override_settings(ROOT_URLCONF='tests.test_routers')
class TestEmptyPrefix(TestCase):
def test_empty_prefix_list(self):
response = self.client.get('/empty-prefix/')
self.assertEqual(200, response.status_code)
self.assertEqual(
json.loads(response.content.decode('utf-8')),
[
{'uuid': '111', 'text': 'First'},
{'uuid': '222', 'text': 'Second'}
]
)
def test_empty_prefix_detail(self):
response = self.client.get('/empty-prefix/1/')
self.assertEqual(200, response.status_code)
self.assertEqual(
json.loads(response.content.decode('utf-8')),
{
'uuid': '111',
'text': 'First'
}
)

View File

@ -215,7 +215,7 @@ class TestSchemaGenerator(TestCase):
}
}
)
self.assertEquals(schema, expected)
self.assertEqual(schema, expected)
class SnippetListView(APIView):

View File

@ -3,9 +3,9 @@ from __future__ import unicode_literals
from collections import namedtuple
from django.conf.urls import include, url
from django.core import urlresolvers
from django.test import TestCase
from rest_framework.compat import RegexURLResolver, Resolver404
from rest_framework.test import APIRequestFactory
from rest_framework.urlpatterns import format_suffix_patterns
@ -28,7 +28,7 @@ class FormatSuffixTests(TestCase):
urlpatterns = format_suffix_patterns(urlpatterns)
except Exception:
self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
resolver = RegexURLResolver(r'^/', urlpatterns)
for test_path in test_paths:
request = factory.get(test_path.path)
try:
@ -43,7 +43,7 @@ class FormatSuffixTests(TestCase):
urlpatterns = format_suffix_patterns([
url(r'^test/$', dummy_view),
])
resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
resolver = RegexURLResolver(r'^/', urlpatterns)
test_paths = [
(URLTestPath('/test.api', (), {'format': 'api'}), True),
@ -55,7 +55,7 @@ class FormatSuffixTests(TestCase):
request = factory.get(test_path.path)
try:
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
except urlresolvers.Resolver404:
except Resolver404:
callback, callback_args, callback_kwargs = (None, None, None)
if not expected_resolved:
assert callback is None

View File

@ -1,5 +1,5 @@
from django.core.exceptions import ObjectDoesNotExist
from django.core.urlresolvers import NoReverseMatch
from rest_framework.compat import NoReverseMatch
class MockObject(object):

View File

@ -4,7 +4,7 @@ addopts=--tb=short
[tox]
envlist =
py27-{lint,docs},
{py27,py32,py33,py34,py35}-django18,
{py27,py33,py34,py35}-django18,
{py27,py34,py35}-django19,
{py27,py34,py35}-django110,
{py27,py34,py35}-django{master}
@ -25,7 +25,6 @@ basepython =
py35: python3.5
py34: python3.4
py33: python3.3
py32: python3.2
py27: python2.7
[testenv:py27-lint]