mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 01:47:59 +03:00 
			
		
		
		
	Corrected OpenAPI schema type for DecimalField (#7254)
This commit is contained in:
		
							parent
							
								
									41f27c3b43
								
							
						
					
					
						commit
						603aac7db1
					
				| 
						 | 
					@ -15,6 +15,7 @@ from django.utils.encoding import force_str
 | 
				
			||||||
from rest_framework import exceptions, renderers, serializers
 | 
					from rest_framework import exceptions, renderers, serializers
 | 
				
			||||||
from rest_framework.compat import uritemplate
 | 
					from rest_framework.compat import uritemplate
 | 
				
			||||||
from rest_framework.fields import _UnvalidatedField, empty
 | 
					from rest_framework.fields import _UnvalidatedField, empty
 | 
				
			||||||
 | 
					from rest_framework.settings import api_settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .generators import BaseSchemaGenerator
 | 
					from .generators import BaseSchemaGenerator
 | 
				
			||||||
from .inspectors import ViewInspector
 | 
					from .inspectors import ViewInspector
 | 
				
			||||||
| 
						 | 
					@ -446,11 +447,17 @@ class AutoSchema(ViewInspector):
 | 
				
			||||||
                content['format'] = field.protocol
 | 
					                content['format'] = field.protocol
 | 
				
			||||||
            return content
 | 
					            return content
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # DecimalField has multipleOf based on decimal_places
 | 
					 | 
				
			||||||
        if isinstance(field, serializers.DecimalField):
 | 
					        if isinstance(field, serializers.DecimalField):
 | 
				
			||||||
 | 
					            if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
 | 
				
			||||||
 | 
					                content = {
 | 
				
			||||||
 | 
					                    'type': 'string',
 | 
				
			||||||
 | 
					                    'format': 'decimal',
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
                content = {
 | 
					                content = {
 | 
				
			||||||
                    'type': 'number'
 | 
					                    'type': 'number'
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if field.decimal_places:
 | 
					            if field.decimal_places:
 | 
				
			||||||
                content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
 | 
					                content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
 | 
				
			||||||
            if field.max_whole_digits:
 | 
					            if field.max_whole_digits:
 | 
				
			||||||
| 
						 | 
					@ -461,7 +468,7 @@ class AutoSchema(ViewInspector):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if isinstance(field, serializers.FloatField):
 | 
					        if isinstance(field, serializers.FloatField):
 | 
				
			||||||
            content = {
 | 
					            content = {
 | 
				
			||||||
                'type': 'number'
 | 
					                'type': 'number',
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            self._map_min_max(field, content)
 | 
					            self._map_min_max(field, content)
 | 
				
			||||||
            return content
 | 
					            return content
 | 
				
			||||||
| 
						 | 
					@ -560,7 +567,8 @@ class AutoSchema(ViewInspector):
 | 
				
			||||||
                schema['maximum'] = v.limit_value
 | 
					                schema['maximum'] = v.limit_value
 | 
				
			||||||
            elif isinstance(v, MinValueValidator):
 | 
					            elif isinstance(v, MinValueValidator):
 | 
				
			||||||
                schema['minimum'] = v.limit_value
 | 
					                schema['minimum'] = v.limit_value
 | 
				
			||||||
            elif isinstance(v, DecimalValidator):
 | 
					            elif isinstance(v, DecimalValidator) and \
 | 
				
			||||||
 | 
					                    not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
 | 
				
			||||||
                if v.decimal_places:
 | 
					                if v.decimal_places:
 | 
				
			||||||
                    schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
 | 
					                    schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
 | 
				
			||||||
                if v.max_digits:
 | 
					                if v.max_digits:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -838,6 +838,16 @@ class TestOperationIntrospection(TestCase):
 | 
				
			||||||
        assert properties['decimal2']['type'] == 'number'
 | 
					        assert properties['decimal2']['type'] == 'number'
 | 
				
			||||||
        assert properties['decimal2']['multipleOf'] == .0001
 | 
					        assert properties['decimal2']['multipleOf'] == .0001
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert properties['decimal3'] == {
 | 
				
			||||||
 | 
					            'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        assert properties['decimal4'] == {
 | 
				
			||||||
 | 
					            'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        assert properties['decimal5'] == {
 | 
				
			||||||
 | 
					            'type': 'string', 'format': 'decimal', 'maximum': 10000, 'minimum': -10000, 'multipleOf': 0.01
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert properties['email']['type'] == 'string'
 | 
					        assert properties['email']['type'] == 'string'
 | 
				
			||||||
        assert properties['email']['format'] == 'email'
 | 
					        assert properties['email']['format'] == 'email'
 | 
				
			||||||
        assert properties['email']['default'] == 'foo@bar.com'
 | 
					        assert properties['email']['default'] == 'foo@bar.com'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -119,9 +119,13 @@ class ExampleValidatedSerializer(serializers.Serializer):
 | 
				
			||||||
            MinLengthValidator(limit_value=2),
 | 
					            MinLengthValidator(limit_value=2),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
 | 
					    decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False)
 | 
				
			||||||
    decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
 | 
					    decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False,
 | 
				
			||||||
                                        validators=(DecimalValidator(max_digits=17, decimal_places=4),))
 | 
					                                        validators=(DecimalValidator(max_digits=17, decimal_places=4),))
 | 
				
			||||||
 | 
					    decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True)
 | 
				
			||||||
 | 
					    decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True,
 | 
				
			||||||
 | 
					                                        validators=(DecimalValidator(max_digits=17, decimal_places=4),))
 | 
				
			||||||
 | 
					    decimal5 = serializers.DecimalField(max_digits=6, decimal_places=2)
 | 
				
			||||||
    email = serializers.EmailField(default='foo@bar.com')
 | 
					    email = serializers.EmailField(default='foo@bar.com')
 | 
				
			||||||
    url = serializers.URLField(default='http://www.example.com', allow_null=True)
 | 
					    url = serializers.URLField(default='http://www.example.com', allow_null=True)
 | 
				
			||||||
    uuid = serializers.UUIDField()
 | 
					    uuid = serializers.UUIDField()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user