Add error responses to schema

This commit is contained in:
Bill Collins 2022-03-20 13:52:35 +00:00
parent 20070f4133
commit 764e399cfd
2 changed files with 97 additions and 1 deletions

View File

@ -192,6 +192,9 @@ class AutoSchema(ViewInspector):
return component_name return component_name
def get_error_component_name(self, serializer):
return self.get_component_name(serializer, 'GET') + 'Error'
def get_components(self, path, method): def get_components(self, path, method):
""" """
Return components with their properties from the serializer. Return components with their properties from the serializer.
@ -208,6 +211,11 @@ class AutoSchema(ViewInspector):
content = self.map_serializer(request_serializer, method) content = self.map_serializer(request_serializer, method)
self.components.setdefault(component_name, content) self.components.setdefault(component_name, content)
if method.lower() in ('put', 'post', 'patch'):
error_component_name = self.get_error_component_name(request_serializer)
error_content = self.map_error_serializer(request_serializer)
self.components.setdefault(error_component_name, error_content)
if isinstance(response_serializer, serializers.Serializer): if isinstance(response_serializer, serializers.Serializer):
component_name = self.get_component_name(response_serializer, 'GET') component_name = self.get_component_name(response_serializer, 'GET')
content = self.map_serializer(response_serializer, 'GET') content = self.map_serializer(response_serializer, 'GET')
@ -516,6 +524,23 @@ class AutoSchema(ViewInspector):
} }
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def map_error_field(self, field):
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self.map_error_serializer(field.child)
}
if isinstance(field, serializers.Serializer):
return self.map_error_serializer(field)
if isinstance(field, serializers.SerializerMethodField) and isinstance(field.output_field, serializers.Serializer):
return self.map_error_serializer(field.output_field)
return {
'type': 'array',
'items': {
'type': 'string'
},
}
def _map_min_max(self, field, content): def _map_min_max(self, field, content):
if field.max_value: if field.max_value:
content['maximum'] = field.max_value content['maximum'] = field.max_value
@ -574,6 +599,33 @@ class AutoSchema(ViewInspector):
self.components[component_name] = result self.components[component_name] = result
return self._get_reference(serializer, method) return self._get_reference(serializer, method)
def map_error_serializer(self, serializer):
properties = {
api_settings.NON_FIELD_ERRORS_KEY: {
'type': 'array',
'items': {
'type': 'string'
}
}
}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.read_only:
continue
properties[field.field_name] = self.map_error_field(field)
result = {
'type': 'object',
'properties': properties
}
component_name = self.get_error_component_name(serializer=serializer)
self.components[component_name] = result
return self._get_error_reference(serializer)
def map_field_validators(self, field, schema): def map_field_validators(self, field, schema):
""" """
map field validators map field validators
@ -664,6 +716,9 @@ class AutoSchema(ViewInspector):
def _get_reference(self, serializer, method): def _get_reference(self, serializer, method):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer, method))} return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer, method))}
def _get_error_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_error_component_name(serializer))}
def get_request_body(self, path, method): def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'): if method not in ('PUT', 'PATCH', 'POST'):
return {} return {}
@ -712,7 +767,7 @@ class AutoSchema(ViewInspector):
else: else:
response_schema = item_schema response_schema = item_schema
status_code = '201' if method == 'POST' else '200' status_code = '201' if method == 'POST' else '200'
return { responses = {
status_code: { status_code: {
'content': { 'content': {
ct: {'schema': response_schema} ct: {'schema': response_schema}
@ -725,6 +780,20 @@ class AutoSchema(ViewInspector):
} }
} }
if method in ('POST', 'PUT', 'PATCH'):
error_schema = self._get_error_reference(self.get_request_serializer(path, method))
responses['400'] = {
'content': {
ct: {'schema': error_schema}
for ct in self.response_media_types
},
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it
'description': ""
}
return responses
def get_tags(self, path, method): def get_tags(self, path, method):
# If user have specified tags, use them. # If user have specified tags, use them.
if self._tags: if self._tags:

View File

@ -795,6 +795,23 @@ class TestOperationIntrospection(TestCase):
}, },
'required': ['text'], 'required': ['text'],
'type': 'object' 'type': 'object'
},
'RequestError': {
'properties': {
'non_field_errors': {
'type': 'array',
'items': {
'type': 'string'
}
},
'text': {
'type': 'array',
'items': {
'type': 'string'
}
}
},
'type': 'object'
} }
} }
@ -832,6 +849,16 @@ class TestOperationIntrospection(TestCase):
} }
}, },
'description': '' 'description': ''
},
'400': {
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/RequestError'
}
}
},
'description': ''
} }
}, },
'tags': [''] 'tags': ['']