From 91767682a31236494a53bd6aae7262a83e9b1516 Mon Sep 17 00:00:00 2001 From: Clinton Blackburn Date: Sun, 5 Apr 2020 11:34:37 -0700 Subject: [PATCH] fixup! Corrected OpenAPI schema type for DecimalField --- rest_framework/schemas/openapi.py | 41 ++++++++++++++++++++++++------- tests/schemas/test_openapi.py | 11 +++++++-- tests/schemas/views.py | 7 ++++-- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 88f4db08c..358e3ef36 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -6,8 +6,8 @@ from operator import attrgetter from urllib.parse import urljoin from django.core.validators import ( - EmailValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator, - MinValueValidator, RegexValidator, URLValidator + DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, + MinLengthValidator, MinValueValidator, RegexValidator, URLValidator ) from django.db import models from django.utils.encoding import force_str @@ -15,6 +15,7 @@ from django.utils.encoding import force_str from rest_framework import exceptions, renderers, serializers from rest_framework.compat import uritemplate from rest_framework.fields import _UnvalidatedField, empty +from rest_framework.settings import api_settings from .generators import BaseSchemaGenerator from .inspectors import ViewInspector @@ -329,10 +330,10 @@ class AutoSchema(ViewInspector): type = 'boolean' elif all(isinstance(choice, int) for choice in choices): type = 'integer' - elif all(isinstance(choice, (int, float)) for choice in choices): # `number` includes `integer` + elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21 type = 'number' - elif all(isinstance(choice, (str, Decimal)) for choice in choices): + elif all(isinstance(choice, str) for choice in choices): type = 'string' else: type = None @@ -443,14 +444,26 @@ class AutoSchema(ViewInspector): return content if isinstance(field, serializers.DecimalField): - return { - 'type': 'string', - 'format': 'decimal' - } + if field.coerce_to_string: + return { + 'type': 'string', + 'format': 'decimal', + } + else: + content = { + 'type': 'number' + } + if field.decimal_places: + content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1') + if field.max_whole_digits: + content['maximum'] = int(field.max_whole_digits * '9') + 1 + content['minimum'] = -content['maximum'] + self._map_min_max(field, content) + return content if isinstance(field, serializers.FloatField): content = { - 'type': 'number' + 'type': 'number', } self._map_min_max(field, content) return content @@ -549,6 +562,16 @@ class AutoSchema(ViewInspector): schema['maximum'] = v.limit_value elif isinstance(v, MinValueValidator): schema['minimum'] = v.limit_value + elif isinstance(v, DecimalValidator) and \ + not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING): + if v.decimal_places: + schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1') + if v.max_digits: + digits = v.max_digits + if v.decimal_places is not None and v.decimal_places > 0: + digits -= v.decimal_places + schema['maximum'] = int(digits * '9') + 1 + schema['minimum'] = -schema['maximum'] def _get_paginator(self): pagination_class = getattr(self.view, 'pagination_class', None) diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 604dd7ffa..dfc1bc1c5 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -830,9 +830,16 @@ class TestOperationIntrospection(TestCase): assert properties['regex']['pattern'] == r'[ABC]12{3}' assert properties['regex']['description'] == 'must have an A, B, or C followed by 1222' - assert properties['decimal1'] == {'type': 'string', 'format': 'decimal'} + assert properties['decimal1']['type'] == 'number' + assert properties['decimal1']['multipleOf'] == .01 + assert properties['decimal1']['maximum'] == 10000 + assert properties['decimal1']['minimum'] == -10000 - assert properties['decimal2'] == {'type': 'string', 'format': 'decimal'} + assert properties['decimal2']['type'] == 'number' + assert properties['decimal2']['multipleOf'] == .0001 + + assert properties['decimal3'] == {'type': 'string', 'format': 'decimal'} + assert properties['decimal4'] == {'type': 'string', 'format': 'decimal'} assert properties['email']['type'] == 'string' assert properties['email']['format'] == 'email' diff --git a/tests/schemas/views.py b/tests/schemas/views.py index 5645f59bf..048ae75d6 100644 --- a/tests/schemas/views.py +++ b/tests/schemas/views.py @@ -119,8 +119,11 @@ class ExampleValidatedSerializer(serializers.Serializer): MinLengthValidator(limit_value=2), ) ) - decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2) - decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, + decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False) + decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False, + validators=(DecimalValidator(max_digits=17, decimal_places=4),)) + decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True) + decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True, validators=(DecimalValidator(max_digits=17, decimal_places=4),)) email = serializers.EmailField(default='foo@bar.com') url = serializers.URLField(default='http://www.example.com', allow_null=True)