mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-22 09:36:49 +03:00
Version 3.5 (#4525)
* Start test case * Added 'requests' test client * Address typos * Graceful fallback if requests is not installed. * Add cookie support * Tests for auth and CSRF * Py3 compat * py3 compat * py3 compat * Add get_requests_client * Added SchemaGenerator.should_include_link * add settings for html cutoff on related fields * Router doesn't work if prefix is blank, though project urls.py handles prefix * Fix Django 1.10 to-many deprecation * Add django.core.urlresolvers compatibility * Update django-filter & django-guardian * Check for empty router prefix; adjust URL accordingly It's easiest to fix this issue after we have made the regex. To try to fix it before would require doing something different for List vs Detail, which means we'd have to know which type of url we're constructing before acting accordingly. * Fix misc django deprecations * Use TOC extension instead of header * Fix deprecations for py3k * Add py3k compatibility to is_simple_callable * Add is_simple_callable tests * Drop python 3.2 support (EOL, Dropped by Django) * schema_renderers= should *set* the renderers, not append to them. * API client (#4424) * Fix release notes * Add note about 'User account is disabled.' vs 'Unable to log in' * Clean up schema generation (#4527) * Handle multiple methods on custom action (#4529) * RequestsClient, CoreAPIClient * exclude_from_schema * Added 'get_schema_view()' shortcut * Added schema descriptions * Better descriptions for schemas * Add type annotation to schema generation * Coerce schema 'pk' in path to actual field name * Deprecations move into assertion errors * Use get_schema_view in tests * Updte CoreJSON media type * Handle schema structure correctly when path prefixs exist. Closes #4401 * Add PendingDeprecation to Router schema generation. * Added SCHEMA_COERCE_PATH_PK and SCHEMA_COERCE_METHOD_NAMES * Renamed and documented 'get_schema_fields' interface.
This commit is contained in:
parent
d49e26f127
commit
0dec36eb41
|
@ -14,7 +14,6 @@ env:
|
|||
- TOX_ENV=py35-django18
|
||||
- TOX_ENV=py34-django18
|
||||
- TOX_ENV=py33-django18
|
||||
- TOX_ENV=py32-django18
|
||||
- TOX_ENV=py27-django18
|
||||
- TOX_ENV=py27-django110
|
||||
- TOX_ENV=py35-django110
|
||||
|
|
|
@ -416,6 +416,12 @@ Generic filters may also present an interface in the browsable API. To do so you
|
|||
|
||||
The method should return a rendered HTML string.
|
||||
|
||||
## Pagination & schemas
|
||||
|
||||
You can also make the filter controls available to the schema autogeneration
|
||||
that REST framework provides, by implementing a `get_schema_fields()` method,
|
||||
which should return a list of `coreapi.Field` instances.
|
||||
|
||||
# Third party packages
|
||||
|
||||
The following third party packages provide additional filter implementations.
|
||||
|
|
|
@ -276,6 +276,12 @@ To have your custom pagination class be used by default, use the `DEFAULT_PAGINA
|
|||
|
||||
API responses for list endpoints will now include a `Link` header, instead of including the pagination links as part of the body of the response, for example:
|
||||
|
||||
## Pagination & schemas
|
||||
|
||||
You can also make the pagination controls available to the schema autogeneration
|
||||
that REST framework provides, by implementing a `get_schema_fields()` method,
|
||||
which should return a list of `coreapi.Field` instances.
|
||||
|
||||
---
|
||||
|
||||
![Link Header][link-header]
|
||||
|
|
|
@ -463,6 +463,8 @@ There are two keyword arguments you can use to control this behavior:
|
|||
- `html_cutoff` - If set this will be the maximum number of choices that will be displayed by a HTML select drop down. Set to `None` to disable any limiting. Defaults to `1000`.
|
||||
- `html_cutoff_text` - If set this will display a textual indicator if the maximum number of items have been cutoff in an HTML select drop down. Defaults to `"More than {count} items…"`
|
||||
|
||||
You can also control these globally using the settings `HTML_SELECT_CUTOFF` and `HTML_SELECT_CUTOFF_TEXT`.
|
||||
|
||||
In cases where the cutoff is being enforced you may want to instead use a plain input field in the HTML form. You can do so using the `style` keyword argument. For example:
|
||||
|
||||
assigned_to = serializers.SlugRelatedField(
|
||||
|
|
|
@ -23,7 +23,7 @@ There's no requirement for you to use them, but if you do then the self-describi
|
|||
|
||||
**Signature:** `reverse(viewname, *args, **kwargs)`
|
||||
|
||||
Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
Has the same behavior as [`django.urls.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
|
||||
You should **include the request as a keyword argument** to the function, for example:
|
||||
|
||||
|
@ -44,7 +44,7 @@ You should **include the request as a keyword argument** to the function, for ex
|
|||
|
||||
**Signature:** `reverse_lazy(viewname, *args, **kwargs)`
|
||||
|
||||
Has the same behavior as [`django.core.urlresolvers.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
Has the same behavior as [`django.urls.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port.
|
||||
|
||||
As with the `reverse` function, you should **include the request as a keyword argument** to the function, for example:
|
||||
|
||||
|
|
|
@ -102,15 +102,20 @@ REST framework includes functionality for auto-generating a schema,
|
|||
or allows you to specify one explicitly. There are a few different ways to
|
||||
add a schema to your API, depending on exactly what you need.
|
||||
|
||||
## Using DefaultRouter
|
||||
## The get_schema_view shortcut
|
||||
|
||||
If you're using `DefaultRouter` then you can include an auto-generated schema,
|
||||
simply by adding a `schema_title` argument to the router.
|
||||
The simplest way to include a schema in your project is to use the
|
||||
`get_schema_view()` function.
|
||||
|
||||
router = DefaultRouter(schema_title='Server Monitoring API')
|
||||
schema_view = get_schema_view(title="Server Monitoring API")
|
||||
|
||||
The schema will be included at the root URL, `/`, and presented to clients
|
||||
that include the Core JSON media type in their `Accept` header.
|
||||
urlpatterns = [
|
||||
url('^$', schema_view),
|
||||
...
|
||||
]
|
||||
|
||||
Once the view has been added, you'll be able to make API requests to retrieve
|
||||
the auto-generated schema definition.
|
||||
|
||||
$ http http://127.0.0.1:8000/ Accept:application/vnd.coreapi+json
|
||||
HTTP/1.0 200 OK
|
||||
|
@ -125,48 +130,43 @@ that include the Core JSON media type in their `Accept` header.
|
|||
...
|
||||
}
|
||||
|
||||
This is a great zero-configuration option for when you want to get up and
|
||||
running really quickly.
|
||||
The arguments to `get_schema_view()` are:
|
||||
|
||||
The other available options to `DefaultRouter` are:
|
||||
#### `title`
|
||||
|
||||
#### schema_renderers
|
||||
May be used to provide a descriptive title for the schema definition.
|
||||
|
||||
May be used to pass the set of renderer classes that can be used to render schema output.
|
||||
#### `url`
|
||||
|
||||
May be used to pass a canonical URL for the schema.
|
||||
|
||||
schema_view = get_schema_view(
|
||||
title='Server Monitoring API',
|
||||
url='https://www.example.org/api/'
|
||||
)
|
||||
|
||||
#### `renderer_classes`
|
||||
|
||||
May be used to pass the set of renderer classes that can be used to render the API root endpoint.
|
||||
|
||||
from rest_framework.renderers import CoreJSONRenderer
|
||||
from my_custom_package import APIBlueprintRenderer
|
||||
|
||||
router = DefaultRouter(schema_title='Server Monitoring API', schema_renderers=[
|
||||
CoreJSONRenderer, APIBlueprintRenderer
|
||||
])
|
||||
|
||||
#### schema_url
|
||||
|
||||
May be used to pass the root URL for the schema. This can either be used to ensure that
|
||||
the schema URLs include a canonical hostname and schema, or to ensure that all the
|
||||
schema URLs include a path prefix.
|
||||
|
||||
router = DefaultRouter(
|
||||
schema_title='Server Monitoring API',
|
||||
schema_url='https://www.example.org/api/'
|
||||
schema_view = get_schema_view(
|
||||
title='Server Monitoring API',
|
||||
url='https://www.example.org/api/',
|
||||
renderer_classes=[CoreJSONRenderer, APIBlueprintRenderer]
|
||||
)
|
||||
|
||||
If you want more flexibility over the schema output then you'll need to consider
|
||||
using `SchemaGenerator` instead.
|
||||
## Using an explicit schema view
|
||||
|
||||
#### root_renderers
|
||||
|
||||
May be used to pass the set of renderer classes that can be used to render the API root endpoint.
|
||||
|
||||
## Using SchemaGenerator
|
||||
|
||||
The most common way to add a schema to your API is to use the `SchemaGenerator`
|
||||
class to auto-generate the `Document` instance, and to return that from a view.
|
||||
If you need a little more control than the `get_schema_view()` shortcut gives you,
|
||||
then you can use the `SchemaGenerator` class directly to auto-generate the
|
||||
`Document` instance, and to return that from a view.
|
||||
|
||||
This option gives you the flexibility of setting up the schema endpoint
|
||||
with whatever behavior you want. For example, you can apply different
|
||||
permission, throttling or authentication policies to the schema endpoint.
|
||||
with whatever behaviour you want. For example, you can apply different
|
||||
permission, throttling, or authentication policies to the schema endpoint.
|
||||
|
||||
Here's an example of using `SchemaGenerator` together with a view to
|
||||
return the schema.
|
||||
|
@ -176,12 +176,13 @@ return the schema.
|
|||
from rest_framework.decorators import api_view, renderer_classes
|
||||
from rest_framework import renderers, response, schemas
|
||||
|
||||
generator = schemas.SchemaGenerator(title='Bookings API')
|
||||
|
||||
@api_view()
|
||||
@renderer_classes([renderers.CoreJSONRenderer])
|
||||
def schema_view(request):
|
||||
generator = schemas.SchemaGenerator(title='Bookings API')
|
||||
return response.Response(generator.get_schema())
|
||||
schema = generator.get_schema(request)
|
||||
return response.Response(schema)
|
||||
|
||||
**urls.py:**
|
||||
|
||||
|
@ -241,6 +242,69 @@ You could then either:
|
|||
|
||||
---
|
||||
|
||||
# Schemas as documentation
|
||||
|
||||
One common usage of API schemas is to use them to build documentation pages.
|
||||
|
||||
The schema generation in REST framework uses docstrings to automatically
|
||||
populate descriptions in the schema document.
|
||||
|
||||
These descriptions will be based on:
|
||||
|
||||
* The corresponding method docstring if one exists.
|
||||
* A named section within the class docstring, which can be either single line or multi-line.
|
||||
* The class docstring.
|
||||
|
||||
## Examples
|
||||
|
||||
An `APIView`, with an explicit method docstring.
|
||||
|
||||
class ListUsernames(APIView):
|
||||
def get(self, request):
|
||||
"""
|
||||
Return a list of all user names in the system.
|
||||
"""
|
||||
usernames = [user.username for user in User.objects.all()]
|
||||
return Response(usernames)
|
||||
|
||||
A `ViewSet`, with an explict action docstring.
|
||||
|
||||
class ListUsernames(ViewSet):
|
||||
def list(self, request):
|
||||
"""
|
||||
Return a list of all user names in the system.
|
||||
"""
|
||||
usernames = [user.username for user in User.objects.all()]
|
||||
return Response(usernames)
|
||||
|
||||
A generic view with sections in the class docstring, using single-line style.
|
||||
|
||||
class UserList(generics.ListCreateAPIView):
|
||||
"""
|
||||
get: Create a new user.
|
||||
post: List all the users.
|
||||
"""
|
||||
queryset = User.objects.all()
|
||||
serializer_class = UserSerializer
|
||||
permission_classes = (IsAdminUser,)
|
||||
|
||||
A generic viewset with sections in the class docstring, using multi-line style.
|
||||
|
||||
class UserViewSet(viewsets.ModelViewSet):
|
||||
"""
|
||||
API endpoint that allows users to be viewed or edited.
|
||||
|
||||
retrieve:
|
||||
Return a user instance.
|
||||
|
||||
list:
|
||||
Return all users, ordered by most recently joined.
|
||||
"""
|
||||
queryset = User.objects.all().order_by('-date_joined')
|
||||
serializer_class = UserSerializer
|
||||
|
||||
---
|
||||
|
||||
# Alternate schema formats
|
||||
|
||||
In order to support an alternate schema format, you need to implement a custom renderer
|
||||
|
|
|
@ -234,6 +234,28 @@ Default:
|
|||
|
||||
---
|
||||
|
||||
## Schema generation controls
|
||||
|
||||
#### SCHEMA_COERCE_PATH_PK
|
||||
|
||||
If set, this maps the `'pk'` identifier in the URL conf onto the actual field
|
||||
name when generating a schema path parameter. Typically this will be `'id'`.
|
||||
This gives a more suitable representation as "primary key" is an implementation
|
||||
detail, wheras "identifier" is a more general concept.
|
||||
|
||||
Default: `True`
|
||||
|
||||
#### SCHEMA_COERCE_METHOD_NAMES
|
||||
|
||||
If set, this is used to map internal viewset method names onto external action
|
||||
names used in the schema generation. This allows us to generate names that
|
||||
are more suitable for an external representation than those that are used
|
||||
internally in the codebase.
|
||||
|
||||
Default: `{'retrieve': 'read', 'destroy': 'delete'}`
|
||||
|
||||
---
|
||||
|
||||
## Content type controls
|
||||
|
||||
#### URL_FORMAT_OVERRIDE
|
||||
|
@ -382,6 +404,22 @@ This should be a function with the following signature:
|
|||
|
||||
Default: `'rest_framework.views.get_view_description'`
|
||||
|
||||
## HTML Select Field cutoffs
|
||||
|
||||
Global settings for [select field cutoffs for rendering relational fields](relations.md#select-field-cutoffs) in the browsable API.
|
||||
|
||||
#### HTML_SELECT_CUTOFF
|
||||
|
||||
Global setting for the `html_cutoff` value. Must be an integer.
|
||||
|
||||
Default: 1000
|
||||
|
||||
#### HTML_SELECT_CUTOFF_TEXT
|
||||
|
||||
A string representing a global setting for `html_cutoff_text`.
|
||||
|
||||
Default: `"More than {count} items..."`
|
||||
|
||||
---
|
||||
|
||||
## Miscellaneous settings
|
||||
|
|
|
@ -184,6 +184,99 @@ As usual CSRF validation will only apply to any session authenticated views. Th
|
|||
|
||||
---
|
||||
|
||||
# RequestsClient
|
||||
|
||||
REST framework also includes a client for interacting with your application
|
||||
using the popular Python library, `requests`.
|
||||
|
||||
This exposes exactly the same interface as if you were using a requests session
|
||||
directly.
|
||||
|
||||
client = RequestsClient()
|
||||
response = client.get('http://testserver/users/')
|
||||
|
||||
Note that the requests client requires you to pass fully qualified URLs.
|
||||
|
||||
## Headers & Authentication
|
||||
|
||||
Custom headers and authentication credentials can be provided in the same way
|
||||
as [when using a standard `requests.Session` instance](http://docs.python-requests.org/en/master/user/advanced/#session-objects).
|
||||
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
client.auth = HTTPBasicAuth('user', 'pass')
|
||||
client.headers.update({'x-test': 'true'})
|
||||
|
||||
## CSRF
|
||||
|
||||
If you're using `SessionAuthentication` then you'll need to include a CSRF token
|
||||
for any `POST`, `PUT`, `PATCH` or `DELETE` requests.
|
||||
|
||||
You can do so by following the same flow that a JavaScript based client would use.
|
||||
First make a `GET` request in order to obtain a CRSF token, then present that
|
||||
token in the following request.
|
||||
|
||||
For example...
|
||||
|
||||
client = RequestsClient()
|
||||
|
||||
# Obtain a CSRF token.
|
||||
response = client.get('/homepage/')
|
||||
assert response.status_code == 200
|
||||
csrftoken = response.cookies['csrftoken']
|
||||
|
||||
# Interact with the API.
|
||||
response = client.post('/organisations/', json={
|
||||
'name': 'MegaCorp',
|
||||
'status': 'active'
|
||||
}, headers={'X-CSRFToken': csrftoken})
|
||||
assert response.status_code == 200
|
||||
|
||||
## Live tests
|
||||
|
||||
With careful usage both the `RequestsClient` and the `CoreAPIClient` provide
|
||||
the ability to write test cases that can run either in development, or be run
|
||||
directly against your staging server or production environment.
|
||||
|
||||
Using this style to create basic tests of a few core piece of functionality is
|
||||
a powerful way to validate your live service. Doing so may require some careful
|
||||
attention to setup and teardown to ensure that the tests run in a way that they
|
||||
do not directly affect customer data.
|
||||
|
||||
---
|
||||
|
||||
# CoreAPIClient
|
||||
|
||||
The CoreAPIClient allows you to interact with your API using the Python
|
||||
`coreapi` client library.
|
||||
|
||||
# Fetch the API schema
|
||||
url = reverse('schema')
|
||||
client = CoreAPIClient()
|
||||
schema = client.get(url)
|
||||
|
||||
# Create a new organisation
|
||||
params = {'name': 'MegaCorp', 'status': 'active'}
|
||||
client.action(schema, ['organisations', 'create'], params)
|
||||
|
||||
# Ensure that the organisation exists in the listing
|
||||
data = client.action(schema, ['organisations', 'list'])
|
||||
assert(len(data) == 1)
|
||||
assert(data == [{'name': 'MegaCorp', 'status': 'active'}])
|
||||
|
||||
## Headers & Authentication
|
||||
|
||||
Custom headers and authentication may be used with `CoreAPIClient` in a
|
||||
similar way as with `RequestsClient`.
|
||||
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
client = CoreAPIClient()
|
||||
client.session.auth = HTTPBasicAuth('user', 'pass')
|
||||
client.session.headers.update({'x-test': 'true'})
|
||||
|
||||
---
|
||||
|
||||
# Test cases
|
||||
|
||||
REST framework includes the following test case classes, that mirror the existing Django test case classes, but use `APIClient` instead of Django's default `Client`.
|
||||
|
@ -197,7 +290,7 @@ REST framework includes the following test case classes, that mirror the existin
|
|||
|
||||
You can use any of REST framework's test case classes as you would for the regular Django test case classes. The `self.client` attribute will be an `APIClient` instance.
|
||||
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
from myproject.apps.core.models import Account
|
||||
|
|
|
@ -127,7 +127,7 @@ REST framework also allows you to work with regular function based views. It pr
|
|||
|
||||
## @api_view()
|
||||
|
||||
**Signature:** `@api_view(http_method_names=['GET'])`
|
||||
**Signature:** `@api_view(http_method_names=['GET'], exclude_from_schema=False)`
|
||||
|
||||
The core of this functionality is the `api_view` decorator, which takes a list of HTTP methods that your view should respond to. For example, this is how you would write a very simple view that just manually returns some data:
|
||||
|
||||
|
@ -139,7 +139,7 @@ The core of this functionality is the `api_view` decorator, which takes a list o
|
|||
|
||||
This view will use the default renderers, parsers, authentication classes etc specified in the [settings].
|
||||
|
||||
By default only `GET` methods will be accepted. Other methods will respond with "405 Method Not Allowed". To alter this behavior, specify which methods the view allows, like so:
|
||||
By default only `GET` methods will be accepted. Other methods will respond with "405 Method Not Allowed". To alter this behaviour, specify which methods the view allows, like so:
|
||||
|
||||
@api_view(['GET', 'POST'])
|
||||
def hello_world(request):
|
||||
|
@ -147,6 +147,13 @@ By default only `GET` methods will be accepted. Other methods will respond with
|
|||
return Response({"message": "Got some data!", "data": request.data})
|
||||
return Response({"message": "Hello, world!"})
|
||||
|
||||
You can also mark an API view as being omitted from any [auto-generated schema][schemas],
|
||||
using the `exclude_from_schema` argument.:
|
||||
|
||||
@api_view(['GET'], exclude_from_schema=True)
|
||||
def api_docs(request):
|
||||
...
|
||||
|
||||
## API policy decorators
|
||||
|
||||
To override the default settings, REST framework provides a set of additional decorators which can be added to your views. These must come *after* (below) the `@api_view` decorator. For example, to create a view that uses a [throttle][throttling] to ensure it can only be called once per day by a particular user, use the `@throttle_classes` decorator, passing a list of throttle classes:
|
||||
|
@ -178,3 +185,4 @@ Each of these decorators takes a single argument which must be a list or tuple o
|
|||
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
|
||||
[settings]: settings.md
|
||||
[throttling]: throttling.md
|
||||
[schemas]: schemas.md
|
||||
|
|
|
@ -88,7 +88,7 @@ The first thing we need to get started on our Web API is to provide a way of ser
|
|||
|
||||
|
||||
class SnippetSerializer(serializers.Serializer):
|
||||
pk = serializers.IntegerField(read_only=True)
|
||||
id = serializers.IntegerField(read_only=True)
|
||||
title = serializers.CharField(required=False, allow_blank=True, max_length=100)
|
||||
code = serializers.CharField(style={'base_template': 'textarea.html'})
|
||||
linenos = serializers.BooleanField(required=False)
|
||||
|
@ -144,13 +144,13 @@ We've now got a few snippet instances to play with. Let's take a look at serial
|
|||
|
||||
serializer = SnippetSerializer(snippet)
|
||||
serializer.data
|
||||
# {'pk': 2, 'title': u'', 'code': u'print "hello, world"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'}
|
||||
# {'id': 2, 'title': u'', 'code': u'print "hello, world"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'}
|
||||
|
||||
At this point we've translated the model instance into Python native datatypes. To finalize the serialization process we render the data into `json`.
|
||||
|
||||
content = JSONRenderer().render(serializer.data)
|
||||
content
|
||||
# '{"pk": 2, "title": "", "code": "print \\"hello, world\\"\\n", "linenos": false, "language": "python", "style": "friendly"}'
|
||||
# '{"id": 2, "title": "", "code": "print \\"hello, world\\"\\n", "linenos": false, "language": "python", "style": "friendly"}'
|
||||
|
||||
Deserialization is similar. First we parse a stream into Python native datatypes...
|
||||
|
||||
|
@ -175,7 +175,7 @@ We can also serialize querysets instead of model instances. To do so we simply
|
|||
|
||||
serializer = SnippetSerializer(Snippet.objects.all(), many=True)
|
||||
serializer.data
|
||||
# [OrderedDict([('pk', 1), ('title', u''), ('code', u'foo = "bar"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('pk', 2), ('title', u''), ('code', u'print "hello, world"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('pk', 3), ('title', u''), ('code', u'print "hello, world"'), ('linenos', False), ('language', 'python'), ('style', 'friendly')])]
|
||||
# [OrderedDict([('id', 1), ('title', u''), ('code', u'foo = "bar"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('id', 2), ('title', u''), ('code', u'print "hello, world"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('id', 3), ('title', u''), ('code', u'print "hello, world"'), ('linenos', False), ('language', 'python'), ('style', 'friendly')])]
|
||||
|
||||
## Using ModelSerializers
|
||||
|
||||
|
@ -259,12 +259,12 @@ Note that because we want to be able to POST to this view from clients that won'
|
|||
We'll also need a view which corresponds to an individual snippet, and can be used to retrieve, update or delete the snippet.
|
||||
|
||||
@csrf_exempt
|
||||
def snippet_detail(request, pk):
|
||||
def snippet_detail(request, id):
|
||||
"""
|
||||
Retrieve, update or delete a code snippet.
|
||||
"""
|
||||
try:
|
||||
snippet = Snippet.objects.get(pk=pk)
|
||||
snippet = Snippet.objects.get(id=id)
|
||||
except Snippet.DoesNotExist:
|
||||
return HttpResponse(status=404)
|
||||
|
||||
|
@ -291,7 +291,7 @@ Finally we need to wire these views up. Create the `snippets/urls.py` file:
|
|||
|
||||
urlpatterns = [
|
||||
url(r'^snippets/$', views.snippet_list),
|
||||
url(r'^snippets/(?P<pk>[0-9]+)/$', views.snippet_detail),
|
||||
url(r'^snippets/(?P<id>[0-9]+)/$', views.snippet_detail),
|
||||
]
|
||||
|
||||
We also need to wire up the root urlconf, in the `tutorial/urls.py` file, to include our snippet app's URLs.
|
||||
|
|
|
@ -66,12 +66,12 @@ Our instance view is an improvement over the previous example. It's a little mo
|
|||
Here is the view for an individual snippet, in the `views.py` module.
|
||||
|
||||
@api_view(['GET', 'PUT', 'DELETE'])
|
||||
def snippet_detail(request, pk):
|
||||
def snippet_detail(request, id):
|
||||
"""
|
||||
Retrieve, update or delete a snippet instance.
|
||||
"""
|
||||
try:
|
||||
snippet = Snippet.objects.get(pk=pk)
|
||||
snippet = Snippet.objects.get(id=id)
|
||||
except Snippet.DoesNotExist:
|
||||
return Response(status=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
|
@ -104,7 +104,7 @@ Start by adding a `format` keyword argument to both of the views, like so.
|
|||
|
||||
and
|
||||
|
||||
def snippet_detail(request, pk, format=None):
|
||||
def snippet_detail(request, id, format=None):
|
||||
|
||||
Now update the `urls.py` file slightly, to append a set of `format_suffix_patterns` in addition to the existing URLs.
|
||||
|
||||
|
@ -114,7 +114,7 @@ Now update the `urls.py` file slightly, to append a set of `format_suffix_patter
|
|||
|
||||
urlpatterns = [
|
||||
url(r'^snippets/$', views.snippet_list),
|
||||
url(r'^snippets/(?P<pk>[0-9]+)$', views.snippet_detail),
|
||||
url(r'^snippets/(?P<id>[0-9]+)$', views.snippet_detail),
|
||||
]
|
||||
|
||||
urlpatterns = format_suffix_patterns(urlpatterns)
|
||||
|
|
|
@ -36,27 +36,27 @@ So far, so good. It looks pretty similar to the previous case, but we've got be
|
|||
"""
|
||||
Retrieve, update or delete a snippet instance.
|
||||
"""
|
||||
def get_object(self, pk):
|
||||
def get_object(self, id):
|
||||
try:
|
||||
return Snippet.objects.get(pk=pk)
|
||||
return Snippet.objects.get(id=id)
|
||||
except Snippet.DoesNotExist:
|
||||
raise Http404
|
||||
|
||||
def get(self, request, pk, format=None):
|
||||
snippet = self.get_object(pk)
|
||||
def get(self, request, id, format=None):
|
||||
snippet = self.get_object(id)
|
||||
serializer = SnippetSerializer(snippet)
|
||||
return Response(serializer.data)
|
||||
|
||||
def put(self, request, pk, format=None):
|
||||
snippet = self.get_object(pk)
|
||||
def put(self, request, id, format=None):
|
||||
snippet = self.get_object(id)
|
||||
serializer = SnippetSerializer(snippet, data=request.data)
|
||||
if serializer.is_valid():
|
||||
serializer.save()
|
||||
return Response(serializer.data)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def delete(self, request, pk, format=None):
|
||||
snippet = self.get_object(pk)
|
||||
def delete(self, request, id, format=None):
|
||||
snippet = self.get_object(id)
|
||||
snippet.delete()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
@ -70,7 +70,7 @@ We'll also need to refactor our `urls.py` slightly now we're using class-based v
|
|||
|
||||
urlpatterns = [
|
||||
url(r'^snippets/$', views.SnippetList.as_view()),
|
||||
url(r'^snippets/(?P<pk>[0-9]+)/$', views.SnippetDetail.as_view()),
|
||||
url(r'^snippets/(?P<id>[0-9]+)/$', views.SnippetDetail.as_view()),
|
||||
]
|
||||
|
||||
urlpatterns = format_suffix_patterns(urlpatterns)
|
||||
|
|
|
@ -88,7 +88,7 @@ Make sure to also import the `UserSerializer` class
|
|||
Finally we need to add those views into the API, by referencing them from the URL conf. Add the following to the patterns in `urls.py`.
|
||||
|
||||
url(r'^users/$', views.UserList.as_view()),
|
||||
url(r'^users/(?P<pk>[0-9]+)/$', views.UserDetail.as_view()),
|
||||
url(r'^users/(?P<id>[0-9]+)/$', views.UserDetail.as_view()),
|
||||
|
||||
## Associating Snippets with Users
|
||||
|
||||
|
@ -150,7 +150,7 @@ The `r'^api-auth/'` part of pattern can actually be whatever URL you want to use
|
|||
|
||||
Now if you open up the browser again and refresh the page you'll see a 'Login' link in the top right of the page. If you log in as one of the users you created earlier, you'll be able to create code snippets again.
|
||||
|
||||
Once you've created a few code snippets, navigate to the '/users/' endpoint, and notice that the representation includes a list of the snippet pks that are associated with each user, in each user's 'snippets' field.
|
||||
Once you've created a few code snippets, navigate to the '/users/' endpoint, and notice that the representation includes a list of the snippet ids that are associated with each user, in each user's 'snippets' field.
|
||||
|
||||
## Object level permissions
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ We'll add a url pattern for our new API root in `snippets/urls.py`:
|
|||
|
||||
And then add a url pattern for the snippet highlights:
|
||||
|
||||
url(r'^snippets/(?P<pk>[0-9]+)/highlight/$', views.SnippetHighlight.as_view()),
|
||||
url(r'^snippets/(?P<id>[0-9]+)/highlight/$', views.SnippetHighlight.as_view()),
|
||||
|
||||
## Hyperlinking our API
|
||||
|
||||
|
@ -67,7 +67,7 @@ In this case we'd like to use a hyperlinked style between entities. In order to
|
|||
|
||||
The `HyperlinkedModelSerializer` has the following differences from `ModelSerializer`:
|
||||
|
||||
* It does not include the `pk` field by default.
|
||||
* It does not include the `id` field by default.
|
||||
* It includes a `url` field, using `HyperlinkedIdentityField`.
|
||||
* Relationships use `HyperlinkedRelatedField`,
|
||||
instead of `PrimaryKeyRelatedField`.
|
||||
|
@ -80,7 +80,7 @@ We can easily re-write our existing serializers to use hyperlinking. In your `sn
|
|||
|
||||
class Meta:
|
||||
model = Snippet
|
||||
fields = ('url', 'pk', 'highlight', 'owner',
|
||||
fields = ('url', 'id', 'highlight', 'owner',
|
||||
'title', 'code', 'linenos', 'language', 'style')
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ We can easily re-write our existing serializers to use hyperlinking. In your `sn
|
|||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ('url', 'pk', 'username', 'snippets')
|
||||
fields = ('url', 'id', 'username', 'snippets')
|
||||
|
||||
Notice that we've also added a new `'highlight'` field. This field is of the same type as the `url` field, except that it points to the `'snippet-highlight'` url pattern, instead of the `'snippet-detail'` url pattern.
|
||||
|
||||
|
@ -116,16 +116,16 @@ After adding all those names into our URLconf, our final `snippets/urls.py` file
|
|||
url(r'^snippets/$',
|
||||
views.SnippetList.as_view(),
|
||||
name='snippet-list'),
|
||||
url(r'^snippets/(?P<pk>[0-9]+)/$',
|
||||
url(r'^snippets/(?P<id>[0-9]+)/$',
|
||||
views.SnippetDetail.as_view(),
|
||||
name='snippet-detail'),
|
||||
url(r'^snippets/(?P<pk>[0-9]+)/highlight/$',
|
||||
url(r'^snippets/(?P<id>[0-9]+)/highlight/$',
|
||||
views.SnippetHighlight.as_view(),
|
||||
name='snippet-highlight'),
|
||||
url(r'^users/$',
|
||||
views.UserList.as_view(),
|
||||
name='user-list'),
|
||||
url(r'^users/(?P<pk>[0-9]+)/$',
|
||||
url(r'^users/(?P<id>[0-9]+)/$',
|
||||
views.UserDetail.as_view(),
|
||||
name='user-detail')
|
||||
])
|
||||
|
|
|
@ -92,10 +92,10 @@ Now that we've bound our resources into concrete views, we can register the view
|
|||
urlpatterns = format_suffix_patterns([
|
||||
url(r'^$', api_root),
|
||||
url(r'^snippets/$', snippet_list, name='snippet-list'),
|
||||
url(r'^snippets/(?P<pk>[0-9]+)/$', snippet_detail, name='snippet-detail'),
|
||||
url(r'^snippets/(?P<pk>[0-9]+)/highlight/$', snippet_highlight, name='snippet-highlight'),
|
||||
url(r'^snippets/(?P<id>[0-9]+)/$', snippet_detail, name='snippet-detail'),
|
||||
url(r'^snippets/(?P<id>[0-9]+)/highlight/$', snippet_highlight, name='snippet-highlight'),
|
||||
url(r'^users/$', user_list, name='user-list'),
|
||||
url(r'^users/(?P<pk>[0-9]+)/$', user_detail, name='user-detail')
|
||||
url(r'^users/(?P<id>[0-9]+)/$', user_detail, name='user-detail')
|
||||
])
|
||||
|
||||
## Using Routers
|
||||
|
|
|
@ -33,10 +33,17 @@ API schema.
|
|||
|
||||
$ pip install coreapi
|
||||
|
||||
We can now include a schema for our API, by adding a `schema_title` argument to
|
||||
the router instantiation.
|
||||
We can now include a schema for our API, by including an autogenerated schema
|
||||
view in our URL configuration.
|
||||
|
||||
router = DefaultRouter(schema_title='Pastebin API')
|
||||
from rest_framework.schemas import get_schema_view
|
||||
|
||||
schema_view = get_schema_view(title='Pastebin API')
|
||||
|
||||
urlpatterns = [
|
||||
url('^schema/$', schema_view),
|
||||
...
|
||||
]
|
||||
|
||||
If you visit the API root endpoint in a browser you should now see `corejson`
|
||||
representation become available as an option.
|
||||
|
@ -46,7 +53,7 @@ representation become available as an option.
|
|||
We can also request the schema from the command line, by specifying the desired
|
||||
content type in the `Accept` header.
|
||||
|
||||
$ http http://127.0.0.1:8000/ Accept:application/vnd.coreapi+json
|
||||
$ http http://127.0.0.1:8000/schema/ Accept:application/vnd.coreapi+json
|
||||
HTTP/1.0 200 OK
|
||||
Allow: GET, HEAD, OPTIONS
|
||||
Content-Type: application/vnd.coreapi+json
|
||||
|
@ -91,16 +98,16 @@ Now check that it is available on the command line...
|
|||
|
||||
First we'll load the API schema using the command line client.
|
||||
|
||||
$ coreapi get http://127.0.0.1:8000/
|
||||
<Pastebin API "http://127.0.0.1:8000/">
|
||||
$ coreapi get http://127.0.0.1:8000/schema/
|
||||
<Pastebin API "http://127.0.0.1:8000/schema/">
|
||||
snippets: {
|
||||
highlight(pk)
|
||||
highlight(id)
|
||||
list()
|
||||
retrieve(pk)
|
||||
read(id)
|
||||
}
|
||||
users: {
|
||||
list()
|
||||
retrieve(pk)
|
||||
read(id)
|
||||
}
|
||||
|
||||
We haven't authenticated yet, so right now we're only able to see the read only
|
||||
|
@ -112,7 +119,7 @@ Let's try listing the existing snippets, using the command line client:
|
|||
[
|
||||
{
|
||||
"url": "http://127.0.0.1:8000/snippets/1/",
|
||||
"pk": 1,
|
||||
"id": 1,
|
||||
"highlight": "http://127.0.0.1:8000/snippets/1/highlight/",
|
||||
"owner": "lucy",
|
||||
"title": "Example",
|
||||
|
@ -126,7 +133,7 @@ Let's try listing the existing snippets, using the command line client:
|
|||
Some of the API endpoints require named parameters. For example, to get back
|
||||
the highlight HTML for a particular snippet we need to provide an id.
|
||||
|
||||
$ coreapi action snippets highlight --param pk=1
|
||||
$ coreapi action snippets highlight --param id=1
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN" "http://www.w3.org/TR/html4/strict.dtd">
|
||||
|
||||
<html>
|
||||
|
@ -150,19 +157,19 @@ Now if we fetch the schema again, we should be able to see the full
|
|||
set of available interactions.
|
||||
|
||||
$ coreapi reload
|
||||
Pastebin API "http://127.0.0.1:8000/">
|
||||
Pastebin API "http://127.0.0.1:8000/schema/">
|
||||
snippets: {
|
||||
create(code, [title], [linenos], [language], [style])
|
||||
destroy(pk)
|
||||
highlight(pk)
|
||||
delete(id)
|
||||
highlight(id)
|
||||
list()
|
||||
partial_update(pk, [title], [code], [linenos], [language], [style])
|
||||
retrieve(pk)
|
||||
update(pk, code, [title], [linenos], [language], [style])
|
||||
partial_update(id, [title], [code], [linenos], [language], [style])
|
||||
read(id)
|
||||
update(id, code, [title], [linenos], [language], [style])
|
||||
}
|
||||
users: {
|
||||
list()
|
||||
retrieve(pk)
|
||||
read(id)
|
||||
}
|
||||
|
||||
We're now able to interact with these endpoints. For example, to create a new
|
||||
|
@ -171,7 +178,7 @@ snippet:
|
|||
$ coreapi action snippets create --param title="Example" --param code="print('hello, world')"
|
||||
{
|
||||
"url": "http://127.0.0.1:8000/snippets/7/",
|
||||
"pk": 7,
|
||||
"id": 7,
|
||||
"highlight": "http://127.0.0.1:8000/snippets/7/highlight/",
|
||||
"owner": "lucy",
|
||||
"title": "Example",
|
||||
|
@ -183,7 +190,7 @@ snippet:
|
|||
|
||||
And to delete a snippet:
|
||||
|
||||
$ coreapi action snippets destroy --param pk=7
|
||||
$ coreapi action snippets delete --param id=7
|
||||
|
||||
As well as the command line client, developers can also interact with your
|
||||
API using client libraries. The Python client library is the first of these
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Optional packages which may be used with REST framework.
|
||||
markdown==2.6.4
|
||||
django-guardian==1.4.3
|
||||
django-filter==0.13.0
|
||||
coreapi==1.32.0
|
||||
django-guardian==1.4.6
|
||||
django-filter==0.14.0
|
||||
coreapi==2.0.8
|
||||
|
|
|
@ -8,7 +8,7 @@ ______ _____ _____ _____ __
|
|||
"""
|
||||
|
||||
__title__ = 'Django REST framework'
|
||||
__version__ = '3.4.7'
|
||||
__version__ = '3.5.0'
|
||||
__author__ = 'Tom Christie'
|
||||
__license__ = 'BSD 2-Clause'
|
||||
__copyright__ = 'Copyright 2011-2016 Tom Christie'
|
||||
|
|
|
@ -16,6 +16,9 @@ class AuthTokenSerializer(serializers.Serializer):
|
|||
user = authenticate(username=username, password=password)
|
||||
|
||||
if user:
|
||||
# From Django 1.10 onwards the `authenticate` call simply
|
||||
# returns `None` for is_active=False users.
|
||||
# (Assuming the default `ModelBackend` authentication backend.)
|
||||
if not user.is_active:
|
||||
msg = _('User account is disabled.')
|
||||
raise serializers.ValidationError(msg)
|
||||
|
|
|
@ -23,6 +23,16 @@ except ImportError:
|
|||
from django.utils import importlib # Will be removed in Django 1.9
|
||||
|
||||
|
||||
try:
|
||||
from django.urls import (
|
||||
NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve
|
||||
)
|
||||
except ImportError:
|
||||
from django.core.urlresolvers import ( # Will be removed in Django 2.0
|
||||
NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import urlparse # Python 2.x
|
||||
except ImportError:
|
||||
|
@ -128,6 +138,12 @@ def is_authenticated(user):
|
|||
return user.is_authenticated
|
||||
|
||||
|
||||
def is_anonymous(user):
|
||||
if django.VERSION < (1, 10):
|
||||
return user.is_anonymous()
|
||||
return user.is_anonymous
|
||||
|
||||
|
||||
def get_related_model(field):
|
||||
if django.VERSION < (1, 9):
|
||||
return _resolve_model(field.rel.to)
|
||||
|
@ -178,6 +194,13 @@ except (ImportError, SyntaxError):
|
|||
uritemplate = None
|
||||
|
||||
|
||||
# requests is optional
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
|
||||
# Fixes (#1712). We keep the try/except for the test suite.
|
||||
guardian = None
|
||||
|
@ -200,8 +223,13 @@ try:
|
|||
|
||||
if markdown.version <= '2.2':
|
||||
HEADERID_EXT_PATH = 'headerid'
|
||||
else:
|
||||
LEVEL_PARAM = 'level'
|
||||
elif markdown.version < '2.6':
|
||||
HEADERID_EXT_PATH = 'markdown.extensions.headerid'
|
||||
LEVEL_PARAM = 'level'
|
||||
else:
|
||||
HEADERID_EXT_PATH = 'markdown.extensions.toc'
|
||||
LEVEL_PARAM = 'baselevel'
|
||||
|
||||
def apply_markdown(text):
|
||||
"""
|
||||
|
@ -211,7 +239,7 @@ try:
|
|||
extensions = [HEADERID_EXT_PATH]
|
||||
extension_configs = {
|
||||
HEADERID_EXT_PATH: {
|
||||
'level': '2'
|
||||
LEVEL_PARAM: '2'
|
||||
}
|
||||
}
|
||||
md = markdown.Markdown(
|
||||
|
@ -277,3 +305,11 @@ def template_render(template, context=None, request=None):
|
|||
# backends template, e.g. django.template.backends.django.Template
|
||||
else:
|
||||
return template.render(context, request=request)
|
||||
|
||||
|
||||
def set_many(instance, field, value):
|
||||
if django.VERSION < (1, 10):
|
||||
setattr(instance, field, value)
|
||||
else:
|
||||
field = getattr(instance, field)
|
||||
field.set(value)
|
||||
|
|
|
@ -15,7 +15,7 @@ from django.utils import six
|
|||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
def api_view(http_method_names=None):
|
||||
def api_view(http_method_names=None, exclude_from_schema=False):
|
||||
"""
|
||||
Decorator that converts a function-based view into an APIView subclass.
|
||||
Takes a list of allowed methods for the view as an argument.
|
||||
|
@ -72,6 +72,7 @@ def api_view(http_method_names=None):
|
|||
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
|
||||
APIView.permission_classes)
|
||||
|
||||
WrappedAPIView.exclude_from_schema = exclude_from_schema
|
||||
return WrappedAPIView.as_view()
|
||||
return decorator
|
||||
|
||||
|
|
|
@ -49,18 +49,32 @@ class empty:
|
|||
pass
|
||||
|
||||
|
||||
def is_simple_callable(obj):
|
||||
if six.PY3:
|
||||
def is_simple_callable(obj):
|
||||
"""
|
||||
True if the object is a callable that takes no arguments.
|
||||
"""
|
||||
if not callable(obj):
|
||||
return False
|
||||
|
||||
sig = inspect.signature(obj)
|
||||
params = sig.parameters.values()
|
||||
return all(param.default != param.empty for param in params)
|
||||
|
||||
else:
|
||||
def is_simple_callable(obj):
|
||||
function = inspect.isfunction(obj)
|
||||
method = inspect.ismethod(obj)
|
||||
|
||||
if not (function or method):
|
||||
return False
|
||||
|
||||
if method:
|
||||
is_unbound = obj.im_self is None
|
||||
|
||||
args, _, _, defaults = inspect.getargspec(obj)
|
||||
len_args = len(args) if function else len(args) - 1
|
||||
|
||||
len_args = len(args) if function or is_unbound else len(args) - 1
|
||||
len_defaults = len(defaults) if defaults else 0
|
||||
return len_args <= len_defaults
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from django.utils import six
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.compat import (
|
||||
crispy_forms, distinct, django_filters, guardian, template_render
|
||||
coreapi, crispy_forms, distinct, django_filters, guardian, template_render
|
||||
)
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
@ -72,7 +72,8 @@ class BaseFilterBackend(object):
|
|||
"""
|
||||
raise NotImplementedError(".filter_queryset() must be overridden.")
|
||||
|
||||
def get_fields(self, view):
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
|
||||
|
@ -131,14 +132,21 @@ class DjangoFilterBackend(BaseFilterBackend):
|
|||
template = loader.get_template(self.template)
|
||||
return template_render(template, context)
|
||||
|
||||
def get_fields(self, view):
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
filter_class = getattr(view, 'filter_class', None)
|
||||
if filter_class:
|
||||
return list(filter_class().filters.keys())
|
||||
return [
|
||||
coreapi.Field(name=field_name, required=False, location='query')
|
||||
for field_name in filter_class().filters.keys()
|
||||
]
|
||||
|
||||
filter_fields = getattr(view, 'filter_fields', None)
|
||||
if filter_fields:
|
||||
return filter_fields
|
||||
return [
|
||||
coreapi.Field(name=field_name, required=False, location='query')
|
||||
for field_name in filter_fields
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
@ -231,8 +239,9 @@ class SearchFilter(BaseFilterBackend):
|
|||
template = loader.get_template(self.template)
|
||||
return template_render(template, context)
|
||||
|
||||
def get_fields(self, view):
|
||||
return [self.search_param]
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return [coreapi.Field(name=self.search_param, required=False, location='query')]
|
||||
|
||||
|
||||
class OrderingFilter(BaseFilterBackend):
|
||||
|
@ -348,8 +357,9 @@ class OrderingFilter(BaseFilterBackend):
|
|||
context = self.get_template_context(request, queryset, view)
|
||||
return template_render(template, context)
|
||||
|
||||
def get_fields(self, view):
|
||||
return [self.ordering_param]
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return [coreapi.Field(name=self.ordering_param, required=False, location='query')]
|
||||
|
||||
|
||||
class DjangoObjectPermissionsFilter(BaseFilterBackend):
|
||||
|
|
|
@ -15,7 +15,7 @@ from django.utils import six
|
|||
from django.utils.six.moves.urllib import parse as urlparse
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.compat import template_render
|
||||
from rest_framework.compat import coreapi, template_render
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
|
@ -157,7 +157,8 @@ class BasePagination(object):
|
|||
def get_results(self, data):
|
||||
return data['results']
|
||||
|
||||
def get_fields(self, view):
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return []
|
||||
|
||||
|
||||
|
@ -283,10 +284,16 @@ class PageNumberPagination(BasePagination):
|
|||
context = self.get_html_context()
|
||||
return template_render(template, context)
|
||||
|
||||
def get_fields(self, view):
|
||||
if self.page_size_query_param is None:
|
||||
return [self.page_query_param]
|
||||
return [self.page_query_param, self.page_size_query_param]
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
fields = [
|
||||
coreapi.Field(name=self.page_query_param, required=False, location='query')
|
||||
]
|
||||
if self.page_size_query_param is not None:
|
||||
fields.append([
|
||||
coreapi.Field(name=self.page_size_query_param, required=False, location='query')
|
||||
])
|
||||
return fields
|
||||
|
||||
|
||||
class LimitOffsetPagination(BasePagination):
|
||||
|
@ -415,8 +422,12 @@ class LimitOffsetPagination(BasePagination):
|
|||
context = self.get_html_context()
|
||||
return template_render(template, context)
|
||||
|
||||
def get_fields(self, view):
|
||||
return [self.limit_query_param, self.offset_query_param]
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(name=self.limit_query_param, required=False, location='query'),
|
||||
coreapi.Field(name=self.offset_query_param, required=False, location='query')
|
||||
]
|
||||
|
||||
|
||||
class CursorPagination(BasePagination):
|
||||
|
@ -721,5 +732,8 @@ class CursorPagination(BasePagination):
|
|||
context = self.get_html_context()
|
||||
return template_render(template, context)
|
||||
|
||||
def get_fields(self, view):
|
||||
return [self.cursor_query_param]
|
||||
def get_schema_fields(self, view):
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
return [
|
||||
coreapi.Field(name=self.cursor_query_param, required=False, location='query')
|
||||
]
|
||||
|
|
|
@ -4,9 +4,6 @@ from __future__ import unicode_literals
|
|||
from collections import OrderedDict
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
|
||||
from django.core.urlresolvers import (
|
||||
NoReverseMatch, Resolver404, get_script_prefix, resolve
|
||||
)
|
||||
from django.db.models import Manager
|
||||
from django.db.models.query import QuerySet
|
||||
from django.utils import six
|
||||
|
@ -14,10 +11,14 @@ from django.utils.encoding import python_2_unicode_compatible, smart_text
|
|||
from django.utils.six.moves.urllib import parse as urlparse
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.compat import (
|
||||
NoReverseMatch, Resolver404, get_script_prefix, resolve
|
||||
)
|
||||
from rest_framework.fields import (
|
||||
Field, empty, get_attribute, is_simple_callable, iter_options
|
||||
)
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import html
|
||||
|
||||
|
||||
|
@ -71,14 +72,19 @@ MANY_RELATION_KWARGS = (
|
|||
|
||||
class RelatedField(Field):
|
||||
queryset = None
|
||||
html_cutoff = 1000
|
||||
html_cutoff_text = _('More than {count} items...')
|
||||
html_cutoff = None
|
||||
html_cutoff_text = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.queryset = kwargs.pop('queryset', self.queryset)
|
||||
self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff)
|
||||
self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text)
|
||||
|
||||
self.html_cutoff = kwargs.pop(
|
||||
'html_cutoff',
|
||||
self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF)
|
||||
)
|
||||
self.html_cutoff_text = kwargs.pop(
|
||||
'html_cutoff_text',
|
||||
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
|
||||
)
|
||||
if not method_overridden('get_queryset', RelatedField, self):
|
||||
assert self.queryset is not None or kwargs.get('read_only', None), (
|
||||
'Relational field must provide a `queryset` argument, '
|
||||
|
@ -447,15 +453,20 @@ class ManyRelatedField(Field):
|
|||
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
|
||||
'empty': _('This list may not be empty.')
|
||||
}
|
||||
html_cutoff = 1000
|
||||
html_cutoff_text = _('More than {count} items...')
|
||||
html_cutoff = None
|
||||
html_cutoff_text = None
|
||||
|
||||
def __init__(self, child_relation=None, *args, **kwargs):
|
||||
self.child_relation = child_relation
|
||||
self.allow_empty = kwargs.pop('allow_empty', True)
|
||||
self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff)
|
||||
self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text)
|
||||
|
||||
self.html_cutoff = kwargs.pop(
|
||||
'html_cutoff',
|
||||
self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF)
|
||||
)
|
||||
self.html_cutoff_text = kwargs.pop(
|
||||
'html_cutoff_text',
|
||||
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
|
||||
)
|
||||
assert child_relation is not None, '`child_relation` is a required argument.'
|
||||
super(ManyRelatedField, self).__init__(*args, **kwargs)
|
||||
self.child_relation.bind(field_name='', parent=self)
|
||||
|
|
|
@ -276,6 +276,10 @@ class HTMLFormRenderer(BaseRenderer):
|
|||
'base_template': 'input.html',
|
||||
'input_type': 'number'
|
||||
},
|
||||
serializers.FloatField: {
|
||||
'base_template': 'input.html',
|
||||
'input_type': 'number'
|
||||
},
|
||||
serializers.DateTimeField: {
|
||||
'base_template': 'input.html',
|
||||
'input_type': 'datetime-local'
|
||||
|
@ -809,7 +813,7 @@ class MultiPartRenderer(BaseRenderer):
|
|||
|
||||
|
||||
class CoreJSONRenderer(BaseRenderer):
|
||||
media_type = 'application/vnd.coreapi+json'
|
||||
media_type = 'application/coreapi+json'
|
||||
charset = None
|
||||
format = 'corejson'
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@ Provide urlresolver functions that return fully qualified URLs or view names
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.core.urlresolvers import reverse as django_reverse
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
from django.utils import six
|
||||
from django.utils.functional import lazy
|
||||
|
||||
from rest_framework.compat import reverse as django_reverse
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils.urls import replace_query_param
|
||||
|
||||
|
@ -54,7 +54,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra
|
|||
|
||||
def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
|
||||
"""
|
||||
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
|
||||
Same as `django.urls.reverse`, but optionally takes a request
|
||||
and returns a fully qualified URL, using the request to get the base URL.
|
||||
"""
|
||||
if format is not None:
|
||||
|
|
|
@ -16,13 +16,15 @@ For example, you might have a `urls.py` that looks something like this:
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections import OrderedDict, namedtuple
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
|
||||
from rest_framework import exceptions, renderers, views
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
from rest_framework.renderers import BrowsableAPIRenderer
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
|
@ -83,6 +85,7 @@ class BaseRouter(object):
|
|||
|
||||
|
||||
class SimpleRouter(BaseRouter):
|
||||
|
||||
routes = [
|
||||
# List route.
|
||||
Route(
|
||||
|
@ -258,6 +261,13 @@ class SimpleRouter(BaseRouter):
|
|||
trailing_slash=self.trailing_slash
|
||||
)
|
||||
|
||||
# If there is no prefix, the first part of the url is probably
|
||||
# controlled by project's urls.py and the router is in an app,
|
||||
# so a slash in the beginning will (A) cause Django to give
|
||||
# warnings and (B) generate URLS that will require using '//'.
|
||||
if not prefix and regex[:2] == '^/':
|
||||
regex = '^' + regex[2:]
|
||||
|
||||
view = viewset.as_view(mapping, **route.initkwargs)
|
||||
name = route.name.format(basename=basename)
|
||||
ret.append(url(regex, view, name=name))
|
||||
|
@ -273,9 +283,15 @@ class DefaultRouter(SimpleRouter):
|
|||
include_root_view = True
|
||||
include_format_suffixes = True
|
||||
root_view_name = 'api-root'
|
||||
default_schema_renderers = [renderers.CoreJSONRenderer]
|
||||
default_schema_renderers = [renderers.CoreJSONRenderer, BrowsableAPIRenderer]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'schema_title' in kwargs:
|
||||
warnings.warn(
|
||||
"Including a schema directly via a router is now pending "
|
||||
"deprecation. Use `get_schema_view()` instead.",
|
||||
PendingDeprecationWarning
|
||||
)
|
||||
if 'schema_renderers' in kwargs:
|
||||
assert 'schema_title' in kwargs, 'Missing "schema_title" argument.'
|
||||
if 'schema_url' in kwargs:
|
||||
|
@ -289,42 +305,44 @@ class DefaultRouter(SimpleRouter):
|
|||
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
|
||||
super(DefaultRouter, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_schema_root_view(self, api_urls=None):
|
||||
"""
|
||||
Return a schema root view.
|
||||
"""
|
||||
schema_renderers = self.schema_renderers
|
||||
schema_generator = SchemaGenerator(
|
||||
title=self.schema_title,
|
||||
url=self.schema_url,
|
||||
patterns=api_urls
|
||||
)
|
||||
|
||||
class APISchemaView(views.APIView):
|
||||
_ignore_model_permissions = True
|
||||
exclude_from_schema = True
|
||||
renderer_classes = schema_renderers
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = schema_generator.get_schema(request)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
return APISchemaView.as_view()
|
||||
|
||||
def get_api_root_view(self, api_urls=None):
|
||||
"""
|
||||
Return a view to use as the API root.
|
||||
Return a basic root view.
|
||||
"""
|
||||
api_root_dict = OrderedDict()
|
||||
list_name = self.routes[0].name
|
||||
for prefix, viewset, basename in self.registry:
|
||||
api_root_dict[prefix] = list_name.format(basename=basename)
|
||||
|
||||
view_renderers = list(self.root_renderers)
|
||||
schema_media_types = []
|
||||
|
||||
if api_urls and self.schema_title:
|
||||
view_renderers += list(self.schema_renderers)
|
||||
schema_generator = SchemaGenerator(
|
||||
title=self.schema_title,
|
||||
url=self.schema_url,
|
||||
patterns=api_urls
|
||||
)
|
||||
schema_media_types = [
|
||||
renderer.media_type
|
||||
for renderer in self.schema_renderers
|
||||
]
|
||||
|
||||
class APIRoot(views.APIView):
|
||||
class APIRootView(views.APIView):
|
||||
_ignore_model_permissions = True
|
||||
renderer_classes = view_renderers
|
||||
exclude_from_schema = True
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
if request.accepted_renderer.media_type in schema_media_types:
|
||||
# Return a schema response.
|
||||
schema = schema_generator.get_schema(request)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
# Return a plain {"name": "hyperlink"} response.
|
||||
ret = OrderedDict()
|
||||
namespace = request.resolver_match.namespace
|
||||
|
@ -345,7 +363,7 @@ class DefaultRouter(SimpleRouter):
|
|||
|
||||
return Response(ret)
|
||||
|
||||
return APIRoot.as_view()
|
||||
return APIRootView.as_view()
|
||||
|
||||
def get_urls(self):
|
||||
"""
|
||||
|
@ -355,6 +373,9 @@ class DefaultRouter(SimpleRouter):
|
|||
urls = super(DefaultRouter, self).get_urls()
|
||||
|
||||
if self.include_root_view:
|
||||
if self.schema_title:
|
||||
view = self.get_schema_root_view(api_urls=urls)
|
||||
else:
|
||||
view = self.get_api_root_view(api_urls=urls)
|
||||
root_url = url(r'^$', view, name=self.root_view_name)
|
||||
urls.append(root_url)
|
||||
|
|
|
@ -1,26 +1,45 @@
|
|||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.admindocs.views import simplify_regex
|
||||
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
|
||||
from django.utils import six
|
||||
from django.utils.encoding import force_text
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import coreapi, uritemplate, urlparse
|
||||
from rest_framework import exceptions, renderers, serializers
|
||||
from rest_framework.compat import (
|
||||
RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse
|
||||
)
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import formatting
|
||||
from rest_framework.utils.field_mapping import ClassLookupDict
|
||||
from rest_framework.utils.model_meta import _get_pk
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
def as_query_fields(items):
|
||||
"""
|
||||
Take a list of Fields and plain strings.
|
||||
Convert any pain strings into `location='query'` Field instances.
|
||||
"""
|
||||
return [
|
||||
item if isinstance(item, coreapi.Field) else coreapi.Field(name=item, required=False, location='query')
|
||||
for item in items
|
||||
]
|
||||
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
|
||||
|
||||
types_lookup = ClassLookupDict({
|
||||
serializers.Field: 'string',
|
||||
serializers.IntegerField: 'integer',
|
||||
serializers.FloatField: 'number',
|
||||
serializers.DecimalField: 'number',
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.FileField: 'file',
|
||||
serializers.MultipleChoiceField: 'array',
|
||||
serializers.ManyRelatedField: 'array',
|
||||
serializers.Serializer: 'object',
|
||||
serializers.ListSerializer: 'array'
|
||||
})
|
||||
|
||||
|
||||
def get_pk_name(model):
|
||||
meta = model._meta.concrete_model._meta
|
||||
return _get_pk(meta).name
|
||||
|
||||
|
||||
def is_api_view(callback):
|
||||
|
@ -31,106 +50,92 @@ def is_api_view(callback):
|
|||
return (cls is not None) and issubclass(cls, APIView)
|
||||
|
||||
|
||||
class SchemaGenerator(object):
|
||||
default_mapping = {
|
||||
'get': 'read',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
known_actions = (
|
||||
'create', 'read', 'retrieve', 'list',
|
||||
'update', 'partial_update', 'destroy'
|
||||
)
|
||||
def insert_into(target, keys, value):
|
||||
"""
|
||||
Nested dictionary insertion.
|
||||
|
||||
def __init__(self, title=None, url=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
>>> example = {}
|
||||
>>> insert_into(example, ['a', 'b', 'c'], 123)
|
||||
>>> example
|
||||
{'a': {'b': {'c': 123}}}
|
||||
"""
|
||||
for key in keys[:-1]:
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys[-1]] = value
|
||||
|
||||
if patterns is None and urlconf is not None:
|
||||
|
||||
def is_custom_action(action):
|
||||
return action not in set([
|
||||
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
|
||||
])
|
||||
|
||||
|
||||
def is_list_view(path, method, view):
|
||||
"""
|
||||
Return True if the given path/method appears to represent a list view.
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have an explicitly defined action, which we can inspect.
|
||||
return view.action == 'list'
|
||||
|
||||
if method.lower() != 'get':
|
||||
return False
|
||||
path_components = path.strip('/').split('/')
|
||||
if path_components and '{' in path_components[-1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def endpoint_ordering(endpoint):
|
||||
path, method, callback = endpoint
|
||||
method_priority = {
|
||||
'GET': 0,
|
||||
'POST': 1,
|
||||
'PUT': 2,
|
||||
'PATCH': 3,
|
||||
'DELETE': 4
|
||||
}.get(method, 5)
|
||||
return (path, method_priority)
|
||||
|
||||
|
||||
class EndpointInspector(object):
|
||||
"""
|
||||
A class to determine the available API endpoints that a project exposes.
|
||||
"""
|
||||
def __init__(self, patterns=None, urlconf=None):
|
||||
if patterns is None:
|
||||
if urlconf is None:
|
||||
# Use the default Django URL conf
|
||||
urlconf = settings.ROOT_URLCONF
|
||||
|
||||
# Load the given URLconf module
|
||||
if isinstance(urlconf, six.string_types):
|
||||
urls = import_module(urlconf)
|
||||
else:
|
||||
urls = urlconf
|
||||
self.patterns = urls.urlpatterns
|
||||
elif patterns is None and urlconf is None:
|
||||
urls = import_module(settings.ROOT_URLCONF)
|
||||
self.patterns = urls.urlpatterns
|
||||
else:
|
||||
patterns = urls.urlpatterns
|
||||
|
||||
self.patterns = patterns
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None):
|
||||
if self.endpoints is None:
|
||||
self.endpoints = self.get_api_endpoints(self.patterns)
|
||||
|
||||
links = []
|
||||
for path, method, category, action, callback in self.endpoints:
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
try:
|
||||
view.check_permissions(view.request)
|
||||
except exceptions.APIException:
|
||||
continue
|
||||
else:
|
||||
view.request = None
|
||||
|
||||
link = self.get_link(path, method, callback, view)
|
||||
links.append((category, action, link))
|
||||
|
||||
if not links:
|
||||
return None
|
||||
|
||||
# Generate the schema content structure, eg:
|
||||
# {'users': {'list': Link()}}
|
||||
content = {}
|
||||
for category, action, link in links:
|
||||
if category is None:
|
||||
content[action] = link
|
||||
elif category in content:
|
||||
content[category][action] = link
|
||||
else:
|
||||
content[category] = {action: link}
|
||||
|
||||
# Return the schema document.
|
||||
return coreapi.Document(title=self.title, content=content, url=self.url)
|
||||
|
||||
def get_api_endpoints(self, patterns, prefix=''):
|
||||
def get_api_endpoints(self, patterns=None, prefix=''):
|
||||
"""
|
||||
Return a list of all available API endpoints by inspecting the URL conf.
|
||||
"""
|
||||
if patterns is None:
|
||||
patterns = self.patterns
|
||||
|
||||
api_endpoints = []
|
||||
|
||||
for pattern in patterns:
|
||||
path_regex = prefix + pattern.regex.pattern
|
||||
if isinstance(pattern, RegexURLPattern):
|
||||
path = self.get_path(path_regex)
|
||||
path = self.get_path_from_regex(path_regex)
|
||||
callback = pattern.callback
|
||||
if self.should_include_endpoint(path, callback):
|
||||
for method in self.get_allowed_methods(callback):
|
||||
action = self.get_action(path, method, callback)
|
||||
category = self.get_category(path, method, callback, action)
|
||||
endpoint = (path, method, category, action, callback)
|
||||
endpoint = (path, method, callback)
|
||||
api_endpoints.append(endpoint)
|
||||
|
||||
elif isinstance(pattern, RegexURLResolver):
|
||||
|
@ -140,9 +145,11 @@ class SchemaGenerator(object):
|
|||
)
|
||||
api_endpoints.extend(nested_endpoints)
|
||||
|
||||
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
|
||||
|
||||
return api_endpoints
|
||||
|
||||
def get_path(self, path_regex):
|
||||
def get_path_from_regex(self, path_regex):
|
||||
"""
|
||||
Given a URL conf regex, return a URI template string.
|
||||
"""
|
||||
|
@ -160,9 +167,6 @@ class SchemaGenerator(object):
|
|||
if path.endswith('.{format}') or path.endswith('.{format}/'):
|
||||
return False # Ignore .json style URLs.
|
||||
|
||||
if path == '/':
|
||||
return False # Ignore the root endpoint.
|
||||
|
||||
return True
|
||||
|
||||
def get_allowed_methods(self, callback):
|
||||
|
@ -177,60 +181,190 @@ class SchemaGenerator(object):
|
|||
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
|
||||
]
|
||||
|
||||
def get_action(self, path, method, callback):
|
||||
"""
|
||||
Return a descriptive action string for the endpoint, eg. 'list'.
|
||||
"""
|
||||
actions = getattr(callback, 'actions', self.default_mapping)
|
||||
return actions[method.lower()]
|
||||
|
||||
def get_category(self, path, method, callback, action):
|
||||
"""
|
||||
Return a descriptive category string for the endpoint, eg. 'users'.
|
||||
class SchemaGenerator(object):
|
||||
# Map HTTP methods onto actions.
|
||||
default_mapping = {
|
||||
'get': 'retrieve',
|
||||
'post': 'create',
|
||||
'put': 'update',
|
||||
'patch': 'partial_update',
|
||||
'delete': 'destroy',
|
||||
}
|
||||
endpoint_inspector_cls = EndpointInspector
|
||||
|
||||
Examples of category/action pairs that should be generated for various
|
||||
endpoints:
|
||||
# Map the method names we use for viewset actions onto external schema names.
|
||||
# These give us names that are more suitable for the external representation.
|
||||
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
|
||||
coerce_method_names = None
|
||||
|
||||
/users/ [users][list], [users][create]
|
||||
/users/{pk}/ [users][read], [users][update], [users][destroy]
|
||||
/users/enabled/ [users][enabled] (custom action)
|
||||
/users/{pk}/star/ [users][star] (custom action)
|
||||
/users/{pk}/groups/ [groups][list], [groups][create]
|
||||
/users/{pk}/groups/{pk}/ [groups][read], [groups][update], [groups][destroy]
|
||||
# 'pk' isn't great as an externally exposed name for an identifier,
|
||||
# so by default we prefer to use the actual model field name for schemas.
|
||||
# Set by 'SCHEMA_COERCE_PATH_PK'.
|
||||
coerce_path_pk = None
|
||||
|
||||
def __init__(self, title=None, url=None, patterns=None, urlconf=None):
|
||||
assert coreapi, '`coreapi` must be installed for schema support.'
|
||||
|
||||
if url and not url.endswith('/'):
|
||||
url += '/'
|
||||
|
||||
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
|
||||
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
|
||||
|
||||
self.patterns = patterns
|
||||
self.urlconf = urlconf
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.endpoints = None
|
||||
|
||||
def get_schema(self, request=None):
|
||||
"""
|
||||
path_components = path.strip('/').split('/')
|
||||
path_components = [
|
||||
component for component in path_components
|
||||
if '{' not in component
|
||||
]
|
||||
if action in self.known_actions:
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
idx = -1
|
||||
Generate a `coreapi.Document` representing the API schema.
|
||||
"""
|
||||
if self.endpoints is None:
|
||||
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
||||
self.endpoints = inspector.get_api_endpoints()
|
||||
|
||||
links = self.get_links(request)
|
||||
if not links:
|
||||
return None
|
||||
return coreapi.Document(title=self.title, url=self.url, content=links)
|
||||
|
||||
def get_links(self, request=None):
|
||||
"""
|
||||
Return a dictionary containing all the links that should be
|
||||
included in the API schema.
|
||||
"""
|
||||
links = OrderedDict()
|
||||
|
||||
# Generate (path, method, view) given (path, method, callback).
|
||||
paths = []
|
||||
view_endpoints = []
|
||||
for path, method, callback in self.endpoints:
|
||||
view = self.create_view(callback, method, request)
|
||||
if getattr(view, 'exclude_from_schema', False):
|
||||
continue
|
||||
path = self.coerce_path(path, method, view)
|
||||
paths.append(path)
|
||||
view_endpoints.append((path, method, view))
|
||||
|
||||
# Only generate the path prefix for paths that will be included
|
||||
prefix = self.determine_path_prefix(paths)
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
link = self.get_link(path, method, view)
|
||||
subpath = path[len(prefix):]
|
||||
keys = self.get_keys(subpath, method, view)
|
||||
insert_into(links, keys, link)
|
||||
return links
|
||||
|
||||
# Methods used when we generate a view instance from the raw callback...
|
||||
|
||||
def determine_path_prefix(self, paths):
|
||||
"""
|
||||
Given a list of all paths, return the common prefix which should be
|
||||
discounted when generating a schema structure.
|
||||
|
||||
This will be the longest common string that does not include that last
|
||||
component of the URL, or the last component before a path parameter.
|
||||
|
||||
For example:
|
||||
|
||||
/api/v1/users/
|
||||
/api/v1/users/{pk}/
|
||||
|
||||
The path prefix is '/api/v1/'
|
||||
"""
|
||||
prefixes = []
|
||||
for path in paths:
|
||||
components = path.strip('/').split('/')
|
||||
initial_components = []
|
||||
for component in components:
|
||||
if '{' in component:
|
||||
break
|
||||
initial_components.append(component)
|
||||
prefix = '/'.join(initial_components[:-1])
|
||||
if not prefix:
|
||||
# We can just break early in the case that there's at least
|
||||
# one URL that doesn't have a path prefix.
|
||||
return '/'
|
||||
prefixes.append('/' + prefix + '/')
|
||||
return os.path.commonprefix(prefixes)
|
||||
|
||||
def create_view(self, callback, method, request=None):
|
||||
"""
|
||||
Given a callback, return an actual view instance.
|
||||
"""
|
||||
view = callback.cls()
|
||||
for attr, val in getattr(callback, 'initkwargs', {}).items():
|
||||
setattr(view, attr, val)
|
||||
view.args = ()
|
||||
view.kwargs = {}
|
||||
view.format_kwarg = None
|
||||
view.request = None
|
||||
view.action_map = getattr(callback, 'actions', None)
|
||||
|
||||
actions = getattr(callback, 'actions', None)
|
||||
if actions is not None:
|
||||
if method == 'OPTIONS':
|
||||
view.action = 'metadata'
|
||||
else:
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
idx = -2
|
||||
view.action = actions.get(method.lower())
|
||||
|
||||
if request is not None:
|
||||
view.request = clone_request(request, method)
|
||||
|
||||
return view
|
||||
|
||||
def has_view_permissions(self, path, method, view):
|
||||
"""
|
||||
Return `True` if the incoming request has the correct view permissions.
|
||||
"""
|
||||
if view.request is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
return path_components[idx]
|
||||
except IndexError:
|
||||
return None
|
||||
view.check_permissions(view.request)
|
||||
except exceptions.APIException:
|
||||
return False
|
||||
return True
|
||||
|
||||
def coerce_path(self, path, method, view):
|
||||
"""
|
||||
Coerce {pk} path arguments into the name of the model field,
|
||||
where possible. This is cleaner for an external representation.
|
||||
(Ie. "this is an identifier", not "this is a database primary key")
|
||||
"""
|
||||
if not self.coerce_path_pk or '{pk}' not in path:
|
||||
return path
|
||||
model = getattr(getattr(view, 'queryset', None), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
# Methods for generating each individual `Link` instance...
|
||||
|
||||
def get_link(self, path, method, callback, view):
|
||||
def get_link(self, path, method, view):
|
||||
"""
|
||||
Return a `coreapi.Link` instance for the given endpoint.
|
||||
"""
|
||||
fields = self.get_path_fields(path, method, callback, view)
|
||||
fields += self.get_serializer_fields(path, method, callback, view)
|
||||
fields += self.get_pagination_fields(path, method, callback, view)
|
||||
fields += self.get_filter_fields(path, method, callback, view)
|
||||
fields = self.get_path_fields(path, method, view)
|
||||
fields += self.get_serializer_fields(path, method, view)
|
||||
fields += self.get_pagination_fields(path, method, view)
|
||||
fields += self.get_filter_fields(path, method, view)
|
||||
|
||||
if fields and any([field.location in ('form', 'body') for field in fields]):
|
||||
encoding = self.get_encoding(path, method, callback, view)
|
||||
encoding = self.get_encoding(path, method, view)
|
||||
else:
|
||||
encoding = None
|
||||
|
||||
description = self.get_description(path, method, view)
|
||||
|
||||
if self.url and path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
||||
|
@ -238,10 +372,44 @@ class SchemaGenerator(object):
|
|||
url=urlparse.urljoin(self.url, path),
|
||||
action=method.lower(),
|
||||
encoding=encoding,
|
||||
fields=fields
|
||||
fields=fields,
|
||||
description=description
|
||||
)
|
||||
|
||||
def get_encoding(self, path, method, callback, view):
|
||||
def get_description(self, path, method, view):
|
||||
"""
|
||||
Determine a link description.
|
||||
|
||||
This will be based on the method docstring if one exists,
|
||||
or else the class docstring.
|
||||
"""
|
||||
method_name = getattr(view, 'action', method.lower())
|
||||
method_docstring = getattr(view, method_name, None).__doc__
|
||||
if method_docstring:
|
||||
# An explicit docstring on the method or action.
|
||||
return formatting.dedent(smart_text(method_docstring))
|
||||
|
||||
description = view.get_view_description()
|
||||
lines = [line.strip() for line in description.splitlines()]
|
||||
current_section = ''
|
||||
sections = {'': ''}
|
||||
|
||||
for line in lines:
|
||||
if header_regex.match(line):
|
||||
current_section, seperator, lead = line.partition(':')
|
||||
sections[current_section] = lead.strip()
|
||||
else:
|
||||
sections[current_section] += line + '\n'
|
||||
|
||||
header = getattr(view, 'action', method.lower())
|
||||
if header in sections:
|
||||
return sections[header].strip()
|
||||
if header in self.coerce_method_names:
|
||||
if self.coerce_method_names[header] in sections:
|
||||
return sections[self.coerce_method_names[header]].strip()
|
||||
return sections[''].strip()
|
||||
|
||||
def get_encoding(self, path, method, view):
|
||||
"""
|
||||
Return the 'encoding' parameter to use for a given endpoint.
|
||||
"""
|
||||
|
@ -262,7 +430,7 @@ class SchemaGenerator(object):
|
|||
|
||||
return None
|
||||
|
||||
def get_path_fields(self, path, method, callback, view):
|
||||
def get_path_fields(self, path, method, view):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
templated path variables.
|
||||
|
@ -275,7 +443,7 @@ class SchemaGenerator(object):
|
|||
|
||||
return fields
|
||||
|
||||
def get_serializer_fields(self, path, method, callback, view):
|
||||
def get_serializer_fields(self, path, method, view):
|
||||
"""
|
||||
Return a list of `coreapi.Field` instances corresponding to any
|
||||
request body input, as determined by the serializer class.
|
||||
|
@ -289,7 +457,14 @@ class SchemaGenerator(object):
|
|||
serializer = view.get_serializer()
|
||||
|
||||
if isinstance(serializer, serializers.ListSerializer):
|
||||
return [coreapi.Field(name='data', location='body', required=True)]
|
||||
return [
|
||||
coreapi.Field(
|
||||
name='data',
|
||||
location='body',
|
||||
required=True,
|
||||
type='array'
|
||||
)
|
||||
]
|
||||
|
||||
if not isinstance(serializer, serializers.Serializer):
|
||||
return []
|
||||
|
@ -305,36 +480,104 @@ class SchemaGenerator(object):
|
|||
name=field.source,
|
||||
location='form',
|
||||
required=required,
|
||||
description=description
|
||||
description=description,
|
||||
type=types_lookup[field]
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def get_pagination_fields(self, path, method, callback, view):
|
||||
if method != 'GET':
|
||||
return []
|
||||
|
||||
if hasattr(callback, 'actions') and ('list' not in callback.actions.values()):
|
||||
def get_pagination_fields(self, path, method, view):
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
if not getattr(view, 'pagination_class', None):
|
||||
return []
|
||||
|
||||
paginator = view.pagination_class()
|
||||
return as_query_fields(paginator.get_fields(view))
|
||||
return paginator.get_schema_fields(view)
|
||||
|
||||
def get_filter_fields(self, path, method, callback, view):
|
||||
if method != 'GET':
|
||||
def get_filter_fields(self, path, method, view):
|
||||
if not is_list_view(path, method, view):
|
||||
return []
|
||||
|
||||
if hasattr(callback, 'actions') and ('list' not in callback.actions.values()):
|
||||
return []
|
||||
|
||||
if not hasattr(view, 'filter_backends'):
|
||||
if not getattr(view, 'filter_backends', None):
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in view.filter_backends:
|
||||
fields += as_query_fields(filter_backend().get_fields(view))
|
||||
fields += filter_backend().get_schema_fields(view)
|
||||
return fields
|
||||
|
||||
# Method for generating the link layout....
|
||||
|
||||
def get_keys(self, subpath, method, view):
|
||||
"""
|
||||
Return a list of keys that should be used to layout a link within
|
||||
the schema document.
|
||||
|
||||
/users/ ("users", "list"), ("users", "create")
|
||||
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
|
||||
/users/enabled/ ("users", "enabled") # custom viewset list action
|
||||
/users/{pk}/star/ ("users", "star") # custom viewset detail action
|
||||
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
|
||||
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
|
||||
"""
|
||||
if hasattr(view, 'action'):
|
||||
# Viewsets have explicitly named actions.
|
||||
action = view.action
|
||||
else:
|
||||
# Views have no associated action, so we determine one from the method.
|
||||
if is_list_view(subpath, method, view):
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.default_mapping[method.lower()]
|
||||
|
||||
named_path_components = [
|
||||
component for component
|
||||
in subpath.strip('/').split('/')
|
||||
if '{' not in component
|
||||
]
|
||||
|
||||
if is_custom_action(action):
|
||||
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
|
||||
if len(view.action_map) > 1:
|
||||
action = self.default_mapping[method.lower()]
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
return named_path_components + [action]
|
||||
else:
|
||||
return named_path_components[:-1] + [action]
|
||||
|
||||
if action in self.coerce_method_names:
|
||||
action = self.coerce_method_names[action]
|
||||
|
||||
# Default action, eg "/users/", "/users/{pk}/"
|
||||
return named_path_components + [action]
|
||||
|
||||
|
||||
def get_schema_view(title=None, url=None, renderer_classes=None):
|
||||
"""
|
||||
Return a schema view.
|
||||
"""
|
||||
generator = SchemaGenerator(title=title, url=url)
|
||||
if renderer_classes is None:
|
||||
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
|
||||
rclasses = [renderers.CoreJSONRenderer, renderers.BrowsableAPIRenderer]
|
||||
else:
|
||||
rclasses = [renderers.CoreJSONRenderer]
|
||||
else:
|
||||
rclasses = renderer_classes
|
||||
|
||||
class SchemaView(APIView):
|
||||
_ignore_model_permissions = True
|
||||
exclude_from_schema = True
|
||||
renderer_classes = rclasses
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
schema = generator.get_schema(request)
|
||||
if schema is None:
|
||||
raise exceptions.PermissionDenied()
|
||||
return Response(schema)
|
||||
|
||||
return SchemaView.as_view()
|
||||
|
|
|
@ -13,7 +13,6 @@ response content is handled by parsers and renderers.
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import DurationField as ModelDurationField
|
||||
|
@ -23,7 +22,7 @@ from django.utils.functional import cached_property
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from rest_framework.compat import JSONField as ModelJSONField
|
||||
from rest_framework.compat import postgres_fields, unicode_to_repr
|
||||
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
|
||||
from rest_framework.utils import model_meta
|
||||
from rest_framework.utils.field_mapping import (
|
||||
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
|
||||
|
@ -892,18 +891,22 @@ class ModelSerializer(Serializer):
|
|||
# Save many-to-many relationships after the instance is created.
|
||||
if many_to_many:
|
||||
for field_name, value in many_to_many.items():
|
||||
setattr(instance, field_name, value)
|
||||
set_many(instance, field_name, value)
|
||||
|
||||
return instance
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
raise_errors_on_nested_writes('update', self, validated_data)
|
||||
info = model_meta.get_field_info(instance)
|
||||
|
||||
# Simply set each attribute on the instance, and then save it.
|
||||
# Note that unlike `.create()` we don't need to treat many-to-many
|
||||
# relationships as being a special case. During updates we already
|
||||
# have an instance pk for the relationships to be associated with.
|
||||
for attr, value in validated_data.items():
|
||||
if attr in info.relations and info.relations[attr].to_many:
|
||||
set_many(instance, attr, value)
|
||||
else:
|
||||
setattr(instance, attr, value)
|
||||
instance.save()
|
||||
|
||||
|
@ -1012,15 +1015,13 @@ class ModelSerializer(Serializer):
|
|||
)
|
||||
)
|
||||
|
||||
if fields is None and exclude is None:
|
||||
warnings.warn(
|
||||
"Creating a ModelSerializer without either the 'fields' "
|
||||
"attribute or the 'exclude' attribute is deprecated "
|
||||
"since 3.3.0. Add an explicit fields = '__all__' to the "
|
||||
assert not (fields is None and exclude is None), (
|
||||
"Creating a ModelSerializer without either the 'fields' attribute "
|
||||
"or the 'exclude' attribute has been deprecated since 3.3.0, "
|
||||
"and is now disallowed. Add an explicit fields = '__all__' to the "
|
||||
"{serializer_class} serializer.".format(
|
||||
serializer_class=self.__class__.__name__
|
||||
),
|
||||
DeprecationWarning
|
||||
)
|
||||
|
||||
if fields == ALL_FIELDS:
|
||||
|
|
|
@ -111,6 +111,17 @@ DEFAULTS = {
|
|||
'COMPACT_JSON': True,
|
||||
'COERCE_DECIMAL_TO_STRING': True,
|
||||
'UPLOADED_FILES_USE_URL': True,
|
||||
|
||||
# Browseable API
|
||||
'HTML_SELECT_CUTOFF': 1000,
|
||||
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
|
||||
|
||||
# Schemas
|
||||
'SCHEMA_COERCE_PATH_PK': True,
|
||||
'SCHEMA_COERCE_METHOD_NAMES': {
|
||||
'retrieve': 'read',
|
||||
'destroy': 'delete'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -3,14 +3,13 @@ from __future__ import absolute_import, unicode_literals
|
|||
import re
|
||||
|
||||
from django import template
|
||||
from django.core.urlresolvers import NoReverseMatch, reverse
|
||||
from django.template import loader
|
||||
from django.utils import six
|
||||
from django.utils.encoding import force_text, iri_to_uri
|
||||
from django.utils.html import escape, format_html, smart_urlquote
|
||||
from django.utils.safestring import SafeData, mark_safe
|
||||
|
||||
from rest_framework.compat import template_render
|
||||
from rest_framework.compat import NoReverseMatch, reverse, template_render
|
||||
from rest_framework.renderers import HTMLFormRenderer
|
||||
from rest_framework.utils.urls import replace_query_param
|
||||
|
||||
|
|
|
@ -4,7 +4,11 @@
|
|||
# to make it harder for the user to import the wrong thing without realizing.
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import io
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.handlers.wsgi import WSGIHandler
|
||||
from django.test import testcases
|
||||
from django.test.client import Client as DjangoClient
|
||||
from django.test.client import RequestFactory as DjangoRequestFactory
|
||||
|
@ -13,6 +17,7 @@ from django.utils import six
|
|||
from django.utils.encoding import force_bytes
|
||||
from django.utils.http import urlencode
|
||||
|
||||
from rest_framework.compat import coreapi, requests
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
|
@ -21,6 +26,128 @@ def force_authenticate(request, user=None, token=None):
|
|||
request._force_auth_token = token
|
||||
|
||||
|
||||
if requests is not None:
|
||||
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
||||
def get_all(self, key, default):
|
||||
return self.getheaders(key)
|
||||
|
||||
class MockOriginalResponse(object):
|
||||
def __init__(self, headers):
|
||||
self.msg = HeaderDict(headers)
|
||||
self.closed = False
|
||||
|
||||
def isclosed(self):
|
||||
return self.closed
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
|
||||
"""
|
||||
A transport adapter for `requests`, that makes requests via the
|
||||
Django WSGI app, rather than making actual HTTP requests over the network.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.app = WSGIHandler()
|
||||
self.factory = DjangoRequestFactory()
|
||||
|
||||
def get_environ(self, request):
|
||||
"""
|
||||
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
||||
"""
|
||||
method = request.method
|
||||
url = request.url
|
||||
kwargs = {}
|
||||
|
||||
# Set request content, if any exists.
|
||||
if request.body is not None:
|
||||
if hasattr(request.body, 'read'):
|
||||
kwargs['data'] = request.body.read()
|
||||
else:
|
||||
kwargs['data'] = request.body
|
||||
if 'content-type' in request.headers:
|
||||
kwargs['content_type'] = request.headers['content-type']
|
||||
|
||||
# Set request headers.
|
||||
for key, value in request.headers.items():
|
||||
key = key.upper()
|
||||
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
||||
continue
|
||||
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
||||
|
||||
return self.factory.generic(method, url, **kwargs).environ
|
||||
|
||||
def send(self, request, *args, **kwargs):
|
||||
"""
|
||||
Make an outgoing request to the Django WSGI application.
|
||||
"""
|
||||
raw_kwargs = {}
|
||||
|
||||
def start_response(wsgi_status, wsgi_headers):
|
||||
status, _, reason = wsgi_status.partition(' ')
|
||||
raw_kwargs['status'] = int(status)
|
||||
raw_kwargs['reason'] = reason
|
||||
raw_kwargs['headers'] = wsgi_headers
|
||||
raw_kwargs['version'] = 11
|
||||
raw_kwargs['preload_content'] = False
|
||||
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
|
||||
|
||||
# Make the outgoing request via WSGI.
|
||||
environ = self.get_environ(request)
|
||||
wsgi_response = self.app(environ, start_response)
|
||||
|
||||
# Build the underlying urllib3.HTTPResponse
|
||||
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
|
||||
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
|
||||
|
||||
# Build the requests.Response
|
||||
return self.build_response(request, raw)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
class NoExternalRequestsAdapter(requests.adapters.HTTPAdapter):
|
||||
def send(self, request, *args, **kwargs):
|
||||
msg = (
|
||||
'RequestsClient refusing to make an outgoing network request '
|
||||
'to "%s". Only "testserver" or hostnames in your ALLOWED_HOSTS '
|
||||
'setting are valid.' % request.url
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
class RequestsClient(requests.Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RequestsClient, self).__init__(*args, **kwargs)
|
||||
adapter = DjangoTestAdapter()
|
||||
self.mount('http://', adapter)
|
||||
self.mount('https://', adapter)
|
||||
|
||||
def request(self, method, url, *args, **kwargs):
|
||||
if ':' not in url:
|
||||
raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
|
||||
return super(RequestsClient, self).request(method, url, *args, **kwargs)
|
||||
|
||||
else:
|
||||
def RequestsClient(*args, **kwargs):
|
||||
raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
|
||||
|
||||
|
||||
if coreapi is not None:
|
||||
class CoreAPIClient(coreapi.Client):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._session = RequestsClient()
|
||||
kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
|
||||
return super(CoreAPIClient, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self._session
|
||||
|
||||
else:
|
||||
def CoreAPIClient(*args, **kwargs):
|
||||
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
|
||||
|
||||
|
||||
class APIRequestFactory(DjangoRequestFactory):
|
||||
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
||||
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.conf.urls import include, url
|
||||
from django.core.urlresolvers import RegexURLResolver
|
||||
|
||||
from rest_framework.compat import RegexURLResolver
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.core.urlresolvers import get_script_prefix, resolve
|
||||
from rest_framework.compat import get_script_prefix, resolve
|
||||
|
||||
|
||||
def get_breadcrumbs(url, request=None):
|
||||
|
|
|
@ -110,6 +110,9 @@ class APIView(View):
|
|||
# Allow dependency injection of other settings to make testing easier.
|
||||
settings = api_settings
|
||||
|
||||
# Mark the view as being included or excluded from schema generation.
|
||||
exclude_from_schema = False
|
||||
|
||||
@classmethod
|
||||
def as_view(cls, **initkwargs):
|
||||
"""
|
||||
|
@ -129,6 +132,7 @@ class APIView(View):
|
|||
|
||||
view = super(APIView, cls).as_view(**initkwargs)
|
||||
view.cls = cls
|
||||
view.initkwargs = initkwargs
|
||||
|
||||
# Note: session based authentication is explicitly CSRF validated,
|
||||
# all other authentication is CSRF exempt.
|
||||
|
|
|
@ -11,6 +11,7 @@ factory = APIRequestFactory()
|
|||
class BasicSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = BasicModel
|
||||
fields = '__all__'
|
||||
|
||||
|
||||
class ManyPostView(generics.GenericAPIView):
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
def pytest_configure():
|
||||
from django.conf import settings
|
||||
|
||||
MIDDLEWARE = (
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
)
|
||||
|
||||
settings.configure(
|
||||
DEBUG_PROPAGATE_EXCEPTIONS=True,
|
||||
DATABASES={
|
||||
|
@ -21,12 +28,8 @@ def pytest_configure():
|
|||
'APP_DIRS': True,
|
||||
},
|
||||
],
|
||||
MIDDLEWARE_CLASSES=(
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
),
|
||||
MIDDLEWARE=MIDDLEWARE,
|
||||
MIDDLEWARE_CLASSES=MIDDLEWARE,
|
||||
INSTALLED_APPS=(
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
|
|
474
tests/test_api_client.py
Normal file
474
tests/test_api_client.py
Normal file
|
@ -0,0 +1,474 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.http import HttpResponse
|
||||
from django.test import override_settings
|
||||
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.parsers import FileUploadParser
|
||||
from rest_framework.renderers import CoreJSONRenderer
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APITestCase, CoreAPIClient
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
def get_schema():
|
||||
return coreapi.Document(
|
||||
url='https://api.example.com/',
|
||||
title='Example API',
|
||||
content={
|
||||
'simple_link': coreapi.Link('/example/', description='example link'),
|
||||
'headers': coreapi.Link('/headers/'),
|
||||
'location': {
|
||||
'query': coreapi.Link('/example/', fields=[
|
||||
coreapi.Field(name='example', description='example field')
|
||||
]),
|
||||
'form': coreapi.Link('/example/', action='post', fields=[
|
||||
coreapi.Field(name='example'),
|
||||
]),
|
||||
'body': coreapi.Link('/example/', action='post', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
'path': coreapi.Link('/example/{id}', fields=[
|
||||
coreapi.Field(name='id', location='path')
|
||||
])
|
||||
},
|
||||
'encoding': {
|
||||
'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
|
||||
coreapi.Field(name='example')
|
||||
]),
|
||||
'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
|
||||
coreapi.Field(name='example')
|
||||
]),
|
||||
'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[
|
||||
coreapi.Field(name='example', location='body')
|
||||
]),
|
||||
},
|
||||
'response': {
|
||||
'download': coreapi.Link('/download/'),
|
||||
'text': coreapi.Link('/text/')
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _iterlists(querydict):
|
||||
if hasattr(querydict, 'iterlists'):
|
||||
return querydict.iterlists()
|
||||
return querydict.lists()
|
||||
|
||||
|
||||
def _get_query_params(request):
|
||||
# Return query params in a plain dict, using a list value if more
|
||||
# than one item is present for a given key.
|
||||
return {
|
||||
key: (value[0] if len(value) == 1 else value)
|
||||
for key, value in
|
||||
_iterlists(request.query_params)
|
||||
}
|
||||
|
||||
|
||||
def _get_data(request):
|
||||
if not isinstance(request.data, dict):
|
||||
return request.data
|
||||
# Coerce multidict into regular dict, and remove files to
|
||||
# make assertions simpler.
|
||||
if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'):
|
||||
# Use a list value if a QueryDict contains multiple items for a key.
|
||||
return {
|
||||
key: value[0] if len(value) == 1 else value
|
||||
for key, value in _iterlists(request.data)
|
||||
if key not in request.FILES
|
||||
}
|
||||
return {
|
||||
key: value
|
||||
for key, value in request.data.items()
|
||||
if key not in request.FILES
|
||||
}
|
||||
|
||||
|
||||
def _get_files(request):
|
||||
if not request.FILES:
|
||||
return {}
|
||||
return {
|
||||
key: {'name': value.name, 'content': value.read()}
|
||||
for key, value in request.FILES.items()
|
||||
}
|
||||
|
||||
|
||||
class SchemaView(APIView):
|
||||
renderer_classes = [CoreJSONRenderer]
|
||||
|
||||
def get(self, request):
|
||||
schema = get_schema()
|
||||
return Response(schema)
|
||||
|
||||
|
||||
class ListView(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': _get_query_params(request)
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
if request.content_type:
|
||||
content_type = request.content_type.split(';')[0]
|
||||
else:
|
||||
content_type = None
|
||||
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': _get_query_params(request),
|
||||
'data': _get_data(request),
|
||||
'files': _get_files(request),
|
||||
'content_type': content_type
|
||||
})
|
||||
|
||||
|
||||
class DetailView(APIView):
|
||||
def get(self, request, id):
|
||||
return Response({
|
||||
'id': id,
|
||||
'method': request.method,
|
||||
'query_params': _get_query_params(request)
|
||||
})
|
||||
|
||||
|
||||
class UploadView(APIView):
|
||||
parser_classes = [FileUploadParser]
|
||||
|
||||
def post(self, request):
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'files': _get_files(request),
|
||||
'content_type': request.content_type
|
||||
})
|
||||
|
||||
|
||||
class DownloadView(APIView):
|
||||
def get(self, request):
|
||||
return HttpResponse('some file content', content_type='image/png')
|
||||
|
||||
|
||||
class TextView(APIView):
|
||||
def get(self, request):
|
||||
return HttpResponse('123', content_type='text/plain')
|
||||
|
||||
|
||||
class HeadersView(APIView):
|
||||
def get(self, request):
|
||||
headers = {
|
||||
key[5:].replace('_', '-'): value
|
||||
for key, value in request.META.items()
|
||||
if key.startswith('HTTP_')
|
||||
}
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'headers': headers
|
||||
})
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^$', SchemaView.as_view()),
|
||||
url(r'^example/$', ListView.as_view()),
|
||||
url(r'^example/(?P<id>[0-9]+)/$', DetailView.as_view()),
|
||||
url(r'^upload/$', UploadView.as_view()),
|
||||
url(r'^download/$', DownloadView.as_view()),
|
||||
url(r'^text/$', TextView.as_view()),
|
||||
url(r'^headers/$', HeadersView.as_view()),
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_api_client')
|
||||
class APIClientTests(APITestCase):
|
||||
def test_api_client(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
assert schema.title == 'Example API'
|
||||
assert schema.url == 'https://api.example.com/'
|
||||
assert schema['simple_link'].description == 'example link'
|
||||
assert schema['location']['query'].fields[0].description == 'example field'
|
||||
data = client.action(schema, ['simple_link'])
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_query_params(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'query'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'example': '123'}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_session_headers(self):
|
||||
client = CoreAPIClient()
|
||||
client.session.headers.update({'X-Custom-Header': 'foo'})
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['headers'])
|
||||
assert data['headers']['X-CUSTOM-HEADER'] == 'foo'
|
||||
|
||||
def test_query_params_with_multiple_values(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]})
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'example': ['1', '2', '3']}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_form_params(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'form'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/json',
|
||||
'query_params': {},
|
||||
'data': {'example': 123},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_body_params(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'body'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/json',
|
||||
'query_params': {},
|
||||
'data': 123,
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_path_params(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['location', 'path'], params={'id': 123})
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {},
|
||||
'id': '123'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
temp = tempfile.NamedTemporaryFile()
|
||||
temp.write(b'example file content')
|
||||
temp.flush()
|
||||
|
||||
with open(temp.name, 'rb') as upload:
|
||||
name = os.path.basename(upload.name)
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': upload})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {},
|
||||
'files': {'example': {'name': name, 'content': 'example file content'}}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_no_file(self):
|
||||
# When no file is included, multipart encoding should still be used.
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': 123})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {'example': '123'},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_multiple_values(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {'example': ['1', '2', '3']},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_string_file_content(self):
|
||||
# Test for `coreapi.utils.File` support.
|
||||
from coreapi.utils import File
|
||||
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = File(name='example.txt', content='123')
|
||||
data = client.action(schema, ['encoding', 'multipart'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {},
|
||||
'files': {'example': {'name': 'example.txt', 'content': '123'}}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_multipart_encoding_in_body(self):
|
||||
from coreapi.utils import File
|
||||
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'}
|
||||
data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'multipart/form-data',
|
||||
'query_params': {},
|
||||
'data': {'bar': 'abc'},
|
||||
'files': {'foo': {'name': 'example.txt', 'content': '123'}}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
# URLencoded
|
||||
|
||||
def test_urlencoded_encoding(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/x-www-form-urlencoded',
|
||||
'query_params': {},
|
||||
'data': {'example': '123'},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_urlencoded_encoding_multiple_values(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/x-www-form-urlencoded',
|
||||
'query_params': {},
|
||||
'data': {'example': ['1', '2', '3']},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_urlencoded_encoding_in_body(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}})
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'content_type': 'application/x-www-form-urlencoded',
|
||||
'query_params': {},
|
||||
'data': {'foo': '123', 'bar': 'true'},
|
||||
'files': {}
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
# Raw uploads
|
||||
|
||||
def test_raw_upload(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
temp = tempfile.NamedTemporaryFile()
|
||||
temp.write(b'example file content')
|
||||
temp.flush()
|
||||
|
||||
with open(temp.name, 'rb') as upload:
|
||||
name = os.path.basename(upload.name)
|
||||
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'files': {'file': {'name': name, 'content': 'example file content'}},
|
||||
'content_type': 'application/octet-stream'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_raw_upload_string_file_content(self):
|
||||
from coreapi.utils import File
|
||||
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = File('example.txt', '123')
|
||||
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'files': {'file': {'name': 'example.txt', 'content': '123'}},
|
||||
'content_type': 'text/plain'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
def test_raw_upload_explicit_content_type(self):
|
||||
from coreapi.utils import File
|
||||
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
example = File('example.txt', '123', 'text/html')
|
||||
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
|
||||
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'files': {'file': {'name': 'example.txt', 'content': '123'}},
|
||||
'content_type': 'text/html'
|
||||
}
|
||||
assert data == expected
|
||||
|
||||
# Responses
|
||||
|
||||
def test_text_response(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['response', 'text'])
|
||||
|
||||
expected = '123'
|
||||
assert data == expected
|
||||
|
||||
def test_download_response(self):
|
||||
client = CoreAPIClient()
|
||||
schema = client.get('http://api.example.com/')
|
||||
|
||||
data = client.action(schema, ['response', 'download'])
|
||||
assert data.basename == 'download.png'
|
||||
assert data.read() == b'some file content'
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
from django.conf.urls import url
|
||||
from django.db import connection, connections, transaction
|
||||
from django.http import Http404
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from django.test import TestCase, TransactionTestCase, override_settings
|
||||
from django.utils.decorators import method_decorator
|
||||
|
||||
from rest_framework import status
|
||||
|
@ -36,6 +36,20 @@ class APIExceptionView(APIView):
|
|||
raise APIException
|
||||
|
||||
|
||||
class NonAtomicAPIExceptionView(APIView):
|
||||
@method_decorator(transaction.non_atomic_requests)
|
||||
def dispatch(self, *args, **kwargs):
|
||||
return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
BasicModel.objects.all()
|
||||
raise Http404
|
||||
|
||||
urlpatterns = (
|
||||
url(r'^$', NonAtomicAPIExceptionView.as_view()),
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipUnless(
|
||||
connection.features.uses_savepoints,
|
||||
"'atomic' requires transactions and savepoints."
|
||||
|
@ -124,22 +138,8 @@ class DBTransactionAPIExceptionTests(TestCase):
|
|||
connection.features.uses_savepoints,
|
||||
"'atomic' requires transactions and savepoints."
|
||||
)
|
||||
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
|
||||
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
|
||||
@property
|
||||
def urls(self):
|
||||
class NonAtomicAPIExceptionView(APIView):
|
||||
@method_decorator(transaction.non_atomic_requests)
|
||||
def dispatch(self, *args, **kwargs):
|
||||
return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
BasicModel.objects.all()
|
||||
raise Http404
|
||||
|
||||
return (
|
||||
url(r'^$', NonAtomicAPIExceptionView.as_view()),
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from rest_framework.authentication import (
|
|||
)
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.authtoken.views import obtain_auth_token
|
||||
from rest_framework.compat import is_authenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APIClient, APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
|
@ -408,7 +409,7 @@ class FailingAuthAccessedInRenderer(TestCase):
|
|||
|
||||
def render(self, data, media_type=None, renderer_context=None):
|
||||
request = renderer_context['request']
|
||||
if request.user.is_authenticated():
|
||||
if is_authenticated(request.user):
|
||||
return b'authenticated'
|
||||
return b'not authenticated'
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
|
@ -11,6 +12,67 @@ from django.utils import six, timezone
|
|||
|
||||
import rest_framework
|
||||
from rest_framework import serializers
|
||||
from rest_framework.fields import is_simple_callable
|
||||
|
||||
try:
|
||||
import typings
|
||||
except ImportError:
|
||||
typings = False
|
||||
|
||||
|
||||
# Tests for helper functions.
|
||||
# ---------------------------
|
||||
|
||||
class TestIsSimpleCallable:
|
||||
|
||||
def test_method(self):
|
||||
class Foo:
|
||||
@classmethod
|
||||
def classmethod(cls):
|
||||
pass
|
||||
|
||||
def valid(self):
|
||||
pass
|
||||
|
||||
def valid_kwargs(self, param='value'):
|
||||
pass
|
||||
|
||||
def invalid(self, param):
|
||||
pass
|
||||
|
||||
assert is_simple_callable(Foo.classmethod)
|
||||
|
||||
# unbound methods
|
||||
assert not is_simple_callable(Foo.valid)
|
||||
assert not is_simple_callable(Foo.valid_kwargs)
|
||||
assert not is_simple_callable(Foo.invalid)
|
||||
|
||||
# bound methods
|
||||
assert is_simple_callable(Foo().valid)
|
||||
assert is_simple_callable(Foo().valid_kwargs)
|
||||
assert not is_simple_callable(Foo().invalid)
|
||||
|
||||
def test_function(self):
|
||||
def simple():
|
||||
pass
|
||||
|
||||
def valid(param='value', param2='value'):
|
||||
pass
|
||||
|
||||
def invalid(param, param2='value'):
|
||||
pass
|
||||
|
||||
assert is_simple_callable(simple)
|
||||
assert is_simple_callable(valid)
|
||||
assert not is_simple_callable(invalid)
|
||||
|
||||
@unittest.skipUnless(typings, 'requires python 3.5')
|
||||
def test_type_annotation(self):
|
||||
# The annotation will otherwise raise a syntax error in python < 3.5
|
||||
exec("def valid(param: str='value'): pass", locals())
|
||||
valid = locals()['valid']
|
||||
|
||||
assert is_simple_callable(valid)
|
||||
|
||||
|
||||
# Tests for field keyword arguments and core functionality.
|
||||
|
|
|
@ -6,7 +6,6 @@ from decimal import Decimal
|
|||
|
||||
from django.conf.urls import url
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
from django.test.utils import override_settings
|
||||
|
@ -14,7 +13,7 @@ from django.utils.dateparse import parse_date
|
|||
from django.utils.six.moves import reload_module
|
||||
|
||||
from rest_framework import filters, generics, serializers, status
|
||||
from rest_framework.compat import django_filters
|
||||
from rest_framework.compat import django_filters, reverse
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
from .models import BaseFilterableItem, BasicModel, FilterableItem
|
||||
|
@ -77,6 +76,7 @@ if django_filters:
|
|||
|
||||
class Meta:
|
||||
model = BaseFilterableItem
|
||||
fields = '__all__'
|
||||
|
||||
class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
|
||||
queryset = FilterableItem.objects.all()
|
||||
|
@ -456,7 +456,7 @@ class AttributeModel(models.Model):
|
|||
|
||||
class SearchFilterModelFk(models.Model):
|
||||
title = models.CharField(max_length=20)
|
||||
attribute = models.ForeignKey(AttributeModel)
|
||||
attribute = models.ForeignKey(AttributeModel, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class SearchFilterFkSerializer(serializers.ModelSerializer):
|
||||
|
|
|
@ -20,7 +20,7 @@ from django.test import TestCase
|
|||
from django.utils import six
|
||||
|
||||
from rest_framework import serializers
|
||||
from rest_framework.compat import unicode_repr
|
||||
from rest_framework.compat import set_many, unicode_repr
|
||||
|
||||
|
||||
def dedent(blocktext):
|
||||
|
@ -651,7 +651,7 @@ class TestIntegration(TestCase):
|
|||
foreign_key=self.foreign_key_target,
|
||||
one_to_one=self.one_to_one_target,
|
||||
)
|
||||
self.instance.many_to_many = self.many_to_many_targets
|
||||
set_many(self.instance, 'many_to_many', self.many_to_many_targets)
|
||||
self.instance.save()
|
||||
|
||||
def test_pk_retrival(self):
|
||||
|
@ -962,7 +962,7 @@ class OneToOneTargetTestModel(models.Model):
|
|||
|
||||
|
||||
class OneToOneSourceTestModel(models.Model):
|
||||
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True)
|
||||
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class TestModelFieldValues(TestCase):
|
||||
|
@ -990,6 +990,7 @@ class TestUniquenessOverride(TestCase):
|
|||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = TestModel
|
||||
fields = '__all__'
|
||||
extra_kwargs = {'field_1': {'required': False}}
|
||||
|
||||
fields = TestSerializer().fields
|
||||
|
|
|
@ -4,7 +4,6 @@ import base64
|
|||
import unittest
|
||||
|
||||
from django.contrib.auth.models import Group, Permission, User
|
||||
from django.core.urlresolvers import ResolverMatch
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
|
||||
|
@ -12,7 +11,7 @@ from rest_framework import (
|
|||
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
|
||||
status
|
||||
)
|
||||
from rest_framework.compat import guardian
|
||||
from rest_framework.compat import ResolverMatch, guardian, set_many
|
||||
from rest_framework.filters import DjangoObjectPermissionsFilter
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
@ -74,15 +73,15 @@ class ModelPermissionsIntegrationTests(TestCase):
|
|||
def setUp(self):
|
||||
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
|
||||
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
|
||||
user.user_permissions = [
|
||||
set_many(user, 'user_permissions', [
|
||||
Permission.objects.get(codename='add_basicmodel'),
|
||||
Permission.objects.get(codename='change_basicmodel'),
|
||||
Permission.objects.get(codename='delete_basicmodel')
|
||||
]
|
||||
])
|
||||
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
|
||||
user.user_permissions = [
|
||||
set_many(user, 'user_permissions', [
|
||||
Permission.objects.get(codename='change_basicmodel'),
|
||||
]
|
||||
])
|
||||
|
||||
self.permitted_credentials = basic_auth_header('permitted', 'password')
|
||||
self.disallowed_credentials = basic_auth_header('disallowed', 'password')
|
||||
|
|
|
@ -13,6 +13,7 @@ from django.utils import six
|
|||
|
||||
from rest_framework import status
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
from rest_framework.compat import is_anonymous
|
||||
from rest_framework.parsers import BaseParser, FormParser, MultiPartParser
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
@ -169,9 +170,9 @@ class TestUserSetter(TestCase):
|
|||
|
||||
def test_user_can_logout(self):
|
||||
self.request.user = self.user
|
||||
self.assertFalse(self.request.user.is_anonymous())
|
||||
self.assertFalse(is_anonymous(self.request.user))
|
||||
logout(self.request)
|
||||
self.assertTrue(self.request.user.is_anonymous())
|
||||
self.assertTrue(is_anonymous(self.request.user))
|
||||
|
||||
def test_logged_in_user_is_set_on_wrapped_request(self):
|
||||
login(self.request, self.user)
|
||||
|
|
256
tests/test_requests_client.py
Normal file
256
tests/test_requests_client.py
Normal file
|
@ -0,0 +1,256 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.contrib.auth import authenticate, login
|
||||
from django.contrib.auth.models import User
|
||||
from django.shortcuts import redirect
|
||||
from django.test import override_settings
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
|
||||
|
||||
from rest_framework.compat import is_authenticated, requests
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APITestCase, RequestsClient
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class Root(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': request.query_params,
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
files = {
|
||||
key: (value.name, value.read())
|
||||
for key, value in request.FILES.items()
|
||||
}
|
||||
post = request.POST
|
||||
json = None
|
||||
if request.META.get('CONTENT_TYPE') == 'application/json':
|
||||
json = request.data
|
||||
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': request.query_params,
|
||||
'POST': post,
|
||||
'FILES': files,
|
||||
'JSON': json
|
||||
})
|
||||
|
||||
|
||||
class HeadersView(APIView):
|
||||
def get(self, request):
|
||||
headers = {
|
||||
key[5:].replace('_', '-'): value
|
||||
for key, value in request.META.items()
|
||||
if key.startswith('HTTP_')
|
||||
}
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'headers': headers
|
||||
})
|
||||
|
||||
|
||||
class SessionView(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
key: value for key, value in request.session.items()
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
for key, value in request.data.items():
|
||||
request.session[key] = value
|
||||
return Response({
|
||||
key: value for key, value in request.session.items()
|
||||
})
|
||||
|
||||
|
||||
class AuthView(APIView):
|
||||
@method_decorator(ensure_csrf_cookie)
|
||||
def get(self, request):
|
||||
if is_authenticated(request.user):
|
||||
username = request.user.username
|
||||
else:
|
||||
username = None
|
||||
return Response({
|
||||
'username': username
|
||||
})
|
||||
|
||||
@method_decorator(csrf_protect)
|
||||
def post(self, request):
|
||||
username = request.data['username']
|
||||
password = request.data['password']
|
||||
user = authenticate(username=username, password=password)
|
||||
if user is None:
|
||||
return Response({'error': 'incorrect credentials'})
|
||||
login(request, user)
|
||||
return redirect('/auth/')
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^$', Root.as_view(), name='root'),
|
||||
url(r'^headers/$', HeadersView.as_view(), name='headers'),
|
||||
url(r'^session/$', SessionView.as_view(), name='session'),
|
||||
url(r'^auth/$', AuthView.as_view(), name='auth'),
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipUnless(requests, 'requests not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_requests_client')
|
||||
class RequestsClientTests(APITestCase):
|
||||
def test_get_request(self):
|
||||
client = RequestsClient()
|
||||
response = client.get('http://testserver/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_request_query_params_in_url(self):
|
||||
client = RequestsClient()
|
||||
response = client.get('http://testserver/?key=value')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_request_query_params_by_kwarg(self):
|
||||
client = RequestsClient()
|
||||
response = client.get('http://testserver/', params={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_with_headers(self):
|
||||
client = RequestsClient()
|
||||
response = client.get('http://testserver/headers/', headers={'User-Agent': 'example'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
headers = response.json()['headers']
|
||||
assert headers['USER-AGENT'] == 'example'
|
||||
|
||||
def test_get_with_session_headers(self):
|
||||
client = RequestsClient()
|
||||
client.headers.update({'User-Agent': 'example'})
|
||||
response = client.get('http://testserver/headers/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
headers = response.json()['headers']
|
||||
assert headers['USER-AGENT'] == 'example'
|
||||
|
||||
def test_post_form_request(self):
|
||||
client = RequestsClient()
|
||||
response = client.post('http://testserver/', data={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'POST': {'key': 'value'},
|
||||
'FILES': {},
|
||||
'JSON': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_post_json_request(self):
|
||||
client = RequestsClient()
|
||||
response = client.post('http://testserver/', json={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'POST': {},
|
||||
'FILES': {},
|
||||
'JSON': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_post_multipart_request(self):
|
||||
client = RequestsClient()
|
||||
files = {
|
||||
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
|
||||
}
|
||||
response = client.post('http://testserver/', files=files)
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
|
||||
'POST': {},
|
||||
'JSON': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_session(self):
|
||||
client = RequestsClient()
|
||||
response = client.get('http://testserver/session/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {}
|
||||
assert response.json() == expected
|
||||
|
||||
response = client.post('http://testserver/session/', json={'example': 'abc'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {'example': 'abc'}
|
||||
assert response.json() == expected
|
||||
|
||||
response = client.get('http://testserver/session/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {'example': 'abc'}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_auth(self):
|
||||
# Confirm session is not authenticated
|
||||
client = RequestsClient()
|
||||
response = client.get('http://testserver/auth/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
assert 'csrftoken' in response.cookies
|
||||
csrftoken = response.cookies['csrftoken']
|
||||
|
||||
user = User.objects.create(username='tom')
|
||||
user.set_password('password')
|
||||
user.save()
|
||||
|
||||
# Perform a login
|
||||
response = client.post('http://testserver/auth/', json={
|
||||
'username': 'tom',
|
||||
'password': 'password'
|
||||
}, headers={'X-CSRFToken': csrftoken})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': 'tom'
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
# Confirm session is authenticated
|
||||
response = client.get('http://testserver/auth/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': 'tom'
|
||||
}
|
||||
assert response.json() == expected
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
from collections import namedtuple
|
||||
|
||||
from django.conf.urls import include, url
|
||||
|
@ -47,6 +48,21 @@ class MockViewSet(viewsets.ModelViewSet):
|
|||
serializer_class = None
|
||||
|
||||
|
||||
class EmptyPrefixSerializer(serializers.HyperlinkedModelSerializer):
|
||||
class Meta:
|
||||
model = RouterTestModel
|
||||
fields = ('uuid', 'text')
|
||||
|
||||
|
||||
class EmptyPrefixViewSet(viewsets.ModelViewSet):
|
||||
queryset = [RouterTestModel(id=1, uuid='111', text='First'), RouterTestModel(id=2, uuid='222', text='Second')]
|
||||
serializer_class = EmptyPrefixSerializer
|
||||
|
||||
def get_object(self, *args, **kwargs):
|
||||
index = int(self.kwargs['pk']) - 1
|
||||
return self.queryset[index]
|
||||
|
||||
|
||||
notes_router = SimpleRouter()
|
||||
notes_router.register(r'notes', NoteViewSet)
|
||||
|
||||
|
@ -56,11 +72,19 @@ kwarged_notes_router.register(r'notes', KWargedNoteViewSet)
|
|||
namespaced_router = DefaultRouter()
|
||||
namespaced_router.register(r'example', MockViewSet, base_name='example')
|
||||
|
||||
empty_prefix_router = SimpleRouter()
|
||||
empty_prefix_router.register(r'', EmptyPrefixViewSet, base_name='empty_prefix')
|
||||
empty_prefix_urls = [
|
||||
url(r'^', include(empty_prefix_router.urls)),
|
||||
]
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^non-namespaced/', include(namespaced_router.urls)),
|
||||
url(r'^namespaced/', include(namespaced_router.urls, namespace='example')),
|
||||
url(r'^example/', include(notes_router.urls)),
|
||||
url(r'^example2/', include(kwarged_notes_router.urls)),
|
||||
|
||||
url(r'^empty-prefix/', include(empty_prefix_urls)),
|
||||
]
|
||||
|
||||
|
||||
|
@ -384,3 +408,28 @@ class TestDynamicListAndDetailRouter(TestCase):
|
|||
|
||||
def test_inherited_list_and_detail_route_decorators(self):
|
||||
self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF='tests.test_routers')
|
||||
class TestEmptyPrefix(TestCase):
|
||||
def test_empty_prefix_list(self):
|
||||
response = self.client.get('/empty-prefix/')
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual(
|
||||
json.loads(response.content.decode('utf-8')),
|
||||
[
|
||||
{'uuid': '111', 'text': 'First'},
|
||||
{'uuid': '222', 'text': 'Second'}
|
||||
]
|
||||
)
|
||||
|
||||
def test_empty_prefix_detail(self):
|
||||
response = self.client.get('/empty-prefix/1/')
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual(
|
||||
json.loads(response.content.decode('utf-8')),
|
||||
{
|
||||
'uuid': '111',
|
||||
'text': 'First'
|
||||
}
|
||||
)
|
||||
|
|
|
@ -6,9 +6,8 @@ from django.test import TestCase, override_settings
|
|||
from rest_framework import filters, pagination, permissions, serializers
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.decorators import detail_route, list_route
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
from rest_framework.schemas import SchemaGenerator, get_schema_view
|
||||
from rest_framework.test import APIClient
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
@ -23,6 +22,10 @@ class ExamplePagination(pagination.PageNumberPagination):
|
|||
page_size = 100
|
||||
|
||||
|
||||
class EmptySerializer(serializers.Serializer):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleSerializer(serializers.Serializer):
|
||||
a = serializers.CharField(required=True, help_text='A field description')
|
||||
b = serializers.CharField(required=False)
|
||||
|
@ -43,36 +46,37 @@ class ExampleViewSet(ModelViewSet):
|
|||
|
||||
@detail_route(methods=['post'], serializer_class=AnotherSerializer)
|
||||
def custom_action(self, request, pk):
|
||||
"""
|
||||
A description of custom action.
|
||||
"""
|
||||
return super(ExampleSerializer, self).retrieve(self, request)
|
||||
|
||||
@list_route()
|
||||
def custom_list_action(self, request):
|
||||
return super(ExampleViewSet, self).list(self, request)
|
||||
|
||||
@list_route(methods=['post', 'get'], serializer_class=EmptySerializer)
|
||||
def custom_list_action_multiple_methods(self, request):
|
||||
return super(ExampleViewSet, self).list(self, request)
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
assert self.request
|
||||
assert self.action
|
||||
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
|
||||
|
||||
|
||||
class ExampleView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
if coreapi:
|
||||
schema_view = get_schema_view(title='Example API')
|
||||
else:
|
||||
def schema_view(request):
|
||||
pass
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
return Response()
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
return Response()
|
||||
|
||||
|
||||
router = DefaultRouter(schema_title='Example API' if coreapi else None)
|
||||
router = DefaultRouter()
|
||||
router.register('example', ExampleViewSet, base_name='example')
|
||||
urlpatterns = [
|
||||
url(r'^$', schema_view),
|
||||
url(r'^', include(router.urls))
|
||||
]
|
||||
urlpatterns2 = [
|
||||
url(r'^example-view/$', ExampleView.as_view(), name='example-view')
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
|
@ -80,7 +84,7 @@ urlpatterns2 = [
|
|||
class TestRouterGeneratedSchema(TestCase):
|
||||
def test_anonymous_request(self):
|
||||
client = APIClient()
|
||||
response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json')
|
||||
response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
expected = coreapi.Document(
|
||||
url='',
|
||||
|
@ -99,11 +103,17 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
url='/example/custom_list_action/',
|
||||
action='get'
|
||||
),
|
||||
'retrieve': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
'custom_list_action_multiple_methods': {
|
||||
'read': coreapi.Link(
|
||||
url='/example/custom_list_action_multiple_methods/',
|
||||
action='get'
|
||||
)
|
||||
},
|
||||
'read': coreapi.Link(
|
||||
url='/example/{id}/',
|
||||
action='get',
|
||||
fields=[
|
||||
coreapi.Field('pk', required=True, location='path')
|
||||
coreapi.Field('id', required=True, location='path')
|
||||
]
|
||||
)
|
||||
}
|
||||
|
@ -114,7 +124,7 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
def test_authenticated_request(self):
|
||||
client = APIClient()
|
||||
client.force_authenticate(MockUser())
|
||||
response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json')
|
||||
response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
expected = coreapi.Document(
|
||||
url='',
|
||||
|
@ -134,56 +144,67 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
action='post',
|
||||
encoding='application/json',
|
||||
fields=[
|
||||
coreapi.Field('a', required=True, location='form', description='A field description'),
|
||||
coreapi.Field('b', required=False, location='form')
|
||||
coreapi.Field('a', required=True, location='form', type='string', description='A field description'),
|
||||
coreapi.Field('b', required=False, location='form', type='string')
|
||||
]
|
||||
),
|
||||
'retrieve': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
'read': coreapi.Link(
|
||||
url='/example/{id}/',
|
||||
action='get',
|
||||
fields=[
|
||||
coreapi.Field('pk', required=True, location='path')
|
||||
coreapi.Field('id', required=True, location='path')
|
||||
]
|
||||
),
|
||||
'custom_action': coreapi.Link(
|
||||
url='/example/{pk}/custom_action/',
|
||||
url='/example/{id}/custom_action/',
|
||||
action='post',
|
||||
encoding='application/json',
|
||||
description='A description of custom action.',
|
||||
fields=[
|
||||
coreapi.Field('pk', required=True, location='path'),
|
||||
coreapi.Field('c', required=True, location='form'),
|
||||
coreapi.Field('d', required=False, location='form'),
|
||||
coreapi.Field('id', required=True, location='path'),
|
||||
coreapi.Field('c', required=True, location='form', type='string'),
|
||||
coreapi.Field('d', required=False, location='form', type='string'),
|
||||
]
|
||||
),
|
||||
'custom_list_action': coreapi.Link(
|
||||
url='/example/custom_list_action/',
|
||||
action='get'
|
||||
),
|
||||
'custom_list_action_multiple_methods': {
|
||||
'read': coreapi.Link(
|
||||
url='/example/custom_list_action_multiple_methods/',
|
||||
action='get'
|
||||
),
|
||||
'create': coreapi.Link(
|
||||
url='/example/custom_list_action_multiple_methods/',
|
||||
action='post'
|
||||
)
|
||||
},
|
||||
'update': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
url='/example/{id}/',
|
||||
action='put',
|
||||
encoding='application/json',
|
||||
fields=[
|
||||
coreapi.Field('pk', required=True, location='path'),
|
||||
coreapi.Field('a', required=True, location='form', description='A field description'),
|
||||
coreapi.Field('b', required=False, location='form')
|
||||
coreapi.Field('id', required=True, location='path'),
|
||||
coreapi.Field('a', required=True, location='form', type='string', description='A field description'),
|
||||
coreapi.Field('b', required=False, location='form', type='string')
|
||||
]
|
||||
),
|
||||
'partial_update': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
url='/example/{id}/',
|
||||
action='patch',
|
||||
encoding='application/json',
|
||||
fields=[
|
||||
coreapi.Field('pk', required=True, location='path'),
|
||||
coreapi.Field('a', required=False, location='form', description='A field description'),
|
||||
coreapi.Field('b', required=False, location='form')
|
||||
coreapi.Field('id', required=True, location='path'),
|
||||
coreapi.Field('a', required=False, location='form', type='string', description='A field description'),
|
||||
coreapi.Field('b', required=False, location='form', type='string')
|
||||
]
|
||||
),
|
||||
'destroy': coreapi.Link(
|
||||
url='/example/{pk}/',
|
||||
'delete': coreapi.Link(
|
||||
url='/example/{id}/',
|
||||
action='delete',
|
||||
fields=[
|
||||
coreapi.Field('pk', required=True, location='path')
|
||||
coreapi.Field('id', required=True, location='path')
|
||||
]
|
||||
)
|
||||
}
|
||||
|
@ -192,27 +213,123 @@ class TestRouterGeneratedSchema(TestCase):
|
|||
self.assertEqual(response.data, expected)
|
||||
|
||||
|
||||
class ExampleListView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleDetailView(APIView):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
class TestSchemaGenerator(TestCase):
|
||||
def test_view(self):
|
||||
schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2)
|
||||
schema = schema_generator.get_schema()
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url('^example/?$', ExampleListView.as_view()),
|
||||
url('^example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url('^example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
"""
|
||||
Ensure that schema generation works for APIView classes.
|
||||
"""
|
||||
generator = SchemaGenerator(title='Example API', patterns=self.patterns)
|
||||
schema = generator.get_schema()
|
||||
expected = coreapi.Document(
|
||||
url='',
|
||||
title='Test View',
|
||||
title='Example API',
|
||||
content={
|
||||
'example-view': {
|
||||
'example': {
|
||||
'create': coreapi.Link(
|
||||
url='/example-view/',
|
||||
url='/example/',
|
||||
action='post',
|
||||
fields=[]
|
||||
),
|
||||
'read': coreapi.Link(
|
||||
url='/example-view/',
|
||||
'list': coreapi.Link(
|
||||
url='/example/',
|
||||
action='get',
|
||||
fields=[]
|
||||
),
|
||||
'read': coreapi.Link(
|
||||
url='/example/{id}/',
|
||||
action='get',
|
||||
fields=[
|
||||
coreapi.Field('id', required=True, location='path')
|
||||
]
|
||||
),
|
||||
'sub': {
|
||||
'list': coreapi.Link(
|
||||
url='/example/{id}/sub/',
|
||||
action='get',
|
||||
fields=[
|
||||
coreapi.Field('id', required=True, location='path')
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEquals(schema, expected)
|
||||
self.assertEqual(schema, expected)
|
||||
|
||||
|
||||
@unittest.skipUnless(coreapi, 'coreapi is not installed')
|
||||
class TestSchemaGeneratorNotAtRoot(TestCase):
|
||||
def setUp(self):
|
||||
self.patterns = [
|
||||
url('^api/v1/example/?$', ExampleListView.as_view()),
|
||||
url('^api/v1/example/(?P<pk>\d+)/?$', ExampleDetailView.as_view()),
|
||||
url('^api/v1/example/(?P<pk>\d+)/sub/?$', ExampleDetailView.as_view()),
|
||||
]
|
||||
|
||||
def test_schema_for_regular_views(self):
|
||||
"""
|
||||
Ensure that schema generation with an API that is not at the URL
|
||||
root continues to use correct structure for link keys.
|
||||
"""
|
||||
generator = SchemaGenerator(title='Example API', patterns=self.patterns)
|
||||
schema = generator.get_schema()
|
||||
expected = coreapi.Document(
|
||||
url='',
|
||||
title='Example API',
|
||||
content={
|
||||
'example': {
|
||||
'create': coreapi.Link(
|
||||
url='/api/v1/example/',
|
||||
action='post',
|
||||
fields=[]
|
||||
),
|
||||
'list': coreapi.Link(
|
||||
url='/api/v1/example/',
|
||||
action='get',
|
||||
fields=[]
|
||||
),
|
||||
'read': coreapi.Link(
|
||||
url='/api/v1/example/{id}/',
|
||||
action='get',
|
||||
fields=[
|
||||
coreapi.Field('id', required=True, location='path')
|
||||
]
|
||||
),
|
||||
'sub': {
|
||||
'list': coreapi.Link(
|
||||
url='/api/v1/example/{id}/sub/',
|
||||
action='get',
|
||||
fields=[
|
||||
coreapi.Field('id', required=True, location='path')
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema, expected)
|
||||
|
|
|
@ -3,9 +3,9 @@ from __future__ import unicode_literals
|
|||
from collections import namedtuple
|
||||
|
||||
from django.conf.urls import include, url
|
||||
from django.core import urlresolvers
|
||||
from django.test import TestCase
|
||||
|
||||
from rest_framework.compat import RegexURLResolver, Resolver404
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.urlpatterns import format_suffix_patterns
|
||||
|
||||
|
@ -28,7 +28,7 @@ class FormatSuffixTests(TestCase):
|
|||
urlpatterns = format_suffix_patterns(urlpatterns)
|
||||
except Exception:
|
||||
self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
|
||||
resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
|
||||
resolver = RegexURLResolver(r'^/', urlpatterns)
|
||||
for test_path in test_paths:
|
||||
request = factory.get(test_path.path)
|
||||
try:
|
||||
|
@ -43,7 +43,7 @@ class FormatSuffixTests(TestCase):
|
|||
urlpatterns = format_suffix_patterns([
|
||||
url(r'^test/$', dummy_view),
|
||||
])
|
||||
resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
|
||||
resolver = RegexURLResolver(r'^/', urlpatterns)
|
||||
|
||||
test_paths = [
|
||||
(URLTestPath('/test.api', (), {'format': 'api'}), True),
|
||||
|
@ -55,7 +55,7 @@ class FormatSuffixTests(TestCase):
|
|||
request = factory.get(test_path.path)
|
||||
try:
|
||||
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
|
||||
except urlresolvers.Resolver404:
|
||||
except Resolver404:
|
||||
callback, callback_args, callback_kwargs = (None, None, None)
|
||||
if not expected_resolved:
|
||||
assert callback is None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.core.urlresolvers import NoReverseMatch
|
||||
from rest_framework.compat import NoReverseMatch
|
||||
|
||||
|
||||
class MockObject(object):
|
||||
|
|
3
tox.ini
3
tox.ini
|
@ -4,7 +4,7 @@ addopts=--tb=short
|
|||
[tox]
|
||||
envlist =
|
||||
py27-{lint,docs},
|
||||
{py27,py32,py33,py34,py35}-django18,
|
||||
{py27,py33,py34,py35}-django18,
|
||||
{py27,py34,py35}-django19,
|
||||
{py27,py34,py35}-django110,
|
||||
{py27,py34,py35}-django{master}
|
||||
|
@ -25,7 +25,6 @@ basepython =
|
|||
py35: python3.5
|
||||
py34: python3.4
|
||||
py33: python3.3
|
||||
py32: python3.2
|
||||
py27: python2.7
|
||||
|
||||
[testenv:py27-lint]
|
||||
|
|
Loading…
Reference in New Issue
Block a user