mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-06-16 19:43:21 +03:00
Generate OpenAPI schema field types from validators. (#6674)
This commit is contained in:
parent
a63860fc8b
commit
2d65f82dd7
|
@ -1,10 +1,15 @@
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from django.core.validators import (
|
||||||
|
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
|
||||||
|
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
|
||||||
|
)
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
|
|
||||||
from rest_framework import exceptions, serializers
|
from rest_framework import exceptions, serializers
|
||||||
from rest_framework.compat import uritemplate
|
from rest_framework.compat import uritemplate
|
||||||
|
from rest_framework.fields import empty
|
||||||
|
|
||||||
from .generators import BaseSchemaGenerator
|
from .generators import BaseSchemaGenerator
|
||||||
from .inspectors import ViewInspector
|
from .inspectors import ViewInspector
|
||||||
|
@ -268,18 +273,76 @@ class AutoSchema(ViewInspector):
|
||||||
'format': 'date-time',
|
'format': 'date-time',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||||
|
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||||
|
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
|
||||||
|
if isinstance(field, serializers.EmailField):
|
||||||
|
return {
|
||||||
|
'type': 'string',
|
||||||
|
'format': 'email'
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(field, serializers.URLField):
|
||||||
|
return {
|
||||||
|
'type': 'string',
|
||||||
|
'format': 'uri'
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(field, serializers.UUIDField):
|
||||||
|
return {
|
||||||
|
'type': 'string',
|
||||||
|
'format': 'uuid'
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(field, serializers.IPAddressField):
|
||||||
|
content = {
|
||||||
|
'type': 'string',
|
||||||
|
}
|
||||||
|
if field.protocol != 'both':
|
||||||
|
content['format'] = field.protocol
|
||||||
|
return content
|
||||||
|
|
||||||
|
# DecimalField has multipleOf based on decimal_places
|
||||||
|
if isinstance(field, serializers.DecimalField):
|
||||||
|
content = {
|
||||||
|
'type': 'number'
|
||||||
|
}
|
||||||
|
if field.decimal_places:
|
||||||
|
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
|
||||||
|
if field.max_whole_digits:
|
||||||
|
content['maximum'] = int(field.max_whole_digits * '9') + 1
|
||||||
|
content['minimum'] = -content['maximum']
|
||||||
|
self._map_min_max(field, content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
if isinstance(field, serializers.FloatField):
|
||||||
|
content = {
|
||||||
|
'type': 'number'
|
||||||
|
}
|
||||||
|
self._map_min_max(field, content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
if isinstance(field, serializers.IntegerField):
|
||||||
|
content = {
|
||||||
|
'type': 'integer'
|
||||||
|
}
|
||||||
|
self._map_min_max(field, content)
|
||||||
|
return content
|
||||||
|
|
||||||
# Simplest cases, default to 'string' type:
|
# Simplest cases, default to 'string' type:
|
||||||
FIELD_CLASS_SCHEMA_TYPE = {
|
FIELD_CLASS_SCHEMA_TYPE = {
|
||||||
serializers.BooleanField: 'boolean',
|
serializers.BooleanField: 'boolean',
|
||||||
serializers.DecimalField: 'number',
|
|
||||||
serializers.FloatField: 'number',
|
|
||||||
serializers.IntegerField: 'integer',
|
|
||||||
|
|
||||||
serializers.JSONField: 'object',
|
serializers.JSONField: 'object',
|
||||||
serializers.DictField: 'object',
|
serializers.DictField: 'object',
|
||||||
}
|
}
|
||||||
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
|
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
|
||||||
|
|
||||||
|
def _map_min_max(self, field, content):
|
||||||
|
if field.max_value:
|
||||||
|
content['maximum'] = field.max_value
|
||||||
|
if field.min_value:
|
||||||
|
content['minimum'] = field.min_value
|
||||||
|
|
||||||
def _map_serializer(self, serializer):
|
def _map_serializer(self, serializer):
|
||||||
# Assuming we have a valid serializer instance.
|
# Assuming we have a valid serializer instance.
|
||||||
# TODO:
|
# TODO:
|
||||||
|
@ -303,6 +366,11 @@ class AutoSchema(ViewInspector):
|
||||||
schema['writeOnly'] = True
|
schema['writeOnly'] = True
|
||||||
if field.allow_null:
|
if field.allow_null:
|
||||||
schema['nullable'] = True
|
schema['nullable'] = True
|
||||||
|
if field.default and field.default != empty: # why don't they use None?!
|
||||||
|
schema['default'] = field.default
|
||||||
|
if field.help_text:
|
||||||
|
schema['description'] = field.help_text
|
||||||
|
self._map_field_validators(field.validators, schema)
|
||||||
|
|
||||||
properties[field.field_name] = schema
|
properties[field.field_name] = schema
|
||||||
return {
|
return {
|
||||||
|
@ -310,6 +378,39 @@ class AutoSchema(ViewInspector):
|
||||||
'properties': properties,
|
'properties': properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _map_field_validators(self, validators, schema):
|
||||||
|
"""
|
||||||
|
map field validators
|
||||||
|
:param list:validators: list of field validators
|
||||||
|
:param dict:schema: schema that the validators get added to
|
||||||
|
"""
|
||||||
|
for v in validators:
|
||||||
|
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
|
||||||
|
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
|
||||||
|
if isinstance(v, EmailValidator):
|
||||||
|
schema['format'] = 'email'
|
||||||
|
if isinstance(v, URLValidator):
|
||||||
|
schema['format'] = 'uri'
|
||||||
|
if isinstance(v, RegexValidator):
|
||||||
|
schema['pattern'] = v.regex.pattern
|
||||||
|
elif isinstance(v, MaxLengthValidator):
|
||||||
|
schema['maxLength'] = v.limit_value
|
||||||
|
elif isinstance(v, MinLengthValidator):
|
||||||
|
schema['minLength'] = v.limit_value
|
||||||
|
elif isinstance(v, MaxValueValidator):
|
||||||
|
schema['maximum'] = v.limit_value
|
||||||
|
elif isinstance(v, MinValueValidator):
|
||||||
|
schema['minimum'] = v.limit_value
|
||||||
|
elif isinstance(v, DecimalValidator):
|
||||||
|
if v.decimal_places:
|
||||||
|
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
|
||||||
|
if v.max_digits:
|
||||||
|
digits = v.max_digits
|
||||||
|
if v.decimal_places is not None and v.decimal_places > 0:
|
||||||
|
digits -= v.decimal_places
|
||||||
|
schema['maximum'] = int(digits * '9') + 1
|
||||||
|
schema['minimum'] = -schema['maximum']
|
||||||
|
|
||||||
def _get_request_body(self, path, method):
|
def _get_request_body(self, path, method):
|
||||||
view = self.view
|
view = self.view
|
||||||
|
|
||||||
|
|
|
@ -257,3 +257,53 @@ class TestGenerator(TestCase):
|
||||||
|
|
||||||
assert response_schema['date']['format'] == 'date'
|
assert response_schema['date']['format'] == 'date'
|
||||||
assert response_schema['datetime']['format'] == 'date-time'
|
assert response_schema['datetime']['format'] == 'date-time'
|
||||||
|
|
||||||
|
def test_serializer_validators(self):
|
||||||
|
patterns = [
|
||||||
|
url(r'^example/?$', views.ExampleValdidatedAPIView.as_view()),
|
||||||
|
]
|
||||||
|
generator = SchemaGenerator(patterns=patterns)
|
||||||
|
|
||||||
|
request = create_request('/')
|
||||||
|
schema = generator.get_schema(request=request)
|
||||||
|
|
||||||
|
response = schema['paths']['/example/']['get']['responses']
|
||||||
|
response_schema = response['200']['content']['application/json']['schema']['properties']
|
||||||
|
|
||||||
|
assert response_schema['integer']['type'] == 'integer'
|
||||||
|
assert response_schema['integer']['maximum'] == 99
|
||||||
|
assert response_schema['integer']['minimum'] == -11
|
||||||
|
|
||||||
|
assert response_schema['string']['minLength'] == 2
|
||||||
|
assert response_schema['string']['maxLength'] == 10
|
||||||
|
|
||||||
|
assert response_schema['regex']['pattern'] == r'[ABC]12{3}'
|
||||||
|
assert response_schema['regex']['description'] == 'must have an A, B, or C followed by 1222'
|
||||||
|
|
||||||
|
assert response_schema['decimal1']['type'] == 'number'
|
||||||
|
assert response_schema['decimal1']['multipleOf'] == .01
|
||||||
|
assert response_schema['decimal1']['maximum'] == 10000
|
||||||
|
assert response_schema['decimal1']['minimum'] == -10000
|
||||||
|
|
||||||
|
assert response_schema['decimal2']['type'] == 'number'
|
||||||
|
assert response_schema['decimal2']['multipleOf'] == .0001
|
||||||
|
|
||||||
|
assert response_schema['email']['type'] == 'string'
|
||||||
|
assert response_schema['email']['format'] == 'email'
|
||||||
|
assert response_schema['email']['default'] == 'foo@bar.com'
|
||||||
|
|
||||||
|
assert response_schema['url']['type'] == 'string'
|
||||||
|
assert response_schema['url']['nullable'] is True
|
||||||
|
assert response_schema['url']['default'] == 'http://www.example.com'
|
||||||
|
|
||||||
|
assert response_schema['uuid']['type'] == 'string'
|
||||||
|
assert response_schema['uuid']['format'] == 'uuid'
|
||||||
|
|
||||||
|
assert response_schema['ip4']['type'] == 'string'
|
||||||
|
assert response_schema['ip4']['format'] == 'ipv4'
|
||||||
|
|
||||||
|
assert response_schema['ip6']['type'] == 'string'
|
||||||
|
assert response_schema['ip6']['format'] == 'ipv6'
|
||||||
|
|
||||||
|
assert response_schema['ip']['type'] == 'string'
|
||||||
|
assert 'format' not in response_schema['ip']
|
||||||
|
|
|
@ -1,3 +1,10 @@
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from django.core.validators import (
|
||||||
|
DecimalValidator, MaxLengthValidator, MaxValueValidator,
|
||||||
|
MinLengthValidator, MinValueValidator, RegexValidator
|
||||||
|
)
|
||||||
|
|
||||||
from rest_framework import generics, permissions, serializers
|
from rest_framework import generics, permissions, serializers
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
@ -56,3 +63,45 @@ class ExampleGenericViewSet(GenericViewSet):
|
||||||
@action(detail=False)
|
@action(detail=False)
|
||||||
def old(self, *args, **kwargs):
|
def old(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Validators and/or equivalent Field attributes.
|
||||||
|
class ExampleValidatedSerializer(serializers.Serializer):
|
||||||
|
integer = serializers.IntegerField(
|
||||||
|
validators=(
|
||||||
|
MaxValueValidator(limit_value=99),
|
||||||
|
MinValueValidator(limit_value=-11),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
string = serializers.CharField(
|
||||||
|
validators=(
|
||||||
|
MaxLengthValidator(limit_value=10),
|
||||||
|
MinLengthValidator(limit_value=2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
regex = serializers.CharField(
|
||||||
|
validators=(
|
||||||
|
RegexValidator(regex=r'[ABC]12{3}'),
|
||||||
|
),
|
||||||
|
help_text='must have an A, B, or C followed by 1222'
|
||||||
|
)
|
||||||
|
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
|
||||||
|
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
|
||||||
|
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
|
||||||
|
email = serializers.EmailField(default='foo@bar.com')
|
||||||
|
url = serializers.URLField(default='http://www.example.com', allow_null=True)
|
||||||
|
uuid = serializers.UUIDField()
|
||||||
|
ip4 = serializers.IPAddressField(protocol='ipv4')
|
||||||
|
ip6 = serializers.IPAddressField(protocol='ipv6')
|
||||||
|
ip = serializers.IPAddressField()
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleValdidatedAPIView(generics.GenericAPIView):
|
||||||
|
serializer_class = ExampleValidatedSerializer
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
serializer = self.get_serializer(integer=33, string='hello', regex='foo', decimal1=3.55,
|
||||||
|
decimal2=5.33, email='a@b.co',
|
||||||
|
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
|
||||||
|
ip='192.168.1.1')
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user