From b29445a196da78a4e98b5046a6ddb8058f0677ee Mon Sep 17 00:00:00 2001 From: Martin Desrumaux Date: Fri, 28 Feb 2020 17:45:40 +0100 Subject: [PATCH] feat(openapi/operationId): Warn user about duplicate operationId in the schema --- rest_framework/schemas/openapi.py | 28 ++++++++++++++++++++++++++++ tests/schemas/test_openapi.py | 19 +++++++++++++++++++ tests/schemas/views.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 5277f17a6..ceec4517e 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -34,6 +34,32 @@ class SchemaGenerator(BaseSchemaGenerator): return info + def check_duplicate_operation_id(self, paths): + ids = {} + for route in paths: + for method in paths[route]: + if 'operationId' not in paths[route][method]: + continue + operation_id = paths[route][method]['operationId'] + if operation_id in ids: + warnings.warn( + 'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n' + '\tRoute: {route1}, Method: {method1}\n' + '\tRoute: {route2}, Method: {method2}\n' + '\tAn operationId has to be unique accros your schema. Your schema may not work in other tools.' + .format( + route1=ids[operation_id]['route'], + method1=ids[operation_id]['method'], + route2=route, + method2=method, + operation_id=operation_id + ) + ) + ids[operation_id] = { + 'route': route, + 'method': method + } + def get_schema(self, request=None, public=False): """ Generate a OpenAPI schema. @@ -57,6 +83,8 @@ class SchemaGenerator(BaseSchemaGenerator): paths.setdefault(path, {}) paths[path][method.lower()] = operation + self.check_duplicate_operation_id(paths) + # Compile final schema. schema = { 'openapi': '3.0.2', diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 7f73c8c30..9ce281d9a 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -1,4 +1,5 @@ import uuid +import warnings import pytest from django.conf.urls import url @@ -592,6 +593,24 @@ class TestOperationIntrospection(TestCase): assert schema_str.count("newExample") == 1 assert schema_str.count("oldExample") == 1 + def test_duplicate_operation_id(self): + patterns = [ + url(r'^duplicate1/?$', views.ExampleOperationIdDuplicate1.as_view()), + url(r'^duplicate2/?$', views.ExampleOperationIdDuplicate2.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + request = create_request('/') + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + generator.get_schema(request=request) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + print(str(w[-1].message)) + assert 'You have a duplicated operationId' in str(w[-1].message) + def test_serializer_datefield(self): path = '/' method = 'GET' diff --git a/tests/schemas/views.py b/tests/schemas/views.py index e8307ccbd..5835a5572 100644 --- a/tests/schemas/views.py +++ b/tests/schemas/views.py @@ -4,6 +4,7 @@ from django.core.validators import ( DecimalValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator, MinValueValidator, RegexValidator ) +from django.db import models from rest_framework import generics, permissions, serializers from rest_framework.decorators import action @@ -137,3 +138,32 @@ class ExampleValidatedAPIView(generics.GenericAPIView): url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1', ip='192.168.1.1') return Response(serializer.data) + + +# Serializer with model. +class OpenAPIExample(models.Model): + first_name = models.CharField(max_length=30) + + +class ExampleSerializerModel(serializers.Serializer): + date = serializers.DateField() + datetime = serializers.DateTimeField() + hstore = serializers.HStoreField() + uuid_field = serializers.UUIDField(default=uuid.uuid4) + + class Meta: + model = OpenAPIExample + + +class ExampleOperationIdDuplicate1(generics.GenericAPIView): + serializer_class = ExampleSerializerModel + + def get(self, *args, **kwargs): + pass + + +class ExampleOperationIdDuplicate2(generics.GenericAPIView): + serializer_class = ExampleSerializerModel + + def get(self, *args, **kwargs): + pass