mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-29 17:39:48 +03:00
Adding Generic support for various classes
Specifically, have the following classes inherit from Generic: * ClassLookupDict * DataAndFiles * BaseSerializer * Field * RelatedField Addresses: https://github.com/encode/django-rest-framework/issues/7624
This commit is contained in:
parent
56e4508123
commit
853c83c7eb
|
@ -8,6 +8,7 @@ import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
|
@ -308,8 +309,13 @@ MISSING_ERROR_MESSAGE = (
|
||||||
'not exist in the `error_messages` dictionary.'
|
'not exist in the `error_messages` dictionary.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_IN = TypeVar("_IN") # Instance Type
|
||||||
|
_VT = TypeVar("_VT") # Value Type
|
||||||
|
_DT = TypeVar("_DT") # Data Type
|
||||||
|
_RP = TypeVar("_RP") # Representation Type
|
||||||
|
|
||||||
class Field:
|
|
||||||
|
class Field(Generic[_VT, _DT, _RP, _IN]):
|
||||||
_creation_counter = 0
|
_creation_counter = 0
|
||||||
|
|
||||||
default_error_messages = {
|
default_error_messages = {
|
||||||
|
|
|
@ -5,6 +5,7 @@ They give us a generic way of being able to handle various media types
|
||||||
on the request, such as form content or json encoded data.
|
on the request, such as form content or json encoded data.
|
||||||
"""
|
"""
|
||||||
import codecs
|
import codecs
|
||||||
|
from typing import Generic, TypeVar
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
@ -21,8 +22,11 @@ from rest_framework.exceptions import ParseError
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
from rest_framework.utils import json
|
from rest_framework.utils import json
|
||||||
|
|
||||||
|
_Data = TypeVar("_Data")
|
||||||
|
_Files = TypeVar("_Files")
|
||||||
|
|
||||||
class DataAndFiles:
|
|
||||||
|
class DataAndFiles(Generic[_Data, _Files]):
|
||||||
def __init__(self, data, files):
|
def __init__(self, data, files):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.files = files
|
self.files = files
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import sys
|
import sys
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
|
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
|
||||||
from django.db.models import Manager
|
from django.db.models import Manager, Model
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
|
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
|
||||||
from django.utils.encoding import smart_str, uri_to_iri
|
from django.utils.encoding import smart_str, uri_to_iri
|
||||||
|
@ -85,8 +86,12 @@ MANY_RELATION_KWARGS = (
|
||||||
'html_cutoff', 'html_cutoff_text'
|
'html_cutoff', 'html_cutoff_text'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_MT = TypeVar("_MT", bound=Model)
|
||||||
|
_DT = TypeVar("_DT") # Data Type
|
||||||
|
_PT = TypeVar("_PT") # Primitive Type
|
||||||
|
|
||||||
class RelatedField(Field):
|
|
||||||
|
class RelatedField(Generic[_MT, _DT, _PT], Field[_MT, _DT, _PT, Any]):
|
||||||
queryset = None
|
queryset = None
|
||||||
html_cutoff = None
|
html_cutoff = None
|
||||||
html_cutoff_text = None
|
html_cutoff_text = None
|
||||||
|
|
|
@ -12,9 +12,11 @@ response content is handled by parsers and renderers.
|
||||||
"""
|
"""
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
||||||
from django.core.exceptions import ValidationError as DjangoValidationError
|
from django.core.exceptions import ValidationError as DjangoValidationError
|
||||||
|
@ -66,6 +68,13 @@ from rest_framework.fields import ( # NOQA # isort:skip
|
||||||
)
|
)
|
||||||
from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip
|
from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip
|
||||||
|
|
||||||
|
if sys.version_info < (3, 7):
|
||||||
|
from typing import GenericMeta
|
||||||
|
else:
|
||||||
|
class GenericMeta(type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# We assume that 'validators' are intended for the child serializer,
|
# We assume that 'validators' are intended for the child serializer,
|
||||||
# rather than the parent serializer.
|
# rather than the parent serializer.
|
||||||
LIST_SERIALIZER_KWARGS = (
|
LIST_SERIALIZER_KWARGS = (
|
||||||
|
@ -79,8 +88,10 @@ ALL_FIELDS = '__all__'
|
||||||
|
|
||||||
# BaseSerializer
|
# BaseSerializer
|
||||||
# --------------
|
# --------------
|
||||||
|
_IN = TypeVar("_IN") # Instance Type
|
||||||
|
|
||||||
class BaseSerializer(Field):
|
|
||||||
|
class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]):
|
||||||
"""
|
"""
|
||||||
The BaseSerializer class provides a minimal class which may be used
|
The BaseSerializer class provides a minimal class which may be used
|
||||||
for writing custom serializer implementations.
|
for writing custom serializer implementations.
|
||||||
|
@ -121,10 +132,6 @@ class BaseSerializer(Field):
|
||||||
return cls.many_init(*args, **kwargs)
|
return cls.many_init(*args, **kwargs)
|
||||||
return super().__new__(cls, *args, **kwargs)
|
return super().__new__(cls, *args, **kwargs)
|
||||||
|
|
||||||
# Allow type checkers to make serializers generic.
|
|
||||||
def __class_getitem__(cls, *args, **kwargs):
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def many_init(cls, *args, **kwargs):
|
def many_init(cls, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -268,7 +275,7 @@ class BaseSerializer(Field):
|
||||||
# Serializer & ListSerializer classes
|
# Serializer & ListSerializer classes
|
||||||
# -----------------------------------
|
# -----------------------------------
|
||||||
|
|
||||||
class SerializerMetaclass(type):
|
class SerializerMetaclass(GenericMeta):
|
||||||
"""
|
"""
|
||||||
This metaclass sets a dictionary named `_declared_fields` on the class.
|
This metaclass sets a dictionary named `_declared_fields` on the class.
|
||||||
|
|
||||||
|
@ -301,9 +308,9 @@ class SerializerMetaclass(type):
|
||||||
|
|
||||||
return OrderedDict(base_fields + fields)
|
return OrderedDict(base_fields + fields)
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
def __new__(cls, name, bases, attrs, *args, **kwargs):
|
||||||
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
|
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
|
||||||
return super().__new__(cls, name, bases, attrs)
|
return super().__new__(cls, name, bases, attrs, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def as_serializer_error(exc):
|
def as_serializer_error(exc):
|
||||||
|
@ -332,7 +339,7 @@ def as_serializer_error(exc):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
|
class Serializer(BaseSerializer[_IN], metaclass=SerializerMetaclass):
|
||||||
default_error_messages = {
|
default_error_messages = {
|
||||||
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
|
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
|
||||||
}
|
}
|
||||||
|
@ -562,7 +569,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
|
||||||
# There's some replication of `ListField` here,
|
# There's some replication of `ListField` here,
|
||||||
# but that's probably better than obfuscating the call hierarchy.
|
# but that's probably better than obfuscating the call hierarchy.
|
||||||
|
|
||||||
class ListSerializer(BaseSerializer):
|
class ListSerializer(BaseSerializer[_IN]):
|
||||||
child = None
|
child = None
|
||||||
many = True
|
many = True
|
||||||
|
|
||||||
|
@ -836,7 +843,10 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelSerializer(Serializer):
|
_MT = TypeVar("_MT", bound=models.Model) # Model Type
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSerializer(Serializer[_MT]):
|
||||||
"""
|
"""
|
||||||
A `ModelSerializer` is just a regular `Serializer`, except that:
|
A `ModelSerializer` is just a regular `Serializer`, except that:
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ Helper functions for mapping model fields to a dictionary of default
|
||||||
keyword arguments that should be used for their equivalent serializer fields.
|
keyword arguments that should be used for their equivalent serializer fields.
|
||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
@ -16,7 +17,11 @@ NUMERIC_FIELD_TYPES = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClassLookupDict:
|
_K = TypeVar("_K", bound=type)
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassLookupDict(Generic[_K, _V]):
|
||||||
"""
|
"""
|
||||||
Takes a dictionary with classes as keys.
|
Takes a dictionary with classes as keys.
|
||||||
Lookups against this object will traverses the object's inheritance
|
Lookups against this object will traverses the object's inheritance
|
||||||
|
|
|
@ -14,7 +14,7 @@ from django.utils.timezone import activate, deactivate, override, utc
|
||||||
import rest_framework
|
import rest_framework
|
||||||
from rest_framework import exceptions, serializers
|
from rest_framework import exceptions, serializers
|
||||||
from rest_framework.fields import (
|
from rest_framework.fields import (
|
||||||
BuiltinSignatureError, DjangoImageField, is_simple_callable
|
BuiltinSignatureError, DjangoImageField, Field, is_simple_callable
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tests for helper functions.
|
# Tests for helper functions.
|
||||||
|
@ -2380,3 +2380,21 @@ class TestValidationErrorCode:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestField:
|
||||||
|
def test_type_annotation(self):
|
||||||
|
assert Field[int, int, int, int] is not Field
|
||||||
|
|
||||||
|
def test_multiple_type_params_needed_when_hinting_class(self):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Field[int]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Field[int, int]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Field[int, int, int]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Field[int, int, int, int, int]
|
||||||
|
|
|
@ -11,7 +11,7 @@ from django.test import TestCase
|
||||||
|
|
||||||
from rest_framework.exceptions import ParseError
|
from rest_framework.exceptions import ParseError
|
||||||
from rest_framework.parsers import (
|
from rest_framework.parsers import (
|
||||||
FileUploadParser, FormParser, JSONParser, MultiPartParser
|
DataAndFiles, FileUploadParser, FormParser, JSONParser, MultiPartParser
|
||||||
)
|
)
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
|
@ -176,3 +176,19 @@ class TestPOSTAccessed(TestCase):
|
||||||
with pytest.raises(RawPostDataException):
|
with pytest.raises(RawPostDataException):
|
||||||
request.POST
|
request.POST
|
||||||
request.data
|
request.data
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataAndFiles:
|
||||||
|
def test_type_annotation(self):
|
||||||
|
"""
|
||||||
|
This class inherits directly from Generic, so adding type hints to it should
|
||||||
|
yield a different class.
|
||||||
|
"""
|
||||||
|
assert DataAndFiles[int, int] is not DataAndFiles
|
||||||
|
|
||||||
|
def test_need_multiple_type_params_when_hinting_class(self):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
DataAndFiles[int]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
DataAndFiles[int, int, int]
|
||||||
|
|
|
@ -376,3 +376,18 @@ class TestHyperlink:
|
||||||
upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
|
upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
|
||||||
assert upkled == self.default_hyperlink
|
assert upkled == self.default_hyperlink
|
||||||
assert upkled.name == self.default_hyperlink.name
|
assert upkled.name == self.default_hyperlink.name
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelatedField:
|
||||||
|
def test_type_annotation(self):
|
||||||
|
assert relations.RelatedField[int, int, int] is not relations.RelatedField
|
||||||
|
|
||||||
|
def test_multiple_type_params_needed_when_hinting_class(self):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
relations.RelatedField[int]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
relations.RelatedField[int, int]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
relations.RelatedField[int, int, int, int]
|
||||||
|
|
|
@ -209,8 +209,19 @@ class TestSerializer:
|
||||||
sys.version_info < (3, 7),
|
sys.version_info < (3, 7),
|
||||||
reason="subscriptable classes requires Python 3.7 or higher",
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
)
|
)
|
||||||
def test_serializer_is_subscriptable(self):
|
def test_type_annotation(self):
|
||||||
assert serializers.Serializer is serializers.Serializer["foo"]
|
assert serializers.Serializer is not serializers.Serializer["foo"]
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info > (3, 5),
|
||||||
|
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||||
|
)
|
||||||
|
def test_type_annotation_pre_36(self):
|
||||||
|
"""
|
||||||
|
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||||
|
should not yield a different class.
|
||||||
|
"""
|
||||||
|
assert serializers.Serializer[int] is not serializers.Serializer
|
||||||
|
|
||||||
|
|
||||||
class TestValidateMethod:
|
class TestValidateMethod:
|
||||||
|
@ -322,6 +333,28 @@ class TestBaseSerializer:
|
||||||
{'id': 2, 'name': 'ann', 'domain': 'example.com'}
|
{'id': 2, 'name': 'ann', 'domain': 'example.com'}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||||
|
)
|
||||||
|
def test_type_annotation(self):
|
||||||
|
"""
|
||||||
|
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||||
|
should not yield a different class.
|
||||||
|
"""
|
||||||
|
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info > (3, 5),
|
||||||
|
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||||
|
)
|
||||||
|
def test_type_annotation_pre_36(self):
|
||||||
|
"""
|
||||||
|
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||||
|
should not yield a different class.
|
||||||
|
"""
|
||||||
|
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer
|
||||||
|
|
||||||
|
|
||||||
class TestStarredSource:
|
class TestStarredSource:
|
||||||
"""
|
"""
|
||||||
|
@ -740,3 +773,27 @@ class TestDeclaredFieldInheritance:
|
||||||
'f4': serializers.CharField,
|
'f4': serializers.CharField,
|
||||||
'f5': serializers.CharField,
|
'f5': serializers.CharField,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelSerializer:
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||||
|
)
|
||||||
|
def test_type_annotation(self):
|
||||||
|
"""
|
||||||
|
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||||
|
should not yield a different class.
|
||||||
|
"""
|
||||||
|
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info > (3, 5),
|
||||||
|
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||||
|
)
|
||||||
|
def test_type_annotation_pre_36(self):
|
||||||
|
"""
|
||||||
|
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||||
|
should not yield a different class.
|
||||||
|
"""
|
||||||
|
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer
|
||||||
|
|
|
@ -62,7 +62,18 @@ class TestListSerializer:
|
||||||
reason="subscriptable classes requires Python 3.7 or higher",
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
)
|
)
|
||||||
def test_list_serializer_is_subscriptable(self):
|
def test_list_serializer_is_subscriptable(self):
|
||||||
assert serializers.ListSerializer is serializers.ListSerializer["foo"]
|
assert serializers.ListSerializer is not serializers.ListSerializer["foo"]
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info > (3, 5),
|
||||||
|
reason="generic meta class behaviour changed from 3.5 to 3.7",
|
||||||
|
)
|
||||||
|
def test_type_annotation_pre_36(self):
|
||||||
|
"""
|
||||||
|
This class does NOT inherit directly from Generic, so adding type hints to it
|
||||||
|
should not yield a different class.
|
||||||
|
"""
|
||||||
|
assert serializers.ListSerializer[int] is not serializers.ListSerializer
|
||||||
|
|
||||||
|
|
||||||
class TestListSerializerContainingNestedSerializer:
|
class TestListSerializerContainingNestedSerializer:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase, override_settings
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
|
|
||||||
|
@ -8,6 +9,7 @@ from rest_framework.routers import SimpleRouter
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from rest_framework.utils import json
|
from rest_framework.utils import json
|
||||||
from rest_framework.utils.breadcrumbs import get_breadcrumbs
|
from rest_framework.utils.breadcrumbs import get_breadcrumbs
|
||||||
|
from rest_framework.utils.field_mapping import ClassLookupDict
|
||||||
from rest_framework.utils.formatting import lazy_format
|
from rest_framework.utils.formatting import lazy_format
|
||||||
from rest_framework.utils.urls import remove_query_param, replace_query_param
|
from rest_framework.utils.urls import remove_query_param, replace_query_param
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
@ -267,3 +269,15 @@ class LazyFormatTests(TestCase):
|
||||||
assert message.format.call_count == 1
|
assert message.format.call_count == 1
|
||||||
str(formatted)
|
str(formatted)
|
||||||
assert message.format.call_count == 1
|
assert message.format.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class ClassLookupDictTests(TestCase):
|
||||||
|
def test_type_annotation(self):
|
||||||
|
assert ClassLookupDict[int, int] is not ClassLookupDict
|
||||||
|
|
||||||
|
def test_need_multiple_type_params_when_hinting_class(self):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
ClassLookupDict[None]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
ClassLookupDict[None, None, None]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user