Ensure implementation of reverse relations in 'fields' is backwards compatible

This commit is contained in:
Tom Christie 2013-04-23 11:31:38 +01:00
parent bcf4cb2b4e
commit 4bf1a09bae
5 changed files with 96 additions and 87 deletions

View File

@ -25,7 +25,7 @@ class BasePermission(object):
""" """
Return `True` if permission is granted, `False` otherwise. Return `True` if permission is granted, `False` otherwise.
""" """
if len(inspect.getargspec(self.has_permission)[0]) == 4: if len(inspect.getargspec(self.has_permission).args) == 4:
warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. '
'Use `has_object_permission()` instead for object permissions.', 'Use `has_object_permission()` instead for object permissions.',
PendingDeprecationWarning, stacklevel=2) PendingDeprecationWarning, stacklevel=2)

View File

@ -568,54 +568,73 @@ class ModelSerializer(Serializer):
assert cls is not None, \ assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta opts = get_concrete_model(cls)._meta
pk_field = opts.pk
# If model is a child via multitable inheritance, use parent's pk
while pk_field.rel and pk_field.rel.parent_link:
pk_field = pk_field.rel.to._meta.pk
fields = [pk_field]
fields += [field for field in opts.fields if field.serialize]
fields += [field for field in opts.many_to_many if field.serialize]
ret = SortedDict() ret = SortedDict()
nested = bool(self.opts.depth) nested = bool(self.opts.depth)
is_pk = True # First field in the list is the pk
for model_field in fields: # Deal with adding the primary key field
if is_pk: pk_field = opts.pk
field = self.get_pk_field(model_field) while pk_field.rel and pk_field.rel.parent_link:
is_pk = False # If model is a child via multitable inheritance, use parent's pk
elif model_field.rel and nested: pk_field = pk_field.rel.to._meta.pk
field = self.get_nested_field(model_field)
elif model_field.rel: field = self.get_pk_field(pk_field)
if field:
ret[pk_field.name] = field
# Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize]
forward_rels += [field for field in opts.many_to_many if field.serialize]
for model_field in forward_rels:
if model_field.rel:
to_many = isinstance(model_field, to_many = isinstance(model_field,
models.fields.related.ManyToManyField) models.fields.related.ManyToManyField)
field = self.get_related_field(model_field, to_many=to_many) related_model = model_field.rel.to
if model_field.rel and nested:
if len(inspect.getargspec(self.get_nested_field).args) == 2:
# TODO: deprecation warning
field = self.get_nested_field(model_field)
else:
field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
if len(inspect.getargspec(self.get_nested_field).args) == 3:
# TODO: deprecation warning
field = self.get_related_field(model_field, to_many=to_many)
else:
field = self.get_related_field(model_field, related_model, to_many)
else: else:
field = self.get_field(model_field) field = self.get_field(model_field)
if field: if field:
ret[model_field.name] = field ret[model_field.name] = field
# Reverse relationships are only included if they are explicitly # Deal with reverse relationships
# present in `Meta.fields`. if not self.opts.fields:
if self.opts.fields: reverse_rels = []
reverse = opts.get_all_related_objects() else:
reverse += opts.get_all_related_many_to_many_objects() # Reverse relationships are only included if they are explicitly
for rel in reverse: # present in the `fields` option on the serializer
name = rel.get_accessor_name() reverse_rels = opts.get_all_related_objects()
if name not in self.opts.fields: reverse_rels += opts.get_all_related_many_to_many_objects()
continue
if nested: for relation in reverse_rels:
field = self.get_nested_field(None, rel) accessor_name = relation.get_accessor_name()
else: if accessor_name not in self.opts.fields:
field = self.get_related_field(None, rel, to_many=True) continue
related_model = relation.model
to_many = relation.field.rel.multiple
if field: if nested:
ret[name] = field field = self.get_nested_field(None, related_model, to_many)
else:
field = self.get_related_field(None, related_model, to_many)
if field:
ret[accessor_name] = field
# Add the `read_only` flag to any fields that have bee specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields: for field_name in self.opts.read_only_fields:
assert field_name in ret, \ assert field_name in ret, \
"read_only_fields on '%s' included invalid item '%s'" % \ "read_only_fields on '%s' included invalid item '%s'" % \
@ -630,39 +649,30 @@ class ModelSerializer(Serializer):
""" """
return self.get_field(model_field) return self.get_field(model_field)
def get_nested_field(self, model_field, rel=None): def get_nested_field(self, model_field, related_model, to_many):
""" """
Creates a default instance of a nested relational field. Creates a default instance of a nested relational field.
""" """
if rel:
model_class = rel.model
else:
model_class = model_field.rel.to
class NestedModelSerializer(ModelSerializer): class NestedModelSerializer(ModelSerializer):
class Meta: class Meta:
model = model_class model = related_model
return NestedModelSerializer() return NestedModelSerializer(many=to_many)
def get_related_field(self, model_field, rel=None, to_many=False): def get_related_field(self, model_field, related_model, to_many):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
""" """
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
if rel:
model_class = rel.model
required = True
else:
model_class = model_field.rel.to
required = not(model_field.null or model_field.blank)
kwargs = { kwargs = {
'required': required, 'queryset': related_model._default_manager,
'queryset': model_class._default_manager,
'many': to_many 'many': to_many
} }
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
return PrimaryKeyRelatedField(**kwargs) return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field): def get_field(self, model_field):
@ -830,19 +840,21 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.fields and model_field.name in self.opts.fields: if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field) return self.get_field(model_field)
def get_related_field(self, model_field, to_many): def get_related_field(self, model_field, related_model, to_many):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
""" """
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
kwargs = { kwargs = {
'required': not(model_field.null or model_field.blank), 'queryset': related_model._default_manager,
'queryset': rel._default_manager, 'view_name': self._get_default_view_name(related_model),
'view_name': self._get_default_view_name(rel),
'many': to_many 'many': to_many
} }
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
return HyperlinkedRelatedField(**kwargs) return HyperlinkedRelatedField(**kwargs)
def get_identity(self, data): def get_identity(self, data):

View File

@ -26,42 +26,44 @@ urlpatterns = patterns('',
) )
# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
class Meta: class Meta:
model = ManyToManyTarget model = ManyToManyTarget
fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta: class Meta:
model = ManyToManySource model = ManyToManySource
fields = ('url', 'name', 'targets')
# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta: class Meta:
model = ForeignKeySource model = ForeignKeySource
fields = ('url', 'name', 'target')
# Nullable ForeignKey # Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta: class Meta:
model = NullableForeignKeySource model = NullableForeignKeySource
fields = ('url', 'name', 'target')
# OneToOne # Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
class Meta: class Meta:
model = OneToOneTarget model = OneToOneTarget
fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid

View File

@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer): class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ForeignKeySource
fields = ('id', 'name', 'target')
depth = 1 depth = 1
model = ForeignKeySource
class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
class ForeignKeyTargetSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = FlatForeignKeySourceSerializer(many=True)
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
fields = ('id', 'name', 'sources')
depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
depth = 1
model = NullableForeignKeySource model = NullableForeignKeySource
fields = ('id', 'name', 'target')
depth = 1
class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableOneToOneSource
class NullableOneToOneTargetSerializer(serializers.ModelSerializer): class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
nullable_source = NullableOneToOneSourceSerializer()
class Meta: class Meta:
model = OneToOneTarget model = OneToOneTarget
fields = ('id', 'name', 'nullable_source')
depth = 1
class ReverseForeignKeyTests(TestCase): class ReverseForeignKeyTests(TestCase):

View File

@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore
from rest_framework.compat import six from rest_framework.compat import six
# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer): class ManyToManyTargetSerializer(serializers.ModelSerializer):
sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta: class Meta:
model = ManyToManyTarget model = ManyToManyTarget
fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer): class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ManyToManySource model = ManyToManySource
fields = ('id', 'name', 'targets')
# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer): class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ForeignKeySource model = ForeignKeySource
fields = ('id', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = NullableForeignKeySource model = NullableForeignKeySource
fields = ('id', 'name', 'target')
# OneToOne # Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer): class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
nullable_source = serializers.PrimaryKeyRelatedField()
class Meta: class Meta:
model = OneToOneTarget model = OneToOneTarget
fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid