Generate OpenAPI schema field types from validators. (#6674)

This commit is contained in:
Alan Crosswell 2019-06-09 08:42:56 -04:00 committed by Carlton Gibson
parent a63860fc8b
commit 2d65f82dd7
3 changed files with 204 additions and 4 deletions

View File

@ -1,10 +1,15 @@
import warnings
from django.core.validators import (
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
)
from django.db import models
from django.utils.encoding import force_text
from rest_framework import exceptions, serializers
from rest_framework.compat import uritemplate
from rest_framework.fields import empty
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
@ -268,18 +273,76 @@ class AutoSchema(ViewInspector):
'format': 'date-time',
}
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
if isinstance(field, serializers.EmailField):
return {
'type': 'string',
'format': 'email'
}
if isinstance(field, serializers.URLField):
return {
'type': 'string',
'format': 'uri'
}
if isinstance(field, serializers.UUIDField):
return {
'type': 'string',
'format': 'uuid'
}
if isinstance(field, serializers.IPAddressField):
content = {
'type': 'string',
}
if field.protocol != 'both':
content['format'] = field.protocol
return content
# DecimalField has multipleOf based on decimal_places
if isinstance(field, serializers.DecimalField):
content = {
'type': 'number'
}
if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits:
content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum']
self._map_min_max(field, content)
return content
if isinstance(field, serializers.FloatField):
content = {
'type': 'number'
}
self._map_min_max(field, content)
return content
if isinstance(field, serializers.IntegerField):
content = {
'type': 'integer'
}
self._map_min_max(field, content)
return content
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
serializers.DecimalField: 'number',
serializers.FloatField: 'number',
serializers.IntegerField: 'integer',
serializers.JSONField: 'object',
serializers.DictField: 'object',
}
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def _map_min_max(self, field, content):
if field.max_value:
content['maximum'] = field.max_value
if field.min_value:
content['minimum'] = field.min_value
def _map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
# TODO:
@ -303,6 +366,11 @@ class AutoSchema(ViewInspector):
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
if field.default and field.default != empty: # why don't they use None?!
schema['default'] = field.default
if field.help_text:
schema['description'] = field.help_text
self._map_field_validators(field.validators, schema)
properties[field.field_name] = schema
return {
@ -310,6 +378,39 @@ class AutoSchema(ViewInspector):
'properties': properties,
}
def _map_field_validators(self, validators, schema):
"""
map field validators
:param list:validators: list of field validators
:param dict:schema: schema that the validators get added to
"""
for v in validators:
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
if isinstance(v, EmailValidator):
schema['format'] = 'email'
if isinstance(v, URLValidator):
schema['format'] = 'uri'
if isinstance(v, RegexValidator):
schema['pattern'] = v.regex.pattern
elif isinstance(v, MaxLengthValidator):
schema['maxLength'] = v.limit_value
elif isinstance(v, MinLengthValidator):
schema['minLength'] = v.limit_value
elif isinstance(v, MaxValueValidator):
schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator):
schema['minimum'] = v.limit_value
elif isinstance(v, DecimalValidator):
if v.decimal_places:
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
if v.max_digits:
digits = v.max_digits
if v.decimal_places is not None and v.decimal_places > 0:
digits -= v.decimal_places
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']
def _get_request_body(self, path, method):
view = self.view

View File

@ -257,3 +257,53 @@ class TestGenerator(TestCase):
assert response_schema['date']['format'] == 'date'
assert response_schema['datetime']['format'] == 'date-time'
def test_serializer_validators(self):
patterns = [
url(r'^example/?$', views.ExampleValdidatedAPIView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
response = schema['paths']['/example/']['get']['responses']
response_schema = response['200']['content']['application/json']['schema']['properties']
assert response_schema['integer']['type'] == 'integer'
assert response_schema['integer']['maximum'] == 99
assert response_schema['integer']['minimum'] == -11
assert response_schema['string']['minLength'] == 2
assert response_schema['string']['maxLength'] == 10
assert response_schema['regex']['pattern'] == r'[ABC]12{3}'
assert response_schema['regex']['description'] == 'must have an A, B, or C followed by 1222'
assert response_schema['decimal1']['type'] == 'number'
assert response_schema['decimal1']['multipleOf'] == .01
assert response_schema['decimal1']['maximum'] == 10000
assert response_schema['decimal1']['minimum'] == -10000
assert response_schema['decimal2']['type'] == 'number'
assert response_schema['decimal2']['multipleOf'] == .0001
assert response_schema['email']['type'] == 'string'
assert response_schema['email']['format'] == 'email'
assert response_schema['email']['default'] == 'foo@bar.com'
assert response_schema['url']['type'] == 'string'
assert response_schema['url']['nullable'] is True
assert response_schema['url']['default'] == 'http://www.example.com'
assert response_schema['uuid']['type'] == 'string'
assert response_schema['uuid']['format'] == 'uuid'
assert response_schema['ip4']['type'] == 'string'
assert response_schema['ip4']['format'] == 'ipv4'
assert response_schema['ip6']['type'] == 'string'
assert response_schema['ip6']['format'] == 'ipv6'
assert response_schema['ip']['type'] == 'string'
assert 'format' not in response_schema['ip']

View File

@ -1,3 +1,10 @@
import uuid
from django.core.validators import (
DecimalValidator, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, RegexValidator
)
from rest_framework import generics, permissions, serializers
from rest_framework.decorators import action
from rest_framework.response import Response
@ -56,3 +63,45 @@ class ExampleGenericViewSet(GenericViewSet):
@action(detail=False)
def old(self, *args, **kwargs):
pass
# Validators and/or equivalent Field attributes.
class ExampleValidatedSerializer(serializers.Serializer):
integer = serializers.IntegerField(
validators=(
MaxValueValidator(limit_value=99),
MinValueValidator(limit_value=-11),
)
)
string = serializers.CharField(
validators=(
MaxLengthValidator(limit_value=10),
MinLengthValidator(limit_value=2),
)
)
regex = serializers.CharField(
validators=(
RegexValidator(regex=r'[ABC]12{3}'),
),
help_text='must have an A, B, or C followed by 1222'
)
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
email = serializers.EmailField(default='foo@bar.com')
url = serializers.URLField(default='http://www.example.com', allow_null=True)
uuid = serializers.UUIDField()
ip4 = serializers.IPAddressField(protocol='ipv4')
ip6 = serializers.IPAddressField(protocol='ipv6')
ip = serializers.IPAddressField()
class ExampleValdidatedAPIView(generics.GenericAPIView):
serializer_class = ExampleValidatedSerializer
def get(self, *args, **kwargs):
serializer = self.get_serializer(integer=33, string='hello', regex='foo', decimal1=3.55,
decimal2=5.33, email='a@b.co',
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
ip='192.168.1.1')
return Response(serializer.data)