Separate schemas per serializer/method

This commit is contained in:
Bill Collins 2022-03-20 12:54:28 +00:00
parent 203aed7af2
commit 20070f4133
2 changed files with 67 additions and 45 deletions

View File

@ -162,7 +162,7 @@ class AutoSchema(ViewInspector):
return operation
def get_component_name(self, serializer):
def get_component_name(self, serializer, method):
"""
Compute the component's name from the serializer.
Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
@ -183,6 +183,13 @@ class AutoSchema(ViewInspector):
.format(serializer.__class__.__name__)
)
if method.lower() == 'patch':
return component_name + 'PartialUpdate'
if method.lower() == 'put':
return component_name + 'Update'
if method.lower() == 'post':
return component_name + 'Create'
return component_name
def get_components(self, path, method):
@ -194,16 +201,16 @@ class AutoSchema(ViewInspector):
return {}
request_serializer = self.get_request_serializer(path, method)
response_serializer = self.get_response_serializer(path, method)
response_serializer = self.get_response_serializer(path, 'GET')
if isinstance(request_serializer, serializers.Serializer):
component_name = self.get_component_name(request_serializer)
content = self.map_serializer(request_serializer)
component_name = self.get_component_name(request_serializer, method)
content = self.map_serializer(request_serializer, method)
self.components.setdefault(component_name, content)
if isinstance(response_serializer, serializers.Serializer):
component_name = self.get_component_name(response_serializer)
content = self.map_serializer(response_serializer)
component_name = self.get_component_name(response_serializer, 'GET')
content = self.map_serializer(response_serializer, 'GET')
self.components.setdefault(component_name, content)
return self.components
@ -364,27 +371,27 @@ class AutoSchema(ViewInspector):
mapping['type'] = type
return mapping
def map_field(self, field):
def map_field(self, field, method):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self.map_serializer(field.child)
'items': self.map_serializer(field.child, method)
}
if isinstance(field, serializers.Serializer):
data = self.map_serializer(field)
data = self.map_serializer(field, method)
data['type'] = 'object'
return data
if isinstance(field, serializers.SerializerMethodField) and field.output_field:
return self.map_field(field.output_field)
return self.map_field(field.output_field, method)
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self.map_field(field.child_relation)
'items': self.map_field(field.child_relation, method)
}
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
@ -413,7 +420,7 @@ class AutoSchema(ViewInspector):
'items': {},
}
if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self.map_field(field.child)
mapping['items'] = self.map_field(field.child, method)
return mapping
# DateField and DateTimeField type is string
@ -515,7 +522,7 @@ class AutoSchema(ViewInspector):
if field.min_value:
content['minimum'] = field.min_value
def map_serializer(self, serializer):
def map_serializer(self, serializer, method):
# Assuming we have a valid serializer instance.
required = []
properties = {}
@ -524,10 +531,24 @@ class AutoSchema(ViewInspector):
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
if method.lower() == 'get':
if field.write_only:
# Write only fields don't appear in the output
continue
else:
# All non-write-only fields are required in get requests
required.append(field.field_name)
schema = self.map_field(field)
if method.lower() in ('post', 'put', 'patch') and field.read_only:
# Don't emit readonly fields for writable methods
continue
if method.lower() != 'patch' and field.required:
# Only mark required fields as required if we're not patching
if field.field_name not in required:
required.append(field.field_name)
schema = self.map_field(field, method)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
@ -549,9 +570,9 @@ class AutoSchema(ViewInspector):
if required:
result['required'] = required
component_name = self.get_component_name(serializer=serializer)
component_name = self.get_component_name(serializer=serializer, method=method)
self.components[component_name] = result
return self._get_reference(serializer)
return self._get_reference(serializer, method)
def map_field_validators(self, field, schema):
"""
@ -640,8 +661,8 @@ class AutoSchema(ViewInspector):
"""
return self.get_serializer(path, method)
def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
def _get_reference(self, serializer, method):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer, method))}
def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
@ -654,7 +675,7 @@ class AutoSchema(ViewInspector):
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self._get_reference(serializer)
item_schema = self._get_reference(serializer, method)
return {
'content': {
@ -673,12 +694,12 @@ class AutoSchema(ViewInspector):
self.response_media_types = self.map_renderers(path, method)
serializer = self.get_response_serializer(path, method)
serializer = self.get_response_serializer(path, 'GET')
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self._get_reference(serializer)
item_schema = self._get_reference(serializer, 'GET')
if is_list_view(path, method, self.view):
response_schema = {
@ -779,7 +800,7 @@ class AutoSchema(ViewInspector):
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_serializer(serializer)
return self.map_serializer(serializer, 'GET')
def _map_field(self, field):
warnings.warn(
@ -787,7 +808,7 @@ class AutoSchema(ViewInspector):
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_field(field)
return self.map_field(field, 'GET')
def _map_choicefield(self, field):
warnings.warn(

View File

@ -87,7 +87,7 @@ class TestFieldMapping(TestCase):
]
for field, mapping in cases:
with self.subTest(field=field):
assert inspector.map_field(field) == mapping
assert inspector.map_field(field, method='GET') == mapping
def test_lazy_string_field(self):
class ItemSerializer(serializers.Serializer):
@ -95,7 +95,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema()
inspector.map_serializer(ItemSerializer())
inspector.map_serializer(ItemSerializer(), method='GET')
data = inspector.components['Item']
assert isinstance(data['properties']['text']['description'], str), "description must be str"
@ -107,7 +107,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema()
inspector.map_serializer(BooleanTestSerializer())
inspector.map_serializer(BooleanTestSerializer(), method='GET')
data = inspector.components['BooleanTest']
assert data['properties']['default_true']['default'] is True, "default must be true"
assert data['properties']['default_false']['default'] is False, "default must be false"
@ -126,7 +126,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema()
inspector.map_serializer(NullableSerializer())
inspector.map_serializer(NullableSerializer(), method='GET')
data = inspector.components['Nullable']
assert data['properties']['rw_field']['nullable'], "rw_field nullable must be true"
assert data['properties']['ro_field']['nullable'], "ro_field nullable must be true"
@ -142,7 +142,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema()
inspector.map_serializer(MethodSerializer())
inspector.map_serializer(MethodSerializer(), method='GET')
data = inspector.components['Method']
assert data['properties']['method_field']['type'] == 'boolean'
@ -243,11 +243,11 @@ class TestOperationIntrospection(TestCase):
request_body = inspector.get_request_body(path, method)
print(request_body)
assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/ItemCreate'
components = inspector.get_components(path, method)
assert components['Item']['required'] == ['text']
assert sorted(list(components['Item']['properties'].keys())) == ['read_only', 'text']
assert components['ItemCreate']['required'] == ['text']
assert sorted(list(components['ItemCreate']['properties'].keys())) == ['text']
def test_invalid_serializer_class_name(self):
path = '/'
@ -271,7 +271,7 @@ class TestOperationIntrospection(TestCase):
serializer = inspector.get_serializer(path, method)
with pytest.raises(Exception) as exc:
inspector.get_component_name(serializer)
inspector.get_component_name(serializer, method='GET')
assert "is an invalid class name for schema generation" in str(exc.value)
def test_empty_required(self):
@ -294,7 +294,7 @@ class TestOperationIntrospection(TestCase):
inspector.view = view
components = inspector.get_components(path, method)
component = components['Item']
component = components['ItemCreate']
# there should be no empty 'required' property, see #6834
assert 'required' not in component
@ -321,7 +321,7 @@ class TestOperationIntrospection(TestCase):
inspector.view = view
components = inspector.get_components(path, method)
component = components['Item']
component = components['ItemPartialUpdate']
# there should be no empty 'required' property, see #6834
assert 'required' not in component
for response in inspector.get_responses(path, method).values():
@ -350,8 +350,10 @@ class TestOperationIntrospection(TestCase):
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
components = inspector.get_components(path, method)
assert sorted(components['Item']['required']) == ['text', 'write_only']
assert sorted(list(components['Item']['properties'].keys())) == ['text', 'write_only']
assert sorted(components['Item']['required']) == ['text']
assert sorted(list(components['Item']['properties'].keys())) == ['text']
assert sorted(components['ItemCreate']['required']) == ['text', 'write_only']
assert sorted(list(components['ItemCreate']['properties'].keys())) == ['text', 'write_only']
assert 'description' in responses['201']
def test_response_body_nested_serializer(self):
@ -776,7 +778,7 @@ class TestOperationIntrospection(TestCase):
components = inspector.get_components(path, method)
assert components == {
'Request': {
'RequestCreate': {
'properties': {
'text': {
'type': 'string'
@ -805,17 +807,17 @@ class TestOperationIntrospection(TestCase):
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/Request'
'$ref': '#/components/schemas/RequestCreate'
}
},
'application/x-www-form-urlencoded': {
'schema': {
'$ref': '#/components/schemas/Request'
'$ref': '#/components/schemas/RequestCreate'
}
},
'multipart/form-data': {
'schema': {
'$ref': '#/components/schemas/Request'
'$ref': '#/components/schemas/RequestCreate'
}
}
}
@ -1197,14 +1199,13 @@ class TestGenerator(TestCase):
body_schema = route['requestBody']['content']['application/json']['schema']
assert body_schema == {
'$ref': '#/components/schemas/AuthToken'
'$ref': '#/components/schemas/AuthTokenCreate'
}
assert schema['components']['schemas']['AuthToken'] == {
assert schema['components']['schemas']['AuthTokenCreate'] == {
'type': 'object',
'properties': {
'username': {'type': 'string', 'writeOnly': True},
'password': {'type': 'string', 'writeOnly': True},
'token': {'type': 'string', 'readOnly': True},
},
'required': ['username', 'password']
}