Adding Generic support for various classes

Specifically, have the following classes inherit from Generic:

* ClassLookupDict
* DataAndFiles
* BaseSerializer
* Field
* RelatedField

Addresses: https://github.com/encode/django-rest-framework/issues/7624
This commit is contained in:
Brady Kieffer 2020-11-03 15:51:31 -05:00
parent 56e4508123
commit 853c83c7eb
11 changed files with 182 additions and 21 deletions

View File

@ -8,6 +8,7 @@ import uuid
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping
from typing import Generic, TypeVar
from django.conf import settings from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
@ -308,8 +309,13 @@ MISSING_ERROR_MESSAGE = (
'not exist in the `error_messages` dictionary.' 'not exist in the `error_messages` dictionary.'
) )
_IN = TypeVar("_IN") # Instance Type
_VT = TypeVar("_VT") # Value Type
_DT = TypeVar("_DT") # Data Type
_RP = TypeVar("_RP") # Representation Type
class Field:
class Field(Generic[_VT, _DT, _RP, _IN]):
_creation_counter = 0 _creation_counter = 0
default_error_messages = { default_error_messages = {

View File

@ -5,6 +5,7 @@ They give us a generic way of being able to handle various media types
on the request, such as form content or json encoded data. on the request, such as form content or json encoded data.
""" """
import codecs import codecs
from typing import Generic, TypeVar
from urllib import parse from urllib import parse
from django.conf import settings from django.conf import settings
@ -21,8 +22,11 @@ from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import json from rest_framework.utils import json
_Data = TypeVar("_Data")
_Files = TypeVar("_Files")
class DataAndFiles:
class DataAndFiles(Generic[_Data, _Files]):
def __init__(self, data, files): def __init__(self, data, files):
self.data = data self.data = data
self.files = files self.files = files

View File

@ -1,9 +1,10 @@
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Generic, TypeVar
from urllib import parse from urllib import parse
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import Manager from django.db.models import Manager, Model
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils.encoding import smart_str, uri_to_iri from django.utils.encoding import smart_str, uri_to_iri
@ -85,8 +86,12 @@ MANY_RELATION_KWARGS = (
'html_cutoff', 'html_cutoff_text' 'html_cutoff', 'html_cutoff_text'
) )
_MT = TypeVar("_MT", bound=Model)
_DT = TypeVar("_DT") # Data Type
_PT = TypeVar("_PT") # Primitive Type
class RelatedField(Field):
class RelatedField(Generic[_MT, _DT, _PT], Field[_MT, _DT, _PT, Any]):
queryset = None queryset = None
html_cutoff = None html_cutoff = None
html_cutoff_text = None html_cutoff_text = None

View File

@ -12,9 +12,11 @@ response content is handled by parsers and renderers.
""" """
import copy import copy
import inspect import inspect
import sys
import traceback import traceback
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Generic, TypeVar
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ValidationError as DjangoValidationError
@ -66,6 +68,13 @@ from rest_framework.fields import ( # NOQA # isort:skip
) )
from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip
if sys.version_info < (3, 7):
from typing import GenericMeta
else:
class GenericMeta(type):
pass
# We assume that 'validators' are intended for the child serializer, # We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer. # rather than the parent serializer.
LIST_SERIALIZER_KWARGS = ( LIST_SERIALIZER_KWARGS = (
@ -79,8 +88,10 @@ ALL_FIELDS = '__all__'
# BaseSerializer # BaseSerializer
# -------------- # --------------
_IN = TypeVar("_IN") # Instance Type
class BaseSerializer(Field):
class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]):
""" """
The BaseSerializer class provides a minimal class which may be used The BaseSerializer class provides a minimal class which may be used
for writing custom serializer implementations. for writing custom serializer implementations.
@ -121,10 +132,6 @@ class BaseSerializer(Field):
return cls.many_init(*args, **kwargs) return cls.many_init(*args, **kwargs)
return super().__new__(cls, *args, **kwargs) return super().__new__(cls, *args, **kwargs)
# Allow type checkers to make serializers generic.
def __class_getitem__(cls, *args, **kwargs):
return cls
@classmethod @classmethod
def many_init(cls, *args, **kwargs): def many_init(cls, *args, **kwargs):
""" """
@ -268,7 +275,7 @@ class BaseSerializer(Field):
# Serializer & ListSerializer classes # Serializer & ListSerializer classes
# ----------------------------------- # -----------------------------------
class SerializerMetaclass(type): class SerializerMetaclass(GenericMeta):
""" """
This metaclass sets a dictionary named `_declared_fields` on the class. This metaclass sets a dictionary named `_declared_fields` on the class.
@ -301,9 +308,9 @@ class SerializerMetaclass(type):
return OrderedDict(base_fields + fields) return OrderedDict(base_fields + fields)
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs, *args, **kwargs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super().__new__(cls, name, bases, attrs) return super().__new__(cls, name, bases, attrs, *args, **kwargs)
def as_serializer_error(exc): def as_serializer_error(exc):
@ -332,7 +339,7 @@ def as_serializer_error(exc):
} }
class Serializer(BaseSerializer, metaclass=SerializerMetaclass): class Serializer(BaseSerializer[_IN], metaclass=SerializerMetaclass):
default_error_messages = { default_error_messages = {
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
} }
@ -562,7 +569,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
# There's some replication of `ListField` here, # There's some replication of `ListField` here,
# but that's probably better than obfuscating the call hierarchy. # but that's probably better than obfuscating the call hierarchy.
class ListSerializer(BaseSerializer): class ListSerializer(BaseSerializer[_IN]):
child = None child = None
many = True many = True
@ -836,7 +843,10 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
) )
class ModelSerializer(Serializer): _MT = TypeVar("_MT", bound=models.Model) # Model Type
class ModelSerializer(Serializer[_MT]):
""" """
A `ModelSerializer` is just a regular `Serializer`, except that: A `ModelSerializer` is just a regular `Serializer`, except that:

View File

@ -3,6 +3,7 @@ Helper functions for mapping model fields to a dictionary of default
keyword arguments that should be used for their equivalent serializer fields. keyword arguments that should be used for their equivalent serializer fields.
""" """
import inspect import inspect
from typing import Generic, TypeVar
from django.core import validators from django.core import validators
from django.db import models from django.db import models
@ -16,7 +17,11 @@ NUMERIC_FIELD_TYPES = (
) )
class ClassLookupDict: _K = TypeVar("_K", bound=type)
_V = TypeVar("_V")
class ClassLookupDict(Generic[_K, _V]):
""" """
Takes a dictionary with classes as keys. Takes a dictionary with classes as keys.
Lookups against this object will traverses the object's inheritance Lookups against this object will traverses the object's inheritance

View File

@ -14,7 +14,7 @@ from django.utils.timezone import activate, deactivate, override, utc
import rest_framework import rest_framework
from rest_framework import exceptions, serializers from rest_framework import exceptions, serializers
from rest_framework.fields import ( from rest_framework.fields import (
BuiltinSignatureError, DjangoImageField, is_simple_callable BuiltinSignatureError, DjangoImageField, Field, is_simple_callable
) )
# Tests for helper functions. # Tests for helper functions.
@ -2380,3 +2380,21 @@ class TestValidationErrorCode:
), ),
] ]
} }
class TestField:
def test_type_annotation(self):
assert Field[int, int, int, int] is not Field
def test_multiple_type_params_needed_when_hinting_class(self):
with pytest.raises(TypeError):
Field[int]
with pytest.raises(TypeError):
Field[int, int]
with pytest.raises(TypeError):
Field[int, int, int]
with pytest.raises(TypeError):
Field[int, int, int, int, int]

View File

@ -11,7 +11,7 @@ from django.test import TestCase
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.parsers import ( from rest_framework.parsers import (
FileUploadParser, FormParser, JSONParser, MultiPartParser DataAndFiles, FileUploadParser, FormParser, JSONParser, MultiPartParser
) )
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
@ -176,3 +176,19 @@ class TestPOSTAccessed(TestCase):
with pytest.raises(RawPostDataException): with pytest.raises(RawPostDataException):
request.POST request.POST
request.data request.data
class TestDataAndFiles:
def test_type_annotation(self):
"""
This class inherits directly from Generic, so adding type hints to it should
yield a different class.
"""
assert DataAndFiles[int, int] is not DataAndFiles
def test_need_multiple_type_params_when_hinting_class(self):
with pytest.raises(TypeError):
DataAndFiles[int]
with pytest.raises(TypeError):
DataAndFiles[int, int, int]

View File

@ -376,3 +376,18 @@ class TestHyperlink:
upkled = pickle.loads(pickle.dumps(self.default_hyperlink)) upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
assert upkled == self.default_hyperlink assert upkled == self.default_hyperlink
assert upkled.name == self.default_hyperlink.name assert upkled.name == self.default_hyperlink.name
class TestRelatedField:
def test_type_annotation(self):
assert relations.RelatedField[int, int, int] is not relations.RelatedField
def test_multiple_type_params_needed_when_hinting_class(self):
with pytest.raises(TypeError):
relations.RelatedField[int]
with pytest.raises(TypeError):
relations.RelatedField[int, int]
with pytest.raises(TypeError):
relations.RelatedField[int, int, int, int]

View File

@ -209,8 +209,19 @@ class TestSerializer:
sys.version_info < (3, 7), sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher", reason="subscriptable classes requires Python 3.7 or higher",
) )
def test_serializer_is_subscriptable(self): def test_type_annotation(self):
assert serializers.Serializer is serializers.Serializer["foo"] assert serializers.Serializer is not serializers.Serializer["foo"]
@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.Serializer[int] is not serializers.Serializer
class TestValidateMethod: class TestValidateMethod:
@ -322,6 +333,28 @@ class TestBaseSerializer:
{'id': 2, 'name': 'ann', 'domain': 'example.com'} {'id': 2, 'name': 'ann', 'domain': 'example.com'}
] ]
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer
@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer
class TestStarredSource: class TestStarredSource:
""" """
@ -740,3 +773,27 @@ class TestDeclaredFieldInheritance:
'f4': serializers.CharField, 'f4': serializers.CharField,
'f5': serializers.CharField, 'f5': serializers.CharField,
} }
class TestModelSerializer:
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer
@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer

View File

@ -62,7 +62,18 @@ class TestListSerializer:
reason="subscriptable classes requires Python 3.7 or higher", reason="subscriptable classes requires Python 3.7 or higher",
) )
def test_list_serializer_is_subscriptable(self): def test_list_serializer_is_subscriptable(self):
assert serializers.ListSerializer is serializers.ListSerializer["foo"] assert serializers.ListSerializer is not serializers.ListSerializer["foo"]
@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.ListSerializer[int] is not serializers.ListSerializer
class TestListSerializerContainingNestedSerializer: class TestListSerializerContainingNestedSerializer:

View File

@ -1,5 +1,6 @@
from unittest import mock from unittest import mock
import pytest
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import path from django.urls import path
@ -8,6 +9,7 @@ from rest_framework.routers import SimpleRouter
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from rest_framework.utils import json from rest_framework.utils import json
from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.utils.field_mapping import ClassLookupDict
from rest_framework.utils.formatting import lazy_format from rest_framework.utils.formatting import lazy_format
from rest_framework.utils.urls import remove_query_param, replace_query_param from rest_framework.utils.urls import remove_query_param, replace_query_param
from rest_framework.views import APIView from rest_framework.views import APIView
@ -267,3 +269,15 @@ class LazyFormatTests(TestCase):
assert message.format.call_count == 1 assert message.format.call_count == 1
str(formatted) str(formatted)
assert message.format.call_count == 1 assert message.format.call_count == 1
class ClassLookupDictTests(TestCase):
def test_type_annotation(self):
assert ClassLookupDict[int, int] is not ClassLookupDict
def test_need_multiple_type_params_when_hinting_class(self):
with pytest.raises(TypeError):
ClassLookupDict[None]
with pytest.raises(TypeError):
ClassLookupDict[None, None, None]