mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-10 19:56:59 +03:00
Allow generic requests, responses, fields, views (#8825)
Allow Request, Response, Field, and GenericAPIView to be subscriptable. This allows the classes to be made generic for type checking. This is especially useful since monkey patching DRF can be problematic as seen in this [issue][1]. [1]: https://github.com/typeddjango/djangorestframework-stubs/issues/299
This commit is contained in:
parent
390daf7a92
commit
15c613a9eb
|
@ -356,6 +356,10 @@ class Field:
|
||||||
messages.update(error_messages or {})
|
messages.update(error_messages or {})
|
||||||
self.error_messages = messages
|
self.error_messages = messages
|
||||||
|
|
||||||
|
# Allow generic typing checking for fields.
|
||||||
|
def __class_getitem__(cls, *args, **kwargs):
|
||||||
|
return cls
|
||||||
|
|
||||||
def bind(self, field_name, parent):
|
def bind(self, field_name, parent):
|
||||||
"""
|
"""
|
||||||
Initializes the field name and parent for the field instance.
|
Initializes the field name and parent for the field instance.
|
||||||
|
|
|
@ -45,6 +45,10 @@ class GenericAPIView(views.APIView):
|
||||||
# The style to use for queryset pagination.
|
# The style to use for queryset pagination.
|
||||||
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||||
|
|
||||||
|
# Allow generic typing checking for generic views.
|
||||||
|
def __class_getitem__(cls, *args, **kwargs):
|
||||||
|
return cls
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
"""
|
"""
|
||||||
Get the list of items for this view.
|
Get the list of items for this view.
|
||||||
|
|
|
@ -186,6 +186,10 @@ class Request:
|
||||||
self.method,
|
self.method,
|
||||||
self.get_full_path())
|
self.get_full_path())
|
||||||
|
|
||||||
|
# Allow generic typing checking for requests.
|
||||||
|
def __class_getitem__(cls, *args, **kwargs):
|
||||||
|
return cls
|
||||||
|
|
||||||
def _default_negotiator(self):
|
def _default_negotiator(self):
|
||||||
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
|
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,10 @@ class Response(SimpleTemplateResponse):
|
||||||
for name, value in headers.items():
|
for name, value in headers.items():
|
||||||
self[name] = value
|
self[name] = value
|
||||||
|
|
||||||
|
# Allow generic typing checking for responses.
|
||||||
|
def __class_getitem__(cls, *args, **kwargs):
|
||||||
|
return cls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rendered_content(self):
|
def rendered_content(self):
|
||||||
renderer = getattr(self, 'accepted_renderer', None)
|
renderer = getattr(self, 'accepted_renderer', None)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import datetime
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from decimal import ROUND_DOWN, ROUND_UP, Decimal
|
from decimal import ROUND_DOWN, ROUND_UP, Decimal
|
||||||
|
|
||||||
|
@ -625,6 +626,15 @@ class Test5087Regression:
|
||||||
assert field.root is parent
|
assert field.root is parent
|
||||||
|
|
||||||
|
|
||||||
|
class TestTyping(TestCase):
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
|
)
|
||||||
|
def test_field_is_subscriptable(self):
|
||||||
|
assert serializers.Field is serializers.Field["foo"]
|
||||||
|
|
||||||
|
|
||||||
# Tests for field input and output values.
|
# Tests for field input and output values.
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
|
@ -698,3 +700,26 @@ class TestSerializer(TestCase):
|
||||||
serializer = response.serializer
|
serializer = response.serializer
|
||||||
|
|
||||||
assert serializer.context is context
|
assert serializer.context is context
|
||||||
|
|
||||||
|
|
||||||
|
class TestTyping(TestCase):
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
|
)
|
||||||
|
def test_genericview_is_subscriptable(self):
|
||||||
|
assert generics.GenericAPIView is generics.GenericAPIView["foo"]
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
|
)
|
||||||
|
def test_listview_is_subscriptable(self):
|
||||||
|
assert generics.ListAPIView is generics.ListAPIView["foo"]
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
|
)
|
||||||
|
def test_instanceview_is_subscriptable(self):
|
||||||
|
assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"]
|
||||||
|
|
|
@ -3,6 +3,7 @@ Tests for content parsing, and form-overloaded content parsing.
|
||||||
"""
|
"""
|
||||||
import copy
|
import copy
|
||||||
import os.path
|
import os.path
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -352,3 +353,12 @@ class TestDeepcopy(TestCase):
|
||||||
def test_deepcopy_works(self):
|
def test_deepcopy_works(self):
|
||||||
request = Request(factory.get('/', secure=False))
|
request = Request(factory.get('/', secure=False))
|
||||||
copy.deepcopy(request)
|
copy.deepcopy(request)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTyping(TestCase):
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
|
)
|
||||||
|
def test_request_is_subscriptable(self):
|
||||||
|
assert Request is Request["foo"]
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase, override_settings
|
||||||
from django.urls import include, path, re_path
|
from django.urls import include, path, re_path
|
||||||
|
|
||||||
|
@ -283,3 +286,12 @@ class Issue807Tests(TestCase):
|
||||||
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
|
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
|
||||||
# self.assertContains(resp, 'Text comes here')
|
# self.assertContains(resp, 'Text comes here')
|
||||||
# self.assertContains(resp, 'Text description.')
|
# self.assertContains(resp, 'Text description.')
|
||||||
|
|
||||||
|
|
||||||
|
class TestTyping(TestCase):
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="subscriptable classes requires Python 3.7 or higher",
|
||||||
|
)
|
||||||
|
def test_response_is_subscriptable(self):
|
||||||
|
assert Response is Response["foo"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user