From 5f4e262edcce1fd75bba5ff52c31676b021b8942 Mon Sep 17 00:00:00 2001 From: Erwin Junge Date: Mon, 22 Mar 2021 17:17:10 +0100 Subject: [PATCH] Use method type annotation to determine schema type for method field --- rest_framework/schemas/openapi.py | 11 +++++++++++ tests/schemas/test_openapi.py | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 4ecb7a65f..1d3ee7b60 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -1,3 +1,4 @@ +import inspect import re import warnings from collections import OrderedDict @@ -359,6 +360,16 @@ class AutoSchema(ViewInspector): return mapping def map_field(self, field): + if isinstance(field, serializers.SerializerMethodField) and field.parent and field.method_name: + return_type = inspect.signature( + getattr(field.parent, field.method_name) + ).return_annotation + if issubclass(return_type, bool): + return {'type': 'boolean'} + if issubclass(return_type, float): + return {'type': 'number'} + if issubclass(return_type, int): + return {'type': 'integer'} # Nested Serializers, `many` or not. if isinstance(field, serializers.ListSerializer): diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 871eb1b30..f1b407ea2 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -54,6 +54,25 @@ class TestFieldMapping(TestCase): uuid1 = uuid.uuid4() uuid2 = uuid.uuid4() inspector = AutoSchema() + + class TestSerializer(serializers.Serializer): + unannotated = serializers.SerializerMethodField() + annotated_int = serializers.SerializerMethodField() + annotated_float = serializers.SerializerMethodField() + annotated_bool = serializers.SerializerMethodField() + + def get_unannotated(self): + return 'blub' + + def get_annotated_int(self) -> int: + return 1 + + def get_annotated_float(self) -> float: + return 1.0 + + def get_annotated_bool(self) -> bool: + return True + cases = [ (serializers.ListField(), {'items': {}, 'type': 'array'}), (serializers.ListField(child=serializers.BooleanField()), {'items': {'type': 'boolean'}, 'type': 'array'}), @@ -83,6 +102,11 @@ class TestFieldMapping(TestCase): {'items': {'enum': [1, 2, 3], 'type': 'integer'}, 'type': 'array'}), (serializers.IntegerField(min_value=2147483648), {'type': 'integer', 'minimum': 2147483648, 'format': 'int64'}), + (serializers.SerializerMethodField(), {'type': 'string'}), + (TestSerializer().fields['unannotated'], {'type': 'string'}), + (TestSerializer().fields['annotated_int'], {'type': 'integer'}), + (TestSerializer().fields['annotated_float'], {'type': 'number'}), + (TestSerializer().fields['annotated_bool'], {'type': 'boolean'}), ] for field, mapping in cases: with self.subTest(field=field):