Getting tests passing

This commit is contained in:
Tom Christie 2014-09-02 17:41:23 +01:00
parent ec096a1cac
commit f2852811f9
20 changed files with 3565 additions and 3553 deletions

View File

@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):
if not user.is_active: if not user.is_active:
msg = _('User account is disabled.') msg = _('User account is disabled.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
attrs['user'] = user
return attrs
else: else:
msg = _('Unable to login with provided credentials.') msg = _('Unable to login with provided credentials.')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
else: else:
msg = _('Must include "username" and "password"') msg = _('Must include "username" and "password"')
raise serializers.ValidationError(msg) raise serializers.ValidationError(msg)
attrs['user'] = user
return attrs

View File

@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):
def post(self, request): def post(self, request):
serializer = self.serializer_class(data=request.DATA) serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid(): if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user']) user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key}) return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

View File

@ -1,4 +1,5 @@
from rest_framework.utils import html from rest_framework.utils import html
import inspect
class empty: class empty:
@ -11,6 +12,22 @@ class empty:
pass pass
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
function = inspect.isfunction(obj)
method = inspect.ismethod(obj)
if not (function or method):
return False
args, _, _, defaults = inspect.getargspec(obj)
len_args = len(args) if function else len(args) - 1
len_defaults = len(defaults) if defaults else 0
return len_args <= len_defaults
def get_attribute(instance, attrs): def get_attribute(instance, attrs):
""" """
Similar to Python's built in `getattr(instance, attr)`, Similar to Python's built in `getattr(instance, attr)`,
@ -98,6 +115,7 @@ class Field(object):
self.field_name = field_name self.field_name = field_name
self.parent = parent self.parent = parent
self.root = root self.root = root
self.context = parent.context
# `self.label` should deafult to being based on the field name. # `self.label` should deafult to being based on the field name.
if self.label is None: if self.label is None:
@ -297,25 +315,55 @@ class IntegerField(Field):
self.fail('invalid_integer') self.fail('invalid_integer')
return data return data
def to_primative(self, value):
if value is None:
return None
return int(value)
class EmailField(CharField): class EmailField(CharField):
pass # TODO pass # TODO
class URLField(CharField):
pass # TODO
class RegexField(CharField): class RegexField(CharField):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.regex = kwargs.pop('regex') self.regex = kwargs.pop('regex')
super(CharField, self).__init__(**kwargs) super(CharField, self).__init__(**kwargs)
class DateField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateField, self).__init__(**kwargs)
class TimeField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(TimeField, self).__init__(**kwargs)
class DateTimeField(CharField): class DateTimeField(CharField):
pass # TODO def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateTimeField, self).__init__(**kwargs)
class FileField(Field): class FileField(Field):
pass # TODO pass # TODO
class ReadOnlyField(Field):
def to_primative(self, value):
if is_simple_callable(value):
return value()
return value
class MethodField(Field): class MethodField(Field):
def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs['source'] = '*' kwargs['source'] = '*'

View File

@ -13,23 +13,6 @@ from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
def _get_validation_exclusions(obj, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
For use when performing full_clean on a model instance,
so we only clean the required fields.
"""
if lookup_field == 'pk':
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
lookup_field = pk_field.name
return [field.name for field in obj._meta.fields if field.name != lookup_field]
class CreateModelMixin(object): class CreateModelMixin(object):
""" """
Create a model instance. Create a model instance.
@ -92,15 +75,14 @@ class UpdateModelMixin(object):
if not serializer.is_valid(): if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg]
extras = {self.lookup_field: lookup_value}
if self.object is None: if self.object is None:
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg]
extras = {self.lookup_field: lookup_value}
self.object = serializer.save(extras=extras) self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
self.object = serializer.save(extras=extras) self.object = serializer.save()
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
@ -122,21 +104,6 @@ class UpdateModelMixin(object):
# return a 404 response. # return a 404 response.
raise raise
def pre_save(self, obj):
"""
Set any attributes on the object that are implicit in the request.
"""
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg]
setattr(obj, self.lookup_field, lookup_value)
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
exclude = _get_validation_exclusions(obj, self.lookup_field)
obj.full_clean(exclude)
class DestroyModelMixin(object): class DestroyModelMixin(object):
""" """

View File

@ -13,7 +13,7 @@ class NextPageField(serializers.Field):
""" """
page_field = 'page' page_field = 'page'
def to_native(self, value): def to_primative(self, value):
if not value.has_next(): if not value.has_next():
return None return None
page = value.next_page_number() page = value.next_page_number()
@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):
""" """
page_field = 'page' page_field = 'page'
def to_native(self, value): def to_primative(self, value):
if not value.has_previous(): if not value.has_previous():
return None return None
page = value.previous_page_number() page = value.previous_page_number()
@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source) super(DefaultObjectSerializer, self).__init__(source=source)
# class PaginationSerializerOptions(serializers.SerializerOptions):
# """
# An object that stores the options that may be provided to a
# pagination serializer by using the inner `Meta` class.
# Accessible on the instance as `serializer.opts`.
# """
# def __init__(self, meta):
# super(PaginationSerializerOptions, self).__init__(meta)
# self.object_serializer_class = getattr(meta, 'object_serializer_class',
# DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer): class BasePaginationSerializer(serializers.Serializer):
""" """
A base class for pagination serializers to inherit from, A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy. to make implementing custom serializers more easy.
""" """
# _options_class = PaginationSerializerOptions
results_field = 'results' results_field = 'results'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer):
""" """
super(BasePaginationSerializer, self).__init__(*args, **kwargs) super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field results_field = self.results_field
object_serializer = self.opts.object_serializer_class try:
object_serializer = self.Meta.object_serializer_class
except AttributeError:
object_serializer = DefaultObjectSerializer
if 'context' in kwargs: self.fields[results_field] = serializers.ListSerializer(
context_kwarg = {'context': kwargs['context']} child=object_serializer(),
else: source='object_list'
context_kwarg = {} )
self.fields[results_field].bind(results_field, self, self) # TODO: Support automatic binding
self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
class PaginationSerializer(BasePaginationSerializer): class PaginationSerializer(BasePaginationSerializer):

View File

@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField):
try: try:
http_prefix = value.startswith(('http:', 'https:')) http_prefix = value.startswith(('http:', 'https:'))
except AttributeError: except AttributeError:
self.fail('incorrect_type', type(value).__name__) self.fail('incorrect_type', data_type=type(value).__name__)
if http_prefix: if http_prefix:
# If needed convert absolute URLs to relative path # If needed convert absolute URLs to relative path

View File

@ -142,7 +142,7 @@ class Serializer(BaseSerializer):
return super(Serializer, cls).__new__(cls) return super(Serializer, cls).__new__(cls)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs.pop('context', None) self.context = kwargs.pop('context', {})
kwargs.pop('partial', None) kwargs.pop('partial', None)
kwargs.pop('many', False) kwargs.pop('many', False)
@ -202,7 +202,7 @@ class Serializer(BaseSerializer):
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
return ret return self.validate(ret)
def to_primative(self, instance): def to_primative(self, instance):
""" """
@ -217,6 +217,9 @@ class Serializer(BaseSerializer):
return ret return ret
def validate(self, attrs):
return attrs
def __iter__(self): def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {} errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values(): for field in self.fields.values():
@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child)) self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.' assert self.child is not None, '`child` is a required argument.'
self.context = kwargs.pop('context', {})
kwargs.pop('context', None)
kwargs.pop('partial', None) kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs) super(ListSerializer, self).__init__(*args, **kwargs)
@ -316,19 +318,19 @@ class ModelSerializer(Serializer):
models.PositiveIntegerField: IntegerField, models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField, models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField,
# models.DateTimeField: DateTimeField, models.DateTimeField: DateTimeField,
# models.DateField: DateField, models.DateField: DateField,
# models.TimeField: TimeField, models.TimeField: TimeField,
# models.DecimalField: DecimalField, # models.DecimalField: DecimalField,
# models.EmailField: EmailField, models.EmailField: EmailField,
models.CharField: CharField, models.CharField: CharField,
# models.URLField: URLField, models.URLField: URLField,
# models.SlugField: SlugField, # models.SlugField: SlugField,
models.TextField: CharField, models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField, models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField, models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField, models.NullBooleanField: BooleanField,
# models.FileField: FileField, models.FileField: FileField,
# models.ImageField: ImageField, # models.ImageField: ImageField,
} }
@ -338,6 +340,15 @@ class ModelSerializer(Serializer):
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs) super(ModelSerializer, self).__init__(*args, **kwargs)
def create(self):
ModelClass = self.opts.model
return ModelClass.objects.create(**self.validated_data)
def update(self, obj):
for attr, value in self.validated_data.items():
setattr(obj, attr, value)
obj.save()
def get_fields(self): def get_fields(self):
# Get the explicitly declared fields. # Get the explicitly declared fields.
fields = copy.deepcopy(self.base_fields) fields = copy.deepcopy(self.base_fields)
@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions _options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail' _default_view_name = '%(model_name)s-detail'
# _hyperlink_field_class = HyperlinkedRelatedField _hyperlink_field_class = HyperlinkedRelatedField
# _hyperlink_identify_field_class = HyperlinkedIdentityField _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self): def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields() fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None: if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model) self.opts.view_name = self._get_default_view_name(self.opts.model)
# if self.opts.url_field_name not in fields: if self.opts.url_field_name not in fields:
# url_field = self._hyperlink_identify_field_class( url_field = self._hyperlink_identify_field_class(
# view_name=self.opts.view_name, view_name=self.opts.view_name,
# lookup_field=self.opts.lookup_field lookup_field=self.opts.lookup_field
# ) )
# ret = self._dict_class() ret = fields.__class__()
# ret[self.opts.url_field_name] = url_field ret[self.opts.url_field_name] = url_field
# ret.update(fields) ret.update(fields)
# fields = ret fields = ret
return fields return fields

File diff suppressed because it is too large Load Diff

View File

@ -1,92 +1,92 @@
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.test import TestCase # from django.test import TestCase
from django.utils import six # from django.utils import six
from rest_framework import serializers # from rest_framework import serializers
from rest_framework.compat import BytesIO # from rest_framework.compat import BytesIO
import datetime # import datetime
class UploadedFile(object): # class UploadedFile(object):
def __init__(self, file=None, created=None): # def __init__(self, file=None, created=None):
self.file = file # self.file = file
self.created = created or datetime.datetime.now() # self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer): # class UploadedFileSerializer(serializers.Serializer):
file = serializers.FileField(required=False) # file = serializers.FileField(required=False)
created = serializers.DateTimeField() # created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
if instance: # if instance:
instance.file = attrs['file'] # instance.file = attrs['file']
instance.created = attrs['created'] # instance.created = attrs['created']
return instance # return instance
return UploadedFile(**attrs) # return UploadedFile(**attrs)
class FileSerializerTests(TestCase): # class FileSerializerTests(TestCase):
def test_create(self): # def test_create(self):
now = datetime.datetime.now() # now = datetime.datetime.now()
file = BytesIO(six.b('stuff')) # file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt' # file.name = 'stuff.txt'
file.size = len(file.getvalue()) # file.size = len(file.getvalue())
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) # serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now) # uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, uploaded_file.created) # self.assertEqual(serializer.object.created, uploaded_file.created)
self.assertEqual(serializer.object.file, uploaded_file.file) # self.assertEqual(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file) # self.assertFalse(serializer.object is uploaded_file)
def test_creation_failure(self): # def test_creation_failure(self):
""" # """
Passing files=None should result in an ValidationError # Passing files=None should result in an ValidationError
Regression test for: # Regression test for:
https://github.com/tomchristie/django-rest-framework/issues/542 # https://github.com/tomchristie/django-rest-framework/issues/542
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now}) # serializer = UploadedFileSerializer(data={'created': now})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, now) # self.assertEqual(serializer.object.created, now)
self.assertIsNone(serializer.object.file) # self.assertIsNone(serializer.object.file)
def test_remove_with_empty_string(self): # def test_remove_with_empty_string(self):
""" # """
Passing empty string as data should cause file to be removed # Passing empty string as data should cause file to be removed
Test for: # Test for:
https://github.com/tomchristie/django-rest-framework/issues/937 # https://github.com/tomchristie/django-rest-framework/issues/937
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
file = BytesIO(six.b('stuff')) # file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt' # file.name = 'stuff.txt'
file.size = len(file.getvalue()) # file.size = len(file.getvalue())
uploaded_file = UploadedFile(file=file, created=now) # uploaded_file = UploadedFile(file=file, created=now)
serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) # serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.object.created, uploaded_file.created) # self.assertEqual(serializer.object.created, uploaded_file.created)
self.assertIsNone(serializer.object.file) # self.assertIsNone(serializer.object.file)
def test_validation_error_with_non_file(self): # def test_validation_error_with_non_file(self):
""" # """
Passing non-files should raise a validation error. # Passing non-files should raise a validation error.
""" # """
now = datetime.datetime.now() # now = datetime.datetime.now()
errmsg = 'No file was submitted. Check the encoding type on the form.' # errmsg = 'No file was submitted. Check the encoding type on the form.'
serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) # serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'file': [errmsg]}) # self.assertEqual(serializer.errors, {'file': [errmsg]})
def test_validation_with_no_data(self): # def test_validation_with_no_data(self):
""" # """
Validation should still function when no data dictionary is provided. # Validation should still function when no data dictionary is provided.
""" # """
uploaded_file = BytesIO(six.b('stuff')) # uploaded_file = BytesIO(six.b('stuff'))
uploaded_file.name = 'stuff.txt' # uploaded_file.name = 'stuff.txt'
uploaded_file.size = len(uploaded_file.getvalue()) # uploaded_file.size = len(uploaded_file.getvalue())
serializer = UploadedFileSerializer(files={'file': uploaded_file}) # serializer = UploadedFileSerializer(files={'file': uploaded_file})
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())

View File

@ -1,151 +1,151 @@
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.contrib.contenttypes.models import ContentType # from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey # from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from rest_framework.compat import python_2_unicode_compatible # from rest_framework.compat import python_2_unicode_compatible
@python_2_unicode_compatible # @python_2_unicode_compatible
class Tag(models.Model): # class Tag(models.Model):
""" # """
Tags have a descriptive slug, and are attached to an arbitrary object. # Tags have a descriptive slug, and are attached to an arbitrary object.
""" # """
tag = models.SlugField() # tag = models.SlugField()
content_type = models.ForeignKey(ContentType) # content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField() # object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id') # tagged_item = GenericForeignKey('content_type', 'object_id')
def __str__(self): # def __str__(self):
return self.tag # return self.tag
@python_2_unicode_compatible # @python_2_unicode_compatible
class Bookmark(models.Model): # class Bookmark(models.Model):
""" # """
A URL bookmark that may have multiple tags attached. # A URL bookmark that may have multiple tags attached.
""" # """
url = models.URLField() # url = models.URLField()
tags = GenericRelation(Tag) # tags = GenericRelation(Tag)
def __str__(self): # def __str__(self):
return 'Bookmark: %s' % self.url # return 'Bookmark: %s' % self.url
@python_2_unicode_compatible # @python_2_unicode_compatible
class Note(models.Model): # class Note(models.Model):
""" # """
A textual note that may have multiple tags attached. # A textual note that may have multiple tags attached.
""" # """
text = models.TextField() # text = models.TextField()
tags = GenericRelation(Tag) # tags = GenericRelation(Tag)
def __str__(self): # def __str__(self):
return 'Note: %s' % self.text # return 'Note: %s' % self.text
class TestGenericRelations(TestCase): # class TestGenericRelations(TestCase):
def setUp(self): # def setUp(self):
self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') # self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
Tag.objects.create(tagged_item=self.bookmark, tag='django') # Tag.objects.create(tagged_item=self.bookmark, tag='django')
Tag.objects.create(tagged_item=self.bookmark, tag='python') # Tag.objects.create(tagged_item=self.bookmark, tag='python')
self.note = Note.objects.create(text='Remember the milk') # self.note = Note.objects.create(text='Remember the milk')
Tag.objects.create(tagged_item=self.note, tag='reminder') # Tag.objects.create(tagged_item=self.note, tag='reminder')
def test_generic_relation(self): # def test_generic_relation(self):
""" # """
Test a relationship that spans a GenericRelation field. # Test a relationship that spans a GenericRelation field.
IE. A reverse generic relationship. # IE. A reverse generic relationship.
""" # """
class BookmarkSerializer(serializers.ModelSerializer): # class BookmarkSerializer(serializers.ModelSerializer):
tags = serializers.RelatedField(many=True) # tags = serializers.RelatedField(many=True)
class Meta: # class Meta:
model = Bookmark # model = Bookmark
exclude = ('id',) # exclude = ('id',)
serializer = BookmarkSerializer(self.bookmark) # serializer = BookmarkSerializer(self.bookmark)
expected = { # expected = {
'tags': ['django', 'python'], # 'tags': ['django', 'python'],
'url': 'https://www.djangoproject.com/' # 'url': 'https://www.djangoproject.com/'
} # }
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_generic_nested_relation(self): # def test_generic_nested_relation(self):
""" # """
Test saving a GenericRelation field via a nested serializer. # Test saving a GenericRelation field via a nested serializer.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('content_type', 'object_id') # exclude = ('content_type', 'object_id')
class BookmarkSerializer(serializers.ModelSerializer): # class BookmarkSerializer(serializers.ModelSerializer):
tags = TagSerializer(many=True) # tags = TagSerializer(many=True)
class Meta: # class Meta:
model = Bookmark # model = Bookmark
exclude = ('id',) # exclude = ('id',)
data = { # data = {
'url': 'https://docs.djangoproject.com/', # 'url': 'https://docs.djangoproject.com/',
'tags': [ # 'tags': [
{'tag': 'contenttypes'}, # {'tag': 'contenttypes'},
{'tag': 'genericrelations'}, # {'tag': 'genericrelations'},
] # ]
} # }
serializer = BookmarkSerializer(data=data) # serializer = BookmarkSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
self.assertEqual(serializer.object.tags.count(), 2) # self.assertEqual(serializer.object.tags.count(), 2)
def test_generic_fk(self): # def test_generic_fk(self):
""" # """
Test a relationship that spans a GenericForeignKey field. # Test a relationship that spans a GenericForeignKey field.
IE. A forward generic relationship. # IE. A forward generic relationship.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
tagged_item = serializers.RelatedField() # tagged_item = serializers.RelatedField()
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('id', 'content_type', 'object_id') # exclude = ('id', 'content_type', 'object_id')
serializer = TagSerializer(Tag.objects.all(), many=True) # serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [ # expected = [
{ # {
'tag': 'django', # 'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' # 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, # },
{ # {
'tag': 'python', # 'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/' # 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
}, # },
{ # {
'tag': 'reminder', # 'tag': 'reminder',
'tagged_item': 'Note: Remember the milk' # 'tagged_item': 'Note: Remember the milk'
} # }
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_restore_object_generic_fk(self): # def test_restore_object_generic_fk(self):
""" # """
Ensure an object with a generic foreign key can be restored. # Ensure an object with a generic foreign key can be restored.
""" # """
class TagSerializer(serializers.ModelSerializer): # class TagSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = Tag # model = Tag
exclude = ('content_type', 'object_id') # exclude = ('content_type', 'object_id')
serializer = TagSerializer() # serializer = TagSerializer()
bookmark = Bookmark(url='http://example.com') # bookmark = Bookmark(url='http://example.com')
attrs = {'tagged_item': bookmark, 'tag': 'example'} # attrs = {'tagged_item': bookmark, 'tag': 'example'}
tag = serializer.restore_object(attrs) # tag = serializer.restore_object(attrs)
self.assertEqual(tag.tagged_item, bookmark) # self.assertEqual(tag.tagged_item, bookmark)

View File

@ -33,13 +33,9 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):
""" """
Example description for OPTIONS. Example description for OPTIONS.
""" """
queryset = BasicModel.objects.all() queryset = BasicModel.objects.exclude(text='filtered out')
serializer_class = BasicSerializer serializer_class = BasicSerializer
def get_queryset(self):
queryset = super(InstanceView, self).get_queryset()
return queryset.exclude(text='filtered out')
class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): class FKInstanceView(generics.RetrieveUpdateDestroyAPIView):
""" """
@ -50,11 +46,11 @@ class FKInstanceView(generics.RetrieveUpdateDestroyAPIView):
class SlugSerializer(serializers.ModelSerializer): class SlugSerializer(serializers.ModelSerializer):
slug = serializers.Field() # read only slug = serializers.Field(read_only=True)
class Meta: class Meta:
model = SlugBasedModel model = SlugBasedModel
exclude = ('id',) fields = ('text', 'slug')
class SlugBasedInstanceView(InstanceView): class SlugBasedInstanceView(InstanceView):
@ -125,46 +121,46 @@ class TestRootView(TestCase):
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
def test_options_root_view(self): # def test_options_root_view(self):
""" # """
OPTIONS requests to ListCreateAPIView should return metadata # OPTIONS requests to ListCreateAPIView should return metadata
""" # """
request = factory.options('/') # request = factory.options('/')
with self.assertNumQueries(0): # with self.assertNumQueries(0):
response = self.view(request).render() # response = self.view(request).render()
expected = { # expected = {
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'name': 'Root', # 'name': 'Root',
'description': 'Example description for OPTIONS.', # 'description': 'Example description for OPTIONS.',
'actions': { # 'actions': {
'POST': { # 'POST': {
'text': { # 'text': {
'max_length': 100, # 'max_length': 100,
'read_only': False, # 'read_only': False,
'required': True, # 'required': True,
'type': 'string', # 'type': 'string',
"label": "Text comes here", # "label": "Text comes here",
"help_text": "Text description." # "help_text": "Text description."
}, # },
'id': { # 'id': {
'read_only': True, # 'read_only': True,
'required': False, # 'required': False,
'type': 'integer', # 'type': 'integer',
'label': 'ID', # 'label': 'ID',
}, # },
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
def test_post_cannot_set_id(self): def test_post_cannot_set_id(self):
""" """
@ -223,10 +219,10 @@ class TestInstanceView(TestCase):
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(2): with self.assertNumQueries(3):
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(dict(response.data), {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
self.assertEqual(updated.text, 'foobar') self.assertEqual(updated.text, 'foobar')
@ -237,7 +233,7 @@ class TestInstanceView(TestCase):
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json') request = factory.patch('/1', data, format='json')
with self.assertNumQueries(2): with self.assertNumQueries(3):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
@ -256,88 +252,88 @@ class TestInstanceView(TestCase):
ids = [obj.id for obj in self.objects.all()] ids = [obj.id for obj in self.objects.all()]
self.assertEqual(ids, [2, 3]) self.assertEqual(ids, [2, 3])
def test_options_instance_view(self): # def test_options_instance_view(self):
""" # """
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata # OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
""" # """
request = factory.options('/1') # request = factory.options('/1')
with self.assertNumQueries(1): # with self.assertNumQueries(1):
response = self.view(request, pk=1).render() # response = self.view(request, pk=1).render()
expected = { # expected = {
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'name': 'Instance', # 'name': 'Instance',
'description': 'Example description for OPTIONS.', # 'description': 'Example description for OPTIONS.',
'actions': { # 'actions': {
'PUT': { # 'PUT': {
'text': { # 'text': {
'max_length': 100, # 'max_length': 100,
'read_only': False, # 'read_only': False,
'required': True, # 'required': True,
'type': 'string', # 'type': 'string',
'label': 'Text comes here', # 'label': 'Text comes here',
'help_text': 'Text description.' # 'help_text': 'Text description.'
}, # },
'id': { # 'id': {
'read_only': True, # 'read_only': True,
'required': False, # 'required': False,
'type': 'integer', # 'type': 'integer',
'label': 'ID', # 'label': 'ID',
}, # },
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
def test_options_before_instance_create(self): # def test_options_before_instance_create(self):
""" # """
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata # OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
before the instance has been created # before the instance has been created
""" # """
request = factory.options('/999') # request = factory.options('/999')
with self.assertNumQueries(1): # with self.assertNumQueries(1):
response = self.view(request, pk=999).render() # response = self.view(request, pk=999).render()
expected = { # expected = {
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'name': 'Instance', # 'name': 'Instance',
'description': 'Example description for OPTIONS.', # 'description': 'Example description for OPTIONS.',
'actions': { # 'actions': {
'PUT': { # 'PUT': {
'text': { # 'text': {
'max_length': 100, # 'max_length': 100,
'read_only': False, # 'read_only': False,
'required': True, # 'required': True,
'type': 'string', # 'type': 'string',
'label': 'Text comes here', # 'label': 'Text comes here',
'help_text': 'Text description.' # 'help_text': 'Text description.'
}, # },
'id': { # 'id': {
'read_only': True, # 'read_only': True,
'required': False, # 'required': False,
'type': 'integer', # 'type': 'integer',
'label': 'ID', # 'label': 'ID',
}, # },
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
def test_get_instance_view_incorrect_arg(self): def test_get_instance_view_incorrect_arg(self):
""" """
@ -355,7 +351,7 @@ class TestInstanceView(TestCase):
""" """
data = {'id': 999, 'text': 'foobar'} data = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(2): with self.assertNumQueries(3):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
@ -370,7 +366,7 @@ class TestInstanceView(TestCase):
self.objects.get(id=1).delete() self.objects.get(id=1).delete()
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(3): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
@ -396,7 +392,7 @@ class TestInstanceView(TestCase):
data = {'text': 'foobar'} data = {'text': 'foobar'}
# pk fields can not be created on demand, only the database can set the pk for a new object # pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', data, format='json') request = factory.put('/5', data, format='json')
with self.assertNumQueries(3): with self.assertNumQueries(2):
response = self.view(request, pk=5).render() response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5) new_obj = self.objects.get(pk=5)
@ -446,52 +442,52 @@ class TestFKInstanceView(TestCase):
] ]
self.view = FKInstanceView.as_view() self.view = FKInstanceView.as_view()
def test_options_root_view(self): # def test_options_root_view(self):
""" # """
OPTIONS requests to ListCreateAPIView should return metadata # OPTIONS requests to ListCreateAPIView should return metadata
""" # """
request = factory.options('/999') # request = factory.options('/999')
with self.assertNumQueries(1): # with self.assertNumQueries(1):
response = self.view(request, pk=999).render() # response = self.view(request, pk=999).render()
expected = { # expected = {
'name': 'Fk Instance', # 'name': 'Fk Instance',
'description': 'FK: example description for OPTIONS.', # 'description': 'FK: example description for OPTIONS.',
'renders': [ # 'renders': [
'application/json', # 'application/json',
'text/html' # 'text/html'
], # ],
'parses': [ # 'parses': [
'application/json', # 'application/json',
'application/x-www-form-urlencoded', # 'application/x-www-form-urlencoded',
'multipart/form-data' # 'multipart/form-data'
], # ],
'actions': { # 'actions': {
'PUT': { # 'PUT': {
'id': { # 'id': {
'type': 'integer', # 'type': 'integer',
'required': False, # 'required': False,
'read_only': True, # 'read_only': True,
'label': 'ID' # 'label': 'ID'
}, # },
'name': { # 'name': {
'type': 'string', # 'type': 'string',
'required': True, # 'required': True,
'read_only': False, # 'read_only': False,
'label': 'name', # 'label': 'name',
'max_length': 100 # 'max_length': 100
}, # },
'target': { # 'target': {
'type': 'field', # 'type': 'field',
'required': True, # 'required': True,
'read_only': False, # 'read_only': False,
'label': 'Target', # 'label': 'Target',
'help_text': 'Target' # 'help_text': 'Target'
} # }
} # }
} # }
} # }
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected) # self.assertEqual(response.data, expected)
class TestOverriddenGetObject(TestCase): class TestOverriddenGetObject(TestCase):

View File

@ -1,406 +1,406 @@
from __future__ import unicode_literals # from __future__ import unicode_literals
import json # import json
from django.test import TestCase # from django.test import TestCase
from rest_framework import generics, status, serializers # from rest_framework import generics, status, serializers
from django.conf.urls import patterns, url # from django.conf.urls import patterns, url
from rest_framework.settings import api_settings # from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory # from rest_framework.test import APIRequestFactory
from tests.models import ( # from tests.models import (
Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, # Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
Album, Photo, OptionalRelationModel # Album, Photo, OptionalRelationModel
) # )
factory = APIRequestFactory() # factory = APIRequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer): # class BlogPostCommentSerializer(serializers.ModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') # url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
text = serializers.CharField() # text = serializers.CharField()
blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') # blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
class Meta: # class Meta:
model = BlogPostComment # model = BlogPostComment
fields = ('text', 'blog_post_url', 'url') # fields = ('text', 'blog_post_url', 'url')
class PhotoSerializer(serializers.Serializer): # class PhotoSerializer(serializers.Serializer):
description = serializers.CharField() # description = serializers.CharField()
album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title') # album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title')
def restore_object(self, attrs, instance=None): # def restore_object(self, attrs, instance=None):
return Photo(**attrs) # return Photo(**attrs)
class AlbumSerializer(serializers.ModelSerializer): # class AlbumSerializer(serializers.ModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') # url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
class Meta: # class Meta:
model = Album # model = Album
fields = ('title', 'url') # fields = ('title', 'url')
class BasicSerializer(serializers.HyperlinkedModelSerializer): # class BasicSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = BasicModel # model = BasicModel
class AnchorSerializer(serializers.HyperlinkedModelSerializer): # class AnchorSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = Anchor # model = Anchor
class ManyToManySerializer(serializers.HyperlinkedModelSerializer): # class ManyToManySerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = ManyToManyModel # model = ManyToManyModel
class BlogPostSerializer(serializers.ModelSerializer): # class BlogPostSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = BlogPost # model = BlogPost
class OptionalRelationSerializer(serializers.HyperlinkedModelSerializer): # class OptionalRelationSerializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = OptionalRelationModel # model = OptionalRelationModel
class BasicList(generics.ListCreateAPIView): # class BasicList(generics.ListCreateAPIView):
queryset = BasicModel.objects.all() # queryset = BasicModel.objects.all()
serializer_class = BasicSerializer # serializer_class = BasicSerializer
class BasicDetail(generics.RetrieveUpdateDestroyAPIView): # class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = BasicModel.objects.all() # queryset = BasicModel.objects.all()
serializer_class = BasicSerializer # serializer_class = BasicSerializer
class AnchorDetail(generics.RetrieveAPIView): # class AnchorDetail(generics.RetrieveAPIView):
queryset = Anchor.objects.all() # queryset = Anchor.objects.all()
serializer_class = AnchorSerializer # serializer_class = AnchorSerializer
class ManyToManyList(generics.ListAPIView): # class ManyToManyList(generics.ListAPIView):
queryset = ManyToManyModel.objects.all() # queryset = ManyToManyModel.objects.all()
serializer_class = ManyToManySerializer # serializer_class = ManyToManySerializer
class ManyToManyDetail(generics.RetrieveAPIView): # class ManyToManyDetail(generics.RetrieveAPIView):
queryset = ManyToManyModel.objects.all() # queryset = ManyToManyModel.objects.all()
serializer_class = ManyToManySerializer # serializer_class = ManyToManySerializer
class BlogPostCommentListCreate(generics.ListCreateAPIView): # class BlogPostCommentListCreate(generics.ListCreateAPIView):
queryset = BlogPostComment.objects.all() # queryset = BlogPostComment.objects.all()
serializer_class = BlogPostCommentSerializer # serializer_class = BlogPostCommentSerializer
class BlogPostCommentDetail(generics.RetrieveAPIView): # class BlogPostCommentDetail(generics.RetrieveAPIView):
queryset = BlogPostComment.objects.all() # queryset = BlogPostComment.objects.all()
serializer_class = BlogPostCommentSerializer # serializer_class = BlogPostCommentSerializer
class BlogPostDetail(generics.RetrieveAPIView): # class BlogPostDetail(generics.RetrieveAPIView):
queryset = BlogPost.objects.all() # queryset = BlogPost.objects.all()
serializer_class = BlogPostSerializer # serializer_class = BlogPostSerializer
class PhotoListCreate(generics.ListCreateAPIView): # class PhotoListCreate(generics.ListCreateAPIView):
queryset = Photo.objects.all() # queryset = Photo.objects.all()
serializer_class = PhotoSerializer # serializer_class = PhotoSerializer
class AlbumDetail(generics.RetrieveAPIView): # class AlbumDetail(generics.RetrieveAPIView):
queryset = Album.objects.all() # queryset = Album.objects.all()
serializer_class = AlbumSerializer # serializer_class = AlbumSerializer
lookup_field = 'title' # lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): # class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = OptionalRelationModel.objects.all() # queryset = OptionalRelationModel.objects.all()
serializer_class = OptionalRelationSerializer # serializer_class = OptionalRelationSerializer
urlpatterns = patterns( # urlpatterns = patterns(
'', # '',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), # url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), # url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), # url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), # url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), # url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), # url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), # url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), # url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), # url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), # url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), # url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
) # )
class TestBasicHyperlinkedView(TestCase): # class TestBasicHyperlinkedView(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create 3 BasicModel instances. # Create 3 BasicModel instances.
""" # """
items = ['foo', 'bar', 'baz'] # items = ['foo', 'bar', 'baz']
for item in items: # for item in items:
BasicModel(text=item).save() # BasicModel(text=item).save()
self.objects = BasicModel.objects # self.objects = BasicModel.objects
self.data = [ # self.data = [
{'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text} # {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
for obj in self.objects.all() # for obj in self.objects.all()
] # ]
self.list_view = BasicList.as_view() # self.list_view = BasicList.as_view()
self.detail_view = BasicDetail.as_view() # self.detail_view = BasicDetail.as_view()
def test_get_list_view(self): # def test_get_list_view(self):
""" # """
GET requests to ListCreateAPIView should return list of objects. # GET requests to ListCreateAPIView should return list of objects.
""" # """
request = factory.get('/basic/') # request = factory.get('/basic/')
response = self.list_view(request).render() # response = self.list_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data) # self.assertEqual(response.data, self.data)
def test_get_detail_view(self): # def test_get_detail_view(self):
""" # """
GET requests to ListCreateAPIView should return list of objects. # GET requests to ListCreateAPIView should return list of objects.
""" # """
request = factory.get('/basic/1') # request = factory.get('/basic/1')
response = self.detail_view(request, pk=1).render() # response = self.detail_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[0]) # self.assertEqual(response.data, self.data[0])
class TestManyToManyHyperlinkedView(TestCase): # class TestManyToManyHyperlinkedView(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create 3 BasicModel instances. # Create 3 BasicModel instances.
""" # """
items = ['foo', 'bar', 'baz'] # items = ['foo', 'bar', 'baz']
anchors = [] # anchors = []
for item in items: # for item in items:
anchor = Anchor(text=item) # anchor = Anchor(text=item)
anchor.save() # anchor.save()
anchors.append(anchor) # anchors.append(anchor)
manytomany = ManyToManyModel() # manytomany = ManyToManyModel()
manytomany.save() # manytomany.save()
manytomany.rel.add(*anchors) # manytomany.rel.add(*anchors)
self.data = [{ # self.data = [{
'url': 'http://testserver/manytomany/1/', # 'url': 'http://testserver/manytomany/1/',
'rel': [ # 'rel': [
'http://testserver/anchor/1/', # 'http://testserver/anchor/1/',
'http://testserver/anchor/2/', # 'http://testserver/anchor/2/',
'http://testserver/anchor/3/', # 'http://testserver/anchor/3/',
] # ]
}] # }]
self.list_view = ManyToManyList.as_view() # self.list_view = ManyToManyList.as_view()
self.detail_view = ManyToManyDetail.as_view() # self.detail_view = ManyToManyDetail.as_view()
def test_get_list_view(self): # def test_get_list_view(self):
""" # """
GET requests to ListCreateAPIView should return list of objects. # GET requests to ListCreateAPIView should return list of objects.
""" # """
request = factory.get('/manytomany/') # request = factory.get('/manytomany/')
response = self.list_view(request) # response = self.list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data) # self.assertEqual(response.data, self.data)
def test_get_detail_view(self): # def test_get_detail_view(self):
""" # """
GET requests to ListCreateAPIView should return list of objects. # GET requests to ListCreateAPIView should return list of objects.
""" # """
request = factory.get('/manytomany/1/') # request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1) # response = self.detail_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[0]) # self.assertEqual(response.data, self.data[0])
class TestHyperlinkedIdentityFieldLookup(TestCase): # class TestHyperlinkedIdentityFieldLookup(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create 3 Album instances. # Create 3 Album instances.
""" # """
titles = ['foo', 'bar', 'baz'] # titles = ['foo', 'bar', 'baz']
for title in titles: # for title in titles:
album = Album(title=title) # album = Album(title=title)
album.save() # album.save()
self.detail_view = AlbumDetail.as_view() # self.detail_view = AlbumDetail.as_view()
self.data = { # self.data = {
'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, # 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, # 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} # 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
} # }
def test_lookup_field(self): # def test_lookup_field(self):
""" # """
GET requests to AlbumDetail view should return serialized Albums # GET requests to AlbumDetail view should return serialized Albums
with a url field keyed by `title`. # with a url field keyed by `title`.
""" # """
for album in Album.objects.all(): # for album in Album.objects.all():
request = factory.get('/albums/{0}/'.format(album.title)) # request = factory.get('/albums/{0}/'.format(album.title))
response = self.detail_view(request, title=album.title) # response = self.detail_view(request, title=album.title)
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[album.title]) # self.assertEqual(response.data, self.data[album.title])
class TestCreateWithForeignKeys(TestCase): # class TestCreateWithForeignKeys(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create a blog post # Create a blog post
""" # """
self.post = BlogPost.objects.create(title="Test post") # self.post = BlogPost.objects.create(title="Test post")
self.create_view = BlogPostCommentListCreate.as_view() # self.create_view = BlogPostCommentListCreate.as_view()
def test_create_comment(self): # def test_create_comment(self):
data = { # data = {
'text': 'A test comment', # 'text': 'A test comment',
'blog_post_url': 'http://testserver/posts/1/' # 'blog_post_url': 'http://testserver/posts/1/'
} # }
request = factory.post('/comments/', data=data) # request = factory.post('/comments/', data=data)
response = self.create_view(request) # response = self.create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) # self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response['Location'], 'http://testserver/comments/1/') # self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1) # self.assertEqual(self.post.blogpostcomment_set.count(), 1)
self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') # self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
class TestCreateWithForeignKeysAndCustomSlug(TestCase): # class TestCreateWithForeignKeysAndCustomSlug(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create an Album # Create an Album
""" # """
self.post = Album.objects.create(title='test-album') # self.post = Album.objects.create(title='test-album')
self.list_create_view = PhotoListCreate.as_view() # self.list_create_view = PhotoListCreate.as_view()
def test_create_photo(self): # def test_create_photo(self):
data = { # data = {
'description': 'A test photo', # 'description': 'A test photo',
'album_url': 'http://testserver/albums/test-album/' # 'album_url': 'http://testserver/albums/test-album/'
} # }
request = factory.post('/photos/', data=data) # request = factory.post('/photos/', data=data)
response = self.list_create_view(request) # response = self.list_create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) # self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') # self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1) # self.assertEqual(self.post.photo_set.count(), 1)
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') # self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
class TestOptionalRelationHyperlinkedView(TestCase): # class TestOptionalRelationHyperlinkedView(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
""" # """
Create 1 OptionalRelationModel instances. # Create 1 OptionalRelationModel instances.
""" # """
OptionalRelationModel().save() # OptionalRelationModel().save()
self.objects = OptionalRelationModel.objects # self.objects = OptionalRelationModel.objects
self.detail_view = OptionalRelationDetail.as_view() # self.detail_view = OptionalRelationDetail.as_view()
self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} # self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
def test_get_detail_view(self): # def test_get_detail_view(self):
""" # """
GET requests to RetrieveAPIView with optional relations should return None # GET requests to RetrieveAPIView with optional relations should return None
for non existing relations. # for non existing relations.
""" # """
request = factory.get('/optionalrelationmodel-detail/1') # request = factory.get('/optionalrelationmodel-detail/1')
response = self.detail_view(request, pk=1) # response = self.detail_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data) # self.assertEqual(response.data, self.data)
def test_put_detail_view(self): # def test_put_detail_view(self):
""" # """
PUT requests to RetrieveUpdateDestroyAPIView with optional relations # PUT requests to RetrieveUpdateDestroyAPIView with optional relations
should accept None for non existing relations. # should accept None for non existing relations.
""" # """
response = self.client.put('/optionalrelation/1/', # response = self.client.put('/optionalrelation/1/',
data=json.dumps(self.data), # data=json.dumps(self.data),
content_type='application/json') # content_type='application/json')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
class TestOverriddenURLField(TestCase): # class TestOverriddenURLField(TestCase):
def setUp(self): # def setUp(self):
class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer): # class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
url = serializers.SerializerMethodField('get_url') # url = serializers.SerializerMethodField('get_url')
class Meta: # class Meta:
model = BlogPost # model = BlogPost
fields = ('title', 'url') # fields = ('title', 'url')
def get_url(self, obj): # def get_url(self, obj):
return 'foo bar' # return 'foo bar'
self.Serializer = OverriddenURLSerializer # self.Serializer = OverriddenURLSerializer
self.obj = BlogPost.objects.create(title='New blog post') # self.obj = BlogPost.objects.create(title='New blog post')
def test_overridden_url_field(self): # def test_overridden_url_field(self):
""" # """
The 'url' field should respect overriding. # The 'url' field should respect overriding.
Regression test for #936. # Regression test for #936.
""" # """
serializer = self.Serializer(self.obj) # serializer = self.Serializer(self.obj)
self.assertEqual( # self.assertEqual(
serializer.data, # serializer.data,
{'title': 'New blog post', 'url': 'foo bar'} # {'title': 'New blog post', 'url': 'foo bar'}
) # )
class TestURLFieldNameBySettings(TestCase): # class TestURLFieldNameBySettings(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
self.saved_url_field_name = api_settings.URL_FIELD_NAME # self.saved_url_field_name = api_settings.URL_FIELD_NAME
api_settings.URL_FIELD_NAME = 'global_url_field' # api_settings.URL_FIELD_NAME = 'global_url_field'
class Serializer(serializers.HyperlinkedModelSerializer): # class Serializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = BlogPost # model = BlogPost
fields = ('title', api_settings.URL_FIELD_NAME) # fields = ('title', api_settings.URL_FIELD_NAME)
self.Serializer = Serializer # self.Serializer = Serializer
self.obj = BlogPost.objects.create(title="New blog post") # self.obj = BlogPost.objects.create(title="New blog post")
def tearDown(self): # def tearDown(self):
api_settings.URL_FIELD_NAME = self.saved_url_field_name # api_settings.URL_FIELD_NAME = self.saved_url_field_name
def test_overridden_url_field_name(self): # def test_overridden_url_field_name(self):
request = factory.get('/posts/') # request = factory.get('/posts/')
serializer = self.Serializer(self.obj, context={'request': request}) # serializer = self.Serializer(self.obj, context={'request': request})
self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) # self.assertIn(api_settings.URL_FIELD_NAME, serializer.data)
class TestURLFieldNameByOptions(TestCase): # class TestURLFieldNameByOptions(TestCase):
urls = 'tests.test_hyperlinkedserializers' # urls = 'tests.test_hyperlinkedserializers'
def setUp(self): # def setUp(self):
class Serializer(serializers.HyperlinkedModelSerializer): # class Serializer(serializers.HyperlinkedModelSerializer):
class Meta: # class Meta:
model = BlogPost # model = BlogPost
fields = ('title', 'serializer_url_field') # fields = ('title', 'serializer_url_field')
url_field_name = 'serializer_url_field' # url_field_name = 'serializer_url_field'
self.Serializer = Serializer # self.Serializer = Serializer
self.obj = BlogPost.objects.create(title="New blog post") # self.obj = BlogPost.objects.create(title="New blog post")
def test_overridden_url_field_name(self): # def test_overridden_url_field_name(self):
request = factory.get('/posts/') # request = factory.get('/posts/')
serializer = self.Serializer(self.obj, context={'request': request}) # serializer = self.Serializer(self.obj, context={'request': request})
self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) # self.assertIn(self.Serializer.Meta.url_field_name, serializer.data)

View File

@ -1,39 +1,39 @@
from django.core.urlresolvers import reverse # from django.core.urlresolvers import reverse
from django.conf.urls import patterns, url # from django.conf.urls import patterns, url
from rest_framework import serializers, generics # from rest_framework import serializers, generics
from rest_framework.test import APITestCase # from rest_framework.test import APITestCase
from tests.models import NullableForeignKeySource # from tests.models import NullableForeignKeySource
class NullableFKSourceSerializer(serializers.ModelSerializer): # class NullableFKSourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = NullableForeignKeySource # model = NullableForeignKeySource
class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): # class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer_class = NullableFKSourceSerializer # serializer_class = NullableFKSourceSerializer
urlpatterns = patterns( # urlpatterns = patterns(
'', # '',
url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), # url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
) # )
class NullableForeignKeyTests(APITestCase): # class NullableForeignKeyTests(APITestCase):
""" # """
DRF should be able to handle nullable foreign keys when a test # DRF should be able to handle nullable foreign keys when a test
Client POST/PUT request is made with its own serialized object. # Client POST/PUT request is made with its own serialized object.
""" # """
urls = 'tests.test_nullable_fields' # urls = 'tests.test_nullable_fields'
def test_updating_object_with_null_fk(self): # def test_updating_object_with_null_fk(self):
obj = NullableForeignKeySource(name='example', target=None) # obj = NullableForeignKeySource(name='example', target=None)
obj.save() # obj.save()
serialized_data = NullableFKSourceSerializer(obj).data # serialized_data = NullableFKSourceSerializer(obj).data
response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) # response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
self.assertEqual(response.data, serialized_data) # self.assertEqual(response.data, serialized_data)

View File

@ -391,10 +391,10 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer): class BasicModelSerializer(serializers.Serializer):
text = CustomField() text = CustomField()
def __init__(self, *args, **kwargs): def to_native(self, value):
super(BasicModelSerializer, self).__init__(*args, **kwargs)
if 'view' not in self.context: if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into serializer init") raise RuntimeError("context isn't getting passed into serializer")
return super(BasicSerializer, self).to_native(value)
class TestContextPassedToCustomField(TestCase): class TestContextPassedToCustomField(TestCase):
@ -423,7 +423,7 @@ class LinksSerializer(serializers.Serializer):
class CustomPaginationSerializer(pagination.BasePaginationSerializer): class CustomPaginationSerializer(pagination.BasePaginationSerializer):
links = LinksSerializer(source='*') # Takes the page object as the source links = LinksSerializer(source='*') # Takes the page object as the source
total_results = serializers.Field(source='paginator.count') total_results = serializers.ReadOnlyField(source='paginator.count')
results_field = 'objects' results_field = 'objects'

View File

@ -108,59 +108,59 @@ class ModelPermissionsIntegrationTests(TestCase):
response = instance_view(request, pk='2') response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self): # def test_options_permitted(self):
request = factory.options( # request = factory.options(
'/', # '/',
HTTP_AUTHORIZATION=self.permitted_credentials # HTTP_AUTHORIZATION=self.permitted_credentials
) # )
response = root_view(request, pk='1') # response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) # self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['POST']) # self.assertEqual(list(response.data['actions'].keys()), ['POST'])
request = factory.options( # request = factory.options(
'/1', # '/1',
HTTP_AUTHORIZATION=self.permitted_credentials # HTTP_AUTHORIZATION=self.permitted_credentials
) # )
response = instance_view(request, pk='1') # response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) # self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) # self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_options_disallowed(self): # def test_options_disallowed(self):
request = factory.options( # request = factory.options(
'/', # '/',
HTTP_AUTHORIZATION=self.disallowed_credentials # HTTP_AUTHORIZATION=self.disallowed_credentials
) # )
response = root_view(request, pk='1') # response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) # self.assertNotIn('actions', response.data)
request = factory.options( # request = factory.options(
'/1', # '/1',
HTTP_AUTHORIZATION=self.disallowed_credentials # HTTP_AUTHORIZATION=self.disallowed_credentials
) # )
response = instance_view(request, pk='1') # response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) # self.assertNotIn('actions', response.data)
def test_options_updateonly(self): # def test_options_updateonly(self):
request = factory.options( # request = factory.options(
'/', # '/',
HTTP_AUTHORIZATION=self.updateonly_credentials # HTTP_AUTHORIZATION=self.updateonly_credentials
) # )
response = root_view(request, pk='1') # response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) # self.assertNotIn('actions', response.data)
request = factory.options( # request = factory.options(
'/1', # '/1',
HTTP_AUTHORIZATION=self.updateonly_credentials # HTTP_AUTHORIZATION=self.updateonly_credentials
) # )
response = instance_view(request, pk='1') # response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) # self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) # self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) # self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
class BasicPermModel(models.Model): class BasicPermModel(models.Model):

View File

@ -1,149 +1,149 @@
""" # """
General tests for relational fields. # General tests for relational fields.
""" # """
from __future__ import unicode_literals # from __future__ import unicode_literals
from django import get_version # from django import get_version
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from django.utils import unittest # from django.utils import unittest
from rest_framework import serializers # from rest_framework import serializers
from tests.models import BlogPost # from tests.models import BlogPost
class NullModel(models.Model): # class NullModel(models.Model):
pass # pass
class FieldTests(TestCase): # class FieldTests(TestCase):
def test_pk_related_field_with_empty_string(self): # def test_pk_related_field_with_empty_string(self):
""" # """
Regression test for #446 # Regression test for #446
https://github.com/tomchristie/django-rest-framework/issues/446 # https://github.com/tomchristie/django-rest-framework/issues/446
""" # """
field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) # field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
self.assertRaises(serializers.ValidationError, field.from_native, '') # self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) # self.assertRaises(serializers.ValidationError, field.from_native, [])
def test_hyperlinked_related_field_with_empty_string(self): # def test_hyperlinked_related_field_with_empty_string(self):
field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') # field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
self.assertRaises(serializers.ValidationError, field.from_native, '') # self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) # self.assertRaises(serializers.ValidationError, field.from_native, [])
def test_slug_related_field_with_empty_string(self): # def test_slug_related_field_with_empty_string(self):
field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') # field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
self.assertRaises(serializers.ValidationError, field.from_native, '') # self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) # self.assertRaises(serializers.ValidationError, field.from_native, [])
class TestManyRelatedMixin(TestCase): # class TestManyRelatedMixin(TestCase):
def test_missing_many_to_many_related_field(self): # def test_missing_many_to_many_related_field(self):
''' # '''
Regression test for #632 # Regression test for #632
https://github.com/tomchristie/django-rest-framework/pull/632 # https://github.com/tomchristie/django-rest-framework/pull/632
''' # '''
field = serializers.RelatedField(many=True, read_only=False) # field = serializers.RelatedField(many=True, read_only=False)
into = {} # into = {}
field.field_from_native({}, None, 'field_name', into) # field.field_from_native({}, None, 'field_name', into)
self.assertEqual(into['field_name'], []) # self.assertEqual(into['field_name'], [])
# Regression tests for #694 (`source` attribute on related fields) # # Regression tests for #694 (`source` attribute on related fields)
class RelatedFieldSourceTests(TestCase): # class RelatedFieldSourceTests(TestCase):
def test_related_manager_source(self): # def test_related_manager_source(self):
""" # """
Relational fields should be able to use manager-returning methods as their source. # Relational fields should be able to use manager-returning methods as their source.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.RelatedField(many=True, source='get_blogposts_manager') # field = serializers.RelatedField(many=True, source='get_blogposts_manager')
class ClassWithManagerMethod(object): # class ClassWithManagerMethod(object):
def get_blogposts_manager(self): # def get_blogposts_manager(self):
return BlogPost.objects # return BlogPost.objects
obj = ClassWithManagerMethod() # obj = ClassWithManagerMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['BlogPost object']) # self.assertEqual(value, ['BlogPost object'])
def test_related_queryset_source(self): # def test_related_queryset_source(self):
""" # """
Relational fields should be able to use queryset-returning methods as their source. # Relational fields should be able to use queryset-returning methods as their source.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.RelatedField(many=True, source='get_blogposts_queryset') # field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
class ClassWithQuerysetMethod(object): # class ClassWithQuerysetMethod(object):
def get_blogposts_queryset(self): # def get_blogposts_queryset(self):
return BlogPost.objects.all() # return BlogPost.objects.all()
obj = ClassWithQuerysetMethod() # obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['BlogPost object']) # self.assertEqual(value, ['BlogPost object'])
def test_dotted_source(self): # def test_dotted_source(self):
""" # """
Source argument should support dotted.source notation. # Source argument should support dotted.source notation.
""" # """
BlogPost.objects.create(title='blah') # BlogPost.objects.create(title='blah')
field = serializers.RelatedField(many=True, source='a.b.c') # field = serializers.RelatedField(many=True, source='a.b.c')
class ClassWithQuerysetMethod(object): # class ClassWithQuerysetMethod(object):
a = { # a = {
'b': { # 'b': {
'c': BlogPost.objects.all() # 'c': BlogPost.objects.all()
} # }
} # }
obj = ClassWithQuerysetMethod() # obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name') # value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['BlogPost object']) # self.assertEqual(value, ['BlogPost object'])
# Regression for #1129 # # Regression for #1129
def test_exception_for_incorect_fk(self): # def test_exception_for_incorect_fk(self):
""" # """
Check that the exception message are correct if the source field # Check that the exception message are correct if the source field
doesn't exist. # doesn't exist.
""" # """
from tests.models import ManyToManySource # from tests.models import ManyToManySource
class Meta: # class Meta:
model = ManyToManySource # model = ManyToManySource
attrs = { # attrs = {
'name': serializers.SlugRelatedField( # 'name': serializers.SlugRelatedField(
slug_field='name', source='banzai'), # slug_field='name', source='banzai'),
'Meta': Meta, # 'Meta': Meta,
} # }
TestSerializer = type( # TestSerializer = type(
str('TestSerializer'), # str('TestSerializer'),
(serializers.ModelSerializer,), # (serializers.ModelSerializer,),
attrs # attrs
) # )
with self.assertRaises(AttributeError): # with self.assertRaises(AttributeError):
TestSerializer(data={'name': 'foo'}) # TestSerializer(data={'name': 'foo'})
@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') # @unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
class RelatedFieldChoicesTests(TestCase): # class RelatedFieldChoicesTests(TestCase):
""" # """
Tests for #1408 "Web browseable API doesn't have blank option on drop down list box" # Tests for #1408 "Web browseable API doesn't have blank option on drop down list box"
https://github.com/tomchristie/django-rest-framework/issues/1408 # https://github.com/tomchristie/django-rest-framework/issues/1408
""" # """
def test_blank_option_is_added_to_choice_if_required_equals_false(self): # def test_blank_option_is_added_to_choice_if_required_equals_false(self):
""" # """
""" # """
post = BlogPost(title="Checking blank option is added") # post = BlogPost(title="Checking blank option is added")
post.save() # post.save()
queryset = BlogPost.objects.all() # queryset = BlogPost.objects.all()
field = serializers.RelatedField(required=False, queryset=queryset) # field = serializers.RelatedField(required=False, queryset=queryset)
choice_count = BlogPost.objects.count() # choice_count = BlogPost.objects.count()
widget_count = len(field.widget.choices) # widget_count = len(field.widget.choices)
self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') # self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added')

File diff suppressed because it is too large Load Diff

View File

@ -1,326 +1,326 @@
from __future__ import unicode_literals # from __future__ import unicode_literals
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from .models import OneToOneTarget # from .models import OneToOneTarget
class OneToOneSource(models.Model): # class OneToOneSource(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, related_name='source', # target = models.OneToOneField(OneToOneTarget, related_name='source',
null=True, blank=True) # null=True, blank=True)
class OneToManyTarget(models.Model): # class OneToManyTarget(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
class OneToManySource(models.Model): # class OneToManySource(models.Model):
name = models.CharField(max_length=100) # name = models.CharField(max_length=100)
target = models.ForeignKey(OneToManyTarget, related_name='sources') # target = models.ForeignKey(OneToManyTarget, related_name='sources')
class ReverseNestedOneToOneTests(TestCase): # class ReverseNestedOneToOneTests(TestCase):
def setUp(self): # def setUp(self):
class OneToOneSourceSerializer(serializers.ModelSerializer): # class OneToOneSourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = OneToOneSource # model = OneToOneSource
fields = ('id', 'name') # fields = ('id', 'name')
class OneToOneTargetSerializer(serializers.ModelSerializer): # class OneToOneTargetSerializer(serializers.ModelSerializer):
source = OneToOneSourceSerializer() # source = OneToOneSourceSerializer()
class Meta: # class Meta:
model = OneToOneTarget # model = OneToOneTarget
fields = ('id', 'name', 'source') # fields = ('id', 'name', 'source')
self.Serializer = OneToOneTargetSerializer # self.Serializer = OneToOneTargetSerializer
for idx in range(1, 4): # for idx in range(1, 4):
target = OneToOneTarget(name='target-%d' % idx) # target = OneToOneTarget(name='target-%d' % idx)
target.save() # target.save()
source = OneToOneSource(name='source-%d' % idx, target=target) # source = OneToOneSource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_one_to_one_retrieve(self): # def test_one_to_one_retrieve(self):
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
{'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
{'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} # {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create(self): # def test_one_to_one_create(self):
data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} # data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-4') # self.assertEqual(obj.name, 'target-4')
# Ensure (target 4, target_source 4, source 4) are added, and # # Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected. # # everything else is as expected.
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
{'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
{'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, # {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
{'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} # {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self): # def test_one_to_one_create_with_invalid_data(self):
data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} # data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) # self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
def test_one_to_one_update(self): # def test_one_to_one_update(self):
data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} # data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
instance = OneToOneTarget.objects.get(pk=3) # instance = OneToOneTarget.objects.get(pk=3)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3-updated') # self.assertEqual(obj.name, 'target-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated, # # Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected. # # and everything else is as expected.
queryset = OneToOneTarget.objects.all() # queryset = OneToOneTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
{'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
{'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} # {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
class ForwardNestedOneToOneTests(TestCase): # class ForwardNestedOneToOneTests(TestCase):
def setUp(self): # def setUp(self):
class OneToOneTargetSerializer(serializers.ModelSerializer): # class OneToOneTargetSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = OneToOneTarget # model = OneToOneTarget
fields = ('id', 'name') # fields = ('id', 'name')
class OneToOneSourceSerializer(serializers.ModelSerializer): # class OneToOneSourceSerializer(serializers.ModelSerializer):
target = OneToOneTargetSerializer() # target = OneToOneTargetSerializer()
class Meta: # class Meta:
model = OneToOneSource # model = OneToOneSource
fields = ('id', 'name', 'target') # fields = ('id', 'name', 'target')
self.Serializer = OneToOneSourceSerializer # self.Serializer = OneToOneSourceSerializer
for idx in range(1, 4): # for idx in range(1, 4):
target = OneToOneTarget(name='target-%d' % idx) # target = OneToOneTarget(name='target-%d' % idx)
target.save() # target.save()
source = OneToOneSource(name='source-%d' % idx, target=target) # source = OneToOneSource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_one_to_one_retrieve(self): # def test_one_to_one_retrieve(self):
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} # {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create(self): # def test_one_to_one_create(self):
data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} # data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure (target 4, target_source 4, source 4) are added, and # # Ensure (target 4, target_source 4, source 4) are added, and
# everything else is as expected. # # everything else is as expected.
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, # {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
{'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} # {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_create_with_invalid_data(self): # def test_one_to_one_create_with_invalid_data(self):
data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} # data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) # self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
def test_one_to_one_update(self): # def test_one_to_one_update(self):
data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} # data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
instance = OneToOneSource.objects.get(pk=3) # instance = OneToOneSource.objects.get(pk=3)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-3-updated') # self.assertEqual(obj.name, 'source-3-updated')
# Ensure (target 3, target_source 3, source 3) are updated, # # Ensure (target 3, target_source 3, source 3) are updated,
# and everything else is as expected. # # and everything else is as expected.
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} # {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_one_update_to_null(self): # def test_one_to_one_update_to_null(self):
data = {'id': 3, 'name': 'source-3-updated', 'target': None} # data = {'id': 3, 'name': 'source-3-updated', 'target': None}
instance = OneToOneSource.objects.get(pk=3) # instance = OneToOneSource.objects.get(pk=3)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-3-updated') # self.assertEqual(obj.name, 'source-3-updated')
self.assertEqual(obj.target, None) # self.assertEqual(obj.target, None)
queryset = OneToOneSource.objects.all() # queryset = OneToOneSource.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, # {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
{'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, # {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
{'id': 3, 'name': 'source-3-updated', 'target': None} # {'id': 3, 'name': 'source-3-updated', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
# TODO: Nullable 1-1 tests # # TODO: Nullable 1-1 tests
# def test_one_to_one_delete(self): # # def test_one_to_one_delete(self):
# data = {'id': 3, 'name': 'target-3', 'target_source': None} # # data = {'id': 3, 'name': 'target-3', 'target_source': None}
# instance = OneToOneTarget.objects.get(pk=3) # # instance = OneToOneTarget.objects.get(pk=3)
# serializer = self.Serializer(instance, data=data) # # serializer = self.Serializer(instance, data=data)
# self.assertTrue(serializer.is_valid()) # # self.assertTrue(serializer.is_valid())
# serializer.save() # # serializer.save()
# # Ensure (target_source 3, source 3) are deleted, # # # Ensure (target_source 3, source 3) are deleted,
# # and everything else is as expected. # # # and everything else is as expected.
# queryset = OneToOneTarget.objects.all() # # queryset = OneToOneTarget.objects.all()
# serializer = self.Serializer(queryset) # # serializer = self.Serializer(queryset)
# expected = [ # # expected = [
# {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, # # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
# {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, # # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
# {'id': 3, 'name': 'target-3', 'source': None} # # {'id': 3, 'name': 'target-3', 'source': None}
# ] # # ]
# self.assertEqual(serializer.data, expected) # # self.assertEqual(serializer.data, expected)
class ReverseNestedOneToManyTests(TestCase): # class ReverseNestedOneToManyTests(TestCase):
def setUp(self): # def setUp(self):
class OneToManySourceSerializer(serializers.ModelSerializer): # class OneToManySourceSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = OneToManySource # model = OneToManySource
fields = ('id', 'name') # fields = ('id', 'name')
class OneToManyTargetSerializer(serializers.ModelSerializer): # class OneToManyTargetSerializer(serializers.ModelSerializer):
sources = OneToManySourceSerializer(many=True, allow_add_remove=True) # sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
class Meta: # class Meta:
model = OneToManyTarget # model = OneToManyTarget
fields = ('id', 'name', 'sources') # fields = ('id', 'name', 'sources')
self.Serializer = OneToManyTargetSerializer # self.Serializer = OneToManyTargetSerializer
target = OneToManyTarget(name='target-1') # target = OneToManyTarget(name='target-1')
target.save() # target.save()
for idx in range(1, 4): # for idx in range(1, 4):
source = OneToManySource(name='source-%d' % idx, target=target) # source = OneToManySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_one_to_many_retrieve(self): # def test_one_to_many_retrieve(self):
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]}, # {'id': 3, 'name': 'source-3'}]},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_many_create(self): # def test_one_to_many_create(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}, # {'id': 3, 'name': 'source-3'},
{'id': 4, 'name': 'source-4'}]} # {'id': 4, 'name': 'source-4'}]}
instance = OneToManyTarget.objects.get(pk=1) # instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-1') # self.assertEqual(obj.name, 'target-1')
# Ensure source 4 is added, and everything else is as # # Ensure source 4 is added, and everything else is as
# expected. # # expected.
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}, # {'id': 3, 'name': 'source-3'},
{'id': 4, 'name': 'source-4'}]} # {'id': 4, 'name': 'source-4'}]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_many_create_with_invalid_data(self): # def test_one_to_many_create_with_invalid_data(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}, # {'id': 3, 'name': 'source-3'},
{'id': 4}]} # {'id': 4}]}
serializer = self.Serializer(data=data) # serializer = self.Serializer(data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) # self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
def test_one_to_many_update(self): # def test_one_to_many_update(self):
data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, # data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
instance = OneToManyTarget.objects.get(pk=1) # instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-1-updated') # self.assertEqual(obj.name, 'target-1-updated')
# Ensure (target 1, source 1) are updated, # # Ensure (target 1, source 1) are updated,
# and everything else is as expected. # # and everything else is as expected.
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, # {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
{'id': 2, 'name': 'source-2'}, # {'id': 2, 'name': 'source-2'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_one_to_many_delete(self): # def test_one_to_many_delete(self):
data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
instance = OneToManyTarget.objects.get(pk=1) # instance = OneToManyTarget.objects.get(pk=1)
serializer = self.Serializer(instance, data=data) # serializer = self.Serializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
serializer.save() # serializer.save()
# Ensure source 2 is deleted, and everything else is as # # Ensure source 2 is deleted, and everything else is as
# expected. # # expected.
queryset = OneToManyTarget.objects.all() # queryset = OneToManyTarget.objects.all()
serializer = self.Serializer(queryset, many=True) # serializer = self.Serializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, # {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
{'id': 3, 'name': 'source-3'}]} # {'id': 3, 'name': 'source-3'}]}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)

File diff suppressed because it is too large Load Diff

View File

@ -1,257 +1,257 @@
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget # from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
class ForeignKeyTargetSerializer(serializers.ModelSerializer): # class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.SlugRelatedField(many=True, slug_field='name') # sources = serializers.SlugRelatedField(many=True, slug_field='name')
class Meta: # class Meta:
model = ForeignKeyTarget # model = ForeignKeyTarget
class ForeignKeySourceSerializer(serializers.ModelSerializer): # class ForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name') # target = serializers.SlugRelatedField(slug_field='name')
class Meta: # class Meta:
model = ForeignKeySource # model = ForeignKeySource
class NullableForeignKeySourceSerializer(serializers.ModelSerializer): # class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name', required=False) # target = serializers.SlugRelatedField(slug_field='name', required=False)
class Meta: # class Meta:
model = NullableForeignKeySource # model = NullableForeignKeySource
# TODO: M2M Tests, FKTests (Non-nullable), One2One # # TODO: M2M Tests, FKTests (Non-nullable), One2One
class SlugForeignKeyTests(TestCase): # class SlugForeignKeyTests(TestCase):
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
new_target = ForeignKeyTarget(name='target-2') # new_target = ForeignKeyTarget(name='target-2')
new_target.save() # new_target.save()
for idx in range(1, 4): # for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target) # source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve(self): # def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'} # {'id': 3, 'name': 'source-3', 'target': 'target-1'}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self): # def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self): # def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} # data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-2'}, # {'id': 1, 'name': 'source-1', 'target': 'target-2'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'} # {'id': 3, 'name': 'source-3', 'target': 'target-1'}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self): # def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 123} # data = {'id': 1, 'name': 'source-1', 'target': 123}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) # self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
def test_reverse_foreign_key_update(self): # def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} # data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2) # instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) # serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save # # We shouldn't have saved anything to the db yet since save
# hasn't been called. # # hasn't been called.
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) # new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
] # ]
self.assertEqual(new_serializer.data, expected) # self.assertEqual(new_serializer.data, expected)
serializer.save() # serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected # # Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, # {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self): # def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} # data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
serializer = ForeignKeySourceSerializer(data=data) # serializer = ForeignKeySourceSerializer(data=data)
serializer.is_valid() # serializer.is_valid()
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected # # Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() # queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) # serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}, # {'id': 3, 'name': 'source-3', 'target': 'target-1'},
{'id': 4, 'name': 'source-4', 'target': 'target-2'}, # {'id': 4, 'name': 'source-4', 'target': 'target-2'},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self): # def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} # data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data) # serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'target-3') # self.assertEqual(obj.name, 'target-3')
# Ensure target 3 is added, and everything else is as expected # # Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) # serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']}, # {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, # {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): # def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} # data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) # instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) # serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) # self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field is required.']}) # self.assertEqual(serializer.errors, {'target': ['This field is required.']})
class SlugNullableForeignKeyTests(TestCase): # class SlugNullableForeignKeyTests(TestCase):
def setUp(self): # def setUp(self):
target = ForeignKeyTarget(name='target-1') # target = ForeignKeyTarget(name='target-1')
target.save() # target.save()
for idx in range(1, 4): # for idx in range(1, 4):
if idx == 3: # if idx == 3:
target = None # target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) # source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() # source.save()
def test_foreign_key_retrieve_with_null(self): # def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): # def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': 'source-4', 'target': None} # data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) # serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None} # {'id': 4, 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self): # def test_foreign_key_create_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'id': 4, 'name': 'source-4', 'target': ''} # data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None} # expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) # serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
obj = serializer.save() # obj = serializer.save()
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
self.assertEqual(obj.name, 'source-4') # self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected # # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, # {'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}, # {'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None} # {'id': 4, 'name': 'source-4', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): # def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} # data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) # serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, data) # self.assertEqual(serializer.data, data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': None}, # {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None} # {'id': 3, 'name': 'source-3', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self): # def test_foreign_key_update_with_valid_emptystring(self):
""" # """
The emptystring should be interpreted as null in the context # The emptystring should be interpreted as null in the context
of relationships. # of relationships.
""" # """
data = {'id': 1, 'name': 'source-1', 'target': ''} # data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None} # expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) # instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) # serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, expected_data) # self.assertEqual(serializer.data, expected_data)
serializer.save() # serializer.save()
# Ensure source 1 is updated, and everything else is as expected # # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() # queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) # serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ # expected = [
{'id': 1, 'name': 'source-1', 'target': None}, # {'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, # {'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None} # {'id': 3, 'name': 'source-3', 'target': None}
] # ]
self.assertEqual(serializer.data, expected) # self.assertEqual(serializer.data, expected)