Convert openapi.AutoSchema methods to public API.

This commit is contained in:
Carlton Gibson 2020-04-06 17:03:10 +02:00 committed by Carlton Gibson
parent d45e0005f3
commit b2497fc245
2 changed files with 148 additions and 51 deletions

View File

@ -12,7 +12,9 @@ from django.core.validators import (
from django.db import models from django.db import models
from django.utils.encoding import force_str from django.utils.encoding import force_str
from rest_framework import exceptions, renderers, serializers from rest_framework import (
RemovedInDRF314Warning, exceptions, renderers, serializers
)
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -146,15 +148,15 @@ class AutoSchema(ViewInspector):
operation['description'] = self.get_description(path, method) operation['description'] = self.get_description(path, method)
parameters = [] parameters = []
parameters += self._get_path_parameters(path, method) parameters += self.get_path_parameters(path, method)
parameters += self._get_pagination_parameters(path, method) parameters += self.get_pagination_parameters(path, method)
parameters += self._get_filter_parameters(path, method) parameters += self.get_filter_parameters(path, method)
operation['parameters'] = parameters operation['parameters'] = parameters
request_body = self._get_request_body(path, method) request_body = self.get_request_body(path, method)
if request_body: if request_body:
operation['requestBody'] = request_body operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method) operation['responses'] = self.get_responses(path, method)
operation['tags'] = self.get_tags(path, method) operation['tags'] = self.get_tags(path, method)
return operation return operation
@ -190,14 +192,14 @@ class AutoSchema(ViewInspector):
if method.lower() == 'delete': if method.lower() == 'delete':
return {} return {}
serializer = self._get_serializer(path, method) serializer = self.get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
return {} return {}
component_name = self.get_component_name(serializer) component_name = self.get_component_name(serializer)
content = self._map_serializer(serializer) content = self.map_serializer(serializer)
return {component_name: content} return {component_name: content}
def _to_camel_case(self, snake_str): def _to_camel_case(self, snake_str):
@ -220,8 +222,8 @@ class AutoSchema(ViewInspector):
name = model.__name__ name = model.__name__
# Try with the serializer class name # Try with the serializer class name
elif self._get_serializer(path, method) is not None: elif self.get_serializer(path, method) is not None:
name = self._get_serializer(path, method).__class__.__name__ name = self.get_serializer(path, method).__class__.__name__
if name.endswith('Serializer'): if name.endswith('Serializer'):
name = name[:-10] name = name[:-10]
@ -259,7 +261,7 @@ class AutoSchema(ViewInspector):
return action + name return action + name
def _get_path_parameters(self, path, method): def get_path_parameters(self, path, method):
""" """
Return a list of parameters from templated path variables. Return a list of parameters from templated path variables.
""" """
@ -295,15 +297,15 @@ class AutoSchema(ViewInspector):
return parameters return parameters
def _get_filter_parameters(self, path, method): def get_filter_parameters(self, path, method):
if not self._allows_filters(path, method): if not self.allows_filters(path, method):
return [] return []
parameters = [] parameters = []
for filter_backend in self.view.filter_backends: for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view) parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters return parameters
def _allows_filters(self, path, method): def allows_filters(self, path, method):
""" """
Determine whether to include filter Fields in schema. Determine whether to include filter Fields in schema.
@ -316,19 +318,19 @@ class AutoSchema(ViewInspector):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"] return method.lower() in ["get", "put", "patch", "delete"]
def _get_pagination_parameters(self, path, method): def get_pagination_parameters(self, path, method):
view = self.view view = self.view
if not is_list_view(path, method, view): if not is_list_view(path, method, view):
return [] return []
paginator = self._get_paginator() paginator = self.get_paginator()
if not paginator: if not paginator:
return [] return []
return paginator.get_schema_operation_parameters(view) return paginator.get_schema_operation_parameters(view)
def _map_choicefield(self, field): def map_choicefield(self, field):
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates
if all(isinstance(choice, bool) for choice in choices): if all(isinstance(choice, bool) for choice in choices):
type = 'boolean' type = 'boolean'
@ -356,16 +358,16 @@ class AutoSchema(ViewInspector):
mapping['type'] = type mapping['type'] = type
return mapping return mapping
def _map_field(self, field): def map_field(self, field):
# Nested Serializers, `many` or not. # Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer): if isinstance(field, serializers.ListSerializer):
return { return {
'type': 'array', 'type': 'array',
'items': self._map_serializer(field.child) 'items': self.map_serializer(field.child)
} }
if isinstance(field, serializers.Serializer): if isinstance(field, serializers.Serializer):
data = self._map_serializer(field) data = self.map_serializer(field)
data['type'] = 'object' data['type'] = 'object'
return data return data
@ -373,7 +375,7 @@ class AutoSchema(ViewInspector):
if isinstance(field, serializers.ManyRelatedField): if isinstance(field, serializers.ManyRelatedField):
return { return {
'type': 'array', 'type': 'array',
'items': self._map_field(field.child_relation) 'items': self.map_field(field.child_relation)
} }
if isinstance(field, serializers.PrimaryKeyRelatedField): if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None) model = getattr(field.queryset, 'model', None)
@ -389,11 +391,11 @@ class AutoSchema(ViewInspector):
if isinstance(field, serializers.MultipleChoiceField): if isinstance(field, serializers.MultipleChoiceField):
return { return {
'type': 'array', 'type': 'array',
'items': self._map_choicefield(field) 'items': self.map_choicefield(field)
} }
if isinstance(field, serializers.ChoiceField): if isinstance(field, serializers.ChoiceField):
return self._map_choicefield(field) return self.map_choicefield(field)
# ListField. # ListField.
if isinstance(field, serializers.ListField): if isinstance(field, serializers.ListField):
@ -402,7 +404,7 @@ class AutoSchema(ViewInspector):
'items': {}, 'items': {},
} }
if not isinstance(field.child, _UnvalidatedField): if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self._map_field(field.child) mapping['items'] = self.map_field(field.child)
return mapping return mapping
# DateField and DateTimeField type is string # DateField and DateTimeField type is string
@ -504,7 +506,7 @@ class AutoSchema(ViewInspector):
if field.min_value: if field.min_value:
content['minimum'] = field.min_value content['minimum'] = field.min_value
def _map_serializer(self, serializer): def map_serializer(self, serializer):
# Assuming we have a valid serializer instance. # Assuming we have a valid serializer instance.
required = [] required = []
properties = {} properties = {}
@ -516,7 +518,7 @@ class AutoSchema(ViewInspector):
if field.required: if field.required:
required.append(field.field_name) required.append(field.field_name)
schema = self._map_field(field) schema = self.map_field(field)
if field.read_only: if field.read_only:
schema['readOnly'] = True schema['readOnly'] = True
if field.write_only: if field.write_only:
@ -527,7 +529,7 @@ class AutoSchema(ViewInspector):
schema['default'] = field.default schema['default'] = field.default
if field.help_text: if field.help_text:
schema['description'] = str(field.help_text) schema['description'] = str(field.help_text)
self._map_field_validators(field, schema) self.map_field_validators(field, schema)
properties[field.field_name] = schema properties[field.field_name] = schema
@ -540,7 +542,7 @@ class AutoSchema(ViewInspector):
return result return result
def _map_field_validators(self, field, schema): def map_field_validators(self, field, schema):
""" """
map field validators map field validators
""" """
@ -578,7 +580,7 @@ class AutoSchema(ViewInspector):
schema['maximum'] = int(digits * '9') + 1 schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum'] schema['minimum'] = -schema['maximum']
def _get_paginator(self): def get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None) pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class: if pagination_class:
return pagination_class() return pagination_class()
@ -596,7 +598,7 @@ class AutoSchema(ViewInspector):
media_types.append(renderer.media_type) media_types.append(renderer.media_type)
return media_types return media_types
def _get_serializer(self, path, method): def get_serializer(self, path, method):
view = self.view view = self.view
if not hasattr(view, 'get_serializer'): if not hasattr(view, 'get_serializer'):
@ -614,13 +616,13 @@ class AutoSchema(ViewInspector):
def _get_reference(self, serializer): def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
def _get_request_body(self, path, method): def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'): if method not in ('PUT', 'PATCH', 'POST'):
return {} return {}
self.request_media_types = self.map_parsers(path, method) self.request_media_types = self.map_parsers(path, method)
serializer = self._get_serializer(path, method) serializer = self.get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
item_schema = {} item_schema = {}
@ -634,8 +636,7 @@ class AutoSchema(ViewInspector):
} }
} }
def _get_responses(self, path, method): def get_responses(self, path, method):
# TODO: Handle multiple codes and pagination classes.
if method == 'DELETE': if method == 'DELETE':
return { return {
'204': { '204': {
@ -645,7 +646,7 @@ class AutoSchema(ViewInspector):
self.response_media_types = self.map_renderers(path, method) self.response_media_types = self.map_renderers(path, method)
serializer = self._get_serializer(path, method) serializer = self.get_serializer(path, method)
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
item_schema = {} item_schema = {}
@ -657,7 +658,7 @@ class AutoSchema(ViewInspector):
'type': 'array', 'type': 'array',
'items': item_schema, 'items': item_schema,
} }
paginator = self._get_paginator() paginator = self.get_paginator()
if paginator: if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema) response_schema = paginator.get_paginated_response_schema(response_schema)
else: else:
@ -688,3 +689,99 @@ class AutoSchema(ViewInspector):
path = path[1:] path = path[1:]
return [path.split('/')[0].replace('_', '-')] return [path.split('/')[0].replace('_', '-')]
def _get_path_parameters(self, path, method):
warnings.warn(
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_path_parameters(path, method)
def _get_filter_parameters(self, path, method):
warnings.warn(
"Method `_get_filter_parameters()` has been renamed to `get_filter_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_filter_parameters(path, method)
def _get_responses(self, path, method):
warnings.warn(
"Method `_get_responses()` has been renamed to `get_responses()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_responses(path, method)
def _get_request_body(self, path, method):
warnings.warn(
"Method `_get_request_body()` has been renamed to `get_request_body()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_request_body(path, method)
def _get_serializer(self, path, method):
warnings.warn(
"Method `_get_serializer()` has been renamed to `get_serializer()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_serializer(path, method)
def _get_paginator(self):
warnings.warn(
"Method `_get_paginator()` has been renamed to `get_paginator()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_paginator()
def _map_field_validators(self, field, schema):
warnings.warn(
"Method `_map_field_validators()` has been renamed to `map_field_validators()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_field_validators(field, schema)
def _map_serializer(self, serializer):
warnings.warn(
"Method `_map_serializer()` has been renamed to `map_serializer()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_serializer(serializer)
def _map_field(self, field):
warnings.warn(
"Method `_map_field()` has been renamed to `map_field()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_field(field)
def _map_choicefield(self, field):
warnings.warn(
"Method `_map_choicefield()` has been renamed to `map_choicefield()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_choicefield(field)
def _get_pagination_parameters(self, path, method):
warnings.warn(
"Method `_get_pagination_parameters()` has been renamed to `get_pagination_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_pagination_parameters(path, method)
def _allows_filters(self, path, method):
warnings.warn(
"Method `_allows_filters()` has been renamed to `allows_filters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.allows_filters(path, method)

View File

@ -83,7 +83,7 @@ class TestFieldMapping(TestCase):
] ]
for field, mapping in cases: for field, mapping in cases:
with self.subTest(field=field): with self.subTest(field=field):
assert inspector._map_field(field) == mapping assert inspector.map_field(field) == mapping
def test_lazy_string_field(self): def test_lazy_string_field(self):
class ItemSerializer(serializers.Serializer): class ItemSerializer(serializers.Serializer):
@ -91,7 +91,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
data = inspector._map_serializer(ItemSerializer()) data = inspector.map_serializer(ItemSerializer())
assert isinstance(data['properties']['text']['description'], str), "description must be str" assert isinstance(data['properties']['text']['description'], str), "description must be str"
def test_boolean_default_field(self): def test_boolean_default_field(self):
@ -102,7 +102,7 @@ class TestFieldMapping(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
data = inspector._map_serializer(Serializer()) data = inspector.map_serializer(Serializer())
assert data['properties']['default_true']['default'] is True, "default must be true" assert data['properties']['default_true']['default'] is True, "default must be true"
assert data['properties']['default_false']['default'] is False, "default must be false" assert data['properties']['default_false']['default'] is False, "default must be false"
assert 'default' not in data['properties']['without_default'], "default must not be defined" assert 'default' not in data['properties']['without_default'], "default must not be defined"
@ -202,7 +202,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
request_body = inspector._get_request_body(path, method) request_body = inspector.get_request_body(path, method)
print(request_body) print(request_body)
assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
@ -229,7 +229,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
serializer = inspector._get_serializer(path, method) serializer = inspector.get_serializer(path, method)
with pytest.raises(Exception) as exc: with pytest.raises(Exception) as exc:
inspector.get_component_name(serializer) inspector.get_component_name(serializer)
@ -259,7 +259,7 @@ class TestOperationIntrospection(TestCase):
# there should be no empty 'required' property, see #6834 # there should be no empty 'required' property, see #6834
assert 'required' not in component assert 'required' not in component
for response in inspector._get_responses(path, method).values(): for response in inspector.get_responses(path, method).values():
assert 'required' not in component assert 'required' not in component
def test_empty_required_with_patch_method(self): def test_empty_required_with_patch_method(self):
@ -285,7 +285,7 @@ class TestOperationIntrospection(TestCase):
component = components['Item'] component = components['Item']
# there should be no empty 'required' property, see #6834 # there should be no empty 'required' property, see #6834
assert 'required' not in component assert 'required' not in component
for response in inspector._get_responses(path, method).values(): for response in inspector.get_responses(path, method).values():
assert 'required' not in component assert 'required' not in component
def test_response_body_generation(self): def test_response_body_generation(self):
@ -307,7 +307,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
responses = inspector._get_responses(path, method) responses = inspector.get_responses(path, method)
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
components = inspector.get_components(path, method) components = inspector.get_components(path, method)
@ -337,7 +337,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
responses = inspector._get_responses(path, method) responses = inspector.get_responses(path, method)
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
components = inspector.get_components(path, method) components = inspector.get_components(path, method)
assert components['Item'] assert components['Item']
@ -368,7 +368,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
responses = inspector._get_responses(path, method) responses = inspector.get_responses(path, method)
assert responses == { assert responses == {
'200': { '200': {
'description': '', 'description': '',
@ -424,7 +424,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
responses = inspector._get_responses(path, method) responses = inspector.get_responses(path, method)
assert responses == { assert responses == {
'200': { '200': {
'description': '', 'description': '',
@ -472,7 +472,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
responses = inspector._get_responses(path, method) responses = inspector.get_responses(path, method)
assert responses == { assert responses == {
'204': { '204': {
'description': '', 'description': '',
@ -496,7 +496,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
request_body = inspector._get_request_body(path, method) request_body = inspector.get_request_body(path, method)
assert len(request_body['content'].keys()) == 2 assert len(request_body['content'].keys()) == 2
assert 'multipart/form-data' in request_body['content'] assert 'multipart/form-data' in request_body['content']
@ -519,7 +519,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
responses = inspector._get_responses(path, method) responses = inspector.get_responses(path, method)
# TODO this should be changed once the multiple response # TODO this should be changed once the multiple response
# schema support is there # schema support is there
success_response = responses['200'] success_response = responses['200']
@ -594,7 +594,7 @@ class TestOperationIntrospection(TestCase):
inspector = AutoSchema() inspector = AutoSchema()
inspector.view = view inspector.view = view
responses = inspector._get_responses(path, method) responses = inspector.get_responses(path, method)
assert responses == { assert responses == {
'200': { '200': {
'description': '', 'description': '',