From 8d4e35d24abd24e1b19d8b0515091c852d6aed03 Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Tue, 14 May 2019 15:58:33 -0500 Subject: [PATCH] generate openapa schema field validators --- rest_framework/schemas/openapi.py | 111 ++++++++++++++++++++++++++++-- tests/schemas/test_openapi.py | 50 ++++++++++++++ tests/schemas/views.py | 49 +++++++++++++ 3 files changed, 206 insertions(+), 4 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 44b281be8..b039c749b 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -1,10 +1,15 @@ import warnings +from django.core.validators import ( + DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, + MinLengthValidator, MinValueValidator, RegexValidator, URLValidator +) from django.db import models from django.utils.encoding import force_text from rest_framework import exceptions, serializers from rest_framework.compat import uritemplate +from rest_framework.fields import empty from .generators import BaseSchemaGenerator from .inspectors import ViewInspector @@ -266,18 +271,76 @@ class AutoSchema(ViewInspector): 'format': 'date-time', } + # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." + # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types + # see also: https://swagger.io/docs/specification/data-models/data-types/#string + if isinstance(field, serializers.EmailField): + return { + 'type': 'string', + 'format': 'email' + } + + if isinstance(field, serializers.URLField): + return { + 'type': 'string', + 'format': 'uri' + } + + if isinstance(field, serializers.UUIDField): + return { + 'type': 'string', + 'format': 'uuid' + } + + if isinstance(field, serializers.IPAddressField): + content = { + 'type': 'string', + } + if field.protocol != 'both': + content['format'] = field.protocol + return content + + # DecimalField has multipleOf based on decimal_places + if isinstance(field, serializers.DecimalField): + content = { + 'type': 'number' + } + if field.decimal_places: + content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1') + if field.max_whole_digits: + content['maximum'] = int(field.max_whole_digits * '9') + 1 + content['minimum'] = -content['maximum'] + self._map_min_max(field, content) + return content + + if isinstance(field, serializers.FloatField): + content = { + 'type': 'number' + } + self._map_min_max(field, content) + return content + + if isinstance(field, serializers.IntegerField): + content = { + 'type': 'integer' + } + self._map_min_max(field, content) + return content + # Simplest cases, default to 'string' type: FIELD_CLASS_SCHEMA_TYPE = { serializers.BooleanField: 'boolean', - serializers.DecimalField: 'number', - serializers.FloatField: 'number', - serializers.IntegerField: 'integer', - serializers.JSONField: 'object', serializers.DictField: 'object', } return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} + def _map_min_max(self, field, content): + if field.max_value: + content['maximum'] = field.max_value + if field.min_value: + content['minimum'] = field.min_value + def _map_serializer(self, serializer): # Assuming we have a valid serializer instance. # TODO: @@ -301,6 +364,11 @@ class AutoSchema(ViewInspector): schema['writeOnly'] = True if field.allow_null: schema['nullable'] = True + if field.default and field.default != empty: # why don't they use None?! + schema['default'] = field.default + if field.help_text: + schema['description'] = field.help_text + self._map_field_validators(field.validators, schema) properties[field.field_name] = schema return { @@ -308,6 +376,41 @@ class AutoSchema(ViewInspector): 'properties': properties, } + def _map_field_validators(self, validators, schema): + """ + map field validators + :param list:validators: list of field validators + :param dict:schema: schema that the validators get added to + """ + for v in validators: + # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." + # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types + if isinstance(v, EmailValidator): + schema['format'] = 'email' + if isinstance(v, URLValidator): + schema['format'] = 'uri' + if isinstance(v, URLValidator): + schema['format'] = 'uri' + if isinstance(v, RegexValidator): + schema['pattern'] = v.regex.pattern + elif isinstance(v, MaxLengthValidator): + schema['maxLength'] = v.limit_value + elif isinstance(v, MinLengthValidator): + schema['minLength'] = v.limit_value + elif isinstance(v, MaxValueValidator): + schema['maximum'] = v.limit_value + elif isinstance(v, MinValueValidator): + schema['minimum'] = v.limit_value + elif isinstance(v, DecimalValidator): + if v.decimal_places: + schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1') + if v.max_digits: + digits = v.max_digits + if v.decimal_places is not None and v.decimal_places > 0: + digits -= v.decimal_places + schema['maximum'] = int(digits * '9') + 1 + schema['minimum'] = -schema['maximum'] + def _get_request_body(self, path, method): view = self.view diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 2ddf54f01..099d397b1 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -243,3 +243,53 @@ class TestGenerator(TestCase): assert response_schema['date']['format'] == 'date' assert response_schema['datetime']['format'] == 'date-time' + + def test_serializer_validators(self): + patterns = [ + url(r'^example/?$', views.ExampleValdidatedAPIView.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + response = schema['paths']['/example/']['get']['responses'] + response_schema = response['200']['content']['application/json']['schema']['properties'] + + assert response_schema['integer']['type'] == 'integer' + assert response_schema['integer']['maximum'] == 99 + assert response_schema['integer']['minimum'] == -11 + + assert response_schema['string']['minLength'] == 2 + assert response_schema['string']['maxLength'] == 10 + + assert response_schema['regex']['pattern'] == r'[ABC]12{3}' + assert response_schema['regex']['description'] == 'must have an A, B, or C followed by 1222' + + assert response_schema['decimal1']['type'] == 'number' + assert response_schema['decimal1']['multipleOf'] == .01 + assert response_schema['decimal1']['maximum'] == 10000 + assert response_schema['decimal1']['minimum'] == -10000 + + assert response_schema['decimal2']['type'] == 'number' + assert response_schema['decimal2']['multipleOf'] == .0001 + + assert response_schema['email']['type'] == 'string' + assert response_schema['email']['format'] == 'email' + assert response_schema['email']['default'] == 'foo@bar.com' + + assert response_schema['url']['type'] == 'string' + assert response_schema['url']['nullable'] is True + assert response_schema['url']['default'] == 'http://www.example.com' + + assert response_schema['uuid']['type'] == 'string' + assert response_schema['uuid']['format'] == 'uuid' + + assert response_schema['ip4']['type'] == 'string' + assert response_schema['ip4']['format'] == 'ipv4' + + assert response_schema['ip6']['type'] == 'string' + assert response_schema['ip6']['format'] == 'ipv6' + + assert response_schema['ip']['type'] == 'string' + assert 'format' not in response_schema['ip'] diff --git a/tests/schemas/views.py b/tests/schemas/views.py index dc0d6065b..7a920f33e 100644 --- a/tests/schemas/views.py +++ b/tests/schemas/views.py @@ -1,3 +1,10 @@ +import uuid + +from django.core.validators import ( + DecimalValidator, MaxLengthValidator, MaxValueValidator, + MinLengthValidator, MinValueValidator, RegexValidator +) + from rest_framework import generics, permissions, serializers from rest_framework.decorators import action from rest_framework.response import Response @@ -56,3 +63,45 @@ class ExampleGenericViewSet(GenericViewSet): @action(detail=False) def old(self, *args, **kwargs): pass + + +# Validators and/or equivalent Field attributes. +class ExampleValidatedSerializer(serializers.Serializer): + integer = serializers.IntegerField( + validators=( + MaxValueValidator(limit_value=99), + MinValueValidator(limit_value=-11), + ) + ) + string = serializers.CharField( + validators=( + MaxLengthValidator(limit_value=10), + MinLengthValidator(limit_value=2), + ) + ) + regex = serializers.CharField( + validators=( + RegexValidator(regex=r'[ABC]12{3}'), + ), + help_text='must have an A, B, or C followed by 1222' + ) + decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2) + decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, + validators=(DecimalValidator(max_digits=17, decimal_places=4),)) + email = serializers.EmailField(default='foo@bar.com') + url = serializers.URLField(default='http://www.example.com', allow_null=True) + uuid = serializers.UUIDField() + ip4 = serializers.IPAddressField(protocol='ipv4') + ip6 = serializers.IPAddressField(protocol='ipv6') + ip = serializers.IPAddressField() + + +class ExampleValdidatedAPIView(generics.GenericAPIView): + serializer_class = ExampleValidatedSerializer + + def get(self, *args, **kwargs): + serializer = self.get_serializer(integer=33, string='hello', regex='foo', decimal1=3.55, + decimal2=5.33, email='a@b.co', + url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1', + ip='192.168.1.1') + return Response(serializer.data)