From 20070f4133ed763c9202b5c5b647d76f59a7e021 Mon Sep 17 00:00:00 2001 From: Bill Collins Date: Sun, 20 Mar 2022 12:54:28 +0000 Subject: [PATCH] Separate schemas per serializer/method --- rest_framework/schemas/openapi.py | 71 ++++++++++++++++++++----------- tests/schemas/test_openapi.py | 41 +++++++++--------- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 6546fee13..fd7576091 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -162,7 +162,7 @@ class AutoSchema(ViewInspector): return operation - def get_component_name(self, serializer): + def get_component_name(self, serializer, method): """ Compute the component's name from the serializer. Raise an exception if the serializer's class name is "Serializer" (case-insensitive). @@ -183,6 +183,13 @@ class AutoSchema(ViewInspector): .format(serializer.__class__.__name__) ) + if method.lower() == 'patch': + return component_name + 'PartialUpdate' + if method.lower() == 'put': + return component_name + 'Update' + if method.lower() == 'post': + return component_name + 'Create' + return component_name def get_components(self, path, method): @@ -194,16 +201,16 @@ class AutoSchema(ViewInspector): return {} request_serializer = self.get_request_serializer(path, method) - response_serializer = self.get_response_serializer(path, method) + response_serializer = self.get_response_serializer(path, 'GET') if isinstance(request_serializer, serializers.Serializer): - component_name = self.get_component_name(request_serializer) - content = self.map_serializer(request_serializer) + component_name = self.get_component_name(request_serializer, method) + content = self.map_serializer(request_serializer, method) self.components.setdefault(component_name, content) if isinstance(response_serializer, serializers.Serializer): - component_name = self.get_component_name(response_serializer) - content = self.map_serializer(response_serializer) + component_name = self.get_component_name(response_serializer, 'GET') + content = self.map_serializer(response_serializer, 'GET') self.components.setdefault(component_name, content) return self.components @@ -364,27 +371,27 @@ class AutoSchema(ViewInspector): mapping['type'] = type return mapping - def map_field(self, field): + def map_field(self, field, method): # Nested Serializers, `many` or not. if isinstance(field, serializers.ListSerializer): return { 'type': 'array', - 'items': self.map_serializer(field.child) + 'items': self.map_serializer(field.child, method) } if isinstance(field, serializers.Serializer): - data = self.map_serializer(field) + data = self.map_serializer(field, method) data['type'] = 'object' return data if isinstance(field, serializers.SerializerMethodField) and field.output_field: - return self.map_field(field.output_field) + return self.map_field(field.output_field, method) # Related fields. if isinstance(field, serializers.ManyRelatedField): return { 'type': 'array', - 'items': self.map_field(field.child_relation) + 'items': self.map_field(field.child_relation, method) } if isinstance(field, serializers.PrimaryKeyRelatedField): model = getattr(field.queryset, 'model', None) @@ -413,7 +420,7 @@ class AutoSchema(ViewInspector): 'items': {}, } if not isinstance(field.child, _UnvalidatedField): - mapping['items'] = self.map_field(field.child) + mapping['items'] = self.map_field(field.child, method) return mapping # DateField and DateTimeField type is string @@ -515,7 +522,7 @@ class AutoSchema(ViewInspector): if field.min_value: content['minimum'] = field.min_value - def map_serializer(self, serializer): + def map_serializer(self, serializer, method): # Assuming we have a valid serializer instance. required = [] properties = {} @@ -524,10 +531,24 @@ class AutoSchema(ViewInspector): if isinstance(field, serializers.HiddenField): continue - if field.required: - required.append(field.field_name) + if method.lower() == 'get': + if field.write_only: + # Write only fields don't appear in the output + continue + else: + # All non-write-only fields are required in get requests + required.append(field.field_name) - schema = self.map_field(field) + if method.lower() in ('post', 'put', 'patch') and field.read_only: + # Don't emit readonly fields for writable methods + continue + + if method.lower() != 'patch' and field.required: + # Only mark required fields as required if we're not patching + if field.field_name not in required: + required.append(field.field_name) + + schema = self.map_field(field, method) if field.read_only: schema['readOnly'] = True if field.write_only: @@ -549,9 +570,9 @@ class AutoSchema(ViewInspector): if required: result['required'] = required - component_name = self.get_component_name(serializer=serializer) + component_name = self.get_component_name(serializer=serializer, method=method) self.components[component_name] = result - return self._get_reference(serializer) + return self._get_reference(serializer, method) def map_field_validators(self, field, schema): """ @@ -640,8 +661,8 @@ class AutoSchema(ViewInspector): """ return self.get_serializer(path, method) - def _get_reference(self, serializer): - return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} + def _get_reference(self, serializer, method): + return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer, method))} def get_request_body(self, path, method): if method not in ('PUT', 'PATCH', 'POST'): @@ -654,7 +675,7 @@ class AutoSchema(ViewInspector): if not isinstance(serializer, serializers.Serializer): item_schema = {} else: - item_schema = self._get_reference(serializer) + item_schema = self._get_reference(serializer, method) return { 'content': { @@ -673,12 +694,12 @@ class AutoSchema(ViewInspector): self.response_media_types = self.map_renderers(path, method) - serializer = self.get_response_serializer(path, method) + serializer = self.get_response_serializer(path, 'GET') if not isinstance(serializer, serializers.Serializer): item_schema = {} else: - item_schema = self._get_reference(serializer) + item_schema = self._get_reference(serializer, 'GET') if is_list_view(path, method, self.view): response_schema = { @@ -779,7 +800,7 @@ class AutoSchema(ViewInspector): "The old name will be removed in DRF v3.14.", RemovedInDRF314Warning, stacklevel=2 ) - return self.map_serializer(serializer) + return self.map_serializer(serializer, 'GET') def _map_field(self, field): warnings.warn( @@ -787,7 +808,7 @@ class AutoSchema(ViewInspector): "The old name will be removed in DRF v3.14.", RemovedInDRF314Warning, stacklevel=2 ) - return self.map_field(field) + return self.map_field(field, 'GET') def _map_choicefield(self, field): warnings.warn( diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 994701988..925220f33 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -87,7 +87,7 @@ class TestFieldMapping(TestCase): ] for field, mapping in cases: with self.subTest(field=field): - assert inspector.map_field(field) == mapping + assert inspector.map_field(field, method='GET') == mapping def test_lazy_string_field(self): class ItemSerializer(serializers.Serializer): @@ -95,7 +95,7 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - inspector.map_serializer(ItemSerializer()) + inspector.map_serializer(ItemSerializer(), method='GET') data = inspector.components['Item'] assert isinstance(data['properties']['text']['description'], str), "description must be str" @@ -107,7 +107,7 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - inspector.map_serializer(BooleanTestSerializer()) + inspector.map_serializer(BooleanTestSerializer(), method='GET') data = inspector.components['BooleanTest'] assert data['properties']['default_true']['default'] is True, "default must be true" assert data['properties']['default_false']['default'] is False, "default must be false" @@ -126,7 +126,7 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - inspector.map_serializer(NullableSerializer()) + inspector.map_serializer(NullableSerializer(), method='GET') data = inspector.components['Nullable'] assert data['properties']['rw_field']['nullable'], "rw_field nullable must be true" assert data['properties']['ro_field']['nullable'], "ro_field nullable must be true" @@ -142,7 +142,7 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - inspector.map_serializer(MethodSerializer()) + inspector.map_serializer(MethodSerializer(), method='GET') data = inspector.components['Method'] assert data['properties']['method_field']['type'] == 'boolean' @@ -243,11 +243,11 @@ class TestOperationIntrospection(TestCase): request_body = inspector.get_request_body(path, method) print(request_body) - assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' + assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/ItemCreate' components = inspector.get_components(path, method) - assert components['Item']['required'] == ['text'] - assert sorted(list(components['Item']['properties'].keys())) == ['read_only', 'text'] + assert components['ItemCreate']['required'] == ['text'] + assert sorted(list(components['ItemCreate']['properties'].keys())) == ['text'] def test_invalid_serializer_class_name(self): path = '/' @@ -271,7 +271,7 @@ class TestOperationIntrospection(TestCase): serializer = inspector.get_serializer(path, method) with pytest.raises(Exception) as exc: - inspector.get_component_name(serializer) + inspector.get_component_name(serializer, method='GET') assert "is an invalid class name for schema generation" in str(exc.value) def test_empty_required(self): @@ -294,7 +294,7 @@ class TestOperationIntrospection(TestCase): inspector.view = view components = inspector.get_components(path, method) - component = components['Item'] + component = components['ItemCreate'] # there should be no empty 'required' property, see #6834 assert 'required' not in component @@ -321,7 +321,7 @@ class TestOperationIntrospection(TestCase): inspector.view = view components = inspector.get_components(path, method) - component = components['Item'] + component = components['ItemPartialUpdate'] # there should be no empty 'required' property, see #6834 assert 'required' not in component for response in inspector.get_responses(path, method).values(): @@ -350,8 +350,10 @@ class TestOperationIntrospection(TestCase): assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' components = inspector.get_components(path, method) - assert sorted(components['Item']['required']) == ['text', 'write_only'] - assert sorted(list(components['Item']['properties'].keys())) == ['text', 'write_only'] + assert sorted(components['Item']['required']) == ['text'] + assert sorted(list(components['Item']['properties'].keys())) == ['text'] + assert sorted(components['ItemCreate']['required']) == ['text', 'write_only'] + assert sorted(list(components['ItemCreate']['properties'].keys())) == ['text', 'write_only'] assert 'description' in responses['201'] def test_response_body_nested_serializer(self): @@ -776,7 +778,7 @@ class TestOperationIntrospection(TestCase): components = inspector.get_components(path, method) assert components == { - 'Request': { + 'RequestCreate': { 'properties': { 'text': { 'type': 'string' @@ -805,17 +807,17 @@ class TestOperationIntrospection(TestCase): 'content': { 'application/json': { 'schema': { - '$ref': '#/components/schemas/Request' + '$ref': '#/components/schemas/RequestCreate' } }, 'application/x-www-form-urlencoded': { 'schema': { - '$ref': '#/components/schemas/Request' + '$ref': '#/components/schemas/RequestCreate' } }, 'multipart/form-data': { 'schema': { - '$ref': '#/components/schemas/Request' + '$ref': '#/components/schemas/RequestCreate' } } } @@ -1197,14 +1199,13 @@ class TestGenerator(TestCase): body_schema = route['requestBody']['content']['application/json']['schema'] assert body_schema == { - '$ref': '#/components/schemas/AuthToken' + '$ref': '#/components/schemas/AuthTokenCreate' } - assert schema['components']['schemas']['AuthToken'] == { + assert schema['components']['schemas']['AuthTokenCreate'] == { 'type': 'object', 'properties': { 'username': {'type': 'string', 'writeOnly': True}, 'password': {'type': 'string', 'writeOnly': True}, - 'token': {'type': 'string', 'readOnly': True}, }, 'required': ['username', 'password'] }