diff --git a/channels/exceptions.py b/channels/exceptions.py index 53cdd37..ffdb5a2 100644 --- a/channels/exceptions.py +++ b/channels/exceptions.py @@ -21,3 +21,11 @@ class RequestTimeout(Exception): Raised when it takes too long to read a request body. """ pass + + +class RequestAborted(Exception): + """ + Raised when the incoming request tells us it's aborted partway through + reading the body. + """ + pass diff --git a/channels/handler.py b/channels/handler.py index 26ab586..b9664f3 100644 --- a/channels/handler.py +++ b/channels/handler.py @@ -18,7 +18,7 @@ from django.http import FileResponse, HttpResponse, HttpResponseServerError from django.utils import six from django.utils.functional import cached_property -from .exceptions import ResponseLater as ResponseLaterOuter, RequestTimeout +from .exceptions import ResponseLater as ResponseLaterOuter, RequestTimeout, RequestAborted logger = logging.getLogger('django.request') @@ -118,6 +118,9 @@ class AsgiRequest(http.HttpRequest): [message['body_channel']], block=True, ) + # If chunk contains close, abort. + if chunk.get("closed", False): + raise RequestAborted() # Add content to body self._body += chunk.get("content", "") # Exit loop if this was the last @@ -197,6 +200,9 @@ class AsgiHandler(base.BaseHandler): except RequestTimeout: # Parsing the rquest failed, so the response is a Request Timeout error response = HttpResponse("408 Request Timeout (upload too slow)", status_code=408) + except RequestAborted: + # Client closed connection on us mid request. Abort! + return else: try: response = self.get_response(request) diff --git a/channels/tests/test_request.py b/channels/tests/test_request.py index 028d3da..52712a4 100644 --- a/channels/tests/test_request.py +++ b/channels/tests/test_request.py @@ -4,7 +4,7 @@ from django.utils import six from channels import Channel from channels.tests import ChannelTestCase from channels.handler import AsgiRequest -from channels.exceptions import RequestTimeout +from channels.exceptions import RequestTimeout, RequestAborted class RequestTests(ChannelTestCase): @@ -216,3 +216,26 @@ class RequestTests(ChannelTestCase): body_receive_timeout = 0 with self.assertRaises(RequestTimeout): VeryImpatientRequest(self.get_next_message("test")) + + def test_request_abort(self): + """ + Tests that the code aborts when a request-body close is sent. + """ + Channel("test").send({ + "reply_channel": "test", + "http_version": "1.1", + "method": "POST", + "path": b"/test/", + "body": b"there_a", + "body_channel": "test-input", + "headers": { + "host": b"example.com", + "content-type": b"application/x-www-form-urlencoded", + "content-length": b"21", + }, + }) + Channel("test-input").send({ + "closed": True, + }) + with self.assertRaises(RequestAborted): + AsgiRequest(self.get_next_message("test"))