diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index cfb54de13..916f8bec4 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -121,6 +121,10 @@ class BaseSerializer(Field): return cls.many_init(*args, **kwargs) return super().__new__(cls, *args, **kwargs) + # Allow type checkers to make serializers generic. + def __class_getitem__(cls, *args, **kwargs): + return cls + @classmethod def many_init(cls, *args, **kwargs): """ diff --git a/tests/test_serializer.py b/tests/test_serializer.py index a58c46b2d..afefd70e1 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,6 +1,7 @@ import inspect import pickle import re +import sys from collections import ChainMap from collections.abc import Mapping @@ -204,6 +205,13 @@ class TestSerializer: exceptions.ErrorDetail(string='Raised error', code='invalid') ]} + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_serializer_is_subscriptable(self): + assert serializers.Serializer is serializers.Serializer["foo"] + class TestValidateMethod: def test_non_field_error_validate_method(self): diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index 98e72385a..f35c4fcc9 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -1,3 +1,5 @@ +import sys + import pytest from django.http import QueryDict from django.utils.datastructures import MultiValueDict @@ -55,6 +57,13 @@ class TestListSerializer: assert serializer.is_valid() assert serializer.validated_data == expected_output + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_list_serializer_is_subscriptable(self): + assert serializers.ListSerializer is serializers.ListSerializer["foo"] + class TestListSerializerContainingNestedSerializer: """