Hyperlinked PK optimization. Closes #1872.

This commit is contained in:
Tom Christie 2014-12-09 17:28:56 +00:00
parent 7d70e56ce3
commit 720a37d3de
2 changed files with 45 additions and 31 deletions

View File

@ -84,9 +84,34 @@ class RelatedField(Field):
queryset = queryset.all() queryset = queryset.all()
return queryset return queryset
def use_pk_only_optimization(self):
return False
def get_attribute(self, instance):
if self.use_pk_only_optimization():
try:
# Optimized case, return a mock object only containing the pk attribute.
instance = get_attribute(instance, self.source_attrs[:-1])
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
except AttributeError:
pass
# Standard case, return the object instance.
return get_attribute(instance, self.source_attrs)
def get_iterable(self, instance, source_attrs): def get_iterable(self, instance, source_attrs):
relationship = get_attribute(instance, source_attrs) relationship = get_attribute(instance, source_attrs)
return relationship.all() if (hasattr(relationship, 'all')) else relationship relationship = relationship.all() if (hasattr(relationship, 'all')) else relationship
if self.use_pk_only_optimization():
# Optimized case, return mock objects only containing the pk attribute.
return [
PKOnlyObject(pk=pk)
for pk in relationship.values_list('pk', flat=True)
]
# Standard case, return the object instances.
return relationship
@property @property
def choices(self): def choices(self):
@ -120,6 +145,9 @@ class PrimaryKeyRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
} }
def use_pk_only_optimization(self):
return True
def to_internal_value(self, data): def to_internal_value(self, data):
try: try:
return self.get_queryset().get(pk=data) return self.get_queryset().get(pk=data)
@ -128,32 +156,6 @@ class PrimaryKeyRelatedField(RelatedField):
except (TypeError, ValueError): except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__) self.fail('incorrect_type', data_type=type(data).__name__)
def get_attribute(self, instance):
# We customize `get_attribute` here for performance reasons.
# For relationships the instance will already have the pk of
# the related object. We return this directly instead of returning the
# object itself, which would require a database lookup.
try:
instance = get_attribute(instance, self.source_attrs[:-1])
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
except AttributeError:
return get_attribute(instance, self.source_attrs)
def get_iterable(self, instance, source_attrs):
# For consistency with `get_attribute` we're using `serializable_value()`
# here. Typically there won't be any difference, but some custom field
# types might return a non-primitive value for the pk otherwise.
#
# We could try to get smart with `values_list('pk', flat=True)`, which
# would be better in some case, but would actually end up with *more*
# queries if the developer is using `prefetch_related` across the
# relationship.
relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source_attrs)
return [
PKOnlyObject(pk=item.serializable_value('pk'))
for item in relationship
]
def to_representation(self, value): def to_representation(self, value):
return value.pk return value.pk
@ -184,6 +186,9 @@ class HyperlinkedRelatedField(RelatedField):
super(HyperlinkedRelatedField, self).__init__(**kwargs) super(HyperlinkedRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self):
return self.lookup_field == 'pk'
def get_object(self, view_name, view_args, view_kwargs): def get_object(self, view_name, view_args, view_kwargs):
""" """
Return the object corresponding to a matched URL. Return the object corresponding to a matched URL.
@ -285,6 +290,11 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
kwargs['source'] = '*' kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
def use_pk_only_optimization(self):
# We have the complete object instance already. We don't need
# to run the 'only get the pk for this relationship' code.
return False
class SlugRelatedField(RelatedField): class SlugRelatedField(RelatedField):
""" """

View File

@ -89,7 +89,8 @@ class HyperlinkedManyToManyTests(TestCase):
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
] ]
self.assertEqual(serializer.data, expected) with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
@ -99,7 +100,8 @@ class HyperlinkedManyToManyTests(TestCase):
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
] ]
self.assertEqual(serializer.data, expected) with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
@ -197,7 +199,8 @@ class HyperlinkedForeignKeyTests(TestCase):
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
] ]
self.assertEqual(serializer.data, expected) with self.assertNumQueries(1):
self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
@ -206,7 +209,8 @@ class HyperlinkedForeignKeyTests(TestCase):
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
] ]
self.assertEqual(serializer.data, expected) with self.assertNumQueries(3):
self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}