diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 4bd1b31e4..74969ceb8 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -263,7 +263,7 @@ class AutoSchema(ViewInspector): if isinstance(field, serializers.ManyRelatedField): return { 'type': 'array', - 'items': self._map_field(field.child_relation) + 'items': self._map_field(method, field.child_relation) } if isinstance(field, serializers.PrimaryKeyRelatedField): model = getattr(field.queryset, 'model', None) diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 8e7dffc5e..ef4fee651 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -18,6 +18,21 @@ from rest_framework.schemas.openapi import ( from . import views +class ExampleModel(models.Model): + text = models.TextField() + + +class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = ExampleModel + fields = ['id', 'text'] + + +class ExampleViewSet(viewsets.ModelViewSet): + serializer_class = ExampleSerializer + queryset = ExampleModel.objects.none() + + def create_request(path): factory = RequestFactory() request = Request(factory.get(path)) @@ -50,11 +65,15 @@ class TestBasics(TestCase): class TestFieldMapping(TestCase): def test_list_field_mapping(self): inspector = AutoSchema() + inspector.init(ComponentRegistry()) + cases = [ (serializers.ListField(), {'items': {}, 'type': 'array'}), (serializers.ListField(child=serializers.BooleanField()), {'items': {'type': 'boolean'}, 'type': 'array'}), (serializers.ListField(child=serializers.FloatField()), {'items': {'type': 'number'}, 'type': 'array'}), (serializers.ListField(child=serializers.CharField()), {'items': {'type': 'string'}, 'type': 'array'}), + (serializers.ManyRelatedField(child_relation=ExampleSerializer(), read_only=True), + {'items': {'$ref': '#/components/schemas/Example', 'type': 'object'}, 'type': 'array'}), (serializers.ListField(child=serializers.IntegerField(max_value=4294967295)), {'items': {'type': 'integer', 'format': 'int64'}, 'type': 'array'}), (serializers.IntegerField(min_value=2147483648), @@ -682,18 +701,6 @@ class TestOperationIntrospection(TestCase): assert 'format' not in properties['ip'] def test_modelviewset(self): - class ExampleModel(models.Model): - text = models.TextField() - - class ExampleSerializer(serializers.ModelSerializer): - class Meta: - model = ExampleModel - fields = ['id', 'text'] - - class ExampleViewSet(viewsets.ModelViewSet): - serializer_class = ExampleSerializer - queryset = ExampleModel.objects.none() - router = routers.DefaultRouter() router.register(r'example', ExampleViewSet)