diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 5e9d59f8b..24bc3f6bf 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -128,6 +128,7 @@ class AutoSchema(ViewInspector): self._tags = tags self.operation_id_base = operation_id_base self.component_name = component_name + self.components = {} super().__init__() request_media_types = [] @@ -195,19 +196,17 @@ class AutoSchema(ViewInspector): request_serializer = self.get_request_serializer(path, method) response_serializer = self.get_response_serializer(path, method) - components = {} - if isinstance(request_serializer, serializers.Serializer): component_name = self.get_component_name(request_serializer) content = self.map_serializer(request_serializer) - components.setdefault(component_name, content) + self.components.setdefault(component_name, content) if isinstance(response_serializer, serializers.Serializer): component_name = self.get_component_name(response_serializer) content = self.map_serializer(response_serializer) - components.setdefault(component_name, content) + self.components.setdefault(component_name, content) - return components + return self.components def _to_camel_case(self, snake_str): components = snake_str.split('_') @@ -547,7 +546,9 @@ class AutoSchema(ViewInspector): if required: result['required'] = required - return result + component_name = self.get_component_name(serializer=serializer) + self.components[component_name] = result + return self._get_reference(serializer) def map_field_validators(self, field, schema): """ diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index daa035a3f..ed3cb4d81 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -95,18 +95,20 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - data = inspector.map_serializer(ItemSerializer()) + inspector.map_serializer(ItemSerializer()) + data = inspector.components['Item'] assert isinstance(data['properties']['text']['description'], str), "description must be str" def test_boolean_default_field(self): - class Serializer(serializers.Serializer): + class BooleanTestSerializer(serializers.Serializer): default_true = serializers.BooleanField(default=True) default_false = serializers.BooleanField(default=False) without_default = serializers.BooleanField() inspector = AutoSchema() - data = inspector.map_serializer(Serializer()) + inspector.map_serializer(BooleanTestSerializer()) + data = inspector.components['BooleanTest'] assert data['properties']['default_true']['default'] is True, "default must be true" assert data['properties']['default_false']['default'] is False, "default must be false" assert 'default' not in data['properties']['without_default'], "default must not be defined" @@ -116,7 +118,7 @@ class TestFieldMapping(TestCase): rw_field = models.CharField(null=True) ro_field = models.CharField(null=True) - class Serializer(serializers.ModelSerializer): + class NullableSerializer(serializers.ModelSerializer): class Meta: model = Model fields = ["rw_field", "ro_field"] @@ -124,7 +126,8 @@ class TestFieldMapping(TestCase): inspector = AutoSchema() - data = inspector.map_serializer(Serializer()) + inspector.map_serializer(NullableSerializer()) + data = inspector.components['Nullable'] assert data['properties']['rw_field']['nullable'], "rw_field nullable must be true" assert data['properties']['ro_field']['nullable'], "ro_field nullable must be true" assert data['properties']['ro_field']['readOnly'], "ro_field read_only must be true" @@ -368,8 +371,10 @@ class TestOperationIntrospection(TestCase): assert sorted(schema['required']) == ['nested', 'text'] assert sorted(list(schema['properties'].keys())) == ['nested', 'text'] assert schema['properties']['nested']['type'] == 'object' - assert list(schema['properties']['nested']['properties'].keys()) == ['number'] - assert schema['properties']['nested']['required'] == ['number'] + assert schema['properties']['nested']['$ref'] == '#/components/schemas/Nested' + nested_schema = components['Nested'] + assert list(nested_schema['properties'].keys()) == ['number'] + assert nested_schema['required'] == ['number'] def test_list_response_body_generation(self): """Test that an array schema is returned for list views."""