From 4e561a4ac02afe1f4883ba3347a422c876da1e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Beaul=C3=A9?= Date: Thu, 5 Mar 2020 20:13:15 -0500 Subject: [PATCH] Allow serializer data to include related objects by object When instantiating a serializer from cleaned data, the relational fields have been already converted into objects, which makes the fields invalid. Having a valid object as the value of a field is not necessarily considered valid, if the object is not in the given queryset. As such, this change adds a check to allow objects of the model type and in the queryset in lieu of the proper representation. --- rest_framework/relations.py | 10 +++++++++- tests/test_relations.py | 13 +++++++++++++ tests/test_versioning.py | 5 +++++ tests/utils.py | 2 ++ 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 3a2a8fb4b..06e886efd 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -250,7 +250,9 @@ class PrimaryKeyRelatedField(RelatedField): return True def to_internal_value(self, data): - if self.pk_field is not None: + if isinstance(data, self.get_queryset().model): + data = data.pk + elif self.pk_field is not None: data = self.pk_field.to_internal_value(data) try: return self.get_queryset().get(pk=data) @@ -331,6 +333,12 @@ class HyperlinkedRelatedField(RelatedField): return self.reverse(view_name, kwargs=kwargs, request=request, format=format) def to_internal_value(self, data): + if isinstance(data, self.get_queryset().model): + try: + return self.get_queryset().get(**{self.lookup_url_kwarg: getattr(data, self.lookup_field)}) + except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError): + self.fail('does_not_exist') + request = self.context.get('request', None) try: http_prefix = data.startswith(('http:', 'https:')) diff --git a/tests/test_relations.py b/tests/test_relations.py index 9f05e3b31..e53182a2a 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -116,6 +116,10 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): instance = field.to_internal_value(self.instance.pk) assert instance is self.instance + def test_pk_related_object_given(self): + instance = self.field.to_internal_value(self.instance) + assert instance is self.queryset.items[2] + class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase): def setUp(self): @@ -216,6 +220,11 @@ class TestHyperlinkedRelatedField(APISimpleTestCase): def hyperlinked_related_queryset_error(self, exc_type): class QuerySet: + class FakeObject: + pass + + model = FakeObject + def get(self, *args, **kwargs): raise exc_type @@ -235,6 +244,10 @@ class TestHyperlinkedRelatedField(APISimpleTestCase): def test_hyperlinked_related_queryset_value_error(self): self.hyperlinked_related_queryset_error(ValueError) + def test_hyperlinked_related_object_given(self): + instance = self.field.to_internal_value(self.queryset.items[2]) + assert instance is self.queryset.items[2] + class TestHyperlinkedIdentityField(APISimpleTestCase): def setUp(self): diff --git a/tests/test_versioning.py b/tests/test_versioning.py index d4e269df3..a0efd8fc2 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -322,6 +322,11 @@ class TestHyperlinkedRelatedField(URLPatternsTestCase, APITestCase): super().setUp() class MockQueryset: + class MockObject: + pass + + model = MockObject + def get(self, pk): return 'object %s' % pk diff --git a/tests/utils.py b/tests/utils.py index 06e5b9abe..2d31976f1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,8 @@ class MockObject: class MockQueryset: + model = MockObject + def __init__(self, iterable): self.items = iterable