Add a method for getting serializer field name (OpenAPI) (#7493)

* Add a method for getting serializer field name

* Add docs and test

Co-authored-by: Tom Christie <tom@tomchristie.com>
This commit is contained in:
Den 2022-10-17 13:47:45 +04:00 committed by GitHub
parent 0cb693700f
commit 35c5be6ec2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 2 deletions

View File

@ -523,7 +523,7 @@ class AutoSchema(ViewInspector):
continue continue
if field.required: if field.required:
required.append(field.field_name) required.append(self.get_field_name(field))
schema = self.map_field(field) schema = self.map_field(field)
if field.read_only: if field.read_only:
@ -538,7 +538,7 @@ class AutoSchema(ViewInspector):
schema['description'] = str(field.help_text) schema['description'] = str(field.help_text)
self.map_field_validators(field, schema) self.map_field_validators(field, schema)
properties[field.field_name] = schema properties[self.get_field_name(field)] = schema
result = { result = {
'type': 'object', 'type': 'object',
@ -589,6 +589,13 @@ class AutoSchema(ViewInspector):
schema['maximum'] = int(digits * '9') + 1 schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum'] schema['minimum'] = -schema['maximum']
def get_field_name(self, field):
"""
Override this method if you want to change schema field name.
For example, convert snake_case field name to camelCase.
"""
return field.field_name
def get_paginator(self): def get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None) pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class: if pagination_class:

View File

@ -111,6 +111,20 @@ class TestFieldMapping(TestCase):
assert data['properties']['default_false']['default'] is False, "default must be false" assert data['properties']['default_false']['default'] is False, "default must be false"
assert 'default' not in data['properties']['without_default'], "default must not be defined" assert 'default' not in data['properties']['without_default'], "default must not be defined"
def test_custom_field_name(self):
class CustomSchema(AutoSchema):
def get_field_name(self, field):
return 'custom_' + field.field_name
class Serializer(serializers.Serializer):
text_field = serializers.CharField()
inspector = CustomSchema()
data = inspector.map_serializer(Serializer())
assert 'custom_text_field' in data['properties']
assert 'text_field' not in data['properties']
def test_nullable_fields(self): def test_nullable_fields(self):
class Model(models.Model): class Model(models.Model):
rw_field = models.CharField(null=True) rw_field = models.CharField(null=True)