From 35c5be6ec23af6e68914812599c905fe0fa2c0cc Mon Sep 17 00:00:00 2001 From: Den Date: Mon, 17 Oct 2022 13:47:45 +0400 Subject: [PATCH] Add a method for getting serializer field name (OpenAPI) (#7493) * Add a method for getting serializer field name * Add docs and test Co-authored-by: Tom Christie --- rest_framework/schemas/openapi.py | 11 +++++++++-- tests/schemas/test_openapi.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index ee614fdf6..2f9fb9f28 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -523,7 +523,7 @@ class AutoSchema(ViewInspector): continue if field.required: - required.append(field.field_name) + required.append(self.get_field_name(field)) schema = self.map_field(field) if field.read_only: @@ -538,7 +538,7 @@ class AutoSchema(ViewInspector): schema['description'] = str(field.help_text) self.map_field_validators(field, schema) - properties[field.field_name] = schema + properties[self.get_field_name(field)] = schema result = { 'type': 'object', @@ -589,6 +589,13 @@ class AutoSchema(ViewInspector): schema['maximum'] = int(digits * '9') + 1 schema['minimum'] = -schema['maximum'] + def get_field_name(self, field): + """ + Override this method if you want to change schema field name. + For example, convert snake_case field name to camelCase. + """ + return field.field_name + def get_paginator(self): pagination_class = getattr(self.view, 'pagination_class', None) if pagination_class: diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index daa035a3f..7542bb615 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -111,6 +111,20 @@ class TestFieldMapping(TestCase): assert data['properties']['default_false']['default'] is False, "default must be false" assert 'default' not in data['properties']['without_default'], "default must not be defined" + def test_custom_field_name(self): + class CustomSchema(AutoSchema): + def get_field_name(self, field): + return 'custom_' + field.field_name + + class Serializer(serializers.Serializer): + text_field = serializers.CharField() + + inspector = CustomSchema() + + data = inspector.map_serializer(Serializer()) + assert 'custom_text_field' in data['properties'] + assert 'text_field' not in data['properties'] + def test_nullable_fields(self): class Model(models.Model): rw_field = models.CharField(null=True)