diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index fd7576091..af2ceda43 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -192,6 +192,9 @@ class AutoSchema(ViewInspector): return component_name + def get_error_component_name(self, serializer): + return self.get_component_name(serializer, 'GET') + 'Error' + def get_components(self, path, method): """ Return components with their properties from the serializer. @@ -208,6 +211,11 @@ class AutoSchema(ViewInspector): content = self.map_serializer(request_serializer, method) self.components.setdefault(component_name, content) + if method.lower() in ('put', 'post', 'patch'): + error_component_name = self.get_error_component_name(request_serializer) + error_content = self.map_error_serializer(request_serializer) + self.components.setdefault(error_component_name, error_content) + if isinstance(response_serializer, serializers.Serializer): component_name = self.get_component_name(response_serializer, 'GET') content = self.map_serializer(response_serializer, 'GET') @@ -516,6 +524,23 @@ class AutoSchema(ViewInspector): } return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} + def map_error_field(self, field): + if isinstance(field, serializers.ListSerializer): + return { + 'type': 'array', + 'items': self.map_error_serializer(field.child) + } + if isinstance(field, serializers.Serializer): + return self.map_error_serializer(field) + if isinstance(field, serializers.SerializerMethodField) and isinstance(field.output_field, serializers.Serializer): + return self.map_error_serializer(field.output_field) + return { + 'type': 'array', + 'items': { + 'type': 'string' + }, + } + def _map_min_max(self, field, content): if field.max_value: content['maximum'] = field.max_value @@ -574,6 +599,33 @@ class AutoSchema(ViewInspector): self.components[component_name] = result return self._get_reference(serializer, method) + def map_error_serializer(self, serializer): + properties = { + api_settings.NON_FIELD_ERRORS_KEY: { + 'type': 'array', + 'items': { + 'type': 'string' + } + } + } + + for field in serializer.fields.values(): + if isinstance(field, serializers.HiddenField): + continue + if field.read_only: + continue + + properties[field.field_name] = self.map_error_field(field) + + result = { + 'type': 'object', + 'properties': properties + } + + component_name = self.get_error_component_name(serializer=serializer) + self.components[component_name] = result + return self._get_error_reference(serializer) + def map_field_validators(self, field, schema): """ map field validators @@ -664,6 +716,9 @@ class AutoSchema(ViewInspector): def _get_reference(self, serializer, method): return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer, method))} + def _get_error_reference(self, serializer): + return {'$ref': '#/components/schemas/{}'.format(self.get_error_component_name(serializer))} + def get_request_body(self, path, method): if method not in ('PUT', 'PATCH', 'POST'): return {} @@ -712,7 +767,7 @@ class AutoSchema(ViewInspector): else: response_schema = item_schema status_code = '201' if method == 'POST' else '200' - return { + responses = { status_code: { 'content': { ct: {'schema': response_schema} @@ -725,6 +780,20 @@ class AutoSchema(ViewInspector): } } + if method in ('POST', 'PUT', 'PATCH'): + error_schema = self._get_error_reference(self.get_request_serializer(path, method)) + responses['400'] = { + 'content': { + ct: {'schema': error_schema} + for ct in self.response_media_types + }, + # description is a mandatory property, + # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject + # TODO: put something meaningful into it + 'description': "" + } + return responses + def get_tags(self, path, method): # If user have specified tags, use them. if self._tags: diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 925220f33..d39f2dbf1 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -795,6 +795,23 @@ class TestOperationIntrospection(TestCase): }, 'required': ['text'], 'type': 'object' + }, + 'RequestError': { + 'properties': { + 'non_field_errors': { + 'type': 'array', + 'items': { + 'type': 'string' + } + }, + 'text': { + 'type': 'array', + 'items': { + 'type': 'string' + } + } + }, + 'type': 'object' } } @@ -832,6 +849,16 @@ class TestOperationIntrospection(TestCase): } }, 'description': '' + }, + '400': { + 'content': { + 'application/json': { + 'schema': { + '$ref': '#/components/schemas/RequestError' + } + } + }, + 'description': '' } }, 'tags': ['']