Merge pull request #647 from maspwr/writable-nested-serializers

Writable nested serializers
This commit is contained in:
Tom Christie 2013-02-09 12:30:22 -08:00
commit 08fac165ae
100 changed files with 2750 additions and 1412 deletions

View File

@ -3,16 +3,30 @@ language: python
python: python:
- "2.6" - "2.6"
- "2.7" - "2.7"
- "3.2"
- "3.3"
env: env:
- DJANGO=https://github.com/django/django/zipball/master - DJANGO=https://www.djangoproject.com/download/1.5c1/tarball/
- DJANGO=django==1.4.3 --use-mirrors - DJANGO="django==1.4.3 --use-mirrors"
- DJANGO=django==1.3.5 --use-mirrors - DJANGO="django==1.3.5 --use-mirrors"
install: install:
- pip install $DJANGO - pip install $DJANGO
- pip install django-filter==0.5.4 --use-mirrors - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-filter==0.5.4 --use-mirrors; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} == '3' ]]; then pip install https://github.com/alex/django-filter/tarball/master; fi"
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:
- python rest_framework/runtests/runtests.py - python rest_framework/runtests/runtests.py
matrix:
exclude:
- python: "3.2"
env: DJANGO="django==1.4.3 --use-mirrors"
- python: "3.2"
env: DJANGO="django==1.3.5 --use-mirrors"
- python: "3.3"
env: DJANGO="django==1.4.3 --use-mirrors"
- python: "3.3"
env: DJANGO="django==1.3.5 --use-mirrors"

View File

@ -12,8 +12,6 @@
**Full documentation for REST framework is available on [http://django-rest-framework.org][docs].** **Full documentation for REST framework is available on [http://django-rest-framework.org][docs].**
Note that this is the 2.0 version of REST framework. If you are looking for earlier versions please see the [0.4.x branch][0.4] on GitHub.
--- ---
# Overview # Overview
@ -28,7 +26,7 @@ There is also a sandbox API you can use for testing purposes, [available here][s
# Requirements # Requirements
* Python (2.6, 2.7) * Python (2.6, 2.7, 3.2, 3.3)
* Django (1.3, 1.4, 1.5) * Django (1.3, 1.4, 1.5)
**Optional:** **Optional:**
@ -81,6 +79,21 @@ To run the tests.
# Changelog # Changelog
### 2.1.17
**Date**: 26th Jan 2013
* Support proper 401 Unauthorized responses where appropriate, instead of always using 403 Forbidden.
* Support json encoding of timedelta objects.
* `format_suffix_patterns()` now supports `include` style URL patterns.
* Bugfix: Fix issues with custom pagination serializers.
* Bugfix: Nested serializers now accept `source='*'` argument.
* Bugfix: Return proper validation errors when incorrect types supplied for relational fields.
* Bugfix: Support nullable FKs with `SlugRelatedField`.
* Bugfix: Don't call custom validation methods if the field has an error.
**Note**: If the primary authentication class is `TokenAuthentication` or `BasicAuthentication`, a view will now correctly return 401 responses to unauthenticated access, with an appropriate `WWW-Authenticate` header, instead of 403 responses.
### 2.1.16 ### 2.1.16
**Date**: 14th Jan 2013 **Date**: 14th Jan 2013
@ -131,20 +144,20 @@ This change will not affect user code, so long as it's following the recommended
* Bugfix: Fix exception in browseable API on DELETE. * Bugfix: Fix exception in browseable API on DELETE.
* Bugfix: Fix issue where pk was was being set to a string if set by URL kwarg. * Bugfix: Fix issue where pk was was being set to a string if set by URL kwarg.
## 2.1.11 ### 2.1.11
**Date**: 17th Dec 2012 **Date**: 17th Dec 2012
* Bugfix: Fix issue with M2M fields in browseable API. * Bugfix: Fix issue with M2M fields in browseable API.
## 2.1.10 ### 2.1.10
**Date**: 17th Dec 2012 **Date**: 17th Dec 2012
* Bugfix: Ensure read-only fields don't have model validation applied. * Bugfix: Ensure read-only fields don't have model validation applied.
* Bugfix: Fix hyperlinked fields in paginated results. * Bugfix: Fix hyperlinked fields in paginated results.
## 2.1.9 ### 2.1.9
**Date**: 11th Dec 2012 **Date**: 11th Dec 2012
@ -152,14 +165,14 @@ This change will not affect user code, so long as it's following the recommended
* Bugfix: Fix `Meta.fields` only working as tuple not as list. * Bugfix: Fix `Meta.fields` only working as tuple not as list.
* Bugfix: Edge case if unnecessarily specifying `required=False` on read only field. * Bugfix: Edge case if unnecessarily specifying `required=False` on read only field.
## 2.1.8 ### 2.1.8
**Date**: 8th Dec 2012 **Date**: 8th Dec 2012
* Fix for creating nullable Foreign Keys with `''` as well as `None`. * Fix for creating nullable Foreign Keys with `''` as well as `None`.
* Added `null=<bool>` related field option. * Added `null=<bool>` related field option.
## 2.1.7 ### 2.1.7
**Date**: 7th Dec 2012 **Date**: 7th Dec 2012
@ -171,19 +184,19 @@ This change will not affect user code, so long as it's following the recommended
* Make `Request.user` settable. * Make `Request.user` settable.
* Bugfix: Fix `RegexField` to work with `BrowsableAPIRenderer` * Bugfix: Fix `RegexField` to work with `BrowsableAPIRenderer`
## 2.1.6 ### 2.1.6
**Date**: 23rd Nov 2012 **Date**: 23rd Nov 2012
* Bugfix: Unfix DjangoModelPermissions. (I am a doofus.) * Bugfix: Unfix DjangoModelPermissions. (I am a doofus.)
## 2.1.5 ### 2.1.5
**Date**: 23rd Nov 2012 **Date**: 23rd Nov 2012
* Bugfix: Fix DjangoModelPermissions. * Bugfix: Fix DjangoModelPermissions.
## 2.1.4 ### 2.1.4
**Date**: 22nd Nov 2012 **Date**: 22nd Nov 2012
@ -194,7 +207,7 @@ This change will not affect user code, so long as it's following the recommended
* Added `obtain_token_view` to get tokens when using `TokenAuthentication`. * Added `obtain_token_view` to get tokens when using `TokenAuthentication`.
* Bugfix: Django 1.5 configurable user support for `TokenAuthentication`. * Bugfix: Django 1.5 configurable user support for `TokenAuthentication`.
## 2.1.3 ### 2.1.3
**Date**: 16th Nov 2012 **Date**: 16th Nov 2012
@ -205,14 +218,14 @@ This change will not affect user code, so long as it's following the recommended
* 201 Responses now return a 'Location' header. * 201 Responses now return a 'Location' header.
* Bugfix: Serializer fields now respect `max_length`. * Bugfix: Serializer fields now respect `max_length`.
## 2.1.2 ### 2.1.2
**Date**: 9th Nov 2012 **Date**: 9th Nov 2012
* **Filtering support.** * **Filtering support.**
* Bugfix: Support creation of objects with reverse M2M relations. * Bugfix: Support creation of objects with reverse M2M relations.
## 2.1.1 ### 2.1.1
**Date**: 7th Nov 2012 **Date**: 7th Nov 2012
@ -222,7 +235,7 @@ This change will not affect user code, so long as it's following the recommended
* Bugfix: Make textareas same width as other fields in browsable API. * Bugfix: Make textareas same width as other fields in browsable API.
* Private API change: `.get_serializer` now uses same `instance` and `data` ordering as serializer initialization. * Private API change: `.get_serializer` now uses same `instance` and `data` ordering as serializer initialization.
## 2.1.0 ### 2.1.0
**Date**: 5th Nov 2012 **Date**: 5th Nov 2012
@ -235,13 +248,13 @@ This change will not affect user code, so long as it's following the recommended
* Minor field improvements. (Don't stringify dicts, more robust many-pk fields.) * Minor field improvements. (Don't stringify dicts, more robust many-pk fields.)
* Bugfixes (Support choice field in Browseable API) * Bugfixes (Support choice field in Browseable API)
## 2.0.2 ### 2.0.2
**Date**: 2nd Nov 2012 **Date**: 2nd Nov 2012
* Fix issues with pk related fields in the browsable API. * Fix issues with pk related fields in the browsable API.
## 2.0.1 ### 2.0.1
**Date**: 1st Nov 2012 **Date**: 1st Nov 2012
@ -249,12 +262,12 @@ This change will not affect user code, so long as it's following the recommended
* Added SlugRelatedField and ManySlugRelatedField. * Added SlugRelatedField and ManySlugRelatedField.
* If PUT creates an instance return '201 Created', instead of '200 OK'. * If PUT creates an instance return '201 Created', instead of '200 OK'.
## 2.0.0 ### 2.0.0
**Date**: 30th Oct 2012 **Date**: 30th Oct 2012
* Redesign of core components. * Redesign of core components.
* Fix **all of the things**. * **Fix all of the things**.
# License # License

View File

@ -8,7 +8,7 @@
Authentication is the mechanism of associating an incoming request with a set of identifying credentials, such as the user the request came from, or the token that it was signed with. The [permission] and [throttling] policies can then use those credentials to determine if the request should be permitted. Authentication is the mechanism of associating an incoming request with a set of identifying credentials, such as the user the request came from, or the token that it was signed with. The [permission] and [throttling] policies can then use those credentials to determine if the request should be permitted.
REST framework provides a number of authentication policies out of the box, and also allows you to implement custom policies. REST framework provides a number of authentication schemes out of the box, and also allows you to implement custom schemes.
Authentication will run the first time either the `request.user` or `request.auth` properties are accessed, and determines how those properties are initialized. Authentication will run the first time either the `request.user` or `request.auth` properties are accessed, and determines how those properties are initialized.
@ -16,17 +16,25 @@ The `request.user` property will typically be set to an instance of the `contrib
The `request.auth` property is used for any additional authentication information, for example, it may be used to represent an authentication token that the request was signed with. The `request.auth` property is used for any additional authentication information, for example, it may be used to represent an authentication token that the request was signed with.
---
**Note:** Don't forget that **authentication by itself won't allow or disallow an incoming request**, it simply identifies the credentials that the request was made with.
For information on how to setup the permission polices for your API please see the [permissions documentation][permission].
---
## How authentication is determined ## How authentication is determined
The authentication policy is always defined as a list of classes. REST framework will attempt to authenticate with each class in the list, and will set `request.user` and `request.auth` using the return value of the first class that successfully authenticates. The authentication schemes are always defined as a list of classes. REST framework will attempt to authenticate with each class in the list, and will set `request.user` and `request.auth` using the return value of the first class that successfully authenticates.
If no class authenticates, `request.user` will be set to an instance of `django.contrib.auth.models.AnonymousUser`, and `request.auth` will be set to `None`. If no class authenticates, `request.user` will be set to an instance of `django.contrib.auth.models.AnonymousUser`, and `request.auth` will be set to `None`.
The value of `request.user` and `request.auth` for unauthenticated requests can be modified using the `UNAUTHENTICATED_USER` and `UNAUTHENTICATED_TOKEN` settings. The value of `request.user` and `request.auth` for unauthenticated requests can be modified using the `UNAUTHENTICATED_USER` and `UNAUTHENTICATED_TOKEN` settings.
## Setting the authentication policy ## Setting the authentication scheme
The default authentication policy may be set globally, using the `DEFAULT_AUTHENTICATION_CLASSES` setting. For example. The default authentication schemes may be set globally, using the `DEFAULT_AUTHENTICATION` setting. For example.
REST_FRAMEWORK = { REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': ( 'DEFAULT_AUTHENTICATION_CLASSES': (
@ -35,7 +43,7 @@ The default authentication policy may be set globally, using the `DEFAULT_AUTHEN
) )
} }
You can also set the authentication policy on a per-view basis, using the `APIView` class based views. You can also set the authentication scheme on a per-view basis, using the `APIView` class based views.
class ExampleView(APIView): class ExampleView(APIView):
authentication_classes = (SessionAuthentication, BasicAuthentication) authentication_classes = (SessionAuthentication, BasicAuthentication)
@ -60,24 +68,52 @@ Or, if you're using the `@api_view` decorator with function based views.
} }
return Response(content) return Response(content)
## Unauthorized and Forbidden responses
When an unauthenticated request is denied permission there are two different error codes that may be appropriate.
* [HTTP 401 Unauthorized][http401]
* [HTTP 403 Permission Denied][http403]
HTTP 401 responses must always include a `WWW-Authenticate` header, that instructs the client how to authenticate. HTTP 403 responses do not include the `WWW-Authenticate` header.
The kind of response that will be used depends on the authentication scheme. Although multiple authentication schemes may be in use, only one scheme may be used to determine the type of response. **The first authentication class set on the view is used when determining the type of response**.
Note that when a request may successfully authenticate, but still be denied permission to perform the request, in which case a `403 Permission Denied` response will always be used, regardless of the authentication scheme.
## Apache mod_wsgi specific configuration
Note that if deploying to [Apache using mod_wsgi][mod_wsgi_official], the authorization header is not passed through to a WSGI application by default, as it is assumed that authentication will be handled by Apache, rather than at an application level.
If you are deploying to Apache, and using any non-session based authentication, you will need to explicitly configure mod_wsgi to pass the required headers through to the application. This can be done by specifying the `WSGIPassAuthorization` directive in the appropriate context and setting it to `'On'`.
# this can go in either server config, virtual host, directory or .htaccess
WSGIPassAuthorization On
---
# API Reference # API Reference
## BasicAuthentication ## BasicAuthentication
This policy uses [HTTP Basic Authentication][basicauth], signed against a user's username and password. Basic authentication is generally only appropriate for testing. This authentication scheme uses [HTTP Basic Authentication][basicauth], signed against a user's username and password. Basic authentication is generally only appropriate for testing.
If successfully authenticated, `BasicAuthentication` provides the following credentials. If successfully authenticated, `BasicAuthentication` provides the following credentials.
* `request.user` will be a Django `User` instance. * `request.user` will be a Django `User` instance.
* `request.auth` will be `None`. * `request.auth` will be `None`.
Unauthenticated responses that are denied permission will result in an `HTTP 401 Unauthorized` response with an appropriate WWW-Authenticate header. For example:
WWW-Authenticate: Basic realm="api"
**Note:** If you use `BasicAuthentication` in production you must ensure that your API is only available over `https` only. You should also ensure that your API clients will always re-request the username and password at login, and will never store those details to persistent storage. **Note:** If you use `BasicAuthentication` in production you must ensure that your API is only available over `https` only. You should also ensure that your API clients will always re-request the username and password at login, and will never store those details to persistent storage.
## TokenAuthentication ## TokenAuthentication
This policy uses a simple token-based HTTP Authentication scheme. Token authentication is appropriate for client-server setups, such as native desktop and mobile clients. This authentication scheme uses a simple token-based HTTP Authentication scheme. Token authentication is appropriate for client-server setups, such as native desktop and mobile clients.
To use the `TokenAuthentication` policy, include `rest_framework.authtoken` in your `INSTALLED_APPS` setting. To use the `TokenAuthentication` scheme, include `rest_framework.authtoken` in your `INSTALLED_APPS` setting.
You'll also need to create tokens for your users. You'll also need to create tokens for your users.
@ -93,10 +129,16 @@ For clients to authenticate, the token key should be included in the `Authorizat
If successfully authenticated, `TokenAuthentication` provides the following credentials. If successfully authenticated, `TokenAuthentication` provides the following credentials.
* `request.user` will be a Django `User` instance. * `request.user` will be a Django `User` instance.
* `request.auth` will be a `rest_framework.tokenauth.models.BasicToken` instance. * `request.auth` will be a `rest_framework.authtoken.models.BasicToken` instance.
Unauthenticated responses that are denied permission will result in an `HTTP 401 Unauthorized` response with an appropriate WWW-Authenticate header. For example:
WWW-Authenticate: Token
**Note:** If you use `TokenAuthentication` in production you must ensure that your API is only available over `https` only. **Note:** If you use `TokenAuthentication` in production you must ensure that your API is only available over `https` only.
---
If you want every user to have an automatically generated Token, you can simply catch the User's `post_save` signal. If you want every user to have an automatically generated Token, you can simply catch the User's `post_save` signal.
@receiver(post_save, sender=User) @receiver(post_save, sender=User)
@ -127,22 +169,67 @@ The `obtain_auth_token` view will return a JSON response when valid `username` a
## SessionAuthentication ## SessionAuthentication
This policy uses Django's default session backend for authentication. Session authentication is appropriate for AJAX clients that are running in the same session context as your website. This authentication scheme uses Django's default session backend for authentication. Session authentication is appropriate for AJAX clients that are running in the same session context as your website.
If successfully authenticated, `SessionAuthentication` provides the following credentials. If successfully authenticated, `SessionAuthentication` provides the following credentials.
* `request.user` will be a Django `User` instance. * `request.user` will be a Django `User` instance.
* `request.auth` will be `None`. * `request.auth` will be `None`.
If you're using an AJAX style API with SessionAuthentication, you'll need to make sure you include a valid CSRF token for any "unsafe" HTTP method calls, such as `PUT`, `POST` or `DELETE` requests. See the [Django CSRF documentation][csrf-ajax] for more details. Unauthenticated responses that are denied permission will result in an `HTTP 403 Forbidden` response.
If you're using an AJAX style API with SessionAuthentication, you'll need to make sure you include a valid CSRF token for any "unsafe" HTTP method calls, such as `PUT`, `PATCH`, `POST` or `DELETE` requests. See the [Django CSRF documentation][csrf-ajax] for more details.
# Custom authentication # Custom authentication
To implement a custom authentication policy, subclass `BaseAuthentication` and override the `.authenticate(self, request)` method. The method should return a two-tuple of `(user, auth)` if authentication succeeds, or `None` otherwise. To implement a custom authentication scheme, subclass `BaseAuthentication` and override the `.authenticate(self, request)` method. The method should return a two-tuple of `(user, auth)` if authentication succeeds, or `None` otherwise.
In some circumstances instead of returning `None`, you may want to raise an `AuthenticationFailed` exception from the `.authenticate()` method.
Typically the approach you should take is:
* If authentication is not attempted, return `None`. Any other authentication schemes also in use will still be checked.
* If authentication is attempted but fails, raise a `AuthenticationFailed` exception. An error response will be returned immediately, without checking any other authentication schemes.
You *may* also override the `.authenticate_header(self, request)` method. If implemented, it should return a string that will be used as the value of the `WWW-Authenticate` header in a `HTTP 401 Unauthorized` response.
If the `.authenticate_header()` method is not overridden, the authentication scheme will return `HTTP 403 Forbidden` responses when an unauthenticated request is denied access.
## Example
The following example will authenticate any incoming request as the user given by the username in a custom request header named 'X_USERNAME'.
class ExampleAuthentication(authentication.BaseAuthentication):
def has_permission(self, request, view, obj=None):
username = request.META.get('X_USERNAME')
if not username:
return None
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
raise authenticate.AuthenticationFailed('No such user')
return (user, None)
---
# Third party packages
The following third party packages are also available.
## Digest Authentication
HTTP digest authentication is a widely implemented scheme that was intended to replace HTTP basic authentication, and which provides a simple encrypted authentication mechanism. [Juan Riaza][juanriaza] maintains the [djangorestframework-digestauth][djangorestframework-digestauth] package which provides HTTP digest authentication support for REST framework.
[cite]: http://jacobian.org/writing/rest-worst-practices/ [cite]: http://jacobian.org/writing/rest-worst-practices/
[http401]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.2
[http403]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.4
[basicauth]: http://tools.ietf.org/html/rfc2617 [basicauth]: http://tools.ietf.org/html/rfc2617
[oauth]: http://oauth.net/2/ [oauth]: http://oauth.net/2/
[permission]: permissions.md [permission]: permissions.md
[throttling]: throttling.md [throttling]: throttling.md
[csrf-ajax]: https://docs.djangoproject.com/en/dev/ref/contrib/csrf/#ajax [csrf-ajax]: https://docs.djangoproject.com/en/dev/ref/contrib/csrf/#ajax
[mod_wsgi_official]: http://code.google.com/p/modwsgi/wiki/ConfigurationDirectives#WSGIPassAuthorization
[juanriaza]: https://github.com/juanriaza
[djangorestframework-digestauth]: https://github.com/juanriaza/django-rest-framework-digestauth

View File

@ -53,11 +53,27 @@ Raised if the request contains malformed data when accessing `request.DATA` or `
By default this exception results in a response with the HTTP status code "400 Bad Request". By default this exception results in a response with the HTTP status code "400 Bad Request".
## AuthenticationFailed
**Signature:** `AuthenticationFailed(detail=None)`
Raised when an incoming request includes incorrect authentication.
By default this exception results in a response with the HTTP status code "401 Unauthenticated", but it may also result in a "403 Forbidden" response, depending on the authentication scheme in use. See the [authentication documentation][authentication] for more details.
## NotAuthenticated
**Signature:** `NotAuthenticated(detail=None)`
Raised when an unauthenticated request fails the permission checks.
By default this exception results in a response with the HTTP status code "401 Unauthenticated", but it may also result in a "403 Forbidden" response, depending on the authentication scheme in use. See the [authentication documentation][authentication] for more details.
## PermissionDenied ## PermissionDenied
**Signature:** `PermissionDenied(detail=None)` **Signature:** `PermissionDenied(detail=None)`
Raised when an incoming request fails the permission checks. Raised when an authenticated request fails the permission checks.
By default this exception results in a response with the HTTP status code "403 Forbidden". By default this exception results in a response with the HTTP status code "403 Forbidden".
@ -86,3 +102,4 @@ Raised when an incoming request fails the throttling checks.
By default this exception results in a response with the HTTP status code "429 Too Many Requests". By default this exception results in a response with the HTTP status code "429 Too Many Requests".
[cite]: http://www.doughellmann.com/articles/how-tos/python-exception-handling/index.html [cite]: http://www.doughellmann.com/articles/how-tos/python-exception-handling/index.html
[authentication]: authentication.md

View File

@ -193,6 +193,16 @@ A date and time representation.
Corresponds to `django.db.models.fields.DateTimeField` Corresponds to `django.db.models.fields.DateTimeField`
When using `ModelSerializer` or `HyperlinkedModelSerializer`, note that any model fields with `auto_now=True` or `auto_now_add=True` will use serializer fields that are `read_only=True` by default.
If you want to override this behavior, you'll need to declare the `DateTimeField` explicitly on the serializer. For example:
class CommentSerializer(serializers.ModelSerializer):
created = serializers.DateTimeField()
class Meta:
model = Comment
## IntegerField ## IntegerField
An integer representation. An integer representation.
@ -232,5 +242,7 @@ Signature and validation is the same as with `FileField`.
**Note:** `FileFields` and `ImageFields` are only suitable for use with MultiPartParser, since e.g. json doesn't support file uploads. **Note:** `FileFields` and `ImageFields` are only suitable for use with MultiPartParser, since e.g. json doesn't support file uploads.
Django's regular [FILE_UPLOAD_HANDLERS] are used for handling uploaded files. Django's regular [FILE_UPLOAD_HANDLERS] are used for handling uploaded files.
---
[cite]: https://docs.djangoproject.com/en/dev/ref/forms/api/#django.forms.Form.cleaned_data [cite]: https://docs.djangoproject.com/en/dev/ref/forms/api/#django.forms.Form.cleaned_data
[FILE_UPLOAD_HANDLERS]: https://docs.djangoproject.com/en/dev/ref/settings/#std:setting-FILE_UPLOAD_HANDLERS [FILE_UPLOAD_HANDLERS]: https://docs.djangoproject.com/en/dev/ref/settings/#std:setting-FILE_UPLOAD_HANDLERS

View File

@ -131,6 +131,15 @@ Each of the generic views provided is built by combining one of the base views b
Extends REST framework's `APIView` class, adding support for serialization of model instances and model querysets. Extends REST framework's `APIView` class, adding support for serialization of model instances and model querysets.
**Methods**:
* `get_serializer_context(self)` - Returns a dictionary containing any extra context that should be supplied to the serializer. Defaults to including `'request'`, `'view'` and `'format'` keys.
* `get_serializer_class(self)` - Returns the class that should be used for the serializer.
* `get_serializer(self, instance=None, data=None, files=None, many=False, partial=False)` - Returns a serializer instance.
* `pre_save(self, obj)` - A hook that is called before saving an object.
* `post_save(self, obj, created=False)` - A hook that is called after saving an object.
**Attributes**: **Attributes**:
* `model` - The model that should be used for this view. Used as a fallback for determining the serializer if `serializer_class` is not set, and as a fallback for determining the queryset if `queryset` is not set. Otherwise not required. * `model` - The model that should be used for this view. Used as a fallback for determining the serializer if `serializer_class` is not set, and as a fallback for determining the queryset if `queryset` is not set. Otherwise not required.

View File

@ -114,8 +114,8 @@ You can also override the name used for the object list field, by setting the `r
For example, to nest a pair of links labelled 'prev' and 'next', and set the name for the results field to 'objects', you might use something like this. For example, to nest a pair of links labelled 'prev' and 'next', and set the name for the results field to 'objects', you might use something like this.
class LinksSerializer(serializers.Serializer): class LinksSerializer(serializers.Serializer):
next = pagination.NextURLField(source='*') next = pagination.NextPageField(source='*')
prev = pagination.PreviousURLField(source='*') prev = pagination.PreviousPageField(source='*')
class CustomPaginationSerializer(pagination.BasePaginationSerializer): class CustomPaginationSerializer(pagination.BasePaginationSerializer):
links = LinksSerializer(source='*') # Takes the page object as the source links = LinksSerializer(source='*') # Takes the page object as the source

View File

@ -14,6 +14,16 @@ REST framework includes a number of built in Parser classes, that allow you to a
The set of valid parsers for a view is always defined as a list of classes. When either `request.DATA` or `request.FILES` is accessed, REST framework will examine the `Content-Type` header on the incoming request, and determine which parser to use to parse the request content. The set of valid parsers for a view is always defined as a list of classes. When either `request.DATA` or `request.FILES` is accessed, REST framework will examine the `Content-Type` header on the incoming request, and determine which parser to use to parse the request content.
---
**Note**: When developing client applications always remember to make sure you're setting the `Content-Type` header when sending data in an HTTP request.
If you don't set the content type, most clients will default to using `'application/x-www-form-urlencoded'`, which may not be what you wanted.
As an example, if you are sending `json` encoded data using jQuery with the [.ajax() method][jquery-ajax], you should make sure to include the `contentType: 'application/json'` setting.
---
## Setting the parsers ## Setting the parsers
The default set of parsers may be set globally, using the `DEFAULT_PARSER_CLASSES` setting. For example, the following settings would allow requests with `YAML` content. The default set of parsers may be set globally, using the `DEFAULT_PARSER_CLASSES` setting. For example, the following settings would allow requests with `YAML` content.
@ -169,6 +179,7 @@ The following third party packages are also available.
[MessagePack][messagepack] is a fast, efficient binary serialization format. [Juan Riaza][juanriaza] maintains the [djangorestframework-msgpack][djangorestframework-msgpack] package which provides MessagePack renderer and parser support for REST framework. [MessagePack][messagepack] is a fast, efficient binary serialization format. [Juan Riaza][juanriaza] maintains the [djangorestframework-msgpack][djangorestframework-msgpack] package which provides MessagePack renderer and parser support for REST framework.
[jquery-ajax]: http://api.jquery.com/jQuery.ajax/
[cite]: https://groups.google.com/d/topic/django-developers/dxI4qVzrBY4/discussion [cite]: https://groups.google.com/d/topic/django-developers/dxI4qVzrBY4/discussion
[messagepack]: https://github.com/juanriaza/django-rest-framework-msgpack [messagepack]: https://github.com/juanriaza/django-rest-framework-msgpack
[juanriaza]: https://github.com/juanriaza [juanriaza]: https://github.com/juanriaza

View File

@ -110,6 +110,15 @@ To implement a custom permission, override `BasePermission` and implement the `.
The method should return `True` if the request should be granted access, and `False` otherwise. The method should return `True` if the request should be granted access, and `False` otherwise.
## Example
The following is an example of a permission class that checks the incoming request's IP address against a blacklist, and denies the request if the IP has been blacklisted.
class BlacklistPermission(permissions.BasePermission):
def has_permission(self, request, view, obj=None):
ip_addr = request.META['REMOTE_ADDR']
blacklisted = Blacklist.objects.filter(ip_addr=ip_addr).exists()
return not blacklisted
[cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html [cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html
[authentication]: authentication.md [authentication]: authentication.md

View File

@ -12,35 +12,260 @@ Relational fields are used to represent model relationships. They can be applie
--- ---
**Note:** The relational fields are declared in `relations.py`, but by convention you should import them using `from rest_framework import serializers` and refer to fields as `serializers.<FieldName>`. **Note:** The relational fields are declared in `relations.py`, but by convention you should import them from the `serializers` module, using `from rest_framework import serializers` and refer to fields as `serializers.<FieldName>`.
--- ---
# API Reference
In order to explain the various types of relational fields, we'll use a couple of simple models for our examples. Our models will be for music albums, and the tracks listed on each album.
class Album(models.Model):
album_name = models.CharField(max_length=100)
artist = models.CharField(max_length=100)
class Track(models.Model):
album = models.ForeignKey(Album, related_name='tracks')
order = models.IntegerField()
title = models.CharField(max_length=100)
duration = models.IntegerField()
class Meta:
unique_together = ('album', 'order')
def __unicode__(self):
return '%d: %s' % (self.order, self.title)
## RelatedField ## RelatedField
This field can be applied to any of the following: `RelatedField` may be used to represent the target of the relationship using it's `__unicode__` method.
* A `ForeignKey` field. For example, the following serializer.
* A `OneToOneField` field.
* A reverse OneToOne relationship
* Any other "to-one" relationship.
By default `RelatedField` will represent the target of the field using it's `__unicode__` method. class AlbumSerializer(serializer.ModelSerializer):
tracks = RelatedField(many=True)
You can customize this behavior by subclassing `ManyRelatedField`, and overriding the `.to_native(self, value)` method. class Meta:
model = Album
fields = ('album_name', 'artist', 'tracks')
## ManyRelatedField Would serialize to the following representation.
This field can be applied to any of the following: {
'album_name': 'Things We Lost In The Fire',
'artist': 'Low'
'tracks': [
'1: Sunflower',
'2: Whitetail',
'3: Dinosaur Act',
...
]
}
* A `ManyToManyField` field. This field is read only.
* A reverse ManyToMany relationship.
* A reverse ForeignKey relationship
* Any other "to-many" relationship.
By default `ManyRelatedField` will represent the targets of the field using their `__unicode__` method. ## PrimaryKeyRelatedField
For example, given the following models: `PrimaryKeyRelatedField` may be used to represent the target of the relationship using it's primary key.
For example, the following serializer:
class AlbumSerializer(serializer.ModelSerializer):
tracks = PrimaryKeyRelatedField(many=True, read_only=True)
class Meta:
model = Album
fields = ('album_name', 'artist', 'tracks')
Would serialize to a representation like this:
{
'album_name': 'The Roots',
'artist': 'Undun'
'tracks': [
89,
90,
91,
...
]
}
By default this field is read-write, although you can change this behavior using the `read_only` flag.
**Arguments**:
* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`.
* `required` - If set to `False`, the field will accept values of `None` or the empty-string for nullable relationships.
## HyperlinkedRelatedField
`HyperlinkedRelatedField` may be used to represent the target of the relationship using a hyperlink.
For example, the following serializer:
class AlbumSerializer(serializer.ModelSerializer):
tracks = HyperlinkedRelatedField(many=True, read_only=True,
view_name='track-detail')
class Meta:
model = Album
fields = ('album_name', 'artist', 'tracks')
Would serialize to a representation like this:
{
'album_name': 'Graceland',
'artist': 'Paul Simon'
'tracks': [
'http://www.example.com/api/tracks/45',
'http://www.example.com/api/tracks/46',
'http://www.example.com/api/tracks/47',
...
]
}
By default this field is read-write, although you can change this behavior using the `read_only` flag.
**Arguments**:
* `view_name` - The view name that should be used as the target of the relationship. **required**.
* `required` - If set to `False`, the field will accept values of `None` or the empty-string for nullable relationships.
* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`.
* `slug_field` - The field on the target that should be used for the lookup. Default is `'slug'`.
* `pk_url_kwarg` - The named url parameter for the pk field lookup. Default is `pk`.
* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`.
* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument.
## SlugRelatedField
`SlugRelatedField` may be used to represent the target of the relationship using a field on the target.
For example, the following serializer:
class AlbumSerializer(serializer.ModelSerializer):
tracks = SlugRelatedField(many=True, read_only=True, slug_field='title')
class Meta:
model = Album
fields = ('album_name', 'artist', 'tracks')
Would serialize to a representation like this:
{
'album_name': 'Dear John',
'artist': 'Loney Dear'
'tracks': [
'Airport Surroundings',
'Everything Turns to You',
'I Was Only Going Out',
...
]
}
By default this field is read-write, although you can change this behavior using the `read_only` flag.
When using `SlugRelatedField` as a read-write field, you will normally want to ensure that the slug field corresponds to a model field with `unique=True`.
**Arguments**:
* `slug_field` - The field on the target that should be used to represent it. This should be a field that uniquely identifies any given instance. For example, `username`.
* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`.
* `null` - If set to `True`, the field will accept values of `None` or the empty-string for nullable relationships.
## HyperLinkedIdentityField
This field can be applied as an identity relationship, such as the `'url'` field on a HyperlinkedModelSerializer. It can also be used for an attribute on the object. For example, the following serializer:
class AlbumSerializer(serializers.HyperlinkedModelSerializer):
track_listing = HyperLinkedIdentityField(view_name='track-list')
class Meta:
model = Album
fields = ('album_name', 'artist', 'track_listing')
Would serialize to a representation like this:
{
'album_name': 'The Eraser',
'artist': 'Thom Yorke'
'track_listing': 'http://www.example.com/api/track_list/12',
}
This field is always read-only.
**Arguments**:
* `view_name` - The view name that should be used as the target of the relationship. **required**.
* `slug_field` - The field on the target that should be used for the lookup. Default is `'slug'`.
* `pk_url_kwarg` - The named url parameter for the pk field lookup. Default is `pk`.
* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`.
* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument.
## Nested relationships
Nested relationships can be expressed by using serializers as fields. For example:
class TrackSerializer(serializer.ModelSerializer):
class Meta:
fields = ('order', 'title')
class AlbumSerializer(serializer.ModelSerializer):
tracks = TrackSerializer(many=True)
class Meta:
model = Album
fields = ('album_name', 'artist', 'tracks')
Note that nested relationships are currently read-only. For read-write relationships, you should use a flat relational style.
## Custom relational fields
To implement a custom relational field, you should override `RelatedField`, and implement the `.to_native(self, value)` method. This method takes the target of the field as the `value` argument, and should return the representation that should be used to serialize the target.
class TrackListingField(serializers.RelatedField):
def to_native(self, value):
return 'Track %d: %s' % (value.ordering, value.name)
If you want to implement a read-write relational field, you must also implement the `.from_native(self, data)` method, and add `read_only = False` to the class definition.
# Further notes
## Reverse relations
Note that reverse relationships are not automatically generated by the `ModelSerializer` and `HyperlinkedModelSerializer` classes. To include a reverse relationship, you cannot simply add it to the fields list.
**The following will not work:**
class AlbumSerializer(serializer.ModelSerializer):
class Meta:
fields = ('tracks', ...)
Instead, you must explicitly add it to the serializer. For example:
class AlbumSerializer(serializer.ModelSerializer):
tracks = serializers.PrimaryKeyRelationship(many=True)
...
By default, the field will uses the same accessor as it's field name to retrieve the relationship, so in this example, `Album` instances would need to have the `tracks` attribute for this relationship to work.
The best way to ensure this is typically to make sure that the relationship on the model definition has it's `related_name` argument properly set. For example:
class Track(models.Model):
album = models.ForeignKey(Album, related_name='tracks')
...
Alternatively, you can use the `source` argument on the serializer field, to use a different accessor attribute than the field name. For example.
class AlbumSerializer(serializer.ModelSerializer):
tracks = serializers.PrimaryKeyRelationship(many=True, source='track_set')
See the Django documentation on [reverse relationships][reverse-relationships] for more details.
## Generic relationships
If you want to serialize a generic foreign key, you need to define a custom field, to determine explicitly how you want serialize the targets of the relationship.
For example, given the following model for a tag, which has a generic relationship with other arbitrary models:
class TaggedItem(models.Model): class TaggedItem(models.Model):
""" """
@ -48,14 +273,15 @@ For example, given the following models:
See: https://docs.djangoproject.com/en/dev/ref/contrib/contenttypes/ See: https://docs.djangoproject.com/en/dev/ref/contrib/contenttypes/
""" """
tag = models.SlugField() tag_name = models.SlugField()
content_type = models.ForeignKey(ContentType) content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField() object_id = models.PositiveIntegerField()
content_object = GenericForeignKey('content_type', 'object_id') tagged_object = GenericForeignKey('content_type', 'object_id')
def __unicode__(self): def __unicode__(self):
return self.tag return self.tag
And the following two models, which may be have associated tags:
class Bookmark(models.Model): class Bookmark(models.Model):
""" """
@ -64,76 +290,65 @@ For example, given the following models:
url = models.URLField() url = models.URLField()
tags = GenericRelation(TaggedItem) tags = GenericRelation(TaggedItem)
And a model serializer defined like this:
class BookmarkSerializer(serializers.ModelSerializer): class Note(models.Model):
tags = serializers.ManyRelatedField(source='tags') """
A note consists of some text, and 0 or more descriptive tags.
"""
text = models.CharField(max_length=1000)
tags = GenericRelation(TaggedItem)
class Meta: We could define a custom field that could be used to serialize tagged instances, using the type of each instance to determine how it should be serialized.
model = Bookmark
exclude = ('id',)
Then an example output format for a Bookmark instance would be: class TaggedObjectRelatedField(serializers.RelatedField):
"""
A custom field to use for the `tagged_object` generic relationship.
"""
{ def to_native(self, value):
'tags': [u'django', u'python'], """
'url': u'https://www.djangoproject.com/' Serialize tagged objects to a simple textual representation.
} """
if isinstance(value, Bookmark):
return 'Bookmark: ' + value.url
elif isinstance(value, Note):
return 'Note: ' + value.text
raise Exception('Unexpected type of tagged object')
## PrimaryKeyRelatedField If you need the target of the relationship to have a nested representation, you can use the required serializers inside the `.to_native()` method:
## ManyPrimaryKeyRelatedField
`PrimaryKeyRelatedField` and `ManyPrimaryKeyRelatedField` will represent the target of the relationship using it's primary key. def to_native(self, value):
"""
Serialize bookmark instances using a bookmark serializer,
and note instances using a note serializer.
"""
if isinstance(value, Bookmark):
serializer = BookmarkSerializer(value)
elif isinstance(value, Note):
serializer = NoteSerializer(value)
else:
raise Exception('Unexpected type of tagged object')
By default these fields are read-write, although you can change this behavior using the `read_only` flag. return serializer.data
**Arguments**: Note that reverse generic keys, expressed using the `GenericRelation` field, can be serialized using the regular relational field types, since the type of the target in the relationship is always known.
* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. For more information see [the Django documentation on generic relations][generic-relations].
* `null` - If set to `True`, the field will accept values of `None` or the empty-string for nullable relationships.
## SlugRelatedField ---
## ManySlugRelatedField
`SlugRelatedField` and `ManySlugRelatedField` will represent the target of the relationship using a unique slug. ## Deprecated relational fields
By default these fields read-write, although you can change this behavior using the `read_only` flag. The following classes have been deprecated, in favor of the `many=<bool>` syntax.
They continue to function, but their usage will raise a `PendingDeprecationWarning`, which is silent by default.
In the 2.3 release, this warning will be escalated to a `DeprecationWarning`.
In the 2.4 release, they will be removed entirely.
**Arguments**: * `ManyRelatedField`
* `ManyPrimaryKeyRelatedField`
* `slug_field` - The field on the target that should be used to represent it. This should be a field that uniquely identifies any given instance. For example, `username`. * `ManyHyperlinkedRelatedField`
* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. * `ManySlugRelatedField`
* `null` - If set to `True`, the field will accept values of `None` or the empty-string for nullable relationships.
## HyperlinkedRelatedField
## ManyHyperlinkedRelatedField
`HyperlinkedRelatedField` and `ManyHyperlinkedRelatedField` will represent the target of the relationship using a hyperlink.
By default, `HyperlinkedRelatedField` is read-write, although you can change this behavior using the `read_only` flag.
**Arguments**:
* `view_name` - The view name that should be used as the target of the relationship. **required**.
* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument.
* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`.
* `slug_field` - The field on the target that should be used for the lookup. Default is `'slug'`.
* `pk_url_kwarg` - The named url parameter for the pk field lookup. Default is `pk`.
* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`.
* `null` - If set to `True`, the field will accept values of `None` or the empty-string for nullable relationships.
## HyperLinkedIdentityField
This field can be applied as an identity relationship, such as the `'url'` field on a HyperlinkedModelSerializer.
This field is always read-only.
**Arguments**:
* `view_name` - The view name that should be used as the target of the relationship. **required**.
* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument.
* `slug_field` - The field on the target that should be used for the lookup. Default is `'slug'`.
* `pk_url_kwarg` - The named url parameter for the pk field lookup. Default is `pk`.
* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`.
[cite]: http://lwn.net/Articles/193245/ [cite]: http://lwn.net/Articles/193245/
[reverse-relationships]: https://docs.djangoproject.com/en/dev/topics/db/queries/#following-relationships-backward
[generic-relations]: https://docs.djangoproject.com/en/dev/ref/contrib/contenttypes/#id1

View File

@ -80,7 +80,7 @@ Renders the request data into `JSONP`. The `JSONP` media type provides a mechan
The javascript callback function must be set by the client including a `callback` URL query parameter. For example `http://example.com/api/users?callback=jsonpCallback`. If the callback function is not explicitly set by the client it will default to `'callback'`. The javascript callback function must be set by the client including a `callback` URL query parameter. For example `http://example.com/api/users?callback=jsonpCallback`. If the callback function is not explicitly set by the client it will default to `'callback'`.
**Note**: If you require cross-domain AJAX requests, you may also want to consider using [CORS] as an alternative to `JSONP`. **Note**: If you require cross-domain AJAX requests, you may want to consider using the more modern approach of [CORS][cors] as an alternative to `JSONP`. See the [CORS documentation][cors-docs] for more details.
**.media_type**: `application/javascript` **.media_type**: `application/javascript`
@ -288,7 +288,8 @@ Comma-separated values are a plain-text tabular data format, that can be easily
[cite]: https://docs.djangoproject.com/en/dev/ref/template-response/#the-rendering-process [cite]: https://docs.djangoproject.com/en/dev/ref/template-response/#the-rendering-process
[conneg]: content-negotiation.md [conneg]: content-negotiation.md
[browser-accept-headers]: http://www.gethifi.com/blog/browser-rest-http-accept-headers [browser-accept-headers]: http://www.gethifi.com/blog/browser-rest-http-accept-headers
[CORS]: http://en.wikipedia.org/wiki/Cross-origin_resource_sharing [cors]: http://www.w3.org/TR/cors/
[cors-docs]: ../topics/ajax-csrf-cors.md
[HATEOAS]: http://timelessrepo.com/haters-gonna-hateoas [HATEOAS]: http://timelessrepo.com/haters-gonna-hateoas
[quote]: http://roy.gbiv.com/untangled/2008/rest-apis-must-be-hypertext-driven [quote]: http://roy.gbiv.com/untangled/2008/rest-apis-must-be-hypertext-driven
[application/vnd.github+json]: http://developer.github.com/v3/media/ [application/vnd.github+json]: http://developer.github.com/v3/media/

View File

@ -83,13 +83,13 @@ You won't typically need to access this property.
# Browser enhancements # Browser enhancements
REST framework supports a few browser enhancements such as browser-based `PUT` and `DELETE` forms. REST framework supports a few browser enhancements such as browser-based `PUT`, `PATCH` and `DELETE` forms.
## .method ## .method
`request.method` returns the **uppercased** string representation of the request's HTTP method. `request.method` returns the **uppercased** string representation of the request's HTTP method.
Browser-based `PUT` and `DELETE` forms are transparently supported. Browser-based `PUT`, `PATCH` and `DELETE` forms are transparently supported.
For more information see the [browser enhancements documentation]. For more information see the [browser enhancements documentation].

View File

@ -190,18 +190,12 @@ By default field values are treated as mapping to an attribute on the object. I
As an example, let's create a field that can be used represent the class name of the object being serialized: As an example, let's create a field that can be used represent the class name of the object being serialized:
class ClassNameField(serializers.WritableField): class ClassNameField(serializers.Field):
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
""" """
Serialize the object's class name, not an attribute of the object. Serialize the object's class name.
""" """
return obj.__class__.__name__ return obj.__class__
def field_from_native(self, data, field_name, into):
"""
We don't want to set anything when we revert this field.
"""
pass
--- ---

View File

@ -150,8 +150,16 @@ User requests to either `ContactListView` or `ContactDetailView` would be restri
# Custom throttles # Custom throttles
To create a custom throttle, override `BaseThrottle` and implement `.allow_request(request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise. To create a custom throttle, override `BaseThrottle` and implement `.allow_request(self, request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise.
Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recommended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`. Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recommended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`.
## Example
The following is an example of a rate throttle, that will randomly throttle 1 in every 10 requests.
class RandomRateThrottle(throttles.BaseThrottle):
def allow_request(self, request, view):
return random.randint(1, 10) == 1
[permissions]: permissions.md [permissions]: permissions.md

View File

@ -85,7 +85,7 @@ The following methods are called before dispatching to the handler method.
## Dispatch methods ## Dispatch methods
The following methods are called directly by the view's `.dispatch()` method. The following methods are called directly by the view's `.dispatch()` method.
These perform any actions that need to occur before or after calling the handler methods such as `.get()`, `.post()`, `put()` and `.delete()`. These perform any actions that need to occur before or after calling the handler methods such as `.get()`, `.post()`, `put()`, `patch()` and `.delete()`.
### .initial(self, request, \*args, **kwargs) ### .initial(self, request, \*args, **kwargs)

View File

@ -25,18 +25,29 @@ pre {
margin-top: 9px; margin-top: 9px;
} }
body.index-page #main-content p.badges {
padding-bottom: 1px;
}
/* GitHub 'Star' badge */ /* GitHub 'Star' badge */
body.index-page #main-content iframe { body.index-page #main-content iframe.github-star-button {
float: right; float: right;
margin-top: -12px; margin-top: -12px;
margin-right: -15px; margin-right: -15px;
} }
/* Tweet button */
body.index-page #main-content iframe.twitter-share-button {
float: right;
margin-top: -12px;
margin-right: 8px;
}
/* Travis CI badge */ /* Travis CI badge */
body.index-page #main-content p:first-of-type { body.index-page #main-content img.travis-build-image {
float: right; float: right;
margin-right: 8px; margin-right: 8px;
margin-top: -14px; margin-top: -9px;
margin-bottom: 0px; margin-bottom: 0px;
} }

View File

@ -1,16 +1,16 @@
<iframe src="http://ghbtns.com/github-btn.html?user=tomchristie&amp;repo=django-rest-framework&amp;type=watch&amp;count=true" allowtransparency="true" frameborder="0" scrolling="0" width="110px" height="20px"></iframe> <p class="badges">
[![Travis build image][travis-build-image]][travis] <iframe src="http://ghbtns.com/github-btn.html?user=tomchristie&amp;repo=django-rest-framework&amp;type=watch&amp;count=true" class="github-star-button" allowtransparency="true" frameborder="0" scrolling="0" width="110px" height="20px"></iframe>
<a href="https://twitter.com/share" class="twitter-share-button" data-url="django-rest-framework.org" data-text="Current status: Checking out the totally awesome Django REST framework! http://django-rest-framework.org" data-count="none">Tweet</a>
<script>!function(d,s,id){var js,fjs=d.getElementsByTagName(s)[0];if(!d.getElementById(id)){js=d.createElement(s);js.id=id;js.src="http://platform.twitter.com/widgets.js";fjs.parentNode.insertBefore(js,fjs);}}(document,"script","twitter-wjs");</script>
<img alt="Travis build image" src="https://secure.travis-ci.org/tomchristie/django-rest-framework.png?branch=master" class="travis-build-image">
</p>
# Django REST framework # Django REST framework
**A toolkit for building well-connected, self-describing Web APIs.** **A toolkit for building well-connected, self-describing Web APIs.**
---
**Note**: This documentation is for the 2.0 version of REST framework. If you are looking for earlier versions please see the [0.4.x branch][0.4] on GitHub.
---
Django REST framework is a lightweight library that makes it easy to build Web APIs. It is designed as a modular and easy to customize architecture, based on Django's class based views. Django REST framework is a lightweight library that makes it easy to build Web APIs. It is designed as a modular and easy to customize architecture, based on Django's class based views.
Web APIs built using REST framework are fully self-describing and web browseable - a huge useability win for your developers. It also supports a wide range of media types, authentication and permission policies out of the box. Web APIs built using REST framework are fully self-describing and web browseable - a huge useability win for your developers. It also supports a wide range of media types, authentication and permission policies out of the box.
@ -27,7 +27,7 @@ There is also a sandbox API you can use for testing purposes, [available here][s
REST framework requires the following: REST framework requires the following:
* Python (2.6, 2.7) * Python (2.6, 2.7, 3.2, 3.3)
* Django (1.3, 1.4, 1.5) * Django (1.3, 1.4, 1.5)
The following packages are optional: The following packages are optional:
@ -111,6 +111,7 @@ The API guide is your complete reference manual to all the functionality provide
General guides to using REST framework. General guides to using REST framework.
* [AJAX, CSRF & CORS][ajax-csrf-cors]
* [Browser enhancements][browser-enhancements] * [Browser enhancements][browser-enhancements]
* [The Browsable API][browsableapi] * [The Browsable API][browsableapi]
* [REST, Hypermedia & HATEOAS][rest-hypermedia-hateoas] * [REST, Hypermedia & HATEOAS][rest-hypermedia-hateoas]
@ -132,9 +133,14 @@ Run the tests:
## Support ## Support
For support please see the [REST framework discussion group][group], or try the `#restframework` channel on `irc.freenode.net`. For support please see the [REST framework discussion group][group], try the `#restframework` channel on `irc.freenode.net`, or raise a question on [Stack Overflow][stack-overflow], making sure to include the ['django-rest-framework'][django-rest-framework-tag] tag.
Paid support is also available from [DabApps], and can include work on REST framework core, or support with building your REST framework API. Please contact [Tom Christie][email] if you'd like to discuss commercial support options. [Paid support is available][paid-support] from [DabApps][dabapps], and can include work on REST framework core, or support with building your REST framework API. Please [contact DabApps][contact-dabapps] if you'd like to discuss commercial support options.
For updates on REST framework development, you may also want to follow [the author][twitter] on Twitter.
<a style="padding-top: 10px" href="https://twitter.com/_tomchristie" class="twitter-follow-button" data-show-count="false">Follow @_tomchristie</a>
<script>!function(d,s,id){var js,fjs=d.getElementsByTagName(s)[0];if(!d.getElementById(id)){js=d.createElement(s);js.id=id;js.src="//platform.twitter.com/widgets.js";fjs.parentNode.insertBefore(js,fjs);}}(document,"script","twitter-wjs");</script>
## License ## License
@ -199,7 +205,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
[status]: api-guide/status-codes.md [status]: api-guide/status-codes.md
[settings]: api-guide/settings.md [settings]: api-guide/settings.md
[csrf]: topics/csrf.md [ajax-csrf-cors]: topics/ajax-csrf-cors.md
[browser-enhancements]: topics/browser-enhancements.md [browser-enhancements]: topics/browser-enhancements.md
[browsableapi]: topics/browsable-api.md [browsableapi]: topics/browsable-api.md
[rest-hypermedia-hateoas]: topics/rest-hypermedia-hateoas.md [rest-hypermedia-hateoas]: topics/rest-hypermedia-hateoas.md
@ -209,5 +215,10 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
[credits]: topics/credits.md [credits]: topics/credits.md
[group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework [group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework
[DabApps]: http://dabapps.com [stack-overflow]: http://stackoverflow.com/
[email]: mailto:tom@tomchristie.com [django-rest-framework-tag]: http://stackoverflow.com/questions/tagged/django-rest-framework
[django-tag]: http://stackoverflow.com/questions/tagged/django
[paid-support]: http://dabapps.com/services/build/api-development/
[dabapps]: http://dabapps.com
[contact-dabapps]: http://dabapps.com/contact/
[twitter]: https://twitter.com/_tomchristie

View File

@ -89,6 +89,7 @@
<li class="dropdown"> <li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Topics <b class="caret"></b></a> <a href="#" class="dropdown-toggle" data-toggle="dropdown">Topics <b class="caret"></b></a>
<ul class="dropdown-menu"> <ul class="dropdown-menu">
<li><a href="{{ base_url }}/topics/ajax-csrf-cors{{ suffix }}">AJAX, CSRF & CORS</a></li>
<li><a href="{{ base_url }}/topics/browser-enhancements{{ suffix }}">Browser enhancements</a></li> <li><a href="{{ base_url }}/topics/browser-enhancements{{ suffix }}">Browser enhancements</a></li>
<li><a href="{{ base_url }}/topics/browsable-api{{ suffix }}">The Browsable API</a></li> <li><a href="{{ base_url }}/topics/browsable-api{{ suffix }}">The Browsable API</a></li>
<li><a href="{{ base_url }}/topics/rest-hypermedia-hateoas{{ suffix }}">REST, Hypermedia & HATEOAS</a></li> <li><a href="{{ base_url }}/topics/rest-hypermedia-hateoas{{ suffix }}">REST, Hypermedia & HATEOAS</a></li>

View File

@ -0,0 +1,113 @@
# REST framework 2.2 release notes
The 2.2 release represents an important point for REST framework, with the addition of Python 3 support, and the introduction of an official deprecation policy.
## Python 3 support
Thanks to some fantastic work from [Xavier Ordoquy][xordoquy], Django REST framework 2.2 now supports Python 3. You'll need to be running Django 1.5, and it's worth keeping in mind that Django's Python 3 support is currently [considered experimental][django-python-3].
Django 1.6's Python 3 support is expected to be officially labeled as 'production-ready'.
If you want to start ensuring that your own projects are Python 3 ready, we can highly recommend Django's [Porting to Python 3][porting-python-3] documentation.
## Deprecation policy
We've now introduced an official deprecation policy, which is in line with [Django's deprecation policy][django-deprecation-policy]. This policy will make it easy for you to continue to track the latest, greatest version of REST framework.
The timeline for deprecation works as follows:
* Version 2.2 introduces some API changes as detailed in the release notes. It remains fully backwards compatible with 2.1, but will raise `PendingDeprecationWarning` warnings if you use bits API that are due to be deprecated. These warnings are silent by default, but can be explicitly enabled when you're ready to start migrating any required changes. For example if you start running your tests using `python -Wd manage.py test`, you'll be warned of any API changes you need to make.
* Version 2.3 will escalate these warnings to `DeprecationWarning`, which is loud by default.
* Version 2.4 will remove the deprecated bits of API entirely.
Note that in line with Django's policy, any parts of the framework not mentioned in the documentation should generally be considered private API, and may be subject to change.
## Community
As of the 2.2 merge, we've also hit an impressive milestone. The number of committers listed in [the credits][credits], is now at over **one hundred individuals**. Each name on that list represents at least one merged pull request, however large or small.
Our [mailing list][mailing-list] and #restframework IRC channel are also very active, and we've got a really impressive rate of development both on REST framework itself, and on third party packages such as the great [django-rest-framework-docs][django-rest-framework-docs] package from [Marc Gibbons][marcgibbons].
## Issue management
All the design work that went into version 2 of Django REST framework has made keeping on top of issues much easier. We've been super-focused on keeping the [issues list][issues] strictly under control, and we've hit another important milestone. At the point of releasing 2.2 there are currently **no open 'bug' tickets**, and the plan is to keep it that way for as much of the time as possible.
## API changes
The 2.2 release makes a few changes to the serializer fields API, in order to make it more consistent, simple, and easier to use.
### Cleaner to-many related fields
The `ManyRelatedField()` style is being deprecated in favor of a new `RelatedField(many=True)` syntax.
For example, if a user is associated with multiple questions, which we want to represent using a primary key relationship, we might use something like the following:
class UserSerializer(serializers.HyperlinkedModelSerializer):
questions = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
fields = ('username', 'questions')
The new syntax is cleaner and more obvious, and the change will also make the documentation cleaner, simplify the internal API, and make writing custom relational fields easier.
The change also applies to serializers. If you have a nested serializer, you should start using `many=True` for to-many relationships. For example, a serializer representation of an Album that can contain many Tracks might look something like this:
class TrackSerializer(serializer.ModelSerializer):
class Meta:
model = Track
fields = ('name', 'duration')
class AlbumSerializer(serializer.ModelSerializer):
tracks = TrackSerializer(many=True)
class Meta:
model = Album
fields = ('album_name', 'artist', 'tracks')
Additionally, the change also applies when serializing or deserializing data. For example to serialize a queryset of models you should now use the `many=True` flag.
serializer = SnippetSerializer(Snippet.objects.all(), many=True)
serializer.data
This more explicit behavior on serializing and deserializing data [makes integration with non-ORM backends such as MongoDB easier][564], as instances to be serialized can include the `__iter__` method, without incorrectly triggering list-based serialization, or requiring workarounds.
The implicit to-many behavior on serializers, and the `ManyRelatedField` style classes will continue to function, but will raise a `PendingDeprecationWarning`, which can be made visible using the `-Wd` flag.
**Note**: If you need to forcibly turn off the implict "`many=True` for `__iter__` objects" behavior, you can now do so by specifying `many=False`. This will become the default (instead of the current default of `None`) once the deprecation of the implicit behavior is finalised in version 2.4.
### Cleaner optional relationships
Serializer relationships for nullable Foreign Keys will change from using the current `null=True` flag, to instead using `required=False`.
For example, is a user account has an optional foreign key to a company, that you want to express using a hyperlink, you might use the following field in a `Serializer` class:
current_company = serializers.HyperlinkedRelatedField(required=False)
This is in line both with the rest of the serializer fields API, and with Django's `Form` and `ModelForm` API.
Using `required` throughout the serializers API means you won't need to consider if a particular field should take `blank` or `null` arguments instead of `required`, and also means there will be more consistent behavior for how fields are treated when they are not present in the incoming data.
The `null=True` argument will continue to function, and will imply `required=False`, but will raise a `PendingDeprecationWarning`.
### Cleaner CharField syntax
The `CharField` API previously took an optional `blank=True` argument, which was intended to differentiate between null CharField input, and blank CharField input.
In keeping with Django's CharField API, REST framework's `CharField` will only ever return the empty string, for missing or `None` inputs. The `blank` flag will no longer be in use, and you should instead just use the `required=<bool>` flag. For example:
extra_details = CharField(required=False)
The `blank` keyword argument will continue to function, but will raise a `PendingDeprecationWarning`.
[xordoquy]: https://github.com/xordoquy
[django-python-3]: https://docs.djangoproject.com/en/dev/faq/install/#can-i-use-django-with-python-3
[porting-python-3]: https://docs.djangoproject.com/en/dev/topics/python3/
[django-deprecation-policy]: https://docs.djangoproject.com/en/dev/internals/release-process/#internal-release-deprecation-policy
[credits]: http://django-rest-framework.org/topics/credits.html
[mailing-list]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework
[django-rest-framework-docs]: https://github.com/marcgibbons/django-rest-framework-docs
[marcgibbons]: https://github.com/marcgibbons/
[issues]: https://github.com/tomchristie/django-rest-framework/issues
[564]: https://github.com/tomchristie/django-rest-framework/issues/564

View File

@ -0,0 +1,41 @@
# Working with AJAX, CSRF & CORS
> "Take a close look at possible CSRF / XSRF vulnerabilities on your own websites. They're the worst kind of vulnerability &mdash; very easy to exploit by attackers, yet not so intuitively easy to understand for software developers, at least until you've been bitten by one."
>
> &mdash; [Jeff Atwood][cite]
## Javascript clients
If your building a javascript client to interface with your Web API, you'll need to consider if the client can use the same authentication policy that is used by the rest of the website, and also determine if you need to use CSRF tokens or CORS headers.
AJAX requests that are made within the same context as the API they are interacting with will typically use `SessionAuthentication`. This ensures that once a user has logged in, any AJAX requests made can be authenticated using the same session-based authentication that is used for the rest of the website.
AJAX requests that are made on a different site from the API they are communicating with will typically need to use a non-session-based authentication scheme, such as `TokenAuthentication`.
## CSRF protection
[Cross Site Request Forgery][csrf] protection is a mechanism of guarding against a particular type of attack, which can occur when a user has not logged out of a web site, and continues to have a valid session. In this circumstance a malicious site may be able to perform actions against the target site, within the context of the logged-in session.
To guard against these type of attacks, you need to do two things:
1. Ensure that the 'safe' HTTP operations, such as `GET`, `HEAD` and `OPTIONS` cannot be used to alter any server-side state.
2. Ensure that any 'unsafe' HTTP operations, such as `POST`, `PUT`, `PATCH` and `DELETE`, always require a valid CSRF token.
If you're using `SessionAuthentication` you'll need to include valid CSRF tokens for any `POST`, `PUT`, `PATCH` or `DELETE` operations.
The Django documentation describes how to [include CSRF tokens in AJAX requests][csrf-ajax].
## CORS
[Cross-Origin Resource Sharing][cors] is a mechanism for allowing clients to interact with APIs that are hosted on a different domain. CORS works by requiring the server to include a specific set of headers that allow a browser to determine if and when cross-domain requests should be allowed.
The best way to deal with CORS in REST framework is to add the required response headers in middleware. This ensures that CORS is supported transparently, without having to change any behavior in your views.
[Otto Yiu][ottoyiu] maintains the [django-cors-headers] package, which is known to work correctly with REST framework APIs.
[cite]: http://www.codinghorror.com/blog/2008/10/preventing-csrf-and-xsrf-attacks.html
[csrf]: https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)
[csrf-ajax]: https://docs.djangoproject.com/en/dev/ref/contrib/csrf/#ajax
[cors]: http://www.w3.org/TR/cors/
[ottoyiu]: https://github.com/ottoyiu/
[django-cors-headers]: https://github.com/ottoyiu/django-cors-headers/

View File

@ -35,23 +35,20 @@ A suitable replacement theme can be generated using Bootstrap's [Customize Tool]
You can also change the navbar variant, which by default is `navbar-inverse`, using the `bootstrap_navbar_variant` block. The empty `{% block bootstrap_navbar_variant %}{% endblock %}` will use the original Bootstrap navbar style. You can also change the navbar variant, which by default is `navbar-inverse`, using the `bootstrap_navbar_variant` block. The empty `{% block bootstrap_navbar_variant %}{% endblock %}` will use the original Bootstrap navbar style.
For more specific CSS tweaks, use the `extra_style` block instead. For more specific CSS tweaks, use the `style` block instead.
### Blocks ### Blocks
All of the blocks available in the browsable API base template that can be used in your `api.html`. All of the blocks available in the browsable API base template that can be used in your `api.html`.
* `blockbots` - `<meta>` tag that blocks crawlers
* `bodyclass` - (empty) class attribute for the `<body>` * `bodyclass` - (empty) class attribute for the `<body>`
* `bootstrap_theme` - CSS for the Bootstrap theme * `bootstrap_theme` - CSS for the Bootstrap theme
* `bootstrap_navbar_variant` - CSS class for the navbar * `bootstrap_navbar_variant` - CSS class for the navbar
* `branding` - section of the navbar, see [Bootstrap components][bcomponentsnav] * `branding` - section of the navbar, see [Bootstrap components][bcomponentsnav]
* `breadcrumbs` - Links showing resource nesting, allowing the user to go back up the resources. It's recommended to preserve these, but they can be overridden using the breadcrumbs block. * `breadcrumbs` - Links showing resource nesting, allowing the user to go back up the resources. It's recommended to preserve these, but they can be overridden using the breadcrumbs block.
* `extrastyle` - (empty) extra CSS for the page
* `extrahead` - (empty) extra markup for the page `<head>`
* `footer` - Any copyright notices or similar footer materials can go here (by default right-aligned) * `footer` - Any copyright notices or similar footer materials can go here (by default right-aligned)
* `global_heading` - (empty) Use to insert content below the header but before the breadcrumbs. * `style` - CSS stylesheets for the page
* `title` - title of the page * `title` - title of the page
* `userlinks` - This is a list of links on the right of the header, by default containing login/logout links. To add links instead of replace, use {{ block.super }} to preserve the authentication links. * `userlinks` - This is a list of links on the right of the header, by default containing login/logout links. To add links instead of replace, use {{ block.super }} to preserve the authentication links.

View File

@ -92,6 +92,16 @@ The following people have helped make REST framework great.
* Johannes Spielmann - [shezi] * Johannes Spielmann - [shezi]
* James Cleveland - [radiosilence] * James Cleveland - [radiosilence]
* Steve Gregory - [steve-gregory] * Steve Gregory - [steve-gregory]
* Federico Capoano - [nemesisdesign]
* Bruno Renié - [brutasse]
* Kevin Stone - [kevinastone]
* Guglielmo Celata - [guglielmo]
* Mike Tums - [mktums]
* Michael Elovskikh - [wronglink]
* Michał Jaworski - [swistakm]
* Andrea de Marco - [z4r]
* Fernando Rocha - [fernandogrd]
* Xavier Ordoquy - [xordoquy]
Many thanks to everyone who's contributed to the project. Many thanks to everyone who's contributed to the project.
@ -115,7 +125,6 @@ For usage questions please see the [REST framework discussion group][group].
You can also contact [@_tomchristie][twitter] directly on twitter. You can also contact [@_tomchristie][twitter] directly on twitter.
[email]: mailto:tom@tomchristie.com
[twitter]: http://twitter.com/_tomchristie [twitter]: http://twitter.com/_tomchristie
[bootstrap]: http://twitter.github.com/bootstrap/ [bootstrap]: http://twitter.github.com/bootstrap/
[markdown]: http://daringfireball.net/projects/markdown/ [markdown]: http://daringfireball.net/projects/markdown/
@ -219,3 +228,13 @@ You can also contact [@_tomchristie][twitter] directly on twitter.
[shezi]: https://github.com/shezi [shezi]: https://github.com/shezi
[radiosilence]: https://github.com/radiosilence [radiosilence]: https://github.com/radiosilence
[steve-gregory]: https://github.com/steve-gregory [steve-gregory]: https://github.com/steve-gregory
[nemesisdesign]: https://github.com/nemesisdesign
[brutasse]: https://github.com/brutasse
[kevinastone]: https://github.com/kevinastone
[guglielmo]: https://github.com/guglielmo
[mktums]: https://github.com/mktums
[wronglink]: https://github.com/wronglink
[swistakm]: https://github.com/swistakm
[z4r]: https://github.com/z4r
[fernandogrd]: https://github.com/fernandogrd
[xordoquy]: https://github.com/xordoquy

View File

@ -1,12 +0,0 @@
# Working with AJAX and CSRF
> "Take a close look at possible CSRF / XSRF vulnerabilities on your own websites. They're the worst kind of vulnerability -- very easy to exploit by attackers, yet not so intuitively easy to understand for software developers, at least until you've been bitten by one."
>
> &mdash; [Jeff Atwood][cite]
* Explain need to add CSRF token to AJAX requests.
* Explain deferred CSRF style used by REST framework
* Why you should use Django's standard login/logout views, and not REST framework view
[cite]: http://www.codinghorror.com/blog/2008/10/preventing-csrf-and-xsrf-attacks.html

View File

@ -12,14 +12,42 @@ Medium version numbers (0.x.0) may include minor API changes. You should read t
Major version numbers (x.0.0) are reserved for project milestones. No major point releases are currently planned. Major version numbers (x.0.0) are reserved for project milestones. No major point releases are currently planned.
## Upgrading
To upgrade Django REST framework to the latest version, use pip:
pip install -U djangorestframework
You can determine your currently installed version using `pip freeze`:
pip freeze | grep djangorestframework
--- ---
## 2.1.x series ## 2.1.x series
### Master ### Master
* Added a `post_save()` hook to the generic views.
* Allow serializers to handle dicts as well as objects.
* Bugfix: Fix styling on browsable API login.
* Bugfix: Fix issue with deserializing empty to-many relations.
* Bugfix: Ensure model field validation is still applied for ModelSerializer subclasses with an custom `.restore_object()` method.
### 2.1.17
**Date**: 26th Jan 2013
* Support proper 401 Unauthorized responses where appropriate, instead of always using 403 Forbidden.
* Support json encoding of timedelta objects. * Support json encoding of timedelta objects.
* `format_suffix_patterns()` now supports `include` style URL patterns.
* Bugfix: Fix issues with custom pagination serializers.
* Bugfix: Nested serializers now accept `source='*'` argument.
* Bugfix: Return proper validation errors when incorrect types supplied for relational fields.
* Bugfix: Support nullable FKs with `SlugRelatedField`. * Bugfix: Support nullable FKs with `SlugRelatedField`.
* Bugfix: Don't call custom validation methods if the field has an error.
**Note**: If the primary authentication class is `TokenAuthentication` or `BasicAuthentication`, a view will now correctly return 401 responses to unauthenticated access, with an appropriate `WWW-Authenticate` header, instead of 403 responses.
### 2.1.16 ### 2.1.16

View File

@ -4,7 +4,7 @@
This tutorial will cover creating a simple pastebin code highlighting Web API. Along the way it will introduce the various components that make up REST framework, and give you a comprehensive understanding of how everything fits together. This tutorial will cover creating a simple pastebin code highlighting Web API. Along the way it will introduce the various components that make up REST framework, and give you a comprehensive understanding of how everything fits together.
The tutorial is fairly in-depth, so you should probably get a cookie and a cup of your favorite brew before getting started.<!-- If you just want a quick overview, you should head over to the [quickstart] documentation instead. --> The tutorial is fairly in-depth, so you should probably get a cookie and a cup of your favorite brew before getting started. If you just want a quick overview, you should head over to the [quickstart] documentation instead.
--- ---
@ -109,7 +109,7 @@ The first thing we need to get started on our Web API is provide a way of serial
from django.forms import widgets from django.forms import widgets
from rest_framework import serializers from rest_framework import serializers
from snippets import models from snippets.models import Snippet
class SnippetSerializer(serializers.Serializer): class SnippetSerializer(serializers.Serializer):
@ -130,15 +130,15 @@ The first thing we need to get started on our Web API is provide a way of serial
""" """
if instance: if instance:
# Update existing instance # Update existing instance
instance.title = attrs['title'] instance.title = attrs.get('title', instance.title)
instance.code = attrs['code'] instance.code = attrs.get('code', instance.code)
instance.linenos = attrs['linenos'] instance.linenos = attrs.get('linenos', instance.linenos)
instance.language = attrs['language'] instance.language = attrs.get('language', instance.language)
instance.style = attrs['style'] instance.style = attrs.get('style', instance.style)
return instance return instance
# Create new instance # Create new instance
return models.Snippet(**attrs) return Snippet(**attrs)
The first part of serializer class defines the fields that get serialized/deserialized. The `restore_object` method defines how fully fledged instances get created when deserializing data. The first part of serializer class defines the fields that get serialized/deserialized. The `restore_object` method defines how fully fledged instances get created when deserializing data.

View File

@ -22,7 +22,7 @@ We'd also need to make sure that when the model is saved, that we populate the h
We'll need some extra imports: We'll need some extra imports:
from pygments.lexers import get_lexer_by_name from pygments.lexers import get_lexer_by_name
from pygments.formatters import HtmlFormatter from pygments.formatters.html import HtmlFormatter
from pygments import highlight from pygments import highlight
And now we can add a `.save()` method to our model class: And now we can add a `.save()` method to our model class:
@ -54,6 +54,8 @@ You might also want to create a few different users, to use for testing the API.
Now that we've got some users to work with, we'd better add representations of those users to our API. Creating a new serializer is easy: Now that we've got some users to work with, we'd better add representations of those users to our API. Creating a new serializer is easy:
from django.contrib.auth.models import User
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
snippets = serializers.ManyPrimaryKeyRelatedField() snippets = serializers.ManyPrimaryKeyRelatedField()
@ -164,7 +166,8 @@ In the snippets app, create a new file, `permissions.py`
if obj is None: if obj is None:
return True return True
# Read permissions are allowed to any request # Read permissions are allowed to any request,
# so we'll always allow GET, HEAD or OPTIONS requests.
if request.method in permissions.SAFE_METHODS: if request.method in permissions.SAFE_METHODS:
return True return True

View File

@ -165,7 +165,7 @@ We've reached the end of our tutorial. If you want to get more involved in the
* Contribute on [GitHub][github] by reviewing and submitting issues, and making pull requests. * Contribute on [GitHub][github] by reviewing and submitting issues, and making pull requests.
* Join the [REST framework discussion group][group], and help build the community. * Join the [REST framework discussion group][group], and help build the community.
* [Follow the author on Twitter][twitter] and say hi. * Follow [the author][twitter] on Twitter and say hi.
**Now go build awesome things.** **Now go build awesome things.**

View File

@ -1,3 +1,6 @@
__version__ = '2.1.16' __version__ = '2.1.17'
VERSION = __version__ # synonym VERSION = __version__ # synonym
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'

View File

@ -1,10 +1,10 @@
""" """
Provides a set of pluggable authentication policies. Provides a set of pluggable authentication policies.
""" """
from __future__ import unicode_literals
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError from django.utils.encoding import DjangoUnicodeDecodeError
from rest_framework import exceptions from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import CsrfViewMiddleware
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
import base64 import base64
@ -21,32 +21,49 @@ class BaseAuthentication(object):
""" """
raise NotImplementedError(".authenticate() must be overridden.") raise NotImplementedError(".authenticate() must be overridden.")
def authenticate_header(self, request):
"""
Return a string to be used as the value of the `WWW-Authenticate`
header in a `401 Unauthenticated` response, or `None` if the
authentication scheme should return `403 Permission Denied` responses.
"""
pass
class BasicAuthentication(BaseAuthentication): class BasicAuthentication(BaseAuthentication):
""" """
HTTP Basic authentication against username/password. HTTP Basic authentication against username/password.
""" """
www_authenticate_realm = 'api'
def authenticate(self, request): def authenticate(self, request):
""" """
Returns a `User` if a correct username and password have been supplied Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`. using HTTP Basic authentication. Otherwise returns `None`.
""" """
if 'HTTP_AUTHORIZATION' in request.META: auth = request.META.get('HTTP_AUTHORIZATION', b'')
auth = request.META['HTTP_AUTHORIZATION'].split() if type(auth) == type(''):
if len(auth) == 2 and auth[0].lower() == "basic": # Work around django test client oddness
try: auth = auth.encode(HTTP_HEADER_ENCODING)
auth_parts = base64.b64decode(auth[1]).partition(':') auth = auth.split()
except TypeError:
return None
try: if not auth or auth[0].lower() != b'basic':
userid = smart_unicode(auth_parts[0]) return None
password = smart_unicode(auth_parts[2])
except DjangoUnicodeDecodeError:
return None
return self.authenticate_credentials(userid, password) if len(auth) != 2:
raise exceptions.AuthenticationFailed('Invalid basic header')
try:
auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')
except (TypeError, UnicodeDecodeError):
raise exceptions.AuthenticationFailed('Invalid basic header')
try:
userid, password = auth_parts[0], auth_parts[2]
except DjangoUnicodeDecodeError:
raise exceptions.AuthenticationFailed('Invalid basic header')
return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password): def authenticate_credentials(self, userid, password):
""" """
@ -55,6 +72,10 @@ class BasicAuthentication(BaseAuthentication):
user = authenticate(username=userid, password=password) user = authenticate(username=userid, password=password)
if user is not None and user.is_active: if user is not None and user.is_active:
return (user, None) return (user, None)
raise exceptions.AuthenticationFailed('Invalid username/password')
def authenticate_header(self, request):
return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication): class SessionAuthentication(BaseAuthentication):
@ -74,7 +95,7 @@ class SessionAuthentication(BaseAuthentication):
# Unauthenticated, CSRF validation not required # Unauthenticated, CSRF validation not required
if not user or not user.is_active: if not user or not user.is_active:
return return None
# Enforce CSRF validation for session based authentication. # Enforce CSRF validation for session based authentication.
class CSRFCheck(CsrfViewMiddleware): class CSRFCheck(CsrfViewMiddleware):
@ -85,7 +106,7 @@ class SessionAuthentication(BaseAuthentication):
reason = CSRFCheck().process_view(http_request, None, (), {}) reason = CSRFCheck().process_view(http_request, None, (), {})
if reason: if reason:
# CSRF failed, bail with explicit error message # CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
# CSRF passed with authenticated user # CSRF passed with authenticated user
return (user, None) return (user, None)
@ -112,14 +133,26 @@ class TokenAuthentication(BaseAuthentication):
def authenticate(self, request): def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '').split() auth = request.META.get('HTTP_AUTHORIZATION', '').split()
if len(auth) == 2 and auth[0].lower() == "token": if not auth or auth[0].lower() != "token":
key = auth[1] return None
try:
token = self.model.objects.get(key=key) if len(auth) != 2:
except self.model.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token header')
return None
return self.authenticate_credentials(auth[1])
def authenticate_credentials(self, key):
try:
token = self.model.objects.get(key=key)
except self.model.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
if token.user.is_active:
return (token.user, token)
raise exceptions.AuthenticationFailed('User inactive or deleted')
def authenticate_header(self, request):
return 'Token'
if token.user.is_active:
return (token.user, token)
# TODO: OAuthAuthentication # TODO: OAuthAuthentication

View File

@ -19,8 +19,8 @@ class Token(models.Model):
return super(Token, self).save(*args, **kwargs) return super(Token, self).save(*args, **kwargs)
def generate_key(self): def generate_key(self):
unique = str(uuid.uuid4()) unique = uuid.uuid4()
return hmac.new(unique, digestmod=sha1).hexdigest() return hmac.new(unique.bytes, digestmod=sha1).hexdigest()
def __unicode__(self): def __unicode__(self):
return self.key return self.key

View File

@ -3,26 +3,56 @@ The `compat` module provides support for backwards compatibility with older
versions of django/python, and compatibility wrappers around optional packages. versions of django/python, and compatibility wrappers around optional packages.
""" """
# flake8: noqa # flake8: noqa
from __future__ import unicode_literals
import django import django
# Try to import six from Django, fallback to included `six`.
try:
from django.utils import six
except ImportError:
from rest_framework import six
# location of patterns, url, include changes in 1.4 onwards # location of patterns, url, include changes in 1.4 onwards
try: try:
from django.conf.urls import patterns, url, include from django.conf.urls import patterns, url, include
except: except ImportError:
from django.conf.urls.defaults import patterns, url, include from django.conf.urls.defaults import patterns, url, include
# Handle django.utils.encoding rename:
# smart_unicode -> smart_text
# force_unicode -> force_text
try:
from django.utils.encoding import smart_text
except ImportError:
from django.utils.encoding import smart_unicode as smart_text
try:
from django.utils.encoding import force_text
except ImportError:
from django.utils.encoding import force_unicode as force_text
# django-filter is optional # django-filter is optional
try: try:
import django_filters import django_filters
except: except ImportError:
django_filters = None django_filters = None
# cStringIO only if it's available, otherwise StringIO # cStringIO only if it's available, otherwise StringIO
try: try:
import cStringIO as StringIO import cStringIO.StringIO as StringIO
except ImportError: except ImportError:
import StringIO StringIO = six.StringIO
BytesIO = six.BytesIO
# urlparse compat import (Required because it changed in python 3.x)
try:
from urllib import parse as urlparse
except ImportError:
import urlparse
# Try to import PIL in either of the two ways it can end up installed. # Try to import PIL in either of the two ways it can end up installed.
@ -54,7 +84,7 @@ else:
try: try:
from django.contrib.auth.models import User from django.contrib.auth.models import User
except ImportError: except ImportError:
raise ImportError(u"User model is not to be found.") raise ImportError("User model is not to be found.")
# First implementation of Django class-based views did not include head method # First implementation of Django class-based views did not include head method
@ -75,11 +105,11 @@ else:
# sanitize keyword arguments # sanitize keyword arguments
for key in initkwargs: for key in initkwargs:
if key in cls.http_method_names: if key in cls.http_method_names:
raise TypeError(u"You tried to pass in the %s method name as a " raise TypeError("You tried to pass in the %s method name as a "
u"keyword argument to %s(). Don't do that." "keyword argument to %s(). Don't do that."
% (key, cls.__name__)) % (key, cls.__name__))
if not hasattr(cls, key): if not hasattr(cls, key):
raise TypeError(u"%s() received an invalid keyword %r" % ( raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key)) cls.__name__, key))
def view(request, *args, **kwargs): def view(request, *args, **kwargs):
@ -110,7 +140,6 @@ else:
import re import re
import random import random
import logging import logging
import urlparse
from django.conf import settings from django.conf import settings
from django.core.urlresolvers import get_callable from django.core.urlresolvers import get_callable
@ -152,7 +181,8 @@ else:
randrange = random.SystemRandom().randrange randrange = random.SystemRandom().randrange
else: else:
randrange = random.randrange randrange = random.randrange
_MAX_CSRF_KEY = 18446744073709551616L # 2 << 63
_MAX_CSRF_KEY = 18446744073709551616 # 2 << 63
REASON_NO_REFERER = "Referer checking failed - no Referer." REASON_NO_REFERER = "Referer checking failed - no Referer."
REASON_BAD_REFERER = "Referer checking failed - %s does not match %s." REASON_BAD_REFERER = "Referer checking failed - %s does not match %s."
@ -396,3 +426,12 @@ try:
from xml.etree import ParseError as ETParseError from xml.etree import ParseError as ETParseError
except ImportError: # python < 2.7 except ImportError: # python < 2.7
ETParseError = None ETParseError = None
# XMLParser only takes an encoding arg from >= 2.7
def ET_XMLParser(encoding=None):
from xml.etree import ElementTree as ET
try:
return ET.XMLParser(encoding=encoding)
except TypeError:
return ET.XMLParser()

View File

@ -1,4 +1,7 @@
from __future__ import unicode_literals
from rest_framework.compat import six
from rest_framework.views import APIView from rest_framework.views import APIView
import types
def api_view(http_method_names): def api_view(http_method_names):
@ -11,7 +14,7 @@ def api_view(http_method_names):
def decorator(func): def decorator(func):
WrappedAPIView = type( WrappedAPIView = type(
'WrappedAPIView', six.PY3 and 'WrappedAPIView' or b'WrappedAPIView',
(APIView,), (APIView,),
{'__doc__': func.__doc__} {'__doc__': func.__doc__}
) )
@ -23,6 +26,14 @@ def api_view(http_method_names):
# pass # pass
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
# api_view applied without (method_names)
assert not(isinstance(http_method_names, types.FunctionType)), \
'@api_view missing list of allowed HTTP methods'
# api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \
'@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__
allowed_methods = set(http_method_names) | set(('options',)) allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]

View File

@ -4,6 +4,7 @@ Handled exceptions raised by REST framework.
In addition Django's built in 403 and 404 exceptions are handled. In addition Django's built in 403 and 404 exceptions are handled.
(`django.http.Http404` and `django.core.exceptions.PermissionDenied`) (`django.http.Http404` and `django.core.exceptions.PermissionDenied`)
""" """
from __future__ import unicode_literals
from rest_framework import status from rest_framework import status
@ -23,6 +24,22 @@ class ParseError(APIException):
self.detail = detail or self.default_detail self.detail = detail or self.default_detail
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Authentication credentials were not provided.'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
class PermissionDenied(APIException): class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN status_code = status.HTTP_403_FORBIDDEN
default_detail = 'You do not have permission to perform this action.' default_detail = 'You do not have permission to perform this action.'

View File

@ -1,20 +1,23 @@
from __future__ import unicode_literals
import copy import copy
import datetime import datetime
import inspect import inspect
import re import re
import warnings import warnings
from io import BytesIO
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.conf import settings from django.conf import settings
from django import forms from django import forms
from django.forms import widgets from django.forms import widgets
from django.utils.encoding import is_protected_type, smart_unicode from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import parse_date, parse_datetime from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone from rest_framework.compat import timezone
from rest_framework.compat import BytesIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
def is_simple_callable(obj): def is_simple_callable(obj):
@ -27,12 +30,28 @@ def is_simple_callable(obj):
) )
def get_component(obj, attr_name):
"""
Given an object, and an attribute name,
return that attribute on the object.
"""
if isinstance(obj, dict):
val = obj[attr_name]
else:
val = getattr(obj, attr_name)
if is_simple_callable(val):
return val()
return val
class Field(object): class Field(object):
read_only = True read_only = True
creation_counter = 0 creation_counter = 0
empty = '' empty = ''
type_name = None type_name = None
_use_files = None partial = False
use_files = False
form_field_class = forms.CharField form_field_class = forms.CharField
def __init__(self, source=None): def __init__(self, source=None):
@ -53,7 +72,8 @@ class Field(object):
self.parent = parent self.parent = parent
self.root = parent.root or parent self.root = parent.root or parent
self.context = self.root.context self.context = self.root.context
if self.root.partial: self.partial = self.root.partial
if self.partial:
self.required = False self.required = False
def field_from_native(self, data, files, field_name, into): def field_from_native(self, data, files, field_name, into):
@ -77,11 +97,9 @@ class Field(object):
if self.source: if self.source:
value = obj value = obj
for component in self.source.split('.'): for component in self.source.split('.'):
value = getattr(value, component) value = get_component(value, component)
if is_simple_callable(value):
value = value()
else: else:
value = getattr(obj, field_name) value = get_component(obj, field_name)
return self.to_native(value) return self.to_native(value)
def to_native(self, value): def to_native(self, value):
@ -93,11 +111,11 @@ class Field(object):
if is_protected_type(value): if is_protected_type(value):
return value return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)):
return [self.to_native(item) for item in value] return [self.to_native(item) for item in value]
elif isinstance(value, dict): elif isinstance(value, dict):
return dict(map(self.to_native, (k, v)) for k, v in value.items()) return dict(map(self.to_native, (k, v)) for k, v in value.items())
return smart_unicode(value) return smart_text(value)
def attributes(self): def attributes(self):
""" """
@ -124,6 +142,13 @@ class WritableField(Field):
validators=[], error_messages=None, widget=None, validators=[], error_messages=None, widget=None,
default=None, blank=None): default=None, blank=None):
# 'blank' is to be deprecated in favor of 'required'
if blank is not None:
warnings.warn('The `blank` keyword argument is due to deprecated. '
'Use the `required` keyword argument instead.',
PendingDeprecationWarning, stacklevel=2)
required = not(blank)
super(WritableField, self).__init__(source=source) super(WritableField, self).__init__(source=source)
self.read_only = read_only self.read_only = read_only
@ -141,7 +166,6 @@ class WritableField(Field):
self.validators = self.default_validators + validators self.validators = self.default_validators + validators
self.default = default if default is not None else self.default self.default = default if default is not None else self.default
self.blank = blank
# Widgets are ony used for HTML forms. # Widgets are ony used for HTML forms.
widget = widget or self.widget widget = widget or self.widget
@ -180,13 +204,13 @@ class WritableField(Field):
return return
try: try:
if self._use_files: if self.use_files:
files = files or {} files = files or {}
native = files[field_name] native = files[field_name]
else: else:
native = data[field_name] native = data[field_name]
except KeyError: except KeyError:
if self.default is not None and not self.root.partial: if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults # Note: partial updates shouldn't set defaults
native = self.default native = self.default
else: else:
@ -217,7 +241,7 @@ class ModelField(WritableField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
try: try:
self.model_field = kwargs.pop('model_field') self.model_field = kwargs.pop('model_field')
except: except KeyError:
raise ValueError("ModelField requires 'model_field' kwarg") raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length', self.min_length = kwargs.pop('min_length',
@ -258,7 +282,7 @@ class BooleanField(WritableField):
form_field_class = forms.BooleanField form_field_class = forms.BooleanField
widget = widgets.CheckboxInput widget = widgets.CheckboxInput
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."), 'invalid': _("'%s' value must be either True or False."),
} }
empty = False empty = False
@ -287,20 +311,10 @@ class CharField(WritableField):
if max_length is not None: if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length)) self.validators.append(validators.MaxLengthValidator(max_length))
def validate(self, value):
"""
Validates that the value is supplied (if required).
"""
# if empty string and allow blank
if self.blank and not value:
return
else:
super(CharField, self).validate(value)
def from_native(self, value): def from_native(self, value):
if isinstance(value, basestring) or value is None: if isinstance(value, six.string_types) or value is None:
return value return value
return smart_unicode(value) return smart_text(value)
class URLField(CharField): class URLField(CharField):
@ -325,7 +339,8 @@ class ChoiceField(WritableField):
form_field_class = forms.ChoiceField form_field_class = forms.ChoiceField
widget = widgets.Select widget = widgets.Select
default_error_messages = { default_error_messages = {
'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), 'invalid_choice': _('Select a valid choice. %(value)s is not one of '
'the available choices.'),
} }
def __init__(self, choices=(), *args, **kwargs): def __init__(self, choices=(), *args, **kwargs):
@ -359,10 +374,10 @@ class ChoiceField(WritableField):
if isinstance(v, (list, tuple)): if isinstance(v, (list, tuple)):
# This is an optgroup, so look inside the group for options # This is an optgroup, so look inside the group for options
for k2, v2 in v: for k2, v2 in v:
if value == smart_unicode(k2): if value == smart_text(k2):
return True return True
else: else:
if value == smart_unicode(k) or value == k: if value == smart_text(k) or value == k:
return True return True
return False return False
@ -402,7 +417,7 @@ class RegexField(CharField):
return self._regex return self._regex
def _set_regex(self, regex): def _set_regex(self, regex):
if isinstance(regex, basestring): if isinstance(regex, six.string_types):
regex = re.compile(regex) regex = re.compile(regex)
self._regex = regex self._regex = regex
if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
@ -425,10 +440,10 @@ class DateField(WritableField):
form_field_class = forms.DateField form_field_class = forms.DateField
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be " 'invalid': _("'%s' value has an invalid date format. It must be "
u"in YYYY-MM-DD format."), "in YYYY-MM-DD format."),
'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) " 'invalid_date': _("'%s' value has the correct format (YYYY-MM-DD) "
u"but it is an invalid date."), "but it is an invalid date."),
} }
empty = None empty = None
@ -464,13 +479,13 @@ class DateTimeField(WritableField):
form_field_class = forms.DateTimeField form_field_class = forms.DateTimeField
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in " 'invalid': _("'%s' value has an invalid format. It must be in "
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
'invalid_date': _(u"'%s' value has the correct format " 'invalid_date': _("'%s' value has the correct format "
u"(YYYY-MM-DD) but it is an invalid date."), "(YYYY-MM-DD) but it is an invalid date."),
'invalid_datetime': _(u"'%s' value has the correct format " 'invalid_datetime': _("'%s' value has the correct format "
u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
u"but it is an invalid date/time."), "but it is an invalid date/time."),
} }
empty = None empty = None
@ -487,8 +502,8 @@ class DateTimeField(WritableField):
# local time. This won't work during DST change, but we can't # local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the # do much about it, so we let the exceptions percolate up the
# call stack. # call stack.
warnings.warn(u"DateTimeField received a naive datetime (%s)" warnings.warn("DateTimeField received a naive datetime (%s)"
u" while time zone support is active." % value, " while time zone support is active." % value,
RuntimeWarning) RuntimeWarning)
default_timezone = timezone.get_default_timezone() default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone) value = timezone.make_aware(value, default_timezone)
@ -564,7 +579,7 @@ class FloatField(WritableField):
class FileField(WritableField): class FileField(WritableField):
_use_files = True use_files = True
type_name = 'FileField' type_name = 'FileField'
form_field_class = forms.FileField form_field_class = forms.FileField
widget = widgets.FileInput widget = widgets.FileInput
@ -608,11 +623,12 @@ class FileField(WritableField):
class ImageField(FileField): class ImageField(FileField):
_use_files = True use_files = True
form_field_class = forms.ImageField form_field_class = forms.ImageField
default_error_messages = { default_error_messages = {
'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."), 'invalid_image': _("Upload a valid image. The file you uploaded was "
"either not an image or a corrupted image."),
} }
def from_native(self, data): def from_native(self, data):

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
FilterSet = django_filters and django_filters.FilterSet or None FilterSet = django_filters and django_filters.FilterSet or None
@ -54,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class = self.get_filter_class(view) filter_class = self.get_filter_class(view)
if filter_class: if filter_class:
return filter_class(request.GET, queryset=queryset) return filter_class(request.QUERY_PARAMS, queryset=queryset)
return queryset return queryset

View File

@ -1,7 +1,7 @@
""" """
Generic views that provide commonly needed behaviour. Generic views that provide commonly needed behaviour.
""" """
from __future__ import unicode_literals
from rest_framework import views, mixins from rest_framework import views, mixins
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin from django.views.generic.detail import SingleObjectMixin
@ -48,7 +48,7 @@ class GenericAPIView(views.APIView):
return serializer_class return serializer_class
def get_serializer(self, instance=None, data=None, def get_serializer(self, instance=None, data=None,
files=None, partial=False): files=None, many=False, partial=False):
""" """
Return the serializer instance that should be used for validating and Return the serializer instance that should be used for validating and
deserializing input, and for serializing output. deserializing input, and for serializing output.
@ -56,7 +56,21 @@ class GenericAPIView(views.APIView):
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
context = self.get_serializer_context() context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files, return serializer_class(instance, data=data, files=files,
partial=partial, context=context) many=many, partial=partial, context=context)
def pre_save(self, obj):
"""
Placeholder method for calling before saving an object.
May be used eg. to set attributes on the object that are implicit
in either the request, or the url.
"""
pass
def post_save(self, obj, created=False):
"""
Placeholder method for calling after saving an object.
"""
pass
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):

View File

@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet, We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways. which allows mixin classes to be composed in interesting ways.
""" """
from __future__ import unicode_literals
from django.http import Http404 from django.http import Http404
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
@ -20,6 +22,7 @@ class CreateModelMixin(object):
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)
self.object = serializer.save() self.object = serializer.save()
self.post_save(self.object, created=True)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers) headers=headers)
@ -32,16 +35,13 @@ class CreateModelMixin(object):
except (TypeError, KeyError): except (TypeError, KeyError):
return {} return {}
def pre_save(self, obj):
pass
class ListModelMixin(object): class ListModelMixin(object):
""" """
List a queryset. List a queryset.
Should be mixed in with `MultipleObjectAPIView`. Should be mixed in with `MultipleObjectAPIView`.
""" """
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
queryset = self.get_queryset() queryset = self.get_queryset()
@ -63,7 +63,7 @@ class ListModelMixin(object):
paginator, page, queryset, is_paginated = packed paginator, page, queryset, is_paginated = packed
serializer = self.get_pagination_serializer(page) serializer = self.get_pagination_serializer(page)
else: else:
serializer = self.get_serializer(self.object_list) serializer = self.get_serializer(self.object_list, many=True)
return Response(serializer.data) return Response(serializer.data)
@ -86,12 +86,15 @@ class UpdateModelMixin(object):
""" """
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
self.object = None
try: try:
self.object = self.get_object() self.object = self.get_object()
success_status_code = status.HTTP_200_OK
except Http404: except Http404:
self.object = None created = True
success_status_code = status.HTTP_201_CREATED success_status_code = status.HTTP_201_CREATED
else:
created = False
success_status_code = status.HTTP_200_OK
serializer = self.get_serializer(self.object, data=request.DATA, serializer = self.get_serializer(self.object, data=request.DATA,
files=request.FILES, partial=partial) files=request.FILES, partial=partial)
@ -99,6 +102,7 @@ class UpdateModelMixin(object):
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)
self.object = serializer.save() self.object = serializer.save()
self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code) return Response(serializer.data, status=success_status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.http import Http404 from django.http import Http404
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -33,7 +34,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
""" """
# Allow URL style format override. eg. "?format=json # Allow URL style format override. eg. "?format=json
format_query_param = self.settings.URL_FORMAT_OVERRIDE format_query_param = self.settings.URL_FORMAT_OVERRIDE
format = format_suffix or request.GET.get(format_query_param) format = format_suffix or request.QUERY_PARAMS.get(format_query_param)
if format: if format:
renderers = self.filter_renderers(renderers, format) renderers = self.filter_renderers(renderers, format)
@ -80,5 +81,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Allows URL style accept override. eg. "?accept=application/json" Allows URL style accept override. eg. "?accept=application/json"
""" """
header = request.META.get('HTTP_ACCEPT', '*/*') header = request.META.get('HTTP_ACCEPT', '*/*')
header = request.GET.get(self.settings.URL_ACCEPT_OVERRIDE, header) header = request.QUERY_PARAMS.get(self.settings.URL_ACCEPT_OVERRIDE, header)
return [token.strip() for token in header.split(',')] return [token.strip() for token in header.split(',')]

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from rest_framework import serializers from rest_framework import serializers
from rest_framework.templatetags.rest_framework import replace_query_param from rest_framework.templatetags.rest_framework import replace_query_param
@ -34,6 +35,17 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page) return replace_query_param(url, self.page_field, page)
class DefaultObjectSerializer(serializers.Field):
"""
If no object serializer is specified, then this serializer will be applied
as the default.
"""
def __init__(self, source=None, context=None):
# Note: Swallow context kwarg - only required for eg. ModelSerializer.
super(DefaultObjectSerializer, self).__init__(source=source)
class PaginationSerializerOptions(serializers.SerializerOptions): class PaginationSerializerOptions(serializers.SerializerOptions):
""" """
An object that stores the options that may be provided to a An object that stores the options that may be provided to a
@ -44,7 +56,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions):
def __init__(self, meta): def __init__(self, meta):
super(PaginationSerializerOptions, self).__init__(meta) super(PaginationSerializerOptions, self).__init__(meta)
self.object_serializer_class = getattr(meta, 'object_serializer_class', self.object_serializer_class = getattr(meta, 'object_serializer_class',
serializers.Field) DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer): class BasePaginationSerializer(serializers.Serializer):
@ -62,14 +74,13 @@ class BasePaginationSerializer(serializers.Serializer):
super(BasePaginationSerializer, self).__init__(*args, **kwargs) super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field results_field = self.results_field
object_serializer = self.opts.object_serializer_class object_serializer = self.opts.object_serializer_class
self.fields[results_field] = object_serializer(source='object_list')
def to_native(self, obj): if 'context' in kwargs:
""" context_kwarg = {'context': kwargs['context']}
Prevent default behaviour of iterating over elements, and serializing else:
each in turn. context_kwarg = {}
"""
return self.convert_object(obj) self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
class PaginationSerializer(BasePaginationSerializer): class PaginationSerializer(BasePaginationSerializer):

View File

@ -4,12 +4,14 @@ Parsers are used to parse the content of incoming HTTP requests.
They give us a generic way of being able to handle various media types They give us a generic way of being able to handle various media types
on the request, such as form content or json encoded data. on the request, such as form content or json encoded data.
""" """
from __future__ import unicode_literals
from django.conf import settings
from django.http import QueryDict from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError from django.http.multipartparser import MultiPartParserError
from rest_framework.compat import yaml, ETParseError from rest_framework.compat import yaml, ETParseError, ET_XMLParser
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.compat import six
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from xml.parsers.expat import ExpatError from xml.parsers.expat import ExpatError
import json import json
@ -54,10 +56,14 @@ class JSONParser(BaseParser):
`data` will be an object which is the parsed content of the response. `data` will be an object which is the parsed content of the response.
`files` will always be `None`. `files` will always be `None`.
""" """
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
try: try:
return json.load(stream) data = stream.read().decode(encoding)
except ValueError, exc: return json.loads(data)
raise ParseError('JSON parse error - %s' % unicode(exc)) except ValueError as exc:
raise ParseError('JSON parse error - %s' % six.text_type(exc))
class YAMLParser(BaseParser): class YAMLParser(BaseParser):
@ -74,10 +80,14 @@ class YAMLParser(BaseParser):
`data` will be an object which is the parsed content of the response. `data` will be an object which is the parsed content of the response.
`files` will always be `None`. `files` will always be `None`.
""" """
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
try: try:
return yaml.safe_load(stream) data = stream.read().decode(encoding)
except (ValueError, yaml.parser.ParserError), exc: return yaml.safe_load(data)
raise ParseError('YAML parse error - %s' % unicode(exc)) except (ValueError, yaml.parser.ParserError) as exc:
raise ParseError('YAML parse error - %s' % six.u(exc))
class FormParser(BaseParser): class FormParser(BaseParser):
@ -94,7 +104,9 @@ class FormParser(BaseParser):
`data` will be a :class:`QueryDict` containing all the form parameters. `data` will be a :class:`QueryDict` containing all the form parameters.
`files` will always be :const:`None`. `files` will always be :const:`None`.
""" """
data = QueryDict(stream.read()) parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
data = QueryDict(stream.read(), encoding=encoding)
return data return data
@ -114,15 +126,16 @@ class MultiPartParser(BaseParser):
""" """
parser_context = parser_context or {} parser_context = parser_context or {}
request = parser_context['request'] request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META meta = request.META
upload_handlers = request.upload_handlers upload_handlers = request.upload_handlers
try: try:
parser = DjangoMultiPartParser(meta, stream, upload_handlers) parser = DjangoMultiPartParser(meta, stream, upload_handlers, encoding)
data, files = parser.parse() data, files = parser.parse()
return DataAndFiles(data, files) return DataAndFiles(data, files)
except MultiPartParserError, exc: except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % unicode(exc)) raise ParseError('Multipart form parse error - %s' % six.u(exc))
class XMLParser(BaseParser): class XMLParser(BaseParser):
@ -133,10 +146,13 @@ class XMLParser(BaseParser):
media_type = 'application/xml' media_type = 'application/xml'
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
parser = ET_XMLParser(encoding=encoding)
try: try:
tree = ET.parse(stream) tree = ET.parse(stream, parser=parser)
except (ExpatError, ETParseError, ValueError), exc: except (ExpatError, ETParseError, ValueError) as exc:
raise ParseError('XML parse error - %s' % unicode(exc)) raise ParseError('XML parse error - %s' % six.u(exc))
data = self._xml_convert(tree.getroot()) data = self._xml_convert(tree.getroot())
return data return data
@ -146,7 +162,7 @@ class XMLParser(BaseParser):
convert the xml `element` into the corresponding python object convert the xml `element` into the corresponding python object
""" """
children = element.getchildren() children = list(element)
if len(children) == 0: if len(children) == 0:
return self._type_convert(element.text) return self._type_convert(element.text)

View File

@ -1,7 +1,7 @@
""" """
Provides a set of pluggable permission policies. Provides a set of pluggable permission policies.
""" """
from __future__ import unicode_literals
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']

View File

@ -1,13 +1,16 @@
from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from django import forms from django import forms
from django.forms import widgets from django.forms import widgets
from django.forms.models import ModelChoiceIterator from django.forms.models import ModelChoiceIterator
from django.utils.encoding import smart_unicode
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.fields import Field, WritableField from rest_framework.fields import Field, WritableField
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from urlparse import urlparse from rest_framework.compat import urlparse
from rest_framework.compat import smart_text
import warnings
##### Relational fields ##### ##### Relational fields #####
@ -17,19 +20,35 @@ class RelatedField(WritableField):
""" """
Base class for related model fields. Base class for related model fields.
If not overridden, this represents a to-one relationship, using the unicode This represents a relationship using the unicode representation of the target.
representation of the target.
""" """
widget = widgets.Select widget = widgets.Select
many_widget = widgets.SelectMultiple
form_field_class = forms.ChoiceField
many_form_field_class = forms.MultipleChoiceField
cache_choices = False cache_choices = False
empty_label = None empty_label = None
default_read_only = True # TODO: Remove this read_only = True
many = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# 'null' is to be deprecated in favor of 'required'
if 'null' in kwargs:
warnings.warn('The `null` keyword argument is due to be deprecated. '
'Use the `required` keyword argument instead.',
PendingDeprecationWarning, stacklevel=2)
kwargs['required'] = not kwargs.pop('null')
self.queryset = kwargs.pop('queryset', None) self.queryset = kwargs.pop('queryset', None)
self.null = kwargs.pop('null', False) self.many = kwargs.pop('many', self.many)
if self.many:
self.widget = self.many_widget
self.form_field_class = self.many_form_field_class
kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
super(RelatedField, self).__init__(*args, **kwargs) super(RelatedField, self).__init__(*args, **kwargs)
self.read_only = kwargs.pop('read_only', self.default_read_only)
def initialize(self, parent, field_name): def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name) super(RelatedField, self).initialize(parent, field_name)
@ -40,7 +59,7 @@ class RelatedField(WritableField):
self.queryset = manager.related.model._default_manager.all() self.queryset = manager.related.model._default_manager.all()
else: # Reverse else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all() self.queryset = manager.field.rel.to._default_manager.all()
except: except Exception:
raise raise
msg = ('Serializer related fields must include a `queryset`' + msg = ('Serializer related fields must include a `queryset`' +
' argument or set `read_only=True') ' argument or set `read_only=True')
@ -48,11 +67,6 @@ class RelatedField(WritableField):
### We need this stuff to make form choices work... ### We need this stuff to make form choices work...
# def __deepcopy__(self, memo):
# result = super(RelatedField, self).__deepcopy__(memo)
# result.queryset = result.queryset
# return result
def prepare_value(self, obj): def prepare_value(self, obj):
return self.to_native(obj) return self.to_native(obj)
@ -60,8 +74,8 @@ class RelatedField(WritableField):
""" """
Return a readable representation for use with eg. select widgets. Return a readable representation for use with eg. select widgets.
""" """
desc = smart_unicode(obj) desc = smart_text(obj)
ident = smart_unicode(self.to_native(obj)) ident = smart_text(self.to_native(obj))
if desc == ident: if desc == ident:
return desc return desc
return "%s - %s" % (desc, ident) return "%s - %s" % (desc, ident)
@ -108,6 +122,9 @@ class RelatedField(WritableField):
if value is None: if value is None:
return None return None
if self.many:
return [self.to_native(item) for item in value.all()]
return self.to_native(value) return self.to_native(value)
def field_from_native(self, data, files, field_name, into): def field_from_native(self, data, files, field_name, into):
@ -115,69 +132,43 @@ class RelatedField(WritableField):
return return
try: try:
value = data[field_name] if self.many:
try:
# Form data
value = data.getlist(field_name)
if value == [''] or value == []:
raise KeyError
except AttributeError:
# Non-form data
value = data[field_name]
else:
value = data[field_name]
except KeyError: except KeyError:
if self.required: if self.partial:
raise ValidationError(self.error_messages['required']) return
return value = [] if self.many else None
if value in (None, '') and not self.null: if value in (None, '') and self.required:
raise ValidationError('Value may not be null') raise ValidationError(self.error_messages['required'])
elif value in (None, '') and self.null: elif value in (None, ''):
into[(self.source or field_name)] = None into[(self.source or field_name)] = None
elif self.many:
into[(self.source or field_name)] = [self.from_native(item) for item in value]
else: else:
into[(self.source or field_name)] = self.from_native(value) into[(self.source or field_name)] = self.from_native(value)
class ManyRelatedMixin(object):
"""
Mixin to convert a related field to a many related field.
"""
widget = widgets.SelectMultiple
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
try:
# Form data
value = data.getlist(self.source or field_name)
except:
# Non-form data
value = data.get(self.source or field_name)
else:
if value == ['']:
value = []
into[field_name] = [self.from_native(item) for item in value]
class ManyRelatedField(ManyRelatedMixin, RelatedField):
"""
Base class for related model managers.
If not overridden, this represents a to-many relationship, using the unicode
representations of the target, and is read-only.
"""
pass
### PrimaryKey relationships ### PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
""" """
Represents a to-one relationship as a pk value. Represents a relationship as a pk value.
""" """
default_read_only = False read_only = False
form_field_class = forms.ChoiceField
default_error_messages = { default_error_messages = {
'does_not_exist': _("Invalid pk '%s' - object does not exist."), 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
'invalid': _('Invalid value.'), 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
} }
# TODO: Remove these field hacks... # TODO: Remove these field hacks...
@ -188,8 +179,8 @@ class PrimaryKeyRelatedField(RelatedField):
""" """
Return a readable representation for use with eg. select widgets. Return a readable representation for use with eg. select widgets.
""" """
desc = smart_unicode(obj) desc = smart_text(obj)
ident = smart_unicode(self.to_native(obj.pk)) ident = smart_text(self.to_native(obj.pk))
if desc == ident: if desc == ident:
return desc return desc
return "%s - %s" % (desc, ident) return "%s - %s" % (desc, ident)
@ -205,85 +196,50 @@ class PrimaryKeyRelatedField(RelatedField):
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
except ObjectDoesNotExist: except ObjectDoesNotExist:
msg = self.error_messages['does_not_exist'] % smart_unicode(data) msg = self.error_messages['does_not_exist'] % smart_text(data)
raise ValidationError(msg) raise ValidationError(msg)
except (TypeError, ValueError): except (TypeError, ValueError):
msg = self.error_messages['invalid'] received = type(data).__name__
msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg) raise ValidationError(msg)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
if self.many:
# To-many relationship
try:
# Prefer obj.serializable_value for performance reasons
queryset = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedManager (reverse relationship)
queryset = getattr(obj, self.source or field_name)
# Forward relationship
return [self.to_native(item.pk) for item in queryset.all()]
# To-one relationship
try: try:
# Prefer obj.serializable_value for performance reasons # Prefer obj.serializable_value for performance reasons
pk = obj.serializable_value(self.source or field_name) pk = obj.serializable_value(self.source or field_name)
except AttributeError: except AttributeError:
# RelatedObject (reverse relationship) # RelatedObject (reverse relationship)
try: try:
obj = getattr(obj, self.source or field_name) pk = getattr(obj, self.source or field_name).pk
except ObjectDoesNotExist: except ObjectDoesNotExist:
return None return None
return self.to_native(obj.pk) return self.to_native(obj.pk)
# Forward relationship # Forward relationship
return self.to_native(pk) return self.to_native(pk)
class ManyPrimaryKeyRelatedField(ManyRelatedField):
"""
Represents a to-many relationship as a pk value.
"""
default_read_only = False
form_field_class = forms.MultipleChoiceField
default_error_messages = {
'does_not_exist': _("Invalid pk '%s' - object does not exist."),
'invalid': _('Invalid value.'),
}
def prepare_value(self, obj):
return self.to_native(obj.pk)
def label_from_instance(self, obj):
"""
Return a readable representation for use with eg. select widgets.
"""
desc = smart_unicode(obj)
ident = smart_unicode(self.to_native(obj.pk))
if desc == ident:
return desc
return "%s - %s" % (desc, ident)
def to_native(self, pk):
return pk
def field_to_native(self, obj, field_name):
try:
# Prefer obj.serializable_value for performance reasons
queryset = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedManager (reverse relationship)
queryset = getattr(obj, self.source or field_name)
return [self.to_native(item.pk) for item in queryset.all()]
# Forward relationship
return [self.to_native(item.pk) for item in queryset.all()]
def from_native(self, data):
if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
return self.queryset.get(pk=data)
except ObjectDoesNotExist:
msg = self.error_messages['does_not_exist'] % smart_unicode(data)
raise ValidationError(msg)
except (TypeError, ValueError):
msg = self.error_messages['invalid']
raise ValidationError(msg)
### Slug relationships ### Slug relationships
class SlugRelatedField(RelatedField): class SlugRelatedField(RelatedField):
default_read_only = False """
form_field_class = forms.ChoiceField Represents a relationship using a unique field on the target.
"""
read_only = False
default_error_messages = { default_error_messages = {
'does_not_exist': _("Object with %s=%s does not exist."), 'does_not_exist': _("Object with %s=%s does not exist."),
@ -306,40 +262,35 @@ class SlugRelatedField(RelatedField):
return self.queryset.get(**{self.slug_field: data}) return self.queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist: except ObjectDoesNotExist:
raise ValidationError(self.error_messages['does_not_exist'] % raise ValidationError(self.error_messages['does_not_exist'] %
(self.slug_field, unicode(data))) (self.slug_field, smart_text(data)))
except (TypeError, ValueError): except (TypeError, ValueError):
msg = self.error_messages['invalid'] msg = self.error_messages['invalid']
raise ValidationError(msg) raise ValidationError(msg)
class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
form_field_class = forms.MultipleChoiceField
### Hyperlinked relationships ### Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField): class HyperlinkedRelatedField(RelatedField):
""" """
Represents a to-one relationship, using hyperlinking. Represents a relationship using hyperlinking.
""" """
pk_url_kwarg = 'pk' pk_url_kwarg = 'pk'
slug_field = 'slug' slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
default_read_only = False read_only = False
form_field_class = forms.ChoiceField
default_error_messages = { default_error_messages = {
'no_match': _('Invalid hyperlink - No URL match'), 'no_match': _('Invalid hyperlink - No URL match'),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
'configuration_error': _('Invalid hyperlink due to configuration error'), 'configuration_error': _('Invalid hyperlink due to configuration error'),
'does_not_exist': _("Invalid hyperlink - object does not exist."), 'does_not_exist': _("Invalid hyperlink - object does not exist."),
'invalid': _('Invalid value.'), 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
} }
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
try: try:
self.view_name = kwargs.pop('view_name') self.view_name = kwargs.pop('view_name')
except: except KeyError:
raise ValueError("Hyperlinked field requires 'view_name' kwarg") raise ValueError("Hyperlinked field requires 'view_name' kwarg")
self.slug_field = kwargs.pop('slug_field', self.slug_field) self.slug_field = kwargs.pop('slug_field', self.slug_field)
@ -366,7 +317,7 @@ class HyperlinkedRelatedField(RelatedField):
kwargs = {self.pk_url_kwarg: pk} kwargs = {self.pk_url_kwarg: pk}
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return reverse(view_name, kwargs=kwargs, request=request, format=format)
except: except NoReverseMatch:
pass pass
slug = getattr(obj, self.slug_field, None) slug = getattr(obj, self.slug_field, None)
@ -377,13 +328,13 @@ class HyperlinkedRelatedField(RelatedField):
kwargs = {self.slug_url_kwarg: slug} kwargs = {self.slug_url_kwarg: slug}
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return reverse(view_name, kwargs=kwargs, request=request, format=format)
except: except NoReverseMatch:
pass pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return reverse(view_name, kwargs=kwargs, request=request, format=format)
except: except NoReverseMatch:
pass pass
raise Exception('Could not resolve URL for field using view name "%s"' % view_name) raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
@ -397,19 +348,19 @@ class HyperlinkedRelatedField(RelatedField):
try: try:
http_prefix = value.startswith('http:') or value.startswith('https:') http_prefix = value.startswith('http:') or value.startswith('https:')
except AttributeError: except AttributeError:
msg = self.error_messages['invalid'] msg = self.error_messages['incorrect_type']
raise ValidationError(msg) raise ValidationError(msg % type(value).__name__)
if http_prefix: if http_prefix:
# If needed convert absolute URLs to relative path # If needed convert absolute URLs to relative path
value = urlparse(value).path value = urlparse.urlparse(value).path
prefix = get_script_prefix() prefix = get_script_prefix()
if value.startswith(prefix): if value.startswith(prefix):
value = '/' + value[len(prefix):] value = '/' + value[len(prefix):]
try: try:
match = resolve(value) match = resolve(value)
except: except Exception:
raise ValidationError(self.error_messages['no_match']) raise ValidationError(self.error_messages['no_match'])
if match.view_name != self.view_name: if match.view_name != self.view_name:
@ -434,19 +385,12 @@ class HyperlinkedRelatedField(RelatedField):
except ObjectDoesNotExist: except ObjectDoesNotExist:
raise ValidationError(self.error_messages['does_not_exist']) raise ValidationError(self.error_messages['does_not_exist'])
except (TypeError, ValueError): except (TypeError, ValueError):
msg = self.error_messages['invalid'] msg = self.error_messages['incorrect_type']
raise ValidationError(msg) raise ValidationError(msg % type(value).__name__)
return obj return obj
class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
"""
Represents a to-many relationship, using hyperlinking.
"""
form_field_class = forms.MultipleChoiceField
class HyperlinkedIdentityField(Field): class HyperlinkedIdentityField(Field):
""" """
Represents the instance, or a property on the instance, using hyperlinking. Represents the instance, or a property on the instance, using hyperlinking.
@ -454,6 +398,7 @@ class HyperlinkedIdentityField(Field):
pk_url_kwarg = 'pk' pk_url_kwarg = 'pk'
slug_field = 'slug' slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
read_only = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# TODO: Make view_name mandatory, and have the # TODO: Make view_name mandatory, and have the
@ -489,7 +434,7 @@ class HyperlinkedIdentityField(Field):
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return reverse(view_name, kwargs=kwargs, request=request, format=format)
except: except NoReverseMatch:
pass pass
slug = getattr(obj, self.slug_field, None) slug = getattr(obj, self.slug_field, None)
@ -500,13 +445,51 @@ class HyperlinkedIdentityField(Field):
kwargs = {self.slug_url_kwarg: slug} kwargs = {self.slug_url_kwarg: slug}
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return reverse(view_name, kwargs=kwargs, request=request, format=format)
except: except NoReverseMatch:
pass pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try: try:
return reverse(view_name, kwargs=kwargs, request=request, format=format) return reverse(view_name, kwargs=kwargs, request=request, format=format)
except: except NoReverseMatch:
pass pass
raise Exception('Could not resolve URL for field using view name "%s"' % view_name) raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
### Old-style many classes for backwards compat
class ManyRelatedField(RelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManyRelatedField()` is due to be deprecated. '
'Use `RelatedField(many=True)` instead.',
PendingDeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyRelatedField, self).__init__(*args, **kwargs)
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. '
'Use `PrimaryKeyRelatedField(many=True)` instead.',
PendingDeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
class ManySlugRelatedField(SlugRelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManySlugRelatedField()` is due to be deprecated. '
'Use `SlugRelatedField(many=True)` instead.',
PendingDeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManySlugRelatedField, self).__init__(*args, **kwargs)
class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
def __init__(self, *args, **kwargs):
warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. '
'Use `HyperlinkedRelatedField(many=True)` instead.',
PendingDeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)

View File

@ -6,6 +6,8 @@ on the response, such as JSON encoded data or HTML output.
REST framework also provides an HTML renderer the renders the browsable API. REST framework also provides an HTML renderer the renders the browsable API.
""" """
from __future__ import unicode_literals
import copy import copy
import string import string
import json import json
@ -60,7 +62,7 @@ class JSONRenderer(BaseRenderer):
if accepted_media_type: if accepted_media_type:
# If the media type looks like 'application/json; indent=4', # If the media type looks like 'application/json; indent=4',
# then pretty print the result. # then pretty print the result.
base_media_type, params = parse_header(accepted_media_type) base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
indent = params.get('indent', indent) indent = params.get('indent', indent)
try: try:
indent = max(min(int(indent), 8), 0) indent = max(min(int(indent), 8), 0)
@ -86,7 +88,7 @@ class JSONPRenderer(JSONRenderer):
Determine the name of the callback to wrap around the json output. Determine the name of the callback to wrap around the json output.
""" """
request = renderer_context.get('request', None) request = renderer_context.get('request', None)
params = request and request.GET or {} params = request and request.QUERY_PARAMS or {}
return params.get(self.callback_parameter, self.default_callback) return params.get(self.callback_parameter, self.default_callback)
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
@ -100,7 +102,7 @@ class JSONPRenderer(JSONRenderer):
callback = self.get_callback(renderer_context) callback = self.get_callback(renderer_context)
json = super(JSONPRenderer, self).render(data, accepted_media_type, json = super(JSONPRenderer, self).render(data, accepted_media_type,
renderer_context) renderer_context)
return u"%s(%s);" % (callback, json) return "%s(%s);" % (callback, json)
class XMLRenderer(BaseRenderer): class XMLRenderer(BaseRenderer):
@ -215,7 +217,7 @@ class TemplateHTMLRenderer(BaseRenderer):
try: try:
# Try to find an appropriate error template # Try to find an appropriate error template
return self.resolve_template(template_names) return self.resolve_template(template_names)
except: except Exception:
# Fall back to using eg '404 Not Found' # Fall back to using eg '404 Not Found'
return Template('%d %s' % (response.status_code, return Template('%d %s' % (response.status_code,
response.status_text.title())) response.status_text.title()))
@ -301,7 +303,7 @@ class BrowsableAPIRenderer(BaseRenderer):
try: try:
if not view.has_permission(request, obj): if not view.has_permission(request, obj):
return # Don't have permission return # Don't have permission
except: except Exception:
return # Don't have permission and exception explicitly raise return # Don't have permission and exception explicitly raise
return True return True
@ -333,6 +335,7 @@ class BrowsableAPIRenderer(BaseRenderer):
kwargs['label'] = k kwargs['label'] = k
fields[k] = v.form_field_class(**kwargs) fields[k] = v.form_field_class(**kwargs)
return fields return fields
def get_form(self, view, method, request): def get_form(self, view, method, request):
@ -357,7 +360,7 @@ class BrowsableAPIRenderer(BaseRenderer):
# Creating an on the fly form see: # Creating an on the fly form see:
# http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields)
data = (obj is not None) and serializer.data or None data = (obj is not None) and serializer.data or None
form_instance = OnTheFlyForm(data) form_instance = OnTheFlyForm(data)
return form_instance return form_instance

View File

@ -9,10 +9,12 @@ The wrapped request then offers a richer API, in particular :
- full support of PUT method, including support for file uploads - full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content - form overloading of HTTP method, content type and content
""" """
from StringIO import StringIO from __future__ import unicode_literals
from django.conf import settings
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import BytesIO
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -20,7 +22,7 @@ def is_form_media_type(media_type):
""" """
Return True if the media type is a valid form media type. Return True if the media type is a valid form media type.
""" """
base_media_type, params = parse_header(media_type) base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING))
return (base_media_type == 'application/x-www-form-urlencoded' or return (base_media_type == 'application/x-www-form-urlencoded' or
base_media_type == 'multipart/form-data') base_media_type == 'multipart/form-data')
@ -86,10 +88,12 @@ class Request(object):
self._method = Empty self._method = Empty
self._content_type = Empty self._content_type = Empty
self._stream = Empty self._stream = Empty
self._authenticator = None
if self.parser_context is None: if self.parser_context is None:
self.parser_context = {} self.parser_context = {}
self.parser_context['request'] = self self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
def _default_negotiator(self): def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
@ -166,7 +170,7 @@ class Request(object):
by the authentication classes provided to the request. by the authentication classes provided to the request.
""" """
if not hasattr(self, '_user'): if not hasattr(self, '_user'):
self._user, self._auth = self._authenticate() self._authenticator, self._user, self._auth = self._authenticate()
return self._user return self._user
@user.setter @user.setter
@ -185,7 +189,7 @@ class Request(object):
request, such as an authentication token. request, such as an authentication token.
""" """
if not hasattr(self, '_auth'): if not hasattr(self, '_auth'):
self._user, self._auth = self._authenticate() self._authenticator, self._user, self._auth = self._authenticate()
return self._auth return self._auth
@auth.setter @auth.setter
@ -196,6 +200,14 @@ class Request(object):
""" """
self._auth = value self._auth = value
@property
def successful_authenticator(self):
"""
Return the instance of the authentication instance class that was used
to authenticate the request, or `None`.
"""
return self._authenticator
def _load_data_and_files(self): def _load_data_and_files(self):
""" """
Parses the request content into self.DATA and self.FILES. Parses the request content into self.DATA and self.FILES.
@ -233,7 +245,7 @@ class Request(object):
elif hasattr(self._request, 'read'): elif hasattr(self._request, 'read'):
self._stream = self._request self._stream = self._request
else: else:
self._stream = StringIO(self.raw_post_data) self._stream = BytesIO(self.raw_post_data)
def _perform_form_overloading(self): def _perform_form_overloading(self):
""" """
@ -268,7 +280,7 @@ class Request(object):
self._CONTENT_PARAM in self._data and self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data): self._CONTENTTYPE_PARAM in self._data):
self._content_type = self._data[self._CONTENTTYPE_PARAM] self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = StringIO(self._data[self._CONTENT_PARAM]) self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING))
self._data, self._files = (Empty, Empty) self._data, self._files = (Empty, Empty)
def _parse(self): def _parse(self):
@ -299,21 +311,23 @@ class Request(object):
def _authenticate(self): def _authenticate(self):
""" """
Attempt to authenticate the request using each authentication instance in turn. Attempt to authenticate the request using each authentication instance
Returns a two-tuple of (user, authtoken). in turn.
Returns a three-tuple of (authenticator, user, authtoken).
""" """
for authenticator in self.authenticators: for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self) user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None: if not user_auth_tuple is None:
return user_auth_tuple user, auth = user_auth_tuple
return (authenticator, user, auth)
return self._not_authenticated() return self._not_authenticated()
def _not_authenticated(self): def _not_authenticated(self):
""" """
Return a two-tuple of (user, authtoken), representing an Return a three-tuple of (authenticator, user, authtoken), representing
unauthenticated request. an unauthenticated request.
By default this will be (AnonymousUser, None). By default this will be (None, AnonymousUser, None).
""" """
if api_settings.UNAUTHENTICATED_USER: if api_settings.UNAUTHENTICATED_USER:
user = api_settings.UNAUTHENTICATED_USER() user = api_settings.UNAUTHENTICATED_USER()
@ -325,7 +339,7 @@ class Request(object):
else: else:
auth = None auth = None
return (user, auth) return (None, user, auth)
def __getattr__(self, attr): def __getattr__(self, attr):
""" """

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from django.core.handlers.wsgi import STATUS_CODE_TEXT from django.core.handlers.wsgi import STATUS_CODE_TEXT
from django.template.response import SimpleTemplateResponse from django.template.response import SimpleTemplateResponse
from rest_framework.compat import six
class Response(SimpleTemplateResponse): class Response(SimpleTemplateResponse):
@ -24,7 +26,7 @@ class Response(SimpleTemplateResponse):
self.exception = exception self.exception = exception
if headers: if headers:
for name,value in headers.iteritems(): for name, value in six.iteritems(headers):
self[name] = value self[name] = value
@property @property

View File

@ -1,6 +1,7 @@
""" """
Provide reverse functions that return fully qualified URLs Provide reverse functions that return fully qualified URLs
""" """
from __future__ import unicode_literals
from django.core.urlresolvers import reverse as django_reverse from django.core.urlresolvers import reverse as django_reverse
from django.utils.functional import lazy from django.utils.functional import lazy

View File

@ -33,7 +33,7 @@ def main():
elif len(sys.argv) == 1: elif len(sys.argv) == 1:
test_case = '' test_case = ''
else: else:
print usage() print(usage())
sys.exit(1) sys.exit(1)
failures = test_runner.run_tests(['tests' + test_case]) failures = test_runner.run_tests(['tests' + test_case])

View File

@ -1,11 +1,14 @@
from __future__ import unicode_literals
import copy import copy
import datetime import datetime
import types import types
from decimal import Decimal from decimal import Decimal
from django.core.paginator import Page
from django.db import models from django.db import models
from django.forms import widgets from django.forms import widgets
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model from rest_framework.compat import get_concrete_model
from rest_framework.compat import six
# Note: We do the following so that users of the framework can use this style: # Note: We do the following so that users of the framework can use this style:
# #
@ -63,7 +66,7 @@ def _get_declared_fields(bases, attrs):
Note that all fields from the base classes are used. Note that all fields from the base classes are used.
""" """
fields = [(field_name, attrs.pop(field_name)) fields = [(field_name, attrs.pop(field_name))
for field_name, obj in attrs.items() for field_name, obj in list(six.iteritems(attrs))
if isinstance(obj, Field)] if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1].creation_counter) fields.sort(key=lambda x: x[1].creation_counter)
@ -72,7 +75,7 @@ def _get_declared_fields(bases, attrs):
# in order to maintain the correct order of fields. # in order to maintain the correct order of fields.
for base in bases[::-1]: for base in bases[::-1]:
if hasattr(base, 'base_fields'): if hasattr(base, 'base_fields'):
fields = base.base_fields.items() + fields fields = list(base.base_fields.items()) + fields
return SortedDict(fields) return SortedDict(fields)
@ -93,20 +96,25 @@ class SerializerOptions(object):
self.exclude = getattr(meta, 'exclude', ()) self.exclude = getattr(meta, 'exclude', ())
class BaseSerializer(WritableField): class BaseSerializer(Field):
"""
This is the Serializer implementation.
We need to implement it as `BaseSerializer` due to metaclass magicks.
"""
class Meta(object): class Meta(object):
pass pass
_options_class = SerializerOptions _options_class = SerializerOptions
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations. _dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None, def __init__(self, instance=None, data=None, files=None,
context=None, partial=False, **kwargs): context=None, partial=False, many=None, source=None):
super(BaseSerializer, self).__init__(**kwargs) super(BaseSerializer, self).__init__(source=source)
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
self.parent = None self.parent = None
self.root = None self.root = None
self.partial = partial self.partial = partial
self.many = many
self.context = context or {} self.context = context or {}
@ -118,7 +126,6 @@ class BaseSerializer(WritableField):
self._data = None self._data = None
self._files = None self._files = None
self._errors = None self._errors = None
self._delete = False
##### #####
# Methods to determine which fields to use when (de)serializing objects. # Methods to determine which fields to use when (de)serializing objects.
@ -187,22 +194,6 @@ class BaseSerializer(WritableField):
""" """
return field_name return field_name
def convert_object(self, obj):
"""
Core of serialization.
Convert an object into a dictionary of serialized field values.
"""
ret = self._dict_class()
ret.fields = {}
for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
def restore_fields(self, data, files): def restore_fields(self, data, files):
""" """
Core of deserialization, together with `restore_object`. Core of deserialization, together with `restore_object`.
@ -211,7 +202,7 @@ class BaseSerializer(WritableField):
reverted_data = {} reverted_data = {}
if data is not None and not isinstance(data, dict): if data is not None and not isinstance(data, dict):
self._errors['non_field_errors'] = [u'Invalid data'] self._errors['non_field_errors'] = ['Invalid data']
return None return None
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
@ -219,10 +210,7 @@ class BaseSerializer(WritableField):
try: try:
field.field_from_native(data, files, field_name, reverted_data) field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err: except ValidationError as err:
if hasattr(err, 'message_dict'): self._errors[field_name] = list(err.messages)
self._errors[field_name] = [err.message_dict]
else:
self._errors[field_name] = list(err.messages)
return reverted_data return reverted_data
@ -231,6 +219,8 @@ class BaseSerializer(WritableField):
Run `validate_<fieldname>()` and `validate()` methods on the serializer Run `validate_<fieldname>()` and `validate()` methods on the serializer
""" """
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
if field_name in self._errors:
continue
try: try:
validate_method = getattr(self, 'validate_%s' % field_name, None) validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method: if validate_method:
@ -275,15 +265,22 @@ class BaseSerializer(WritableField):
""" """
Serialize objects -> primitives. Serialize objects -> primitives.
""" """
if hasattr(obj, '__iter__'): ret = self._dict_class()
return [self.convert_object(item) for item in obj] ret.fields = {}
return self.convert_object(obj)
for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
def from_native(self, data, files): def from_native(self, data, files):
""" """
Deserialize primitives -> objects. Deserialize primitives -> objects.
""" """
if hasattr(data, '__iter__') and not isinstance(data, dict): if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
# TODO: error data when deserializing lists # TODO: error data when deserializing lists
return [self.from_native(item, None) for item in data] return [self.from_native(item, None) for item in data]
@ -302,6 +299,9 @@ class BaseSerializer(WritableField):
Override default so that we can apply ModelSerializer as a nested Override default so that we can apply ModelSerializer as a nested
field to relationships. field to relationships.
""" """
if self.source == '*':
return self.to_native(obj)
try: try:
if self.source: if self.source:
for component in self.source.split('.'): for component in self.source.split('.'):
@ -322,6 +322,13 @@ class BaseSerializer(WritableField):
if obj is None: if obj is None:
return None return None
if self.many is not None:
many = self.many
else:
many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
if many:
return [self.to_native(item) for item in obj]
return self.to_native(obj) return self.to_native(obj)
@property @property
@ -331,9 +338,20 @@ class BaseSerializer(WritableField):
setting self.object if no errors occurred. setting self.object if no errors occurred.
""" """
if self._errors is None: if self._errors is None:
obj = self.from_native(self.init_data, self.init_files) data, files = self.init_data, self.init_files
if self.many is not None:
many = self.many
else:
many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict))
# TODO: error data when deserializing lists
if many:
ret = [self.from_native(item, None) for item in data]
ret = self.from_native(data, files)
if not self._errors: if not self._errors:
self.object = obj self.object = ret
return self._errors return self._errors
def is_valid(self): def is_valid(self):
@ -341,8 +359,22 @@ class BaseSerializer(WritableField):
@property @property
def data(self): def data(self):
"""
Returns the serialized data on the serializer.
"""
if self._data is None: if self._data is None:
self._data = self.to_native(self.object) obj = self.object
if self.many is not None:
many = self.many
else:
many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
if many:
self._data = [self.to_native(item) for item in obj]
else:
self._data = self.to_native(obj)
return self._data return self._data
def save(self): def save(self):
@ -353,8 +385,8 @@ class BaseSerializer(WritableField):
return self.object return self.object
class Serializer(BaseSerializer): class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
__metaclass__ = SerializerMetaclass pass
class ModelSerializerOptions(SerializerOptions): class ModelSerializerOptions(SerializerOptions):
@ -373,35 +405,6 @@ class ModelSerializer(Serializer):
""" """
_options_class = ModelSerializerOptions _options_class = ModelSerializerOptions
def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
try:
value = data[field_name]
except KeyError:
if self.required:
raise ValidationError(self.error_messages['required'])
return
if self.parent.object:
# Set the serializer object if it exists
pk_field_name = self.opts.model._meta.pk.name
obj = getattr(self.parent.object, field_name)
self.object = obj
if value in (None, ''):
self._delete = True
into[(self.source or field_name)] = self
else:
obj = self.from_native(value, files)
if not self._errors:
self.object = obj
into[self.source or field_name] = self
else:
# Propagate errors up to our parent
raise ValidationError(self._errors)
def get_default_fields(self): def get_default_fields(self):
""" """
Return all the fields that should be serialized for the model. Return all the fields that should be serialized for the model.
@ -466,12 +469,11 @@ class ModelSerializer(Serializer):
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = { kwargs = {
'null': model_field.null or model_field.blank, 'required': not(model_field.null or model_field.blank),
'queryset': model_field.rel.to._default_manager 'queryset': model_field.rel.to._default_manager,
'many': to_many
} }
if to_many:
return ManyPrimaryKeyRelatedField(**kwargs)
return PrimaryKeyRelatedField(**kwargs) return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field): def get_field(self, model_field):
@ -479,20 +481,18 @@ class ModelSerializer(Serializer):
Creates a default instance of a basic non-relational field. Creates a default instance of a basic non-relational field.
""" """
kwargs = {} kwargs = {}
has_default = model_field.has_default()
kwargs['blank'] = model_field.blank if model_field.null or model_field.blank or has_default:
if model_field.null or model_field.blank:
kwargs['required'] = False kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable: if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True kwargs['read_only'] = True
if model_field.has_default(): if has_default:
kwargs['required'] = False
kwargs['default'] = model_field.get_default() kwargs['default'] = model_field.get_default()
if model_field.__class__ == models.TextField: if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea kwargs['widget'] = widgets.Textarea
# TODO: TypedChoiceField? # TODO: TypedChoiceField?
@ -536,6 +536,22 @@ class ModelSerializer(Serializer):
exclusions.remove(field_name) exclusions.remove(field_name)
return exclusions return exclusions
def full_clean(self, instance):
"""
Perform Django's full_clean, and populate the `errors` dictionary
if any validation errors occur.
Note that we don't perform this inside the `.restore_object()` method,
so that subclasses can override `.restore_object()`, and still get
the full_clean validation checking.
"""
try:
instance.full_clean(exclude=self.get_validation_exclusions())
except ValidationError as err:
self._errors = err.message_dict
return None
return instance
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
""" """
Restore the model instance. Restore the model instance.
@ -569,19 +585,24 @@ class ModelSerializer(Serializer):
try: try:
instance.full_clean(exclude=self.get_validation_exclusions()) instance.full_clean(exclude=self.get_validation_exclusions())
except ValidationError, err: except ValidationError as err:
self._errors = err.message_dict self._errors = err.message_dict
return None return None
return instance return instance
def _save(self, parent=None, fk_field=None): def from_native(self, data, files):
if self._delete: """
self.object.delete() Override the default method to also include model field validation.
return """
instance = super(ModelSerializer, self).from_native(data, files)
if instance:
return self.full_clean(instance)
if parent and fk_field: def save(self):
setattr(self.object, fk_field, parent) """
Save the deserialized object and return it.
"""
self.object.save() self.object.save()
if getattr(self, 'm2m_data', None): if getattr(self, 'm2m_data', None):
@ -591,18 +612,9 @@ class ModelSerializer(Serializer):
if getattr(self, 'related_data', None): if getattr(self, 'related_data', None):
for accessor_name, object_list in self.related_data.items(): for accessor_name, object_list in self.related_data.items():
if isinstance(object_list, ModelSerializer): setattr(self.object, accessor_name, object_list)
fk_field = self.object._meta.get_field_by_name(accessor_name)[0].field.name
object_list._save(parent=self.object, fk_field=fk_field)
else:
setattr(self.object, accessor_name, object_list)
self.related_data = {} self.related_data = {}
def save(self):
"""
Save the deserialized object and return it.
"""
self._save()
return self.object return self.object
@ -617,6 +629,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):
""" """
A subclass of ModelSerializer that uses hyperlinked relationships,
instead of primary key relationships.
""" """
_options_class = HyperlinkedModelSerializerOptions _options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail' _default_view_name = '%(model_name)s-detail'
@ -650,10 +664,9 @@ class HyperlinkedModelSerializer(ModelSerializer):
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to rel = model_field.rel.to
kwargs = { kwargs = {
'null': model_field.null, 'required': not(model_field.null or model_field.blank),
'queryset': rel._default_manager, 'queryset': rel._default_manager,
'view_name': self._get_default_view_name(rel) 'view_name': self._get_default_view_name(rel),
'many': to_many
} }
if to_many:
return ManyHyperlinkedRelatedField(**kwargs)
return HyperlinkedRelatedField(**kwargs) return HyperlinkedRelatedField(**kwargs)

View File

@ -17,8 +17,10 @@ This module provides the `api_setting` object, that is used to access
REST framework settings, checking for user settings first, then falling REST framework settings, checking for user settings first, then falling
back to the defaults. back to the defaults.
""" """
from __future__ import unicode_literals
from django.conf import settings from django.conf import settings
from django.utils import importlib from django.utils import importlib
from rest_framework.compat import six
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
@ -98,7 +100,7 @@ def perform_import(val, setting_name):
If the given setting is a string import notation, If the given setting is a string import notation,
then perform the necessary import or imports. then perform the necessary import or imports.
""" """
if isinstance(val, basestring): if isinstance(val, six.string_types):
return import_from_string(val, setting_name) return import_from_string(val, setting_name)
elif isinstance(val, (list, tuple)): elif isinstance(val, (list, tuple)):
return [import_from_string(item, setting_name) for item in val] return [import_from_string(item, setting_name) for item in val]

389
rest_framework/six.py Normal file
View File

@ -0,0 +1,389 @@
"""Utilities for writing code that runs on Python 2 and 3"""
import operator
import sys
import types
__author__ = "Benjamin Peterson <benjamin@python.org>"
__version__ = "1.2.0"
# True if we are running on Python 3.
PY3 = sys.version_info[0] == 3
if PY3:
string_types = str,
integer_types = int,
class_types = type,
text_type = str
binary_type = bytes
MAXSIZE = sys.maxsize
else:
string_types = basestring,
integer_types = (int, long)
class_types = (type, types.ClassType)
text_type = unicode
binary_type = str
if sys.platform == "java":
# Jython always uses 32 bits.
MAXSIZE = int((1 << 31) - 1)
else:
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
MAXSIZE = int((1 << 31) - 1)
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1)
del X
def _add_doc(func, doc):
"""Add documentation to a function."""
func.__doc__ = doc
def _import_module(name):
"""Import module, returning the module after the last dot."""
__import__(name)
return sys.modules[name]
class _LazyDescr(object):
def __init__(self, name):
self.name = name
def __get__(self, obj, tp):
result = self._resolve()
setattr(obj, self.name, result)
# This is a bit ugly, but it avoids running this again.
delattr(tp, self.name)
return result
class MovedModule(_LazyDescr):
def __init__(self, name, old, new=None):
super(MovedModule, self).__init__(name)
if PY3:
if new is None:
new = name
self.mod = new
else:
self.mod = old
def _resolve(self):
return _import_module(self.mod)
class MovedAttribute(_LazyDescr):
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
super(MovedAttribute, self).__init__(name)
if PY3:
if new_mod is None:
new_mod = name
self.mod = new_mod
if new_attr is None:
if old_attr is None:
new_attr = name
else:
new_attr = old_attr
self.attr = new_attr
else:
self.mod = old_mod
if old_attr is None:
old_attr = name
self.attr = old_attr
def _resolve(self):
module = _import_module(self.mod)
return getattr(module, self.attr)
class _MovedItems(types.ModuleType):
"""Lazy loading of moved objects"""
_moved_attributes = [
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
MovedAttribute("reload_module", "__builtin__", "imp", "reload"),
MovedAttribute("reduce", "__builtin__", "functools"),
MovedAttribute("StringIO", "StringIO", "io"),
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
MovedModule("builtins", "__builtin__"),
MovedModule("configparser", "ConfigParser"),
MovedModule("copyreg", "copy_reg"),
MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
MovedModule("http_cookies", "Cookie", "http.cookies"),
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
MovedModule("html_parser", "HTMLParser", "html.parser"),
MovedModule("http_client", "httplib", "http.client"),
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
MovedModule("cPickle", "cPickle", "pickle"),
MovedModule("queue", "Queue"),
MovedModule("reprlib", "repr"),
MovedModule("socketserver", "SocketServer"),
MovedModule("tkinter", "Tkinter"),
MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
MovedModule("tkinter_colorchooser", "tkColorChooser",
"tkinter.colorchooser"),
MovedModule("tkinter_commondialog", "tkCommonDialog",
"tkinter.commondialog"),
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
"tkinter.simpledialog"),
MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
MovedModule("winreg", "_winreg"),
]
for attr in _moved_attributes:
setattr(_MovedItems, attr.name, attr)
del attr
moves = sys.modules["django.utils.six.moves"] = _MovedItems("moves")
def add_move(move):
"""Add an item to six.moves."""
setattr(_MovedItems, move.name, move)
def remove_move(name):
"""Remove item from six.moves."""
try:
delattr(_MovedItems, name)
except AttributeError:
try:
del moves.__dict__[name]
except KeyError:
raise AttributeError("no such move, %r" % (name,))
if PY3:
_meth_func = "__func__"
_meth_self = "__self__"
_func_code = "__code__"
_func_defaults = "__defaults__"
_iterkeys = "keys"
_itervalues = "values"
_iteritems = "items"
else:
_meth_func = "im_func"
_meth_self = "im_self"
_func_code = "func_code"
_func_defaults = "func_defaults"
_iterkeys = "iterkeys"
_itervalues = "itervalues"
_iteritems = "iteritems"
try:
advance_iterator = next
except NameError:
def advance_iterator(it):
return it.next()
next = advance_iterator
if PY3:
def get_unbound_function(unbound):
return unbound
Iterator = object
def callable(obj):
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
else:
def get_unbound_function(unbound):
return unbound.im_func
class Iterator(object):
def next(self):
return type(self).__next__(self)
callable = callable
_add_doc(get_unbound_function,
"""Get the function out of a possibly unbound function""")
get_method_function = operator.attrgetter(_meth_func)
get_method_self = operator.attrgetter(_meth_self)
get_function_code = operator.attrgetter(_func_code)
get_function_defaults = operator.attrgetter(_func_defaults)
def iterkeys(d):
"""Return an iterator over the keys of a dictionary."""
return iter(getattr(d, _iterkeys)())
def itervalues(d):
"""Return an iterator over the values of a dictionary."""
return iter(getattr(d, _itervalues)())
def iteritems(d):
"""Return an iterator over the (key, value) pairs of a dictionary."""
return iter(getattr(d, _iteritems)())
if PY3:
def b(s):
return s.encode("latin-1")
def u(s):
return s
if sys.version_info[1] <= 1:
def int2byte(i):
return bytes((i,))
else:
# This is about 2x faster than the implementation above on 3.2+
int2byte = operator.methodcaller("to_bytes", 1, "big")
import io
StringIO = io.StringIO
BytesIO = io.BytesIO
else:
def b(s):
return s
def u(s):
return unicode(s, "unicode_escape")
int2byte = chr
import StringIO
StringIO = BytesIO = StringIO.StringIO
_add_doc(b, """Byte literal""")
_add_doc(u, """Text literal""")
if PY3:
import builtins
exec_ = getattr(builtins, "exec")
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
print_ = getattr(builtins, "print")
del builtins
else:
def exec_(code, globs=None, locs=None):
"""Execute code in a namespace."""
if globs is None:
frame = sys._getframe(1)
globs = frame.f_globals
if locs is None:
locs = frame.f_locals
del frame
elif locs is None:
locs = globs
exec("""exec code in globs, locs""")
exec_("""def reraise(tp, value, tb=None):
raise tp, value, tb
""")
def print_(*args, **kwargs):
"""The new-style print function."""
fp = kwargs.pop("file", sys.stdout)
if fp is None:
return
def write(data):
if not isinstance(data, basestring):
data = str(data)
fp.write(data)
want_unicode = False
sep = kwargs.pop("sep", None)
if sep is not None:
if isinstance(sep, unicode):
want_unicode = True
elif not isinstance(sep, str):
raise TypeError("sep must be None or a string")
end = kwargs.pop("end", None)
if end is not None:
if isinstance(end, unicode):
want_unicode = True
elif not isinstance(end, str):
raise TypeError("end must be None or a string")
if kwargs:
raise TypeError("invalid keyword arguments to print()")
if not want_unicode:
for arg in args:
if isinstance(arg, unicode):
want_unicode = True
break
if want_unicode:
newline = unicode("\n")
space = unicode(" ")
else:
newline = "\n"
space = " "
if sep is None:
sep = space
if end is None:
end = newline
for i, arg in enumerate(args):
if i:
write(sep)
write(arg)
write(end)
_add_doc(reraise, """Reraise an exception.""")
def with_metaclass(meta, base=object):
"""Create a base class with a metaclass."""
return meta("NewBase", (base,), {})
### Additional customizations for Django ###
if PY3:
_iterlists = "lists"
_assertRaisesRegex = "assertRaisesRegex"
else:
_iterlists = "iterlists"
_assertRaisesRegex = "assertRaisesRegexp"
def iterlists(d):
"""Return an iterator over the values of a MultiValueDict."""
return getattr(d, _iterlists)()
def assertRaisesRegex(self, *args, **kwargs):
return getattr(self, _assertRaisesRegex)(*args, **kwargs)
add_move(MovedModule("_dummy_thread", "dummy_thread"))
add_move(MovedModule("_thread", "thread"))

View File

@ -4,6 +4,7 @@ Descriptive HTTP status codes, for code readability.
See RFC 2616 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html See RFC 2616 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html
And RFC 6585 - http://tools.ietf.org/html/rfc6585 And RFC 6585 - http://tools.ietf.org/html/rfc6585
""" """
from __future__ import unicode_literals
HTTP_100_CONTINUE = 100 HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101 HTTP_101_SWITCHING_PROTOCOLS = 101

View File

@ -13,7 +13,7 @@
<title>{% block title %}Django REST framework{% endblock %}</title> <title>{% block title %}Django REST framework{% endblock %}</title>
{% block style %} {% block style %}
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %}
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/> <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/>
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>

View File

@ -25,14 +25,14 @@
<form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
{% csrf_token %} {% csrf_token %}
<div id="div_id_username" class="clearfix control-group"> <div id="div_id_username" class="clearfix control-group">
<div class="controls" style="height: 30px"> <div class="controls">
<Label class="span4" style="margin-top: 3px">Username:</label> <Label class="span4">Username:</label>
<input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
</div> </div>
</div> </div>
<div id="div_id_password" class="clearfix control-group"> <div id="div_id_password" class="clearfix control-group">
<div class="controls" style="height: 30px"> <div class="controls">
<Label class="span4" style="margin-top: 3px">Password:</label> <Label class="span4">Password:</label>
<input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
</div> </div>
</div> </div>

View File

@ -1,10 +1,12 @@
from __future__ import unicode_literals, absolute_import
from django import template from django import template
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict from django.http import QueryDict
from django.utils.encoding import force_unicode
from django.utils.html import escape from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe from django.utils.safestring import SafeData, mark_safe
from urlparse import urlsplit, urlunsplit from rest_framework.compat import urlparse
from rest_framework.compat import force_text
from rest_framework.compat import six
import re import re
import string import string
@ -29,7 +31,7 @@ try: # Django 1.5+
def do_static(parser, token): def do_static(parser, token):
return StaticFilesNode.handle_token(parser, token) return StaticFilesNode.handle_token(parser, token)
except: except ImportError:
try: # Django 1.4 try: # Django 1.4
from django.contrib.staticfiles.storage import staticfiles_storage from django.contrib.staticfiles.storage import staticfiles_storage
@ -41,7 +43,7 @@ except:
""" """
return staticfiles_storage.url(path) return staticfiles_storage.url(path)
except: # Django 1.3 except ImportError: # Django 1.3
from urlparse import urljoin from urlparse import urljoin
from django import template from django import template
from django.templatetags.static import PrefixNode from django.templatetags.static import PrefixNode
@ -99,11 +101,11 @@ def replace_query_param(url, key, val):
Given a URL and a key/val pair, set or replace an item in the query Given a URL and a key/val pair, set or replace an item in the query
parameters of the URL, and return the new URL. parameters of the URL, and return the new URL.
""" """
(scheme, netloc, path, query, fragment) = urlsplit(url) (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
query_dict = QueryDict(query).copy() query_dict = QueryDict(query).copy()
query_dict[key] = val query_dict[key] = val
query = query_dict.urlencode() query = query_dict.urlencode()
return urlunsplit((scheme, netloc, path, query, fragment)) return urlparse.urlunsplit((scheme, netloc, path, query, fragment))
# Regex for adding classes to html snippets # Regex for adding classes to html snippets
@ -135,7 +137,7 @@ def optional_login(request):
""" """
try: try:
login_url = reverse('rest_framework:login') login_url = reverse('rest_framework:login')
except: except NoReverseMatch:
return '' return ''
snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path) snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path)
@ -149,7 +151,7 @@ def optional_logout(request):
""" """
try: try:
logout_url = reverse('rest_framework:logout') logout_url = reverse('rest_framework:logout')
except: except NoReverseMatch:
return '' return ''
snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path) snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path)
@ -179,7 +181,7 @@ def add_class(value, css_class):
In the case of REST Framework, the filter is used to add Bootstrap-specific In the case of REST Framework, the filter is used to add Bootstrap-specific
classes to the forms. classes to the forms.
""" """
html = unicode(value) html = six.text_type(value)
match = class_re.search(html) match = class_re.search(html)
if match: if match:
m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class, m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class,
@ -213,7 +215,7 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
""" """
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData) safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_unicode(text)) words = word_split_re.split(force_text(text))
nofollow_attr = nofollow and ' rel="nofollow"' or '' nofollow_attr = nofollow and ' rel="nofollow"' or ''
for i, word in enumerate(words): for i, word in enumerate(words):
match = None match = None
@ -249,4 +251,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words[i] = mark_safe(word) words[i] = mark_safe(word)
elif autoescape: elif autoescape:
words[i] = escape(word) words[i] = escape(word)
return mark_safe(u''.join(words)) return mark_safe(''.join(words))

View File

@ -1,13 +1,13 @@
from __future__ import unicode_literals
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import Client, TestCase from django.test import Client, TestCase
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import permissions from rest_framework import permissions
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.authentication import TokenAuthentication from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication
from rest_framework.compat import patterns from rest_framework.compat import patterns
from rest_framework.views import APIView from rest_framework.views import APIView
import json import json
import base64 import base64
@ -21,10 +21,10 @@ class MockView(APIView):
def put(self, request): def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({'a': 1, 'b': 2, 'c': 3})
MockView.authentication_classes += (TokenAuthentication,)
urlpatterns = patterns('', urlpatterns = patterns('',
(r'^$', MockView.as_view()), (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
) )
@ -42,25 +42,30 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self): def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() credentials = ('%s:%s' % (self.username, self.password))
response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials
response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_post_json_passing_basic_auth(self): def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() credentials = ('%s:%s' % (self.username, self.password))
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_post_form_failing_basic_auth(self): def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails""" """Ensure POSTing form over basic auth without correct credentials fails"""
response = self.csrf_client.post('/', {'example': 'example'}) response = self.csrf_client.post('/basic/', {'example': 'example'})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 401)
def test_post_json_failing_basic_auth(self): def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails""" """Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 401)
self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
class SessionAuthTests(TestCase): class SessionAuthTests(TestCase):
@ -83,7 +88,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails. Ensure POSTing form over session authentication without CSRF token fails.
""" """
self.csrf_client.login(username=self.username, password=self.password) self.csrf_client.login(username=self.username, password=self.password)
response = self.csrf_client.post('/', {'example': 'example'}) response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
def test_post_form_session_auth_passing(self): def test_post_form_session_auth_passing(self):
@ -91,7 +96,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in user and CSRF token passes. Ensure POSTing form over session authentication with logged in user and CSRF token passes.
""" """
self.non_csrf_client.login(username=self.username, password=self.password) self.non_csrf_client.login(username=self.username, password=self.password)
response = self.non_csrf_client.post('/', {'example': 'example'}) response = self.non_csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_put_form_session_auth_passing(self): def test_put_form_session_auth_passing(self):
@ -99,14 +104,14 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with logged in user and CSRF token passes. Ensure PUTting form over session authentication with logged in user and CSRF token passes.
""" """
self.non_csrf_client.login(username=self.username, password=self.password) self.non_csrf_client.login(username=self.username, password=self.password)
response = self.non_csrf_client.put('/', {'example': 'example'}) response = self.non_csrf_client.put('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_post_form_session_auth_failing(self): def test_post_form_session_auth_failing(self):
""" """
Ensure POSTing form over session authentication without logged in user fails. Ensure POSTing form over session authentication without logged in user fails.
""" """
response = self.csrf_client.post('/', {'example': 'example'}) response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
@ -127,24 +132,24 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self): def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_post_form_failing_token_auth(self): def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails""" """Ensure POSTing form over token auth without correct credentials fails"""
response = self.csrf_client.post('/', {'example': 'example'}) response = self.csrf_client.post('/token/', {'example': 'example'})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 401)
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 401)
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key""" """Ensure creating a token with no key will auto-assign a key"""
@ -158,7 +163,7 @@ class TokenAuthTests(TestCase):
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': self.password}), 'application/json') json.dumps({'username': self.username, 'password': self.password}), 'application/json')
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(json.loads(response.content)['token'], self.key) self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
def test_token_login_json_bad_creds(self): def test_token_login_json_bad_creds(self):
"""Ensure token login view using JSON POST fails if bad credentials are used.""" """Ensure token login view using JSON POST fails if bad credentials are used."""
@ -180,4 +185,4 @@ class TokenAuthTests(TestCase):
response = client.post('/auth-token/', response = client.post('/auth-token/',
{'username': self.username, 'password': self.password}) {'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(json.loads(response.content)['token'], self.key) self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import patterns, url from rest_framework.compat import patterns, url
from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework.utils.breadcrumbs import get_breadcrumbs

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
@ -28,13 +29,27 @@ class DecoratorTestCase(TestCase):
response.request = request response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs) return APIView.finalize_response(self, request, response, *args, **kwargs)
def test_wrap_view(self): def test_api_view_incorrect(self):
"""
If @api_view is not applied correct, we should raise an assertion.
"""
@api_view(['GET']) @api_view
def view(request): def view(request):
return Response({}) return Response()
self.assertTrue(isinstance(view.cls_instance, APIView)) request = self.factory.get('/')
self.assertRaises(AssertionError, view, request)
def test_api_view_incorrect_arguments(self):
"""
If @api_view is missing arguments, we should raise an assertion.
"""
with self.assertRaises(AssertionError):
@api_view('GET')
def view(request):
return Response()
def test_calling_method(self): def test_calling_method(self):

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.compat import apply_markdown from rest_framework.compat import apply_markdown

View File

@ -1,7 +1,7 @@
""" """
General serializer field tests. General serializer field tests.
""" """
from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers

View File

@ -1,9 +1,9 @@
import StringIO from __future__ import unicode_literals
import datetime
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import BytesIO
from rest_framework.compat import six
import datetime
class UploadedFile(object): class UploadedFile(object):
@ -27,9 +27,9 @@ class UploadedFileSerializer(serializers.Serializer):
class FileSerializerTests(TestCase): class FileSerializerTests(TestCase):
def test_create(self): def test_create(self):
now = datetime.datetime.now() now = datetime.datetime.now()
file = StringIO.StringIO('stuff') file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt' file.name = 'stuff.txt'
file.size = file.len file.size = len(file.getvalue())
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now) uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.test import TestCase from django.test import TestCase

View File

@ -1,25 +1,62 @@
from __future__ import unicode_literals
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import *
class Tag(models.Model):
"""
Tags have a descriptive slug, and are attached to an arbitrary object.
"""
tag = models.SlugField()
content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id')
def __unicode__(self):
return self.tag
class Bookmark(models.Model):
"""
A URL bookmark that may have multiple tags attached.
"""
url = models.URLField()
tags = GenericRelation(Tag)
def __unicode__(self):
return 'Bookmark: %s' % self.url
class Note(models.Model):
"""
A textual note that may have multiple tags attached.
"""
text = models.TextField()
tags = GenericRelation(Tag)
def __unicode__(self):
return 'Note: %s' % self.text
class TestGenericRelations(TestCase): class TestGenericRelations(TestCase):
def setUp(self): def setUp(self):
bookmark = Bookmark(url='https://www.djangoproject.com/') self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
bookmark.save() Tag.objects.create(tagged_item=self.bookmark, tag='django')
django = Tag(tag_name='django') Tag.objects.create(tagged_item=self.bookmark, tag='python')
django.save() self.note = Note.objects.create(text='Remember the milk')
python = Tag(tag_name='python') Tag.objects.create(tagged_item=self.note, tag='reminder')
python.save()
t1 = TaggedItem(content_object=bookmark, tag=django) def test_generic_relation(self):
t1.save() """
t2 = TaggedItem(content_object=bookmark, tag=python) Test a relationship that spans a GenericRelation field.
t2.save() IE. A reverse generic relationship.
self.bookmark = bookmark """
def test_reverse_generic_relation(self):
class BookmarkSerializer(serializers.ModelSerializer): class BookmarkSerializer(serializers.ModelSerializer):
tags = serializers.ManyRelatedField(source='tags') tags = serializers.RelatedField(many=True)
class Meta: class Meta:
model = Bookmark model = Bookmark
@ -27,7 +64,37 @@ class TestGenericRelations(TestCase):
serializer = BookmarkSerializer(self.bookmark) serializer = BookmarkSerializer(self.bookmark)
expected = { expected = {
'tags': [u'django', u'python'], 'tags': ['django', 'python'],
'url': u'https://www.djangoproject.com/' 'url': 'https://www.djangoproject.com/'
} }
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_generic_fk(self):
"""
Test a relationship that spans a GenericForeignKey field.
IE. A forward generic relationship.
"""
class TagSerializer(serializers.ModelSerializer):
tagged_item = serializers.RelatedField()
class Meta:
model = Tag
exclude = ('id', 'content_type', 'object_id')
serializer = TagSerializer(Tag.objects.all())
expected = [
{
'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
},
{
'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
},
{
'tag': 'reminder',
'tagged_item': 'Note: Remember the milk'
}
]
self.assertEquals(serializer.data, expected)

View File

@ -1,10 +1,11 @@
import json from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers, status from rest_framework import generics, serializers, status
from rest_framework.tests.utils import RequestFactory from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six
import json
factory = RequestFactory() factory = RequestFactory()
@ -72,7 +73,7 @@ class TestRootView(TestCase):
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) self.assertEquals(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4) created = self.objects.get(id=4)
self.assertEquals(created.text, 'foobar') self.assertEquals(created.text, 'foobar')
@ -127,7 +128,7 @@ class TestRootView(TestCase):
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) self.assertEquals(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4) created = self.objects.get(id=4)
self.assertEquals(created.text, 'foobar') self.assertEquals(created.text, 'foobar')
@ -202,7 +203,7 @@ class TestInstanceView(TestCase):
request = factory.delete('/1') request = factory.delete('/1')
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEquals(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEquals(response.content, '') self.assertEquals(response.content, six.b(''))
ids = [obj.id for obj in self.objects.all()] ids = [obj.id for obj in self.objects.all()]
self.assertEquals(ids, [2, 3]) self.assertEquals(ids, [2, 3])
@ -329,7 +330,7 @@ class ClassA(models.Model):
class ClassASerializer(serializers.ModelSerializer): class ClassASerializer(serializers.ModelSerializer):
childs = serializers.ManyPrimaryKeyRelatedField(source='childs') childs = serializers.PrimaryKeyRelatedField(many=True, source='childs')
class Meta: class Meta:
model = ClassA model = ClassA

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from django.test import TestCase from django.test import TestCase
@ -7,6 +8,7 @@ from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view, renderer_classes from rest_framework.decorators import api_view, renderer_classes
from rest_framework.renderers import TemplateHTMLRenderer from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.compat import six
@api_view(('GET',)) @api_view(('GET',))
@ -68,13 +70,13 @@ class TemplateHTMLRendererTests(TestCase):
def test_not_found_html_view(self): def test_not_found_html_view(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEquals(response.status_code, 404) self.assertEquals(response.status_code, 404)
self.assertEquals(response.content, "404 Not Found") self.assertEquals(response.content, six.b("404 Not Found"))
self.assertEquals(response['Content-Type'], 'text/html') self.assertEquals(response['Content-Type'], 'text/html')
def test_permission_denied_html_view(self): def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEquals(response.status_code, 403) self.assertEquals(response.status_code, 403)
self.assertEquals(response.content, "403 Forbidden") self.assertEquals(response.content, six.b("403 Forbidden"))
self.assertEquals(response['Content-Type'], 'text/html') self.assertEquals(response['Content-Type'], 'text/html')
@ -105,11 +107,11 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self): def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEquals(response.status_code, 404) self.assertEquals(response.status_code, 404)
self.assertEquals(response.content, "404: Not found") self.assertEquals(response.content, six.b("404: Not found"))
self.assertEquals(response['Content-Type'], 'text/html') self.assertEquals(response['Content-Type'], 'text/html')
def test_permission_denied_html_view_with_template(self): def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEquals(response.status_code, 403) self.assertEquals(response.status_code, 403)
self.assertEquals(response.content, "403: Permission denied") self.assertEquals(response.content, six.b("403: Permission denied"))
self.assertEquals(response['Content-Type'], 'text/html') self.assertEquals(response['Content-Type'], 'text/html')

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
import json import json
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory

View File

@ -1,35 +1,6 @@
from __future__ import unicode_literals
from django.db import models from django.db import models
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
# from django.contrib.auth.models import Group
# class CustomUser(models.Model):
# """
# A custom user model, which uses a 'through' table for the foreign key
# """
# username = models.CharField(max_length=255, unique=True)
# groups = models.ManyToManyField(
# to=Group, blank=True, null=True, through='UserGroupMap'
# )
# @models.permalink
# def get_absolute_url(self):
# return ('custom_user', (), {
# 'pk': self.id
# })
# class UserGroupMap(models.Model):
# user = models.ForeignKey(to=CustomUser)
# group = models.ForeignKey(to=Group)
# @models.permalink
# def get_absolute_url(self):
# return ('user_group_map', (), {
# 'pk': self.id
# })
def foobar(): def foobar():
return 'foobar' return 'foobar'
@ -86,27 +57,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor') text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor) rel = models.ManyToManyField(Anchor)
# Models to test generic relations
class Tag(RESTFrameworkModel):
tag_name = models.SlugField()
class TaggedItem(RESTFrameworkModel):
tag = models.ForeignKey(Tag, related_name='items')
content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField()
content_object = GenericForeignKey('content_type', 'object_id')
def __unicode__(self):
return self.tag.tag_name
class Bookmark(RESTFrameworkModel):
url = models.URLField()
tags = GenericRelation(TaggedItem)
# Model to test filtering. # Model to test filtering.
class FilterableItem(RESTFrameworkModel): class FilterableItem(RESTFrameworkModel):

View File

@ -1,90 +0,0 @@
# from rest_framework.compat import patterns, url
# from django.forms import ModelForm
# from django.contrib.auth.models import Group, User
# from rest_framework.resources import ModelResource
# from rest_framework.views import ListOrCreateModelView, InstanceModelView
# from rest_framework.tests.models import CustomUser
# from rest_framework.tests.testcases import TestModelsTestCase
# class GroupResource(ModelResource):
# model = Group
# class UserForm(ModelForm):
# class Meta:
# model = User
# exclude = ('last_login', 'date_joined')
# class UserResource(ModelResource):
# model = User
# form = UserForm
# class CustomUserResource(ModelResource):
# model = CustomUser
# urlpatterns = patterns('',
# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
# )
# class ModelViewTests(TestModelsTestCase):
# """Test the model views rest_framework provides"""
# urls = 'rest_framework.tests.modelviews'
# def test_creation(self):
# """Ensure that a model object can be created"""
# self.assertEqual(0, Group.objects.count())
# response = self.client.post('/groups/', {'name': 'foo'})
# self.assertEqual(response.status_code, 201)
# self.assertEqual(1, Group.objects.count())
# self.assertEqual('foo', Group.objects.all()[0].name)
# def test_creation_with_m2m_relation(self):
# """Ensure that a model object with a m2m relation can be created"""
# group = Group(name='foo')
# group.save()
# self.assertEqual(0, User.objects.count())
# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
# self.assertEqual(response.status_code, 201)
# self.assertEqual(1, User.objects.count())
# user = User.objects.all()[0]
# self.assertEqual('bar', user.username)
# self.assertEqual('baz', user.password)
# self.assertEqual(1, user.groups.count())
# group = user.groups.all()[0]
# self.assertEqual('foo', group.name)
# def test_creation_with_m2m_relation_through(self):
# """
# Ensure that a model object with a m2m relation can be created where that
# relation uses a through table
# """
# group = Group(name='foo')
# group.save()
# self.assertEqual(0, User.objects.count())
# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
# self.assertEqual(response.status_code, 201)
# self.assertEqual(1, CustomUser.objects.count())
# user = CustomUser.objects.all()[0]
# self.assertEqual('bar', user.username)
# self.assertEqual(1, user.groups.count())
# group = user.groups.all()[0]
# self.assertEqual('foo', group.name)

View File

@ -1,6 +1,9 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.negotiation import DefaultContentNegotiation
from rest_framework.request import Request
factory = RequestFactory() factory = RequestFactory()
@ -22,16 +25,16 @@ class TestAcceptedMediaType(TestCase):
return self.negotiator.select_renderer(request, self.renderers) return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self): def test_client_without_accept_use_renderer(self):
request = factory.get('/') request = Request(factory.get('/'))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json') self.assertEquals(accepted_media_type, 'application/json')
def test_client_underspecifies_accept_use_renderer(self): def test_client_underspecifies_accept_use_renderer(self):
request = factory.get('/', HTTP_ACCEPT='*/*') request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json') self.assertEquals(accepted_media_type, 'application/json')
def test_client_overspecifies_accept_use_client(self): def test_client_overspecifies_accept_use_client(self):
request = factory.get('/', HTTP_ACCEPT='application/json; indent=8') request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json; indent=8') self.assertEquals(accepted_media_type, 'application/json; indent=8')

View File

@ -1,124 +0,0 @@
from django.db import models
from django.test import TestCase
from rest_framework import serializers
class OneToOneTarget(models.Model):
name = models.CharField(max_length=100)
class OneToOneTargetSource(models.Model):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='target_source')
class OneToOneSource(models.Model):
name = models.CharField(max_length=100)
target_source = models.OneToOneField(OneToOneTargetSource, related_name='source')
class OneToOneSourceSerializer(serializers.ModelSerializer):
class Meta:
model = OneToOneSource
exclude = ('target_source', )
class OneToOneTargetSourceSerializer(serializers.ModelSerializer):
source = OneToOneSourceSerializer()
class Meta:
model = OneToOneTargetSource
exclude = ('target', )
class OneToOneTargetSerializer(serializers.ModelSerializer):
target_source = OneToOneTargetSourceSerializer()
class Meta:
model = OneToOneTarget
class NestedOneToOneTests(TestCase):
def setUp(self):
for idx in range(1, 4):
target = OneToOneTarget(name='target-%d' % idx)
target.save()
target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target)
target_source.save()
source = OneToOneSource(name='source-%d' % idx, target_source=target_source)
source.save()
def test_one_to_one_retrieve(self):
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': u'target-1', 'target_source': {'id': 1, 'name': u'target-source-1', 'source': {'id': 1, 'name': u'source-1'}}},
{'id': 2, 'name': u'target-2', 'target_source': {'id': 2, 'name': u'target-source-2', 'source': {'id': 2, 'name': u'source-2'}}},
{'id': 3, 'name': u'target-3', 'target_source': {'id': 3, 'name': u'target-source-3', 'source': {'id': 3, 'name': u'source-3'}}}
]
self.assertEquals(serializer.data, expected)
def test_one_to_one_create(self):
data = {'id': 4, 'name': u'target-4', 'target_source': {'id': 4, 'name': u'target-source-4', 'source': {'id': 4, 'name': u'source-4'}}}
serializer = OneToOneTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'target-4')
# Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': u'target-1', 'target_source': {'id': 1, 'name': u'target-source-1', 'source': {'id': 1, 'name': u'source-1'}}},
{'id': 2, 'name': u'target-2', 'target_source': {'id': 2, 'name': u'target-source-2', 'source': {'id': 2, 'name': u'source-2'}}},
{'id': 3, 'name': u'target-3', 'target_source': {'id': 3, 'name': u'target-source-3', 'source': {'id': 3, 'name': u'source-3'}}},
{'id': 4, 'name': u'target-4', 'target_source': {'id': 4, 'name': u'target-source-4', 'source': {'id': 4, 'name': u'source-4'}}}
]
self.assertEquals(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self):
data = {'id': 4, 'name': u'target-4', 'target_source': {'id': 4, 'name': u'target-source-4', 'source': {'id': 4}}}
serializer = OneToOneTargetSerializer(data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target_source': [{'source': [{'name': [u'This field is required.']}]}]})
def test_one_to_one_update(self):
data = {'id': 3, 'name': u'target-3-updated', 'target_source': {'id': 3, 'name': u'target-source-3-updated', 'source': {'id': 3, 'name': u'source-3-updated'}}}
instance = OneToOneTarget.objects.get(pk=3)
serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'target-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': u'target-1', 'target_source': {'id': 1, 'name': u'target-source-1', 'source': {'id': 1, 'name': u'source-1'}}},
{'id': 2, 'name': u'target-2', 'target_source': {'id': 2, 'name': u'target-source-2', 'source': {'id': 2, 'name': u'source-2'}}},
{'id': 3, 'name': u'target-3-updated', 'target_source': {'id': 3, 'name': u'target-source-3-updated', 'source': {'id': 3, 'name': u'source-3-updated'}}}
]
self.assertEquals(serializer.data, expected)
def test_one_to_one_delete(self):
data = {'id': 3, 'name': u'target-3', 'target_source': None}
instance = OneToOneTarget.objects.get(pk=3)
serializer = OneToOneTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
# Ensure (target_source 3, source 3) are deleted,
# and everything else is as expected.
queryset = OneToOneTarget.objects.all()
serializer = OneToOneTargetSerializer(queryset)
expected = [
{'id': 1, 'name': u'target-1', 'target_source': {'id': 1, 'name': u'target-source-1', 'source': {'id': 1, 'name': u'source-1'}}},
{'id': 2, 'name': u'target-2', 'target_source': {'id': 2, 'name': u'target-source-2', 'source': {'id': 2, 'name': u'source-2'}}},
{'id': 3, 'name': u'target-3', 'target_source': None}
]
self.assertEquals(serializer.data, expected)

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.core.paginator import Paginator from django.core.paginator import Paginator
@ -252,6 +253,8 @@ class TestCustomPaginateByParam(TestCase):
self.assertEquals(response.data['results'], self.data[:5]) self.assertEquals(response.data['results'], self.data[:5])
### Tests for context in pagination serializers
class CustomField(serializers.Field): class CustomField(serializers.Field):
def to_native(self, value): def to_native(self, value):
if not 'view' in self.context: if not 'view' in self.context:
@ -262,6 +265,11 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer): class BasicModelSerializer(serializers.Serializer):
text = CustomField() text = CustomField()
def __init__(self, *args, **kwargs):
super(BasicModelSerializer, self).__init__(*args, **kwargs)
if not 'view' in self.context:
raise RuntimeError("context isn't getting passed into serializer init")
class TestContextPassedToCustomField(TestCase): class TestContextPassedToCustomField(TestCase):
def setUp(self): def setUp(self):
@ -279,3 +287,39 @@ class TestContextPassedToCustomField(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
### Tests for custom pagination serializers
class LinksSerializer(serializers.Serializer):
next = pagination.NextPageField(source='*')
prev = pagination.PreviousPageField(source='*')
class CustomPaginationSerializer(pagination.BasePaginationSerializer):
links = LinksSerializer(source='*') # Takes the page object as the source
total_results = serializers.Field(source='paginator.count')
results_field = 'objects'
class TestCustomPaginationSerializer(TestCase):
def setUp(self):
objects = ['john', 'paul', 'george', 'ringo']
paginator = Paginator(objects, 2)
self.page = paginator.page(1)
def test_custom_pagination_serializer(self):
request = RequestFactory().get('/foobar')
serializer = CustomPaginationSerializer(
instance=self.page,
context={'request': request}
)
expected = {
'links': {
'next': 'http://testserver/foobar?page=2',
'prev': None
},
'total_results': 4,
'objects': ['john', 'paul']
}
self.assertEquals(serializer.data, expected)

View File

@ -1,137 +1,5 @@
# """ from __future__ import unicode_literals
# .. from rest_framework.compat import StringIO
# >>> from rest_framework.parsers import FormParser
# >>> from django.test.client import RequestFactory
# >>> from rest_framework.views import View
# >>> from StringIO import StringIO
# >>> from urllib import urlencode
# >>> req = RequestFactory().get('/')
# >>> some_view = View()
# >>> some_view.request = req # Make as if this request had been dispatched
#
# FormParser
# ============
#
# Data flatening
# ----------------
#
# Here is some example data, which would eventually be sent along with a post request :
#
# >>> inpt = urlencode([
# ... ('key1', 'bla1'),
# ... ('key2', 'blo1'), ('key2', 'blo2'),
# ... ])
#
# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter :
#
# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'bla1', 'key2': 'blo1'}
# True
#
# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` :
#
# >>> class MyFormParser(FormParser):
# ...
# ... def is_a_list(self, key, val_list):
# ... return len(val_list) > 1
#
# This new parser only flattens the lists of parameters that contain a single value.
#
# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']}
# True
#
# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`.
#
# Submitting an empty list
# --------------------------
#
# When submitting an empty select multiple, like this one ::
#
# <select multiple="multiple" name="key2"></select>
#
# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty ::
#
# <select multiple="multiple" name="key2"><option value="_empty"></select>
#
# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data :
#
# >>> inpt = urlencode([
# ... ('key1', 'blo1'), ('key1', '_empty'),
# ... ('key2', '_empty'),
# ... ])
#
# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists.
#
# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'blo1'}
# True
#
# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it.
#
# >>> class MyFormParser(FormParser):
# ...
# ... def is_a_list(self, key, val_list):
# ... return key == 'key2'
# ...
# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'blo1', 'key2': []}
# True
#
# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`.
# """
# import httplib, mimetypes
# from tempfile import TemporaryFile
# from django.test import TestCase
# from django.test.client import RequestFactory
# from rest_framework.parsers import MultiPartParser
# from rest_framework.views import View
# from StringIO import StringIO
#
# def encode_multipart_formdata(fields, files):
# """For testing multipart parser.
# fields is a sequence of (name, value) elements for regular form fields.
# files is a sequence of (name, filename, value) elements for data to be uploaded as files
# Return (content_type, body)."""
# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$'
# CRLF = '\r\n'
# L = []
# for (key, value) in fields:
# L.append('--' + BOUNDARY)
# L.append('Content-Disposition: form-data; name="%s"' % key)
# L.append('')
# L.append(value)
# for (key, filename, value) in files:
# L.append('--' + BOUNDARY)
# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename))
# L.append('Content-Type: %s' % get_content_type(filename))
# L.append('')
# L.append(value)
# L.append('--' + BOUNDARY + '--')
# L.append('')
# body = CRLF.join(L)
# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY
# return content_type, body
#
# def get_content_type(filename):
# return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
#
#class TestMultiPartParser(TestCase):
# def setUp(self):
# self.req = RequestFactory()
# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')],
# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')])
#
# def test_multipartparser(self):
# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters."""
# post_req = RequestFactory().post('/', self.body, content_type=self.content_type)
# view = View()
# view.request = post_req
# (data, files) = MultiPartParser(view).parse(StringIO(self.body))
# self.assertEqual(data['key1'], 'val1')
# self.assertEqual(files['file1'].read(), 'blablabla')
from StringIO import StringIO
from django import forms from django import forms
from django.test import TestCase from django.test import TestCase
from rest_framework.parsers import FormParser from rest_framework.parsers import FormParser

View File

@ -1,7 +1,7 @@
""" """
General tests for relational fields. General tests for relational fields.
""" """
from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
@ -31,3 +31,17 @@ class FieldTests(TestCase):
field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) self.assertRaises(serializers.ValidationError, field.from_native, [])
class TestManyRelateMixin(TestCase):
def test_missing_many_to_many_related_field(self):
'''
Regression test for #632
https://github.com/tomchristie/django-rest-framework/pull/632
'''
field = serializers.RelatedField(many=True, read_only=False)
into = {}
field.field_from_native({}, None, 'field_name', into)
self.assertEqual(into['field_name'], [])

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import patterns, url from rest_framework.compat import patterns, url
@ -19,7 +20,7 @@ urlpatterns = patterns('',
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
class Meta: class Meta:
model = ManyToManyTarget model = ManyToManyTarget
@ -31,7 +32,7 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
@ -74,9 +75,9 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -84,14 +85,14 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, {'url': '/manytomanytarget/1/', 'name': 'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, {'url': '/manytomanytarget/2/', 'name': 'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} {'url': '/manytomanytarget/3/', 'name': 'target-3', 'sources': ['/manytomanysource/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} data = {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1) instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data) serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -102,14 +103,14 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
{'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']} data = {'url': '/manytomanytarget/1/', 'name': 'target-1', 'sources': ['/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data) serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -120,48 +121,48 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}, {'url': '/manytomanytarget/1/', 'name': 'target-1', 'sources': ['/manytomanysource/1/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, {'url': '/manytomanytarget/2/', 'name': 'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} {'url': '/manytomanytarget/3/', 'name': 'target-3', 'sources': ['/manytomanysource/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_many_to_many_create(self): def test_many_to_many_create(self):
data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} data = {'url': '/manytomanysource/4/', 'name': 'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data) serializer = ManyToManySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
{'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} {'url': '/manytomanysource/4/', 'name': 'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} data = {'url': '/manytomanytarget/4/', 'name': 'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data) serializer = ManyToManyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'target-4') self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected # Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, {'url': '/manytomanytarget/1/', 'name': 'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, {'url': '/manytomanytarget/2/', 'name': 'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}, {'url': '/manytomanytarget/3/', 'name': 'target-3', 'sources': ['/manytomanysource/3/']},
{'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} {'url': '/manytomanytarget/4/', 'name': 'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -182,9 +183,9 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/1/', 'name': 'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} {'url': '/foreignkeysource/3/', 'name': 'source-3', 'target': '/foreignkeytarget/1/'}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -192,13 +193,13 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, {'url': '/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'} data = {'url': '/foreignkeysource/1/', 'name': 'source-1', 'target': '/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -209,14 +210,21 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}, {'url': '/foreignkeysource/1/', 'name': 'source-1', 'target': '/foreignkeytarget/2/'},
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} {'url': '/foreignkeysource/3/', 'name': 'source-3', 'target': '/foreignkeytarget/1/'}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self):
data = {'url': '/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} data = {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -225,8 +233,8 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset) new_serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, {'url': '/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
] ]
self.assertEquals(new_serializer.data, expected) self.assertEquals(new_serializer.data, expected)
@ -237,54 +245,54 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, {'url': '/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['/foreignkeysource/2/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create(self): def test_foreign_key_create(self):
data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'} data = {'url': '/foreignkeysource/4/', 'name': 'source-4', 'target': '/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data) serializer = ForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/1/', 'name': 'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/3/', 'name': 'source-3', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}, {'url': '/foreignkeysource/4/', 'name': 'source-4', 'target': '/foreignkeytarget/2/'},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} data = {'url': '/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data) serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'target-3') self.assertEqual(obj.name, 'target-3')
# Ensure target 4 is added, and everything else is as expected # Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, {'url': '/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['/foreignkeysource/2/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, {'url': '/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None} data = {'url': '/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(serializer.errors, {'target': ['This field is required.']})
class HyperlinkedNullableForeignKeyTests(TestCase): class HyperlinkedNullableForeignKeyTests(TestCase):
@ -303,28 +311,28 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} data = {'url': '/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} {'url': '/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -333,27 +341,27 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''} data = {'url': '/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} expected_data = {'url': '/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, expected_data) self.assertEquals(serializer.data, expected_data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} {'url': '/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} data = {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -364,9 +372,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -375,8 +383,8 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''} data = {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} expected_data = {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -387,9 +395,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, {'url': '/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/2/', 'name': 'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -398,7 +406,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
# and cannot be arbitrarily set. # and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self): # def test_reverse_foreign_key_update(self):
# data = {'id': 1, 'name': u'target-1', 'sources': [1]} # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1) # instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data) # serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
@ -409,8 +417,8 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
# queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
# serializer = ForeignKeyTargetSerializer(queryset) # serializer = ForeignKeyTargetSerializer(queryset)
# expected = [ # expected = [
# {'id': 1, 'name': u'target-1', 'sources': [1]}, # {'id': 1, 'name': 'target-1', 'sources': [1]},
# {'id': 2, 'name': u'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
# ] # ]
# self.assertEquals(serializer.data, expected) # self.assertEquals(serializer.data, expected)
@ -430,7 +438,7 @@ class HyperlinkedNullableOneToOneTests(TestCase):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset) serializer = NullableOneToOneTargetSerializer(queryset)
expected = [ expected = [
{'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, {'url': '/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': '/nullableonetoonesource/1/'},
{'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, {'url': '/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
@ -15,7 +16,7 @@ class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
class ForeignKeyTargetSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = FlatForeignKeySourceSerializer() sources = FlatForeignKeySourceSerializer(many=True)
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
@ -53,9 +54,9 @@ class ReverseForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -63,12 +64,12 @@ class ReverseForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [ {'id': 1, 'name': 'target-1', 'sources': [
{'id': 1, 'name': u'source-1', 'target': 1}, {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1},
]}, ]},
{'id': 2, 'name': u'target-2', 'sources': [ {'id': 2, 'name': 'target-2', 'sources': [
]} ]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -88,9 +89,9 @@ class NestedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': 'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -108,7 +109,7 @@ class NestedNullableOneToOneTests(TestCase):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset) serializer = NullableOneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}},
{'id': 2, 'name': u'target-2', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -1,10 +1,11 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
class ManyToManyTargetSerializer(serializers.ModelSerializer): class ManyToManyTargetSerializer(serializers.ModelSerializer):
sources = serializers.ManyPrimaryKeyRelatedField() sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta: class Meta:
model = ManyToManyTarget model = ManyToManyTarget
@ -16,7 +17,7 @@ class ManyToManySourceSerializer(serializers.ModelSerializer):
class ForeignKeyTargetSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.ManyPrimaryKeyRelatedField() sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
@ -56,9 +57,9 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'targets': [1]}, {'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': u'source-2', 'targets': [1, 2]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -66,14 +67,14 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': u'target-3', 'sources': [3]} {'id': 3, 'name': 'target-3', 'sources': [3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1) instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data) serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -84,14 +85,14 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}, {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': u'source-2', 'targets': [1, 2]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'id': 1, 'name': u'target-1', 'sources': [1]} data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data) serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -102,47 +103,47 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1]}, {'id': 1, 'name': 'target-1', 'sources': [1]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': u'target-3', 'sources': [3]} {'id': 3, 'name': 'target-3', 'sources': [3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_many_to_many_create(self): def test_many_to_many_create(self):
data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]} data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data) serializer = ManyToManySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'targets': [1]}, {'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': u'source-2', 'targets': [1, 2]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
{'id': 4, 'name': u'source-4', 'targets': [1, 3]}, {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]} data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data) serializer = ManyToManyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'target-4') self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected # Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': u'target-3', 'sources': [3]}, {'id': 3, 'name': 'target-3', 'sources': [3]},
{'id': 4, 'name': u'target-4', 'sources': [1, 3]} {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -161,9 +162,9 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 1}, {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1} {'id': 3, 'name': 'source-3', 'target': 1}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -171,13 +172,13 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': 'target-2', 'sources': []},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'id': 1, 'name': u'source-1', 'target': 2} data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -188,14 +189,21 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 2}, {'id': 1, 'name': 'source-1', 'target': 2},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1} {'id': 3, 'name': 'source-3', 'target': 1}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': ['Incorrect type. Expected pk value, received str.']})
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]} data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -204,8 +212,8 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset) new_serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': 'target-2', 'sources': []},
] ]
self.assertEquals(new_serializer.data, expected) self.assertEquals(new_serializer.data, expected)
@ -216,54 +224,54 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [2]}, {'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': u'target-2', 'sources': [1, 3]}, {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create(self): def test_foreign_key_create(self):
data = {'id': 4, 'name': u'source-4', 'target': 2} data = {'id': 4, 'name': 'source-4', 'target': 2}
serializer = ForeignKeySourceSerializer(data=data) serializer = ForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 1}, {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1},
{'id': 4, 'name': u'source-4', 'target': 2}, {'id': 4, 'name': 'source-4', 'target': 2},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]} data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data) serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'target-3') self.assertEqual(obj.name, 'target-3')
# Ensure target 3 is added, and everything else is as expected # Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [2]}, {'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': u'target-3', 'sources': [1, 3]}, {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None} data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(serializer.errors, {'target': ['Value may not be null']})
class PKNullableForeignKeyTests(TestCase): class PKNullableForeignKeyTests(TestCase):
@ -280,28 +288,28 @@ class PKNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 1}, {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': 'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': u'source-4', 'target': None} data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 1}, {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': u'source-4', 'target': None} {'id': 4, 'name': 'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -310,27 +318,27 @@ class PKNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 4, 'name': u'source-4', 'target': ''} data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': u'source-4', 'target': None} expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, expected_data) self.assertEquals(serializer.data, expected_data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 1}, {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': u'source-4', 'target': None} {'id': 4, 'name': 'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None} data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -341,9 +349,9 @@ class PKNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': None}, {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': None} {'id': 3, 'name': 'source-3', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -352,8 +360,8 @@ class PKNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 1, 'name': u'source-1', 'target': ''} data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': u'source-1', 'target': None} expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -364,9 +372,9 @@ class PKNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': None}, {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': None} {'id': 3, 'name': 'source-3', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -375,7 +383,7 @@ class PKNullableForeignKeyTests(TestCase):
# and cannot be arbitrarily set. # and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self): # def test_reverse_foreign_key_update(self):
# data = {'id': 1, 'name': u'target-1', 'sources': [1]} # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1) # instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data) # serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
@ -386,8 +394,8 @@ class PKNullableForeignKeyTests(TestCase):
# queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
# serializer = ForeignKeyTargetSerializer(queryset) # serializer = ForeignKeyTargetSerializer(queryset)
# expected = [ # expected = [
# {'id': 1, 'name': u'target-1', 'sources': [1]}, # {'id': 1, 'name': 'target-1', 'sources': [1]},
# {'id': 2, 'name': u'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
# ] # ]
# self.assertEquals(serializer.data, expected) # self.assertEquals(serializer.data, expected)
@ -405,7 +413,7 @@ class PKNullableOneToOneTests(TestCase):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset) serializer = NullableOneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'nullable_source': 1}, {'id': 1, 'name': 'target-1', 'nullable_source': 1},
{'id': 2, 'name': u'target-2', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -1,16 +1,156 @@
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import NullableForeignKeySource, ForeignKeyTarget from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
class NullableSlugSourceSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name', null=True) sources = serializers.SlugRelatedField(many=True, slug_field='name')
class Meta:
model = ForeignKeyTarget
class ForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name')
class Meta:
model = ForeignKeySource
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name', required=False)
class Meta: class Meta:
model = NullableForeignKeySource model = NullableForeignKeySource
# TODO: M2M Tests, FKTests (Non-nulable), One2One # TODO: M2M Tests, FKTests (Non-nulable), One2One
class PKForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
self.assertEquals(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
self.assertEquals(serializer.data, expected)
def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-2'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
self.assertEquals(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 123}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': ['Object with name=123 does not exist.']})
def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
self.assertEquals(new_serializer.data, expected)
serializer.save()
self.assertEquals(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
]
self.assertEquals(serializer.data, expected)
def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
serializer = ForeignKeySourceSerializer(data=data)
serializer.is_valid()
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'},
{'id': 4, 'name': 'source-4', 'target': 'target-2'},
]
self.assertEquals(serializer.data, expected)
def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, 'target-3')
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
]
self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': ['This field is required.']})
class SlugNullableForeignKeyTests(TestCase): class SlugNullableForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
@ -24,30 +164,30 @@ class SlugNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableSlugSourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 'target-1'}, {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': u'source-2', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': 'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': u'source-4', 'target': None} data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableSlugSourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableSlugSourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 'target-1'}, {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': u'source-2', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': u'source-4', 'target': None} {'id': 4, 'name': 'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -56,40 +196,40 @@ class SlugNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 4, 'name': u'source-4', 'target': ''} data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': u'source-4', 'target': None} expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableSlugSourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, expected_data) self.assertEquals(serializer.data, expected_data)
self.assertEqual(obj.name, u'source-4') self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableSlugSourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 'target-1'}, {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': u'source-2', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': u'source-4', 'target': None} {'id': 4, 'name': 'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None} data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableSlugSourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
serializer.save() serializer.save()
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableSlugSourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': None}, {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': u'source-3', 'target': None} {'id': 3, 'name': 'source-3', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -98,20 +238,20 @@ class SlugNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 1, 'name': u'source-1', 'target': ''} data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': u'source-1', 'target': None} expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableSlugSourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, expected_data) self.assertEquals(serializer.data, expected_data)
serializer.save() serializer.save()
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableSlugSourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': None}, {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': u'source-3', 'target': None} {'id': 3, 'name': 'source-3', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -14,7 +14,8 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from StringIO import StringIO from rest_framework.compat import StringIO
from rest_framework.compat import six
import datetime import datetime
from decimal import Decimal from decimal import Decimal
@ -22,8 +23,8 @@ from decimal import Decimal
DUMMYSTATUS = status.HTTP_200_OK DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent' DUMMYCONTENT = 'dummycontent'
RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
expected_results = [ expected_results = [
@ -140,7 +141,7 @@ class RendererEndToEndTests(TestCase):
resp = self.client.head('/') resp = self.client.head('/')
self.assertEquals(resp.status_code, DUMMYSTATUS) self.assertEquals(resp.status_code, DUMMYSTATUS)
self.assertEquals(resp['Content-Type'], RendererA.media_type) self.assertEquals(resp['Content-Type'], RendererA.media_type)
self.assertEquals(resp.content, '') self.assertEquals(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""
@ -267,7 +268,8 @@ class JSONPRendererTests(TestCase):
HTTP_ACCEPT='application/javascript') HTTP_ACCEPT='application/javascript')
self.assertEquals(resp.status_code, 200) self.assertEquals(resp.status_code, 200)
self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp['Content-Type'], 'application/javascript')
self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) self.assertEquals(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
def test_without_callback_without_json_renderer(self): def test_without_callback_without_json_renderer(self):
""" """
@ -277,7 +279,8 @@ class JSONPRendererTests(TestCase):
HTTP_ACCEPT='application/javascript') HTTP_ACCEPT='application/javascript')
self.assertEquals(resp.status_code, 200) self.assertEquals(resp.status_code, 200)
self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp['Content-Type'], 'application/javascript')
self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) self.assertEquals(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
def test_with_callback(self): def test_with_callback(self):
""" """
@ -288,7 +291,8 @@ class JSONPRendererTests(TestCase):
HTTP_ACCEPT='application/javascript') HTTP_ACCEPT='application/javascript')
self.assertEquals(resp.status_code, 200) self.assertEquals(resp.status_code, 200)
self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp['Content-Type'], 'application/javascript')
self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr)) self.assertEquals(resp.content,
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
if yaml: if yaml:

View File

@ -1,7 +1,7 @@
""" """
Tests for content parsing, and form-overloaded content parsing. Tests for content parsing, and form-overloaded content parsing.
""" """
import json from __future__ import unicode_literals
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sessions.middleware import SessionMiddleware
@ -20,6 +20,8 @@ from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.compat import six
import json
factory = RequestFactory() factory = RequestFactory()
@ -79,14 +81,14 @@ class TestContentParsing(TestCase):
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data)) request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.DATA.items(), data.items()) self.assertEqual(list(request.DATA.items()), list(data.items()))
def test_request_DATA_with_text_content(self): def test_request_DATA_with_text_content(self):
""" """
Ensure request.DATA returns content for POST request with Ensure request.DATA returns content for POST request with
non-form content. non-form content.
""" """
content = 'qwerty' content = six.b('qwerty')
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type)) request = Request(factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),) request.parsers = (PlainTextParser(),)
@ -99,7 +101,7 @@ class TestContentParsing(TestCase):
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data)) request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.POST.items(), data.items()) self.assertEqual(list(request.POST.items()), list(data.items()))
def test_standard_behaviour_determines_form_content_PUT(self): def test_standard_behaviour_determines_form_content_PUT(self):
""" """
@ -117,14 +119,14 @@ class TestContentParsing(TestCase):
request = Request(factory.put('/', data)) request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
self.assertEqual(request.DATA.items(), data.items()) self.assertEqual(list(request.DATA.items()), list(data.items()))
def test_standard_behaviour_determines_non_form_content_PUT(self): def test_standard_behaviour_determines_non_form_content_PUT(self):
""" """
Ensure request.DATA returns content for PUT request with Ensure request.DATA returns content for PUT request with
non-form content. non-form content.
""" """
content = 'qwerty' content = six.b('qwerty')
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type)) request = Request(factory.put('/', content, content_type=content_type))
request.parsers = (PlainTextParser(), ) request.parsers = (PlainTextParser(), )

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import patterns, url, include from rest_framework.compat import patterns, url, include
from rest_framework.response import Response from rest_framework.response import Response
@ -9,6 +10,7 @@ from rest_framework.renderers import (
BrowsableAPIRenderer BrowsableAPIRenderer
) )
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.compat import six
class MockPickleRenderer(BaseRenderer): class MockPickleRenderer(BaseRenderer):
@ -22,8 +24,8 @@ class MockJsonRenderer(BaseRenderer):
DUMMYSTATUS = status.HTTP_200_OK DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent' DUMMYCONTENT = 'dummycontent'
RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
class RendererA(BaseRenderer): class RendererA(BaseRenderer):
@ -92,7 +94,7 @@ class RendererIntegrationTests(TestCase):
resp = self.client.head('/') resp = self.client.head('/')
self.assertEquals(resp.status_code, DUMMYSTATUS) self.assertEquals(resp.status_code, DUMMYSTATUS)
self.assertEquals(resp['Content-Type'], RendererA.media_type) self.assertEquals(resp['Content-Type'], RendererA.media_type)
self.assertEquals(resp.content, '') self.assertEquals(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework.compat import patterns, url from rest_framework.compat import patterns, url

View File

@ -1,10 +1,12 @@
import datetime from __future__ import unicode_literals
import pickle from django.utils.datastructures import MultiValueDict
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
import datetime
import pickle
class SubComment(object): class SubComment(object):
@ -54,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer):
model = ActionItem model = ActionItem
class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
class Meta:
model = ActionItem
def restore_object(self, data, instance=None):
if instance is None:
return ActionItem(**data)
for key, val in data.items():
setattr(instance, key, val)
return instance
class PersonSerializer(serializers.ModelSerializer): class PersonSerializer(serializers.ModelSerializer):
info = serializers.Field(source='info') info = serializers.Field(source='info')
@ -162,7 +177,6 @@ class BasicTests(TestCase):
""" """
Attempting to update fields set as read_only should have no effect. Attempting to update fields set as read_only should have no effect.
""" """
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
instance = serializer.save() instance = serializer.save()
@ -171,6 +185,33 @@ class BasicTests(TestCase):
self.assertEquals(instance.age, self.person_data['age']) self.assertEquals(instance.age, self.person_data['age'])
class DictStyleSerializer(serializers.Serializer):
"""
Note that we don't have any `restore_object` method, so the default
case of simply returning a dict will apply.
"""
email = serializers.EmailField()
class DictStyleSerializerTests(TestCase):
def test_dict_style_deserialize(self):
"""
Ensure serializers can deserialize into a dict.
"""
data = {'email': 'foo@example.com'}
serializer = DictStyleSerializer(data=data)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data)
def test_dict_style_serialize(self):
"""
Ensure serializers can serialize dict objects.
"""
data = {'email': 'foo@example.com'}
serializer = DictStyleSerializer(data)
self.assertEquals(serializer.data, data)
class ValidationTests(TestCase): class ValidationTests(TestCase):
def setUp(self): def setUp(self):
self.comment = Comment( self.comment = Comment(
@ -183,18 +224,17 @@ class ValidationTests(TestCase):
'content': 'x' * 1001, 'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1) 'created': datetime.datetime(2012, 1, 1)
} }
self.actionitem = ActionItem(title='Some to do item', self.actionitem = ActionItem(title='Some to do item',)
)
def test_create(self): def test_create(self):
serializer = CommentSerializer(data=self.data) serializer = CommentSerializer(data=self.data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) self.assertEquals(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self): def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data) serializer = CommentSerializer(self.comment, data=self.data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) self.assertEquals(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update_missing_field(self): def test_update_missing_field(self):
data = { data = {
@ -203,7 +243,7 @@ class ValidationTests(TestCase):
} }
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) self.assertEquals(serializer.errors, {'email': ['This field is required.']})
def test_missing_bool_with_default(self): def test_missing_bool_with_default(self):
"""Make sure that a boolean value with a 'False' value is not """Make sure that a boolean value with a 'False' value is not
@ -216,31 +256,6 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {}) self.assertEquals(serializer.errors, {})
def test_field_validation(self):
class CommentSerializerWithFieldValidator(CommentSerializer):
def validate_content(self, attrs, source):
value = attrs[source]
if "test" not in value:
raise serializers.ValidationError("Test not in value")
return attrs
data = {
'email': 'tom@example.com',
'content': 'A test comment',
'created': datetime.datetime(2012, 1, 1)
}
serializer = CommentSerializerWithFieldValidator(data=data)
self.assertTrue(serializer.is_valid())
data['content'] = 'This should not validate'
serializer = CommentSerializerWithFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
def test_bad_type_data_is_false(self): def test_bad_type_data_is_false(self):
""" """
Data of the wrong type is not valid. Data of the wrong type is not valid.
@ -248,17 +263,17 @@ class ValidationTests(TestCase):
data = ['i am', 'a', 'list'] data = ['i am', 'a', 'list']
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(serializer.errors, {'non_field_errors': ['Invalid data']})
data = 'and i am a string' data = 'and i am a string'
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(serializer.errors, {'non_field_errors': ['Invalid data']})
data = 42 data = 42
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(serializer.errors, {'non_field_errors': ['Invalid data']})
def test_cross_field_validation(self): def test_cross_field_validation(self):
@ -282,7 +297,7 @@ class ValidationTests(TestCase):
serializer = CommentSerializerWithCrossFieldValidator(data=data) serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) self.assertEquals(serializer.errors, {'non_field_errors': ['Email address not in content']})
def test_null_is_true_fields(self): def test_null_is_true_fields(self):
""" """
@ -298,7 +313,21 @@ class ValidationTests(TestCase):
} }
serializer = ActionItemSerializer(data=data) serializer = ActionItemSerializer(data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) self.assertEquals(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
def test_modelserializer_max_length_exceeded_with_custom_restore(self):
"""
When overriding ModelSerializer.restore_object, validation tests should still apply.
Regression test for #623.
https://github.com/tomchristie/django-rest-framework/pull/623
"""
data = {
'title': 'x' * 201,
}
serializer = ActionItemSerializerCustomRestore(data=data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
def test_default_modelfield_max_length_exceeded(self): def test_default_modelfield_max_length_exceeded(self):
data = { data = {
@ -307,15 +336,72 @@ class ValidationTests(TestCase):
} }
serializer = ActionItemSerializer(data=data) serializer = ActionItemSerializer(data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) self.assertEquals(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
class CustomValidationTests(TestCase):
class CommentSerializerWithFieldValidator(CommentSerializer):
def validate_email(self, attrs, source):
value = attrs[source]
return attrs
def validate_content(self, attrs, source):
value = attrs[source]
if "test" not in value:
raise serializers.ValidationError("Test not in value")
return attrs
def test_field_validation(self):
data = {
'email': 'tom@example.com',
'content': 'A test comment',
'created': datetime.datetime(2012, 1, 1)
}
serializer = self.CommentSerializerWithFieldValidator(data=data)
self.assertTrue(serializer.is_valid())
data['content'] = 'This should not validate'
serializer = self.CommentSerializerWithFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': ['Test not in value']})
def test_missing_data(self):
"""
Make sure that validate_content isn't called if the field is missing
"""
incomplete_data = {
'email': 'tom@example.com',
'created': datetime.datetime(2012, 1, 1)
}
serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': ['This field is required.']})
def test_wrong_data(self):
"""
Make sure that validate_content isn't called if the field input is wrong
"""
wrong_data = {
'email': 'not an email',
'content': 'A test comment',
'created': datetime.datetime(2012, 1, 1)
}
serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'email': ['Enter a valid e-mail address.']})
class PositiveIntegerAsChoiceTests(TestCase): class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self): def test_positive_integer_in_json_is_correctly_parsed(self):
data = {'some_integer':1} data = {'some_integer': 1}
serializer = PositiveIntegerAsChoiceSerializer(data=data) serializer = PositiveIntegerAsChoiceSerializer(data=data)
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
class ModelValidationTests(TestCase): class ModelValidationTests(TestCase):
def test_validate_unique(self): def test_validate_unique(self):
""" """
@ -326,7 +412,7 @@ class ModelValidationTests(TestCase):
serializer.save() serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'}) second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid()) self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']}) self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
def test_foreign_key_with_partial(self): def test_foreign_key_with_partial(self):
""" """
@ -364,15 +450,15 @@ class RegexValidationTest(TestCase):
def test_create_failed(self): def test_create_failed(self):
serializer = BookSerializer(data={'isbn': '1234567890'}) serializer = BookSerializer(data={'isbn': '1234567890'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': '12345678901234'}) serializer = BookSerializer(data={'isbn': '12345678901234'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
def test_create_success(self): def test_create_success(self):
serializer = BookSerializer(data={'isbn': '1234567890123'}) serializer = BookSerializer(data={'isbn': '1234567890123'})
@ -479,7 +565,8 @@ class ManyToManyTests(TestCase):
containing no items, using a representation that does not support containing no items, using a representation that does not support
lists (eg form data). lists (eg form data).
""" """
data = {'rel': ''} data = MultiValueDict()
data.setlist('rel', [''])
serializer = self.serializer_class(data=data) serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
instance = serializer.save() instance = serializer.save()
@ -491,7 +578,7 @@ class ManyToManyTests(TestCase):
class ReadOnlyManyToManyTests(TestCase): class ReadOnlyManyToManyTests(TestCase):
def setUp(self): def setUp(self):
class ReadOnlyManyToManySerializer(serializers.ModelSerializer): class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
rel = serializers.ManyRelatedField(read_only=True) rel = serializers.RelatedField(many=True, read_only=True)
class Meta: class Meta:
model = ReadOnlyManyToManyModel model = ReadOnlyManyToManyModel
@ -686,11 +773,11 @@ class RelatedTraversalTest(TestCase):
serializer = BlogPostSerializer(instance=post) serializer = BlogPostSerializer(instance=post)
expected = { expected = {
'title': u'Test blog post', 'title': 'Test blog post',
'comments': [{ 'comments': [{
'text': u'I love this blog post', 'text': 'I love this blog post',
'post_owner': { 'post_owner': {
"name": u"django", "name": "django",
"age": None "age": None
} }
}] }]
@ -725,8 +812,8 @@ class SerializerMethodFieldTests(TestCase):
serializer = self.serializer_class(source_data) serializer = self.serializer_class(source_data)
expected = { expected = {
'beep': u'hello!', 'beep': 'hello!',
'boop': [u'a', u'b', u'c'], 'boop': ['a', 'b', 'c'],
'boop_count': 3, 'boop_count': 3,
} }
@ -742,7 +829,7 @@ class BlankFieldTests(TestCase):
model = BlankFieldModel model = BlankFieldModel
class BlankFieldSerializer(serializers.Serializer): class BlankFieldSerializer(serializers.Serializer):
title = serializers.CharField(blank=True) title = serializers.CharField(required=False)
class NotBlankFieldModelSerializer(serializers.ModelSerializer): class NotBlankFieldModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -785,7 +872,7 @@ class BlankFieldTests(TestCase):
serializer = self.not_blank_model_serializer_class(data=self.data) serializer = self.not_blank_model_serializer_class(data=self.data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
def test_create_model_null_field(self): def test_create_model_empty_field(self):
serializer = self.model_serializer_class(data={}) serializer = self.model_serializer_class(data={})
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
@ -825,8 +912,8 @@ class DepthTest(TestCase):
depth = 1 depth = 1
serializer = BlogPostSerializer(instance=post) serializer = BlogPostSerializer(instance=post)
expected = {'id': 1, 'title': u'Test blog post', expected = {'id': 1, 'title': 'Test blog post',
'writer': {'id': 1, 'name': u'django', 'age': 1}} 'writer': {'id': 1, 'name': 'django', 'age': 1}}
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
@ -845,8 +932,8 @@ class DepthTest(TestCase):
model = BlogPost model = BlogPost
serializer = BlogPostSerializer(instance=post) serializer = BlogPostSerializer(instance=post)
expected = {'id': 1, 'title': u'Test blog post', expected = {'id': 1, 'title': 'Test blog post',
'writer': {'id': 1, 'name': u'django', 'age': 1}} 'writer': {'id': 1, 'name': 'django', 'age': 1}}
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)

View File

@ -1,4 +1,5 @@
"""Tests for the settings module""" """Tests for the settings module"""
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS

View File

@ -1,4 +1,5 @@
"""Tests for the status module""" """Tests for the status module"""
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework import status from rest_framework import status

View File

@ -1,4 +1,5 @@
# http://djangosnippets.org/snippets/1011/ # http://djangosnippets.org/snippets/1011/
from __future__ import unicode_literals
from django.conf import settings from django.conf import settings
from django.core.management import call_command from django.core.management import call_command
from django.db.models import loading from django.db.models import loading

View File

@ -2,6 +2,7 @@
Force import of all modules in this package in order to get the standard test Force import of all modules in this package in order to get the standard test
runner to pick up the tests. Yowzers. runner to pick up the tests. Yowzers.
""" """
from __future__ import unicode_literals
import os import os
modules = [filename.rsplit('.', 1)[0] modules = [filename.rsplit('.', 1)[0]

View File

@ -1,11 +1,10 @@
""" """
Tests for the throttling implementations in the permissions module. Tests for the throttling implementations in the permissions module.
""" """
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.cache import cache from django.core.cache import cache
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle

View File

@ -0,0 +1,76 @@
from __future__ import unicode_literals
from collections import namedtuple
from django.core import urlresolvers
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.compat import patterns, url, include
from rest_framework.urlpatterns import format_suffix_patterns
# A container class for test paths for the test case
URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
def dummy_view(request, *args, **kwargs):
pass
class FormatSuffixTests(TestCase):
"""
Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
"""
def _resolve_urlpatterns(self, urlpatterns, test_paths):
factory = RequestFactory()
try:
urlpatterns = format_suffix_patterns(urlpatterns)
except Exception:
self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
for test_path in test_paths:
request = factory.get(test_path.path)
try:
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
except Exception:
self.fail("Failed to resolve URL: %s" % request.path_info)
self.assertEquals(callback_args, test_path.args)
self.assertEquals(callback_kwargs, test_path.kwargs)
def test_format_suffix(self):
urlpatterns = patterns(
'',
url(r'^test$', dummy_view),
)
test_paths = [
URLTestPath('/test', (), {}),
URLTestPath('/test.api', (), {'format': 'api'}),
URLTestPath('/test.asdf', (), {'format': 'asdf'}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_default_args(self):
urlpatterns = patterns(
'',
url(r'^test$', dummy_view, {'foo': 'bar'}),
)
test_paths = [
URLTestPath('/test', (), {'foo': 'bar', }),
URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_included_urls(self):
nested_patterns = patterns(
'',
url(r'^path$', dummy_view)
)
urlpatterns = patterns(
'',
url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
)
test_paths = [
URLTestPath('/test/path', (), {'foo': 'bar', }),
URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
from django.test.client import RequestFactory, FakePayload from django.test.client import RequestFactory, FakePayload
from django.test.client import MULTIPART_CONTENT from django.test.client import MULTIPART_CONTENT
from urlparse import urlparse from rest_framework.compat import urlparse
class RequestFactory(RequestFactory): class RequestFactory(RequestFactory):
@ -14,7 +15,7 @@ class RequestFactory(RequestFactory):
patch_data = self._encode_data(data, content_type) patch_data = self._encode_data(data, content_type)
parsed = urlparse(path) parsed = urlparse.urlparse(path)
r = { r = {
'CONTENT_LENGTH': len(patch_data), 'CONTENT_LENGTH': len(patch_data),
'CONTENT_TYPE': content_type, 'CONTENT_TYPE': content_type,

View File

@ -139,7 +139,7 @@
# raise errors on unexpected request data""" # raise errors on unexpected request data"""
# content = {'qwerty': 'uiop', 'extra': 'extra'} # content = {'qwerty': 'uiop', 'extra': 'extra'}
# validator.allow_unknown_form_fields = True # validator.allow_unknown_form_fields = True
# self.assertEqual({'qwerty': u'uiop'}, # self.assertEqual({'qwerty': 'uiop'},
# validator.validate_request(content, None), # validator.validate_request(content, None),
# "Resource didn't accept unknown fields.") # "Resource didn't accept unknown fields.")
# validator.allow_unknown_form_fields = False # validator.allow_unknown_form_fields = False

View File

@ -1,4 +1,4 @@
import copy from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework import status from rest_framework import status
@ -6,6 +6,7 @@ from rest_framework.decorators import api_view
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.views import APIView from rest_framework.views import APIView
import copy
factory = RequestFactory() factory = RequestFactory()
@ -49,7 +50,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)
@ -64,7 +65,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data) request = factory.post('/', form_data)
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)
@ -78,7 +79,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)
@ -93,7 +94,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data) request = factory.post('/', form_data)
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)

View File

@ -1,7 +1,8 @@
import time from __future__ import unicode_literals
from django.core.cache import cache from django.core.cache import cache
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
import time
class BaseThrottle(object): class BaseThrottle(object):

View File

@ -1,7 +1,38 @@
from rest_framework.compat import url from __future__ import unicode_literals
from django.core.urlresolvers import RegexURLResolver
from rest_framework.compat import url, include
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
ret = []
for urlpattern in urlpatterns:
if isinstance(urlpattern, RegexURLResolver):
# Set of included URL patterns
regex = urlpattern.regex.pattern
namespace = urlpattern.namespace
app_name = urlpattern.app_name
kwargs = urlpattern.default_kwargs
# Add in the included patterns, after applying the suffixes
patterns = apply_suffix_patterns(urlpattern.url_patterns,
suffix_pattern,
suffix_required)
ret.append(url(regex, include(patterns, namespace, app_name), kwargs))
else:
# Regular URL pattern
regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
view = urlpattern._callback or urlpattern._callback_str
kwargs = urlpattern.default_args
name = urlpattern.name
# Add in both the existing and the new urlpattern
if not suffix_required:
ret.append(urlpattern)
ret.append(url(regex, view, kwargs, name))
return ret
def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
""" """
Supplement existing urlpatterns with corresponding patterns that also Supplement existing urlpatterns with corresponding patterns that also
@ -28,15 +59,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
else: else:
suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg
ret = [] return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required)
for urlpattern in urlpatterns:
# Form our complementing '.format' urlpattern
regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
view = urlpattern._callback or urlpattern._callback_str
kwargs = urlpattern.default_args
name = urlpattern.name
# Add in both the existing and the new urlpattern
if not suffix_required:
ret.append(urlpattern)
ret.append(url(regex, view, kwargs, name))
return ret

View File

@ -12,6 +12,7 @@ your authentication settings include `SessionAuthentication`.
url(r'^auth', include('rest_framework.urls', namespace='rest_framework')) url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))
) )
""" """
from __future__ import unicode_literals
from rest_framework.compat import patterns, url from rest_framework.compat import patterns, url

View File

@ -1,6 +1,8 @@
from django.utils.encoding import smart_unicode from __future__ import unicode_literals
from django.utils.xmlutils import SimplerXMLGenerator from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO from rest_framework.compat import StringIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
import re import re
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -70,7 +72,7 @@ class XMLRenderer():
xml.endElement("list-item") xml.endElement("list-item")
elif isinstance(data, dict): elif isinstance(data, dict):
for key, value in data.iteritems(): for key, value in six.iteritems(data):
xml.startElement(key, {}) xml.startElement(key, {})
self._to_xml(xml, value) self._to_xml(xml, value)
xml.endElement(key) xml.endElement(key)
@ -80,10 +82,10 @@ class XMLRenderer():
pass pass
else: else:
xml.characters(smart_unicode(data)) xml.characters(smart_text(data))
def dict2xml(self, data): def dict2xml(self, data):
stream = StringIO.StringIO() stream = StringIO()
xml = SimplerXMLGenerator(stream, "utf-8") xml = SimplerXMLGenerator(stream, "utf-8")
xml.startDocument() xml.startDocument()

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix from django.core.urlresolvers import resolve, get_script_prefix

View File

@ -1,13 +1,14 @@
""" """
Helper classes for parsers. Helper classes for parsers.
""" """
from __future__ import unicode_literals
from django.utils.datastructures import SortedDict
from rest_framework.compat import timezone
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime import datetime
import decimal import decimal
import types import types
import json import json
from django.utils.datastructures import SortedDict
from rest_framework.compat import timezone
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):

View File

@ -3,8 +3,9 @@ Handling of media types, as found in HTTP Content-Type and Accept headers.
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7 See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7
""" """
from __future__ import unicode_literals
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from rest_framework import HTTP_HEADER_ENCODING
def media_type_matches(lhs, rhs): def media_type_matches(lhs, rhs):
@ -47,7 +48,7 @@ class _MediaType(object):
if media_type_str is None: if media_type_str is None:
media_type_str = '' media_type_str = ''
self.orig = media_type_str self.orig = media_type_str
self.full_type, self.params = parse_header(media_type_str) self.full_type, self.params = parse_header(media_type_str.encode(HTTP_HEADER_ENCODING))
self.main_type, sep, self.sub_type = self.full_type.partition('/') self.main_type, sep, self.sub_type = self.full_type.partition('/')
def match(self, other): def match(self, other):

View File

@ -1,8 +1,7 @@
""" """
Provides an APIView class that is used as the base of all class-based views. Provides an APIView class that is used as the base of all class-based views.
""" """
from __future__ import unicode_literals
import re
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from django.utils.html import escape from django.utils.html import escape
@ -13,6 +12,7 @@ from rest_framework.compat import View, apply_markdown
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
import re
def _remove_trailing_string(content, trailing): def _remove_trailing_string(content, trailing):
@ -148,6 +148,8 @@ class APIView(View):
""" """
If request is not permitted, determine what kind of exception to raise. If request is not permitted, determine what kind of exception to raise.
""" """
if not self.request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied() raise exceptions.PermissionDenied()
def throttled(self, request, wait): def throttled(self, request, wait):
@ -156,6 +158,15 @@ class APIView(View):
""" """
raise exceptions.Throttled(wait) raise exceptions.Throttled(wait)
def get_authenticate_header(self, request):
"""
If a request is unauthenticated, determine the WWW-Authenticate
header to use for 401 responses, if any.
"""
authenticators = self.get_authenticators()
if authenticators:
return authenticators[0].authenticate_header(request)
def get_parser_context(self, http_request): def get_parser_context(self, http_request):
""" """
Returns a dict that is passed through to Parser.parse(), Returns a dict that is passed through to Parser.parse(),
@ -241,7 +252,7 @@ class APIView(View):
try: try:
return conneg.select_renderer(request, renderers, self.format_kwarg) return conneg.select_renderer(request, renderers, self.format_kwarg)
except: except Exception:
if force: if force:
return (renderers[0], renderers[0].media_type) return (renderers[0], renderers[0].media_type)
raise raise
@ -319,6 +330,16 @@ class APIView(View):
# Throttle wait header # Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
self.headers['WWW-Authenticate'] = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
if isinstance(exc, exceptions.APIException): if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail}, return Response({'detail': exc.detail},
status=exc.status_code, status=exc.status_code,

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
#from __future__ import unicode_literals
from setuptools import setup from setuptools import setup
import re import re
import os import os
@ -45,9 +47,9 @@ version = get_version('rest_framework')
if sys.argv[-1] == 'publish': if sys.argv[-1] == 'publish':
os.system("python setup.py sdist upload") os.system("python setup.py sdist upload")
print "You probably want to also tag the version now:" print("You probably want to also tag the version now:")
print " git tag -a %s -m 'version %s'" % (version, version) print(" git tag -a %s -m 'version %s'" % (version, version))
print " git push --tags" print(" git push --tags")
sys.exit() sys.exit()
@ -59,7 +61,7 @@ setup(
license='BSD', license='BSD',
description='A lightweight REST framework for Django.', description='A lightweight REST framework for Django.',
author='Tom Christie', author='Tom Christie',
author_email='tom@tomchristie.com', author_email='tom@tomchristie.com', # SEE NOTE BELOW (*)
packages=get_packages('rest_framework'), packages=get_packages('rest_framework'),
package_data=get_package_data('rest_framework'), package_data=get_package_data('rest_framework'),
test_suite='rest_framework.runtests.runtests.main', test_suite='rest_framework.runtests.runtests.main',
@ -72,6 +74,13 @@ setup(
'License :: OSI Approved :: BSD License', 'License :: OSI Approved :: BSD License',
'Operating System :: OS Independent', 'Operating System :: OS Independent',
'Programming Language :: Python', 'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Topic :: Internet :: WWW/HTTP', 'Topic :: Internet :: WWW/HTTP',
] ]
) )
# (*) Please direct queries to the discussion group, rather than to me directly
# Doing so helps ensure your question is helpful to other users.
# Queries directly to my email are likely to receive a canned response.
#
# Many thanks for your understanding.

34
tox.ini
View File

@ -1,13 +1,28 @@
[tox] [tox]
downloadcache = {toxworkdir}/cache/ downloadcache = {toxworkdir}/cache/
envlist = py2.7-django1.5,py2.7-django1.4,py2.7-django1.3,py2.6-django1.5,py2.6-django1.4,py2.6-django1.3 envlist = py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.7-django1.4,py2.7-django1.3,py2.6-django1.5,py2.6-django1.4,py2.6-django1.3
[testenv] [testenv]
commands = {envpython} rest_framework/runtests/runtests.py commands = {envpython} rest_framework/runtests/runtests.py
[testenv:py3.3-django1.5]
basepython = python3.3
deps = https://www.djangoproject.com/download/1.5c1/tarball/
https://github.com/alex/django-filter/archive/master.tar.gz
[testenv:py3.2-django1.5]
basepython = python3.2
deps = https://www.djangoproject.com/download/1.5c1/tarball/
https://github.com/alex/django-filter/archive/master.tar.gz
[testenv:py2.7-django1.5] [testenv:py2.7-django1.5]
basepython = python2.7 basepython = python2.7
deps = https://github.com/django/django/zipball/master deps = https://www.djangoproject.com/download/1.5c1/tarball/
django-filter==0.5.4
[testenv:py2.6-django1.5]
basepython = python2.6
deps = https://www.djangoproject.com/download/1.5c1/tarball/
django-filter==0.5.4 django-filter==0.5.4
[testenv:py2.7-django1.4] [testenv:py2.7-django1.4]
@ -15,21 +30,16 @@ basepython = python2.7
deps = django==1.4.3 deps = django==1.4.3
django-filter==0.5.4 django-filter==0.5.4
[testenv:py2.6-django1.4]
basepython = python2.6
deps = django==1.4.3
django-filter==0.5.4
[testenv:py2.7-django1.3] [testenv:py2.7-django1.3]
basepython = python2.7 basepython = python2.7
deps = django==1.3.5 deps = django==1.3.5
django-filter==0.5.4 django-filter==0.5.4
[testenv:py2.6-django1.5]
basepython = python2.6
deps = https://github.com/django/django/zipball/master
django-filter==0.5.4
[testenv:py2.6-django1.4]
basepython = python2.6
deps = django==1.4.3
django-filter==0.5.4
[testenv:py2.6-django1.3] [testenv:py2.6-django1.3]
basepython = python2.6 basepython = python2.6
deps = django==1.3.5 deps = django==1.3.5