mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-29 17:39:48 +03:00
Adding Generic support for various classes
Specifically, have the following classes inherit from Generic: * ClassLookupDict * DataAndFiles * BaseSerializer * Field * RelatedField Addresses: https://github.com/encode/django-rest-framework/issues/7624
This commit is contained in:
parent
56e4508123
commit
853c83c7eb
|
@ -8,6 +8,7 @@ import uuid
|
|||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
|
@ -308,8 +309,13 @@ MISSING_ERROR_MESSAGE = (
|
|||
'not exist in the `error_messages` dictionary.'
|
||||
)
|
||||
|
||||
_IN = TypeVar("_IN") # Instance Type
|
||||
_VT = TypeVar("_VT") # Value Type
|
||||
_DT = TypeVar("_DT") # Data Type
|
||||
_RP = TypeVar("_RP") # Representation Type
|
||||
|
||||
class Field:
|
||||
|
||||
class Field(Generic[_VT, _DT, _RP, _IN]):
|
||||
_creation_counter = 0
|
||||
|
||||
default_error_messages = {
|
||||
|
|
|
@ -5,6 +5,7 @@ They give us a generic way of being able to handle various media types
|
|||
on the request, such as form content or json encoded data.
|
||||
"""
|
||||
import codecs
|
||||
from typing import Generic, TypeVar
|
||||
from urllib import parse
|
||||
|
||||
from django.conf import settings
|
||||
|
@ -21,8 +22,11 @@ from rest_framework.exceptions import ParseError
|
|||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import json
|
||||
|
||||
_Data = TypeVar("_Data")
|
||||
_Files = TypeVar("_Files")
|
||||
|
||||
class DataAndFiles:
|
||||
|
||||
class DataAndFiles(Generic[_Data, _Files]):
|
||||
def __init__(self, data, files):
|
||||
self.data = data
|
||||
self.files = files
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Generic, TypeVar
|
||||
from urllib import parse
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
|
||||
from django.db.models import Manager
|
||||
from django.db.models import Manager, Model
|
||||
from django.db.models.query import QuerySet
|
||||
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
|
||||
from django.utils.encoding import smart_str, uri_to_iri
|
||||
|
@ -85,8 +86,12 @@ MANY_RELATION_KWARGS = (
|
|||
'html_cutoff', 'html_cutoff_text'
|
||||
)
|
||||
|
||||
_MT = TypeVar("_MT", bound=Model)
|
||||
_DT = TypeVar("_DT") # Data Type
|
||||
_PT = TypeVar("_PT") # Primitive Type
|
||||
|
||||
class RelatedField(Field):
|
||||
|
||||
class RelatedField(Generic[_MT, _DT, _PT], Field[_MT, _DT, _PT, Any]):
|
||||
queryset = None
|
||||
html_cutoff = None
|
||||
html_cutoff_text = None
|
||||
|
|
|
@ -12,9 +12,11 @@ response content is handled by parsers and renderers.
|
|||
"""
|
||||
import copy
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
||||
from django.core.exceptions import ValidationError as DjangoValidationError
|
||||
|
@ -66,6 +68,13 @@ from rest_framework.fields import ( # NOQA # isort:skip
|
|||
)
|
||||
from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
from typing import GenericMeta
|
||||
else:
|
||||
class GenericMeta(type):
|
||||
pass
|
||||
|
||||
|
||||
# We assume that 'validators' are intended for the child serializer,
|
||||
# rather than the parent serializer.
|
||||
LIST_SERIALIZER_KWARGS = (
|
||||
|
@ -79,8 +88,10 @@ ALL_FIELDS = '__all__'
|
|||
|
||||
# BaseSerializer
|
||||
# --------------
|
||||
_IN = TypeVar("_IN") # Instance Type
|
||||
|
||||
class BaseSerializer(Field):
|
||||
|
||||
class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]):
|
||||
"""
|
||||
The BaseSerializer class provides a minimal class which may be used
|
||||
for writing custom serializer implementations.
|
||||
|
@ -121,10 +132,6 @@ class BaseSerializer(Field):
|
|||
return cls.many_init(*args, **kwargs)
|
||||
return super().__new__(cls, *args, **kwargs)
|
||||
|
||||
# Allow type checkers to make serializers generic.
|
||||
def __class_getitem__(cls, *args, **kwargs):
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def many_init(cls, *args, **kwargs):
|
||||
"""
|
||||
|
@ -268,7 +275,7 @@ class BaseSerializer(Field):
|
|||
# Serializer & ListSerializer classes
|
||||
# -----------------------------------
|
||||
|
||||
class SerializerMetaclass(type):
|
||||
class SerializerMetaclass(GenericMeta):
|
||||
"""
|
||||
This metaclass sets a dictionary named `_declared_fields` on the class.
|
||||
|
||||
|
@ -301,9 +308,9 @@ class SerializerMetaclass(type):
|
|||
|
||||
return OrderedDict(base_fields + fields)
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
def __new__(cls, name, bases, attrs, *args, **kwargs):
|
||||
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
|
||||
return super().__new__(cls, name, bases, attrs)
|
||||
return super().__new__(cls, name, bases, attrs, *args, **kwargs)
|
||||
|
||||
|
||||
def as_serializer_error(exc):
|
||||
|
@ -332,7 +339,7 @@ def as_serializer_error(exc):
|
|||
}
|
||||
|
||||
|
||||
class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
|
||||
class Serializer(BaseSerializer[_IN], metaclass=SerializerMetaclass):
|
||||
default_error_messages = {
|
||||
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
|
||||
}
|
||||
|
@ -562,7 +569,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
|
|||
# There's some replication of `ListField` here,
|
||||
# but that's probably better than obfuscating the call hierarchy.
|
||||
|
||||
class ListSerializer(BaseSerializer):
|
||||
class ListSerializer(BaseSerializer[_IN]):
|
||||
child = None
|
||||
many = True
|
||||
|
||||
|
@ -836,7 +843,10 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
|
|||
)
|
||||
|
||||
|
||||
class ModelSerializer(Serializer):
|
||||
_MT = TypeVar("_MT", bound=models.Model) # Model Type
|
||||
|
||||
|
||||
class ModelSerializer(Serializer[_MT]):
|
||||
"""
|
||||
A `ModelSerializer` is just a regular `Serializer`, except that:
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ Helper functions for mapping model fields to a dictionary of default
|
|||
keyword arguments that should be used for their equivalent serializer fields.
|
||||
"""
|
||||
import inspect
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from django.core import validators
|
||||
from django.db import models
|
||||
|
@ -16,7 +17,11 @@ NUMERIC_FIELD_TYPES = (
|
|||
)
|
||||
|
||||
|
||||
class ClassLookupDict:
|
||||
_K = TypeVar("_K", bound=type)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class ClassLookupDict(Generic[_K, _V]):
|
||||
"""
|
||||
Takes a dictionary with classes as keys.
|
||||
Lookups against this object will traverses the object's inheritance
|
||||
|
|
|
@ -14,7 +14,7 @@ from django.utils.timezone import activate, deactivate, override, utc
|
|||
import rest_framework
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.fields import (
|
||||
BuiltinSignatureError, DjangoImageField, is_simple_callable
|
||||
BuiltinSignatureError, DjangoImageField, Field, is_simple_callable
|
||||
)
|
||||
|
||||
# Tests for helper functions.
|
||||
|
@ -2380,3 +2380,21 @@ class TestValidationErrorCode:
|
|||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class TestField:
|
||||
def test_type_annotation(self):
|
||||
assert Field[int, int, int, int] is not Field
|
||||
|
||||
def test_multiple_type_params_needed_when_hinting_class(self):
|
||||
with pytest.raises(TypeError):
|
||||
Field[int]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Field[int, int]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Field[int, int, int]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Field[int, int, int, int, int]
|
||||
|
|
|
@ -11,7 +11,7 @@ from django.test import TestCase
|
|||
|
||||
from rest_framework.exceptions import ParseError
|
||||
from rest_framework.parsers import (
|
||||
FileUploadParser, FormParser, JSONParser, MultiPartParser
|
||||
DataAndFiles, FileUploadParser, FormParser, JSONParser, MultiPartParser
|
||||
)
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
@ -176,3 +176,19 @@ class TestPOSTAccessed(TestCase):
|
|||
with pytest.raises(RawPostDataException):
|
||||
request.POST
|
||||
request.data
|
||||
|
||||
|
||||
class TestDataAndFiles:
|
||||
def test_type_annotation(self):
|
||||
"""
|
||||
This class inherits directly from Generic, so adding type hints to it should
|
||||
yield a different class.
|
||||
"""
|
||||
assert DataAndFiles[int, int] is not DataAndFiles
|
||||
|
||||
def test_need_multiple_type_params_when_hinting_class(self):
|
||||
with pytest.raises(TypeError):
|
||||
DataAndFiles[int]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
DataAndFiles[int, int, int]
|
||||
|
|
|
@ -376,3 +376,18 @@ class TestHyperlink:
|
|||
upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
|
||||
assert upkled == self.default_hyperlink
|
||||
assert upkled.name == self.default_hyperlink.name
|
||||
|
||||
|
||||
class TestRelatedField:
|
||||
def test_type_annotation(self):
|
||||
assert relations.RelatedField[int, int, int] is not relations.RelatedField
|
||||
|
||||
def test_multiple_type_params_needed_when_hinting_class(self):
|
||||
with pytest.raises(TypeError):
|
||||
relations.RelatedField[int]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
relations.RelatedField[int, int]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
relations.RelatedField[int, int, int, int]
|
||||
|
|
|
@ -209,8 +209,19 @@ class TestSerializer:
|
|||
sys.version_info < (3, 7),
|
||||
reason="subscriptable classes requires Python 3.7 or higher",
|
||||
)
|
||||
def test_serializer_is_subscriptable(self):
|
||||
assert serializers.Serializer is serializers.Serializer["foo"]
|
||||
def test_type_annotation(self):
|
||||
assert serializers.Serializer is not serializers.Serializer["foo"]
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info > (3, 5),
|
||||
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||
)
|
||||
def test_type_annotation_pre_36(self):
|
||||
"""
|
||||
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||
should not yield a different class.
|
||||
"""
|
||||
assert serializers.Serializer[int] is not serializers.Serializer
|
||||
|
||||
|
||||
class TestValidateMethod:
|
||||
|
@ -322,6 +333,28 @@ class TestBaseSerializer:
|
|||
{'id': 2, 'name': 'ann', 'domain': 'example.com'}
|
||||
]
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 7),
|
||||
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||
)
|
||||
def test_type_annotation(self):
|
||||
"""
|
||||
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||
should not yield a different class.
|
||||
"""
|
||||
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info > (3, 5),
|
||||
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||
)
|
||||
def test_type_annotation_pre_36(self):
|
||||
"""
|
||||
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||
should not yield a different class.
|
||||
"""
|
||||
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer
|
||||
|
||||
|
||||
class TestStarredSource:
|
||||
"""
|
||||
|
@ -740,3 +773,27 @@ class TestDeclaredFieldInheritance:
|
|||
'f4': serializers.CharField,
|
||||
'f5': serializers.CharField,
|
||||
}
|
||||
|
||||
|
||||
class TestModelSerializer:
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 7),
|
||||
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||
)
|
||||
def test_type_annotation(self):
|
||||
"""
|
||||
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||
should not yield a different class.
|
||||
"""
|
||||
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info > (3, 5),
|
||||
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||
)
|
||||
def test_type_annotation_pre_36(self):
|
||||
"""
|
||||
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||
should not yield a different class.
|
||||
"""
|
||||
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer
|
||||
|
|
|
@ -62,7 +62,18 @@ class TestListSerializer:
|
|||
reason="subscriptable classes requires Python 3.7 or higher",
|
||||
)
|
||||
def test_list_serializer_is_subscriptable(self):
|
||||
assert serializers.ListSerializer is serializers.ListSerializer["foo"]
|
||||
assert serializers.ListSerializer is not serializers.ListSerializer["foo"]
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info > (3, 5),
|
||||
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||
)
|
||||
def test_type_annotation_pre_36(self):
|
||||
"""
|
||||
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||
should not yield a different class.
|
||||
"""
|
||||
assert serializers.ListSerializer[int] is not serializers.ListSerializer
|
||||
|
||||
|
||||
class TestListSerializerContainingNestedSerializer:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.test import TestCase, override_settings
|
||||
from django.urls import path
|
||||
|
||||
|
@ -8,6 +9,7 @@ from rest_framework.routers import SimpleRouter
|
|||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.utils import json
|
||||
from rest_framework.utils.breadcrumbs import get_breadcrumbs
|
||||
from rest_framework.utils.field_mapping import ClassLookupDict
|
||||
from rest_framework.utils.formatting import lazy_format
|
||||
from rest_framework.utils.urls import remove_query_param, replace_query_param
|
||||
from rest_framework.views import APIView
|
||||
|
@ -267,3 +269,15 @@ class LazyFormatTests(TestCase):
|
|||
assert message.format.call_count == 1
|
||||
str(formatted)
|
||||
assert message.format.call_count == 1
|
||||
|
||||
|
||||
class ClassLookupDictTests(TestCase):
|
||||
def test_type_annotation(self):
|
||||
assert ClassLookupDict[int, int] is not ClassLookupDict
|
||||
|
||||
def test_need_multiple_type_params_when_hinting_class(self):
|
||||
with pytest.raises(TypeError):
|
||||
ClassLookupDict[None]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
ClassLookupDict[None, None, None]
|
||||
|
|
Loading…
Reference in New Issue
Block a user