Allow generic requests, responses, fields, views (#8825)

Allow Request, Response, Field, and GenericAPIView to be subscriptable.
This allows the classes to be made generic for type checking.

This is especially useful since monkey patching DRF can be problematic
as seen in this [issue][1].

[1]: https://github.com/typeddjango/djangorestframework-stubs/issues/299
This commit is contained in:
Jameel Al-Aziz 2023-02-22 07:39:01 -08:00 committed by GitHub
parent 390daf7a92
commit 15c613a9eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 73 additions and 0 deletions

View File

@ -356,6 +356,10 @@ class Field:
messages.update(error_messages or {})
self.error_messages = messages
# Allow generic typing checking for fields.
def __class_getitem__(cls, *args, **kwargs):
return cls
def bind(self, field_name, parent):
"""
Initializes the field name and parent for the field instance.

View File

@ -45,6 +45,10 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
# Allow generic typing checking for generic views.
def __class_getitem__(cls, *args, **kwargs):
return cls
def get_queryset(self):
"""
Get the list of items for this view.

View File

@ -186,6 +186,10 @@ class Request:
self.method,
self.get_full_path())
# Allow generic typing checking for requests.
def __class_getitem__(cls, *args, **kwargs):
return cls
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()

View File

@ -46,6 +46,10 @@ class Response(SimpleTemplateResponse):
for name, value in headers.items():
self[name] = value
# Allow generic typing checking for responses.
def __class_getitem__(cls, *args, **kwargs):
return cls
@property
def rendered_content(self):
renderer = getattr(self, 'accepted_renderer', None)

View File

@ -2,6 +2,7 @@ import datetime
import math
import os
import re
import sys
import uuid
from decimal import ROUND_DOWN, ROUND_UP, Decimal
@ -625,6 +626,15 @@ class Test5087Regression:
assert field.root is parent
class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_field_is_subscriptable(self):
assert serializers.Field is serializers.Field["foo"]
# Tests for field input and output values.
# ----------------------------------------

View File

@ -1,3 +1,5 @@
import sys
import pytest
from django.db import models
from django.http import Http404
@ -698,3 +700,26 @@ class TestSerializer(TestCase):
serializer = response.serializer
assert serializer.context is context
class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_genericview_is_subscriptable(self):
assert generics.GenericAPIView is generics.GenericAPIView["foo"]
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_listview_is_subscriptable(self):
assert generics.ListAPIView is generics.ListAPIView["foo"]
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_instanceview_is_subscriptable(self):
assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"]

View File

@ -3,6 +3,7 @@ Tests for content parsing, and form-overloaded content parsing.
"""
import copy
import os.path
import sys
import tempfile
import pytest
@ -352,3 +353,12 @@ class TestDeepcopy(TestCase):
def test_deepcopy_works(self):
request = Request(factory.get('/', secure=False))
copy.deepcopy(request)
class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_request_is_subscriptable(self):
assert Request is Request["foo"]

View File

@ -1,3 +1,6 @@
import sys
import pytest
from django.test import TestCase, override_settings
from django.urls import include, path, re_path
@ -283,3 +286,12 @@ class Issue807Tests(TestCase):
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')
class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_response_is_subscriptable(self):
assert Response is Response["foo"]