From b2497fc2456c607a3c639ed2355c28dac672a70f Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Mon, 6 Apr 2020 17:03:10 +0200 Subject: [PATCH] Convert openapi.AutoSchema methods to public API. --- rest_framework/schemas/openapi.py | 169 +++++++++++++++++++++++------- tests/schemas/test_openapi.py | 30 +++--- 2 files changed, 148 insertions(+), 51 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 7af013444..9b3082822 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -12,7 +12,9 @@ from django.core.validators import ( from django.db import models from django.utils.encoding import force_str -from rest_framework import exceptions, renderers, serializers +from rest_framework import ( + RemovedInDRF314Warning, exceptions, renderers, serializers +) from rest_framework.compat import uritemplate from rest_framework.fields import _UnvalidatedField, empty from rest_framework.settings import api_settings @@ -146,15 +148,15 @@ class AutoSchema(ViewInspector): operation['description'] = self.get_description(path, method) parameters = [] - parameters += self._get_path_parameters(path, method) - parameters += self._get_pagination_parameters(path, method) - parameters += self._get_filter_parameters(path, method) + parameters += self.get_path_parameters(path, method) + parameters += self.get_pagination_parameters(path, method) + parameters += self.get_filter_parameters(path, method) operation['parameters'] = parameters - request_body = self._get_request_body(path, method) + request_body = self.get_request_body(path, method) if request_body: operation['requestBody'] = request_body - operation['responses'] = self._get_responses(path, method) + operation['responses'] = self.get_responses(path, method) operation['tags'] = self.get_tags(path, method) return operation @@ -190,14 +192,14 @@ class AutoSchema(ViewInspector): if method.lower() == 'delete': return {} - serializer = self._get_serializer(path, method) + serializer = self.get_serializer(path, method) if not isinstance(serializer, serializers.Serializer): return {} component_name = self.get_component_name(serializer) - content = self._map_serializer(serializer) + content = self.map_serializer(serializer) return {component_name: content} def _to_camel_case(self, snake_str): @@ -220,8 +222,8 @@ class AutoSchema(ViewInspector): name = model.__name__ # Try with the serializer class name - elif self._get_serializer(path, method) is not None: - name = self._get_serializer(path, method).__class__.__name__ + elif self.get_serializer(path, method) is not None: + name = self.get_serializer(path, method).__class__.__name__ if name.endswith('Serializer'): name = name[:-10] @@ -259,7 +261,7 @@ class AutoSchema(ViewInspector): return action + name - def _get_path_parameters(self, path, method): + def get_path_parameters(self, path, method): """ Return a list of parameters from templated path variables. """ @@ -295,15 +297,15 @@ class AutoSchema(ViewInspector): return parameters - def _get_filter_parameters(self, path, method): - if not self._allows_filters(path, method): + def get_filter_parameters(self, path, method): + if not self.allows_filters(path, method): return [] parameters = [] for filter_backend in self.view.filter_backends: parameters += filter_backend().get_schema_operation_parameters(self.view) return parameters - def _allows_filters(self, path, method): + def allows_filters(self, path, method): """ Determine whether to include filter Fields in schema. @@ -316,19 +318,19 @@ class AutoSchema(ViewInspector): return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] return method.lower() in ["get", "put", "patch", "delete"] - def _get_pagination_parameters(self, path, method): + def get_pagination_parameters(self, path, method): view = self.view if not is_list_view(path, method, view): return [] - paginator = self._get_paginator() + paginator = self.get_paginator() if not paginator: return [] return paginator.get_schema_operation_parameters(view) - def _map_choicefield(self, field): + def map_choicefield(self, field): choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates if all(isinstance(choice, bool) for choice in choices): type = 'boolean' @@ -356,16 +358,16 @@ class AutoSchema(ViewInspector): mapping['type'] = type return mapping - def _map_field(self, field): + def map_field(self, field): # Nested Serializers, `many` or not. if isinstance(field, serializers.ListSerializer): return { 'type': 'array', - 'items': self._map_serializer(field.child) + 'items': self.map_serializer(field.child) } if isinstance(field, serializers.Serializer): - data = self._map_serializer(field) + data = self.map_serializer(field) data['type'] = 'object' return data @@ -373,7 +375,7 @@ class AutoSchema(ViewInspector): if isinstance(field, serializers.ManyRelatedField): return { 'type': 'array', - 'items': self._map_field(field.child_relation) + 'items': self.map_field(field.child_relation) } if isinstance(field, serializers.PrimaryKeyRelatedField): model = getattr(field.queryset, 'model', None) @@ -389,11 +391,11 @@ class AutoSchema(ViewInspector): if isinstance(field, serializers.MultipleChoiceField): return { 'type': 'array', - 'items': self._map_choicefield(field) + 'items': self.map_choicefield(field) } if isinstance(field, serializers.ChoiceField): - return self._map_choicefield(field) + return self.map_choicefield(field) # ListField. if isinstance(field, serializers.ListField): @@ -402,7 +404,7 @@ class AutoSchema(ViewInspector): 'items': {}, } if not isinstance(field.child, _UnvalidatedField): - mapping['items'] = self._map_field(field.child) + mapping['items'] = self.map_field(field.child) return mapping # DateField and DateTimeField type is string @@ -504,7 +506,7 @@ class AutoSchema(ViewInspector): if field.min_value: content['minimum'] = field.min_value - def _map_serializer(self, serializer): + def map_serializer(self, serializer): # Assuming we have a valid serializer instance. required = [] properties = {} @@ -516,7 +518,7 @@ class AutoSchema(ViewInspector): if field.required: required.append(field.field_name) - schema = self._map_field(field) + schema = self.map_field(field) if field.read_only: schema['readOnly'] = True if field.write_only: @@ -527,7 +529,7 @@ class AutoSchema(ViewInspector): schema['default'] = field.default if field.help_text: schema['description'] = str(field.help_text) - self._map_field_validators(field, schema) + self.map_field_validators(field, schema) properties[field.field_name] = schema @@ -540,7 +542,7 @@ class AutoSchema(ViewInspector): return result - def _map_field_validators(self, field, schema): + def map_field_validators(self, field, schema): """ map field validators """ @@ -578,7 +580,7 @@ class AutoSchema(ViewInspector): schema['maximum'] = int(digits * '9') + 1 schema['minimum'] = -schema['maximum'] - def _get_paginator(self): + def get_paginator(self): pagination_class = getattr(self.view, 'pagination_class', None) if pagination_class: return pagination_class() @@ -596,7 +598,7 @@ class AutoSchema(ViewInspector): media_types.append(renderer.media_type) return media_types - def _get_serializer(self, path, method): + def get_serializer(self, path, method): view = self.view if not hasattr(view, 'get_serializer'): @@ -614,13 +616,13 @@ class AutoSchema(ViewInspector): def _get_reference(self, serializer): return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} - def _get_request_body(self, path, method): + def get_request_body(self, path, method): if method not in ('PUT', 'PATCH', 'POST'): return {} self.request_media_types = self.map_parsers(path, method) - serializer = self._get_serializer(path, method) + serializer = self.get_serializer(path, method) if not isinstance(serializer, serializers.Serializer): item_schema = {} @@ -634,8 +636,7 @@ class AutoSchema(ViewInspector): } } - def _get_responses(self, path, method): - # TODO: Handle multiple codes and pagination classes. + def get_responses(self, path, method): if method == 'DELETE': return { '204': { @@ -645,7 +646,7 @@ class AutoSchema(ViewInspector): self.response_media_types = self.map_renderers(path, method) - serializer = self._get_serializer(path, method) + serializer = self.get_serializer(path, method) if not isinstance(serializer, serializers.Serializer): item_schema = {} @@ -657,7 +658,7 @@ class AutoSchema(ViewInspector): 'type': 'array', 'items': item_schema, } - paginator = self._get_paginator() + paginator = self.get_paginator() if paginator: response_schema = paginator.get_paginated_response_schema(response_schema) else: @@ -688,3 +689,99 @@ class AutoSchema(ViewInspector): path = path[1:] return [path.split('/')[0].replace('_', '-')] + + def _get_path_parameters(self, path, method): + warnings.warn( + "Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_path_parameters(path, method) + + def _get_filter_parameters(self, path, method): + warnings.warn( + "Method `_get_filter_parameters()` has been renamed to `get_filter_parameters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_filter_parameters(path, method) + + def _get_responses(self, path, method): + warnings.warn( + "Method `_get_responses()` has been renamed to `get_responses()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_responses(path, method) + + def _get_request_body(self, path, method): + warnings.warn( + "Method `_get_request_body()` has been renamed to `get_request_body()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_request_body(path, method) + + def _get_serializer(self, path, method): + warnings.warn( + "Method `_get_serializer()` has been renamed to `get_serializer()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_serializer(path, method) + + def _get_paginator(self): + warnings.warn( + "Method `_get_paginator()` has been renamed to `get_paginator()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_paginator() + + def _map_field_validators(self, field, schema): + warnings.warn( + "Method `_map_field_validators()` has been renamed to `map_field_validators()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_field_validators(field, schema) + + def _map_serializer(self, serializer): + warnings.warn( + "Method `_map_serializer()` has been renamed to `map_serializer()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_serializer(serializer) + + def _map_field(self, field): + warnings.warn( + "Method `_map_field()` has been renamed to `map_field()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_field(field) + + def _map_choicefield(self, field): + warnings.warn( + "Method `_map_choicefield()` has been renamed to `map_choicefield()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_choicefield(field) + + def _get_pagination_parameters(self, path, method): + warnings.warn( + "Method `_get_pagination_parameters()` has been renamed to `get_pagination_parameters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_pagination_parameters(path, method) + + def _allows_filters(self, path, method): + warnings.warn( + "Method `_allows_filters()` has been renamed to `allows_filters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.allows_filters(path, method) diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 774636972..0e86a7f50 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -83,7 +83,7 @@ class TestFieldMapping(TestCase): ] for field, mapping in cases: with self.subTest(field=field): - assert inspector._map_field(field) == mapping + assert inspector.map_field(field) == mapping def test_lazy_string_field(self): class ItemSerializer(serializers.Serializer): @@ -91,7 +91,7 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - data = inspector._map_serializer(ItemSerializer()) + data = inspector.map_serializer(ItemSerializer()) assert isinstance(data['properties']['text']['description'], str), "description must be str" def test_boolean_default_field(self): @@ -102,7 +102,7 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - data = inspector._map_serializer(Serializer()) + data = inspector.map_serializer(Serializer()) assert data['properties']['default_true']['default'] is True, "default must be true" assert data['properties']['default_false']['default'] is False, "default must be false" assert 'default' not in data['properties']['without_default'], "default must not be defined" @@ -202,7 +202,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - request_body = inspector._get_request_body(path, method) + request_body = inspector.get_request_body(path, method) print(request_body) assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' @@ -229,7 +229,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - serializer = inspector._get_serializer(path, method) + serializer = inspector.get_serializer(path, method) with pytest.raises(Exception) as exc: inspector.get_component_name(serializer) @@ -259,7 +259,7 @@ class TestOperationIntrospection(TestCase): # there should be no empty 'required' property, see #6834 assert 'required' not in component - for response in inspector._get_responses(path, method).values(): + for response in inspector.get_responses(path, method).values(): assert 'required' not in component def test_empty_required_with_patch_method(self): @@ -285,7 +285,7 @@ class TestOperationIntrospection(TestCase): component = components['Item'] # there should be no empty 'required' property, see #6834 assert 'required' not in component - for response in inspector._get_responses(path, method).values(): + for response in inspector.get_responses(path, method).values(): assert 'required' not in component def test_response_body_generation(self): @@ -307,7 +307,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) + responses = inspector.get_responses(path, method) assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' components = inspector.get_components(path, method) @@ -337,7 +337,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) + responses = inspector.get_responses(path, method) assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' components = inspector.get_components(path, method) assert components['Item'] @@ -368,7 +368,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) + responses = inspector.get_responses(path, method) assert responses == { '200': { 'description': '', @@ -424,7 +424,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) + responses = inspector.get_responses(path, method) assert responses == { '200': { 'description': '', @@ -472,7 +472,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) + responses = inspector.get_responses(path, method) assert responses == { '204': { 'description': '', @@ -496,7 +496,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - request_body = inspector._get_request_body(path, method) + request_body = inspector.get_request_body(path, method) assert len(request_body['content'].keys()) == 2 assert 'multipart/form-data' in request_body['content'] @@ -519,7 +519,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) + responses = inspector.get_responses(path, method) # TODO this should be changed once the multiple response # schema support is there success_response = responses['200'] @@ -594,7 +594,7 @@ class TestOperationIntrospection(TestCase): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) + responses = inspector.get_responses(path, method) assert responses == { '200': { 'description': '',