Nested serializers generates separate OpenApi components.

This commit is contained in:
Torstein A. Bø 2020-08-27 15:35:46 +02:00
parent 05120c4658
commit d44afafbc7
2 changed files with 44 additions and 35 deletions

View File

@ -197,10 +197,8 @@ class AutoSchema(ViewInspector):
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
return {} return {}
component_name = self.get_component_name(serializer) _, components = self.map_serializer(serializer)
return components
content = self.map_serializer(serializer)
return {component_name: content}
def _to_camel_case(self, snake_str): def _to_camel_case(self, snake_str):
components = snake_str.split('_') components = snake_str.split('_')
@ -365,24 +363,24 @@ class AutoSchema(ViewInspector):
return { return {
'type': 'array', 'type': 'array',
'items': self.map_serializer(field.child) 'items': self.map_serializer(field.child)
} }, {}
if isinstance(field, serializers.Serializer): if isinstance(field, serializers.Serializer):
data = self.map_serializer(field) data, components = self.map_serializer(field)
data['type'] = 'object' return data, components
return data
# Related fields. # Related fields.
if isinstance(field, serializers.ManyRelatedField): if isinstance(field, serializers.ManyRelatedField):
items, components = self.map_field(field.child_relation)
return { return {
'type': 'array', 'type': 'array',
'items': self.map_field(field.child_relation) 'items': items
} }, components
if isinstance(field, serializers.PrimaryKeyRelatedField): if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None) model = getattr(field.queryset, 'model', None)
if model is not None: if model is not None:
model_field = model._meta.pk model_field = model._meta.pk
if isinstance(model_field, models.AutoField): if isinstance(model_field, models.AutoField):
return {'type': 'integer'} return {'type': 'integer'}, {}
# ChoiceFields (single and multiple). # ChoiceFields (single and multiple).
# Q: # Q:
@ -392,33 +390,35 @@ class AutoSchema(ViewInspector):
return { return {
'type': 'array', 'type': 'array',
'items': self.map_choicefield(field) 'items': self.map_choicefield(field)
} }, {}
if isinstance(field, serializers.ChoiceField): if isinstance(field, serializers.ChoiceField):
return self.map_choicefield(field) return self.map_choicefield(field), {}
# ListField. # ListField.
if isinstance(field, serializers.ListField): if isinstance(field, serializers.ListField):
components = {}
mapping = { mapping = {
'type': 'array', 'type': 'array',
'items': {}, 'items': {},
} }
if not isinstance(field.child, _UnvalidatedField): if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self.map_field(field.child) items, components = self.map_field(field.child)
return mapping mapping['items'] = items
return mapping, components
# DateField and DateTimeField type is string # DateField and DateTimeField type is string
if isinstance(field, serializers.DateField): if isinstance(field, serializers.DateField):
return { return {
'type': 'string', 'type': 'string',
'format': 'date', 'format': 'date',
} }, {}
if isinstance(field, serializers.DateTimeField): if isinstance(field, serializers.DateTimeField):
return { return {
'type': 'string', 'type': 'string',
'format': 'date-time', 'format': 'date-time',
} }, {}
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
@ -427,19 +427,19 @@ class AutoSchema(ViewInspector):
return { return {
'type': 'string', 'type': 'string',
'format': 'email' 'format': 'email'
} }, {}
if isinstance(field, serializers.URLField): if isinstance(field, serializers.URLField):
return { return {
'type': 'string', 'type': 'string',
'format': 'uri' 'format': 'uri'
} }, {}
if isinstance(field, serializers.UUIDField): if isinstance(field, serializers.UUIDField):
return { return {
'type': 'string', 'type': 'string',
'format': 'uuid' 'format': 'uuid'
} }, {}
if isinstance(field, serializers.IPAddressField): if isinstance(field, serializers.IPAddressField):
content = { content = {
@ -447,7 +447,7 @@ class AutoSchema(ViewInspector):
} }
if field.protocol != 'both': if field.protocol != 'both':
content['format'] = field.protocol content['format'] = field.protocol
return content return content, {}
if isinstance(field, serializers.DecimalField): if isinstance(field, serializers.DecimalField):
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING): if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
@ -466,14 +466,14 @@ class AutoSchema(ViewInspector):
content['maximum'] = int(field.max_whole_digits * '9') + 1 content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum'] content['minimum'] = -content['maximum']
self._map_min_max(field, content) self._map_min_max(field, content)
return content return content, {}
if isinstance(field, serializers.FloatField): if isinstance(field, serializers.FloatField):
content = { content = {
'type': 'number', 'type': 'number',
} }
self._map_min_max(field, content) self._map_min_max(field, content)
return content return content, {}
if isinstance(field, serializers.IntegerField): if isinstance(field, serializers.IntegerField):
content = { content = {
@ -483,13 +483,13 @@ class AutoSchema(ViewInspector):
# 2147483647 is max for int32_size, so we use int64 for format # 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647: if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
content['format'] = 'int64' content['format'] = 'int64'
return content return content, {}
if isinstance(field, serializers.FileField): if isinstance(field, serializers.FileField):
return { return {
'type': 'string', 'type': 'string',
'format': 'binary' 'format': 'binary'
} }, {}
# Simplest cases, default to 'string' type: # Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = { FIELD_CLASS_SCHEMA_TYPE = {
@ -498,7 +498,7 @@ class AutoSchema(ViewInspector):
serializers.DictField: 'object', serializers.DictField: 'object',
serializers.HStoreField: 'object', serializers.HStoreField: 'object',
} }
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}, {}
def _map_min_max(self, field, content): def _map_min_max(self, field, content):
if field.max_value: if field.max_value:
@ -516,6 +516,9 @@ class AutoSchema(ViewInspector):
except AttributeError: except AttributeError:
pass pass
component_name = self.get_component_name(serializer)
components = {}
for field in serializer.fields.values(): for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField): if isinstance(field, serializers.HiddenField):
continue continue
@ -523,7 +526,8 @@ class AutoSchema(ViewInspector):
if field.required: if field.required:
required.append(field.field_name) required.append(field.field_name)
schema = self.map_field(field) schema, subcomponents = self.map_field(field)
components.update(subcomponents)
if field.read_only: if field.read_only:
schema['readOnly'] = True schema['readOnly'] = True
if field.write_only: if field.write_only:
@ -544,8 +548,9 @@ class AutoSchema(ViewInspector):
} }
if required: if required:
result['required'] = required result['required'] = required
components[component_name] = result
return result return {'$ref': '#/components/schemas/%s' % component_name}, components
def map_field_validators(self, field, schema): def map_field_validators(self, field, schema):
""" """

View File

@ -83,7 +83,8 @@ class TestFieldMapping(TestCase):
] ]
for field, mapping in cases: for field, mapping in cases:
with self.subTest(field=field): with self.subTest(field=field):
assert inspector.map_field(field) == mapping data, _ = inspector.map_field(field)
assert data == mapping
def test_lazy_string_field(self): def test_lazy_string_field(self):
class ItemSerializer(serializers.Serializer): class ItemSerializer(serializers.Serializer):
@ -91,18 +92,20 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
data = inspector.map_serializer(ItemSerializer()) ref, component = inspector.map_serializer(ItemSerializer())
data = component['Item']
assert isinstance(data['properties']['text']['description'], str), "description must be str" assert isinstance(data['properties']['text']['description'], str), "description must be str"
def test_boolean_default_field(self): def test_boolean_default_field(self):
class Serializer(serializers.Serializer): class BooleanSerializer(serializers.Serializer):
default_true = serializers.BooleanField(default=True) default_true = serializers.BooleanField(default=True)
default_false = serializers.BooleanField(default=False) default_false = serializers.BooleanField(default=False)
without_default = serializers.BooleanField() without_default = serializers.BooleanField()
inspector = AutoSchema() inspector = AutoSchema()
data = inspector.map_serializer(Serializer()) ref, component = inspector.map_serializer(BooleanSerializer())
data = component['Boolean']
assert data['properties']['default_true']['default'] is True, "default must be true" assert data['properties']['default_true']['default'] is True, "default must be true"
assert data['properties']['default_false']['default'] is False, "default must be false" assert data['properties']['default_false']['default'] is False, "default must be false"
assert 'default' not in data['properties']['without_default'], "default must not be defined" assert 'default' not in data['properties']['without_default'], "default must not be defined"
@ -345,9 +348,10 @@ class TestOperationIntrospection(TestCase):
schema = components['Item'] schema = components['Item']
assert sorted(schema['required']) == ['nested', 'text'] assert sorted(schema['required']) == ['nested', 'text']
assert sorted(list(schema['properties'].keys())) == ['nested', 'text'] assert sorted(list(schema['properties'].keys())) == ['nested', 'text']
assert schema['properties']['nested']['type'] == 'object' assert schema['properties']['nested']['$ref'] == '#/components/schemas/Nested'
assert list(schema['properties']['nested']['properties'].keys()) == ['number'] nested = components['Nested']
assert schema['properties']['nested']['required'] == ['number'] assert list(nested['properties'].keys()) == ['number']
assert nested['required'] == ['number']
def test_list_response_body_generation(self): def test_list_response_body_generation(self):
"""Test that an array schema is returned for list views.""" """Test that an array schema is returned for list views."""