From 1535ee2a3c0f861a040c2599c3e4755f394a1747 Mon Sep 17 00:00:00 2001 From: Thorsten Franzel Date: Fri, 13 Dec 2019 23:04:24 +0100 Subject: [PATCH] improvement: add operationId override to @extend_schema; fix tags --- rest_framework/schemas/openapi.py | 47 ++++++++++++------------- rest_framework/schemas/openapi_utils.py | 9 ++++- tests/schemas/test_openapi.py | 6 ++-- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index f0d133173..9353ac333 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -141,7 +141,7 @@ class AutoSchema(ViewInspector): def get_operation(self, path, method): operation = {} - operation['operationId'] = self._get_operation_id(path, method) + operation['operationId'] = self.get_operation_id(path, method) operation['description'] = self.get_description(path, method) operation['parameters'] = sorted( [ @@ -207,38 +207,35 @@ class AutoSchema(ViewInspector): def get_tags(self, path, method): """ override this for custom behaviour """ - path = re.sub( - pattern=api_settings.SCHEMA_PATH_PREFIX, - repl='', - string=path, - flags=re.IGNORECASE - ).split('/') - return [path[0]] + tokenized_path = self._tokenize_path(path) + # use first non-parameter path part as tag + return tokenized_path[:1] - def _get_operation_id(self, path, method): - """ - Compute an operation ID from the model, serializer or view name. - """ - # remove path prefix - sub_path = re.sub( - pattern=api_settings.SCHEMA_PATH_PREFIX, - repl='', - string=path, - flags=re.IGNORECASE - ) - # cleanup, normalize and tokenize remaining parts. + def get_operation_id(self, path, method): + """ override this for custom behaviour """ + tokenized_path = self._tokenize_path(path) # replace dashes as they can be problematic later in code generation - sub_path = sub_path.replace('-', '_').rstrip('/').lstrip('/') - sub_path = sub_path.split('/') if sub_path else [] - # remove path variables - sub_path = [p for p in sub_path if not p.startswith('{')] + tokenized_path = [t.replace('-', '_') for t in tokenized_path] if is_list_view(path, method, self.view): action = 'list' else: action = self.method_mapping[method.lower()] - return '_'.join(sub_path + [action]) + return '_'.join(tokenized_path + [action]) + + def _tokenize_path(self, path): + # remove path prefix + path = re.sub( + pattern=api_settings.SCHEMA_PATH_PREFIX, + repl='', + string=path, + flags=re.IGNORECASE + ) + # cleanup and tokenize remaining parts. + path = path.rstrip('/').lstrip('/').split('/') + # remove path variables and empty tokens + return [t for t in path if t and not t.startswith('{')] def _get_path_parameters(self, path, method): """ diff --git a/rest_framework/schemas/openapi_utils.py b/rest_framework/schemas/openapi_utils.py index 8b64ee86a..7336268b9 100644 --- a/rest_framework/schemas/openapi_utils.py +++ b/rest_framework/schemas/openapi_utils.py @@ -92,6 +92,7 @@ class QueryParameter(OpenApiSchemaBase): def extend_schema( operation=None, + operation_id=None, extra_parameters=None, responses=None, request=None, @@ -117,10 +118,16 @@ def extend_schema( return operation return super().get_operation(path, method) + def get_operation_id(self, path, method): + if operation_id: + return operation_id + return super().get_operation_id(path, method) + def get_extra_parameters(self, path, method): if extra_parameters: return [ - p.to_schema() if isinstance(p, OpenApiSchemaBase) else p for p in extra_parameters + p.to_schema() if isinstance(p, OpenApiSchemaBase) else p + for p in extra_parameters ] return super().get_extra_parameters(path, method) diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 277c12769..faec64c8e 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -90,7 +90,7 @@ class TestOperationIntrospection(TestCase): 'operationId': 'example_list', 'description': 'get: A description of my GET operation.\npost: A description of my POST operation.', 'parameters': [], - 'tags': [''], + 'tags': ['example'], 'security': [{'cookieAuth': []}, {'basicAuth': []}, {}], 'responses': { '200': { @@ -138,7 +138,7 @@ class TestOperationIntrospection(TestCase): } } ], - 'tags': [''], + 'tags': ['example'], 'security': [{'cookieAuth': []}, {'basicAuth': []}, {}], 'responses': { '200': { @@ -555,7 +555,7 @@ class TestOperationIntrospection(TestCase): inspector.view = view inspector.init(ComponentRegistry()) - operationId = inspector._get_operation_id(path, method) + operationId = inspector.get_operation_id(path, method) assert operationId == 'list' def test_repeat_operation_ids(self):