Separate schemas per serializer/method

This commit is contained in:
Bill Collins 2022-03-20 12:54:28 +00:00
parent 203aed7af2
commit 20070f4133
2 changed files with 67 additions and 45 deletions

View File

@ -162,7 +162,7 @@ class AutoSchema(ViewInspector):
return operation return operation
def get_component_name(self, serializer): def get_component_name(self, serializer, method):
""" """
Compute the component's name from the serializer. Compute the component's name from the serializer.
Raise an exception if the serializer's class name is "Serializer" (case-insensitive). Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
@ -183,6 +183,13 @@ class AutoSchema(ViewInspector):
.format(serializer.__class__.__name__) .format(serializer.__class__.__name__)
) )
if method.lower() == 'patch':
return component_name + 'PartialUpdate'
if method.lower() == 'put':
return component_name + 'Update'
if method.lower() == 'post':
return component_name + 'Create'
return component_name return component_name
def get_components(self, path, method): def get_components(self, path, method):
@ -194,16 +201,16 @@ class AutoSchema(ViewInspector):
return {} return {}
request_serializer = self.get_request_serializer(path, method) request_serializer = self.get_request_serializer(path, method)
response_serializer = self.get_response_serializer(path, method) response_serializer = self.get_response_serializer(path, 'GET')
if isinstance(request_serializer, serializers.Serializer): if isinstance(request_serializer, serializers.Serializer):
component_name = self.get_component_name(request_serializer) component_name = self.get_component_name(request_serializer, method)
content = self.map_serializer(request_serializer) content = self.map_serializer(request_serializer, method)
self.components.setdefault(component_name, content) self.components.setdefault(component_name, content)
if isinstance(response_serializer, serializers.Serializer): if isinstance(response_serializer, serializers.Serializer):
component_name = self.get_component_name(response_serializer) component_name = self.get_component_name(response_serializer, 'GET')
content = self.map_serializer(response_serializer) content = self.map_serializer(response_serializer, 'GET')
self.components.setdefault(component_name, content) self.components.setdefault(component_name, content)
return self.components return self.components
@ -364,27 +371,27 @@ class AutoSchema(ViewInspector):
mapping['type'] = type mapping['type'] = type
return mapping return mapping
def map_field(self, field): def map_field(self, field, method):
# Nested Serializers, `many` or not. # Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer): if isinstance(field, serializers.ListSerializer):
return { return {
'type': 'array', 'type': 'array',
'items': self.map_serializer(field.child) 'items': self.map_serializer(field.child, method)
} }
if isinstance(field, serializers.Serializer): if isinstance(field, serializers.Serializer):
data = self.map_serializer(field) data = self.map_serializer(field, method)
data['type'] = 'object' data['type'] = 'object'
return data return data
if isinstance(field, serializers.SerializerMethodField) and field.output_field: if isinstance(field, serializers.SerializerMethodField) and field.output_field:
return self.map_field(field.output_field) return self.map_field(field.output_field, method)
# Related fields. # Related fields.
if isinstance(field, serializers.ManyRelatedField): if isinstance(field, serializers.ManyRelatedField):
return { return {
'type': 'array', 'type': 'array',
'items': self.map_field(field.child_relation) 'items': self.map_field(field.child_relation, method)
} }
if isinstance(field, serializers.PrimaryKeyRelatedField): if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None) model = getattr(field.queryset, 'model', None)
@ -413,7 +420,7 @@ class AutoSchema(ViewInspector):
'items': {}, 'items': {},
} }
if not isinstance(field.child, _UnvalidatedField): if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self.map_field(field.child) mapping['items'] = self.map_field(field.child, method)
return mapping return mapping
# DateField and DateTimeField type is string # DateField and DateTimeField type is string
@ -515,7 +522,7 @@ class AutoSchema(ViewInspector):
if field.min_value: if field.min_value:
content['minimum'] = field.min_value content['minimum'] = field.min_value
def map_serializer(self, serializer): def map_serializer(self, serializer, method):
# Assuming we have a valid serializer instance. # Assuming we have a valid serializer instance.
required = [] required = []
properties = {} properties = {}
@ -524,10 +531,24 @@ class AutoSchema(ViewInspector):
if isinstance(field, serializers.HiddenField): if isinstance(field, serializers.HiddenField):
continue continue
if field.required: if method.lower() == 'get':
required.append(field.field_name) if field.write_only:
# Write only fields don't appear in the output
continue
else:
# All non-write-only fields are required in get requests
required.append(field.field_name)
schema = self.map_field(field) if method.lower() in ('post', 'put', 'patch') and field.read_only:
# Don't emit readonly fields for writable methods
continue
if method.lower() != 'patch' and field.required:
# Only mark required fields as required if we're not patching
if field.field_name not in required:
required.append(field.field_name)
schema = self.map_field(field, method)
if field.read_only: if field.read_only:
schema['readOnly'] = True schema['readOnly'] = True
if field.write_only: if field.write_only:
@ -549,9 +570,9 @@ class AutoSchema(ViewInspector):
if required: if required:
result['required'] = required result['required'] = required
component_name = self.get_component_name(serializer=serializer) component_name = self.get_component_name(serializer=serializer, method=method)
self.components[component_name] = result self.components[component_name] = result
return self._get_reference(serializer) return self._get_reference(serializer, method)
def map_field_validators(self, field, schema): def map_field_validators(self, field, schema):
""" """
@ -640,8 +661,8 @@ class AutoSchema(ViewInspector):
""" """
return self.get_serializer(path, method) return self.get_serializer(path, method)
def _get_reference(self, serializer): def _get_reference(self, serializer, method):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer, method))}
def get_request_body(self, path, method): def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'): if method not in ('PUT', 'PATCH', 'POST'):
@ -654,7 +675,7 @@ class AutoSchema(ViewInspector):
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
item_schema = {} item_schema = {}
else: else:
item_schema = self._get_reference(serializer) item_schema = self._get_reference(serializer, method)
return { return {
'content': { 'content': {
@ -673,12 +694,12 @@ class AutoSchema(ViewInspector):
self.response_media_types = self.map_renderers(path, method) self.response_media_types = self.map_renderers(path, method)
serializer = self.get_response_serializer(path, method) serializer = self.get_response_serializer(path, 'GET')
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
item_schema = {} item_schema = {}
else: else:
item_schema = self._get_reference(serializer) item_schema = self._get_reference(serializer, 'GET')
if is_list_view(path, method, self.view): if is_list_view(path, method, self.view):
response_schema = { response_schema = {
@ -779,7 +800,7 @@ class AutoSchema(ViewInspector):
"The old name will be removed in DRF v3.14.", "The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2 RemovedInDRF314Warning, stacklevel=2
) )
return self.map_serializer(serializer) return self.map_serializer(serializer, 'GET')
def _map_field(self, field): def _map_field(self, field):
warnings.warn( warnings.warn(
@ -787,7 +808,7 @@ class AutoSchema(ViewInspector):
"The old name will be removed in DRF v3.14.", "The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2 RemovedInDRF314Warning, stacklevel=2
) )
return self.map_field(field) return self.map_field(field, 'GET')
def _map_choicefield(self, field): def _map_choicefield(self, field):
warnings.warn( warnings.warn(

View File

@ -87,7 +87,7 @@ class TestFieldMapping(TestCase):
] ]
for field, mapping in cases: for field, mapping in cases:
with self.subTest(field=field): with self.subTest(field=field):
assert inspector.map_field(field) == mapping assert inspector.map_field(field, method='GET') == mapping
def test_lazy_string_field(self): def test_lazy_string_field(self):
class ItemSerializer(serializers.Serializer): class ItemSerializer(serializers.Serializer):
@ -95,7 +95,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.map_serializer(ItemSerializer()) inspector.map_serializer(ItemSerializer(), method='GET')
data = inspector.components['Item'] data = inspector.components['Item']
assert isinstance(data['properties']['text']['description'], str), "description must be str" assert isinstance(data['properties']['text']['description'], str), "description must be str"
@ -107,7 +107,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.map_serializer(BooleanTestSerializer()) inspector.map_serializer(BooleanTestSerializer(), method='GET')
data = inspector.components['BooleanTest'] data = inspector.components['BooleanTest']
assert data['properties']['default_true']['default'] is True, "default must be true" assert data['properties']['default_true']['default'] is True, "default must be true"
assert data['properties']['default_false']['default'] is False, "default must be false" assert data['properties']['default_false']['default'] is False, "default must be false"
@ -126,7 +126,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.map_serializer(NullableSerializer()) inspector.map_serializer(NullableSerializer(), method='GET')
data = inspector.components['Nullable'] data = inspector.components['Nullable']
assert data['properties']['rw_field']['nullable'], "rw_field nullable must be true" assert data['properties']['rw_field']['nullable'], "rw_field nullable must be true"
assert data['properties']['ro_field']['nullable'], "ro_field nullable must be true" assert data['properties']['ro_field']['nullable'], "ro_field nullable must be true"
@ -142,7 +142,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.map_serializer(MethodSerializer()) inspector.map_serializer(MethodSerializer(), method='GET')
data = inspector.components['Method'] data = inspector.components['Method']
assert data['properties']['method_field']['type'] == 'boolean' assert data['properties']['method_field']['type'] == 'boolean'
@ -243,11 +243,11 @@ class TestOperationIntrospection(TestCase):
request_body = inspector.get_request_body(path, method) request_body = inspector.get_request_body(path, method)
print(request_body) print(request_body)
assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/ItemCreate'
components = inspector.get_components(path, method) components = inspector.get_components(path, method)
assert components['Item']['required'] == ['text'] assert components['ItemCreate']['required'] == ['text']
assert sorted(list(components['Item']['properties'].keys())) == ['read_only', 'text'] assert sorted(list(components['ItemCreate']['properties'].keys())) == ['text']
def test_invalid_serializer_class_name(self): def test_invalid_serializer_class_name(self):
path = '/' path = '/'
@ -271,7 +271,7 @@ class TestOperationIntrospection(TestCase):
serializer = inspector.get_serializer(path, method) serializer = inspector.get_serializer(path, method)
with pytest.raises(Exception) as exc: with pytest.raises(Exception) as exc:
inspector.get_component_name(serializer) inspector.get_component_name(serializer, method='GET')
assert "is an invalid class name for schema generation" in str(exc.value) assert "is an invalid class name for schema generation" in str(exc.value)
def test_empty_required(self): def test_empty_required(self):
@ -294,7 +294,7 @@ class TestOperationIntrospection(TestCase):
inspector.view = view inspector.view = view
components = inspector.get_components(path, method) components = inspector.get_components(path, method)
component = components['Item'] component = components['ItemCreate']
# there should be no empty 'required' property, see #6834 # there should be no empty 'required' property, see #6834
assert 'required' not in component assert 'required' not in component
@ -321,7 +321,7 @@ class TestOperationIntrospection(TestCase):
inspector.view = view inspector.view = view
components = inspector.get_components(path, method) components = inspector.get_components(path, method)
component = components['Item'] component = components['ItemPartialUpdate']
# there should be no empty 'required' property, see #6834 # there should be no empty 'required' property, see #6834
assert 'required' not in component assert 'required' not in component
for response in inspector.get_responses(path, method).values(): for response in inspector.get_responses(path, method).values():
@ -350,8 +350,10 @@ class TestOperationIntrospection(TestCase):
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
components = inspector.get_components(path, method) components = inspector.get_components(path, method)
assert sorted(components['Item']['required']) == ['text', 'write_only'] assert sorted(components['Item']['required']) == ['text']
assert sorted(list(components['Item']['properties'].keys())) == ['text', 'write_only'] assert sorted(list(components['Item']['properties'].keys())) == ['text']
assert sorted(components['ItemCreate']['required']) == ['text', 'write_only']
assert sorted(list(components['ItemCreate']['properties'].keys())) == ['text', 'write_only']
assert 'description' in responses['201'] assert 'description' in responses['201']
def test_response_body_nested_serializer(self): def test_response_body_nested_serializer(self):
@ -776,7 +778,7 @@ class TestOperationIntrospection(TestCase):
components = inspector.get_components(path, method) components = inspector.get_components(path, method)
assert components == { assert components == {
'Request': { 'RequestCreate': {
'properties': { 'properties': {
'text': { 'text': {
'type': 'string' 'type': 'string'
@ -805,17 +807,17 @@ class TestOperationIntrospection(TestCase):
'content': { 'content': {
'application/json': { 'application/json': {
'schema': { 'schema': {
'$ref': '#/components/schemas/Request' '$ref': '#/components/schemas/RequestCreate'
} }
}, },
'application/x-www-form-urlencoded': { 'application/x-www-form-urlencoded': {
'schema': { 'schema': {
'$ref': '#/components/schemas/Request' '$ref': '#/components/schemas/RequestCreate'
} }
}, },
'multipart/form-data': { 'multipart/form-data': {
'schema': { 'schema': {
'$ref': '#/components/schemas/Request' '$ref': '#/components/schemas/RequestCreate'
} }
} }
} }
@ -1197,14 +1199,13 @@ class TestGenerator(TestCase):
body_schema = route['requestBody']['content']['application/json']['schema'] body_schema = route['requestBody']['content']['application/json']['schema']
assert body_schema == { assert body_schema == {
'$ref': '#/components/schemas/AuthToken' '$ref': '#/components/schemas/AuthTokenCreate'
} }
assert schema['components']['schemas']['AuthToken'] == { assert schema['components']['schemas']['AuthTokenCreate'] == {
'type': 'object', 'type': 'object',
'properties': { 'properties': {
'username': {'type': 'string', 'writeOnly': True}, 'username': {'type': 'string', 'writeOnly': True},
'password': {'type': 'string', 'writeOnly': True}, 'password': {'type': 'string', 'writeOnly': True},
'token': {'type': 'string', 'readOnly': True},
}, },
'required': ['username', 'password'] 'required': ['username', 'password']
} }