fixup! Corrected OpenAPI schema type for DecimalField

This commit is contained in:
Clinton Blackburn 2020-04-05 11:34:37 -07:00
parent e21156ff17
commit 91767682a3
3 changed files with 46 additions and 13 deletions

View File

@ -6,8 +6,8 @@ from operator import attrgetter
from urllib.parse import urljoin
from django.core.validators import (
EmailValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator,
MinValueValidator, RegexValidator, URLValidator
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
)
from django.db import models
from django.utils.encoding import force_str
@ -15,6 +15,7 @@ from django.utils.encoding import force_str
from rest_framework import exceptions, renderers, serializers
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
@ -329,10 +330,10 @@ class AutoSchema(ViewInspector):
type = 'boolean'
elif all(isinstance(choice, int) for choice in choices):
type = 'integer'
elif all(isinstance(choice, (int, float)) for choice in choices): # `number` includes `integer`
elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer`
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
type = 'number'
elif all(isinstance(choice, (str, Decimal)) for choice in choices):
elif all(isinstance(choice, str) for choice in choices):
type = 'string'
else:
type = None
@ -443,14 +444,26 @@ class AutoSchema(ViewInspector):
return content
if isinstance(field, serializers.DecimalField):
return {
'type': 'string',
'format': 'decimal'
}
if field.coerce_to_string:
return {
'type': 'string',
'format': 'decimal',
}
else:
content = {
'type': 'number'
}
if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits:
content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum']
self._map_min_max(field, content)
return content
if isinstance(field, serializers.FloatField):
content = {
'type': 'number'
'type': 'number',
}
self._map_min_max(field, content)
return content
@ -549,6 +562,16 @@ class AutoSchema(ViewInspector):
schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator):
schema['minimum'] = v.limit_value
elif isinstance(v, DecimalValidator) and \
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
if v.decimal_places:
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
if v.max_digits:
digits = v.max_digits
if v.decimal_places is not None and v.decimal_places > 0:
digits -= v.decimal_places
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']
def _get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None)

View File

@ -830,9 +830,16 @@ class TestOperationIntrospection(TestCase):
assert properties['regex']['pattern'] == r'[ABC]12{3}'
assert properties['regex']['description'] == 'must have an A, B, or C followed by 1222'
assert properties['decimal1'] == {'type': 'string', 'format': 'decimal'}
assert properties['decimal1']['type'] == 'number'
assert properties['decimal1']['multipleOf'] == .01
assert properties['decimal1']['maximum'] == 10000
assert properties['decimal1']['minimum'] == -10000
assert properties['decimal2'] == {'type': 'string', 'format': 'decimal'}
assert properties['decimal2']['type'] == 'number'
assert properties['decimal2']['multipleOf'] == .0001
assert properties['decimal3'] == {'type': 'string', 'format': 'decimal'}
assert properties['decimal4'] == {'type': 'string', 'format': 'decimal'}
assert properties['email']['type'] == 'string'
assert properties['email']['format'] == 'email'

View File

@ -119,8 +119,11 @@ class ExampleValidatedSerializer(serializers.Serializer):
MinLengthValidator(limit_value=2),
)
)
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False)
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False,
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True)
decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True,
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
email = serializers.EmailField(default='foo@bar.com')
url = serializers.URLField(default='http://www.example.com', allow_null=True)