Fixed #33738 -- Allowed handling ASGI http.disconnect in long-lived requests.

This commit is contained in:
th3nn3ss 2022-12-21 14:25:24 -05:00 committed by Mariusz Felisiak
parent 4e4eda6d6c
commit 1d1ddffc27
5 changed files with 157 additions and 3 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import logging import logging
import sys import sys
import tempfile import tempfile
@ -177,15 +178,49 @@ class ASGIHandler(base.BaseHandler):
body_file.close() body_file.close()
await self.send_response(error_response, send) await self.send_response(error_response, send)
return return
# Get the response, using the async mode of BaseHandler. # Try to catch a disconnect while getting response.
tasks = [
asyncio.create_task(self.run_get_response(request)),
asyncio.create_task(self.listen_for_disconnect(receive)),
]
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = done.pop(), pending.pop()
# Allow views to handle cancellation.
pending.cancel()
try:
await pending
except asyncio.CancelledError:
# Task re-raised the CancelledError as expected.
pass
try:
response = done.result()
except RequestAborted:
body_file.close()
return
except AssertionError:
body_file.close()
raise
# Send the response.
await self.send_response(response, send)
async def listen_for_disconnect(self, receive):
"""Listen for disconnect from the client."""
message = await receive()
if message["type"] == "http.disconnect":
raise RequestAborted()
# This should never happen.
assert False, "Invalid ASGI message after request body: %s" % message["type"]
async def run_get_response(self, request):
"""Get async response."""
# Use the async mode of BaseHandler.
response = await self.get_response_async(request) response = await self.get_response_async(request)
response._handler_class = self.__class__ response._handler_class = self.__class__
# Increase chunk size on file responses (ASGI servers handles low-level # Increase chunk size on file responses (ASGI servers handles low-level
# chunking). # chunking).
if isinstance(response, FileResponse): if isinstance(response, FileResponse):
response.block_size = self.chunk_size response.block_size = self.chunk_size
# Send the response. return response
await self.send_response(response, send)
async def read_body(self, receive): async def read_body(self, receive):
"""Reads an HTTP body from an ASGI connection.""" """Reads an HTTP body from an ASGI connection."""

View File

@ -192,6 +192,13 @@ Minor features
* ... * ...
Asynchronous views
~~~~~~~~~~~~~~~~~~
* Under ASGI, ``http.disconnect`` events are now handled. This allows views to
perform any necessary cleanup if a client disconnects before the response is
generated. See :ref:`async-handling-disconnect` for more details.
Cache Cache
~~~~~ ~~~~~

View File

@ -136,6 +136,26 @@ a purely synchronous codebase under ASGI because the request-handling code is
still all running asynchronously. In general you will only want to enable ASGI still all running asynchronously. In general you will only want to enable ASGI
mode if you have asynchronous code in your project. mode if you have asynchronous code in your project.
.. _async-handling-disconnect:
Handling disconnects
--------------------
.. versionadded:: 5.0
For long-lived requests, a client may disconnect before the view returns a
response. In this case, an ``asyncio.CancelledError`` will be raised in the
view. You can catch this error and handle it if you need to perform any
cleanup::
async def my_view(request):
try:
# Do some work
...
except asyncio.CancelledError:
# Handle disconnect
raise
.. _async-safety: .. _async-safety:
Async safety Async safety

View File

@ -7,8 +7,10 @@ from asgiref.testing import ApplicationCommunicator
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
from django.core.signals import request_finished, request_started from django.core.signals import request_finished, request_started
from django.db import close_old_connections from django.db import close_old_connections
from django.http import HttpResponse
from django.test import ( from django.test import (
AsyncRequestFactory, AsyncRequestFactory,
SimpleTestCase, SimpleTestCase,
@ -16,6 +18,7 @@ from django.test import (
modify_settings, modify_settings,
override_settings, override_settings,
) )
from django.urls import path
from django.utils.http import http_date from django.utils.http import http_date
from .urls import sync_waiter, test_filename from .urls import sync_waiter, test_filename
@ -234,6 +237,34 @@ class ASGITest(SimpleTestCase):
with self.assertRaises(asyncio.TimeoutError): with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output() await communicator.receive_output()
async def test_disconnect_with_body(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request", "body": b"some body"})
await communicator.send_input({"type": "http.disconnect"})
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()
async def test_assert_in_listen_for_disconnect(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request"})
await communicator.send_input({"type": "http.not_a_real_message"})
msg = "Invalid ASGI message after request body: http.not_a_real_message"
with self.assertRaisesMessage(AssertionError, msg):
await communicator.receive_output()
async def test_delayed_disconnect_with_body(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/delayed_hello/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request", "body": b"some body"})
await communicator.send_input({"type": "http.disconnect"})
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()
async def test_wrong_connection_type(self): async def test_wrong_connection_type(self):
application = get_asgi_application() application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/", type="other") scope = self.async_request_factory._base_scope(path="/", type="other")
@ -318,3 +349,56 @@ class ASGITest(SimpleTestCase):
self.assertEqual(len(sync_waiter.active_threads), 2) self.assertEqual(len(sync_waiter.active_threads), 2)
sync_waiter.active_threads.clear() sync_waiter.active_threads.clear()
async def test_asyncio_cancel_error(self):
# Flag to check if the view was cancelled.
view_did_cancel = False
# A view that will listen for the cancelled error.
async def view(request):
nonlocal view_did_cancel
try:
await asyncio.sleep(0.2)
return HttpResponse("Hello World!")
except asyncio.CancelledError:
# Set the flag.
view_did_cancel = True
raise
# Request class to use the view.
class TestASGIRequest(ASGIRequest):
urlconf = (path("cancel/", view),)
# Handler to use request class.
class TestASGIHandler(ASGIHandler):
request_class = TestASGIRequest
# Request cycle should complete since no disconnect was sent.
application = TestASGIHandler()
scope = self.async_request_factory._base_scope(path="/cancel/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request"})
response_start = await communicator.receive_output()
self.assertEqual(response_start["type"], "http.response.start")
self.assertEqual(response_start["status"], 200)
response_body = await communicator.receive_output()
self.assertEqual(response_body["type"], "http.response.body")
self.assertEqual(response_body["body"], b"Hello World!")
# Give response.close() time to finish.
await communicator.wait()
self.assertIs(view_did_cancel, False)
# Request cycle with a disconnect before the view can respond.
application = TestASGIHandler()
scope = self.async_request_factory._base_scope(path="/cancel/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request"})
# Let the view actually start.
await asyncio.sleep(0.1)
# Disconnect the client.
await communicator.send_input({"type": "http.disconnect"})
# The handler should not send a response.
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()
await communicator.wait()
self.assertIs(view_did_cancel, True)

View File

@ -1,4 +1,5 @@
import threading import threading
import time
from django.http import FileResponse, HttpResponse from django.http import FileResponse, HttpResponse
from django.urls import path from django.urls import path
@ -10,6 +11,12 @@ def hello(request):
return HttpResponse("Hello %s!" % name) return HttpResponse("Hello %s!" % name)
def hello_with_delay(request):
name = request.GET.get("name") or "World"
time.sleep(1)
return HttpResponse(f"Hello {name}!")
def hello_meta(request): def hello_meta(request):
return HttpResponse( return HttpResponse(
"From %s" % request.META.get("HTTP_REFERER") or "", "From %s" % request.META.get("HTTP_REFERER") or "",
@ -46,4 +53,5 @@ urlpatterns = [
path("meta/", hello_meta), path("meta/", hello_meta),
path("post/", post_echo), path("post/", post_echo),
path("wait/", sync_waiter), path("wait/", sync_waiter),
path("delayed_hello/", hello_with_delay),
] ]