mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 13:54:47 +03:00
Merge branch 'version-3-5' into rejig-schema-generation
This commit is contained in:
commit
657a7c5e00
|
@ -14,7 +14,6 @@ env:
|
|||
- TOX_ENV=py35-django18
|
||||
- TOX_ENV=py34-django18
|
||||
- TOX_ENV=py33-django18
|
||||
- TOX_ENV=py32-django18
|
||||
- TOX_ENV=py27-django18
|
||||
- TOX_ENV=py27-django110
|
||||
- TOX_ENV=py35-django110
|
||||
|
|
|
@ -457,6 +457,8 @@ There are two keyword arguments you can use to control this behavior:
|
|||
- `html_cutoff` - If set this will be the maximum number of choices that will be displayed by a HTML select drop down. Set to `None` to disable any limiting. Defaults to `1000`.
|
||||
- `html_cutoff_text` - If set this will display a textual indicator if the maximum number of items have been cutoff in an HTML select drop down. Defaults to `"More than {count} items…"`
|
||||
|
||||
You can also control these globally using the settings `HTML_SELECT_CUTOFF` and `HTML_SELECT_CUTOFF_TEXT`.
|
||||
|
||||
In cases where the cutoff is being enforced you may want to instead use a plain input field in the HTML form. You can do so using the `style` keyword argument. For example:
|
||||
|
||||
assigned_to = serializers.SlugRelatedField(
|
||||
|
|
|
@ -23,7 +23,7 @@ There's no requirement for you to use them, but if you do then the self-describi
|
|||
|
||||
**Signature:** `reverse(viewname, *args, **kwargs)`
|
||||
|
||||
Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
Has the same behavior as [`django.urls.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
|
||||
You should **include the request as a keyword argument** to the function, for example:
|
||||
|
||||
|
@ -44,7 +44,7 @@ You should **include the request as a keyword argument** to the function, for ex
|
|||
|
||||
**Signature:** `reverse_lazy(viewname, *args, **kwargs)`
|
||||
|
||||
Has the same behavior as [`django.core.urlresolvers.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
Has the same behavior as [`django.urls.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
|
||||
As with the `reverse` function, you should **include the request as a keyword argument** to the function, for example:
|
||||
|
||||
|
|
|
@ -382,6 +382,22 @@ This should be a function with the following signature:
|
|||
|
||||
Default: `'rest_framework.views.get_view_description'`
|
||||
|
||||
## HTML Select Field cutoffs
|
||||
|
||||
Global settings for [select field cutoffs for rendering relational fields](relations.md#select-field-cutoffs) in the browsable API.
|
||||
|
||||
#### HTML_SELECT_CUTOFF
|
||||
|
||||
Global setting for the `html_cutoff` value. Must be an integer.
|
||||
|
||||
Default: 1000
|
||||
|
||||
#### HTML_SELECT_CUTOFF_TEXT
|
||||
|
||||
A string representing a global setting for `html_cutoff_text`.
|
||||
|
||||
Default: `"More than {count} items..."`
|
||||
|
||||
---
|
||||
|
||||
## Miscellaneous settings
|
||||
|
|
|
@ -197,7 +197,7 @@ REST framework includes the following test case classes, that mirror the existin
|
|||
|
||||
You can use any of REST framework's test case classes as you would for the regular Django test case classes. The `self.client` attribute will be an `APIClient` instance.
|
||||
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
from myproject.apps.core.models import Account
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Optional packages which may be used with REST framework.
|
||||
markdown==2.6.4
|
||||
django-guardian==1.4.3
|
||||
django-filter==0.13.0
|
||||
coreapi==1.32.0
|
||||
django-guardian==1.4.6
|
||||
django-filter==0.14.0
|
||||
coreapi==2.0.8
|
||||
|
|
|
@ -8,7 +8,7 @@ ______ _____ _____ _____ __
|
|||
"""
|
||||
|
||||
__title__ = 'Django REST framework'
|
||||
__version__ = '3.4.7'
|
||||
__version__ = '3.5.0'
|
||||
__author__ = 'Tom Christie'
|
||||
__license__ = 'BSD 2-Clause'
|
||||
__copyright__ = 'Copyright 2011-2016 Tom Christie'
|
||||
|
|
|
@ -16,6 +16,9 @@ class AuthTokenSerializer(serializers.Serializer):
|
|||
user = authenticate(username=username, password=password)
|
||||
|
||||
if user:
|
||||
# From Django 1.10 onwards the `authenticate` call simply
|
||||
# returns `None` for is_active=False users.
|
||||
# (Assuming the default `ModelBackend` authentication backend.)
|
||||
if not user.is_active:
|
||||
msg = _('User account is disabled.')
|
||||
raise serializers.ValidationError(msg)
|
||||
|
|
|
@ -23,6 +23,16 @@ except ImportError:
|
|||
from django.utils import importlib # Will be removed in Django 1.9
|
||||
|
||||
|
||||
try:
|
||||
from django.urls import (
|
||||
NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve
|
||||
)
|
||||
except ImportError:
|
||||
from django.core.urlresolvers import ( # Will be removed in Django 2.0
|
||||
NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import urlparse # Python 2.x
|
||||
except ImportError:
|
||||
|
@ -128,6 +138,12 @@ def is_authenticated(user):
|
|||
return user.is_authenticated
|
||||
|
||||
|
||||
def is_anonymous(user):
|
||||
if django.VERSION < (1, 10):
|
||||
return user.is_anonymous()
|
||||
return user.is_anonymous
|
||||
|
||||
|
||||
def get_related_model(field):
|
||||
if django.VERSION < (1, 9):
|
||||
return _resolve_model(field.rel.to)
|
||||
|
@ -178,6 +194,13 @@ except (ImportError, SyntaxError):
|
|||
uritemplate = None
|
||||
|
||||
|
||||
# requests is optional
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
|
||||
# Fixes (#1712). We keep the try/except for the test suite.
|
||||
guardian = None
|
||||
|
@ -200,8 +223,13 @@ try:
|
|||
|
||||
if markdown.version <= '2.2':
|
||||
HEADERID_EXT_PATH = 'headerid'
|
||||
else:
|
||||
LEVEL_PARAM = 'level'
|
||||
elif markdown.version < '2.6':
|
||||
HEADERID_EXT_PATH = 'markdown.extensions.headerid'
|
||||
LEVEL_PARAM = 'level'
|
||||
else:
|
||||
HEADERID_EXT_PATH = 'markdown.extensions.toc'
|
||||
LEVEL_PARAM = 'baselevel'
|
||||
|
||||
def apply_markdown(text):
|
||||
"""
|
||||
|
@ -211,7 +239,7 @@ try:
|
|||
extensions = [HEADERID_EXT_PATH]
|
||||
extension_configs = {
|
||||
HEADERID_EXT_PATH: {
|
||||
'level': '2'
|
||||
LEVEL_PARAM: '2'
|
||||
}
|
||||
}
|
||||
md = markdown.Markdown(
|
||||
|
@ -277,3 +305,11 @@ def template_render(template, context=None, request=None):
|
|||
# backends template, e.g. django.template.backends.django.Template
|
||||
else:
|
||||
return template.render(context, request=request)
|
||||
|
||||
|
||||
def set_many(instance, field, value):
|
||||
if django.VERSION < (1, 10):
|
||||
setattr(instance, field, value)
|
||||
else:
|
||||
field = getattr(instance, field)
|
||||
field.set(value)
|
||||
|
|
|
@ -49,18 +49,32 @@ class empty:
|
|||
pass
|
||||
|
||||
|
||||
def is_simple_callable(obj):
|
||||
if six.PY3:
|
||||
def is_simple_callable(obj):
|
||||
"""
|
||||
True if the object is a callable that takes no arguments.
|
||||
"""
|
||||
if not callable(obj):
|
||||
return False
|
||||
|
||||
sig = inspect.signature(obj)
|
||||
params = sig.parameters.values()
|
||||
return all(param.default != param.empty for param in params)
|
||||
|
||||
else:
|
||||
def is_simple_callable(obj):
|
||||
function = inspect.isfunction(obj)
|
||||
method = inspect.ismethod(obj)
|
||||
|
||||
if not (function or method):
|
||||
return False
|
||||
|
||||
if method:
|
||||
is_unbound = obj.im_self is None
|
||||
|
||||
args, _, _, defaults = inspect.getargspec(obj)
|
||||
len_args = len(args) if function else len(args) - 1
|
||||
|
||||
len_args = len(args) if function or is_unbound else len(args) - 1
|
||||
len_defaults = len(defaults) if defaults else 0
|
||||
return len_args <= len_defaults
|
||||
|
||||
|
|
|
@ -4,9 +4,6 @@ from __future__ import unicode_literals
|
|||
from collections import OrderedDict
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
|
||||
from django.core.urlresolvers import (
|
||||
NoReverseMatch, Resolver404, get_script_prefix, resolve
|
||||
)
|
||||
from django.db.models import Manager
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils import six
|
||||
|
@ -14,10 +11,14 @@ from django.utils.encoding import python_2_unicode_compatible, smart_text
|
|||
from django.utils.six.moves.urllib import parse as urlparse
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.compat import (
|
||||
NoReverseMatch, Resolver404, get_script_prefix, resolve
|
||||
)
|
||||
from rest_framework.fields import (
|
||||
Field, empty, get_attribute, is_simple_callable, iter_options
|
||||
)
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import html
|
||||
|
||||
|
||||
|
@ -71,14 +72,19 @@ MANY_RELATION_KWARGS = (
|
|||
|
||||
class RelatedField(Field):
|
||||
queryset = None
|
||||
html_cutoff = 1000
|
||||
html_cutoff_text = _('More than {count} items...')
|
||||
html_cutoff = None
|
||||
html_cutoff_text = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.queryset = kwargs.pop('queryset', self.queryset)
|
||||
self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff)
|
||||
self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text)
|
||||
|
||||
self.html_cutoff = kwargs.pop(
|
||||
'html_cutoff',
|
||||
self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF)
|
||||
)
|
||||
self.html_cutoff_text = kwargs.pop(
|
||||
'html_cutoff_text',
|
||||
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
|
||||
)
|
||||
if not method_overridden('get_queryset', RelatedField, self):
|
||||
assert self.queryset is not None or kwargs.get('read_only', None), (
|
||||
'Relational field must provide a `queryset` argument, '
|
||||
|
@ -447,15 +453,20 @@ class ManyRelatedField(Field):
|
|||
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
|
||||
'empty': _('This list may not be empty.')
|
||||
}
|
||||
html_cutoff = 1000
|
||||
html_cutoff_text = _('More than {count} items...')
|
||||
html_cutoff = None
|
||||
html_cutoff_text = None
|
||||
|
||||
def __init__(self, child_relation=None, *args, **kwargs):
|
||||
self.child_relation = child_relation
|
||||
self.allow_empty = kwargs.pop('allow_empty', True)
|
||||
self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff)
|
||||
self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text)
|
||||
|
||||
self.html_cutoff = kwargs.pop(
|
||||
'html_cutoff',
|
||||
self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF)
|
||||
)
|
||||
self.html_cutoff_text = kwargs.pop(
|
||||
'html_cutoff_text',
|
||||
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
|
||||
)
|
||||
assert child_relation is not None, '`child_relation` is a required argument.'
|
||||
super(ManyRelatedField, self).__init__(*args, **kwargs)
|
||||
self.child_relation.bind(field_name='', parent=self)
|
||||
|
|
|
@ -3,11 +3,11 @@ Provide urlresolver functions that return fully qualified URLs or view names
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.core.urlresolvers import reverse as django_reverse
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
from django.utils import six
|
||||
from django.utils.functional import lazy
|
||||
|
||||
from rest_framework.compat import reverse as django_reverse
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.urls import replace_query_param
|
||||
|
||||
|
@ -54,7 +54,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra
|
|||
|
||||
def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
|
||||
"""
|
||||
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
|
||||
Same as `django.urls.reverse`, but optionally takes a request
|
||||
and returns a fully qualified URL, using the request to get the base URL.
|
||||
"""
|
||||
if format is not None:
|
||||
|
|
|
@ -20,9 +20,9 @@ from collections import OrderedDict, namedtuple
|
|||
|
||||
from django.conf.urls import url
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
|
||||
from rest_framework import exceptions, renderers, views
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
|
@ -83,6 +83,7 @@ class BaseRouter(object):
|
|||
|
||||
|
||||
class SimpleRouter(BaseRouter):
|
||||
|
||||
routes = [
|
||||
# List route.
|
||||
Route(
|
||||
|
@ -258,6 +259,13 @@ class SimpleRouter(BaseRouter):
|
|||
trailing_slash=self.trailing_slash
|
||||
)
|
||||
|
||||
# If there is no prefix, the first part of the url is probably
|
||||
# controlled by project's urls.py and the router is in an app,
|
||||
# so a slash in the beginning will (A) cause Django to give
|
||||
# warnings and (B) generate URLS that will require using '//'.
|
||||
if not prefix and regex[:2] == '^/':
|
||||
regex = '^' + regex[2:]
|
||||
|
||||
view = viewset.as_view(mapping, **route.initkwargs)
|
||||
name = route.name.format(basename=basename)
|
||||
ret.append(url(regex, view, name=name))
|
||||
|
@ -289,42 +297,42 @@ class DefaultRouter(SimpleRouter):
|
|||
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
|
||||
super(DefaultRouter, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_schema_root_view(self, api_urls=None):
|
||||
"""
|
||||
Return a schema root view.
|
||||
"""
|
||||
schema_renderers = self.schema_renderers
|
||||
schema_generator = SchemaGenerator(
|
||||
title=self.schema_title,
|
||||
url=self.schema_url,
|
||||
patterns=api_urls
|
||||
)
|
||||
|
||||
class APISchemaView(views.APIView):
|
||||
_ignore_model_permissions = True
|
||||
renderer_classes = schema_renderers
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = schema_generator.get_schema(request)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
return APISchemaView.as_view()
|
||||
|
||||
def get_api_root_view(self, api_urls=None):
|
||||
"""
|
||||
Return a view to use as the API root.
|
||||
Return a basic root view.
|
||||
"""
|
||||
api_root_dict = OrderedDict()
|
||||
list_name = self.routes[0].name
|
||||
for prefix, viewset, basename in self.registry:
|
||||
api_root_dict[prefix] = list_name.format(basename=basename)
|
||||
|
||||
view_renderers = list(self.root_renderers)
|
||||
schema_media_types = []
|
||||
|
||||
if api_urls and self.schema_title:
|
||||
view_renderers += list(self.schema_renderers)
|
||||
schema_generator = SchemaGenerator(
|
||||
title=self.schema_title,
|
||||
url=self.schema_url,
|
||||
patterns=api_urls
|
||||
)
|
||||
schema_media_types = [
|
||||
renderer.media_type
|
||||
for renderer in self.schema_renderers
|
||||
]
|
||||
|
||||
class APIRoot(views.APIView):
|
||||
class APIRootView(views.APIView):
|
||||
_ignore_model_permissions = True
|
||||
renderer_classes = view_renderers
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
if request.accepted_renderer.media_type in schema_media_types:
|
||||
# Return a schema response.
|
||||
schema = schema_generator.get_schema(request)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
# Return a plain {"name": "hyperlink"} response.
|
||||
ret = OrderedDict()
|
||||
namespace = request.resolver_match.namespace
|
||||
|
@ -345,7 +353,7 @@ class DefaultRouter(SimpleRouter):
|
|||
|
||||
return Response(ret)
|
||||
|
||||
return APIRoot.as_view()
|
||||
return APIRootView.as_view()
|
||||
|
||||
def get_urls(self):
|
||||
"""
|
||||
|
@ -355,6 +363,9 @@ class DefaultRouter(SimpleRouter):
|
|||
urls = super(DefaultRouter, self).get_urls()
|
||||
|
||||
if self.include_root_view:
|
||||
if self.schema_title:
|
||||
view = self.get_schema_root_view(api_urls=urls)
|
||||
else:
|
||||
view = self.get_api_root_view(api_urls=urls)
|
||||
root_url = url(r'^$', view, name=self.root_view_name)
|
||||
urls.append(root_url)
|
||||
|
|
|
@ -2,12 +2,13 @@ from importlib import import_module
|
|||
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
|
||||
from django.utils import six
|
||||
from django.utils.encoding import force_text
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import coreapi, uritemplate, urlparse
|
||||
from rest_framework.compat import (
|
||||
RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse
|
||||
)
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
@ -329,7 +330,7 @@ class SchemaGenerator(object):
|
|||
fields += as_query_fields(filter_backend().get_fields(view))
|
||||
return fields
|
||||
|
||||
# Methods for generating the keys which are used to layout each link.
|
||||
# Methods for generating the link layout....
|
||||
|
||||
default_mapping = {
|
||||
'get': 'read',
|
||||
|
|
|
@ -23,7 +23,7 @@ from django.utils.functional import cached_property
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.compat import JSONField as ModelJSONField
|
||||
from rest_framework.compat import postgres_fields, unicode_to_repr
|
||||
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
|
||||
from rest_framework.utils import model_meta
|
||||
from rest_framework.utils.field_mapping import (
|
||||
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
|
||||
|
@ -892,18 +892,22 @@ class ModelSerializer(Serializer):
|
|||
# Save many-to-many relationships after the instance is created.
|
||||
if many_to_many:
|
||||
for field_name, value in many_to_many.items():
|
||||
setattr(instance, field_name, value)
|
||||
set_many(instance, field_name, value)
|
||||
|
||||
return instance
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise_errors_on_nested_writes('update', self, validated_data)
|
||||
info = model_meta.get_field_info(instance)
|
||||
|
||||
# Simply set each attribute on the instance, and then save it.
|
||||
# Note that unlike `.create()` we don't need to treat many-to-many
|
||||
# relationships as being a special case. During updates we already
|
||||
# have an instance pk for the relationships to be associated with.
|
||||
for attr, value in validated_data.items():
|
||||
if attr in info.relations and info.relations[attr].to_many:
|
||||
set_many(instance, attr, value)
|
||||
else:
|
||||
setattr(instance, attr, value)
|
||||
instance.save()
|
||||
|
||||
|
|
|
@ -111,6 +111,10 @@ DEFAULTS = {
|
|||
'COMPACT_JSON': True,
|
||||
'COERCE_DECIMAL_TO_STRING': True,
|
||||
'UPLOADED_FILES_USE_URL': True,
|
||||
|
||||
# Browseable API
|
||||
'HTML_SELECT_CUTOFF': 1000,
|
||||
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -3,14 +3,13 @@ from __future__ import absolute_import, unicode_literals
|
|||
import re
|
||||
|
||||
from django import template
|
||||
from django.core.urlresolvers import NoReverseMatch, reverse
|
||||
from django.template import loader
|
||||
from django.utils import six
|
||||
from django.utils.encoding import force_text, iri_to_uri
|
||||
from django.utils.html import escape, format_html, smart_urlquote
|
||||
from django.utils.safestring import SafeData, mark_safe
|
||||
|
||||
from rest_framework.compat import template_render
|
||||
from rest_framework.compat import NoReverseMatch, reverse, template_render
|
||||
from rest_framework.renderers import HTMLFormRenderer
|
||||
from rest_framework.utils.urls import replace_query_param
|
||||
|
||||
|
|
|
@ -4,7 +4,10 @@
|
|||
# to make it harder for the user to import the wrong thing without realizing.
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import io
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.handlers.wsgi import WSGIHandler
|
||||
from django.test import testcases
|
||||
from django.test.client import Client as DjangoClient
|
||||
from django.test.client import RequestFactory as DjangoRequestFactory
|
||||
|
@ -13,6 +16,7 @@ from django.utils import six
|
|||
from django.utils.encoding import force_bytes
|
||||
from django.utils.http import urlencode
|
||||
|
||||
from rest_framework.compat import coreapi, requests
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
|
@ -21,6 +25,118 @@ def force_authenticate(request, user=None, token=None):
|
|||
request._force_auth_token = token
|
||||
|
||||
|
||||
if requests is not None:
|
||||
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
||||
def get_all(self, key, default):
|
||||
return self.getheaders(key)
|
||||
|
||||
class MockOriginalResponse(object):
|
||||
def __init__(self, headers):
|
||||
self.msg = HeaderDict(headers)
|
||||
self.closed = False
|
||||
|
||||
def isclosed(self):
|
||||
return self.closed
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
|
||||
"""
|
||||
A transport adapter for `requests`, that makes requests via the
|
||||
Django WSGI app, rather than making actual HTTP requests over the network.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.app = WSGIHandler()
|
||||
self.factory = DjangoRequestFactory()
|
||||
|
||||
def get_environ(self, request):
|
||||
"""
|
||||
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
||||
"""
|
||||
method = request.method
|
||||
url = request.url
|
||||
kwargs = {}
|
||||
|
||||
# Set request content, if any exists.
|
||||
if request.body is not None:
|
||||
if hasattr(request.body, 'read'):
|
||||
kwargs['data'] = request.body.read()
|
||||
else:
|
||||
kwargs['data'] = request.body
|
||||
if 'content-type' in request.headers:
|
||||
kwargs['content_type'] = request.headers['content-type']
|
||||
|
||||
# Set request headers.
|
||||
for key, value in request.headers.items():
|
||||
key = key.upper()
|
||||
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
||||
continue
|
||||
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
||||
|
||||
return self.factory.generic(method, url, **kwargs).environ
|
||||
|
||||
def send(self, request, *args, **kwargs):
|
||||
"""
|
||||
Make an outgoing request to the Django WSGI application.
|
||||
"""
|
||||
raw_kwargs = {}
|
||||
|
||||
def start_response(wsgi_status, wsgi_headers):
|
||||
status, _, reason = wsgi_status.partition(' ')
|
||||
raw_kwargs['status'] = int(status)
|
||||
raw_kwargs['reason'] = reason
|
||||
raw_kwargs['headers'] = wsgi_headers
|
||||
raw_kwargs['version'] = 11
|
||||
raw_kwargs['preload_content'] = False
|
||||
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
|
||||
|
||||
# Make the outgoing request via WSGI.
|
||||
environ = self.get_environ(request)
|
||||
wsgi_response = self.app(environ, start_response)
|
||||
|
||||
# Build the underlying urllib3.HTTPResponse
|
||||
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
|
||||
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
|
||||
|
||||
# Build the requests.Response
|
||||
return self.build_response(request, raw)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
class DjangoTestSession(requests.Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DjangoTestSession, self).__init__(*args, **kwargs)
|
||||
|
||||
adapter = DjangoTestAdapter()
|
||||
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
|
||||
|
||||
for hostname in hostnames:
|
||||
if hostname == '*':
|
||||
hostname = ''
|
||||
self.mount('http://%s' % hostname, adapter)
|
||||
self.mount('https://%s' % hostname, adapter)
|
||||
|
||||
def request(self, method, url, *args, **kwargs):
|
||||
if ':' not in url:
|
||||
url = 'http://testserver/' + url.lstrip('/')
|
||||
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
|
||||
|
||||
|
||||
def get_requests_client():
|
||||
assert requests is not None, 'requests must be installed'
|
||||
return DjangoTestSession()
|
||||
|
||||
|
||||
def get_api_client():
|
||||
assert coreapi is not None, 'coreapi must be installed'
|
||||
session = get_requests_client()
|
||||
return coreapi.Client(transports=[
|
||||
coreapi.transports.HTTPTransport(session=session)
|
||||
])
|
||||
|
||||
|
||||
class APIRequestFactory(DjangoRequestFactory):
|
||||
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
||||
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.conf.urls import include, url
|
||||
from django.core.urlresolvers import RegexURLResolver
|
||||
|
||||
from rest_framework.compat import RegexURLResolver
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.core.urlresolvers import get_script_prefix, resolve
|
||||
from rest_framework.compat import get_script_prefix, resolve
|
||||
|
||||
|
||||
def get_breadcrumbs(url, request=None):
|
||||
|
|
|
@ -11,6 +11,7 @@ factory = APIRequestFactory()
|
|||
class BasicSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
fields = '__all__'
|
||||
|
||||
|
||||
class ManyPostView(generics.GenericAPIView):
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
def pytest_configure():
|
||||
from django.conf import settings
|
||||
|
||||
MIDDLEWARE = (
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
)
|
||||
|
||||
settings.configure(
|
||||
DEBUG_PROPAGATE_EXCEPTIONS=True,
|
||||
DATABASES={
|
||||
|
@ -21,12 +28,8 @@ def pytest_configure():
|
|||
'APP_DIRS': True,
|
||||
},
|
||||
],
|
||||
MIDDLEWARE_CLASSES=(
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
),
|
||||
MIDDLEWARE=MIDDLEWARE,
|
||||
MIDDLEWARE_CLASSES=MIDDLEWARE,
|
||||
INSTALLED_APPS=(
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
|
|
452
tests/test_api_client.py
Normal file
452
tests/test_api_client.py
Normal file
|
@ -0,0 +1,452 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.http import HttpResponse
|
||||
from django.test import override_settings
|
||||
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.parsers import FileUploadParser
|
||||
from rest_framework.renderers import CoreJSONRenderer
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APITestCase, get_api_client
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
def get_schema():
|
||||
return coreapi.Document(
|
||||
url='https://api.example.com/',
|
||||
title='Example API',
|
||||
content={
|
||||
'simple_link': coreapi.Link('/example/', description='example link'),
|
||||
'location': {
|
||||
'query': coreapi.Link('/example/', fields=[
|
||||
coreapi.Field(name='example', description='example field')
|
||||
]),
|
||||
'form': coreapi.Link('/example/', action='post', fields=[
|
||||
coreapi.Field(name='example'),
|
||||
]),
|
||||
'body': coreapi.Link('/example/', action='post', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
'path': coreapi.Link('/example/{id}', fields=[
|
||||
coreapi.Field(name='id', location='path')
|
||||
])
|
||||
},
|
||||
'encoding': {
|
||||
'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
|
||||
coreapi.Field(name='example')
|
||||
]),
|
||||
'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
|
||||
coreapi.Field(name='example')
|
||||
]),
|
||||
'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
},
|
||||
'response': {
|
||||
'download': coreapi.Link('/download/'),
|
||||
'text': coreapi.Link('/text/')
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _iterlists(querydict):
|
||||
if hasattr(querydict, 'iterlists'):
|
||||
return querydict.iterlists()
|
||||
return querydict.lists()
|
||||
|
||||
|
||||
def _get_query_params(request):
|
||||
# Return query params in a plain dict, using a list value if more
|
||||
# than one item is present for a given key.
|
||||
return {
|
||||
key: (value[0] if len(value) == 1 else value)
|
||||
for key, value in
|
||||
_iterlists(request.query_params)
|
||||
}
|
||||
|
||||
|
||||
def _get_data(request):
|
||||
if not isinstance(request.data, dict):
|
||||
return request.data
|
||||
# Coerce multidict into regular dict, and remove files to
|
||||
# make assertions simpler.
|
||||
if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'):
|
||||
# Use a list value if a QueryDict contains multiple items for a key.
|
||||
return {
|
||||
key: value[0] if len(value) == 1 else value
|
||||
for key, value in _iterlists(request.data)
|
||||
if key not in request.FILES
|
||||
}
|
||||
return {
|
||||
key: value
|
||||
for key, value in request.data.items()
|
||||
if key not in request.FILES
|
||||
}
|
||||
|
||||
|
||||
def _get_files(request):
|
||||
if not request.FILES:
|
||||
return {}
|
||||
return {
|
||||
key: {'name': value.name, 'content': value.read()}
|
||||
for key, value in request.FILES.items()
|
||||
}
|
||||
|
||||
|
||||
class SchemaView(APIView):
|
||||
renderer_classes = [CoreJSONRenderer]
|
||||
|
||||
def get(self, request):
|
||||
schema = get_schema()
|
||||
return Response(schema)
|
||||
|
||||
|
||||
class ListView(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': _get_query_params(request)
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
if request.content_type:
|
||||
content_type = request.content_type.split(';')[0]
|
||||
else:
|
||||
content_type = None
|
||||
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': _get_query_params(request),
|
||||
'data': _get_data(request),
|
||||
'files': _get_files(request),
|
||||
'content_type': content_type
|
||||
})
|
||||
|
||||
|
||||
class DetailView(APIView):
|
||||
def get(self, request, id):
|
||||
return Response({
|
||||
'id': id,
|
||||
'method': request.method,
|
||||
'query_params': _get_query_params(request)
|
||||
})
|
||||
|
||||
|
||||
class UploadView(APIView):
|
||||
parser_classes = [FileUploadParser]
|
||||
|
||||
def post(self, request):
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'files': _get_files(request),
|
||||
'content_type': request.content_type
|
||||
})
|
||||
|
||||
|
||||
class DownloadView(APIView):
|
||||
def get(self, request):
|
||||
return HttpResponse('some file content', content_type='image/png')
|
||||
|
||||
|
||||
class TextView(APIView):
|
||||
def get(self, request):
|
||||
return HttpResponse('123', content_type='text/plain')
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^$', SchemaView.as_view()),
|
||||
url(r'^example/$', ListView.as_view()),
|
||||
url(r'^example/(?P<id>[0-9]+)/$', DetailView.as_view()),
|
||||
url(r'^upload/$', UploadView.as_view()),
|
||||
url(r'^download/$', DownloadView.as_view()),
|
||||
url(r'^text/$', TextView.as_view()),
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_api_client')
|
||||
class APIClientTests(APITestCase):
|
||||
def test_api_client(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
assert schema.title == 'Example API'
|
||||
assert schema.url == 'https://api.example.com/'
|
||||
assert schema['simple_link'].description == 'example link'
|
||||
assert schema['location']['query'].fields[0].description == 'example field'
|
||||
data = client.action(schema, ['simple_link'])
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_query_params(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'query'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'example': '123'}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_query_params_with_multiple_values(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]})
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'example': ['1', '2', '3']}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_form_params(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'form'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/json',
|
||||
'query_params': {},
|
||||
'data': {'example': 123},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_body_params(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'body'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/json',
|
||||
'query_params': {},
|
||||
'data': 123,
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_path_params(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'path'], params={'id': 123})
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {},
|
||||
'id': '123'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
temp = tempfile.NamedTemporaryFile()
|
||||
temp.write(b'example file content')
|
||||
temp.flush()
|
||||
|
||||
with open(temp.name, 'rb') as upload:
|
||||
name = os.path.basename(upload.name)
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': upload})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {},
|
||||
'files': {'example': {'name': name, 'content': 'example file content'}}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_no_file(self):
|
||||
# When no file is included, multipart encoding should still be used.
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': 123})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {'example': '123'},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_multiple_values(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {'example': ['1', '2', '3']},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_string_file_content(self):
|
||||
# Test for `coreapi.utils.File` support.
|
||||
from coreapi.utils import File
|
||||
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = File(name='example.txt', content='123')
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {},
|
||||
'files': {'example': {'name': 'example.txt', 'content': '123'}}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_in_body(self):
|
||||
from coreapi.utils import File
|
||||
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'}
|
||||
data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {'bar': 'abc'},
|
||||
'files': {'foo': {'name': 'example.txt', 'content': '123'}}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
# URLencoded
|
||||
|
||||
def test_urlencoded_encoding(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/x-www-form-urlencoded',
|
||||
'query_params': {},
|
||||
'data': {'example': '123'},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_urlencoded_encoding_multiple_values(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/x-www-form-urlencoded',
|
||||
'query_params': {},
|
||||
'data': {'example': ['1', '2', '3']},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_urlencoded_encoding_in_body(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/x-www-form-urlencoded',
|
||||
'query_params': {},
|
||||
'data': {'foo': '123', 'bar': 'true'},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
# Raw uploads
|
||||
|
||||
def test_raw_upload(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
temp = tempfile.NamedTemporaryFile()
|
||||
temp.write(b'example file content')
|
||||
temp.flush()
|
||||
|
||||
with open(temp.name, 'rb') as upload:
|
||||
name = os.path.basename(upload.name)
|
||||
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'files': {'file': {'name': name, 'content': 'example file content'}},
|
||||
'content_type': 'application/octet-stream'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_raw_upload_string_file_content(self):
|
||||
from coreapi.utils import File
|
||||
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = File('example.txt', '123')
|
||||
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'files': {'file': {'name': 'example.txt', 'content': '123'}},
|
||||
'content_type': 'text/plain'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_raw_upload_explicit_content_type(self):
|
||||
from coreapi.utils import File
|
||||
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = File('example.txt', '123', 'text/html')
|
||||
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'files': {'file': {'name': 'example.txt', 'content': '123'}},
|
||||
'content_type': 'text/html'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
# Responses
|
||||
|
||||
def test_text_response(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['response', 'text'])
|
||||
|
||||
expected = '123'
|
||||
assert data == expected
|
||||
|
||||
def test_download_response(self):
|
||||
client = get_api_client()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['response', 'download'])
|
||||
assert data.basename == 'download.png'
|
||||
assert data.read() == b'some file content'
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
from django.conf.urls import url
|
||||
from django.db import connection, connections, transaction
|
||||
from django.http import Http404
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from django.test import TestCase, TransactionTestCase, override_settings
|
||||
from django.utils.decorators import method_decorator
|
||||
|
||||
from rest_framework import status
|
||||
|
@ -36,6 +36,20 @@ class APIExceptionView(APIView):
|
|||
raise APIException
|
||||
|
||||
|
||||
class NonAtomicAPIExceptionView(APIView):
|
||||
@method_decorator(transaction.non_atomic_requests)
|
||||
def dispatch(self, *args, **kwargs):
|
||||
return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
BasicModel.objects.all()
|
||||
raise Http404
|
||||
|
||||
urlpatterns = (
|
||||
url(r'^$', NonAtomicAPIExceptionView.as_view()),
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipUnless(
|
||||
connection.features.uses_savepoints,
|
||||
"'atomic' requires transactions and savepoints."
|
||||
|
@ -124,22 +138,8 @@ class DBTransactionAPIExceptionTests(TestCase):
|
|||
connection.features.uses_savepoints,
|
||||
"'atomic' requires transactions and savepoints."
|
||||
)
|
||||
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
|
||||
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
|
||||
@property
|
||||
def urls(self):
|
||||
class NonAtomicAPIExceptionView(APIView):
|
||||
@method_decorator(transaction.non_atomic_requests)
|
||||
def dispatch(self, *args, **kwargs):
|
||||
return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
BasicModel.objects.all()
|
||||
raise Http404
|
||||
|
||||
return (
|
||||
url(r'^$', NonAtomicAPIExceptionView.as_view()),
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from rest_framework.authentication import (
|
|||
)
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.authtoken.views import obtain_auth_token
|
||||
from rest_framework.compat import is_authenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APIClient, APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
|
@ -408,7 +409,7 @@ class FailingAuthAccessedInRenderer(TestCase):
|
|||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
request = renderer_context['request']
|
||||
if request.user.is_authenticated():
|
||||
if is_authenticated(request.user):
|
||||
return b'authenticated'
|
||||
return b'not authenticated'
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
|
@ -11,6 +12,67 @@ from django.utils import six, timezone
|
|||
|
||||
import rest_framework
|
||||
from rest_framework import serializers
|
||||
from rest_framework.fields import is_simple_callable
|
||||
|
||||
try:
|
||||
import typings
|
||||
except ImportError:
|
||||
typings = False
|
||||
|
||||
|
||||
# Tests for helper functions.
|
||||
# ---------------------------
|
||||
|
||||
class TestIsSimpleCallable:
|
||||
|
||||
def test_method(self):
|
||||
class Foo:
|
||||
@classmethod
|
||||
def classmethod(cls):
|
||||
pass
|
||||
|
||||
def valid(self):
|
||||
pass
|
||||
|
||||
def valid_kwargs(self, param='value'):
|
||||
pass
|
||||
|
||||
def invalid(self, param):
|
||||
pass
|
||||
|
||||
assert is_simple_callable(Foo.classmethod)
|
||||
|
||||
# unbound methods
|
||||
assert not is_simple_callable(Foo.valid)
|
||||
assert not is_simple_callable(Foo.valid_kwargs)
|
||||
assert not is_simple_callable(Foo.invalid)
|
||||
|
||||
# bound methods
|
||||
assert is_simple_callable(Foo().valid)
|
||||
assert is_simple_callable(Foo().valid_kwargs)
|
||||
assert not is_simple_callable(Foo().invalid)
|
||||
|
||||
def test_function(self):
|
||||
def simple():
|
||||
pass
|
||||
|
||||
def valid(param='value', param2='value'):
|
||||
pass
|
||||
|
||||
def invalid(param, param2='value'):
|
||||
pass
|
||||
|
||||
assert is_simple_callable(simple)
|
||||
assert is_simple_callable(valid)
|
||||
assert not is_simple_callable(invalid)
|
||||
|
||||
@unittest.skipUnless(typings, 'requires python 3.5')
|
||||
def test_type_annotation(self):
|
||||
# The annotation will otherwise raise a syntax error in python < 3.5
|
||||
exec("def valid(param: str='value'): pass", locals())
|
||||
valid = locals()['valid']
|
||||
|
||||
assert is_simple_callable(valid)
|
||||
|
||||
|
||||
# Tests for field keyword arguments and core functionality.
|
||||
|
|
|
@ -6,7 +6,6 @@ from decimal import Decimal
|
|||
|
||||
from django.conf.urls import url
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
|
@ -14,7 +13,7 @@ from django.utils.dateparse import parse_date
|
|||
from django.utils.six.moves import reload_module
|
||||
|
||||
from rest_framework import filters, generics, serializers, status
|
||||
from rest_framework.compat import django_filters
|
||||
from rest_framework.compat import django_filters, reverse
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
from .models import BaseFilterableItem, BasicModel, FilterableItem
|
||||
|
@ -77,6 +76,7 @@ if django_filters:
|
|||
|
||||
class Meta:
|
||||
model = BaseFilterableItem
|
||||
fields = '__all__'
|
||||
|
||||
class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
|
||||
queryset = FilterableItem.objects.all()
|
||||
|
@ -456,7 +456,7 @@ class AttributeModel(models.Model):
|
|||
|
||||
class SearchFilterModelFk(models.Model):
|
||||
title = models.CharField(max_length=20)
|
||||
attribute = models.ForeignKey(AttributeModel)
|
||||
attribute = models.ForeignKey(AttributeModel, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class SearchFilterFkSerializer(serializers.ModelSerializer):
|
||||
|
|
|
@ -20,7 +20,7 @@ from django.test import TestCase
|
|||
from django.utils import six
|
||||
|
||||
from rest_framework import serializers
|
||||
from rest_framework.compat import unicode_repr
|
||||
from rest_framework.compat import set_many, unicode_repr
|
||||
|
||||
|
||||
def dedent(blocktext):
|
||||
|
@ -651,7 +651,7 @@ class TestIntegration(TestCase):
|
|||
foreign_key=self.foreign_key_target,
|
||||
one_to_one=self.one_to_one_target,
|
||||
)
|
||||
self.instance.many_to_many = self.many_to_many_targets
|
||||
set_many(self.instance, 'many_to_many', self.many_to_many_targets)
|
||||
self.instance.save()
|
||||
|
||||
def test_pk_retrival(self):
|
||||
|
@ -962,7 +962,7 @@ class OneToOneTargetTestModel(models.Model):
|
|||
|
||||
|
||||
class OneToOneSourceTestModel(models.Model):
|
||||
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True)
|
||||
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class TestModelFieldValues(TestCase):
|
||||
|
@ -990,6 +990,7 @@ class TestUniquenessOverride(TestCase):
|
|||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = TestModel
|
||||
fields = '__all__'
|
||||
extra_kwargs = {'field_1': {'required': False}}
|
||||
|
||||
fields = TestSerializer().fields
|
||||
|
|
|
@ -4,7 +4,6 @@ import base64
|
|||
import unittest
|
||||
|
||||
from django.contrib.auth.models import Group, Permission, User
|
||||
from django.core.urlresolvers import ResolverMatch
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
|
||||
|
@ -12,7 +11,7 @@ from rest_framework import (
|
|||
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
|
||||
status
|
||||
)
|
||||
from rest_framework.compat import guardian
|
||||
from rest_framework.compat import ResolverMatch, guardian, set_many
|
||||
from rest_framework.filters import DjangoObjectPermissionsFilter
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
@ -74,15 +73,15 @@ class ModelPermissionsIntegrationTests(TestCase):
|
|||
def setUp(self):
|
||||
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
|
||||
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
|
||||
user.user_permissions = [
|
||||
set_many(user, 'user_permissions', [
|
||||
Permission.objects.get(codename='add_basicmodel'),
|
||||
Permission.objects.get(codename='change_basicmodel'),
|
||||
Permission.objects.get(codename='delete_basicmodel')
|
||||
]
|
||||
])
|
||||
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
|
||||
user.user_permissions = [
|
||||
set_many(user, 'user_permissions', [
|
||||
Permission.objects.get(codename='change_basicmodel'),
|
||||
]
|
||||
])
|
||||
|
||||
self.permitted_credentials = basic_auth_header('permitted', 'password')
|
||||
self.disallowed_credentials = basic_auth_header('disallowed', 'password')
|
||||
|
|
|
@ -13,6 +13,7 @@ from django.utils import six
|
|||
|
||||
from rest_framework import status
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework.compat import is_anonymous
|
||||
from rest_framework.parsers import BaseParser, FormParser, MultiPartParser
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
@ -169,9 +170,9 @@ class TestUserSetter(TestCase):
|
|||
|
||||
def test_user_can_logout(self):
|
||||
self.request.user = self.user
|
||||
self.assertFalse(self.request.user.is_anonymous())
|
||||
self.assertFalse(is_anonymous(self.request.user))
|
||||
logout(self.request)
|
||||
self.assertTrue(self.request.user.is_anonymous())
|
||||
self.assertTrue(is_anonymous(self.request.user))
|
||||
|
||||
def test_logged_in_user_is_set_on_wrapped_request(self):
|
||||
login(self.request, self.user)
|
||||
|
|
247
tests/test_requests_client.py
Normal file
247
tests/test_requests_client.py
Normal file
|
@ -0,0 +1,247 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.contrib.auth import authenticate, login
|
||||
from django.contrib.auth.models import User
|
||||
from django.shortcuts import redirect
|
||||
from django.test import override_settings
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
|
||||
|
||||
from rest_framework.compat import is_authenticated, requests
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APITestCase, get_requests_client
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class Root(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': request.query_params,
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
files = {
|
||||
key: (value.name, value.read())
|
||||
for key, value in request.FILES.items()
|
||||
}
|
||||
post = request.POST
|
||||
json = None
|
||||
if request.META.get('CONTENT_TYPE') == 'application/json':
|
||||
json = request.data
|
||||
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': request.query_params,
|
||||
'POST': post,
|
||||
'FILES': files,
|
||||
'JSON': json
|
||||
})
|
||||
|
||||
|
||||
class HeadersView(APIView):
|
||||
def get(self, request):
|
||||
headers = {
|
||||
key[5:].replace('_', '-'): value
|
||||
for key, value in request.META.items()
|
||||
if key.startswith('HTTP_')
|
||||
}
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'headers': headers
|
||||
})
|
||||
|
||||
|
||||
class SessionView(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
key: value for key, value in request.session.items()
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
for key, value in request.data.items():
|
||||
request.session[key] = value
|
||||
return Response({
|
||||
key: value for key, value in request.session.items()
|
||||
})
|
||||
|
||||
|
||||
class AuthView(APIView):
|
||||
@method_decorator(ensure_csrf_cookie)
|
||||
def get(self, request):
|
||||
if is_authenticated(request.user):
|
||||
username = request.user.username
|
||||
else:
|
||||
username = None
|
||||
return Response({
|
||||
'username': username
|
||||
})
|
||||
|
||||
@method_decorator(csrf_protect)
|
||||
def post(self, request):
|
||||
username = request.data['username']
|
||||
password = request.data['password']
|
||||
user = authenticate(username=username, password=password)
|
||||
if user is None:
|
||||
return Response({'error': 'incorrect credentials'})
|
||||
login(request, user)
|
||||
return redirect('/auth/')
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^$', Root.as_view()),
|
||||
url(r'^headers/$', HeadersView.as_view()),
|
||||
url(r'^session/$', SessionView.as_view()),
|
||||
url(r'^auth/$', AuthView.as_view()),
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipUnless(requests, 'requests not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_requests_client')
|
||||
class RequestsClientTests(APITestCase):
|
||||
def test_get_request(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_request_query_params_in_url(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/?key=value')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_request_query_params_by_kwarg(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/', params={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_with_headers(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/headers/', headers={'User-Agent': 'example'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
headers = response.json()['headers']
|
||||
assert headers['USER-AGENT'] == 'example'
|
||||
|
||||
def test_post_form_request(self):
|
||||
client = get_requests_client()
|
||||
response = client.post('/', data={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'POST': {'key': 'value'},
|
||||
'FILES': {},
|
||||
'JSON': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_post_json_request(self):
|
||||
client = get_requests_client()
|
||||
response = client.post('/', json={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'POST': {},
|
||||
'FILES': {},
|
||||
'JSON': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_post_multipart_request(self):
|
||||
client = get_requests_client()
|
||||
files = {
|
||||
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
|
||||
}
|
||||
response = client.post('/', files=files)
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
|
||||
'POST': {},
|
||||
'JSON': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_session(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/session/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {}
|
||||
assert response.json() == expected
|
||||
|
||||
response = client.post('/session/', json={'example': 'abc'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {'example': 'abc'}
|
||||
assert response.json() == expected
|
||||
|
||||
response = client.get('/session/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {'example': 'abc'}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_auth(self):
|
||||
# Confirm session is not authenticated
|
||||
client = get_requests_client()
|
||||
response = client.get('/auth/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
assert 'csrftoken' in response.cookies
|
||||
csrftoken = response.cookies['csrftoken']
|
||||
|
||||
user = User.objects.create(username='tom')
|
||||
user.set_password('password')
|
||||
user.save()
|
||||
|
||||
# Perform a login
|
||||
response = client.post('/auth/', json={
|
||||
'username': 'tom',
|
||||
'password': 'password'
|
||||
}, headers={'X-CSRFToken': csrftoken})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': 'tom'
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
# Confirm session is authenticated
|
||||
response = client.get('/auth/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': 'tom'
|
||||
}
|
||||
assert response.json() == expected
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
from collections import namedtuple
|
||||
|
||||
from django.conf.urls import include, url
|
||||
|
@ -47,6 +48,21 @@ class MockViewSet(viewsets.ModelViewSet):
|
|||
serializer_class = None
|
||||
|
||||
|
||||
class EmptyPrefixSerializer(serializers.HyperlinkedModelSerializer):
|
||||
class Meta:
|
||||
model = RouterTestModel
|
||||
fields = ('uuid', 'text')
|
||||
|
||||
|
||||
class EmptyPrefixViewSet(viewsets.ModelViewSet):
|
||||
queryset = [RouterTestModel(id=1, uuid='111', text='First'), RouterTestModel(id=2, uuid='222', text='Second')]
|
||||
serializer_class = EmptyPrefixSerializer
|
||||
|
||||
def get_object(self, *args, **kwargs):
|
||||
index = int(self.kwargs['pk']) - 1
|
||||
return self.queryset[index]
|
||||
|
||||
|
||||
notes_router = SimpleRouter()
|
||||
notes_router.register(r'notes', NoteViewSet)
|
||||
|
||||
|
@ -56,11 +72,19 @@ kwarged_notes_router.register(r'notes', KWargedNoteViewSet)
|
|||
namespaced_router = DefaultRouter()
|
||||
namespaced_router.register(r'example', MockViewSet, base_name='example')
|
||||
|
||||
empty_prefix_router = SimpleRouter()
|
||||
empty_prefix_router.register(r'', EmptyPrefixViewSet, base_name='empty_prefix')
|
||||
empty_prefix_urls = [
|
||||
url(r'^', include(empty_prefix_router.urls)),
|
||||
]
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^non-namespaced/', include(namespaced_router.urls)),
|
||||
url(r'^namespaced/', include(namespaced_router.urls, namespace='example')),
|
||||
url(r'^example/', include(notes_router.urls)),
|
||||
url(r'^example2/', include(kwarged_notes_router.urls)),
|
||||
|
||||
url(r'^empty-prefix/', include(empty_prefix_urls)),
|
||||
]
|
||||
|
||||
|
||||
|
@ -384,3 +408,28 @@ class TestDynamicListAndDetailRouter(TestCase):
|
|||
|
||||
def test_inherited_list_and_detail_route_decorators(self):
|
||||
self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF='tests.test_routers')
|
||||
class TestEmptyPrefix(TestCase):
|
||||
def test_empty_prefix_list(self):
|
||||
response = self.client.get('/empty-prefix/')
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual(
|
||||
json.loads(response.content.decode('utf-8')),
|
||||
[
|
||||
{'uuid': '111', 'text': 'First'},
|
||||
{'uuid': '222', 'text': 'Second'}
|
||||
]
|
||||
)
|
||||
|
||||
def test_empty_prefix_detail(self):
|
||||
response = self.client.get('/empty-prefix/1/')
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual(
|
||||
json.loads(response.content.decode('utf-8')),
|
||||
{
|
||||
'uuid': '111',
|
||||
'text': 'First'
|
||||
}
|
||||
)
|
||||
|
|
|
@ -215,7 +215,7 @@ class TestSchemaGenerator(TestCase):
|
|||
}
|
||||
}
|
||||
)
|
||||
self.assertEquals(schema, expected)
|
||||
self.assertEqual(schema, expected)
|
||||
|
||||
|
||||
class SnippetListView(APIView):
|
||||
|
|
|
@ -3,9 +3,9 @@ from __future__ import unicode_literals
|
|||
from collections import namedtuple
|
||||
|
||||
from django.conf.urls import include, url
|
||||
from django.core import urlresolvers
|
||||
from django.test import TestCase
|
||||
|
||||
from rest_framework.compat import RegexURLResolver, Resolver404
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
|
@ -28,7 +28,7 @@ class FormatSuffixTests(TestCase):
|
|||
urlpatterns = format_suffix_patterns(urlpatterns)
|
||||
except Exception:
|
||||
self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
|
||||
resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
|
||||
resolver = RegexURLResolver(r'^/', urlpatterns)
|
||||
for test_path in test_paths:
|
||||
request = factory.get(test_path.path)
|
||||
try:
|
||||
|
@ -43,7 +43,7 @@ class FormatSuffixTests(TestCase):
|
|||
urlpatterns = format_suffix_patterns([
|
||||
url(r'^test/$', dummy_view),
|
||||
])
|
||||
resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
|
||||
resolver = RegexURLResolver(r'^/', urlpatterns)
|
||||
|
||||
test_paths = [
|
||||
(URLTestPath('/test.api', (), {'format': 'api'}), True),
|
||||
|
@ -55,7 +55,7 @@ class FormatSuffixTests(TestCase):
|
|||
request = factory.get(test_path.path)
|
||||
try:
|
||||
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
|
||||
except urlresolvers.Resolver404:
|
||||
except Resolver404:
|
||||
callback, callback_args, callback_kwargs = (None, None, None)
|
||||
if not expected_resolved:
|
||||
assert callback is None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
|
||||
|
||||
class MockObject(object):
|
||||
|
|
3
tox.ini
3
tox.ini
|
@ -4,7 +4,7 @@ addopts=--tb=short
|
|||
[tox]
|
||||
envlist =
|
||||
py27-{lint,docs},
|
||||
{py27,py32,py33,py34,py35}-django18,
|
||||
{py27,py33,py34,py35}-django18,
|
||||
{py27,py34,py35}-django19,
|
||||
{py27,py34,py35}-django110,
|
||||
{py27,py34,py35}-django{master}
|
||||
|
@ -25,7 +25,6 @@ basepython =
|
|||
py35: python3.5
|
||||
py34: python3.4
|
||||
py33: python3.3
|
||||
py32: python3.2
|
||||
py27: python2.7
|
||||
|
||||
[testenv:py27-lint]
|
||||
|
|
Loading…
Reference in New Issue
Block a user