Handle asyncio.CancelledError in Server.application_checker (#341)

As of [bpo-32528](https://bugs.python.org/issue32528), asyncio.CancelledError is
not a subclass of concurrent.futures.CancelledError. This means that if an
asyncio future raises an exception, it won't be caught. Therefore, the
exception will bubble past the try-except within the loop in application_checker,
resulting in done applications not being cleaned up, and the application_checker
task not being queued again.
This commit is contained in:
Patrick Gingras 2020-11-11 10:12:33 -05:00 committed by GitHub
parent a69723ca3f
commit aae0870971
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 9 deletions

View File

@ -277,7 +277,7 @@ class Server(object):
if application_instance and application_instance.done(): if application_instance and application_instance.done():
try: try:
exception = application_instance.exception() exception = application_instance.exception()
except CancelledError: except (CancelledError, asyncio.CancelledError):
# Future cancellation. We can ignore this. # Future cancellation. We can ignore this.
pass pass
else: else:

View File

@ -7,7 +7,7 @@ import traceback
from concurrent.futures import CancelledError from concurrent.futures import CancelledError
class DaphneTestingInstance: class BaseDaphneTestingInstance:
""" """
Launches an instance of Daphne in a subprocess, with a host and port Launches an instance of Daphne in a subprocess, with a host and port
attribute allowing you to call it. attribute allowing you to call it.
@ -17,17 +17,16 @@ class DaphneTestingInstance:
startup_timeout = 2 startup_timeout = 2
def __init__(self, xff=False, http_timeout=None, request_buffer_size=None): def __init__(
self, xff=False, http_timeout=None, request_buffer_size=None, *, application
):
self.xff = xff self.xff = xff
self.http_timeout = http_timeout self.http_timeout = http_timeout
self.host = "127.0.0.1" self.host = "127.0.0.1"
self.lock = multiprocessing.Lock()
self.request_buffer_size = request_buffer_size self.request_buffer_size = request_buffer_size
self.application = application
def __enter__(self): def __enter__(self):
# Clear result storage
TestApplication.delete_setup()
TestApplication.delete_result()
# Option Daphne features # Option Daphne features
kwargs = {} kwargs = {}
if self.request_buffer_size: if self.request_buffer_size:
@ -42,7 +41,7 @@ class DaphneTestingInstance:
# Start up process # Start up process
self.process = DaphneProcess( self.process = DaphneProcess(
host=self.host, host=self.host,
application=TestApplication(lock=self.lock), application=self.application,
kwargs=kwargs, kwargs=kwargs,
setup=self.process_setup, setup=self.process_setup,
teardown=self.process_teardown, teardown=self.process_teardown,
@ -76,6 +75,21 @@ class DaphneTestingInstance:
""" """
pass pass
def get_received(self):
pass
class DaphneTestingInstance(BaseDaphneTestingInstance):
def __init__(self, *args, **kwargs):
self.lock = multiprocessing.Lock()
super().__init__(*args, **kwargs, application=TestApplication(lock=self.lock))
def __enter__(self):
# Clear result storage
TestApplication.delete_setup()
TestApplication.delete_result()
return super().__enter__()
def get_received(self): def get_received(self):
""" """
Returns the scope and messages the test application has received Returns the scope and messages the test application has received
@ -149,7 +163,7 @@ class DaphneProcess(multiprocessing.Process):
self.server.run() self.server.run()
finally: finally:
self.teardown() self.teardown()
except Exception as e: except BaseException as e:
# Put the error on our queue so the parent gets it # Put the error on our queue so the parent gets it
self.errors.put((e, traceback.format_exc())) self.errors.put((e, traceback.format_exc()))

View File

@ -8,6 +8,8 @@ import http_strategies
from http_base import DaphneTestCase, DaphneTestingInstance from http_base import DaphneTestCase, DaphneTestingInstance
from hypothesis import given, settings from hypothesis import given, settings
from daphne.testing import BaseDaphneTestingInstance
class TestWebsocket(DaphneTestCase): class TestWebsocket(DaphneTestCase):
""" """
@ -261,3 +263,54 @@ class TestWebsocket(DaphneTestCase):
self.websocket_send_frame(sock, "still alive?") self.websocket_send_frame(sock, "still alive?")
# Receive a frame and make sure it's correct # Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "cake" assert self.websocket_receive_frame(sock) == "cake"
def test_application_checker_handles_asyncio_cancellederror(self):
with CancellingTestingInstance() as app:
# Connect to the websocket app, it will immediately raise
# asyncio.CancelledError
sock, _ = self.websocket_handshake(app)
# Disconnect from the socket
sock.close()
# Wait for application_checker to clean up the applications for
# disconnected clients, and for the server to be stopped.
time.sleep(3)
# Make sure we received either no error, or a ConnectionsNotEmpty
while not app.process.errors.empty():
err, _tb = app.process.errors.get()
if not isinstance(err, ConnectionsNotEmpty):
raise err
self.fail(
"Server connections were not cleaned up after an asyncio.CancelledError was raised"
)
async def cancelling_application(scope, receive, send):
import asyncio
from twisted.internet import reactor
# Stop the server after a short delay so that the teardown is run.
reactor.callLater(2, lambda: reactor.stop())
await send({"type": "websocket.accept"})
raise asyncio.CancelledError()
class ConnectionsNotEmpty(Exception):
pass
class CancellingTestingInstance(BaseDaphneTestingInstance):
def __init__(self):
super().__init__(application=cancelling_application)
def process_teardown(self):
import multiprocessing
# Get a hold of the enclosing DaphneProcess (we're currently running in
# the same process as the application).
proc = multiprocessing.current_process()
# By now the (only) socket should have disconnected, and the
# application_checker should have run. If there are any connections
# still, it means that the application_checker did not clean them up.
if proc.server.connections:
raise ConnectionsNotEmpty()