Ensure implementation of reverse relations in 'fields' is backwards compatible

This commit is contained in:
Tom Christie 2013-04-23 11:31:38 +01:00
parent bcf4cb2b4e
commit 4bf1a09bae
5 changed files with 96 additions and 87 deletions

View File

@ -25,7 +25,7 @@ class BasePermission(object):
"""
Return `True` if permission is granted, `False` otherwise.
"""
if len(inspect.getargspec(self.has_permission)[0]) == 4:
if len(inspect.getargspec(self.has_permission).args) == 4:
warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. '
'Use `has_object_permission()` instead for object permissions.',
PendingDeprecationWarning, stacklevel=2)

View File

@ -568,54 +568,73 @@ class ModelSerializer(Serializer):
assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta
pk_field = opts.pk
# If model is a child via multitable inheritance, use parent's pk
while pk_field.rel and pk_field.rel.parent_link:
pk_field = pk_field.rel.to._meta.pk
fields = [pk_field]
fields += [field for field in opts.fields if field.serialize]
fields += [field for field in opts.many_to_many if field.serialize]
ret = SortedDict()
nested = bool(self.opts.depth)
is_pk = True # First field in the list is the pk
for model_field in fields:
if is_pk:
field = self.get_pk_field(model_field)
is_pk = False
elif model_field.rel and nested:
field = self.get_nested_field(model_field)
elif model_field.rel:
# Deal with adding the primary key field
pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
field = self.get_pk_field(pk_field)
if field:
ret[pk_field.name] = field
# Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize]
forward_rels += [field for field in opts.many_to_many if field.serialize]
for model_field in forward_rels:
if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
field = self.get_related_field(model_field, to_many=to_many)
related_model = model_field.rel.to
if model_field.rel and nested:
if len(inspect.getargspec(self.get_nested_field).args) == 2:
# TODO: deprecation warning
field = self.get_nested_field(model_field)
else:
field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
if len(inspect.getargspec(self.get_nested_field).args) == 3:
# TODO: deprecation warning
field = self.get_related_field(model_field, to_many=to_many)
else:
field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
if field:
ret[model_field.name] = field
# Reverse relationships are only included if they are explicitly
# present in `Meta.fields`.
if self.opts.fields:
reverse = opts.get_all_related_objects()
reverse += opts.get_all_related_many_to_many_objects()
for rel in reverse:
name = rel.get_accessor_name()
if name not in self.opts.fields:
continue
# Deal with reverse relationships
if not self.opts.fields:
reverse_rels = []
else:
# Reverse relationships are only included if they are explicitly
# present in the `fields` option on the serializer
reverse_rels = opts.get_all_related_objects()
reverse_rels += opts.get_all_related_many_to_many_objects()
if nested:
field = self.get_nested_field(None, rel)
else:
field = self.get_related_field(None, rel, to_many=True)
for relation in reverse_rels:
accessor_name = relation.get_accessor_name()
if accessor_name not in self.opts.fields:
continue
related_model = relation.model
to_many = relation.field.rel.multiple
if field:
ret[name] = field
if nested:
field = self.get_nested_field(None, related_model, to_many)
else:
field = self.get_related_field(None, related_model, to_many)
if field:
ret[accessor_name] = field
# Add the `read_only` flag to any fields that have bee specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name in ret, \
"read_only_fields on '%s' included invalid item '%s'" % \
@ -630,39 +649,30 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
def get_nested_field(self, model_field, rel=None):
def get_nested_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a nested relational field.
"""
if rel:
model_class = rel.model
else:
model_class = model_field.rel.to
class NestedModelSerializer(ModelSerializer):
class Meta:
model = model_class
return NestedModelSerializer()
model = related_model
return NestedModelSerializer(many=to_many)
def get_related_field(self, model_field, rel=None, to_many=False):
def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
if rel:
model_class = rel.model
required = True
else:
model_class = model_field.rel.to
required = not(model_field.null or model_field.blank)
kwargs = {
'required': required,
'queryset': model_class._default_manager,
'queryset': related_model._default_manager,
'many': to_many
}
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
@ -830,19 +840,21 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
def get_related_field(self, model_field, to_many):
def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
kwargs = {
'required': not(model_field.null or model_field.blank),
'queryset': rel._default_manager,
'view_name': self._get_default_view_name(rel),
'queryset': related_model._default_manager,
'view_name': self._get_default_view_name(related_model),
'many': to_many
}
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
return HyperlinkedRelatedField(**kwargs)
def get_identity(self, data):

View File

@ -26,42 +26,44 @@ urlpatterns = patterns('',
)
# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
class Meta:
model = ManyToManyTarget
fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
fields = ('url', 'name', 'targets')
# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
class Meta:
model = ForeignKeyTarget
fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
fields = ('url', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
fields = ('url', 'name', 'target')
# OneToOne
# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
class Meta:
model = OneToOneTarget
fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid

View File

@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
fields = ('id', 'name', 'target')
depth = 1
model = ForeignKeySource
class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = FlatForeignKeySourceSerializer(many=True)
class Meta:
model = ForeignKeyTarget
fields = ('id', 'name', 'sources')
depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
depth = 1
model = NullableForeignKeySource
class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableOneToOneSource
fields = ('id', 'name', 'target')
depth = 1
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
nullable_source = NullableOneToOneSourceSerializer()
class Meta:
model = OneToOneTarget
fields = ('id', 'name', 'nullable_source')
depth = 1
class ReverseForeignKeyTests(TestCase):

View File

@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore
from rest_framework.compat import six
# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer):
sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
model = ManyToManyTarget
fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
fields = ('id', 'name', 'targets')
# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
model = ForeignKeyTarget
fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
fields = ('id', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
fields = ('id', 'name', 'target')
# OneToOne
# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
nullable_source = serializers.PrimaryKeyRelatedField()
class Meta:
model = OneToOneTarget
fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid