diff --git a/rest_framework/fields.py b/rest_framework/fields.py index fdfba13f2..3b41e4015 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,6 +8,7 @@ import uuid import warnings from collections import OrderedDict from collections.abc import Mapping +from typing import Generic, TypeVar from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -308,8 +309,13 @@ MISSING_ERROR_MESSAGE = ( 'not exist in the `error_messages` dictionary.' ) +_IN = TypeVar("_IN") # Instance Type +_VT = TypeVar("_VT") # Value Type +_DT = TypeVar("_DT") # Data Type +_RP = TypeVar("_RP") # Representation Type -class Field: + +class Field(Generic[_VT, _DT, _RP, _IN]): _creation_counter = 0 default_error_messages = { diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index fc4eb1428..8b0fc7381 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -5,6 +5,7 @@ They give us a generic way of being able to handle various media types on the request, such as form content or json encoded data. """ import codecs +from typing import Generic, TypeVar from urllib import parse from django.conf import settings @@ -21,8 +22,11 @@ from rest_framework.exceptions import ParseError from rest_framework.settings import api_settings from rest_framework.utils import json +_Data = TypeVar("_Data") +_Files = TypeVar("_Files") -class DataAndFiles: + +class DataAndFiles(Generic[_Data, _Files]): def __init__(self, data, files): self.data = data self.files = files diff --git a/rest_framework/relations.py b/rest_framework/relations.py index eaf27e1d9..3a535f34c 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,9 +1,10 @@ import sys from collections import OrderedDict +from typing import Any, Generic, TypeVar from urllib import parse from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.db.models import Manager +from django.db.models import Manager, Model from django.db.models.query import QuerySet from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve from django.utils.encoding import smart_str, uri_to_iri @@ -85,8 +86,12 @@ MANY_RELATION_KWARGS = ( 'html_cutoff', 'html_cutoff_text' ) +_MT = TypeVar("_MT", bound=Model) +_DT = TypeVar("_DT") # Data Type +_PT = TypeVar("_PT") # Primitive Type -class RelatedField(Field): + +class RelatedField(Generic[_MT, _DT, _PT], Field[_MT, _DT, _PT, Any]): queryset = None html_cutoff = None html_cutoff_text = None diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 49eec8259..35d52610f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,9 +12,11 @@ response content is handled by parsers and renderers. """ import copy import inspect +import sys import traceback from collections import OrderedDict, defaultdict from collections.abc import Mapping +from typing import Any, Generic, TypeVar from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.exceptions import ValidationError as DjangoValidationError @@ -66,6 +68,13 @@ from rest_framework.fields import ( # NOQA # isort:skip ) from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip +if sys.version_info < (3, 7): + from typing import GenericMeta +else: + class GenericMeta(type): + pass + + # We assume that 'validators' are intended for the child serializer, # rather than the parent serializer. LIST_SERIALIZER_KWARGS = ( @@ -79,8 +88,10 @@ ALL_FIELDS = '__all__' # BaseSerializer # -------------- +_IN = TypeVar("_IN") # Instance Type -class BaseSerializer(Field): + +class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): """ The BaseSerializer class provides a minimal class which may be used for writing custom serializer implementations. @@ -121,10 +132,6 @@ class BaseSerializer(Field): return cls.many_init(*args, **kwargs) return super().__new__(cls, *args, **kwargs) - # Allow type checkers to make serializers generic. - def __class_getitem__(cls, *args, **kwargs): - return cls - @classmethod def many_init(cls, *args, **kwargs): """ @@ -268,7 +275,7 @@ class BaseSerializer(Field): # Serializer & ListSerializer classes # ----------------------------------- -class SerializerMetaclass(type): +class SerializerMetaclass(GenericMeta): """ This metaclass sets a dictionary named `_declared_fields` on the class. @@ -301,9 +308,9 @@ class SerializerMetaclass(type): return OrderedDict(base_fields + fields) - def __new__(cls, name, bases, attrs): + def __new__(cls, name, bases, attrs, *args, **kwargs): attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) - return super().__new__(cls, name, bases, attrs) + return super().__new__(cls, name, bases, attrs, *args, **kwargs) def as_serializer_error(exc): @@ -332,7 +339,7 @@ def as_serializer_error(exc): } -class Serializer(BaseSerializer, metaclass=SerializerMetaclass): +class Serializer(BaseSerializer[_IN], metaclass=SerializerMetaclass): default_error_messages = { 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') } @@ -562,7 +569,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass): # There's some replication of `ListField` here, # but that's probably better than obfuscating the call hierarchy. -class ListSerializer(BaseSerializer): +class ListSerializer(BaseSerializer[_IN]): child = None many = True @@ -836,7 +843,10 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data): ) -class ModelSerializer(Serializer): +_MT = TypeVar("_MT", bound=models.Model) # Model Type + + +class ModelSerializer(Serializer[_MT]): """ A `ModelSerializer` is just a regular `Serializer`, except that: diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 4f8a4f192..79154c19d 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -3,6 +3,7 @@ Helper functions for mapping model fields to a dictionary of default keyword arguments that should be used for their equivalent serializer fields. """ import inspect +from typing import Generic, TypeVar from django.core import validators from django.db import models @@ -16,7 +17,11 @@ NUMERIC_FIELD_TYPES = ( ) -class ClassLookupDict: +_K = TypeVar("_K", bound=type) +_V = TypeVar("_V") + + +class ClassLookupDict(Generic[_K, _V]): """ Takes a dictionary with classes as keys. Lookups against this object will traverses the object's inheritance diff --git a/tests/test_fields.py b/tests/test_fields.py index fdd570d8a..c8d2f87a2 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -14,7 +14,7 @@ from django.utils.timezone import activate, deactivate, override, utc import rest_framework from rest_framework import exceptions, serializers from rest_framework.fields import ( - BuiltinSignatureError, DjangoImageField, is_simple_callable + BuiltinSignatureError, DjangoImageField, Field, is_simple_callable ) # Tests for helper functions. @@ -2380,3 +2380,21 @@ class TestValidationErrorCode: ), ] } + + +class TestField: + def test_type_annotation(self): + assert Field[int, int, int, int] is not Field + + def test_multiple_type_params_needed_when_hinting_class(self): + with pytest.raises(TypeError): + Field[int] + + with pytest.raises(TypeError): + Field[int, int] + + with pytest.raises(TypeError): + Field[int, int, int] + + with pytest.raises(TypeError): + Field[int, int, int, int, int] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index dcd62fac9..0e5a915f5 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -11,7 +11,7 @@ from django.test import TestCase from rest_framework.exceptions import ParseError from rest_framework.parsers import ( - FileUploadParser, FormParser, JSONParser, MultiPartParser + DataAndFiles, FileUploadParser, FormParser, JSONParser, MultiPartParser ) from rest_framework.request import Request from rest_framework.test import APIRequestFactory @@ -176,3 +176,19 @@ class TestPOSTAccessed(TestCase): with pytest.raises(RawPostDataException): request.POST request.data + + +class TestDataAndFiles: + def test_type_annotation(self): + """ + This class inherits directly from Generic, so adding type hints to it should + yield a different class. + """ + assert DataAndFiles[int, int] is not DataAndFiles + + def test_need_multiple_type_params_when_hinting_class(self): + with pytest.raises(TypeError): + DataAndFiles[int] + + with pytest.raises(TypeError): + DataAndFiles[int, int, int] diff --git a/tests/test_relations.py b/tests/test_relations.py index 92aeecf6c..f02b637eb 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -376,3 +376,18 @@ class TestHyperlink: upkled = pickle.loads(pickle.dumps(self.default_hyperlink)) assert upkled == self.default_hyperlink assert upkled.name == self.default_hyperlink.name + + +class TestRelatedField: + def test_type_annotation(self): + assert relations.RelatedField[int, int, int] is not relations.RelatedField + + def test_multiple_type_params_needed_when_hinting_class(self): + with pytest.raises(TypeError): + relations.RelatedField[int] + + with pytest.raises(TypeError): + relations.RelatedField[int, int] + + with pytest.raises(TypeError): + relations.RelatedField[int, int, int, int] diff --git a/tests/test_serializer.py b/tests/test_serializer.py index afefd70e1..cc40407e4 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -209,8 +209,19 @@ class TestSerializer: sys.version_info < (3, 7), reason="subscriptable classes requires Python 3.7 or higher", ) - def test_serializer_is_subscriptable(self): - assert serializers.Serializer is serializers.Serializer["foo"] + def test_type_annotation(self): + assert serializers.Serializer is not serializers.Serializer["foo"] + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.Serializer[int] is not serializers.Serializer class TestValidateMethod: @@ -322,6 +333,28 @@ class TestBaseSerializer: {'id': 2, 'name': 'ann', 'domain': 'example.com'} ] + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.BaseSerializer[int] is not serializers.BaseSerializer + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.BaseSerializer[int] is not serializers.BaseSerializer + class TestStarredSource: """ @@ -740,3 +773,27 @@ class TestDeclaredFieldInheritance: 'f4': serializers.CharField, 'f5': serializers.CharField, } + + +class TestModelSerializer: + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.ModelSerializer[int] is not serializers.ModelSerializer + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.ModelSerializer[int] is not serializers.ModelSerializer diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index f35c4fcc9..5ad7eeb28 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -62,7 +62,18 @@ class TestListSerializer: reason="subscriptable classes requires Python 3.7 or higher", ) def test_list_serializer_is_subscriptable(self): - assert serializers.ListSerializer is serializers.ListSerializer["foo"] + assert serializers.ListSerializer is not serializers.ListSerializer["foo"] + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.ListSerializer[int] is not serializers.ListSerializer class TestListSerializerContainingNestedSerializer: diff --git a/tests/test_utils.py b/tests/test_utils.py index c72f680fe..a8371d8b6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ from unittest import mock +import pytest from django.test import TestCase, override_settings from django.urls import path @@ -8,6 +9,7 @@ from rest_framework.routers import SimpleRouter from rest_framework.serializers import ModelSerializer from rest_framework.utils import json from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.utils.field_mapping import ClassLookupDict from rest_framework.utils.formatting import lazy_format from rest_framework.utils.urls import remove_query_param, replace_query_param from rest_framework.views import APIView @@ -267,3 +269,15 @@ class LazyFormatTests(TestCase): assert message.format.call_count == 1 str(formatted) assert message.format.call_count == 1 + + +class ClassLookupDictTests(TestCase): + def test_type_annotation(self): + assert ClassLookupDict[int, int] is not ClassLookupDict + + def test_need_multiple_type_params_when_hinting_class(self): + with pytest.raises(TypeError): + ClassLookupDict[None] + + with pytest.raises(TypeError): + ClassLookupDict[None, None, None]