mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-10 19:56:59 +03:00
Generate OpenAPI schema field types from validators. (#6674)
This commit is contained in:
parent
a63860fc8b
commit
2d65f82dd7
|
@ -1,10 +1,15 @@
|
|||
import warnings
|
||||
|
||||
from django.core.validators import (
|
||||
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
|
||||
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
|
||||
)
|
||||
from django.db import models
|
||||
from django.utils.encoding import force_text
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.compat import uritemplate
|
||||
from rest_framework.fields import empty
|
||||
|
||||
from .generators import BaseSchemaGenerator
|
||||
from .inspectors import ViewInspector
|
||||
|
@ -268,18 +273,76 @@ class AutoSchema(ViewInspector):
|
|||
'format': 'date-time',
|
||||
}
|
||||
|
||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
|
||||
if isinstance(field, serializers.EmailField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'email'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.URLField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'uri'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.UUIDField):
|
||||
return {
|
||||
'type': 'string',
|
||||
'format': 'uuid'
|
||||
}
|
||||
|
||||
if isinstance(field, serializers.IPAddressField):
|
||||
content = {
|
||||
'type': 'string',
|
||||
}
|
||||
if field.protocol != 'both':
|
||||
content['format'] = field.protocol
|
||||
return content
|
||||
|
||||
# DecimalField has multipleOf based on decimal_places
|
||||
if isinstance(field, serializers.DecimalField):
|
||||
content = {
|
||||
'type': 'number'
|
||||
}
|
||||
if field.decimal_places:
|
||||
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
|
||||
if field.max_whole_digits:
|
||||
content['maximum'] = int(field.max_whole_digits * '9') + 1
|
||||
content['minimum'] = -content['maximum']
|
||||
self._map_min_max(field, content)
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.FloatField):
|
||||
content = {
|
||||
'type': 'number'
|
||||
}
|
||||
self._map_min_max(field, content)
|
||||
return content
|
||||
|
||||
if isinstance(field, serializers.IntegerField):
|
||||
content = {
|
||||
'type': 'integer'
|
||||
}
|
||||
self._map_min_max(field, content)
|
||||
return content
|
||||
|
||||
# Simplest cases, default to 'string' type:
|
||||
FIELD_CLASS_SCHEMA_TYPE = {
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.DecimalField: 'number',
|
||||
serializers.FloatField: 'number',
|
||||
serializers.IntegerField: 'integer',
|
||||
|
||||
serializers.JSONField: 'object',
|
||||
serializers.DictField: 'object',
|
||||
}
|
||||
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
|
||||
|
||||
def _map_min_max(self, field, content):
|
||||
if field.max_value:
|
||||
content['maximum'] = field.max_value
|
||||
if field.min_value:
|
||||
content['minimum'] = field.min_value
|
||||
|
||||
def _map_serializer(self, serializer):
|
||||
# Assuming we have a valid serializer instance.
|
||||
# TODO:
|
||||
|
@ -303,6 +366,11 @@ class AutoSchema(ViewInspector):
|
|||
schema['writeOnly'] = True
|
||||
if field.allow_null:
|
||||
schema['nullable'] = True
|
||||
if field.default and field.default != empty: # why don't they use None?!
|
||||
schema['default'] = field.default
|
||||
if field.help_text:
|
||||
schema['description'] = field.help_text
|
||||
self._map_field_validators(field.validators, schema)
|
||||
|
||||
properties[field.field_name] = schema
|
||||
return {
|
||||
|
@ -310,6 +378,39 @@ class AutoSchema(ViewInspector):
|
|||
'properties': properties,
|
||||
}
|
||||
|
||||
def _map_field_validators(self, validators, schema):
|
||||
"""
|
||||
map field validators
|
||||
:param list:validators: list of field validators
|
||||
:param dict:schema: schema that the validators get added to
|
||||
"""
|
||||
for v in validators:
|
||||
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||
if isinstance(v, EmailValidator):
|
||||
schema['format'] = 'email'
|
||||
if isinstance(v, URLValidator):
|
||||
schema['format'] = 'uri'
|
||||
if isinstance(v, RegexValidator):
|
||||
schema['pattern'] = v.regex.pattern
|
||||
elif isinstance(v, MaxLengthValidator):
|
||||
schema['maxLength'] = v.limit_value
|
||||
elif isinstance(v, MinLengthValidator):
|
||||
schema['minLength'] = v.limit_value
|
||||
elif isinstance(v, MaxValueValidator):
|
||||
schema['maximum'] = v.limit_value
|
||||
elif isinstance(v, MinValueValidator):
|
||||
schema['minimum'] = v.limit_value
|
||||
elif isinstance(v, DecimalValidator):
|
||||
if v.decimal_places:
|
||||
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
|
||||
if v.max_digits:
|
||||
digits = v.max_digits
|
||||
if v.decimal_places is not None and v.decimal_places > 0:
|
||||
digits -= v.decimal_places
|
||||
schema['maximum'] = int(digits * '9') + 1
|
||||
schema['minimum'] = -schema['maximum']
|
||||
|
||||
def _get_request_body(self, path, method):
|
||||
view = self.view
|
||||
|
||||
|
|
|
@ -257,3 +257,53 @@ class TestGenerator(TestCase):
|
|||
|
||||
assert response_schema['date']['format'] == 'date'
|
||||
assert response_schema['datetime']['format'] == 'date-time'
|
||||
|
||||
def test_serializer_validators(self):
|
||||
patterns = [
|
||||
url(r'^example/?$', views.ExampleValdidatedAPIView.as_view()),
|
||||
]
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
|
||||
response = schema['paths']['/example/']['get']['responses']
|
||||
response_schema = response['200']['content']['application/json']['schema']['properties']
|
||||
|
||||
assert response_schema['integer']['type'] == 'integer'
|
||||
assert response_schema['integer']['maximum'] == 99
|
||||
assert response_schema['integer']['minimum'] == -11
|
||||
|
||||
assert response_schema['string']['minLength'] == 2
|
||||
assert response_schema['string']['maxLength'] == 10
|
||||
|
||||
assert response_schema['regex']['pattern'] == r'[ABC]12{3}'
|
||||
assert response_schema['regex']['description'] == 'must have an A, B, or C followed by 1222'
|
||||
|
||||
assert response_schema['decimal1']['type'] == 'number'
|
||||
assert response_schema['decimal1']['multipleOf'] == .01
|
||||
assert response_schema['decimal1']['maximum'] == 10000
|
||||
assert response_schema['decimal1']['minimum'] == -10000
|
||||
|
||||
assert response_schema['decimal2']['type'] == 'number'
|
||||
assert response_schema['decimal2']['multipleOf'] == .0001
|
||||
|
||||
assert response_schema['email']['type'] == 'string'
|
||||
assert response_schema['email']['format'] == 'email'
|
||||
assert response_schema['email']['default'] == 'foo@bar.com'
|
||||
|
||||
assert response_schema['url']['type'] == 'string'
|
||||
assert response_schema['url']['nullable'] is True
|
||||
assert response_schema['url']['default'] == 'http://www.example.com'
|
||||
|
||||
assert response_schema['uuid']['type'] == 'string'
|
||||
assert response_schema['uuid']['format'] == 'uuid'
|
||||
|
||||
assert response_schema['ip4']['type'] == 'string'
|
||||
assert response_schema['ip4']['format'] == 'ipv4'
|
||||
|
||||
assert response_schema['ip6']['type'] == 'string'
|
||||
assert response_schema['ip6']['format'] == 'ipv6'
|
||||
|
||||
assert response_schema['ip']['type'] == 'string'
|
||||
assert 'format' not in response_schema['ip']
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
import uuid
|
||||
|
||||
from django.core.validators import (
|
||||
DecimalValidator, MaxLengthValidator, MaxValueValidator,
|
||||
MinLengthValidator, MinValueValidator, RegexValidator
|
||||
)
|
||||
|
||||
from rest_framework import generics, permissions, serializers
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
|
@ -56,3 +63,45 @@ class ExampleGenericViewSet(GenericViewSet):
|
|||
@action(detail=False)
|
||||
def old(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# Validators and/or equivalent Field attributes.
|
||||
class ExampleValidatedSerializer(serializers.Serializer):
|
||||
integer = serializers.IntegerField(
|
||||
validators=(
|
||||
MaxValueValidator(limit_value=99),
|
||||
MinValueValidator(limit_value=-11),
|
||||
)
|
||||
)
|
||||
string = serializers.CharField(
|
||||
validators=(
|
||||
MaxLengthValidator(limit_value=10),
|
||||
MinLengthValidator(limit_value=2),
|
||||
)
|
||||
)
|
||||
regex = serializers.CharField(
|
||||
validators=(
|
||||
RegexValidator(regex=r'[ABC]12{3}'),
|
||||
),
|
||||
help_text='must have an A, B, or C followed by 1222'
|
||||
)
|
||||
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
|
||||
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
|
||||
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
|
||||
email = serializers.EmailField(default='foo@bar.com')
|
||||
url = serializers.URLField(default='http://www.example.com', allow_null=True)
|
||||
uuid = serializers.UUIDField()
|
||||
ip4 = serializers.IPAddressField(protocol='ipv4')
|
||||
ip6 = serializers.IPAddressField(protocol='ipv6')
|
||||
ip = serializers.IPAddressField()
|
||||
|
||||
|
||||
class ExampleValdidatedAPIView(generics.GenericAPIView):
|
||||
serializer_class = ExampleValidatedSerializer
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
serializer = self.get_serializer(integer=33, string='hello', regex='foo', decimal1=3.55,
|
||||
decimal2=5.33, email='a@b.co',
|
||||
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
|
||||
ip='192.168.1.1')
|
||||
return Response(serializer.data)
|
||||
|
|
Loading…
Reference in New Issue
Block a user