Corrected OpenAPI schema type for DecimalField (#7254)

This commit is contained in:
Clinton Blackburn 2020-04-09 10:16:17 -07:00 committed by GitHub
parent 41f27c3b43
commit 603aac7db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 8 deletions

View File

@ -15,6 +15,7 @@ from django.utils.encoding import force_str
from rest_framework import exceptions, renderers, serializers from rest_framework import exceptions, renderers, serializers
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector from .inspectors import ViewInspector
@ -446,11 +447,17 @@ class AutoSchema(ViewInspector):
content['format'] = field.protocol content['format'] = field.protocol
return content return content
# DecimalField has multipleOf based on decimal_places
if isinstance(field, serializers.DecimalField): if isinstance(field, serializers.DecimalField):
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
content = {
'type': 'string',
'format': 'decimal',
}
else:
content = { content = {
'type': 'number' 'type': 'number'
} }
if field.decimal_places: if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1') content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits: if field.max_whole_digits:
@ -461,7 +468,7 @@ class AutoSchema(ViewInspector):
if isinstance(field, serializers.FloatField): if isinstance(field, serializers.FloatField):
content = { content = {
'type': 'number' 'type': 'number',
} }
self._map_min_max(field, content) self._map_min_max(field, content)
return content return content
@ -560,7 +567,8 @@ class AutoSchema(ViewInspector):
schema['maximum'] = v.limit_value schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator): elif isinstance(v, MinValueValidator):
schema['minimum'] = v.limit_value schema['minimum'] = v.limit_value
elif isinstance(v, DecimalValidator): elif isinstance(v, DecimalValidator) and \
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
if v.decimal_places: if v.decimal_places:
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1') schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
if v.max_digits: if v.max_digits:

View File

@ -838,6 +838,16 @@ class TestOperationIntrospection(TestCase):
assert properties['decimal2']['type'] == 'number' assert properties['decimal2']['type'] == 'number'
assert properties['decimal2']['multipleOf'] == .0001 assert properties['decimal2']['multipleOf'] == .0001
assert properties['decimal3'] == {
'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
}
assert properties['decimal4'] == {
'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
}
assert properties['decimal5'] == {
'type': 'string', 'format': 'decimal', 'maximum': 10000, 'minimum': -10000, 'multipleOf': 0.01
}
assert properties['email']['type'] == 'string' assert properties['email']['type'] == 'string'
assert properties['email']['format'] == 'email' assert properties['email']['format'] == 'email'
assert properties['email']['default'] == 'foo@bar.com' assert properties['email']['default'] == 'foo@bar.com'

View File

@ -119,9 +119,13 @@ class ExampleValidatedSerializer(serializers.Serializer):
MinLengthValidator(limit_value=2), MinLengthValidator(limit_value=2),
) )
) )
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2) decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False)
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False,
validators=(DecimalValidator(max_digits=17, decimal_places=4),)) validators=(DecimalValidator(max_digits=17, decimal_places=4),))
decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True)
decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True,
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
decimal5 = serializers.DecimalField(max_digits=6, decimal_places=2)
email = serializers.EmailField(default='foo@bar.com') email = serializers.EmailField(default='foo@bar.com')
url = serializers.URLField(default='http://www.example.com', allow_null=True) url = serializers.URLField(default='http://www.example.com', allow_null=True)
uuid = serializers.UUIDField() uuid = serializers.UUIDField()